[spirv-writer] Only extract composites for non-const constructors.

Currently we will attempt to extract composite values for constant
constructors which may happen outside of a function. This causes issues
as the extract requires us to be in a function.

Change-Id: I5724987542cc7d9d86493363ed4d9a44a391a52f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23221
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 607f2ce..df4829a 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -909,6 +909,15 @@
     return 0;
   }
 
+  auto* result_type = init->type()->UnwrapPtrIfNeeded();
+  if (result_type->IsVector()) {
+    result_type = result_type->AsVector()->type();
+  } else if (result_type->IsArray()) {
+    result_type = result_type->AsArray()->type();
+  } else if (result_type->IsMatrix()) {
+    result_type = result_type->AsMatrix()->type();
+  }
+
   std::ostringstream out;
   out << "__const";
 
@@ -924,6 +933,8 @@
     }
   }
 
+  bool result_is_constant_composite = constructor_is_const;
+  bool result_is_spec_composite = false;
   for (const auto& e : init->values()) {
     uint32_t id = 0;
     if (constructor_is_const) {
@@ -936,27 +947,69 @@
       return 0;
     }
 
-    auto* result_type = e->result_type()->UnwrapPtrIfNeeded();
+    auto* value_type = e->result_type()->UnwrapPtrIfNeeded();
 
-    // If we're putting a vector into the constructed composite we need to
-    // extract each of the values and insert them individually
-    if (result_type->IsVector()) {
-      auto* vec = result_type->AsVector();
-      auto result_type_id = GenerateTypeIfNeeded(vec->type());
-      if (result_type_id == 0) {
-        return 0;
-      }
+    // When handling vectors as the values there a few cases to take into
+    // consideration:
+    //  1. Module scoped vec3<f32>(vec2<f32>(1, 2), 3)  -> OpSpecConstantOp
+    //  2. Function scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpCompositeExtract
+    //  3. Either array<vec3<f32>, 1>(vec3<f32>(1, 2, 3))  -> use the ID.
+    if (value_type->IsVector()) {
+      auto* vec = value_type->AsVector();
+      auto* vec_type = vec->type();
 
-      for (uint32_t i = 0; i < vec->size(); ++i) {
-        auto extract = result_op();
-        auto extract_id = extract.to_i();
+      // If the value we want is the same as what we have, use it directly.
+      // This maps to case 3.
+      if (result_type == value_type) {
+        out << "_" << id;
+        ops.push_back(Operand::Int(id));
+      } else if (!is_global_init) {
+        // A non-global initializer. Case 2.
+        auto value_type_id = GenerateTypeIfNeeded(vec_type);
+        if (value_type_id == 0) {
+          return 0;
+        }
 
-        push_function_inst(spv::Op::OpCompositeExtract,
-                           {Operand::Int(result_type_id), extract,
-                            Operand::Int(id), Operand::Int(i)});
+        for (uint32_t i = 0; i < vec->size(); ++i) {
+          auto extract = result_op();
+          auto extract_id = extract.to_i();
 
-        out << "_" << extract_id;
-        ops.push_back(Operand::Int(extract_id));
+          push_function_inst(spv::Op::OpCompositeExtract,
+                             {Operand::Int(value_type_id), extract,
+                              Operand::Int(id), Operand::Int(i)});
+
+          out << "_" << extract_id;
+          ops.push_back(Operand::Int(extract_id));
+
+          // We no longer have a constant composite, but have to do a
+          // composite construction as these calls are inside a function.
+          result_is_constant_composite = false;
+        }
+      } else {
+        // A global initializer, must use OpSpecConstantOp. Case 1.
+        auto value_type_id = GenerateTypeIfNeeded(vec_type);
+        if (value_type_id == 0) {
+          return 0;
+        }
+
+        for (uint32_t i = 0; i < vec->size(); ++i) {
+          auto extract = result_op();
+          auto extract_id = extract.to_i();
+
+          auto idx_id = GenerateU32Literal(i);
+          if (idx_id == 0) {
+            return 0;
+          }
+          push_type(spv::Op::OpSpecConstantOp,
+                    {Operand::Int(value_type_id), extract,
+                     Operand::Int(SpvOpCompositeExtract), Operand::Int(id),
+                     Operand::Int(idx_id)});
+
+          out << "_" << extract_id;
+          ops.push_back(Operand::Int(extract_id));
+
+          result_is_spec_composite = true;
+        }
       }
     } else {
       out << "_" << id;
@@ -976,7 +1029,9 @@
 
   const_to_id_[str] = result.to_i();
 
-  if (constructor_is_const) {
+  if (result_is_spec_composite) {
+    push_type(spv::Op::OpSpecConstantComposite, ops);
+  } else if (result_is_constant_composite) {
     push_type(spv::Op::OpConstantComposite, ops);
   } else {
     push_function_inst(spv::Op::OpCompositeConstruct, ops);
diff --git a/src/writer/spirv/builder_assign_test.cc b/src/writer/spirv/builder_assign_test.cc
index 46875a7..ad98cc3 100644
--- a/src/writer/spirv/builder_assign_test.cc
+++ b/src/writer/spirv/builder_assign_test.cc
@@ -78,6 +78,141 @@
 )");
 }
 
+TEST_F(BuilderTest, Assign_Var_Complex_ConstructorWithExtract) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::VectorType vec2(&f32, 2);
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  auto first =
+      std::make_unique<ast::TypeConstructorExpression>(&vec2, std::move(vals));
+
+  vals.push_back(std::move(first));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+  auto init =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  ast::Variable v("var", ast::StorageClass::kOutput, &vec3);
+
+  ast::AssignmentStatement assign(
+      std::make_unique<ast::IdentifierExpression>("var"), std::move(init));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(&v);
+  ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+  ASSERT_FALSE(b.has_error()) << b.error();
+
+  EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
+  EXPECT_FALSE(b.has_error());
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Output %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Output %5
+%6 = OpTypeVector %4 2
+%7 = OpConstant %4 1
+%8 = OpConstant %4 2
+%9 = OpConstantComposite %6 %7 %8
+%12 = OpConstant %4 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%10 = OpCompositeExtract %4 %9 0
+%11 = OpCompositeExtract %4 %9 1
+%13 = OpCompositeConstruct %3 %10 %11 %12
+OpStore %1 %13
+)");
+}
+
+TEST_F(BuilderTest, Assign_Var_Complex_Constructor) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::VectorType vec(&vec3, 3);
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+  auto first =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  auto second =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+  auto third =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  vals.push_back(std::move(first));
+  vals.push_back(std::move(second));
+  vals.push_back(std::move(third));
+
+  auto init =
+      std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+
+  ast::Variable v("var", ast::StorageClass::kOutput, &vec);
+
+  ast::AssignmentStatement assign(
+      std::make_unique<ast::IdentifierExpression>("var"), std::move(init));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(&v);
+  ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+  ASSERT_FALSE(b.has_error()) << b.error();
+
+  EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error();
+  EXPECT_FALSE(b.has_error());
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Output %3
+%6 = OpConstantNull %3
+%1 = OpVariable %2 Output %6
+%7 = OpConstant %5 1
+%8 = OpConstant %5 2
+%9 = OpConstant %5 3
+%10 = OpConstantComposite %4 %7 %8 %9
+%11 = OpConstantComposite %4 %9 %8 %7
+%12 = OpConstantComposite %4 %8 %7 %9
+%13 = OpConstantComposite %3 %10 %11 %12
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpStore %1 %13
+)");
+}
+
 TEST_F(BuilderTest, Assign_StructMember) {
   ast::type::F32Type f32;
 
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 306d2d6..8618f39 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -165,6 +165,123 @@
 )");
 }
 
+TEST_F(BuilderTest, GlobalVar_Complex_Constructor) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::VectorType vec(&vec3, 3);
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+  auto first =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  auto second =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+  auto third =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  vals.push_back(std::move(first));
+  vals.push_back(std::move(second));
+  vals.push_back(std::move(third));
+
+  auto init =
+      std::make_unique<ast::TypeConstructorExpression>(&vec, std::move(vals));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
+  ast::Variable v("var", ast::StorageClass::kOutput, &f32);
+  v.set_constructor(std::move(init));
+  v.set_is_const(true);
+  td.RegisterVariableForTesting(&v);
+
+  Builder b(&mod);
+  EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+  ASSERT_FALSE(b.has_error()) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypeVector %3 3
+%1 = OpTypeVector %2 3
+%4 = OpConstant %3 1
+%5 = OpConstant %3 2
+%6 = OpConstant %3 3
+%7 = OpConstantComposite %2 %4 %5 %6
+%8 = OpConstantComposite %2 %6 %5 %4
+%9 = OpConstantComposite %2 %5 %4 %6
+%10 = OpConstantComposite %1 %7 %8 %9
+)");
+}
+
+TEST_F(BuilderTest, GlobalVar_Complex_ConstructorWithExtract) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::VectorType vec2(&f32, 2);
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0f)));
+  auto first =
+      std::make_unique<ast::TypeConstructorExpression>(&vec2, std::move(vals));
+
+  vals.push_back(std::move(first));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+  auto init =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error();
+
+  ast::Variable v("var", ast::StorageClass::kOutput, &f32);
+  v.set_constructor(std::move(init));
+  v.set_is_const(true);
+  td.RegisterVariableForTesting(&v);
+
+  Builder b(&mod);
+  EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+  ASSERT_FALSE(b.has_error()) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%1 = OpTypeVector %2 3
+%3 = OpTypeVector %2 2
+%4 = OpConstant %2 1
+%5 = OpConstant %2 2
+%6 = OpConstantComposite %3 %4 %5
+%8 = OpTypeInt 32 0
+%9 = OpConstant %8 0
+%7 = OpSpecConstantOp %2 CompositeExtract %6 9
+%11 = OpConstant %8 1
+%10 = OpSpecConstantOp %2 CompositeExtract %6 11
+%12 = OpConstant %2 3
+%13 = OpSpecConstantComposite %1 %7 %10 %12
+)");
+}
+
 TEST_F(BuilderTest, GlobalVar_WithLocation) {
   ast::type::F32Type f32;
   auto v =