tint/transform: Fix NPE in ZeroInitWorkgroupMemory.

If an array uses an override expression, then we'd raise an error, but then attempt to dereference a nullptr.

Bug: chromium:1392853
Change-Id: Ib1d538bc491923b628b32f2398f8b2ace24c3bc3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112561
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc
index ed3584e9..2f3fd94 100644
--- a/src/tint/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/transform/zero_init_workgroup_memory.cc
@@ -100,6 +100,9 @@
         uint32_t num_iterations = 0;
         /// All array indices used by this expression
         ArrayIndices array_indices;
+
+        /// @returns true if the expr is not null (null usually indicates a failure)
+        operator bool() const { return expr != nullptr; }
     };
 
     /// Statement holds information about a statement that will zero workgroup
@@ -137,10 +140,13 @@
         auto* func = sem.Get(fn);
         for (auto* var : func->TransitivelyReferencedGlobals()) {
             if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) {
-                BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) {
+                auto get_expr = [&](uint32_t num_values) {
                     auto var_name = ctx.Clone(var->Declaration()->symbol);
                     return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
-                });
+                };
+                if (!BuildZeroingStatements(var->Type()->UnwrapRef(), get_expr)) {
+                    return;
+                }
             }
         }
 
@@ -283,41 +289,54 @@
     /// initialize the workgroup storage expression of type `ty`.
     /// @param ty the expression type
     /// @param get_expr a function that builds the AST nodes for the expression.
-    void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) {
+    /// @returns true on success, false on failure
+    [[nodiscard]] bool BuildZeroingStatements(const sem::Type* ty,
+                                              const BuildZeroingExpr& get_expr) {
         if (CanTriviallyZero(ty)) {
             auto var = get_expr(1u);
+            if (!var) {
+                return false;
+            }
             auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
             statements.emplace_back(
                 Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
-            return;
+            return true;
         }
 
         if (auto* atomic = ty->As<sem::Atomic>()) {
             auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
             auto expr = get_expr(1u);
+            if (!expr) {
+                return false;
+            }
             auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
             statements.emplace_back(
                 Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
-            return;
+            return true;
         }
 
         if (auto* str = ty->As<sem::Struct>()) {
             for (auto* member : str->Members()) {
                 auto name = ctx.Clone(member->Declaration()->symbol);
-                BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
+                auto get_member = [&](uint32_t num_values) {
                     auto s = get_expr(num_values);
+                    if (!s) {
+                        return Expression{};  // error
+                    }
                     return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
                                       s.array_indices};
-                });
+                };
+                if (!BuildZeroingStatements(member->Type(), get_member)) {
+                    return false;
+                }
             }
-            return;
+            return true;
         }
 
         if (auto* arr = ty->As<sem::Array>()) {
-            BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
+            auto get_el = [&](uint32_t num_values) {
                 // num_values is the number of values to zero for the element type.
-                // The number of iterations required to zero the array and its elements
-                // is:
+                // The number of iterations required to zero the array and its elements is:
                 //      `num_values * arr->Count()`
                 // The index for this array is:
                 //      `(idx % modulo) / division`
@@ -325,22 +344,26 @@
                 if (!count) {
                     ctx.dst->Diagnostics().add_error(diag::System::Transform,
                                                      sem::Array::kErrExpectedConstantCount);
-                    return Expression{};
+                    return Expression{};  // error
                 }
                 auto modulo = num_values * count.value();
                 auto division = num_values;
                 auto a = get_expr(modulo);
+                if (!a) {
+                    return Expression{};  // error
+                }
                 auto array_indices = a.array_indices;
                 array_indices.Add(ArrayIndex{modulo, division});
                 auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
                                                 [&] { return b.Symbols().New("i"); });
                 return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
-            });
-            return;
+            };
+            return BuildZeroingStatements(arr->ElemType(), get_el);
         }
 
         TINT_UNREACHABLE(Transform, b.Diagnostics())
             << "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
+        return false;
     }
 
     /// DeclareArrayIndices returns a list of statements that contain the `let`
diff --git a/src/tint/transform/zero_init_workgroup_memory_test.cc b/src/tint/transform/zero_init_workgroup_memory_test.cc
index 93f3933..c067f1b 100644
--- a/src/tint/transform/zero_init_workgroup_memory_test.cc
+++ b/src/tint/transform/zero_init_workgroup_memory_test.cc
@@ -1363,5 +1363,28 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(ZeroInitWorkgroupMemoryTest, ArrayWithOverrideCount) {
+    auto* src =
+        R"(override O = 123;
+type A = array<i32, O*2>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+    let p : ptr<workgroup, A> = &W;
+    (*p)[0] = 42;
+}
+)";
+
+    auto* expect =
+        R"(error: array size is an override-expression, when expected a constant-expression.
+Was the SubstituteOverride transform run?)";
+
+    auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
 }  // namespace
 }  // namespace tint::transform