tint/transform: Fix NPE in ZeroInitWorkgroupMemory.
If an array uses an override expression, then we'd raise an error, but then attempt to dereference a nullptr.
Bug: chromium:1392853
Change-Id: Ib1d538bc491923b628b32f2398f8b2ace24c3bc3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112561
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc
index ed3584e9..2f3fd94 100644
--- a/src/tint/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/transform/zero_init_workgroup_memory.cc
@@ -100,6 +100,9 @@
uint32_t num_iterations = 0;
/// All array indices used by this expression
ArrayIndices array_indices;
+
+ /// @returns true if the expr is not null (null usually indicates a failure)
+ operator bool() const { return expr != nullptr; }
};
/// Statement holds information about a statement that will zero workgroup
@@ -137,10 +140,13 @@
auto* func = sem.Get(fn);
for (auto* var : func->TransitivelyReferencedGlobals()) {
if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) {
- BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) {
+ auto get_expr = [&](uint32_t num_values) {
auto var_name = ctx.Clone(var->Declaration()->symbol);
return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
- });
+ };
+ if (!BuildZeroingStatements(var->Type()->UnwrapRef(), get_expr)) {
+ return;
+ }
}
}
@@ -283,41 +289,54 @@
/// initialize the workgroup storage expression of type `ty`.
/// @param ty the expression type
/// @param get_expr a function that builds the AST nodes for the expression.
- void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) {
+ /// @returns true on success, false on failure
+ [[nodiscard]] bool BuildZeroingStatements(const sem::Type* ty,
+ const BuildZeroingExpr& get_expr) {
if (CanTriviallyZero(ty)) {
auto var = get_expr(1u);
+ if (!var) {
+ return false;
+ }
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
statements.emplace_back(
Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices});
- return;
+ return true;
}
if (auto* atomic = ty->As<sem::Atomic>()) {
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
auto expr = get_expr(1u);
+ if (!expr) {
+ return false;
+ }
auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
statements.emplace_back(
Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices});
- return;
+ return true;
}
if (auto* str = ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto name = ctx.Clone(member->Declaration()->symbol);
- BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
+ auto get_member = [&](uint32_t num_values) {
auto s = get_expr(num_values);
+ if (!s) {
+ return Expression{}; // error
+ }
return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
s.array_indices};
- });
+ };
+ if (!BuildZeroingStatements(member->Type(), get_member)) {
+ return false;
+ }
}
- return;
+ return true;
}
if (auto* arr = ty->As<sem::Array>()) {
- BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
+ auto get_el = [&](uint32_t num_values) {
// num_values is the number of values to zero for the element type.
- // The number of iterations required to zero the array and its elements
- // is:
+ // The number of iterations required to zero the array and its elements is:
// `num_values * arr->Count()`
// The index for this array is:
// `(idx % modulo) / division`
@@ -325,22 +344,26 @@
if (!count) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
- return Expression{};
+ return Expression{}; // error
}
auto modulo = num_values * count.value();
auto division = num_values;
auto a = get_expr(modulo);
+ if (!a) {
+ return Expression{}; // error
+ }
auto array_indices = a.array_indices;
array_indices.Add(ArrayIndex{modulo, division});
auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
[&] { return b.Symbols().New("i"); });
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
- });
- return;
+ };
+ return BuildZeroingStatements(arr->ElemType(), get_el);
}
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols());
+ return false;
}
/// DeclareArrayIndices returns a list of statements that contain the `let`
diff --git a/src/tint/transform/zero_init_workgroup_memory_test.cc b/src/tint/transform/zero_init_workgroup_memory_test.cc
index 93f3933..c067f1b 100644
--- a/src/tint/transform/zero_init_workgroup_memory_test.cc
+++ b/src/tint/transform/zero_init_workgroup_memory_test.cc
@@ -1363,5 +1363,28 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(ZeroInitWorkgroupMemoryTest, ArrayWithOverrideCount) {
+ auto* src =
+ R"(override O = 123;
+type A = array<i32, O*2>;
+
+var<workgroup> W : A;
+
+@compute @workgroup_size(1)
+fn main() {
+ let p : ptr<workgroup, A> = &W;
+ (*p)[0] = 42;
+}
+)";
+
+ auto* expect =
+ R"(error: array size is an override-expression, when expected a constant-expression.
+Was the SubstituteOverride transform run?)";
+
+ auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::transform