validation: validate struct constructor

Bug: tint:864
Change-Id: I57db071bcda96d45f758bcdbc47c6ef0a4a8192d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57280
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/builtins_validation_test.cc b/src/resolver/builtins_validation_test.cc
index c6aa820..4af121d 100644
--- a/src/resolver/builtins_validation_test.cc
+++ b/src/resolver/builtins_validation_test.cc
@@ -27,14 +27,13 @@
 using vec3 = builder::vec3<T>;
 template <typename T>
 using vec4 = builder::vec4<T>;
-template <typename T>
 using f32 = builder::f32;
 using i32 = builder::i32;
 using u32 = builder::u32;
 
 class ResolverBuiltinsValidationTest : public resolver::TestHelper,
                                        public testing::Test {};
-namespace TypeTemp {
+namespace StageTest {
 struct Params {
   builder::ast_type_func_ptr type;
   ast::Builtin builtin;
@@ -218,7 +217,7 @@
       "12:34 error: builtin(frag_depth) cannot be used in input of fragment "
       "pipeline stage\nnote: while analysing entry point fragShader");
 }
-}  // namespace TypeTemp
+}  // namespace StageTest
 
 TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
   // struct MyInputs {
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 881d380..0c48f04 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2266,6 +2266,9 @@
     if (auto* arr_type = type->As<sem::Array>()) {
       return ValidateArrayConstructor(type_ctor, arr_type);
     }
+    if (auto* struct_type = type->As<sem::Struct>()) {
+      return ValidateStructureConstructor(type_ctor, struct_type);
+    }
   } else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
     Mark(scalar_ctor->literal());
     auto* type = TypeOf(scalar_ctor->literal());
@@ -2280,6 +2283,36 @@
   return true;
 }
 
+bool Resolver::ValidateStructureConstructor(
+    const ast::TypeConstructorExpression* ctor,
+    const sem::Struct* struct_type) {
+  if (ctor->values().size() > 0) {
+    if (ctor->values().size() != struct_type->Members().size()) {
+      std::string fm = ctor->values().size() < struct_type->Members().size()
+                           ? "few"
+                           : "many";
+      AddError("struct constructor has too " + fm + " inputs: expected " +
+                   std::to_string(struct_type->Members().size()) + ", found " +
+                   std::to_string(ctor->values().size()),
+               ctor->source());
+      return false;
+    }
+    for (auto* member : struct_type->Members()) {
+      auto* value = ctor->values()[member->Index()];
+      if (member->Type() != TypeOf(value)->UnwrapRef()) {
+        AddError(
+            "type in struct constructor does not match struct member type: "
+            "expected '" +
+                member->Type()->FriendlyName(builder_->Symbols()) +
+                "', found '" + TypeNameOf(value) + "'",
+            value->source());
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 bool Resolver::ValidateArrayConstructor(
     const ast::TypeConstructorExpression* ctor,
     const sem::Array* array_type) {
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 3306294..04b7326 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -294,6 +294,8 @@
   bool ValidateStatements(const ast::StatementList& stmts);
   bool ValidateStorageTexture(const ast::StorageTexture* t);
   bool ValidateStructure(const sem::Struct* str);
+  bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor,
+                                    const sem::Struct* struct_type);
   bool ValidateSwitch(const ast::SwitchStatement* s);
   bool ValidateVariable(const VariableInfo* info);
   bool ValidateVariableConstructor(const ast::Variable* var,
diff --git a/src/resolver/struct_pipeline_stage_use_test.cc b/src/resolver/struct_pipeline_stage_use_test.cc
index 1b2c18e..7c0c022 100644
--- a/src/resolver/struct_pipeline_stage_use_test.cc
+++ b/src/resolver/struct_pipeline_stage_use_test.cc
@@ -81,7 +81,7 @@
   auto* s = Structure(
       "S", {Member("a", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)})});
 
-  Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s), Expr(0.f)))},
+  Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s)))},
        {Stage(ast::PipelineStage::kVertex)});
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -141,8 +141,7 @@
       "S", {Member("a", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)})});
 
   Func("vert_main", {Param("param", ty.Of(s))}, ty.Of(s),
-       {Return(Construct(ty.Of(s), Expr(0.f)))},
-       {Stage(ast::PipelineStage::kVertex)});
+       {Return(Construct(ty.Of(s)))}, {Stage(ast::PipelineStage::kVertex)});
 
   Func("frag_main", {Param("param", ty.Of(s))}, ty.void_(), {},
        {Stage(ast::PipelineStage::kFragment)});
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 32b01d2..eaa2d2b 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -2335,6 +2335,176 @@
                                          MatrixDimensions{3, 4},
                                          MatrixDimensions{4, 4}));
 
+namespace StructConstructor {
+using builder::CreatePtrs;
+using builder::CreatePtrsFor;
+using builder::f32;
+using builder::i32;
+using builder::mat2x2;
+using builder::mat3x3;
+using builder::mat4x4;
+using builder::u32;
+using builder::vec2;
+using builder::vec3;
+using builder::vec4;
+
+constexpr CreatePtrs all_types[] = {
+    CreatePtrsFor<bool>(),         //
+    CreatePtrsFor<u32>(),          //
+    CreatePtrsFor<i32>(),          //
+    CreatePtrsFor<f32>(),          //
+    CreatePtrsFor<vec4<bool>>(),   //
+    CreatePtrsFor<vec2<i32>>(),    //
+    CreatePtrsFor<vec3<u32>>(),    //
+    CreatePtrsFor<vec4<f32>>(),    //
+    CreatePtrsFor<mat2x2<f32>>(),  //
+    CreatePtrsFor<mat3x3<f32>>(),  //
+    CreatePtrsFor<mat4x4<f32>>()   //
+};
+
+auto number_of_members = testing::Values(2u, 32u, 64u);
+
+using StructConstructorInputsTest =
+    ResolverTestWithParam<std::tuple<CreatePtrs,  // struct member type
+                                     uint32_t>>;  // number of struct members
+TEST_P(StructConstructorInputsTest, TooFew) {
+  auto& param = GetParam();
+  auto& str_params = std::get<0>(param);
+  uint32_t N = std::get<1>(param);
+
+  ast::StructMemberList members;
+  ast::ExpressionList values;
+  for (uint32_t i = 0; i < N; i++) {
+    auto* struct_type = str_params.ast(*this);
+    members.push_back(Member("member_" + std::to_string(i), struct_type));
+    if (i < N - 1) {
+      auto* ctor_value_expr = str_params.expr(*this, 0);
+      values.push_back(ctor_value_expr);
+    }
+  }
+  auto* s = Structure("s", members);
+  auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 34}}, ty.Of(s),
+                                                    values);
+  WrapInFunction(tc);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: struct constructor has too few inputs: expected " +
+                std::to_string(N) + ", found " + std::to_string(N - 1));
+}
+
+TEST_P(StructConstructorInputsTest, TooMany) {
+  auto& param = GetParam();
+  auto& str_params = std::get<0>(param);
+  uint32_t N = std::get<1>(param);
+
+  ast::StructMemberList members;
+  ast::ExpressionList values;
+  for (uint32_t i = 0; i < N + 1; i++) {
+    if (i < N) {
+      auto* struct_type = str_params.ast(*this);
+      members.push_back(Member("member_" + std::to_string(i), struct_type));
+    }
+    auto* ctor_value_expr = str_params.expr(*this, 0);
+    values.push_back(ctor_value_expr);
+  }
+  auto* s = Structure("s", members);
+  auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 34}}, ty.Of(s),
+                                                    values);
+  WrapInFunction(tc);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: struct constructor has too many inputs: expected " +
+                std::to_string(N) + ", found " + std::to_string(N + 1));
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverValidationTest,
+                         StructConstructorInputsTest,
+                         testing::Combine(testing::ValuesIn(all_types),
+                                          number_of_members));
+using StructConstructorTypeTest =
+    ResolverTestWithParam<std::tuple<CreatePtrs,  // struct member type
+                                     CreatePtrs,  // constructor value type
+                                     uint32_t>>;  // number of struct members
+TEST_P(StructConstructorTypeTest, AllTypes) {
+  auto& param = GetParam();
+  auto& str_params = std::get<0>(param);
+  auto& ctor_params = std::get<1>(param);
+  uint32_t N = std::get<2>(param);
+
+  if (str_params.ast == ctor_params.ast) {
+    return;
+  }
+
+  ast::StructMemberList members;
+  ast::ExpressionList values;
+  // make the last value of the constructor to have a different type
+  uint32_t constructor_value_with_different_type = N - 1;
+  for (uint32_t i = 0; i < N; i++) {
+    auto* struct_type = str_params.ast(*this);
+    members.push_back(Member("member_" + std::to_string(i), struct_type));
+    auto* ctor_value_expr = (i == constructor_value_with_different_type)
+                                ? ctor_params.expr(*this, 0)
+                                : str_params.expr(*this, 0);
+    values.push_back(ctor_value_expr);
+  }
+  auto* s = Structure("s", members);
+  auto* tc = create<ast::TypeConstructorExpression>(ty.Of(s), values);
+  WrapInFunction(tc);
+
+  std::string found = FriendlyName(ctor_params.ast(*this));
+  std::string expected = FriendlyName(str_params.ast(*this));
+  std::stringstream err;
+  err << "error: type in struct constructor does not match struct member ";
+  err << "type: expected '" << expected << "', found '" << found << "'";
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), err.str());
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverValidationTest,
+                         StructConstructorTypeTest,
+                         testing::Combine(testing::ValuesIn(all_types),
+                                          testing::ValuesIn(all_types),
+                                          number_of_members));
+
+TEST_F(ResolverValidationTest, Expr_Constructor_Struct_Nested) {
+  auto* inner_m = Member("m", ty.i32());
+  auto* inner_s = Structure("inner_s", {inner_m});
+
+  auto* m0 = Member("m", ty.i32());
+  auto* m1 = Member("m", ty.Of(inner_s));
+  auto* m2 = Member("m", ty.i32());
+  auto* s = Structure("s", {m0, m1, m2});
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 34}}, ty.Of(s),
+                                                    ExprList(1, 1, 1));
+  WrapInFunction(tc);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "error: type in struct constructor does not match struct member "
+            "type: expected 'inner_s', found 'i32'");
+}
+
+TEST_F(ResolverValidationTest, Expr_Constructor_Struct) {
+  auto* m = Member("m", ty.i32());
+  auto* s = Structure("MyInputs", {m});
+  auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 34}}, ty.Of(s),
+                                                    ExprList());
+  WrapInFunction(tc);
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Expr_Constructor_Struct_Empty) {
+  auto* str = Structure("S", {
+                                 Member("a", ty.i32()),
+                                 Member("b", ty.f32()),
+                                 Member("c", ty.vec3<i32>()),
+                             });
+
+  WrapInFunction(Construct(ty.Of(str)));
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+}  // namespace StructConstructor
+
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index 189e8d5..f897004 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -500,7 +500,7 @@
 
 [[stage(vertex)]]
 fn vert_main(in : VertexIn) -> VertexOut {
-  return VertexOut(in.i, in.u, in.vi, in.vu);
+  return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
 }
 
 [[stage(fragment)]]
@@ -561,7 +561,7 @@
 [[stage(vertex)]]
 fn vert_main() {
   let tint_symbol_4 : VertexIn = VertexIn(tint_symbol, tint_symbol_1, tint_symbol_2, tint_symbol_3);
-  tint_symbol_11(VertexOut(tint_symbol_4.i, tint_symbol_4.u, tint_symbol_4.vi, tint_symbol_4.vu));
+  tint_symbol_11(VertexOut(tint_symbol_4.i, tint_symbol_4.u, tint_symbol_4.vi, tint_symbol_4.vu, vec4<f32>()));
   return;
 }
 
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 1499596..57de37c 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -264,13 +264,10 @@
        {});
 
   Func("vert_main1", {}, ty.Of(vertex_output_struct),
-       {Return(Construct(ty.Of(vertex_output_struct),
-                         Expr(Call("foo", Expr(0.5f)))))},
-       {Stage(ast::PipelineStage::kVertex)});
+       {Return(Call("foo", Expr(0.5f)))}, {Stage(ast::PipelineStage::kVertex)});
 
   Func("vert_main2", {}, ty.Of(vertex_output_struct),
-       {Return(Construct(ty.Of(vertex_output_struct),
-                         Expr(Call("foo", Expr(0.25f)))))},
+       {Return(Call("foo", Expr(0.25f)))},
        {Stage(ast::PipelineStage::kVertex)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -290,7 +287,7 @@
 };
 
 tint_symbol vert_main1() {
-  const VertexOutput tint_symbol_1 = {foo(0.5f)};
+  const VertexOutput tint_symbol_1 = foo(0.5f);
   const tint_symbol tint_symbol_5 = {tint_symbol_1.pos};
   return tint_symbol_5;
 }
@@ -300,7 +297,7 @@
 };
 
 tint_symbol_2 vert_main2() {
-  const VertexOutput tint_symbol_3 = {foo(0.25f)};
+  const VertexOutput tint_symbol_3 = foo(0.25f);
   const tint_symbol_2 tint_symbol_6 = {tint_symbol_3.pos};
   return tint_symbol_6;
 }
diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc
index 08c130e..f1ba7c1 100644
--- a/src/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/writer/spirv/builder_constructor_expression_test.cc
@@ -1793,9 +1793,9 @@
                                    });
 
   Global("a", ty.f32(), ast::StorageClass::kPrivate);
-  Global("b", ty.f32(), ast::StorageClass::kPrivate);
+  Global("b", ty.vec3<f32>(), ast::StorageClass::kPrivate);
 
-  auto* t = Construct(ty.Of(s), 2.f, "a", 2.f);
+  auto* t = Construct(ty.Of(s), "a", "b");
   WrapInFunction(t);
 
   spirv::Builder& b = Build();