tint/transform: Handle arrays of complex override lengths
Update CreateASTTypeFor() to handle a potential edge-case described in tint:1764.
We haven't seen this issue happen in production, nor can I find a way to trigger this with the tint executable, but try to handle this before we encounter a nasty bug.
Fixed: tint:1764
Change-Id: I496932955a6fdcbe26eacef8dcd04988f92545a1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111040
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index c37f3b4..5c8357c 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -114,6 +114,19 @@
return ctx.dst->ty.array(el, count, std::move(attrs));
}
if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) {
+ // If the array count is an unnamed (complex) override expression, then its not safe to
+ // redeclare this type as we'd end up with two types that would not compare equal.
+ // See crbug.com/tint/1764.
+ // Look for a type alias for this array.
+ for (auto* type_decl : ctx.src->AST().TypeDecls()) {
+ if (auto* alias = type_decl->As<ast::Alias>()) {
+ if (ty == ctx.src->Sem().Get(alias)) {
+ // Alias found. Use the alias name to ensure types compare equal.
+ return ctx.dst->ty.type_name(ctx.Clone(alias->name));
+ }
+ }
+ }
+ // Array is not aliased. Rebuild the array.
auto* count = ctx.Clone(override->expr->Declaration());
return ctx.dst->ty.array(el, count, std::move(attrs));
}
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index 82fdf6a..4b8ad53 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -21,6 +21,8 @@
namespace tint::transform {
namespace {
+using namespace tint::number_suffixes; // NOLINT
+
// Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform {
ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
@@ -95,6 +97,28 @@
EXPECT_EQ(size->value, 2);
}
+// crbug.com/tint/1764
+TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
+ // override O = 123;
+ // type A = array<i32, O*2>;
+ //
+ // var<workgroup> W : A;
+ //
+ ProgramBuilder b;
+ auto* arr_len = b.Mul("O", 2_a);
+ b.Override("O", b.Expr(123_a));
+ auto* alias = b.Alias("A", b.ty.array(b.ty.i32(), arr_len));
+
+ Program program(std::move(b));
+
+ auto* arr_ty = program.Sem().Get(alias);
+
+ CloneContext ctx(&ast_type_builder, &program, false);
+ auto* ast_ty = tint::As<ast::TypeName>(CreateASTTypeFor(ctx, arr_ty));
+ ASSERT_NE(ast_ty, nullptr);
+ EXPECT_EQ(ast_type_builder.Symbols().NameFor(ast_ty->name), "A");
+}
+
TEST_F(CreateASTTypeForTest, Struct) {
auto* str = create([](ProgramBuilder& b) {
auto* decl = b.Structure("S", {});
diff --git a/test/tint/bug/tint/1764.wgsl b/test/tint/bug/tint/1764.wgsl
new file mode 100644
index 0000000..0b9541b
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl
@@ -0,0 +1,12 @@
+// flags: --transform substitute_override
+
+override O = 123;
+type A = array<i32, O*2>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+ let p : ptr<workgroup, A> = &W;
+ (*p)[0] = 42;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1764.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..e915d81
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.dxc.hlsl
@@ -0,0 +1,22 @@
+groupshared int W[246];
+
+struct tint_symbol_1 {
+ uint local_invocation_index : SV_GroupIndex;
+};
+
+void main_inner(uint local_invocation_index) {
+ {
+ for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+ const uint i = idx;
+ W[i] = 0;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ W[0] = 42;
+}
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+ main_inner(tint_symbol.local_invocation_index);
+ return;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1764.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..e915d81
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.fxc.hlsl
@@ -0,0 +1,22 @@
+groupshared int W[246];
+
+struct tint_symbol_1 {
+ uint local_invocation_index : SV_GroupIndex;
+};
+
+void main_inner(uint local_invocation_index) {
+ {
+ for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+ const uint i = idx;
+ W[i] = 0;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ W[0] = 42;
+}
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+ main_inner(tint_symbol.local_invocation_index);
+ return;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.glsl b/test/tint/bug/tint/1764.wgsl.expected.glsl
new file mode 100644
index 0000000..b5175cb
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.glsl
@@ -0,0 +1,19 @@
+#version 310 es
+
+shared int W[246];
+void tint_symbol(uint local_invocation_index) {
+ {
+ for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+ uint i = idx;
+ W[i] = 0;
+ }
+ }
+ barrier();
+ W[0] = 42;
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+ tint_symbol(gl_LocalInvocationIndex);
+ return;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.msl b/test/tint/bug/tint/1764.wgsl.expected.msl
new file mode 100644
index 0000000..06b186f
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.msl
@@ -0,0 +1,31 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
+
+void tint_symbol_inner(uint local_invocation_index, threadgroup tint_array<int, 246>* const tint_symbol_1) {
+ for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+ uint const i = idx;
+ (*(tint_symbol_1))[i] = 0;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ (*(tint_symbol_1))[0] = 42;
+}
+
+kernel void tint_symbol(uint local_invocation_index [[thread_index_in_threadgroup]]) {
+ threadgroup tint_array<int, 246> tint_symbol_2;
+ tint_symbol_inner(local_invocation_index, &(tint_symbol_2));
+ return;
+}
+
diff --git a/test/tint/bug/tint/1764.wgsl.expected.spvasm b/test/tint/bug/tint/1764.wgsl.expected.spvasm
new file mode 100644
index 0000000..aeefee4
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.spvasm
@@ -0,0 +1,76 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 44
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main" %local_invocation_index_1
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %local_invocation_index_1 "local_invocation_index_1"
+ OpName %W "W"
+ OpName %main_inner "main_inner"
+ OpName %local_invocation_index "local_invocation_index"
+ OpName %idx "idx"
+ OpName %main "main"
+ OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
+ OpDecorate %_arr_int_uint_246 ArrayStride 4
+ %uint = OpTypeInt 32 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+%local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
+ %int = OpTypeInt 32 1
+ %uint_246 = OpConstant %uint 246
+%_arr_int_uint_246 = OpTypeArray %int %uint_246
+%_ptr_Workgroup__arr_int_uint_246 = OpTypePointer Workgroup %_arr_int_uint_246
+ %W = OpVariable %_ptr_Workgroup__arr_int_uint_246 Workgroup
+ %void = OpTypeVoid
+ %9 = OpTypeFunction %void %uint
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %16 = OpConstantNull %uint
+ %bool = OpTypeBool
+%_ptr_Workgroup_int = OpTypePointer Workgroup %int
+ %30 = OpConstantNull %int
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %uint_264 = OpConstant %uint 264
+ %int_42 = OpConstant %int 42
+ %39 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %9
+%local_invocation_index = OpFunctionParameter %uint
+ %13 = OpLabel
+ %idx = OpVariable %_ptr_Function_uint Function %16
+ OpStore %idx %local_invocation_index
+ OpBranch %17
+ %17 = OpLabel
+ OpLoopMerge %18 %19 None
+ OpBranch %20
+ %20 = OpLabel
+ %22 = OpLoad %uint %idx
+ %23 = OpULessThan %bool %22 %uint_246
+ %21 = OpLogicalNot %bool %23
+ OpSelectionMerge %25 None
+ OpBranchConditional %21 %26 %25
+ %26 = OpLabel
+ OpBranch %18
+ %25 = OpLabel
+ %27 = OpLoad %uint %idx
+ %29 = OpAccessChain %_ptr_Workgroup_int %W %27
+ OpStore %29 %30
+ OpBranch %19
+ %19 = OpLabel
+ %31 = OpLoad %uint %idx
+ %33 = OpIAdd %uint %31 %uint_1
+ OpStore %idx %33
+ OpBranch %17
+ %18 = OpLabel
+ OpControlBarrier %uint_2 %uint_2 %uint_264
+ %37 = OpAccessChain %_ptr_Workgroup_int %W %30
+ OpStore %37 %int_42
+ OpReturn
+ OpFunctionEnd
+ %main = OpFunction %void None %39
+ %41 = OpLabel
+ %43 = OpLoad %uint %local_invocation_index_1
+ %42 = OpFunctionCall %void %main_inner %43
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/tint/1764.wgsl.expected.wgsl b/test/tint/bug/tint/1764.wgsl.expected.wgsl
new file mode 100644
index 0000000..dd336f9
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.wgsl
@@ -0,0 +1,11 @@
+const O = 123;
+
+type A = array<i32, (O * 2)>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+ let p : ptr<workgroup, A> = &(W);
+ (*(p))[0] = 42;
+}