transform/VertexPulling: Handle entry point parameters
Adds support for the new shader IO syntax by processing entry
parameters and pushing them to function-scope variables as necessary.
Module-scope variables are still supported for now.
Fixed: tint:731
Change-Id: I36d7ce4e3a990b6323292cb7c685af37187d6fda
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48960
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 98e4000..f84bb9b 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -47,12 +47,14 @@
CloneContext& ctx;
VertexPulling::Config const cfg;
- std::unordered_map<uint32_t, ast::Variable*> location_to_var;
- Symbol vertex_index_name;
- Symbol instance_index_name;
+ std::unordered_map<uint32_t, std::function<ast::Expression*()>>
+ location_to_expr;
+ std::function<ast::Expression*()> vertex_index_expr = nullptr;
+ std::function<ast::Expression*()> instance_index_expr = nullptr;
Symbol pulling_position_name;
Symbol struct_buffer_name;
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
+ ast::VariableList new_function_parameters;
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
@@ -106,7 +108,9 @@
for (auto* d : v->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kVertexIndex) {
- vertex_index_name = ctx.Clone(v->symbol());
+ vertex_index_expr = [this, v]() {
+ return ctx.dst->Expr(ctx.Clone(v->symbol()));
+ };
return;
}
}
@@ -114,11 +118,10 @@
}
// We didn't find a vertex index builtin, so create one
- static const char kDefaultVertexIndexName[] = "tint_pulling_vertex_index";
- vertex_index_name = ctx.dst->Symbols().New(kDefaultVertexIndexName);
+ auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
+ vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
- ctx.dst->Global(vertex_index_name, ctx.dst->ty.u32(),
- ast::StorageClass::kInput, nullptr,
+ ctx.dst->Global(name, ctx.dst->ty.u32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{
ctx.dst->Builtin(ast::Builtin::kVertexIndex),
});
@@ -147,7 +150,9 @@
for (auto* d : v->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kInstanceIndex) {
- instance_index_name = ctx.Clone(v->symbol());
+ instance_index_expr = [this, v]() {
+ return ctx.dst->Expr(ctx.Clone(v->symbol()));
+ };
return;
}
}
@@ -155,12 +160,10 @@
}
// We didn't find an instance index builtin, so create one
- static const char kDefaultInstanceIndexName[] =
- "tint_pulling_instance_index";
- instance_index_name = ctx.dst->Symbols().New(kDefaultInstanceIndexName);
+ auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
+ instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
- ctx.dst->Global(instance_index_name, ctx.dst->ty.u32(),
- ast::StorageClass::kInput, nullptr,
+ ctx.dst->Global(name, ctx.dst->ty.u32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{
ctx.dst->Builtin(ast::Builtin::kInstanceIndex),
});
@@ -180,10 +183,12 @@
// This is where the replacement is created. Expressions use
// identifier strings instead of pointers, so we don't need to update
// any other place in the AST.
- auto* replacement = ctx.dst->Var(ctx.Clone(v->symbol()),
- ctx.Clone(v->declared_type()),
+ auto name = ctx.Clone(v->symbol());
+ auto* replacement = ctx.dst->Var(name, ctx.Clone(v->declared_type()),
ast::StorageClass::kPrivate);
- location_to_var[location] = replacement;
+ location_to_expr[location] = [this, name]() {
+ return ctx.dst->Expr(name);
+ };
ctx.Replace(v, replacement);
break;
}
@@ -237,30 +242,29 @@
for (const VertexAttributeDescriptor& attribute_desc :
buffer_layout.attributes) {
- auto it = location_to_var.find(attribute_desc.shader_location);
- if (it == location_to_var.end()) {
+ auto it = location_to_expr.find(attribute_desc.shader_location);
+ if (it == location_to_expr.end()) {
continue;
}
- auto* v = it->second;
+ auto* ident = it->second();
- auto name = buffer_layout.step_mode == InputStepMode::kVertex
- ? vertex_index_name
- : instance_index_name;
+ auto* index_expr = buffer_layout.step_mode == InputStepMode::kVertex
+ ? vertex_index_expr()
+ : instance_index_expr();
// An expression for the start of the read in the buffer in bytes
auto* pos_value = ctx.dst->Add(
- ctx.dst->Mul(name,
+ ctx.dst->Mul(index_expr,
static_cast<uint32_t>(buffer_layout.array_stride)),
static_cast<uint32_t>(attribute_desc.offset));
// Update position of the read
- auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
- ctx.dst->Expr(GetPullingPositionName()), pos_value);
+ auto* set_pos_expr =
+ ctx.dst->Assign(ctx.dst->Expr(GetPullingPositionName()), pos_value);
stmts.emplace_back(set_pos_expr);
- stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
- ctx.dst->create<ast::IdentifierExpression>(v->symbol()),
- AccessByFormat(i, attribute_desc.format)));
+ stmts.emplace_back(
+ ctx.dst->Assign(ident, AccessByFormat(i, attribute_desc.format)));
}
}
@@ -379,6 +383,186 @@
return ctx.dst->create<ast::TypeConstructorExpression>(
ctx.dst->create<sem::Vector>(base_type, count), std::move(expr_list));
}
+
+ /// Process a non-struct entry point parameter.
+ /// Generate function-scope variables for location parameters, and record
+ /// vertex_index and instance_index builtins if present.
+ /// @param func the entry point function
+ /// @param param the parameter to process
+ void ProcessNonStructParameter(ast::Function* func, ast::Variable* param) {
+ if (auto* location =
+ ast::GetDecoration<ast::LocationDecoration>(param->decorations())) {
+ // Create a function-scope variable to replace the parameter.
+ auto func_var_sym = ctx.Clone(param->symbol());
+ auto* func_var_type = ctx.Clone(param->declared_type());
+ auto* func_var = ctx.dst->Var(func_var_sym, func_var_type,
+ ast::StorageClass::kFunction);
+ ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+ ctx.dst->Decl(func_var));
+ // Capture mapping from location to the new variable.
+ location_to_expr[location->value()] = [this, func_var]() {
+ return ctx.dst->Expr(func_var);
+ };
+ } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
+ param->decorations())) {
+ // Check for existing vertex_index and instance_index builtins.
+ if (builtin->value() == ast::Builtin::kVertexIndex) {
+ vertex_index_expr = [this, param]() {
+ return ctx.dst->Expr(ctx.Clone(param->symbol()));
+ };
+ } else if (builtin->value() == ast::Builtin::kInstanceIndex) {
+ instance_index_expr = [this, param]() {
+ return ctx.dst->Expr(ctx.Clone(param->symbol()));
+ };
+ }
+ new_function_parameters.push_back(ctx.Clone(param));
+ } else {
+ TINT_ICE(ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+ }
+ }
+
+ /// Process a struct entry point parameter.
+ /// If the struct has members with location attributes, push the parameter to
+ /// a function-scope variable and create a new struct parameter without those
+ /// attributes. Record expressions for members that are vertex_index and
+ /// instance_index builtins.
+ /// @param func the entry point function
+ /// @param param the parameter to process
+ void ProcessStructParameter(ast::Function* func, ast::Variable* param) {
+ auto* struct_ty = param->declared_type()->As<sem::StructType>();
+ if (!struct_ty) {
+ TINT_ICE(ctx.dst->Diagnostics()) << "Invalid struct parameter";
+ }
+
+ auto param_sym = ctx.Clone(param->symbol());
+
+ // Process the struct members.
+ bool has_locations = false;
+ ast::StructMemberList members_to_clone;
+ for (auto* member : struct_ty->impl()->members()) {
+ auto member_sym = ctx.Clone(member->symbol());
+ std::function<ast::Expression*()> member_expr = [this, param_sym,
+ member_sym]() {
+ return ctx.dst->MemberAccessor(param_sym, member_sym);
+ };
+
+ if (auto* location = ast::GetDecoration<ast::LocationDecoration>(
+ member->decorations())) {
+ // Capture mapping from location to struct member.
+ location_to_expr[location->value()] = member_expr;
+ has_locations = true;
+ } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
+ member->decorations())) {
+ // Check for existing vertex_index and instance_index builtins.
+ if (builtin->value() == ast::Builtin::kVertexIndex) {
+ vertex_index_expr = member_expr;
+ } else if (builtin->value() == ast::Builtin::kInstanceIndex) {
+ instance_index_expr = member_expr;
+ }
+ members_to_clone.push_back(member);
+ } else {
+ TINT_ICE(ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+ }
+ }
+
+ if (!has_locations) {
+ // Nothing to do.
+ new_function_parameters.push_back(ctx.Clone(param));
+ return;
+ }
+
+ // Create a function-scope variable to replace the parameter.
+ auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->declared_type()),
+ ast::StorageClass::kFunction);
+ ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+ ctx.dst->Decl(func_var));
+
+ if (!members_to_clone.empty()) {
+ // Create a new struct without the location attributes.
+ ast::StructMemberList new_members;
+ for (auto* member : members_to_clone) {
+ auto member_sym = ctx.Clone(member->symbol());
+ auto member_type = ctx.Clone(member->type());
+ auto member_decos = ctx.Clone(member->decorations());
+ new_members.push_back(
+ ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
+ }
+ auto new_struct =
+ ctx.dst->Structure(ctx.dst->Symbols().New(), new_members);
+
+ // Create a new function parameter with this struct.
+ auto* new_param = ctx.dst->Param(ctx.dst->Symbols().New(), new_struct);
+ new_function_parameters.push_back(new_param);
+
+ // Copy values from the new parameter to the function-scope variable.
+ for (auto* member : members_to_clone) {
+ auto member_name = ctx.Clone(member->symbol());
+ ctx.InsertBefore(
+ func->body()->statements(), *func->body()->begin(),
+ ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
+ ctx.dst->MemberAccessor(new_param, member_name)));
+ }
+ }
+ }
+
+ /// Process an entry point function.
+ /// @param func the entry point function
+ void Process(ast::Function* func) {
+ if (func->body()->empty()) {
+ return;
+ }
+
+ // Process entry point parameters.
+ for (auto* param : func->params()) {
+ auto* sem = ctx.src->Sem().Get(param);
+ if (sem->Type()->Is<sem::StructType>()) {
+ ProcessStructParameter(func, param);
+ } else {
+ ProcessNonStructParameter(func, param);
+ }
+ }
+
+ // Insert new parameters for vertex_index and instance_index if needed.
+ if (!vertex_index_expr) {
+ for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+ if (layout.step_mode == InputStepMode::kVertex) {
+ auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
+ new_function_parameters.push_back(
+ ctx.dst->Param(name, ctx.dst->ty.u32(),
+ {ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
+ vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+ break;
+ }
+ }
+ }
+ if (!instance_index_expr) {
+ for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+ if (layout.step_mode == InputStepMode::kInstance) {
+ auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
+ new_function_parameters.push_back(
+ ctx.dst->Param(name, ctx.dst->ty.u32(),
+ {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
+ instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+ break;
+ }
+ }
+ }
+
+ // Generate vertex pulling preamble.
+ ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+ CreateVertexPullingPreamble());
+
+ // Rewrite the function header with the new parameters.
+ auto func_sym = ctx.Clone(func->symbol());
+ auto ret_type = ctx.Clone(func->return_type());
+ auto* body = ctx.Clone(func->body());
+ auto decos = ctx.Clone(func->decorations());
+ auto ret_decos = ctx.Clone(func->return_type_decorations());
+ auto* new_func = ctx.dst->create<ast::Function>(
+ func->source(), func_sym, new_function_parameters, ret_type, body,
+ std::move(decos), std::move(ret_decos));
+ ctx.Replace(func, new_func);
+ }
};
} // namespace
@@ -413,18 +597,26 @@
CloneContext ctx(&out, in);
State state{ctx, cfg};
- state.FindOrInsertVertexIndexIfUsed();
- state.FindOrInsertInstanceIndexIfUsed();
- state.ConvertVertexInputVariablesToPrivate();
- state.AddVertexStorageBuffers();
- ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* {
- if (f == func) {
- return CloneWithStatementsAtStart(&ctx, f,
- {state.CreateVertexPullingPreamble()});
- }
- return nullptr; // Just clone func
- });
+ if (func->params().empty()) {
+ // TODO(crbug.com/tint/697): Remove this path for the old shader IO syntax.
+ state.FindOrInsertVertexIndexIfUsed();
+ state.FindOrInsertInstanceIndexIfUsed();
+ state.ConvertVertexInputVariablesToPrivate();
+ state.AddVertexStorageBuffers();
+
+ ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* {
+ if (f == func) {
+ return CloneWithStatementsAtStart(
+ &ctx, f, {state.CreateVertexPullingPreamble()});
+ }
+ return nullptr; // Just clone func
+ });
+ } else {
+ state.AddVertexStorageBuffers();
+ state.Process(func);
+ }
+
ctx.Clone();
return Output(Program(std::move(out)));
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index 8effc23..1e6ff9e 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -109,17 +109,13 @@
TEST_F(VertexPullingTest, OneAttribute) {
auto* src = R"(
-[[location(0)]] var<in> var_a : f32;
-
[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
- return vec4<f32>();
+fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_pulling_vertex_index : u32;
-
[[block]]
struct TintVertexData {
tint_vertex_data : [[stride(4)]] array<u32>;
@@ -127,16 +123,15 @@
[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
-var<private> var_a : f32;
-
[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : f32;
{
var tint_pulling_pos : u32;
tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u);
var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
}
- return vec4<f32>();
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
@@ -154,17 +149,13 @@
TEST_F(VertexPullingTest, OneInstancedAttribute) {
auto* src = R"(
-[[location(0)]] var<in> var_a : f32;
-
[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
- return vec4<f32>();
+fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
auto* expect = R"(
-[[builtin(instance_index)]] var<in> tint_pulling_instance_index : u32;
-
[[block]]
struct TintVertexData {
tint_vertex_data : [[stride(4)]] array<u32>;
@@ -172,16 +163,15 @@
[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
-var<private> var_a : f32;
-
[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
+fn main([[builtin(instance_index)]] tint_pulling_instance_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : f32;
{
var tint_pulling_pos : u32;
tint_pulling_pos = ((tint_pulling_instance_index * 4u) + 0u);
var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
}
- return vec4<f32>();
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
}
)";
@@ -199,6 +189,472 @@
TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(5)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : f32;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u);
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ }
+ return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}};
+ cfg.pulling_group = 5;
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, OneAttribute_Struct) {
+ auto* src = R"(
+struct Inputs {
+ [[location(0)]] var_a : f32;
+};
+
+[[stage(vertex)]]
+fn main(inputs : Inputs) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+struct Inputs {
+ [[location(0)]]
+ var_a : f32;
+};
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var inputs : Inputs;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u);
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ }
+ return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// We expect the transform to use an existing builtin variables if it finds them
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
+ auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32,
+ [[location(1)]] var_b : f32,
+ [[builtin(vertex_index)]] custom_vertex_index : u32,
+ [[builtin(instance_index)]] custom_instance_index : u32
+ ) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(var_a, var_b, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] custom_vertex_index : u32, [[builtin(instance_index)]] custom_instance_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : f32;
+ var var_b : f32;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((custom_vertex_index * 4u) + 0u);
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ tint_pulling_pos = ((custom_instance_index * 4u) + 0u);
+ var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ }
+ return vec4<f32>(var_a, var_b, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ InputStepMode::kVertex,
+ {{VertexFormat::kF32, 0, 0}},
+ },
+ {
+ 4,
+ InputStepMode::kInstance,
+ {{VertexFormat::kF32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct) {
+ auto* src = R"(
+struct Inputs {
+ [[location(0)]] var_a : f32;
+ [[location(1)]] var_b : f32;
+ [[builtin(vertex_index)]] custom_vertex_index : u32;
+ [[builtin(instance_index)]] custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(inputs : Inputs) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+struct tint_symbol {
+ [[builtin(vertex_index)]]
+ custom_vertex_index : u32;
+ [[builtin(instance_index)]]
+ custom_instance_index : u32;
+};
+
+struct Inputs {
+ [[location(0)]]
+ var_a : f32;
+ [[location(1)]]
+ var_b : f32;
+ [[builtin(vertex_index)]]
+ custom_vertex_index : u32;
+ [[builtin(instance_index)]]
+ custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(tint_symbol_1 : tint_symbol) -> [[builtin(position)]] vec4<f32> {
+ var inputs : Inputs;
+ inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index;
+ inputs.custom_instance_index = tint_symbol_1.custom_instance_index;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((inputs.custom_vertex_index * 4u) + 0u);
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ tint_pulling_pos = ((inputs.custom_instance_index * 4u) + 0u);
+ inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ }
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ InputStepMode::kVertex,
+ {{VertexFormat::kF32, 0, 0}},
+ },
+ {
+ 4,
+ InputStepMode::kInstance,
+ {{VertexFormat::kF32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct) {
+ auto* src = R"(
+struct Inputs {
+ [[location(0)]] var_a : f32;
+ [[location(1)]] var_b : f32;
+};
+
+struct Indices {
+ [[builtin(vertex_index)]] custom_vertex_index : u32;
+ [[builtin(instance_index)]] custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(inputs : Inputs, indices : Indices) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+struct Inputs {
+ [[location(0)]]
+ var_a : f32;
+ [[location(1)]]
+ var_b : f32;
+};
+
+struct Indices {
+ [[builtin(vertex_index)]]
+ custom_vertex_index : u32;
+ [[builtin(instance_index)]]
+ custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(indices : Indices) -> [[builtin(position)]] vec4<f32> {
+ var inputs : Inputs;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((indices.custom_vertex_index * 4u) + 0u);
+ inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ tint_pulling_pos = ((indices.custom_instance_index * 4u) + 0u);
+ inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ }
+ return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {
+ 4,
+ InputStepMode::kVertex,
+ {{VertexFormat::kF32, 0, 0}},
+ },
+ {
+ 4,
+ InputStepMode::kInstance,
+ {{VertexFormat::kF32, 0, 1}},
+ },
+ }};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
+ auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32,
+ [[location(1)]] var_b : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : f32;
+ var var_b : vec4<f32>;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+ tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
+ var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)]));
+ }
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{16,
+ InputStepMode::kVertex,
+ {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FloatVectorAttributes) {
+ auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : vec2<f32>,
+ [[location(1)]] var_b : vec3<f32>,
+ [[location(2)]] var_c : vec4<f32>
+ ) -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+[[binding(2), group(4)]] var<storage> tint_pulling_vertex_buffer_2 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : vec2<f32>;
+ var var_b : vec3<f32>;
+ var var_c : vec4<f32>;
+ {
+ var tint_pulling_pos : u32;
+ tint_pulling_pos = ((tint_pulling_vertex_index * 8u) + 0u);
+ var_a = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]));
+ tint_pulling_pos = ((tint_pulling_vertex_index * 12u) + 0u);
+ var_b = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]));
+ tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
+ var_c = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)]));
+ }
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {{
+ {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
+ {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
+ {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}},
+ }};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, AttemptSymbolCollision) {
+ auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32,
+ [[location(1)]] var_b : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
+ var tint_pulling_vertex_index : i32;
+ var tint_pulling_vertex_buffer_0 : i32;
+ var tint_vertex_data : i32;
+ var tint_pulling_pos : i32;
+ return vec4<f32>();
+}
+)";
+
+ auto* expect = R"(
+[[block]]
+struct TintVertexData {
+ tint_vertex_data_1 : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0_1 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index_1 : u32) -> [[builtin(position)]] vec4<f32> {
+ var var_a : f32;
+ var var_b : vec4<f32>;
+ {
+ var tint_pulling_pos_1 : u32;
+ tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
+ var_a = bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(tint_pulling_pos_1 / 4u)]);
+ tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
+ var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 12u) / 4u)]));
+ }
+ var tint_pulling_vertex_index : i32;
+ var tint_pulling_vertex_buffer_0 : i32;
+ var tint_vertex_data : i32;
+ var tint_pulling_pos : i32;
+ return vec4<f32>();
+}
+)";
+
+ VertexPulling::Config cfg;
+ cfg.vertex_state = {
+ {{16,
+ InputStepMode::kVertex,
+ {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
+ cfg.entry_point_name = "main";
+
+ DataMap data;
+ data.Add<VertexPulling::Config>(cfg);
+ auto got = Run<VertexPulling>(src, std::move(data));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// TODO(crbug.com/tint/697): Remove this.
+TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet_Legacy) {
+ auto* src = R"(
[[location(0)]] var<in> var_a : f32;
[[stage(vertex)]]
@@ -243,8 +699,9 @@
EXPECT_EQ(expect, str(got));
}
+// TODO(crbug.com/tint/697): Remove this.
// We expect the transform to use an existing builtin variables if it finds them
-TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Legacy) {
auto* src = R"(
[[location(0)]] var<in> var_a : f32;
[[location(1)]] var<in> var_b : f32;
@@ -310,7 +767,8 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
+// TODO(crbug.com/tint/697): Remove this.
+TEST_F(VertexPullingTest, TwoAttributesSameBuffer_Legacy) {
auto* src = R"(
[[location(0)]] var<in> var_a : f32;
[[location(1)]] var<in> var_b : vec4<f32>;
@@ -362,127 +820,6 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(VertexPullingTest, FloatVectorAttributes) {
- auto* src = R"(
-[[location(0)]] var<in> var_a : vec2<f32>;
-[[location(1)]] var<in> var_b : vec3<f32>;
-[[location(2)]] var<in> var_c : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
- return vec4<f32>();
-}
-)";
-
- auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_pulling_vertex_index : u32;
-
-[[block]]
-struct TintVertexData {
- tint_vertex_data : [[stride(4)]] array<u32>;
-};
-
-[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
-
-[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
-
-[[binding(2), group(4)]] var<storage> tint_pulling_vertex_buffer_2 : [[access(read)]] TintVertexData;
-
-var<private> var_a : vec2<f32>;
-
-var<private> var_b : vec3<f32>;
-
-var<private> var_c : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
- {
- var tint_pulling_pos : u32;
- tint_pulling_pos = ((tint_pulling_vertex_index * 8u) + 0u);
- var_a = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]));
- tint_pulling_pos = ((tint_pulling_vertex_index * 12u) + 0u);
- var_b = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]));
- tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
- var_c = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)]));
- }
- return vec4<f32>();
-}
-)";
-
- VertexPulling::Config cfg;
- cfg.vertex_state = {{
- {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
- {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
- {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}},
- }};
- cfg.entry_point_name = "main";
-
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, data);
-
- EXPECT_EQ(expect, str(got));
-}
-
-TEST_F(VertexPullingTest, AttemptSymbolCollision) {
- auto* src = R"(
-[[location(0)]] var<in> var_a : f32;
-[[location(1)]] var<in> var_b : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
- var tint_pulling_vertex_index : i32;
- var tint_pulling_vertex_buffer_0 : i32;
- var tint_vertex_data : i32;
- var tint_pulling_pos : i32;
- return vec4<f32>();
-}
-)";
-
- auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_pulling_vertex_index_1 : u32;
-
-[[block]]
-struct TintVertexData {
- tint_vertex_data_1 : [[stride(4)]] array<u32>;
-};
-
-[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0_1 : [[access(read)]] TintVertexData;
-
-var<private> var_a : f32;
-
-var<private> var_b : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
- {
- var tint_pulling_pos_1 : u32;
- tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
- var_a = bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(tint_pulling_pos_1 / 4u)]);
- tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
- var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 12u) / 4u)]));
- }
- var tint_pulling_vertex_index : i32;
- var tint_pulling_vertex_buffer_0 : i32;
- var tint_vertex_data : i32;
- var tint_pulling_pos : i32;
- return vec4<f32>();
-}
-)";
-
- VertexPulling::Config cfg;
- cfg.vertex_state = {
- {{16,
- InputStepMode::kVertex,
- {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
- cfg.entry_point_name = "main";
-
- DataMap data;
- data.Add<VertexPulling::Config>(cfg);
- auto got = Run<VertexPulling>(src, std::move(data));
-
- EXPECT_EQ(expect, str(got));
-}
} // namespace
} // namespace transform
} // namespace tint