transform/hlsl: Hoist structure constructors to new var

HLSL has some pecular rules around structure constructors.
`S s = S(1,2,3)` is not valid, but `S s = {1,2,3}` is.

This matches the quirkiness with array initializers, so adjust the array
hoisting logic to also support structures.

Fixed: tint:702
Change-Id: Ifdcafd98292715ae2482f72ec06c87842176d270
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46875
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 918d6e3..9db68e0 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -32,24 +32,24 @@
 Transform::Output Hlsl::Run(const Program* in, const DataMap&) {
   ProgramBuilder out;
   CloneContext ctx(&out, in);
-  PromoteArrayInitializerToConstVar(ctx);
+  PromoteInitializersToConstVar(ctx);
   AddEmptyEntryPoint(ctx);
   ctx.Clone();
   return Output{Program(std::move(out))};
 }
 
-void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const {
-  // Scan the AST nodes for array initializers which need to be promoted to
-  // their own constant declaration.
+void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {
+  // Scan the AST nodes for array and structure initializers which
+  // need to be promoted to their own constant declaration.
 
-  // Note: Correct handling of arrays-of-arrays is guaranteed due to the
+  // Note: Correct handling of nested expressions is guaranteed due to the
   // depth-first traversal of the ast::Node::Clone() methods:
   //
-  // The inner-most array initializers are traversed first, and they are hoisted
+  // The inner-most initializers are traversed first, and they are hoisted
   // to const variables declared just above the statement of use. The outer
-  // array initializer will then be hoisted, inserting themselves between the
-  // inner array declaration and the statement of use. This pattern applies
-  // correctly to any nested depth.
+  // initializer will then be hoisted, inserting themselves between the
+  // inner declaration and the statement of use. This pattern applies correctly
+  // to any nested depth.
   //
   // Depth-first traversal of the AST is guaranteed because AST nodes are fully
   // immutable and require their children to be constructed first so their
@@ -75,22 +75,23 @@
 
       if (auto* src_var_decl = src_stmt->As<ast::VariableDeclStatement>()) {
         if (src_var_decl->variable()->constructor() == src_init) {
-          // This statement is just a variable declaration with the array
-          // initializer as the constructor value. This is what we're
-          // attempting to transform to, and so ignore.
+          // This statement is just a variable declaration with the initializer
+          // as the constructor value. This is what we're attempting to
+          // transform to, and so ignore.
           continue;
         }
       }
 
-      if (auto* src_array_ty = src_sem_expr->Type()->As<type::Array>()) {
+      auto* src_ty = src_sem_expr->Type();
+      if (src_ty->IsAnyOf<type::Array, type::Struct>()) {
         // Create a new symbol for the constant
         auto dst_symbol = ctx.dst->Symbols().New();
-        // Clone the array type
-        auto* dst_array_ty = ctx.Clone(src_array_ty);
-        // Clone the array initializer
+        // Clone the type
+        auto* dst_ty = ctx.Clone(src_ty);
+        // Clone the initializer
         auto* dst_init = ctx.Clone(src_init);
-        // Construct the constant that holds the array
-        auto* dst_var = ctx.dst->Const(dst_symbol, dst_array_ty, dst_init);
+        // Construct the constant that holds the hoisted initializer
+        auto* dst_var = ctx.dst->Const(dst_symbol, dst_ty, dst_init);
         // Construct the variable declaration statement
         auto* dst_var_decl =
             ctx.dst->create<ast::VariableDeclStatement>(dst_var);
@@ -100,7 +101,7 @@
         // Insert the constant before the usage
         ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
                          dst_var_decl);
-        // Replace the inlined array with a reference to the constant
+        // Replace the inlined initializer with a reference to the constant
         ctx.Replace(src_init, dst_ident);
       }
     }
diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h
index df903a7..ff7978f 100644
--- a/src/transform/hlsl.h
+++ b/src/transform/hlsl.h
@@ -40,10 +40,10 @@
   Output Run(const Program* program, const DataMap& data = {}) override;
 
  private:
-  /// Hoists the array initializer to a constant variable, declared just before
-  /// the array usage statement.
-  /// See crbug.com/tint/406 for more details
-  void PromoteArrayInitializerToConstVar(CloneContext& ctx) const;
+  /// Hoists the array and structure initializers to a constant variable,
+  /// declared just before the statement of usage. See crbug.com/tint/406 for
+  /// more details
+  void PromoteInitializersToConstVar(CloneContext& ctx) const;
   /// Add an empty shader entry point if none exist in the module.
   void AddEmptyEntryPoint(CloneContext& ctx) const;
 };
diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc
index 4ce2259..27779bd 100644
--- a/src/transform/hlsl_test.cc
+++ b/src/transform/hlsl_test.cc
@@ -51,6 +51,39 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Basic) {
+  auto* src = R"(
+struct S {
+  a : i32;
+  b : f32;
+  c : vec3<f32>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+  var x : f32 = S(1, 2.0, vec3<f32>()).b;
+}
+)";
+
+  auto* expect = R"(
+struct S {
+  a : i32;
+  b : f32;
+  c : vec3<f32>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+  const tint_symbol_1 : S = S(1, 2.0, vec3<f32>());
+  var x : f32 = tint_symbol_1.b;
+}
+)";
+
+  auto got = Run<Hlsl>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) {
   auto* src = R"(
 [[stage(vertex)]]
@@ -74,14 +107,115 @@
   EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(HlslTest, PromoteArrayInitializerToConstVar_NoChangeOnArrayVarDecl) {
+TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Nested) {
   auto* src = R"(
+struct S1 {
+  a : i32;
+};
+
+struct S2 {
+  a : i32;
+  b : S1;
+  c : i32;
+};
+
+struct S3 {
+  a : S2;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+  var x : i32 = S3(S2(1, S1(2), 3)).a.b.a;
+}
+)";
+
+  auto* expect = R"(
+struct S1 {
+  a : i32;
+};
+
+struct S2 {
+  a : i32;
+  b : S1;
+  c : i32;
+};
+
+struct S3 {
+  a : S2;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+  const tint_symbol_1 : S1 = S1(2);
+  const tint_symbol_4 : S2 = S2(1, tint_symbol_1, 3);
+  const tint_symbol_8 : S3 = S3(tint_symbol_4);
+  var x : i32 = tint_symbol_8.a.b.a;
+}
+)";
+
+  auto got = Run<Hlsl>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(HlslTest, PromoteInitializerToConstVar_Mixed) {
+  auto* src = R"(
+struct S1 {
+  a : i32;
+};
+
+struct S2 {
+  a : array<S1, 3>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+  var x : i32 = S2(array<S1, 3>(S1(1), S1(2), S1(3))).a[1].a;
+}
+)";
+
+  auto* expect = R"(
+struct S1 {
+  a : i32;
+};
+
+struct S2 {
+  a : array<S1, 3>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+  const tint_symbol_1 : S1 = S1(1);
+  const tint_symbol_4 : S1 = S1(2);
+  const tint_symbol_5 : S1 = S1(3);
+  const tint_symbol_6 : array<S1, 3> = array<S1, 3>(tint_symbol_1, tint_symbol_4, tint_symbol_5);
+  const tint_symbol_7 : S2 = S2(tint_symbol_6);
+  var x : i32 = tint_symbol_7.a[1].a;
+}
+)";
+
+  auto got = Run<Hlsl>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(HlslTest, PromoteInitializerToConstVar_NoChangeOnVarDecl) {
+  auto* src = R"(
+struct S {
+  a : i32;
+  b : f32;
+  c : i32;
+};
+
 [[stage(vertex)]]
 fn main() -> void {
   var local_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
+  var local_str : S = S(1, 2.0, 3);
 }
 
 const module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
+
+const module_str : S = S(1, 2.0, 3);
 )";
 
   auto* expect = src;
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 90ffaf3..4f82f6c 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1229,7 +1229,16 @@
 bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre,
                                         std::ostream& out,
                                         ast::TypeConstructorExpression* expr) {
-  if (expr->type()->Is<type::Array>()) {
+  // If the type constructor is empty then we need to construct with the zero
+  // value for all components.
+  if (expr->values().empty()) {
+    return EmitZeroValue(out, expr->type());
+  }
+
+  bool brackets =
+      expr->type()->UnwrapAliasIfNeeded()->IsAnyOf<type::Array, type::Struct>();
+
+  if (brackets) {
     out << "{";
   } else {
     if (!EmitType(out, expr->type(), "")) {
@@ -1238,31 +1247,19 @@
     out << "(";
   }
 
-  // If the type constructor is empty then we need to construct with the zero
-  // value for all components.
-  if (expr->values().empty()) {
-    if (!EmitZeroValue(out, expr->type())) {
+  bool first = true;
+  for (auto* e : expr->values()) {
+    if (!first) {
+      out << ", ";
+    }
+    first = false;
+
+    if (!EmitExpression(pre, out, e)) {
       return false;
     }
-  } else {
-    bool first = true;
-    for (auto* e : expr->values()) {
-      if (!first) {
-        out << ", ";
-      }
-      first = false;
-
-      if (!EmitExpression(pre, out, e)) {
-        return false;
-      }
-    }
   }
 
-  if (expr->type()->Is<type::Array>()) {
-    out << "}";
-  } else {
-    out << ")";
-  }
+  out << (brackets ? "}" : ")");
   return true;
 }
 
@@ -1994,6 +1991,10 @@
   } else if (type->Is<type::U32>()) {
     out << "0u";
   } else if (auto* vec = type->As<type::Vector>()) {
+    if (!EmitType(out, type, "")) {
+      return false;
+    }
+    ScopedParen sp(out);
     for (uint32_t i = 0; i < vec->size(); i++) {
       if (i != 0) {
         out << ", ";
@@ -2003,6 +2004,10 @@
       }
     }
   } else if (auto* mat = type->As<type::Matrix>()) {
+    if (!EmitType(out, type, "")) {
+      return false;
+    }
+    ScopedParen sp(out);
     for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
       if (i != 0) {
         out << ", ";
@@ -2011,6 +2016,19 @@
         return false;
       }
     }
+  } else if (auto* str = type->As<type::Struct>()) {
+    out << "{";
+    bool first = true;
+    for (auto* member : str->impl()->members()) {
+      if (!first) {
+        out << ", ";
+      }
+      first = false;
+      if (!EmitZeroValue(out, member->type())) {
+        return false;
+      }
+    }
+    out << "}";
   } else {
     diagnostics_.add_error("Invalid type for zero emission: " +
                            type->type_name());
diff --git a/src/writer/hlsl/generator_impl_constructor_test.cc b/src/writer/hlsl/generator_impl_constructor_test.cc
index 2acbedb..2bee83b 100644
--- a/src/writer/hlsl/generator_impl_constructor_test.cc
+++ b/src/writer/hlsl/generator_impl_constructor_test.cc
@@ -194,6 +194,40 @@
   Validate();
 }
 
+TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct) {
+  auto* str = Structure("S", {
+                                 Member("a", ty.i32()),
+                                 Member("b", ty.f32()),
+                                 Member("c", ty.vec3<i32>()),
+                             });
+
+  WrapInFunction(Construct(str, 1, 2.0f, vec3<i32>(3, 4, 5)));
+
+  GeneratorImpl& gen = SanitizeAndBuild();
+
+  ASSERT_TRUE(gen.Generate(out)) << gen.error();
+  EXPECT_THAT(result(), HasSubstr("{1, 2.0f, int3(3, 4, 5)}"));
+
+  Validate();
+}
+
+TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct_Empty) {
+  auto* str = Structure("S", {
+                                 Member("a", ty.i32()),
+                                 Member("b", ty.f32()),
+                                 Member("c", ty.vec3<i32>()),
+                             });
+
+  WrapInFunction(Construct(str));
+
+  GeneratorImpl& gen = SanitizeAndBuild();
+
+  ASSERT_TRUE(gen.Generate(out)) << gen.error();
+  EXPECT_THAT(result(), HasSubstr("{0, 0.0f, int3(0, 0, 0)}"));
+
+  Validate();
+}
+
 }  // namespace
 }  // namespace hlsl
 }  // namespace writer
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index f9b1d6a..7d70277 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -124,16 +124,17 @@
   GeneratorImpl& gen = SanitizeAndBuild();
 
   ASSERT_TRUE(gen.Generate(out)) << gen.error();
-  EXPECT_EQ(result(), R"(struct tint_symbol_1 {
+  EXPECT_EQ(result(), R"(struct tint_symbol_5 {
   float foo : TEXCOORD0;
 };
-struct tint_symbol_3 {
+struct tint_symbol_2 {
   float value : SV_Target1;
 };
 
-tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
-  const float foo = tint_symbol_6.foo;
-  return tint_symbol_3(foo);
+tint_symbol_2 frag_main(tint_symbol_5 tint_symbol_7) {
+  const float foo = tint_symbol_7.foo;
+  const tint_symbol_2 tint_symbol_1 = {foo};
+  return tint_symbol_1;
 }
 
 )");
@@ -157,16 +158,17 @@
   GeneratorImpl& gen = SanitizeAndBuild();
 
   ASSERT_TRUE(gen.Generate(out)) << gen.error();
-  EXPECT_EQ(result(), R"(struct tint_symbol_1 {
+  EXPECT_EQ(result(), R"(struct tint_symbol_6 {
   float4 coord : SV_Position;
 };
-struct tint_symbol_3 {
+struct tint_symbol_2 {
   float value : SV_Depth;
 };
 
-tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
-  const float4 coord = tint_symbol_6.coord;
-  return tint_symbol_3(coord.x);
+tint_symbol_2 frag_main(tint_symbol_6 tint_symbol_8) {
+  const float4 coord = tint_symbol_8.coord;
+  const tint_symbol_2 tint_symbol_1 = {coord.x};
+  return tint_symbol_1;
 }
 
 )");
@@ -213,22 +215,23 @@
   float col1;
   float col2;
 };
-struct tint_symbol_4 {
+struct tint_symbol_2 {
   float col1 : TEXCOORD1;
   float col2 : TEXCOORD2;
 };
-struct tint_symbol_7 {
+struct tint_symbol_8 {
   float col1 : TEXCOORD1;
   float col2 : TEXCOORD2;
 };
 
-tint_symbol_4 vert_main() {
-  const Interface tint_symbol_6 = Interface(0.5f, 0.25f);
-  return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2);
+tint_symbol_2 vert_main() {
+  const Interface tint_symbol_5 = {0.5f, 0.25f};
+  const tint_symbol_2 tint_symbol_1 = {tint_symbol_5.col1, tint_symbol_5.col2};
+  return tint_symbol_1;
 }
 
-void frag_main(tint_symbol_7 tint_symbol_9) {
-  const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2);
+void frag_main(tint_symbol_8 tint_symbol_10) {
+  const Interface colors = {tint_symbol_10.col1, tint_symbol_10.col2};
   const float r = colors.col1;
   const float g = colors.col2;
   return;
@@ -236,8 +239,7 @@
 
 )");
 
-  // TODO(crbug.com/tint/702): This is not legal HLSL
-  // Validate();
+  Validate();
 }
 
 TEST_F(HlslGeneratorImplTest_Function,
@@ -281,25 +283,28 @@
   EXPECT_EQ(result(), R"(struct VertexOutput {
   float4 pos;
 };
-struct tint_symbol_5 {
+struct tint_symbol_2 {
   float4 pos : SV_Position;
 };
-struct tint_symbol_8 {
+struct tint_symbol_6 {
   float4 pos : SV_Position;
 };
 
 VertexOutput foo(float x) {
-  return VertexOutput(float4(x, x, x, 1.0f));
+  const VertexOutput tint_symbol_8 = {float4(x, x, x, 1.0f)};
+  return tint_symbol_8;
 }
 
-tint_symbol_5 vert_main1() {
-  const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f));
-  return tint_symbol_5(tint_symbol_7.pos);
+tint_symbol_2 vert_main1() {
+  const VertexOutput tint_symbol_4 = {foo(0.5f)};
+  const tint_symbol_2 tint_symbol_1 = {tint_symbol_4.pos};
+  return tint_symbol_1;
 }
 
-tint_symbol_8 vert_main2() {
-  const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f));
-  return tint_symbol_8(tint_symbol_10.pos);
+tint_symbol_6 vert_main2() {
+  const VertexOutput tint_symbol_7 = {foo(0.25f)};
+  const tint_symbol_6 tint_symbol_5 = {tint_symbol_7.pos};
+  return tint_symbol_5;
 }
 
 )");
diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc
index 53b7cd1..8766b91 100644
--- a/src/writer/hlsl/generator_impl_sanitizer_test.cc
+++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc
@@ -51,6 +51,45 @@
   EXPECT_EQ(expect, got);
 }
 
+TEST_F(HlslSanitizerTest, PromoteStructInitializerToConstVar) {
+  auto* str = Structure("S", {
+                                 Member("a", ty.i32()),
+                                 Member("b", ty.vec3<f32>()),
+                                 Member("c", ty.i32()),
+                             });
+  auto* struct_init = Construct(str, 1, vec3<f32>(2.f, 3.f, 4.f), 4);
+  auto* struct_access = MemberAccessor(struct_init, "b");
+  auto* pos =
+      Var("pos", ty.vec3<f32>(), ast::StorageClass::kFunction, struct_access);
+
+  Func("main", ast::VariableList{}, ty.void_(),
+       ast::StatementList{
+           create<ast::VariableDeclStatement>(pos),
+       },
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  GeneratorImpl& gen = SanitizeAndBuild();
+
+  ASSERT_TRUE(gen.Generate(out)) << gen.error();
+
+  auto got = result();
+  auto* expect = R"(struct S {
+  int a;
+  float3 b;
+  int c;
+};
+
+void main() {
+  const S tint_symbol_1 = {1, float3(2.0f, 3.0f, 4.0f), 4};
+  float3 pos = tint_symbol_1.b;
+  return;
+}
+
+)";
+  EXPECT_EQ(expect, got);
+}
 }  // namespace
 }  // namespace hlsl
 }  // namespace writer