[tint] Fix emission of module-scope struct initializers
Constant structure initializers need to be assigned directly to a
variable in HLSL. We were emitting these as `const` variables for both
module-scope and function-scope.
Module-scope `const` variables are not held in the program bitcode, but
are expected to be provided as a cbuffer binding, which dawn was not
providing.
Instead emit these always as global scope `static const` variables.
Add a dawn end2end test for `var<private>` variables initialized with a
`struct`
Bug: tint:2142
Change-Id: I03c7c94a1729dcfbc077bcef01495a9fbbff0290
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/169940
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp
index 2c2c65a..e3ba8c0 100644
--- a/src/dawn/tests/end2end/ShaderTests.cpp
+++ b/src/dawn/tests/end2end/ShaderTests.cpp
@@ -46,7 +46,7 @@
}
wgpu::ComputePipeline CreateComputePipeline(
const std::string& shader,
- const char* entryPoint,
+ const char* entryPoint = nullptr,
const std::vector<wgpu::ConstantEntry>* constants = nullptr) {
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = utils::CreateShaderModule(device, shader.c_str());
@@ -2227,6 +2227,44 @@
renderPass.color, 32, 16);
}
+TEST_P(ShaderTests, PrivateVarInitWithStruct) {
+ wgpu::ComputePipeline pipeline = CreateComputePipeline(R"(
+@binding(0) @group(0) var<storage, read_write> output : i32;
+
+struct S {
+ i : i32,
+}
+
+var<private> P = S(42);
+
+@compute @workgroup_size(1)
+fn main() {
+ output = P.i;
+}
+)");
+
+ wgpu::Buffer output = CreateBuffer(1);
+
+ wgpu::BindGroup bindGroup =
+ utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, output}});
+
+ wgpu::CommandBuffer commands;
+ {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+ pass.SetPipeline(pipeline);
+ pass.SetBindGroup(0, bindGroup);
+ pass.DispatchWorkgroups(1);
+ pass.End();
+
+ commands = encoder.Finish();
+ }
+
+ queue.Submit(1, &commands);
+
+ EXPECT_BUFFER_U32_EQ(42, output, 0);
+}
+
DAWN_INSTANTIATE_TEST(ShaderTests,
D3D11Backend(),
D3D12Backend(),
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
index 3d25696..3415d05 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
@@ -416,6 +416,8 @@
}
last_kind = kind;
+ global_insertion_point_ = current_buffer_->lines.size();
+
bool ok = Switch(
decl,
[&](const ast::Variable* global) { //
@@ -3847,14 +3849,20 @@
}
} else {
// HLSL requires structure initializers to be assigned directly to a variable.
+ // For these constants use 'static const' at global-scope. 'const' at global scope
+ // creates a variable who's initializer is ignored, and the value is expected to be
+ // provided in a cbuffer. 'static const' is a true value-embedded-in-the-shader-code
+ // constant. We also emit these for function-local constant expressions for
+ // consistency and to ensure that these are not computed at execution time.
auto name = UniqueIdentifier("c");
{
- auto decl = Line();
- decl << "const " << StructName(s) << " " << name << " = ";
+ StringStream decl;
+ decl << "static const " << StructName(s) << " " << name << " = ";
if (!emit_member_values(decl)) {
return false;
}
decl << ";";
+ current_buffer_->Insert(decl.str(), global_insertion_point_++, 0);
}
out << name;
}
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h
index 633180c..3d14eac 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h
@@ -619,6 +619,9 @@
std::unordered_map<const core::type::Type*, std::string> value_or_one_if_zero_;
std::unordered_set<const core::type::Struct*> emitted_structs_;
std::unordered_map<const core::type::Type*, bool> is_struct_or_array_of_matrix_;
+
+ // The line index in current_buffer_ of the current global declaration / function.
+ size_t global_insertion_point_ = 0;
};
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/ast_printer/builtin_test.cc b/src/tint/lang/hlsl/writer/ast_printer/builtin_test.cc
index d9d6d89..803be81 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/builtin_test.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/builtin_test.cc
@@ -588,10 +588,10 @@
float3 fract;
float3 whole;
};
+static const modf_result_vec3_f32 c = {(0.5f).xxx, float3(4.0f, 5.0f, 6.0f)};
[numthreads(1, 1, 1)]
void test_function() {
modf_result_vec3_f32 v = {(0.5f).xxx, float3(1.0f, 2.0f, 3.0f)};
- const modf_result_vec3_f32 c = {(0.5f).xxx, float3(4.0f, 5.0f, 6.0f)};
v = c;
return;
}
@@ -802,10 +802,10 @@
float3 fract;
int3 exp;
};
+static const frexp_result_vec3_f32 c = {float3(0.5625f, 0.6875f, 0.8125f), (3).xxx};
[numthreads(1, 1, 1)]
void test_function() {
frexp_result_vec3_f32 v = {float3(0.75f, 0.625f, 0.875f), int3(1, 2, 2)};
- const frexp_result_vec3_f32 c = {float3(0.5625f, 0.6875f, 0.8125f), (3).xxx};
v = c;
return;
}
diff --git a/src/tint/utils/generator/text_generator.cc b/src/tint/utils/generator/text_generator.cc
index 905449a..0bcab39 100644
--- a/src/tint/utils/generator/text_generator.cc
+++ b/src/tint/utils/generator/text_generator.cc
@@ -68,8 +68,8 @@
}
void TextGenerator::TextBuffer::Insert(const std::string& line, size_t before, uint32_t indent) {
- if (TINT_UNLIKELY(before >= lines.size())) {
- TINT_ICE() << "TextBuffer::Insert() called with before >= lines.size()\n"
+ if (TINT_UNLIKELY(before > lines.size())) {
+ TINT_ICE() << "TextBuffer::Insert() called with before > lines.size()\n"
<< " before:" << before << "\n"
<< " lines.size(): " << lines.size();
return;
@@ -86,8 +86,8 @@
}
void TextGenerator::TextBuffer::Insert(const TextBuffer& tb, size_t before, uint32_t indent) {
- if (TINT_UNLIKELY(before >= lines.size())) {
- TINT_ICE() << "TextBuffer::Insert() called with before >= lines.size()\n"
+ if (TINT_UNLIKELY(before > lines.size())) {
+ TINT_ICE() << "TextBuffer::Insert() called with before > lines.size()\n"
<< " before:" << before << "\n"
<< " lines.size(): " << lines.size();
return;
diff --git a/test/tint/bug/chromium/1430309.wgsl.expected.dxc.hlsl b/test/tint/bug/chromium/1430309.wgsl.expected.dxc.hlsl
index 9f62ac4..94574a4 100644
--- a/test/tint/bug/chromium/1430309.wgsl.expected.dxc.hlsl
+++ b/test/tint/bug/chromium/1430309.wgsl.expected.dxc.hlsl
@@ -7,7 +7,7 @@
};
static frexp_result_f32 a = (frexp_result_f32)0;
-const frexp_result_f32_1 c = {0.5f, 1};
+static const frexp_result_f32_1 c = {0.5f, 1};
static frexp_result_f32_1 b = c;
struct tint_symbol {
diff --git a/test/tint/bug/chromium/1430309.wgsl.expected.fxc.hlsl b/test/tint/bug/chromium/1430309.wgsl.expected.fxc.hlsl
index 9f62ac4..94574a4 100644
--- a/test/tint/bug/chromium/1430309.wgsl.expected.fxc.hlsl
+++ b/test/tint/bug/chromium/1430309.wgsl.expected.fxc.hlsl
@@ -7,7 +7,7 @@
};
static frexp_result_f32 a = (frexp_result_f32)0;
-const frexp_result_f32_1 c = {0.5f, 1};
+static const frexp_result_f32_1 c = {0.5f, 1};
static frexp_result_f32_1 b = c;
struct tint_symbol {
diff --git a/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.dxc.hlsl b/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.dxc.hlsl
index 78f8457..cc064b3 100644
--- a/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.dxc.hlsl
+++ b/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.dxc.hlsl
@@ -9,12 +9,12 @@
return result;
}
+static const frexp_result_f32 c = {0.625f, 1};
[numthreads(1, 1, 1)]
void main() {
const float runtime_in = 1.25f;
frexp_result_f32 res = {0.625f, 1};
res = tint_frexp(runtime_in);
- const frexp_result_f32 c = {0.625f, 1};
res = c;
const float fract = res.fract;
const int exp = res.exp;
diff --git a/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.fxc.hlsl b/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.fxc.hlsl
index 78f8457..cc064b3 100644
--- a/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/frexp/scalar/mixed.wgsl.expected.fxc.hlsl
@@ -9,12 +9,12 @@
return result;
}
+static const frexp_result_f32 c = {0.625f, 1};
[numthreads(1, 1, 1)]
void main() {
const float runtime_in = 1.25f;
frexp_result_f32 res = {0.625f, 1};
res = tint_frexp(runtime_in);
- const frexp_result_f32 c = {0.625f, 1};
res = c;
const float fract = res.fract;
const int exp = res.exp;
diff --git a/test/tint/builtins/frexp/vector/mixed.wgsl.expected.dxc.hlsl b/test/tint/builtins/frexp/vector/mixed.wgsl.expected.dxc.hlsl
index c5fc8db..cf863a4 100644
--- a/test/tint/builtins/frexp/vector/mixed.wgsl.expected.dxc.hlsl
+++ b/test/tint/builtins/frexp/vector/mixed.wgsl.expected.dxc.hlsl
@@ -9,12 +9,12 @@
return result;
}
+static const frexp_result_vec2_f32 c = {float2(0.625f, 0.9375f), int2(1, 2)};
[numthreads(1, 1, 1)]
void main() {
const float2 runtime_in = float2(1.25f, 3.75f);
frexp_result_vec2_f32 res = {float2(0.625f, 0.9375f), int2(1, 2)};
res = tint_frexp(runtime_in);
- const frexp_result_vec2_f32 c = {float2(0.625f, 0.9375f), int2(1, 2)};
res = c;
const float2 fract = res.fract;
const int2 exp = res.exp;
diff --git a/test/tint/builtins/frexp/vector/mixed.wgsl.expected.fxc.hlsl b/test/tint/builtins/frexp/vector/mixed.wgsl.expected.fxc.hlsl
index c5fc8db..cf863a4 100644
--- a/test/tint/builtins/frexp/vector/mixed.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/frexp/vector/mixed.wgsl.expected.fxc.hlsl
@@ -9,12 +9,12 @@
return result;
}
+static const frexp_result_vec2_f32 c = {float2(0.625f, 0.9375f), int2(1, 2)};
[numthreads(1, 1, 1)]
void main() {
const float2 runtime_in = float2(1.25f, 3.75f);
frexp_result_vec2_f32 res = {float2(0.625f, 0.9375f), int2(1, 2)};
res = tint_frexp(runtime_in);
- const frexp_result_vec2_f32 c = {float2(0.625f, 0.9375f), int2(1, 2)};
res = c;
const float2 fract = res.fract;
const int2 exp = res.exp;
diff --git a/test/tint/builtins/gen/literal/textureBarrier/3d0f7e.wgsl.expected.fxc.hlsl b/test/tint/builtins/gen/literal/textureBarrier/3d0f7e.wgsl.expected.fxc.hlsl
index 7eb8c9b..ce8ab3b8 100644
--- a/test/tint/builtins/gen/literal/textureBarrier/3d0f7e.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/gen/literal/textureBarrier/3d0f7e.wgsl.expected.fxc.hlsl
@@ -1,5 +1,3 @@
-SKIP: FAILED
-
void textureBarrier_3d0f7e() {
DeviceMemoryBarrierWithGroupSync();
}
diff --git a/test/tint/builtins/modf/scalar/mixed.wgsl.expected.dxc.hlsl b/test/tint/builtins/modf/scalar/mixed.wgsl.expected.dxc.hlsl
index 852de4c..0cbff50 100644
--- a/test/tint/builtins/modf/scalar/mixed.wgsl.expected.dxc.hlsl
+++ b/test/tint/builtins/modf/scalar/mixed.wgsl.expected.dxc.hlsl
@@ -8,12 +8,12 @@
return result;
}
+static const modf_result_f32 c = {0.25f, 1.0f};
[numthreads(1, 1, 1)]
void main() {
const float runtime_in = 1.25f;
modf_result_f32 res = {0.25f, 1.0f};
res = tint_modf(runtime_in);
- const modf_result_f32 c = {0.25f, 1.0f};
res = c;
const float fract = res.fract;
const float whole = res.whole;
diff --git a/test/tint/builtins/modf/scalar/mixed.wgsl.expected.fxc.hlsl b/test/tint/builtins/modf/scalar/mixed.wgsl.expected.fxc.hlsl
index 852de4c..0cbff50 100644
--- a/test/tint/builtins/modf/scalar/mixed.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/modf/scalar/mixed.wgsl.expected.fxc.hlsl
@@ -8,12 +8,12 @@
return result;
}
+static const modf_result_f32 c = {0.25f, 1.0f};
[numthreads(1, 1, 1)]
void main() {
const float runtime_in = 1.25f;
modf_result_f32 res = {0.25f, 1.0f};
res = tint_modf(runtime_in);
- const modf_result_f32 c = {0.25f, 1.0f};
res = c;
const float fract = res.fract;
const float whole = res.whole;
diff --git a/test/tint/builtins/modf/vector/mixed.wgsl.expected.dxc.hlsl b/test/tint/builtins/modf/vector/mixed.wgsl.expected.dxc.hlsl
index f917164..cd83834 100644
--- a/test/tint/builtins/modf/vector/mixed.wgsl.expected.dxc.hlsl
+++ b/test/tint/builtins/modf/vector/mixed.wgsl.expected.dxc.hlsl
@@ -8,12 +8,12 @@
return result;
}
+static const modf_result_vec2_f32 c = {float2(0.25f, 0.75f), float2(1.0f, 3.0f)};
[numthreads(1, 1, 1)]
void main() {
const float2 runtime_in = float2(1.25f, 3.75f);
modf_result_vec2_f32 res = {float2(0.25f, 0.75f), float2(1.0f, 3.0f)};
res = tint_modf(runtime_in);
- const modf_result_vec2_f32 c = {float2(0.25f, 0.75f), float2(1.0f, 3.0f)};
res = c;
const float2 fract = res.fract;
const float2 whole = res.whole;
diff --git a/test/tint/builtins/modf/vector/mixed.wgsl.expected.fxc.hlsl b/test/tint/builtins/modf/vector/mixed.wgsl.expected.fxc.hlsl
index f917164..cd83834 100644
--- a/test/tint/builtins/modf/vector/mixed.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/modf/vector/mixed.wgsl.expected.fxc.hlsl
@@ -8,12 +8,12 @@
return result;
}
+static const modf_result_vec2_f32 c = {float2(0.25f, 0.75f), float2(1.0f, 3.0f)};
[numthreads(1, 1, 1)]
void main() {
const float2 runtime_in = float2(1.25f, 3.75f);
modf_result_vec2_f32 res = {float2(0.25f, 0.75f), float2(1.0f, 3.0f)};
res = tint_modf(runtime_in);
- const modf_result_vec2_f32 c = {float2(0.25f, 0.75f), float2(1.0f, 3.0f)};
res = c;
const float2 fract = res.fract;
const float2 whole = res.whole;
diff --git a/test/tint/var/inferred/global.wgsl.expected.dxc.hlsl b/test/tint/var/inferred/global.wgsl.expected.dxc.hlsl
index ec015d8..221f4df 100644
--- a/test/tint/var/inferred/global.wgsl.expected.dxc.hlsl
+++ b/test/tint/var/inferred/global.wgsl.expected.dxc.hlsl
@@ -8,7 +8,7 @@
static int3 v4 = (1).xxx;
static uint3 v5 = uint3(1u, 2u, 3u);
static float3 v6 = float3(1.0f, 2.0f, 3.0f);
-const MyStruct c = {1.0f};
+static const MyStruct c = {1.0f};
static MyStruct v7 = c;
static float v8[10] = (float[10])0;
static int v9 = 0;
diff --git a/test/tint/var/inferred/global.wgsl.expected.fxc.hlsl b/test/tint/var/inferred/global.wgsl.expected.fxc.hlsl
index ec015d8..221f4df 100644
--- a/test/tint/var/inferred/global.wgsl.expected.fxc.hlsl
+++ b/test/tint/var/inferred/global.wgsl.expected.fxc.hlsl
@@ -8,7 +8,7 @@
static int3 v4 = (1).xxx;
static uint3 v5 = uint3(1u, 2u, 3u);
static float3 v6 = float3(1.0f, 2.0f, 3.0f);
-const MyStruct c = {1.0f};
+static const MyStruct c = {1.0f};
static MyStruct v7 = c;
static float v8[10] = (float[10])0;
static int v9 = 0;