[spirv-writer] Handle building vectors from other vectors.
This Cl updates the composite construction to handle decomposing vectors
into smaller parts before building the composite.
Bug: tint:61
Change-Id: I7e0ac3a5c966dbcdf6429d508a392756f521b756
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20541
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 8011e74..802e34a 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -680,6 +680,9 @@
}
constructor_is_const = false;
}
+ }
+
+ for (const auto& e : init->values()) {
uint32_t id = 0;
if (constructor_is_const) {
id = GenerateConstructorExpression(e->AsConstructor(), is_global_init);
@@ -691,8 +694,32 @@
return 0;
}
- out << "_" << id;
- ops.push_back(Operand::Int(id));
+ auto* result_type = e->result_type()->UnwrapPtrIfNeeded();
+
+ // If we're putting a vector into the constructed composite we need to
+ // extract each of the values and insert them individually
+ if (result_type->IsVector()) {
+ auto* vec = result_type->AsVector();
+ auto result_type_id = GenerateTypeIfNeeded(vec->type());
+ if (result_type_id == 0) {
+ return 0;
+ }
+
+ for (uint32_t i = 0; i < vec->size(); ++i) {
+ auto extract = result_op();
+ auto extract_id = extract.to_i();
+
+ push_function_inst(spv::Op::OpCompositeExtract,
+ {Operand::Int(result_type_id), extract,
+ Operand::Int(id), Operand::Int(i)});
+
+ out << "_" << extract_id;
+ ops.push_back(Operand::Int(extract_id));
+ }
+ } else {
+ out << "_" << id;
+ ops.push_back(Operand::Int(id));
+ }
}
auto str = out.str();
diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc
index efa3875..b7af4c8 100644
--- a/src/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/writer/spirv/builder_constructor_expression_test.cc
@@ -65,7 +65,11 @@
ast::TypeConstructorExpression t(&vec, std::move(vals));
+ Context ctx;
ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(&t)) << td.error();
+
Builder b(&mod);
EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u);
ASSERT_FALSE(b.has_error()) << b.error();
@@ -120,6 +124,54 @@
)");
}
+TEST_F(BuilderTest, Constructor_Type_NonConstVector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec2(&f32, 2);
+ ast::type::VectorType vec4(&f32, 4);
+
+ auto var = std::make_unique<ast::Variable>(
+ "ident", ast::StorageClass::kFunction, &vec2);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::IdentifierExpression>("ident"));
+
+ ast::TypeConstructorExpression t(&vec4, std::move(vals));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ EXPECT_TRUE(td.DetermineResultType(&t)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateFunctionVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateConstructorExpression(&t, false), 10u);
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 2
+%2 = OpTypePointer Function %3
+%5 = OpTypeVector %4 4
+%6 = OpConstant %4 1
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+ R"(%1 = OpVariable %2 Function
+)");
+
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%7 = OpLoad %3 %1
+%8 = OpCompositeExtract %4 %7 0
+%9 = OpCompositeExtract %4 %7 1
+%10 = OpCompositeConstruct %5 %6 %6 %8 %9
+)");
+}
+
TEST_F(BuilderTest, Constructor_Type_Dedups) {
ast::type::F32Type f32;
ast::type::VectorType vec(&f32, 3);
@@ -134,7 +186,11 @@
ast::TypeConstructorExpression t(&vec, std::move(vals));
+ Context ctx;
ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(&t)) << td.error();
+
Builder b(&mod);
EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u);
EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u);
diff --git a/src/writer/spirv/builder_function_variable_test.cc b/src/writer/spirv/builder_function_variable_test.cc
index 38e98bd..33ce0f1 100644
--- a/src/writer/spirv/builder_function_variable_test.cc
+++ b/src/writer/spirv/builder_function_variable_test.cc
@@ -30,6 +30,8 @@
#include "src/ast/type_constructor_expression.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decoration.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@@ -74,10 +76,16 @@
auto init =
std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
v.set_constructor(std::move(init));
- ast::Module mod;
+ td.RegisterVariableForTesting(&v);
+
Builder b(&mod);
b.push_function(Function{});
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
@@ -161,11 +169,17 @@
auto init =
std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
v.set_constructor(std::move(init));
v.set_is_const(true);
- ast::Module mod;
+ td.RegisterVariableForTesting(&v);
+
Builder b(&mod);
EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index aad5670..8763d87 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -29,6 +29,8 @@
#include "src/ast/type_constructor_expression.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decoration.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@@ -84,10 +86,15 @@
auto init =
std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
v.set_constructor(std::move(init));
+ td.RegisterVariableForTesting(&v);
- ast::Module mod;
Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
@@ -119,11 +126,16 @@
auto init =
std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
v.set_constructor(std::move(init));
v.set_is_const(true);
+ td.RegisterVariableForTesting(&v);
- ast::Module mod;
Builder b(&mod);
EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
ASSERT_FALSE(b.has_error()) << b.error();
diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc
index e69bb73..5fb39b5 100644
--- a/src/writer/spirv/builder_ident_expression_test.cc
+++ b/src/writer/spirv/builder_ident_expression_test.cc
@@ -52,13 +52,15 @@
auto init =
std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
ast::Variable v("var", ast::StorageClass::kOutput, &f32);
v.set_constructor(std::move(init));
v.set_is_const(true);
- Context ctx;
- ast::Module mod;
- TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(&v);
Builder b(&mod);
@@ -117,13 +119,14 @@
auto init =
std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
- ast::Variable v("var", ast::StorageClass::kOutput, &f32);
- v.set_constructor(std::move(init));
- v.set_is_const(true);
-
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
+ ast::Variable v("var", ast::StorageClass::kOutput, &f32);
+ v.set_constructor(std::move(init));
+ v.set_is_const(true);
td.RegisterVariableForTesting(&v);
Builder b(&mod);
diff --git a/src/writer/spirv/builder_return_test.cc b/src/writer/spirv/builder_return_test.cc
index e6ab54c..48933a4 100644
--- a/src/writer/spirv/builder_return_test.cc
+++ b/src/writer/spirv/builder_return_test.cc
@@ -21,6 +21,8 @@
#include "src/ast/type/f32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@@ -61,7 +63,11 @@
ast::ReturnStatement ret(std::move(val));
+ Context ctx;
ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ EXPECT_TRUE(td.DetermineResultType(&ret)) << td.error();
+
Builder b(&mod);
b.push_function(Function{});
EXPECT_TRUE(b.GenerateReturnStatement(&ret));