[hlsl] Split module var structs and arrays.
For struct and array initializers at module scope they need to be
completely split out, no nesting. The split values are `static const`
values.
Bug: 42251045
Change-Id: I44ab10783838b8e1922f210926ff0d87e9894d65
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196034
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/cmd/fuzz/ir/BUILD.cmake b/src/tint/cmd/fuzz/ir/BUILD.cmake
index 489d2a6..3e99635 100644
--- a/src/tint/cmd/fuzz/ir/BUILD.cmake
+++ b/src/tint/cmd/fuzz/ir/BUILD.cmake
@@ -55,6 +55,7 @@
tint_lang_core_ir
tint_lang_core_ir_transform_fuzz
tint_lang_core_type
+ tint_lang_hlsl_writer_raise_fuzz
tint_lang_wgsl_program_fuzz
tint_lang_wgsl_writer_raise_fuzz
tint_lang_wgsl_fuzz
diff --git a/src/tint/cmd/fuzz/ir/BUILD.gn b/src/tint/cmd/fuzz/ir/BUILD.gn
index beedcbc..3f05aed 100644
--- a/src/tint/cmd/fuzz/ir/BUILD.gn
+++ b/src/tint/cmd/fuzz/ir/BUILD.gn
@@ -94,6 +94,7 @@
"${tint_src_dir}/lang/core/ir",
"${tint_src_dir}/lang/core/ir/transform:fuzz",
"${tint_src_dir}/lang/core/type",
+ "${tint_src_dir}/lang/hlsl/writer/raise:fuzz",
"${tint_src_dir}/lang/wgsl:fuzz",
"${tint_src_dir}/lang/wgsl/program:fuzz",
"${tint_src_dir}/lang/wgsl/writer/raise:fuzz",
diff --git a/src/tint/cmd/fuzz/wgsl/BUILD.cmake b/src/tint/cmd/fuzz/wgsl/BUILD.cmake
index aeb77bb..5138fb3 100644
--- a/src/tint/cmd/fuzz/wgsl/BUILD.cmake
+++ b/src/tint/cmd/fuzz/wgsl/BUILD.cmake
@@ -50,6 +50,7 @@
tint_lang_core_constant
tint_lang_core_ir_transform_fuzz
tint_lang_core_type
+ tint_lang_hlsl_writer_raise_fuzz
tint_lang_wgsl
tint_lang_wgsl_ast
tint_lang_wgsl_program
diff --git a/src/tint/cmd/fuzz/wgsl/BUILD.gn b/src/tint/cmd/fuzz/wgsl/BUILD.gn
index a060109..cb122cb 100644
--- a/src/tint/cmd/fuzz/wgsl/BUILD.gn
+++ b/src/tint/cmd/fuzz/wgsl/BUILD.gn
@@ -92,6 +92,7 @@
"${tint_src_dir}/lang/core/constant",
"${tint_src_dir}/lang/core/ir/transform:fuzz",
"${tint_src_dir}/lang/core/type",
+ "${tint_src_dir}/lang/hlsl/writer/raise:fuzz",
"${tint_src_dir}/lang/wgsl",
"${tint_src_dir}/lang/wgsl:fuzz",
"${tint_src_dir}/lang/wgsl/ast",
diff --git a/src/tint/lang/core/ir/call.h b/src/tint/lang/core/ir/call.h
index b56a092..0b88fc2 100644
--- a/src/tint/lang/core/ir/call.h
+++ b/src/tint/lang/core/ir/call.h
@@ -49,6 +49,10 @@
return operands_.Slice().Offset(ArgsOperandOffset());
}
+ /// Sets the argument at `idx` of `arg`. `idx` must be within bounds of the current argument
+ /// set.
+ void SetArg(size_t idx, ir::Value* arg) { SetOperand(ArgsOperandOffset() + idx, arg); }
+
/// Append a new argument to the argument list for this call instruction.
/// @param arg the argument value to append
void AppendArg(ir::Value* arg) { AddOperand(operands_.Length(), arg); }
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index ba645a7..646c246 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -754,12 +754,27 @@
AddError(inst) << "instruction in root block does not have root block as parent";
continue;
}
- auto* var = inst->As<ir::Var>();
- if (!var) {
- AddError(inst) << "root block: invalid instruction: " << inst->TypeInfo().name;
- continue;
- }
- CheckInstruction(var);
+
+ tint::Switch(
+ inst, //
+ [&](const core::ir::Var* var) { CheckInstruction(var); },
+ [&](const core::ir::Let* let) {
+ if (capabilities_.Contains(Capability::kAllowModuleScopeLets)) {
+ CheckInstruction(let);
+ } else {
+ AddError(inst) << "root block: invalid instruction: " << inst->TypeInfo().name;
+ }
+ },
+ [&](const core::ir::Construct* c) {
+ if (capabilities_.Contains(Capability::kAllowModuleScopeLets)) {
+ CheckInstruction(c);
+ } else {
+ AddError(inst) << "root block: invalid instruction: " << inst->TypeInfo().name;
+ }
+ },
+ [&](Default) {
+ AddError(inst) << "root block: invalid instruction: " << inst->TypeInfo().name;
+ });
}
}
diff --git a/src/tint/lang/core/ir/validator.h b/src/tint/lang/core/ir/validator.h
index 6e4d475..2fe187c 100644
--- a/src/tint/lang/core/ir/validator.h
+++ b/src/tint/lang/core/ir/validator.h
@@ -46,6 +46,8 @@
kAllowVectorElementPointer,
/// Allows ref types
kAllowRefTypes,
+ /// Allows module scoped lets
+ kAllowModuleScopeLets,
};
/// Capabilities is a set of Capability
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 6be75fd..3389346 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -88,6 +88,64 @@
)");
}
+TEST_F(IR_ValidatorTest, RootBlock_Let) {
+ mod.root_block->Append(b.Let("a", 1_f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:2:12 error: let: root block: invalid instruction: tint::core::ir::Let
+ %a:f32 = let 1.0f
+ ^^^
+
+:1:1 note: in block
+$B1: { # root
+^^^
+
+note: # Disassembly
+$B1: { # root
+ %a:f32 = let 1.0f
+}
+
+)");
+}
+
+TEST_F(IR_ValidatorTest, RootBlock_LetWithAllowModuleScopeLets) {
+ mod.root_block->Append(b.Let("a", 1_f));
+
+ auto res = ir::Validate(mod, Capabilities{Capability::kAllowModuleScopeLets});
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, RootBlock_Construct) {
+ mod.root_block->Append(b.Construct(ty.vec2<f32>(), 1_f, 2_f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:2:18 error: construct: root block: invalid instruction: tint::core::ir::Construct
+ %1:vec2<f32> = construct 1.0f, 2.0f
+ ^^^^^^^^^
+
+:1:1 note: in block
+$B1: { # root
+^^^
+
+note: # Disassembly
+$B1: { # root
+ %1:vec2<f32> = construct 1.0f, 2.0f
+}
+
+)");
+}
+
+TEST_F(IR_ValidatorTest, RootBlock_ConstructWithAllowModuleScopeLets) {
+ mod.root_block->Append(b.Construct(ty.vec2<f32>(), 1_f, 2_f));
+
+ auto res = ir::Validate(mod, Capabilities{Capability::kAllowModuleScopeLets});
+ ASSERT_EQ(res, Success);
+}
+
TEST_F(IR_ValidatorTest, RootBlock_VarBlockMismatch) {
auto* var = b.Var(ty.ptr<private_, i32>());
mod.root_block->Append(var);
diff --git a/src/tint/lang/hlsl/writer/constant_test.cc b/src/tint/lang/hlsl/writer/constant_test.cc
index da2ee19..0dcf0b1 100644
--- a/src/tint/lang/hlsl/writer/constant_test.cc
+++ b/src/tint/lang/hlsl/writer/constant_test.cc
@@ -539,6 +539,20 @@
)");
}
+TEST_F(HlslWriterTest, ConstantTypeArrayModuleScopeZero) {
+ b.ir.root_block->Append(b.Var<private_>("v", b.Zero<array<f32, 65536>>()));
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+static const float v_1[65536] = (float[65536])0;
+static float v[65536] = v_1;
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, ConstantTypeArrayEmpty) {
auto* f = b.Function("a", ty.void_(), core::ir::Function::PipelineStage::kCompute);
f->SetWorkgroupSize(1, 1, 1);
@@ -687,8 +701,7 @@
)");
}
-// TODO(dsinclair): Need support for `static const` variables
-TEST_F(HlslWriterTest, DISABLED_ConstantTypeLetStructCompositeModuleScoped) {
+TEST_F(HlslWriterTest, ConstantTypeLetStructCompositeModuleScoped) {
Vector members_a{
ty.Get<core::type::StructMember>(b.ir.symbols.New("e"), ty.vec4<f32>(), 0u, 0u, 16u, 16u,
core::type::StructMemberAttributes{}),
@@ -720,9 +733,10 @@
A c;
};
-static const A c_1 = {(1.f).xxxx};
-static const S c_2 = {c_1};
-static S z = c_2;
+
+static const A v = {(1.0f).xxxx};
+static const S v_1 = {v};
+static S z = v_1;
float a() {
S t = {{(1.0f).xxxx}};
return 1.0f;
@@ -868,8 +882,8 @@
};
-static
-S p = (S)0;
+static const S v = {0};
+static S p = v;
[numthreads(1, 1, 1)]
void unused_entry_point() {
}
@@ -877,8 +891,7 @@
)");
}
-// TODO(dsinclair): Need suppport for `static const` variables
-TEST_F(HlslWriterTest, DISABLED_ConstantTypeStructStatic) {
+TEST_F(HlslWriterTest, ConstantTypeStructStatic) {
Vector members{
ty.Get<core::type::StructMember>(b.ir.symbols.New("a"), ty.i32(), 0u, 0u, 4u, 4u,
core::type::StructMemberAttributes{}),
@@ -893,8 +906,73 @@
};
-static const
-S p = {3};
+static const S v = {3};
+static S p = v;
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, ConstantTypeMultiStructAndArray) {
+ auto* a_ty = ty.Struct(mod.symbols.New("A"), {
+ {mod.symbols.New("a"), ty.array<i32, 2>(),
+ core::type::StructMemberAttributes{}},
+ });
+ auto* b_ty =
+ ty.Struct(mod.symbols.New("B"), {
+ {mod.symbols.New("b"), ty.array<array<i32, 4>, 1>(),
+ core::type::StructMemberAttributes{}},
+ });
+ auto* c_ty = ty.Struct(mod.symbols.New("C"),
+ {
+ {mod.symbols.New("a"), a_ty, core::type::StructMemberAttributes{}},
+ });
+
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_>("a", b.Composite(a_ty, b.Composite(ty.array<i32, 2>(), 9_i, 10_i)));
+ b.Var<private_>(
+ "b",
+ b.Composite(b_ty, b.Composite(ty.array<array<i32, 4>, 1>(),
+ b.Composite(ty.array<i32, 4>(), 5_i, 6_i, 7_i, 8_i))));
+ b.Var<private_>(
+ "c", b.Composite(c_ty, b.Composite(a_ty, b.Composite(ty.array<i32, 2>(), 1_i, 2_i))));
+
+ b.Var<private_>("d", b.Composite(ty.array<i32, 2>(), 11_i, 12_i));
+ b.Var<private_>("e", b.Composite(ty.array<array<array<i32, 3>, 2>, 1>(),
+ b.Composite(ty.array<array<i32, 3>, 2>(),
+ b.Composite(ty.array<i32, 3>(), 1_i, 2_i, 3_i),
+ b.Composite(ty.array<i32, 3>(), 4_i, 5_i, 6_i)
+
+ )));
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct A {
+ int a[2];
+};
+
+struct B {
+ int b[1][4];
+};
+
+struct C {
+ A a_1;
+};
+
+
+static const A v = {{9, 10}};
+static A a = v;
+static const B v_1 = {{{5, 6, 7, 8}}};
+static B b = v_1;
+static const A v_2 = {{1, 2}};
+static const C v_3 = {v_2};
+static C c = v_3;
+static const int v_4[2] = {11, 12};
+static int d[2] = v_4;
+static const int v_5[1][2][3] = {{{1, 2, 3}, {4, 5, 6}}};
+static int e[1][2][3] = v_5;
[numthreads(1, 1, 1)]
void unused_entry_point() {
}
diff --git a/src/tint/lang/hlsl/writer/function_test.cc b/src/tint/lang/hlsl/writer/function_test.cc
index 20f07cc..594a37a 100644
--- a/src/tint/lang/hlsl/writer/function_test.cc
+++ b/src/tint/lang/hlsl/writer/function_test.cc
@@ -309,13 +309,16 @@
vert_main_outputs vert_main() {
Interface v_1 = vert_main_inner();
- vert_main_outputs v_2 = {v_1.col1, v_1.col2, v_1.pos};
- return v_2;
+ Interface v_2 = v_1;
+ Interface v_3 = v_1;
+ Interface v_4 = v_1;
+ vert_main_outputs v_5 = {v_3.col1, v_4.col2, v_2.pos};
+ return v_5;
}
void frag_main(frag_main_inputs inputs) {
- Interface v_3 = {float4(inputs.Interface_pos.xyz, (1.0f / inputs.Interface_pos[3u])), inputs.Interface_col1, inputs.Interface_col2};
- frag_main_inner(v_3);
+ Interface v_6 = {float4(inputs.Interface_pos.xyz, (1.0f / inputs.Interface_pos[3u])), inputs.Interface_col1, inputs.Interface_col2};
+ frag_main_inner(v_6);
}
)");
@@ -985,5 +988,30 @@
)");
}
+TEST_F(HlslWriterTest, DuplicateConstant) {
+ auto* ret_arr = b.Function("ret_arr", ty.array<vec4<i32>, 4>());
+ b.Append(ret_arr->Block(), [&] { b.Return(ret_arr, b.Zero<array<vec4<i32>, 4>>()); });
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("src_let", b.Zero<array<vec4<i32>, 4>>());
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+typedef int4 ary_ret[4];
+ary_ret ret_arr() {
+ int4 v[4] = (int4[4])0;
+ return v;
+}
+
+void foo() {
+ int4 src_let[4] = (int4[4])0;
+}
+
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 24b2508e..8aa8284 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -153,7 +153,8 @@
/// @returns the generated HLSL shader
tint::Result<PrintResult> Generate() {
- auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "HLSL writer");
+ core::ir::Capabilities capabilities{core::ir::Capability::kAllowModuleScopeLets};
+ auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "HLSL writer", capabilities);
if (valid != Success) {
return std::move(valid.Failure());
}
@@ -197,13 +198,20 @@
/// Block to emit for a continuing
std::function<void()> emit_continuing_;
+ enum LetType : uint8_t {
+ kFunction,
+ kModuleScope,
+ };
+
/// Emit the root block.
/// @param root_block the root block to emit
void EmitRootBlock(core::ir::Block* root_block) {
for (auto* inst : *root_block) {
Switch(
- inst, //
- [&](core::ir::Var* v) { return EmitGlobalVar(v); }, //
+ inst, //
+ [&](core::ir::Var* v) { EmitGlobalVar(v); }, //
+ [&](core::ir::Let* l) { EmitLet(l, LetType::kModuleScope); }, //
+ [&](core::ir::Construct*) { /* inlined */ }, //
TINT_ICE_ON_NO_MATCH);
}
}
@@ -295,14 +303,14 @@
[&](const core::ir::ExitLoop*) { EmitExitLoop(); }, //
[&](const core::ir::ExitSwitch*) { EmitExitSwitch(); }, //
[&](const core::ir::If* i) { EmitIf(i); }, //
- [&](const core::ir::Let* i) { EmitLet(i); }, //
+ [&](const core::ir::Let* i) { EmitLet(i, LetType::kFunction); }, //
[&](const core::ir::StoreVectorElement* s) { EmitStoreVectorElement(s); }, //
[&](const core::ir::Loop* l) { EmitLoop(l); }, //
[&](const core::ir::Return* i) { EmitReturn(i); }, //
[&](const core::ir::Store* i) { EmitStore(i); }, //
[&](const core::ir::Switch* i) { EmitSwitch(i); }, //
[&](const core::ir::Unreachable*) { EmitUnreachable(); }, //
- [&](const core::ir::Var* v) { EmitVar(v); }, //
+ [&](const core::ir::Var* v) { EmitVar(Line(), v); }, //
//
[&](const core::ir::NextIteration*) { /* do nothing */ }, //
[&](const core::ir::ExitIf*) { /* do nothing handled by transform */ }, //
@@ -477,14 +485,17 @@
EmitHandleVariable(var);
break;
case core::AddressSpace::kPrivate: {
- Line() << "static";
- EmitVar(var);
+ auto out = Line();
+ out << "static ";
+ EmitVar(out, var);
break;
}
- case core::AddressSpace::kWorkgroup:
- Line() << "groupshared";
- EmitVar(var);
+ case core::AddressSpace::kWorkgroup: {
+ auto out = Line();
+ out << "groupshared ";
+ EmitVar(out, var);
break;
+ }
case core::AddressSpace::kPushConstant:
default: {
TINT_ICE() << "unhandled address space " << space;
@@ -555,13 +566,12 @@
out << RegisterAndSpace(register_space, bp.value()) << ";";
}
- void EmitVar(const core::ir::Var* var) {
+ void EmitVar(StringStream& out, const core::ir::Var* var) {
auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
TINT_ASSERT(ptr);
auto space = ptr->AddressSpace();
- auto out = Line();
EmitTypeAndName(out, var->Result(0)->Type(), space, ptr->Access(), NameOf(var->Result(0)));
if (var->Initializer()) {
@@ -583,9 +593,13 @@
EmitConstant(out, ir_.constant_values.Zero(ty));
}
- void EmitLet(const core::ir::Let* l) {
+ void EmitLet(const core::ir::Let* l, LetType type) {
auto out = Line();
+ if (type == LetType::kModuleScope) {
+ out << "static const ";
+ }
+
// TODO(dsinclair): Investigate using `const` here as well, the AST printer doesn't emit
// const with a let, but we should be able to.
EmitTypeAndName(out, l->Result(0)->Type(), core::AddressSpace::kUndefined,
@@ -689,9 +703,18 @@
Switch(
c->Result(0)->Type(),
[&](const core::type::Array*) {
- out << "{";
- emit_args();
- out << "}";
+ // The PromoteInitializers transform will inject splat arrays as composites of one
+ // element. These need to convert to `(type)0` in HLSL otherwise DXC will complain
+ // about missing values.
+ if (c->Args().Length() == 1) {
+ out << "(";
+ EmitType(out, c->Result(0)->Type());
+ out << ")0";
+ } else {
+ out << "{";
+ emit_args();
+ out << "}";
+ }
},
[&](const core::type::Struct*) {
out << "{";
diff --git a/src/tint/lang/hlsl/writer/raise/BUILD.cmake b/src/tint/lang/hlsl/writer/raise/BUILD.cmake
index 7e21ab8..90c19e8 100644
--- a/src/tint/lang/hlsl/writer/raise/BUILD.cmake
+++ b/src/tint/lang/hlsl/writer/raise/BUILD.cmake
@@ -119,3 +119,35 @@
tint_target_add_external_dependencies(tint_lang_hlsl_writer_raise_test test
"gtest"
)
+
+################################################################################
+# Target: tint_lang_hlsl_writer_raise_fuzz
+# Kind: fuzz
+################################################################################
+tint_add_target(tint_lang_hlsl_writer_raise_fuzz fuzz
+ lang/hlsl/writer/raise/promote_initializers_fuzz.cc
+)
+
+tint_target_add_dependencies(tint_lang_hlsl_writer_raise_fuzz fuzz
+ tint_api_common
+ tint_cmd_fuzz_ir_fuzz
+ tint_lang_core
+ tint_lang_core_constant
+ tint_lang_core_ir
+ tint_lang_core_type
+ tint_lang_hlsl_writer_raise
+ tint_utils_bytes
+ tint_utils_containers
+ tint_utils_diagnostic
+ tint_utils_ice
+ tint_utils_id
+ tint_utils_macros
+ tint_utils_math
+ tint_utils_memory
+ tint_utils_reflection
+ tint_utils_result
+ tint_utils_rtti
+ tint_utils_symbol
+ tint_utils_text
+ tint_utils_traits
+)
diff --git a/src/tint/lang/hlsl/writer/raise/BUILD.gn b/src/tint/lang/hlsl/writer/raise/BUILD.gn
index 8fdf1be..6f30f45 100644
--- a/src/tint/lang/hlsl/writer/raise/BUILD.gn
+++ b/src/tint/lang/hlsl/writer/raise/BUILD.gn
@@ -119,3 +119,30 @@
]
}
}
+
+tint_fuzz_source_set("fuzz") {
+ sources = [ "promote_initializers_fuzz.cc" ]
+ deps = [
+ "${tint_src_dir}/api/common",
+ "${tint_src_dir}/cmd/fuzz/ir:fuzz",
+ "${tint_src_dir}/lang/core",
+ "${tint_src_dir}/lang/core/constant",
+ "${tint_src_dir}/lang/core/ir",
+ "${tint_src_dir}/lang/core/type",
+ "${tint_src_dir}/lang/hlsl/writer/raise",
+ "${tint_src_dir}/utils/bytes",
+ "${tint_src_dir}/utils/containers",
+ "${tint_src_dir}/utils/diagnostic",
+ "${tint_src_dir}/utils/ice",
+ "${tint_src_dir}/utils/id",
+ "${tint_src_dir}/utils/macros",
+ "${tint_src_dir}/utils/math",
+ "${tint_src_dir}/utils/memory",
+ "${tint_src_dir}/utils/reflection",
+ "${tint_src_dir}/utils/result",
+ "${tint_src_dir}/utils/rtti",
+ "${tint_src_dir}/utils/symbol",
+ "${tint_src_dir}/utils/text",
+ "${tint_src_dir}/utils/traits",
+ ]
+}
diff --git a/src/tint/lang/hlsl/writer/raise/promote_initializers.cc b/src/tint/lang/hlsl/writer/raise/promote_initializers.cc
index b86601b..0ad0221 100644
--- a/src/tint/lang/hlsl/writer/raise/promote_initializers.cc
+++ b/src/tint/lang/hlsl/writer/raise/promote_initializers.cc
@@ -60,48 +60,141 @@
struct ValueInfo {
core::ir::Instruction* inst;
+ size_t index;
core::ir::Value* val;
};
void Process(core::ir::Block* block, bool is_root_block) {
Vector<ValueInfo, 4> worklist;
- Hashset<core::ir::Value*, 4> seen;
for (auto* inst : *block) {
if (inst->Is<core::ir::Let>()) {
continue;
}
if (inst->Is<core::ir::Var>()) {
- if (is_root_block) {
- // split root var if needed ...
+ // In the root block we need to split struct and array vars out to turn them into
+ // `static const` variables.
+ if (!is_root_block) {
+ continue;
}
- continue;
}
- for (auto* operand : inst->Operands()) {
- if (!operand || seen.Contains(operand) || !operand->Type() ||
+ // Check each operand of the instruction to determine if it's a struct or array.
+ auto operands = inst->Operands();
+ for (size_t i = 0; i < operands.Length(); ++i) {
+ auto* operand = operands[i];
+ if (!operand || !operand->Type() ||
!operand->Type()->IsAnyOf<core::type::Struct, core::type::Array>()) {
continue;
}
if (operand->IsAnyOf<core::ir::InstructionResult, core::ir::Constant>()) {
- seen.Add(operand);
- worklist.Push({inst, operand});
+ worklist.Push({inst, i, operand});
}
}
}
+ Vector<core::ir::Construct*, 4> const_worklist;
for (auto& item : worklist) {
if (auto* res = As<core::ir::InstructionResult>(item.val)) {
- PutInLet(res);
+ PutInLet(item.inst, item.index, res);
} else if (auto* val = As<core::ir::Constant>(item.val)) {
- PutInLet(item.inst, val);
+ auto* let = PutInLet(item.inst, item.index, val);
+ auto ret = HoistModuleScopeLetToConstruct(is_root_block, item.inst, let, val);
+ if (ret.has_value()) {
+ const_worklist.Insert(0, *ret);
+ }
+ }
+ }
+
+ // If any element in the constant is `struct` or `array` it needs to be pulled out
+ // into it's own `let`. That also means the `constant` value needs to turn into a
+ // `Construct`.
+ while (!const_worklist.IsEmpty()) {
+ auto item = const_worklist.Pop();
+
+ tint::Slice<core::ir::Value* const> args = item->Args();
+ for (size_t i = 0; i < args.Length(); ++i) {
+ auto ret = ProcessConstant(args[i], item, i);
+ if (ret.has_value()) {
+ const_worklist.Insert(0, *ret);
+ }
}
}
}
+ // Process a constant operand and replace if it's a struct initializer
+ std::optional<core::ir::Construct*> ProcessConstant(core::ir::Value* operand,
+ core::ir::Construct* parent,
+ size_t idx) {
+ auto* const_val = operand->As<core::ir::Constant>();
+ TINT_ASSERT(const_val);
+
+ if (!const_val->Type()->Is<core::type::Struct>()) {
+ return std::nullopt;
+ }
+
+ auto* let = b.Let(const_val->Type());
+
+ Vector<core::ir::Value*, 4> new_args = GatherArgs(const_val);
+
+ auto* construct = b.Construct(const_val->Type(), new_args);
+ let->SetValue(construct->Result(0));
+
+ // Put the `let` in before the `construct` value that we're based off of
+ let->InsertBefore(parent);
+ // Put the new `construct` in before the `let`.
+ construct->InsertBefore(let);
+
+ // Replace the argument in the originating `construct` with the new `let`.
+ parent->SetArg(idx, let->Result(0));
+
+ return {construct};
+ }
+
+ // Determine if this is a root block var which contains a struct initializer and, if
+ // so, setup the instruction for the needed replacement.
+ std::optional<core::ir::Construct*> HoistModuleScopeLetToConstruct(bool is_root_block,
+ core::ir::Instruction* inst,
+ core::ir::Let* let,
+ core::ir::Constant* val) {
+ // Only care about root-block variables
+ if (!is_root_block || !inst->Is<core::ir::Var>()) {
+ return std::nullopt;
+ }
+ // Only care about struct constants
+ if (!val->Type()->Is<core::type::Struct>()) {
+ return std::nullopt;
+ }
+
+ // This may not actually need to be a `construct` but pull it out now to
+ // make further changes, if they're necessary, easier.
+ Vector<core::ir::Value*, 4> args = GatherArgs(val);
+
+ // Turn the `constant` into a `construct` call and replace the value of the `let` that
+ // was created.
+ auto* construct = b.Construct(val->Type(), args);
+ let->SetValue(construct->Result(0));
+ construct->InsertBefore(let);
+
+ return {construct};
+ }
+
+ // Gather the arguments to the constant and create a `ir::Value` array from them which can
+ // be used in a `construct`.
+ Vector<core::ir::Value*, 4> GatherArgs(core::ir::Constant* val) {
+ Vector<core::ir::Value*, 4> args;
+ if (auto* const_val = val->Value()->As<core::constant::Composite>()) {
+ for (auto v : const_val->elements) {
+ args.Push(b.Constant(v));
+ }
+ } else if (auto* splat_val = val->Value()->As<core::constant::Splat>()) {
+ args.Push(b.Constant(splat_val->el));
+ }
+ return args;
+ }
+
core::ir::Let* MakeLet(core::ir::Value* value) {
auto* let = b.Let(value->Type());
- value->ReplaceAllUsesWith(let->Result(0));
let->SetValue(value);
auto name = b.ir.NameOf(value);
@@ -112,15 +205,12 @@
return let;
}
- void PutInLet(core::ir::Instruction* inst, core::ir::Value* value) {
+ core::ir::Let* PutInLet(core::ir::Instruction* inst, size_t index, core::ir::Value* value) {
auto* let = MakeLet(value);
let->InsertBefore(inst);
- }
- void PutInLet(core::ir::InstructionResult* value) {
- auto* inst = value->Instruction();
- auto* let = MakeLet(value);
- let->InsertAfter(inst);
+ inst->SetOperand(index, let->Result(0));
+ return let;
}
};
diff --git a/src/tint/lang/hlsl/writer/raise/promote_initializers.h b/src/tint/lang/hlsl/writer/raise/promote_initializers.h
index 340c96a..768a8ce 100644
--- a/src/tint/lang/hlsl/writer/raise/promote_initializers.h
+++ b/src/tint/lang/hlsl/writer/raise/promote_initializers.h
@@ -38,7 +38,37 @@
namespace tint::hlsl::writer::raise {
/// PromoteInitializers is a transform that moves inline struct and array initializers to a `let`
-/// unless the initializer is already in a `let ` or `var`.
+/// unless the initializer is already in a `let ` or `var`. For any `var` at the module scope it
+/// will recursively break any array or struct initializers out of the constant into their own
+/// `let`.
+///
+/// After this transform the `Capability::kAllowModuleScopeLets` must be enabled and any downstream
+/// transform/printer must under stand `let` and `construct` instructions at the module scope.
+/// (`construct` can just be skipped as they will be inlined, but the instruction still has to be
+/// handled.)
+///
+/// For example:
+///
+/// ```wgsl
+/// struct A {
+/// b: f32,
+/// }
+/// struct S {
+/// a: A
+/// }
+/// var<private> p = S(A(1.f));
+/// ```
+///
+/// Essentially creates:
+///
+/// ```wgsl
+/// struct S {
+/// a: i32,
+/// }
+/// let v: A = A(1.f);
+/// let v_1: S = S(v);
+/// var p = v_1;
+/// ```
///
/// @param module the module to transform
/// @returns error diagnostics on failure
diff --git a/src/tint/lang/hlsl/writer/raise/promote_initializers_fuzz.cc b/src/tint/lang/hlsl/writer/raise/promote_initializers_fuzz.cc
new file mode 100644
index 0000000..7683d95
--- /dev/null
+++ b/src/tint/lang/hlsl/writer/raise/promote_initializers_fuzz.cc
@@ -0,0 +1,51 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT\ OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/hlsl/writer/raise/promote_initializers.h"
+
+#include "src/tint/cmd/fuzz/ir/fuzz.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+
+namespace tint::hlsl::writer::raise {
+namespace {
+
+void PromoteInitializersFuzzer(core::ir::Module& module) {
+ if (auto res = PromoteInitializers(module); res != Success) {
+ return;
+ }
+
+ core::ir::Capabilities capabilities{core::ir::Capability::kAllowModuleScopeLets};
+ if (auto res = Validate(module, capabilities); res != Success) {
+ TINT_ICE() << "result of PromoteInitializers failed IR validation\n" << res.Failure();
+ }
+}
+
+} // namespace
+} // namespace tint::hlsl::writer::raise
+
+TINT_IR_MODULE_FUZZER(tint::hlsl::writer::raise::PromoteInitializersFuzzer);
diff --git a/src/tint/lang/hlsl/writer/raise/promote_initializers_test.cc b/src/tint/lang/hlsl/writer/raise/promote_initializers_test.cc
index fd52b86..3aea052 100644
--- a/src/tint/lang/hlsl/writer/raise/promote_initializers_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/promote_initializers_test.cc
@@ -32,6 +32,701 @@
#include "gtest/gtest.h"
#include "src/tint/lang/core/ir/transform/helper_test.h"
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
namespace tint::hlsl::writer::raise {
-namespace {}
+namespace {
+
+using HlslWriterPromoteInitializersTest = core::ir::transform::TransformTest;
+
+TEST_F(HlslWriterPromoteInitializersTest, NoStructInitializers) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Var<private_>("a", b.Zero<i32>());
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = @fragment func():void {
+ $B1: {
+ %a:ptr<private, i32, read_write> = var, 0i
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, StructInVarNoChange) {
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("S"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Var<private_>("a", b.Composite(str_ty, 1_i));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+S = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+%foo = @fragment func():void {
+ $B1: {
+ %a:ptr<private, S, read_write> = var, S(1i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ArrayInVarNoChange) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Var<private_>("a", b.Zero<array<i32, 2>>());
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = @fragment func():void {
+ $B1: {
+ %a:ptr<private, array<i32, 2>, read_write> = var, array<i32, 2>(0i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, StructInLetNoChange) {
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("S"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Composite(str_ty, 1_i));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+S = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+%foo = @fragment func():void {
+ $B1: {
+ %a:S = let S(1i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ArrayInLetNoChange) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Zero<array<i32, 2>>());
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = @fragment func():void {
+ $B1: {
+ %a:array<i32, 2> = let array<i32, 2>(0i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, StructInCall) {
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("S"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* p = b.FunctionParam("p", str_ty);
+ auto* dst = b.Function("dst", ty.void_());
+ dst->SetParams({p});
+ dst->Block()->Append(b.Return(dst));
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Call(dst, b.Composite(str_ty, 1_i));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+S = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+%dst = func(%p:S):void {
+ $B1: {
+ ret
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %4:void = call %dst, S(1i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+S = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+%dst = func(%p:S):void {
+ $B1: {
+ ret
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %4:S = let S(1i)
+ %5:void = call %dst, %4
+ ret
+ }
+}
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ArrayInCall) {
+ auto* p = b.FunctionParam("p", ty.array<i32, 2>());
+ auto* dst = b.Function("dst", ty.void_());
+ dst->SetParams({p});
+ dst->Block()->Append(b.Return(dst));
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Call(dst, b.Composite(ty.array<i32, 2>(), 1_i));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%dst = func(%p:array<i32, 2>):void {
+ $B1: {
+ ret
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %4:void = call %dst, array<i32, 2>(1i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%dst = func(%p:array<i32, 2>):void {
+ $B1: {
+ ret
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %4:array<i32, 2> = let array<i32, 2>(1i)
+ %5:void = call %dst, %4
+ ret
+ }
+}
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ModuleScopedStruct) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("S"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ });
+
+ b.ir.root_block->Append(b.Var<private_>("a", b.Composite(str_ty, 1_i)));
+
+ auto* src = R"(
+S = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+$B1: { # root
+ %a:ptr<private, S, read_write> = var, S(1i)
+}
+
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+S = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+$B1: { # root
+ %1:S = construct 1i
+ %2:S = let %1
+ %a:ptr<private, S, read_write> = var, %2
+}
+
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ModuleScopedArray) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ b.ir.root_block->Append(b.Var<private_>("a", b.Zero<array<i32, 2>>()));
+
+ auto* src = R"(
+$B1: { # root
+ %a:ptr<private, array<i32, 2>, read_write> = var, array<i32, 2>(0i)
+}
+
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %1:array<i32, 2> = let array<i32, 2>(0i)
+ %a:ptr<private, array<i32, 2>, read_write> = var, %1
+}
+
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ModuleScopedStructNested) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* b_ty =
+ ty.Struct(mod.symbols.New("B"),
+ {
+ {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* a_ty =
+ ty.Struct(mod.symbols.New("A"),
+ {
+ {mod.symbols.New("z"), ty.i32(), core::type::StructMemberAttributes{}},
+ {mod.symbols.New("b"), b_ty, core::type::StructMemberAttributes{}},
+ });
+
+ auto* str_ty = ty.Struct(mod.symbols.New("S"),
+ {
+ {mod.symbols.New("a"), a_ty, core::type::StructMemberAttributes{}},
+ });
+
+ b.ir.root_block->Append(
+ b.Var<private_>("a", b.Composite(str_ty, b.Composite(a_ty, 1_i, b.Composite(b_ty, 1_f)))));
+
+ auto* src = R"(
+B = struct @align(4) {
+ c:f32 @offset(0)
+}
+
+A = struct @align(4) {
+ z:i32 @offset(0)
+ b:B @offset(4)
+}
+
+S = struct @align(4) {
+ a:A @offset(0)
+}
+
+$B1: { # root
+ %a:ptr<private, S, read_write> = var, S(A(1i, B(1.0f)))
+}
+
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+B = struct @align(4) {
+ c:f32 @offset(0)
+}
+
+A = struct @align(4) {
+ z:i32 @offset(0)
+ b:B @offset(4)
+}
+
+S = struct @align(4) {
+ a:A @offset(0)
+}
+
+$B1: { # root
+ %1:B = construct 1.0f
+ %2:B = let %1
+ %3:A = construct 1i, %2
+ %4:A = let %3
+ %5:S = construct %4
+ %6:S = let %5
+ %a:ptr<private, S, read_write> = var, %6
+}
+
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, ModuleScopedArrayNestedInStruct) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* str_ty = ty.Struct(mod.symbols.New("S"), {
+ {mod.symbols.New("a"), ty.array<i32, 3>(),
+ core::type::StructMemberAttributes{}},
+ });
+
+ b.ir.root_block->Append(b.Var<private_>("a", b.Composite(str_ty, b.Zero(ty.array<i32, 3>()))));
+
+ auto* src = R"(
+S = struct @align(4) {
+ a:array<i32, 3> @offset(0)
+}
+
+$B1: { # root
+ %a:ptr<private, S, read_write> = var, S(array<i32, 3>(0i))
+}
+
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+S = struct @align(4) {
+ a:array<i32, 3> @offset(0)
+}
+
+$B1: { # root
+ %1:S = construct array<i32, 3>(0i)
+ %2:S = let %1
+ %a:ptr<private, S, read_write> = var, %2
+}
+
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, Many) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* a_ty = ty.Struct(mod.symbols.New("A"), {
+ {mod.symbols.New("a"), ty.array<i32, 2>(),
+ core::type::StructMemberAttributes{}},
+ });
+ auto* b_ty =
+ ty.Struct(mod.symbols.New("B"), {
+ {mod.symbols.New("b"), ty.array<array<i32, 4>, 1>(),
+ core::type::StructMemberAttributes{}},
+ });
+ auto* c_ty = ty.Struct(mod.symbols.New("C"),
+ {
+ {mod.symbols.New("a"), a_ty, core::type::StructMemberAttributes{}},
+ });
+
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_>("a", b.Composite(a_ty, b.Composite(ty.array<i32, 2>(), 9_i, 10_i)));
+ b.Var<private_>(
+ "b",
+ b.Composite(b_ty, b.Composite(ty.array<array<i32, 4>, 1>(),
+ b.Composite(ty.array<i32, 4>(), 5_i, 6_i, 7_i, 8_i))));
+ b.Var<private_>(
+ "c", b.Composite(c_ty, b.Composite(a_ty, b.Composite(ty.array<i32, 2>(), 1_i, 2_i))));
+
+ b.Var<private_>("d", b.Composite(ty.array<i32, 2>(), 11_i, 12_i));
+ b.Var<private_>("e", b.Composite(ty.array<array<array<i32, 3>, 2>, 1>(),
+ b.Composite(ty.array<array<i32, 3>, 2>(),
+ b.Composite(ty.array<i32, 3>(), 1_i, 2_i, 3_i),
+ b.Composite(ty.array<i32, 3>(), 4_i, 5_i, 6_i)
+
+ )));
+ });
+
+ auto* src = R"(
+A = struct @align(4) {
+ a:array<i32, 2> @offset(0)
+}
+
+B = struct @align(4) {
+ b:array<array<i32, 4>, 1> @offset(0)
+}
+
+C = struct @align(4) {
+ a_1:A @offset(0)
+}
+
+$B1: { # root
+ %a:ptr<private, A, read_write> = var, A(array<i32, 2>(9i, 10i))
+ %b:ptr<private, B, read_write> = var, B(array<array<i32, 4>, 1>(array<i32, 4>(5i, 6i, 7i, 8i)))
+ %c:ptr<private, C, read_write> = var, C(A(array<i32, 2>(1i, 2i)))
+ %d:ptr<private, array<i32, 2>, read_write> = var, array<i32, 2>(11i, 12i)
+ %e:ptr<private, array<array<array<i32, 3>, 2>, 1>, read_write> = var, array<array<array<i32, 3>, 2>, 1>(array<array<i32, 3>, 2>(array<i32, 3>(1i, 2i, 3i), array<i32, 3>(4i, 5i, 6i)))
+}
+
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+A = struct @align(4) {
+ a:array<i32, 2> @offset(0)
+}
+
+B = struct @align(4) {
+ b:array<array<i32, 4>, 1> @offset(0)
+}
+
+C = struct @align(4) {
+ a_1:A @offset(0)
+}
+
+$B1: { # root
+ %1:A = construct array<i32, 2>(9i, 10i)
+ %2:A = let %1
+ %a:ptr<private, A, read_write> = var, %2
+ %4:B = construct array<array<i32, 4>, 1>(array<i32, 4>(5i, 6i, 7i, 8i))
+ %5:B = let %4
+ %b:ptr<private, B, read_write> = var, %5
+ %7:A = construct array<i32, 2>(1i, 2i)
+ %8:A = let %7
+ %9:C = construct %8
+ %10:C = let %9
+ %c:ptr<private, C, read_write> = var, %10
+ %12:array<i32, 2> = let array<i32, 2>(11i, 12i)
+ %d:ptr<private, array<i32, 2>, read_write> = var, %12
+ %14:array<array<array<i32, 3>, 2>, 1> = let array<array<array<i32, 3>, 2>, 1>(array<array<i32, 3>, 2>(array<i32, 3>(1i, 2i, 3i), array<i32, 3>(4i, 5i, 6i)))
+ %e:ptr<private, array<array<array<i32, 3>, 2>, 1>, read_write> = var, %14
+}
+
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, DuplicateConstantInLet) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* ret_arr = b.Function("ret_arr", ty.array<vec4<i32>, 4>());
+ b.Append(ret_arr->Block(), [&] { b.Return(ret_arr, b.Zero<array<vec4<i32>, 4>>()); });
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("src_let", b.Zero<array<vec4<i32>, 4>>());
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%ret_arr = func():array<vec4<i32>, 4> {
+ $B1: {
+ ret array<vec4<i32>, 4>(vec4<i32>(0i))
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %src_let:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%ret_arr = func():array<vec4<i32>, 4> {
+ $B1: {
+ %2:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret %2
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %src_let:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret
+ }
+}
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, DuplicateConstantInBlock) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* a_ty =
+ ty.Struct(mod.symbols.New("A"),
+ {
+ {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+ });
+
+ auto* param = b.FunctionParam("a", a_ty);
+ auto* bar = b.Function("bar", ty.void_());
+ bar->SetParams({param});
+ bar->Block()->Append(b.Return(bar));
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Call(bar, b.Composite(a_ty, 1_i));
+ b.Call(bar, b.Composite(a_ty, 1_i));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+A = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+%bar = func(%a:A):void {
+ $B1: {
+ ret
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %4:void = call %bar, A(1i)
+ %5:void = call %bar, A(1i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+A = struct @align(4) {
+ a:i32 @offset(0)
+}
+
+%bar = func(%a:A):void {
+ $B1: {
+ ret
+ }
+}
+%foo = @fragment func():void {
+ $B2: {
+ %4:A = let A(1i)
+ %5:void = call %bar, %4
+ %6:A = let A(1i)
+ %7:void = call %bar, %6
+ ret
+ }
+}
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriterPromoteInitializersTest, DuplicateConstant) {
+ capabilities = core::ir::Capabilities{core::ir::Capability::kAllowModuleScopeLets};
+
+ auto* ret_arr = b.Function("ret_arr", ty.array<vec4<i32>, 4>());
+ b.Append(ret_arr->Block(), [&] { b.Return(ret_arr, b.Zero<array<vec4<i32>, 4>>()); });
+
+ auto* second_arr = b.Function("second_arr", ty.array<vec4<i32>, 4>());
+ b.Append(second_arr->Block(), [&] { b.Return(second_arr, b.Zero<array<vec4<i32>, 4>>()); });
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("src_let", b.Zero<array<vec4<i32>, 4>>());
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%ret_arr = func():array<vec4<i32>, 4> {
+ $B1: {
+ ret array<vec4<i32>, 4>(vec4<i32>(0i))
+ }
+}
+%second_arr = func():array<vec4<i32>, 4> {
+ $B2: {
+ ret array<vec4<i32>, 4>(vec4<i32>(0i))
+ }
+}
+%foo = @fragment func():void {
+ $B3: {
+ %src_let:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%ret_arr = func():array<vec4<i32>, 4> {
+ $B1: {
+ %2:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret %2
+ }
+}
+%second_arr = func():array<vec4<i32>, 4> {
+ $B2: {
+ %4:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret %4
+ }
+}
+%foo = @fragment func():void {
+ $B3: {
+ %src_let:array<vec4<i32>, 4> = let array<vec4<i32>, 4>(vec4<i32>(0i))
+ ret
+ }
+}
+)";
+ Run(PromoteInitializers);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
} // namespace tint::hlsl::writer::raise
diff --git a/src/tint/lang/hlsl/writer/raise/raise.cc b/src/tint/lang/hlsl/writer/raise/raise.cc
index 8a50940..257a973 100644
--- a/src/tint/lang/hlsl/writer/raise/raise.cc
+++ b/src/tint/lang/hlsl/writer/raise/raise.cc
@@ -71,9 +71,11 @@
// These transforms need to be run last as various transforms introduce terminator arguments,
// naming conflicts, and expressions that need to be explicitly not inlined.
RUN_TRANSFORM(core::ir::transform::RemoveTerminatorArgs);
- RUN_TRANSFORM(raise::PromoteInitializers);
RUN_TRANSFORM(core::ir::transform::ValueToLet);
+ // Anything which runs after this needs to handle `Capabilities::kAllowModuleScopedLets`
+ RUN_TRANSFORM(raise::PromoteInitializers);
+
return Success;
}
diff --git a/src/tint/lang/hlsl/writer/switch_test.cc b/src/tint/lang/hlsl/writer/switch_test.cc
index f6458cd..34b82ef 100644
--- a/src/tint/lang/hlsl/writer/switch_test.cc
+++ b/src/tint/lang/hlsl/writer/switch_test.cc
@@ -179,10 +179,8 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
-static
-int global = 0;
-static
-int a = 0;
+static int global = 0;
+static int a = 0;
int bar() {
global = 84;
return global;
@@ -294,10 +292,8 @@
ASSERT_TRUE(Generate(options)) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
-static
-int global = 0;
-static
-int a = 0;
+static int global = 0;
+static int a = 0;
int bar() {
global = 84;
return global;
diff --git a/src/tint/lang/hlsl/writer/var_let_test.cc b/src/tint/lang/hlsl/writer/var_let_test.cc
index 7c1109c..f00e423 100644
--- a/src/tint/lang/hlsl/writer/var_let_test.cc
+++ b/src/tint/lang/hlsl/writer/var_let_test.cc
@@ -513,8 +513,7 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
-static
-float4 u = (0.0f).xxxx;
+static float4 u = (0.0f).xxxx;
[numthreads(1, 1, 1)]
void unused_entry_point() {
}
@@ -530,8 +529,7 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
-groupshared
-float4 u;
+groupshared float4 u;
[numthreads(1, 1, 1)]
void unused_entry_point() {
}