transform/hlsl: Hoist structure constructors to new var
HLSL has some pecular rules around structure constructors.
`S s = S(1,2,3)` is not valid, but `S s = {1,2,3}` is.
This matches the quirkiness with array initializers, so adjust the array
hoisting logic to also support structures.
Fixed: tint:702
Change-Id: Ifdcafd98292715ae2482f72ec06c87842176d270
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46875
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 918d6e3..9db68e0 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -32,24 +32,24 @@
Transform::Output Hlsl::Run(const Program* in, const DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
- PromoteArrayInitializerToConstVar(ctx);
+ PromoteInitializersToConstVar(ctx);
AddEmptyEntryPoint(ctx);
ctx.Clone();
return Output{Program(std::move(out))};
}
-void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const {
- // Scan the AST nodes for array initializers which need to be promoted to
- // their own constant declaration.
+void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {
+ // Scan the AST nodes for array and structure initializers which
+ // need to be promoted to their own constant declaration.
- // Note: Correct handling of arrays-of-arrays is guaranteed due to the
+ // Note: Correct handling of nested expressions is guaranteed due to the
// depth-first traversal of the ast::Node::Clone() methods:
//
- // The inner-most array initializers are traversed first, and they are hoisted
+ // The inner-most initializers are traversed first, and they are hoisted
// to const variables declared just above the statement of use. The outer
- // array initializer will then be hoisted, inserting themselves between the
- // inner array declaration and the statement of use. This pattern applies
- // correctly to any nested depth.
+ // initializer will then be hoisted, inserting themselves between the
+ // inner declaration and the statement of use. This pattern applies correctly
+ // to any nested depth.
//
// Depth-first traversal of the AST is guaranteed because AST nodes are fully
// immutable and require their children to be constructed first so their
@@ -75,22 +75,23 @@
if (auto* src_var_decl = src_stmt->As<ast::VariableDeclStatement>()) {
if (src_var_decl->variable()->constructor() == src_init) {
- // This statement is just a variable declaration with the array
- // initializer as the constructor value. This is what we're
- // attempting to transform to, and so ignore.
+ // This statement is just a variable declaration with the initializer
+ // as the constructor value. This is what we're attempting to
+ // transform to, and so ignore.
continue;
}
}
- if (auto* src_array_ty = src_sem_expr->Type()->As<type::Array>()) {
+ auto* src_ty = src_sem_expr->Type();
+ if (src_ty->IsAnyOf<type::Array, type::Struct>()) {
// Create a new symbol for the constant
auto dst_symbol = ctx.dst->Symbols().New();
- // Clone the array type
- auto* dst_array_ty = ctx.Clone(src_array_ty);
- // Clone the array initializer
+ // Clone the type
+ auto* dst_ty = ctx.Clone(src_ty);
+ // Clone the initializer
auto* dst_init = ctx.Clone(src_init);
- // Construct the constant that holds the array
- auto* dst_var = ctx.dst->Const(dst_symbol, dst_array_ty, dst_init);
+ // Construct the constant that holds the hoisted initializer
+ auto* dst_var = ctx.dst->Const(dst_symbol, dst_ty, dst_init);
// Construct the variable declaration statement
auto* dst_var_decl =
ctx.dst->create<ast::VariableDeclStatement>(dst_var);
@@ -100,7 +101,7 @@
// Insert the constant before the usage
ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
dst_var_decl);
- // Replace the inlined array with a reference to the constant
+ // Replace the inlined initializer with a reference to the constant
ctx.Replace(src_init, dst_ident);
}
}
diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h
index df903a7..ff7978f 100644
--- a/src/transform/hlsl.h
+++ b/src/transform/hlsl.h
@@ -40,10 +40,10 @@
Output Run(const Program* program, const DataMap& data = {}) override;
private:
- /// Hoists the array initializer to a constant variable, declared just before
- /// the array usage statement.
- /// See crbug.com/tint/406 for more details
- void PromoteArrayInitializerToConstVar(CloneContext& ctx) const;
+ /// Hoists the array and structure initializers to a constant variable,
+ /// declared just before the statement of usage. See crbug.com/tint/406 for
+ /// more details
+ void PromoteInitializersToConstVar(CloneContext& ctx) const;
/// Add an empty shader entry point if none exist in the module.
void AddEmptyEntryPoint(CloneContext& ctx) const;
};
diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc
index 4ce2259..27779bd 100644
--- a/src/transform/hlsl_test.cc
+++ b/src/transform/hlsl_test.cc
@@ -51,6 +51,39 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Basic) {
+ auto* src = R"(
+struct S {
+ a : i32;
+ b : f32;
+ c : vec3<f32>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+ var x : f32 = S(1, 2.0, vec3<f32>()).b;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : i32;
+ b : f32;
+ c : vec3<f32>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+ const tint_symbol_1 : S = S(1, 2.0, vec3<f32>());
+ var x : f32 = tint_symbol_1.b;
+}
+)";
+
+ auto got = Run<Hlsl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) {
auto* src = R"(
[[stage(vertex)]]
@@ -74,14 +107,115 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(HlslTest, PromoteArrayInitializerToConstVar_NoChangeOnArrayVarDecl) {
+TEST_F(HlslTest, PromoteStructureInitializerToConstVar_Nested) {
auto* src = R"(
+struct S1 {
+ a : i32;
+};
+
+struct S2 {
+ a : i32;
+ b : S1;
+ c : i32;
+};
+
+struct S3 {
+ a : S2;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+ var x : i32 = S3(S2(1, S1(2), 3)).a.b.a;
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ a : i32;
+};
+
+struct S2 {
+ a : i32;
+ b : S1;
+ c : i32;
+};
+
+struct S3 {
+ a : S2;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+ const tint_symbol_1 : S1 = S1(2);
+ const tint_symbol_4 : S2 = S2(1, tint_symbol_1, 3);
+ const tint_symbol_8 : S3 = S3(tint_symbol_4);
+ var x : i32 = tint_symbol_8.a.b.a;
+}
+)";
+
+ auto got = Run<Hlsl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(HlslTest, PromoteInitializerToConstVar_Mixed) {
+ auto* src = R"(
+struct S1 {
+ a : i32;
+};
+
+struct S2 {
+ a : array<S1, 3>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+ var x : i32 = S2(array<S1, 3>(S1(1), S1(2), S1(3))).a[1].a;
+}
+)";
+
+ auto* expect = R"(
+struct S1 {
+ a : i32;
+};
+
+struct S2 {
+ a : array<S1, 3>;
+};
+
+[[stage(vertex)]]
+fn main() -> void {
+ const tint_symbol_1 : S1 = S1(1);
+ const tint_symbol_4 : S1 = S1(2);
+ const tint_symbol_5 : S1 = S1(3);
+ const tint_symbol_6 : array<S1, 3> = array<S1, 3>(tint_symbol_1, tint_symbol_4, tint_symbol_5);
+ const tint_symbol_7 : S2 = S2(tint_symbol_6);
+ var x : i32 = tint_symbol_7.a[1].a;
+}
+)";
+
+ auto got = Run<Hlsl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(HlslTest, PromoteInitializerToConstVar_NoChangeOnVarDecl) {
+ auto* src = R"(
+struct S {
+ a : i32;
+ b : f32;
+ c : i32;
+};
+
[[stage(vertex)]]
fn main() -> void {
var local_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
+ var local_str : S = S(1, 2.0, 3);
}
const module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
+
+const module_str : S = S(1, 2.0, 3);
)";
auto* expect = src;
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 90ffaf3..4f82f6c 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1229,7 +1229,16 @@
bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre,
std::ostream& out,
ast::TypeConstructorExpression* expr) {
- if (expr->type()->Is<type::Array>()) {
+ // If the type constructor is empty then we need to construct with the zero
+ // value for all components.
+ if (expr->values().empty()) {
+ return EmitZeroValue(out, expr->type());
+ }
+
+ bool brackets =
+ expr->type()->UnwrapAliasIfNeeded()->IsAnyOf<type::Array, type::Struct>();
+
+ if (brackets) {
out << "{";
} else {
if (!EmitType(out, expr->type(), "")) {
@@ -1238,31 +1247,19 @@
out << "(";
}
- // If the type constructor is empty then we need to construct with the zero
- // value for all components.
- if (expr->values().empty()) {
- if (!EmitZeroValue(out, expr->type())) {
+ bool first = true;
+ for (auto* e : expr->values()) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+
+ if (!EmitExpression(pre, out, e)) {
return false;
}
- } else {
- bool first = true;
- for (auto* e : expr->values()) {
- if (!first) {
- out << ", ";
- }
- first = false;
-
- if (!EmitExpression(pre, out, e)) {
- return false;
- }
- }
}
- if (expr->type()->Is<type::Array>()) {
- out << "}";
- } else {
- out << ")";
- }
+ out << (brackets ? "}" : ")");
return true;
}
@@ -1994,6 +1991,10 @@
} else if (type->Is<type::U32>()) {
out << "0u";
} else if (auto* vec = type->As<type::Vector>()) {
+ if (!EmitType(out, type, "")) {
+ return false;
+ }
+ ScopedParen sp(out);
for (uint32_t i = 0; i < vec->size(); i++) {
if (i != 0) {
out << ", ";
@@ -2003,6 +2004,10 @@
}
}
} else if (auto* mat = type->As<type::Matrix>()) {
+ if (!EmitType(out, type, "")) {
+ return false;
+ }
+ ScopedParen sp(out);
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) {
out << ", ";
@@ -2011,6 +2016,19 @@
return false;
}
}
+ } else if (auto* str = type->As<type::Struct>()) {
+ out << "{";
+ bool first = true;
+ for (auto* member : str->impl()->members()) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+ if (!EmitZeroValue(out, member->type())) {
+ return false;
+ }
+ }
+ out << "}";
} else {
diagnostics_.add_error("Invalid type for zero emission: " +
type->type_name());
diff --git a/src/writer/hlsl/generator_impl_constructor_test.cc b/src/writer/hlsl/generator_impl_constructor_test.cc
index 2acbedb..2bee83b 100644
--- a/src/writer/hlsl/generator_impl_constructor_test.cc
+++ b/src/writer/hlsl/generator_impl_constructor_test.cc
@@ -194,6 +194,40 @@
Validate();
}
+TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct) {
+ auto* str = Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ Member("c", ty.vec3<i32>()),
+ });
+
+ WrapInFunction(Construct(str, 1, 2.0f, vec3<i32>(3, 4, 5)));
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate(out)) << gen.error();
+ EXPECT_THAT(result(), HasSubstr("{1, 2.0f, int3(3, 4, 5)}"));
+
+ Validate();
+}
+
+TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Struct_Empty) {
+ auto* str = Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ Member("c", ty.vec3<i32>()),
+ });
+
+ WrapInFunction(Construct(str));
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate(out)) << gen.error();
+ EXPECT_THAT(result(), HasSubstr("{0, 0.0f, int3(0, 0, 0)}"));
+
+ Validate();
+}
+
} // namespace
} // namespace hlsl
} // namespace writer
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index f9b1d6a..7d70277 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -124,16 +124,17 @@
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_1 {
+ EXPECT_EQ(result(), R"(struct tint_symbol_5 {
float foo : TEXCOORD0;
};
-struct tint_symbol_3 {
+struct tint_symbol_2 {
float value : SV_Target1;
};
-tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
- const float foo = tint_symbol_6.foo;
- return tint_symbol_3(foo);
+tint_symbol_2 frag_main(tint_symbol_5 tint_symbol_7) {
+ const float foo = tint_symbol_7.foo;
+ const tint_symbol_2 tint_symbol_1 = {foo};
+ return tint_symbol_1;
}
)");
@@ -157,16 +158,17 @@
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_1 {
+ EXPECT_EQ(result(), R"(struct tint_symbol_6 {
float4 coord : SV_Position;
};
-struct tint_symbol_3 {
+struct tint_symbol_2 {
float value : SV_Depth;
};
-tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
- const float4 coord = tint_symbol_6.coord;
- return tint_symbol_3(coord.x);
+tint_symbol_2 frag_main(tint_symbol_6 tint_symbol_8) {
+ const float4 coord = tint_symbol_8.coord;
+ const tint_symbol_2 tint_symbol_1 = {coord.x};
+ return tint_symbol_1;
}
)");
@@ -213,22 +215,23 @@
float col1;
float col2;
};
-struct tint_symbol_4 {
+struct tint_symbol_2 {
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
};
-struct tint_symbol_7 {
+struct tint_symbol_8 {
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
};
-tint_symbol_4 vert_main() {
- const Interface tint_symbol_6 = Interface(0.5f, 0.25f);
- return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2);
+tint_symbol_2 vert_main() {
+ const Interface tint_symbol_5 = {0.5f, 0.25f};
+ const tint_symbol_2 tint_symbol_1 = {tint_symbol_5.col1, tint_symbol_5.col2};
+ return tint_symbol_1;
}
-void frag_main(tint_symbol_7 tint_symbol_9) {
- const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2);
+void frag_main(tint_symbol_8 tint_symbol_10) {
+ const Interface colors = {tint_symbol_10.col1, tint_symbol_10.col2};
const float r = colors.col1;
const float g = colors.col2;
return;
@@ -236,8 +239,7 @@
)");
- // TODO(crbug.com/tint/702): This is not legal HLSL
- // Validate();
+ Validate();
}
TEST_F(HlslGeneratorImplTest_Function,
@@ -281,25 +283,28 @@
EXPECT_EQ(result(), R"(struct VertexOutput {
float4 pos;
};
-struct tint_symbol_5 {
+struct tint_symbol_2 {
float4 pos : SV_Position;
};
-struct tint_symbol_8 {
+struct tint_symbol_6 {
float4 pos : SV_Position;
};
VertexOutput foo(float x) {
- return VertexOutput(float4(x, x, x, 1.0f));
+ const VertexOutput tint_symbol_8 = {float4(x, x, x, 1.0f)};
+ return tint_symbol_8;
}
-tint_symbol_5 vert_main1() {
- const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f));
- return tint_symbol_5(tint_symbol_7.pos);
+tint_symbol_2 vert_main1() {
+ const VertexOutput tint_symbol_4 = {foo(0.5f)};
+ const tint_symbol_2 tint_symbol_1 = {tint_symbol_4.pos};
+ return tint_symbol_1;
}
-tint_symbol_8 vert_main2() {
- const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f));
- return tint_symbol_8(tint_symbol_10.pos);
+tint_symbol_6 vert_main2() {
+ const VertexOutput tint_symbol_7 = {foo(0.25f)};
+ const tint_symbol_6 tint_symbol_5 = {tint_symbol_7.pos};
+ return tint_symbol_5;
}
)");
diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc
index 53b7cd1..8766b91 100644
--- a/src/writer/hlsl/generator_impl_sanitizer_test.cc
+++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc
@@ -51,6 +51,45 @@
EXPECT_EQ(expect, got);
}
+TEST_F(HlslSanitizerTest, PromoteStructInitializerToConstVar) {
+ auto* str = Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.vec3<f32>()),
+ Member("c", ty.i32()),
+ });
+ auto* struct_init = Construct(str, 1, vec3<f32>(2.f, 3.f, 4.f), 4);
+ auto* struct_access = MemberAccessor(struct_init, "b");
+ auto* pos =
+ Var("pos", ty.vec3<f32>(), ast::StorageClass::kFunction, struct_access);
+
+ Func("main", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ create<ast::VariableDeclStatement>(pos),
+ },
+ ast::DecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+ });
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate(out)) << gen.error();
+
+ auto got = result();
+ auto* expect = R"(struct S {
+ int a;
+ float3 b;
+ int c;
+};
+
+void main() {
+ const S tint_symbol_1 = {1, float3(2.0f, 3.0f, 4.0f), 4};
+ float3 pos = tint_symbol_1.b;
+ return;
+}
+
+)";
+ EXPECT_EQ(expect, got);
+}
} // namespace
} // namespace hlsl
} // namespace writer