[spirv-writer] Only extract composites for non-const constructors.
Currently we will attempt to extract composite values for constant
constructors which may happen outside of a function. This causes issues
as the extract requires us to be in a function.
Change-Id: I5724987542cc7d9d86493363ed4d9a44a391a52f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23221
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 607f2ce..df4829a 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -909,6 +909,15 @@
return 0;
}
+ auto* result_type = init->type()->UnwrapPtrIfNeeded();
+ if (result_type->IsVector()) {
+ result_type = result_type->AsVector()->type();
+ } else if (result_type->IsArray()) {
+ result_type = result_type->AsArray()->type();
+ } else if (result_type->IsMatrix()) {
+ result_type = result_type->AsMatrix()->type();
+ }
+
std::ostringstream out;
out << "__const";
@@ -924,6 +933,8 @@
}
}
+ bool result_is_constant_composite = constructor_is_const;
+ bool result_is_spec_composite = false;
for (const auto& e : init->values()) {
uint32_t id = 0;
if (constructor_is_const) {
@@ -936,27 +947,69 @@
return 0;
}
- auto* result_type = e->result_type()->UnwrapPtrIfNeeded();
+ auto* value_type = e->result_type()->UnwrapPtrIfNeeded();
- // If we're putting a vector into the constructed composite we need to
- // extract each of the values and insert them individually
- if (result_type->IsVector()) {
- auto* vec = result_type->AsVector();
- auto result_type_id = GenerateTypeIfNeeded(vec->type());
- if (result_type_id == 0) {
- return 0;
- }
+ // When handling vectors as the values there a few cases to take into
+ // consideration:
+ // 1. Module scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpSpecConstantOp
+ // 2. Function scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpCompositeExtract
+ // 3. Either array<vec3<f32>, 1>(vec3<f32>(1, 2, 3)) -> use the ID.
+ if (value_type->IsVector()) {
+ auto* vec = value_type->AsVector();
+ auto* vec_type = vec->type();
- for (uint32_t i = 0; i < vec->size(); ++i) {
- auto extract = result_op();
- auto extract_id = extract.to_i();
+ // If the value we want is the same as what we have, use it directly.
+ // This maps to case 3.
+ if (result_type == value_type) {
+ out << "_" << id;
+ ops.push_back(Operand::Int(id));
+ } else if (!is_global_init) {
+ // A non-global initializer. Case 2.
+ auto value_type_id = GenerateTypeIfNeeded(vec_type);
+ if (value_type_id == 0) {
+ return 0;
+ }
- push_function_inst(spv::Op::OpCompositeExtract,
- {Operand::Int(result_type_id), extract,
- Operand::Int(id), Operand::Int(i)});
+ for (uint32_t i = 0; i < vec->size(); ++i) {
+ auto extract = result_op();
+ auto extract_id = extract.to_i();
- out << "_" << extract_id;
- ops.push_back(Operand::Int(extract_id));
+ push_function_inst(spv::Op::OpCompositeExtract,
+ {Operand::Int(value_type_id), extract,
+ Operand::Int(id), Operand::Int(i)});
+
+ out << "_" << extract_id;
+ ops.push_back(Operand::Int(extract_id));
+
+ // We no longer have a constant composite, but have to do a
+ // composite construction as these calls are inside a function.
+ result_is_constant_composite = false;
+ }
+ } else {
+ // A global initializer, must use OpSpecConstantOp. Case 1.
+ auto value_type_id = GenerateTypeIfNeeded(vec_type);
+ if (value_type_id == 0) {
+ return 0;
+ }
+
+ for (uint32_t i = 0; i < vec->size(); ++i) {
+ auto extract = result_op();
+ auto extract_id = extract.to_i();
+
+ auto idx_id = GenerateU32Literal(i);
+ if (idx_id == 0) {
+ return 0;
+ }
+ push_type(spv::Op::OpSpecConstantOp,
+ {Operand::Int(value_type_id), extract,
+ Operand::Int(SpvOpCompositeExtract), Operand::Int(id),
+ Operand::Int(idx_id)});
+
+ out << "_" << extract_id;
+ ops.push_back(Operand::Int(extract_id));
+
+ result_is_spec_composite = true;
+ }
}
} else {
out << "_" << id;
@@ -976,7 +1029,9 @@
const_to_id_[str] = result.to_i();
- if (constructor_is_const) {
+ if (result_is_spec_composite) {
+ push_type(spv::Op::OpSpecConstantComposite, ops);
+ } else if (result_is_constant_composite) {
push_type(spv::Op::OpConstantComposite, ops);
} else {
push_function_inst(spv::Op::OpCompositeConstruct, ops);
diff --git a/src/writer/spirv/builder_assign_test.cc b/src/writer/spirv/builder_assign_test.cc
index 46875a7..ad98cc3 100644
--- a/src/writer/spirv/builder_assign_test.cc
+++ b/src/writer/spirv/builder_assign_test.cc
@@ -78,6 +78,141 @@
)");
}
+TEST_F(BuilderTest, Assign_Var_Complex_ConstructorWithExtract) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::VectorType vec2(&f32, 2);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ auto first =
+ std::make_unique<ast::TypeConstructorExpression>(&vec2, std::move(vals));
+
+ vals.push_back(std::move(first));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ auto init =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ ast::Variable v("var", ast::StorageClass::kOutput, &vec3);
+
+ ast::AssignmentStatement assign(
+ std::make_unique<ast::IdentifierExpression>("var"), std::move(init));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(&v);
+ ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
+ EXPECT_FALSE(b.has_error());
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Output %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Output %5
+%6 = OpTypeVector %4 2
+%7 = OpConstant %4 1
+%8 = OpConstant %4 2
+%9 = OpConstantComposite %6 %7 %8
+%12 = OpConstant %4 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%10 = OpCompositeExtract %4 %9 0
+%11 = OpCompositeExtract %4 %9 1
+%13 = OpCompositeConstruct %3 %10 %11 %12
+OpStore %1 %13
+)");
+}
+
+TEST_F(BuilderTest, Assign_Var_Complex_Constructor) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::VectorType vec(&vec3, 3);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+ auto first =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ auto second =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+ auto third =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ vals.push_back(std::move(first));
+ vals.push_back(std::move(second));
+ vals.push_back(std::move(third));
+
+ auto init =
+ std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+
+ ast::Variable v("var", ast::StorageClass::kOutput, &vec);
+
+ ast::AssignmentStatement assign(
+ std::make_unique<ast::IdentifierExpression>("var"), std::move(init));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(&v);
+ ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
+ EXPECT_FALSE(b.has_error());
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Output %3
+%6 = OpConstantNull %3
+%1 = OpVariable %2 Output %6
+%7 = OpConstant %5 1
+%8 = OpConstant %5 2
+%9 = OpConstant %5 3
+%10 = OpConstantComposite %4 %7 %8 %9
+%11 = OpConstantComposite %4 %9 %8 %7
+%12 = OpConstantComposite %4 %8 %7 %9
+%13 = OpConstantComposite %3 %10 %11 %12
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpStore %1 %13
+)");
+}
+
TEST_F(BuilderTest, Assign_StructMember) {
ast::type::F32Type f32;
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 306d2d6..8618f39 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -165,6 +165,123 @@
)");
}
+TEST_F(BuilderTest, GlobalVar_Complex_Constructor) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::VectorType vec(&vec3, 3);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+ auto first =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ auto second =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+ auto third =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ vals.push_back(std::move(first));
+ vals.push_back(std::move(second));
+ vals.push_back(std::move(third));
+
+ auto init =
+ std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
+ ast::Variable v("var", ast::StorageClass::kOutput, &f32);
+ v.set_constructor(std::move(init));
+ v.set_is_const(true);
+ td.RegisterVariableForTesting(&v);
+
+ Builder b(&mod);
+ EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypeVector %3 3
+%1 = OpTypeVector %2 3
+%4 = OpConstant %3 1
+%5 = OpConstant %3 2
+%6 = OpConstant %3 3
+%7 = OpConstantComposite %2 %4 %5 %6
+%8 = OpConstantComposite %2 %6 %5 %4
+%9 = OpConstantComposite %2 %5 %4 %6
+%10 = OpConstantComposite %1 %7 %8 %9
+)");
+}
+
+TEST_F(BuilderTest, GlobalVar_Complex_ConstructorWithExtract) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::VectorType vec2(&f32, 2);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+ auto first =
+ std::make_unique<ast::TypeConstructorExpression>(&vec2, std::move(vals));
+
+ vals.push_back(std::move(first));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ auto init =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
+ ast::Variable v("var", ast::StorageClass::kOutput, &f32);
+ v.set_constructor(std::move(init));
+ v.set_is_const(true);
+ td.RegisterVariableForTesting(&v);
+
+ Builder b(&mod);
+ EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%1 = OpTypeVector %2 3
+%3 = OpTypeVector %2 2
+%4 = OpConstant %2 1
+%5 = OpConstant %2 2
+%6 = OpConstantComposite %3 %4 %5
+%8 = OpTypeInt 32 0
+%9 = OpConstant %8 0
+%7 = OpSpecConstantOp %2 CompositeExtract %6 9
+%11 = OpConstant %8 1
+%10 = OpSpecConstantOp %2 CompositeExtract %6 11
+%12 = OpConstant %2 3
+%13 = OpSpecConstantComposite %1 %7 %10 %12
+)");
+}
+
TEST_F(BuilderTest, GlobalVar_WithLocation) {
ast::type::F32Type f32;
auto v =