tint/transform: Handle arrays of complex override lengths

Update CreateASTTypeFor() to handle a potential edge-case described in tint:1764.

We haven't seen this issue happen in production, nor can I find a way to trigger this with the tint executable, but try to handle this before we encounter a nasty bug.

Fixed: tint:1764
Change-Id: I496932955a6fdcbe26eacef8dcd04988f92545a1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111040
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index c37f3b4..5c8357c 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -114,6 +114,19 @@
             return ctx.dst->ty.array(el, count, std::move(attrs));
         }
         if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) {
+            // If the array count is an unnamed (complex) override expression, then its not safe to
+            // redeclare this type as we'd end up with two types that would not compare equal.
+            // See crbug.com/tint/1764.
+            // Look for a type alias for this array.
+            for (auto* type_decl : ctx.src->AST().TypeDecls()) {
+                if (auto* alias = type_decl->As<ast::Alias>()) {
+                    if (ty == ctx.src->Sem().Get(alias)) {
+                        // Alias found. Use the alias name to ensure types compare equal.
+                        return ctx.dst->ty.type_name(ctx.Clone(alias->name));
+                    }
+                }
+            }
+            // Array is not aliased. Rebuild the array.
             auto* count = ctx.Clone(override->expr->Declaration());
             return ctx.dst->ty.array(el, count, std::move(attrs));
         }
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index 82fdf6a..4b8ad53 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -21,6 +21,8 @@
 namespace tint::transform {
 namespace {
 
+using namespace tint::number_suffixes;  // NOLINT
+
 // Inherit from Transform so we have access to protected methods
 struct CreateASTTypeForTest : public testing::Test, public Transform {
     ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
@@ -95,6 +97,28 @@
     EXPECT_EQ(size->value, 2);
 }
 
+// crbug.com/tint/1764
+TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
+    // override O = 123;
+    // type A = array<i32, O*2>;
+    //
+    // var<workgroup> W : A;
+    //
+    ProgramBuilder b;
+    auto* arr_len = b.Mul("O", 2_a);
+    b.Override("O", b.Expr(123_a));
+    auto* alias = b.Alias("A", b.ty.array(b.ty.i32(), arr_len));
+
+    Program program(std::move(b));
+
+    auto* arr_ty = program.Sem().Get(alias);
+
+    CloneContext ctx(&ast_type_builder, &program, false);
+    auto* ast_ty = tint::As<ast::TypeName>(CreateASTTypeFor(ctx, arr_ty));
+    ASSERT_NE(ast_ty, nullptr);
+    EXPECT_EQ(ast_type_builder.Symbols().NameFor(ast_ty->name), "A");
+}
+
 TEST_F(CreateASTTypeForTest, Struct) {
     auto* str = create([](ProgramBuilder& b) {
         auto* decl = b.Structure("S", {});
diff --git a/test/tint/bug/tint/1764.wgsl b/test/tint/bug/tint/1764.wgsl
new file mode 100644
index 0000000..0b9541b
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl
@@ -0,0 +1,12 @@
+// flags: --transform substitute_override
+
+override O = 123;
+type A = array<i32, O*2>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+    let p : ptr<workgroup, A> = &W;
+    (*p)[0] = 42;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1764.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..e915d81
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.dxc.hlsl
@@ -0,0 +1,22 @@
+groupshared int W[246];
+
+struct tint_symbol_1 {
+  uint local_invocation_index : SV_GroupIndex;
+};
+
+void main_inner(uint local_invocation_index) {
+  {
+    for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+      const uint i = idx;
+      W[i] = 0;
+    }
+  }
+  GroupMemoryBarrierWithGroupSync();
+  W[0] = 42;
+}
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+  main_inner(tint_symbol.local_invocation_index);
+  return;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1764.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..e915d81
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.fxc.hlsl
@@ -0,0 +1,22 @@
+groupshared int W[246];
+
+struct tint_symbol_1 {
+  uint local_invocation_index : SV_GroupIndex;
+};
+
+void main_inner(uint local_invocation_index) {
+  {
+    for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+      const uint i = idx;
+      W[i] = 0;
+    }
+  }
+  GroupMemoryBarrierWithGroupSync();
+  W[0] = 42;
+}
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+  main_inner(tint_symbol.local_invocation_index);
+  return;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.glsl b/test/tint/bug/tint/1764.wgsl.expected.glsl
new file mode 100644
index 0000000..b5175cb
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.glsl
@@ -0,0 +1,19 @@
+#version 310 es
+
+shared int W[246];
+void tint_symbol(uint local_invocation_index) {
+  {
+    for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+      uint i = idx;
+      W[i] = 0;
+    }
+  }
+  barrier();
+  W[0] = 42;
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+  tint_symbol(gl_LocalInvocationIndex);
+  return;
+}
diff --git a/test/tint/bug/tint/1764.wgsl.expected.msl b/test/tint/bug/tint/1764.wgsl.expected.msl
new file mode 100644
index 0000000..06b186f
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.msl
@@ -0,0 +1,31 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+template<typename T, size_t N>
+struct tint_array {
+    const constant T& operator[](size_t i) const constant { return elements[i]; }
+    device T& operator[](size_t i) device { return elements[i]; }
+    const device T& operator[](size_t i) const device { return elements[i]; }
+    thread T& operator[](size_t i) thread { return elements[i]; }
+    const thread T& operator[](size_t i) const thread { return elements[i]; }
+    threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+    const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+    T elements[N];
+};
+
+void tint_symbol_inner(uint local_invocation_index, threadgroup tint_array<int, 246>* const tint_symbol_1) {
+  for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
+    uint const i = idx;
+    (*(tint_symbol_1))[i] = 0;
+  }
+  threadgroup_barrier(mem_flags::mem_threadgroup);
+  (*(tint_symbol_1))[0] = 42;
+}
+
+kernel void tint_symbol(uint local_invocation_index [[thread_index_in_threadgroup]]) {
+  threadgroup tint_array<int, 246> tint_symbol_2;
+  tint_symbol_inner(local_invocation_index, &(tint_symbol_2));
+  return;
+}
+
diff --git a/test/tint/bug/tint/1764.wgsl.expected.spvasm b/test/tint/bug/tint/1764.wgsl.expected.spvasm
new file mode 100644
index 0000000..aeefee4
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.spvasm
@@ -0,0 +1,76 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 44
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %local_invocation_index_1
+               OpExecutionMode %main LocalSize 1 1 1
+               OpName %local_invocation_index_1 "local_invocation_index_1"
+               OpName %W "W"
+               OpName %main_inner "main_inner"
+               OpName %local_invocation_index "local_invocation_index"
+               OpName %idx "idx"
+               OpName %main "main"
+               OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
+               OpDecorate %_arr_int_uint_246 ArrayStride 4
+       %uint = OpTypeInt 32 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+%local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
+        %int = OpTypeInt 32 1
+   %uint_246 = OpConstant %uint 246
+%_arr_int_uint_246 = OpTypeArray %int %uint_246
+%_ptr_Workgroup__arr_int_uint_246 = OpTypePointer Workgroup %_arr_int_uint_246
+          %W = OpVariable %_ptr_Workgroup__arr_int_uint_246 Workgroup
+       %void = OpTypeVoid
+          %9 = OpTypeFunction %void %uint
+%_ptr_Function_uint = OpTypePointer Function %uint
+         %16 = OpConstantNull %uint
+       %bool = OpTypeBool
+%_ptr_Workgroup_int = OpTypePointer Workgroup %int
+         %30 = OpConstantNull %int
+     %uint_1 = OpConstant %uint 1
+     %uint_2 = OpConstant %uint 2
+   %uint_264 = OpConstant %uint 264
+     %int_42 = OpConstant %int 42
+         %39 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %9
+%local_invocation_index = OpFunctionParameter %uint
+         %13 = OpLabel
+        %idx = OpVariable %_ptr_Function_uint Function %16
+               OpStore %idx %local_invocation_index
+               OpBranch %17
+         %17 = OpLabel
+               OpLoopMerge %18 %19 None
+               OpBranch %20
+         %20 = OpLabel
+         %22 = OpLoad %uint %idx
+         %23 = OpULessThan %bool %22 %uint_246
+         %21 = OpLogicalNot %bool %23
+               OpSelectionMerge %25 None
+               OpBranchConditional %21 %26 %25
+         %26 = OpLabel
+               OpBranch %18
+         %25 = OpLabel
+         %27 = OpLoad %uint %idx
+         %29 = OpAccessChain %_ptr_Workgroup_int %W %27
+               OpStore %29 %30
+               OpBranch %19
+         %19 = OpLabel
+         %31 = OpLoad %uint %idx
+         %33 = OpIAdd %uint %31 %uint_1
+               OpStore %idx %33
+               OpBranch %17
+         %18 = OpLabel
+               OpControlBarrier %uint_2 %uint_2 %uint_264
+         %37 = OpAccessChain %_ptr_Workgroup_int %W %30
+               OpStore %37 %int_42
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %39
+         %41 = OpLabel
+         %43 = OpLoad %uint %local_invocation_index_1
+         %42 = OpFunctionCall %void %main_inner %43
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/tint/1764.wgsl.expected.wgsl b/test/tint/bug/tint/1764.wgsl.expected.wgsl
new file mode 100644
index 0000000..dd336f9
--- /dev/null
+++ b/test/tint/bug/tint/1764.wgsl.expected.wgsl
@@ -0,0 +1,11 @@
+const O = 123;
+
+type A = array<i32, (O * 2)>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+  let p : ptr<workgroup, A> = &(W);
+  (*(p))[0] = 42;
+}