transform/VertexPulling: Use SymbolTable::New()
And clean up some code in the process.
Avoids potential symbol collisions. Simplifies the logic.
Bug: tint:712
Change-Id: Ibce5ccbd4c7fd45d5bf29906b5a83b3637b6cdcc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47633
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 951f4fd..df00a17 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -22,6 +22,7 @@
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
#include "src/semantic/variable.h"
+#include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
@@ -29,12 +30,360 @@
namespace transform {
namespace {
-static const char kVertexBufferNamePrefix[] = "_tint_pulling_vertex_buffer_";
-static const char kStructBufferName[] = "_tint_vertex_data";
-static const char kStructName[] = "TintVertexData";
-static const char kPullingPosVarName[] = "_tint_pulling_pos";
-static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
-static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index";
+struct State {
+ State(CloneContext& context, const VertexPulling::Config& c)
+ : ctx(context), cfg(c) {}
+ State(const State&) = default;
+ ~State() = default;
+
+ /// LocationReplacement describes an ast::Variable replacement for a
+ /// location input.
+ struct LocationReplacement {
+ /// The variable to replace in the source Program
+ ast::Variable* from;
+ /// The replacement to use in the target ProgramBuilder
+ ast::Variable* to;
+ };
+
+ CloneContext& ctx;
+ VertexPulling::Config const cfg;
+ std::unordered_map<uint32_t, ast::Variable*> location_to_var;
+ std::vector<LocationReplacement> location_replacements;
+ Symbol vertex_index_name;
+ Symbol instance_index_name;
+ Symbol pulling_position_name;
+ Symbol struct_buffer_name;
+ std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
+
+ /// Generate the vertex buffer binding name
+ /// @param index index to append to buffer name
+ Symbol GetVertexBufferName(uint32_t index) {
+ return utils::GetOrCreate(vertex_buffer_names, index, [&] {
+ static const char kVertexBufferNamePrefix[] =
+ "_tint_pulling_vertex_buffer_";
+ return ctx.dst->Symbols().New(kVertexBufferNamePrefix +
+ std::to_string(index));
+ });
+ }
+
+ /// Lazily generates the pulling position symbol
+ Symbol GetPullingPositionName() {
+ if (!pulling_position_name.IsValid()) {
+ static const char kPullingPosVarName[] = "_tint_pulling_pos";
+ pulling_position_name = ctx.dst->Symbols().New(kPullingPosVarName);
+ }
+ return pulling_position_name;
+ }
+
+ /// Lazily generates the structure buffer symbol
+ Symbol GetStructBufferName() {
+ if (!struct_buffer_name.IsValid()) {
+ static const char kStructBufferName[] = "_tint_vertex_data";
+ struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
+ }
+ return struct_buffer_name;
+ }
+
+ /// Inserts vertex_index binding, or finds the existing one
+ void FindOrInsertVertexIndexIfUsed() {
+ bool uses_vertex_step_mode = false;
+ for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
+ if (buffer_layout.step_mode == InputStepMode::kVertex) {
+ uses_vertex_step_mode = true;
+ break;
+ }
+ }
+ if (!uses_vertex_step_mode) {
+ return;
+ }
+
+ // Look for an existing vertex index builtin
+ for (auto* v : ctx.src->AST().GlobalVariables()) {
+ auto* sem = ctx.src->Sem().Get(v);
+ if (sem->StorageClass() != ast::StorageClass::kInput) {
+ continue;
+ }
+
+ for (auto* d : v->decorations()) {
+ if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
+ if (builtin->value() == ast::Builtin::kVertexIndex) {
+ vertex_index_name = ctx.Clone(v->symbol());
+ return;
+ }
+ }
+ }
+ }
+
+ // We didn't find a vertex index builtin, so create one
+ static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
+ vertex_index_name = ctx.dst->Symbols().New(kDefaultVertexIndexName);
+
+ ctx.dst->Global(
+ vertex_index_name, ctx.dst->ty.u32(), ast::StorageClass::kInput,
+ nullptr,
+ ast::DecorationList{
+ ctx.dst->create<ast::BuiltinDecoration>(ast::Builtin::kVertexIndex),
+ });
+ }
+
+ /// Inserts instance_index binding, or finds the existing one
+ void FindOrInsertInstanceIndexIfUsed() {
+ bool uses_instance_step_mode = false;
+ for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
+ if (buffer_layout.step_mode == InputStepMode::kInstance) {
+ uses_instance_step_mode = true;
+ break;
+ }
+ }
+ if (!uses_instance_step_mode) {
+ return;
+ }
+
+ // Look for an existing instance index builtin
+ for (auto* v : ctx.src->AST().GlobalVariables()) {
+ auto* sem = ctx.src->Sem().Get(v);
+ if (sem->StorageClass() != ast::StorageClass::kInput) {
+ continue;
+ }
+
+ for (auto* d : v->decorations()) {
+ if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
+ if (builtin->value() == ast::Builtin::kInstanceIndex) {
+ instance_index_name = ctx.Clone(v->symbol());
+ return;
+ }
+ }
+ }
+ }
+
+ // We didn't find an instance index builtin, so create one
+ static const char kDefaultInstanceIndexName[] =
+ "_tint_pulling_instance_index";
+ instance_index_name = ctx.dst->Symbols().New(kDefaultInstanceIndexName);
+
+ ctx.dst->Global(instance_index_name, ctx.dst->ty.u32(),
+ ast::StorageClass::kInput, nullptr,
+ ast::DecorationList{
+ ctx.dst->create<ast::BuiltinDecoration>(
+ ast::Builtin::kInstanceIndex),
+ });
+ }
+
+ /// Converts var<in> with a location decoration to var<private>
+ void ConvertVertexInputVariablesToPrivate() {
+ for (auto* v : ctx.src->AST().GlobalVariables()) {
+ auto* sem = ctx.src->Sem().Get(v);
+ if (sem->StorageClass() != ast::StorageClass::kInput) {
+ continue;
+ }
+
+ for (auto* d : v->decorations()) {
+ if (auto* l = d->As<ast::LocationDecoration>()) {
+ uint32_t location = l->value();
+ // This is where the replacement is created. Expressions use
+ // identifier strings instead of pointers, so we don't need to update
+ // any other place in the AST.
+ auto* replacement = ctx.dst->Var(ctx.Clone(v->symbol()),
+ ctx.Clone(v->declared_type()),
+ ast::StorageClass::kPrivate);
+ location_to_var[location] = replacement;
+ location_replacements.emplace_back(
+ LocationReplacement{v, replacement});
+ break;
+ }
+ }
+ }
+ }
+
+ /// Adds storage buffer decorated variables for the vertex buffers
+ void AddVertexStorageBuffers() {
+ // TODO(idanr): Make this readonly
+ // https://github.com/gpuweb/gpuweb/issues/935
+
+ // Creating the struct type
+ static const char kStructName[] = "TintVertexData";
+ auto* struct_type = ctx.dst->Structure(
+ ctx.dst->Symbols().New(kStructName),
+ {
+ ctx.dst->Member(GetStructBufferName(),
+ ctx.dst->ty.array<ProgramBuilder::u32, 0>(4)),
+ },
+ {
+ ctx.dst->create<ast::StructBlockDecoration>(),
+ });
+
+ for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
+ // The decorated variable with struct type
+ ctx.dst->Global(
+ GetVertexBufferName(i), struct_type, ast::StorageClass::kStorage,
+ nullptr,
+ ast::DecorationList{
+ ctx.dst->create<ast::BindingDecoration>(i),
+ ctx.dst->create<ast::GroupDecoration>(cfg.pulling_group),
+ });
+ }
+ }
+
+ /// Creates and returns the assignment to the variables from the buffers
+ ast::BlockStatement* CreateVertexPullingPreamble() {
+ // Assign by looking at the vertex descriptor to find attributes with
+ // matching location.
+
+ ast::StatementList stmts;
+
+ // Declare the pulling position variable in the shader
+ stmts.emplace_back(ctx.dst->create<ast::VariableDeclStatement>(
+ ctx.dst->Var(GetPullingPositionName(), ctx.dst->ty.u32(),
+ ast::StorageClass::kFunction)));
+
+ for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
+ const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i];
+
+ for (const VertexAttributeDescriptor& attribute_desc :
+ buffer_layout.attributes) {
+ auto it = location_to_var.find(attribute_desc.shader_location);
+ if (it == location_to_var.end()) {
+ continue;
+ }
+ auto* v = it->second;
+
+ auto name = buffer_layout.step_mode == InputStepMode::kVertex
+ ? vertex_index_name
+ : instance_index_name;
+
+ // An expression for the start of the read in the buffer in bytes
+ auto* pos_value = ctx.dst->Add(
+ ctx.dst->Mul(name,
+ static_cast<uint32_t>(buffer_layout.array_stride)),
+ static_cast<uint32_t>(attribute_desc.offset));
+
+ // Update position of the read
+ auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
+ ctx.dst->Expr(GetPullingPositionName()), pos_value);
+ stmts.emplace_back(set_pos_expr);
+
+ stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
+ ctx.dst->create<ast::IdentifierExpression>(v->symbol()),
+ AccessByFormat(i, attribute_desc.format)));
+ }
+ }
+
+ return ctx.dst->create<ast::BlockStatement>(stmts);
+ }
+
+ /// Generates an expression reading from a buffer a specific format.
+ /// This reads the value wherever `kPullingPosVarName` points to at the time
+ /// of the read.
+ /// @param buffer the index of the vertex buffer
+ /// @param format the format to read
+ ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) {
+ // TODO(idanr): this doesn't account for the format of the attribute in the
+ // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
+ // right now, we would try to assign a vec4<f32> to this attribute, but we
+ // really need to assign a vec4<u32> by casting.
+ // We could split this function to first do memory accesses and unpacking
+ // into int/uint/float1-4/etc, then convert that variable to a var<in> with
+ // the conversion defined in the WebGPU spec.
+ switch (format) {
+ case VertexFormat::kU32:
+ return AccessU32(buffer, ctx.dst->Expr(GetPullingPositionName()));
+ case VertexFormat::kI32:
+ return AccessI32(buffer, ctx.dst->Expr(GetPullingPositionName()));
+ case VertexFormat::kF32:
+ return AccessF32(buffer, ctx.dst->Expr(GetPullingPositionName()));
+ case VertexFormat::kVec2F32:
+ return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 2);
+ case VertexFormat::kVec3F32:
+ return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 3);
+ case VertexFormat::kVec4F32:
+ return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 4);
+ default:
+ return nullptr;
+ }
+ }
+
+ /// Generates an expression reading a uint32 from a vertex buffer
+ /// @param buffer the index of the vertex buffer
+ /// @param pos an expression for the position of the access, in bytes
+ ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) {
+ // Here we divide by 4, since the buffer is uint32 not uint8. The input
+ // buffer has byte offsets for each attribute, and we will convert it to u32
+ // indexes by dividing. Then, that element is going to be read, and if
+ // needed, unpacked into an appropriate variable. All reads should end up
+ // here as a base case.
+ return ctx.dst->create<ast::ArrayAccessorExpression>(
+ ctx.dst->MemberAccessor(GetVertexBufferName(buffer),
+ GetStructBufferName()),
+ ctx.dst->Div(pos, 4u));
+ }
+
+ /// Generates an expression reading an int32 from a vertex buffer
+ /// @param buffer the index of the vertex buffer
+ /// @param pos an expression for the position of the access, in bytes
+ ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) {
+ // as<T> reinterprets bits
+ return ctx.dst->create<ast::BitcastExpression>(ctx.dst->ty.i32(),
+ AccessU32(buffer, pos));
+ }
+
+ /// Generates an expression reading a float from a vertex buffer
+ /// @param buffer the index of the vertex buffer
+ /// @param pos an expression for the position of the access, in bytes
+ ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) {
+ // as<T> reinterprets bits
+ return ctx.dst->create<ast::BitcastExpression>(ctx.dst->ty.f32(),
+ AccessU32(buffer, pos));
+ }
+
+ /// Generates an expression reading a basic type (u32, i32, f32) from a
+ /// vertex buffer
+ /// @param buffer the index of the vertex buffer
+ /// @param pos an expression for the position of the access, in bytes
+ /// @param format the underlying vertex format
+ ast::Expression* AccessPrimitive(uint32_t buffer,
+ ast::Expression* pos,
+ VertexFormat format) {
+ // This function uses a position expression to read, rather than using the
+ // position variable. This allows us to read from offset positions relative
+ // to |kPullingPosVarName|. We can't call AccessByFormat because it reads
+ // only from the position variable.
+ switch (format) {
+ case VertexFormat::kU32:
+ return AccessU32(buffer, pos);
+ case VertexFormat::kI32:
+ return AccessI32(buffer, pos);
+ case VertexFormat::kF32:
+ return AccessF32(buffer, pos);
+ default:
+ return nullptr;
+ }
+ }
+
+ /// Generates an expression reading a vec2/3/4 from a vertex buffer.
+ /// This reads the value wherever `kPullingPosVarName` points to at the time
+ /// of the read.
+ /// @param buffer the index of the vertex buffer
+ /// @param element_stride stride between elements, in bytes
+ /// @param base_type underlying AST type
+ /// @param base_format underlying vertex format
+ /// @param count how many elements the vector has
+ ast::Expression* AccessVec(uint32_t buffer,
+ uint32_t element_stride,
+ type::Type* base_type,
+ VertexFormat base_format,
+ uint32_t count) {
+ ast::ExpressionList expr_list;
+ for (uint32_t i = 0; i < count; ++i) {
+ // Offset read position by element_stride for each component
+ auto* cur_pos =
+ ctx.dst->Add(GetPullingPositionName(), element_stride * i);
+ expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
+ }
+
+ return ctx.dst->create<ast::TypeConstructorExpression>(
+ ctx.dst->create<type::Vector>(base_type, count), std::move(expr_list));
+ }
+};
} // namespace
@@ -93,367 +442,6 @@
VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
default;
-VertexPulling::State::State(CloneContext& context, const Config& c)
- : ctx(context), cfg(c) {}
-VertexPulling::State::State(const State&) = default;
-VertexPulling::State::~State() = default;
-
-std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const {
- return kVertexBufferNamePrefix + std::to_string(index);
-}
-
-void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
- bool uses_vertex_step_mode = false;
- for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
- if (buffer_layout.step_mode == InputStepMode::kVertex) {
- uses_vertex_step_mode = true;
- break;
- }
- }
- if (!uses_vertex_step_mode) {
- return;
- }
-
- // Look for an existing vertex index builtin
- for (auto* v : ctx.src->AST().GlobalVariables()) {
- auto* sem = ctx.src->Sem().Get(v);
- if (sem->StorageClass() != ast::StorageClass::kInput) {
- continue;
- }
-
- for (auto* d : v->decorations()) {
- if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
- if (builtin->value() == ast::Builtin::kVertexIndex) {
- vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
- return;
- }
- }
- }
- }
-
- // We didn't find a vertex index builtin, so create one
- vertex_index_name = kDefaultVertexIndexName;
-
- auto* var = ctx.dst->create<ast::Variable>(
- Source{}, // source
- ctx.dst->Symbols().Register(vertex_index_name), // symbol
- ast::StorageClass::kInput, // storage_class
- GetU32Type(), // type
- false, // is_const
- nullptr, // constructor
- ast::DecorationList{
- ctx.dst->create<ast::BuiltinDecoration>(Source{},
- ast::Builtin::kVertexIndex),
- });
-
- ctx.dst->AST().AddGlobalVariable(var);
-}
-
-void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
- bool uses_instance_step_mode = false;
- for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
- if (buffer_layout.step_mode == InputStepMode::kInstance) {
- uses_instance_step_mode = true;
- break;
- }
- }
- if (!uses_instance_step_mode) {
- return;
- }
-
- // Look for an existing instance index builtin
- for (auto* v : ctx.src->AST().GlobalVariables()) {
- auto* sem = ctx.src->Sem().Get(v);
- if (sem->StorageClass() != ast::StorageClass::kInput) {
- continue;
- }
-
- for (auto* d : v->decorations()) {
- if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
- if (builtin->value() == ast::Builtin::kInstanceIndex) {
- instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
- return;
- }
- }
- }
- }
-
- // We didn't find an instance index builtin, so create one
- instance_index_name = kDefaultInstanceIndexName;
-
- auto* var = ctx.dst->create<ast::Variable>(
- Source{}, // source
- ctx.dst->Symbols().Register(instance_index_name), // symbol
- ast::StorageClass::kInput, // storage_class
- GetU32Type(), // type
- false, // is_const
- nullptr, // constructor
- ast::DecorationList{
- ctx.dst->create<ast::BuiltinDecoration>(Source{},
- ast::Builtin::kInstanceIndex),
- });
- ctx.dst->AST().AddGlobalVariable(var);
-}
-
-void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
- for (auto* v : ctx.src->AST().GlobalVariables()) {
- auto* sem = ctx.src->Sem().Get(v);
- if (sem->StorageClass() != ast::StorageClass::kInput) {
- continue;
- }
-
- for (auto* d : v->decorations()) {
- if (auto* l = d->As<ast::LocationDecoration>()) {
- uint32_t location = l->value();
- // This is where the replacement is created. Expressions use identifier
- // strings instead of pointers, so we don't need to update any other
- // place in the AST.
- auto name = ctx.src->Symbols().NameFor(v->symbol());
- auto* replacement = ctx.dst->create<ast::Variable>(
- Source{}, // source
- ctx.dst->Symbols().Register(name), // symbol
- ast::StorageClass::kPrivate, // storage_class
- ctx.Clone(v->declared_type()), // type
- false, // is_const
- nullptr, // constructor
- ast::DecorationList{}); // decorations
- location_to_var[location] = replacement;
- location_replacements.emplace_back(LocationReplacement{v, replacement});
- break;
- }
- }
- }
-}
-
-void VertexPulling::State::AddVertexStorageBuffers() {
- // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
- // The array inside the struct definition
- auto* internal_array_type = ctx.dst->create<type::Array>(
- GetU32Type(), 0,
- ast::DecorationList{
- ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
- });
-
- // Creating the struct type
- ast::StructMemberList members;
- members.push_back(ctx.dst->create<ast::StructMember>(
- Source{}, ctx.dst->Symbols().Register(kStructBufferName),
- internal_array_type, ast::DecorationList{}));
-
- ast::DecorationList decos;
- decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));
-
- auto* struct_type = ctx.dst->create<type::Struct>(
- ctx.dst->Symbols().Register(kStructName),
- ctx.dst->create<ast::Struct>(Source{}, std::move(members),
- std::move(decos)));
-
- for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
- // The decorated variable with struct type
- std::string name = GetVertexBufferName(i);
- auto* var = ctx.dst->create<ast::Variable>(
- Source{}, // source
- ctx.dst->Symbols().Register(name), // symbol
- ast::StorageClass::kStorage, // storage_class
- struct_type, // type
- false, // is_const
- nullptr, // constructor
- ast::DecorationList{
- ctx.dst->create<ast::BindingDecoration>(Source{}, i),
- ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
- });
- ctx.dst->AST().AddGlobalVariable(var);
- }
- ctx.dst->AST().AddConstructedType(struct_type);
-}
-
-ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
- // Assign by looking at the vertex descriptor to find attributes with matching
- // location.
-
- ast::StatementList stmts;
-
- // Declare the |kPullingPosVarName| variable in the shader
- auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
- Source{}, ctx.dst->create<ast::Variable>(
- Source{}, // source
- ctx.dst->Symbols().Register(kPullingPosVarName), // symbol
- ast::StorageClass::kFunction, // storage_class
- GetU32Type(), // type
- false, // is_const
- nullptr, // constructor
- ast::DecorationList{})); // decorations
-
- // |kPullingPosVarName| refers to the byte location of the current read. We
- // declare a variable in the shader to avoid having to reuse Expression
- // objects.
- stmts.emplace_back(pos_declaration);
-
- for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
- const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i];
-
- for (const VertexAttributeDescriptor& attribute_desc :
- buffer_layout.attributes) {
- auto it = location_to_var.find(attribute_desc.shader_location);
- if (it == location_to_var.end()) {
- continue;
- }
- auto* v = it->second;
-
- auto name = buffer_layout.step_mode == InputStepMode::kVertex
- ? vertex_index_name
- : instance_index_name;
- // Identifier to index by
- auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
- Source{}, ctx.dst->Symbols().Register(name));
-
- // An expression for the start of the read in the buffer in bytes
- auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
- Source{}, ast::BinaryOp::kAdd,
- ctx.dst->create<ast::BinaryExpression>(
- Source{}, ast::BinaryOp::kMultiply, index_identifier,
- GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
- GenUint(static_cast<uint32_t>(attribute_desc.offset)));
-
- // Update position of the read
- auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
- Source{}, CreatePullingPositionIdent(), pos_value);
- stmts.emplace_back(set_pos_expr);
-
- stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
- Source{},
- ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
- AccessByFormat(i, attribute_desc.format)));
- }
- }
-
- return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
-}
-
-ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
- return ctx.dst->create<ast::ScalarConstructorExpression>(
- Source{},
- ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
-}
-
-ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
- return ctx.dst->create<ast::IdentifierExpression>(
- Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
-}
-
-ast::Expression* VertexPulling::State::AccessByFormat(
- uint32_t buffer,
- VertexFormat format) const {
- // TODO(idanr): this doesn't account for the format of the attribute in the
- // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
- // right now, we would try to assign a vec4<f32> to this attribute, but we
- // really need to assign a vec4<u32> by casting.
- // We could split this function to first do memory accesses and unpacking into
- // int/uint/float1-4/etc, then convert that variable to a var<in> with the
- // conversion defined in the WebGPU spec.
- switch (format) {
- case VertexFormat::kU32:
- return AccessU32(buffer, CreatePullingPositionIdent());
- case VertexFormat::kI32:
- return AccessI32(buffer, CreatePullingPositionIdent());
- case VertexFormat::kF32:
- return AccessF32(buffer, CreatePullingPositionIdent());
- case VertexFormat::kVec2F32:
- return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2);
- case VertexFormat::kVec3F32:
- return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3);
- case VertexFormat::kVec4F32:
- return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4);
- default:
- return nullptr;
- }
-}
-
-ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
- ast::Expression* pos) const {
- // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
- // has byte offsets for each attribute, and we will convert it to u32 indexes
- // by dividing. Then, that element is going to be read, and if needed,
- // unpacked into an appropriate variable. All reads should end up here as a
- // base case.
- auto vbuf_name = GetVertexBufferName(buffer);
- return ctx.dst->create<ast::ArrayAccessorExpression>(
- Source{},
- ctx.dst->create<ast::MemberAccessorExpression>(
- Source{},
- ctx.dst->create<ast::IdentifierExpression>(
- Source{}, ctx.dst->Symbols().Register(vbuf_name)),
- ctx.dst->create<ast::IdentifierExpression>(
- Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
- ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
- pos, GenUint(4)));
-}
-
-ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
- ast::Expression* pos) const {
- // as<T> reinterprets bits
- return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
- AccessU32(buffer, pos));
-}
-
-ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
- ast::Expression* pos) const {
- // as<T> reinterprets bits
- return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
- AccessU32(buffer, pos));
-}
-
-ast::Expression* VertexPulling::State::AccessPrimitive(
- uint32_t buffer,
- ast::Expression* pos,
- VertexFormat format) const {
- // This function uses a position expression to read, rather than using the
- // position variable. This allows us to read from offset positions relative to
- // |kPullingPosVarName|. We can't call AccessByFormat because it reads only
- // from the position variable.
- switch (format) {
- case VertexFormat::kU32:
- return AccessU32(buffer, pos);
- case VertexFormat::kI32:
- return AccessI32(buffer, pos);
- case VertexFormat::kF32:
- return AccessF32(buffer, pos);
- default:
- return nullptr;
- }
-}
-
-ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
- uint32_t element_stride,
- type::Type* base_type,
- VertexFormat base_format,
- uint32_t count) const {
- ast::ExpressionList expr_list;
- for (uint32_t i = 0; i < count; ++i) {
- // Offset read position by element_stride for each component
- auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
- Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
- GenUint(element_stride * i));
- expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
- }
-
- return ctx.dst->create<ast::TypeConstructorExpression>(
- Source{}, ctx.dst->create<type::Vector>(base_type, count),
- std::move(expr_list));
-}
-
-type::Type* VertexPulling::State::GetU32Type() const {
- return ctx.dst->create<type::U32>();
-}
-
-type::Type* VertexPulling::State::GetI32Type() const {
- return ctx.dst->create<type::I32>();
-}
-
-type::Type* VertexPulling::State::GetF32Type() const {
- return ctx.dst->create<type::F32>();
-}
-
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h
index 8be1812..576bc09 100644
--- a/src/transform/vertex_pulling.h
+++ b/src/transform/vertex_pulling.h
@@ -176,105 +176,6 @@
private:
Config cfg_;
-
- struct State {
- State(CloneContext& ctx, const Config& c);
- explicit State(const State&);
- ~State();
-
- /// Generate the vertex buffer binding name
- /// @param index index to append to buffer name
- std::string GetVertexBufferName(uint32_t index) const;
-
- /// Inserts vertex_index binding, or finds the existing one
- void FindOrInsertVertexIndexIfUsed();
-
- /// Inserts instance_index binding, or finds the existing one
- void FindOrInsertInstanceIndexIfUsed();
-
- /// Converts var<in> with a location decoration to var<private>
- void ConvertVertexInputVariablesToPrivate();
-
- /// Adds storage buffer decorated variables for the vertex buffers
- void AddVertexStorageBuffers();
-
- /// Creates and returns the assignment to the variables from the buffers
- ast::BlockStatement* CreateVertexPullingPreamble() const;
-
- /// Generates an expression holding a constant uint
- /// @param value uint value
- ast::Expression* GenUint(uint32_t value) const;
-
- /// Generates an expression to read the shader value `kPullingPosVarName`
- ast::Expression* CreatePullingPositionIdent() const;
-
- /// Generates an expression reading from a buffer a specific format.
- /// This reads the value wherever `kPullingPosVarName` points to at the time
- /// of the read.
- /// @param buffer the index of the vertex buffer
- /// @param format the format to read
- ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) const;
-
- /// Generates an expression reading a uint32 from a vertex buffer
- /// @param buffer the index of the vertex buffer
- /// @param pos an expression for the position of the access, in bytes
- ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) const;
-
- /// Generates an expression reading an int32 from a vertex buffer
- /// @param buffer the index of the vertex buffer
- /// @param pos an expression for the position of the access, in bytes
- ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) const;
-
- /// Generates an expression reading a float from a vertex buffer
- /// @param buffer the index of the vertex buffer
- /// @param pos an expression for the position of the access, in bytes
- ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) const;
-
- /// Generates an expression reading a basic type (u32, i32, f32) from a
- /// vertex buffer
- /// @param buffer the index of the vertex buffer
- /// @param pos an expression for the position of the access, in bytes
- /// @param format the underlying vertex format
- ast::Expression* AccessPrimitive(uint32_t buffer,
- ast::Expression* pos,
- VertexFormat format) const;
-
- /// Generates an expression reading a vec2/3/4 from a vertex buffer.
- /// This reads the value wherever `kPullingPosVarName` points to at the time
- /// of the read.
- /// @param buffer the index of the vertex buffer
- /// @param element_stride stride between elements, in bytes
- /// @param base_type underlying AST type
- /// @param base_format underlying vertex format
- /// @param count how many elements the vector has
- ast::Expression* AccessVec(uint32_t buffer,
- uint32_t element_stride,
- type::Type* base_type,
- VertexFormat base_format,
- uint32_t count) const;
-
- // Used to grab corresponding types from the type manager
- type::Type* GetU32Type() const;
- type::Type* GetI32Type() const;
- type::Type* GetF32Type() const;
-
- CloneContext& ctx;
- Config const cfg;
-
- /// LocationReplacement describes an ast::Variable replacement for a
- /// location input.
- struct LocationReplacement {
- /// The variable to replace in the source Program
- ast::Variable* from;
- /// The replacement to use in the target ProgramBuilder
- ast::Variable* to;
- };
-
- std::unordered_map<uint32_t, ast::Variable*> location_to_var;
- std::vector<LocationReplacement> location_replacements;
- std::string vertex_index_name;
- std::string instance_index_name;
- };
};
} // namespace transform
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index e1222e6..ce01f87 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -113,13 +113,13 @@
auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
[[block]]
struct TintVertexData {
_tint_vertex_data : [[stride(4)]] array<u32>;
};
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
var<private> var_a : f32;
[[stage(vertex)]]
@@ -155,13 +155,13 @@
auto* expect = R"(
[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : u32;
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
[[block]]
struct TintVertexData {
_tint_vertex_data : [[stride(4)]] array<u32>;
};
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
var<private> var_a : f32;
[[stage(vertex)]]
@@ -197,13 +197,13 @@
auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
-[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
[[block]]
struct TintVertexData {
_tint_vertex_data : [[stride(4)]] array<u32>;
};
+[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
var<private> var_a : f32;
[[stage(vertex)]]
@@ -242,15 +242,15 @@
)";
auto* expect = R"(
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
-[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
-
[[block]]
struct TintVertexData {
_tint_vertex_data : [[stride(4)]] array<u32>;
};
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
+[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
+
var<private> var_a : f32;
var<private> var_b : f32;
@@ -305,13 +305,13 @@
auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
[[block]]
struct TintVertexData {
_tint_vertex_data : [[stride(4)]] array<u32>;
};
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
var<private> var_a : f32;
var<private> var_b : vec4<f32>;
@@ -355,17 +355,17 @@
auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
+[[block]]
+struct TintVertexData {
+ _tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
[[binding(2), group(4)]] var<storage> _tint_pulling_vertex_buffer_2 : TintVertexData;
-[[block]]
-struct TintVertexData {
- _tint_vertex_data : [[stride(4)]] array<u32>;
-};
-
var<private> var_a : vec2<f32>;
var<private> var_b : vec3<f32>;