Fix overrides in array size.
This CL fixes the usage of overrides in array sizes. Currently
the usage will generate a validation error as we check that the
array size is const.
Bug: tint:1660
Change-Id: Ibf440905c30a73b581d55b0c071b8621b61605e6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101900
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: dan sinclair <dsinclair@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 8208236..c8c6547 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -394,8 +394,10 @@
return builder.create<Splat>(type, zero_el, m->columns());
},
[&](const sem::Array* a) -> const ImplConstant* {
- if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
- return builder.create<Splat>(type, zero_el, a->Count());
+ if (auto n = a->ConstantCount()) {
+ if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
+ return builder.create<Splat>(type, zero_el, n.value());
+ }
}
return nullptr;
},
@@ -451,12 +453,16 @@
return true;
},
[&](const sem::Array* arr) {
- for (size_t i = 0; i < arr->Count(); i++) {
- if (!Equal(a->Index(i), b->Index(i))) {
- return false;
+ if (auto count = arr->ConstantCount()) {
+ for (size_t i = 0; i < count; i++) {
+ if (!Equal(a->Index(i), b->Index(i))) {
+ return false;
+ }
}
+ return true;
}
- return true;
+
+ return false;
},
[&](Default) { return a->Value() == b->Value(); });
}
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index 3b1d7ee..f36e374 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -1700,7 +1700,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
- EXPECT_EQ(arr->Count(), 4u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1738,7 +1738,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
- EXPECT_EQ(arr->Count(), 4u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1776,7 +1776,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
- EXPECT_EQ(arr->Count(), 2u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1828,7 +1828,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
- EXPECT_EQ(arr->Count(), 2u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1866,7 +1866,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
- EXPECT_EQ(arr->Count(), 4u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@@ -1904,7 +1904,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
- EXPECT_EQ(arr->Count(), 4u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@@ -1943,7 +1943,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
- EXPECT_EQ(arr->Count(), 2u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@@ -1973,7 +1973,7 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
- EXPECT_EQ(arr->Count(), 2u);
+ EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index cf87296..5d88258 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -877,7 +877,7 @@
Member("m", ty.atomic(ty.i32())),
});
auto* ret_type = ty.type_name(Source{{12, 34}}, "S");
- auto* bar = Param(Source{{12, 34}}, "bar", ret_type);
+ auto* bar = Param("bar", ret_type);
Func("f", utils::Vector{bar}, ty.void_(), utils::Empty);
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc
index 047ac78..2521959 100644
--- a/src/tint/resolver/inferred_type_test.cc
+++ b/src/tint/resolver/inferred_type_test.cc
@@ -135,7 +135,8 @@
TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
auto* type = ty.array(ty.u32(), 10_u);
- auto* expected_type = create<sem::Array>(create<sem::U32>(), 10u, 4u, 4u * 10u, 4u, 4u);
+ auto* expected_type =
+ create<sem::Array>(create<sem::U32>(), sem::ConstantArrayCount{10u}, 4u, 4u * 10u, 4u, 4u);
auto* ctor_expr = Construct(type);
auto* var = Var("a", ast::StorageClass::kFunction, ctor_expr);
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 4db3cb4..8533940 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -513,7 +513,7 @@
}
if (auto* a = ty->As<sem::Array>()) {
- if (a->Count() == 0) {
+ if (a->IsRuntimeSized()) {
T = a->ElemType();
return true;
}
@@ -523,7 +523,7 @@
const sem::Array* build_array(MatchState& state, const sem::Type* el) {
return state.builder.create<sem::Array>(el,
- /* count */ 0u,
+ /* count */ sem::RuntimeArrayCount{},
/* align */ 0u,
/* size */ 0u,
/* stride */ 0u,
diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc
index 0ffc011..b8f5629 100644
--- a/src/tint/resolver/intrinsic_table_test.cc
+++ b/src/tint/resolver/intrinsic_table_test.cc
@@ -235,7 +235,7 @@
}
TEST_F(IntrinsicTableTest, MatchArray) {
- auto* arr = create<sem::Array>(create<sem::U32>(), 0u, 4u, 4u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr_ptr = create<sem::Pointer>(arr, ast::StorageClass::kStorage, ast::Access::kReadWrite);
auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr}, Source{});
ASSERT_NE(result.sem, nullptr) << Diagnostics().str();
@@ -798,7 +798,7 @@
}
TEST_F(IntrinsicTableTest, MismatchTypeConversion) {
- auto* arr = create<sem::Array>(create<sem::U32>(), 0u, 4u, 4u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* f32 = create<sem::F32>();
auto result =
table->Lookup(CtorConvIntrinsic::kVec3, f32, utils::Vector{arr}, Source{{12, 34}});
diff --git a/src/tint/resolver/is_host_shareable_test.cc b/src/tint/resolver/is_host_shareable_test.cc
index 0ac3352..04f9b89 100644
--- a/src/tint/resolver/is_host_shareable_test.cc
+++ b/src/tint/resolver/is_host_shareable_test.cc
@@ -106,12 +106,13 @@
}
TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
diff --git a/src/tint/resolver/is_storeable_test.cc b/src/tint/resolver/is_storeable_test.cc
index 0423f72..3662f38 100644
--- a/src/tint/resolver/is_storeable_test.cc
+++ b/src/tint/resolver/is_storeable_test.cc
@@ -89,12 +89,13 @@
}
TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}
TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index af22c53..dd200bc 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -118,7 +118,9 @@
}
},
[&](const sem::Array* a) {
- for (uint32_t i = 0; i < a->Count(); i++) {
+ auto count = a->ConstantCount();
+ ASSERT_NE(count, 0u);
+ for (uint32_t i = 0; i < count; i++) {
auto* el = value->Index(i);
ASSERT_NE(el, nullptr);
EXPECT_TYPE(el->Type(), a->ElemType());
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index c1920d8..a7e917d 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1489,7 +1489,7 @@
target_el_ty = target_arr_ty->ElemType();
}
if (auto* el_ty = ConcreteType(a->ElemType(), target_el_ty, source)) {
- return Array(source, el_ty, a->Count(), /* explicit_stride */ 0);
+ return Array(source, source, el_ty, a->Count(), /* explicit_stride */ 0);
}
return nullptr;
});
@@ -1879,7 +1879,8 @@
[&](const ast::Array* a) -> sem::Call* {
Mark(a);
// array element type must be inferred if it was not specified.
- auto el_count = static_cast<uint32_t>(args.Length());
+ sem::ArrayCount el_count =
+ sem::ConstantArrayCount{static_cast<uint32_t>(args.Length())};
const sem::Type* el_ty = nullptr;
if (a->type) {
el_ty = Type(a->type);
@@ -1921,7 +1922,9 @@
return nullptr;
}
- auto* arr = Array(a->source, el_ty, el_count, explicit_stride);
+ auto* arr = Array(a->type ? a->type->source : a->source,
+ a->count ? a->count->source : a->source, //
+ el_ty, el_count, explicit_stride);
if (!arr) {
return nullptr;
}
@@ -2591,7 +2594,7 @@
return nullptr;
}
- uint32_t el_count = 0; // sem::Array uses a size of 0 for a runtime-sized array.
+ sem::ArrayCount el_count = sem::RuntimeArrayCount{};
// Evaluate the constant array size expression.
if (auto* count_expr = arr->count) {
@@ -2602,7 +2605,9 @@
}
}
- auto* out = Array(arr->source, el_ty, el_count, explicit_stride);
+ auto* out = Array(arr->type->source, //
+ arr->count ? arr->count->source : arr->source, //
+ el_ty, el_count, explicit_stride);
if (out == nullptr) {
return nullptr;
}
@@ -2619,16 +2624,27 @@
return out;
}
-utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr) {
+utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array size expression.
const auto* count_sem = Materialize(Expression(count_expr));
if (!count_sem) {
return utils::Failure;
}
+ // Note: If the array count is an 'override', but not a identifier expression, we do not return
+ // here, but instead continue to the ConstantValue() check below.
+ if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
+ if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
+ if (global->Declaration()->Is<ast::Override>()) {
+ return sem::ArrayCount{sem::OverrideArrayCount{global}};
+ }
+ }
+ }
+
auto* count_val = count_sem->ConstantValue();
if (!count_val) {
- AddError("array size must evaluate to a constant integer expression", count_expr->source);
+ AddError("array size must evaluate to a constant integer expression or override variable",
+ count_expr->source);
return utils::Failure;
}
@@ -2646,7 +2662,7 @@
return utils::Failure;
}
- return static_cast<uint32_t>(count);
+ return sem::ArrayCount{sem::ConstantArrayCount{static_cast<uint32_t>(count)}};
}
bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attributes,
@@ -2673,27 +2689,33 @@
return true;
}
-sem::Array* Resolver::Array(const Source& source,
+sem::Array* Resolver::Array(const Source& el_source,
+ const Source& count_source,
const sem::Type* el_ty,
- uint32_t el_count,
+ sem::ArrayCount el_count,
uint32_t explicit_stride) {
uint32_t el_align = el_ty->Align();
uint32_t el_size = el_ty->Size();
uint64_t implicit_stride = el_size ? utils::RoundUp<uint64_t>(el_align, el_size) : 0;
uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
+ uint64_t size = 0;
- auto size = std::max<uint64_t>(el_count, 1u) * stride;
- if (size > std::numeric_limits<uint32_t>::max()) {
- std::stringstream msg;
- msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes";
- AddError(msg.str(), source);
- return nullptr;
+ if (auto const_count = std::get_if<sem::ConstantArrayCount>(&el_count)) {
+ size = const_count->value * stride;
+ if (size > std::numeric_limits<uint32_t>::max()) {
+ std::stringstream msg;
+ msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes";
+ AddError(msg.str(), count_source);
+ return nullptr;
+ }
+ } else if (std::holds_alternative<sem::RuntimeArrayCount>(el_count)) {
+ size = stride;
}
auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),
static_cast<uint32_t>(stride),
static_cast<uint32_t>(implicit_stride));
- if (!validator_.Array(out, source)) {
+ if (!validator_.Array(out, el_source)) {
return nullptr;
}
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 4c61c47..7e115b3 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -302,7 +302,7 @@
/// Resolves and validates the expression used as the count parameter of an array.
/// @param count_expr the expression used as the second template parameter to an array<>.
/// @returns the number of elements in the array.
- utils::Result<uint32_t> ArrayCount(const ast::Expression* count_expr);
+ utils::Result<sem::ArrayCount> ArrayCount(const ast::Expression* count_expr);
/// Resolves and validates the attributes on an array.
/// @param attributes the attributes on the array type.
@@ -315,13 +315,17 @@
/// Builds and returns the semantic information for an array.
/// @returns the semantic Array information, or nullptr if an error is raised.
- /// @param source the source of the array declaration
+ /// @param el_source the source of the array element, or the array if the array does not have a
+ /// locally-declared element AST node.
+ /// @param count_source the source of the array count, or the array if the array does not have a
+ /// locally-declared element AST node.
/// @param el_ty the Array element type
- /// @param el_count the number of elements in the array. Zero means runtime-sized.
+ /// @param el_count the number of elements in the array.
/// @param explicit_stride the explicit byte stride of the array. Zero means implicit stride.
- sem::Array* Array(const Source& source,
+ sem::Array* Array(const Source& el_source,
+ const Source& count_source,
const sem::Type* el_ty,
- uint32_t el_count,
+ sem::ArrayCount el_count,
uint32_t explicit_stride);
/// Builds and returns the semantic information for the alias `alias`.
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 79cc139..cde4735 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -432,7 +432,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), 10u);
+ EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_SignedLiteral) {
@@ -445,7 +445,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), 10u);
+ EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_UnsignedConst) {
@@ -460,7 +460,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), 10u);
+ EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_SignedConst) {
@@ -475,7 +475,51 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), 10u);
+ EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
+}
+
+TEST_F(ResolverTest, ArraySize_Override) {
+ // override size = 0;
+ // var<workgroup> a : array<f32, size>;
+ auto* override = Override("size", Expr(10_i));
+ auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(a), nullptr);
+ auto* ref = TypeOf(a)->As<sem::Reference>();
+ ASSERT_NE(ref, nullptr);
+ auto* ary = ref->StoreType()->As<sem::Array>();
+ auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+ ASSERT_NE(sem_override, nullptr);
+ EXPECT_EQ(ary->Count(), sem::OverrideArrayCount{sem_override});
+}
+
+TEST_F(ResolverTest, ArraySize_Override_Equivalence) {
+ // override size = 0;
+ // var<workgroup> a : array<f32, size>;
+ // var<workgroup> b : array<f32, size>;
+ auto* override = Override("size", Expr(10_i));
+ auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
+ auto* b = GlobalVar("b", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(a), nullptr);
+ auto* ref_a = TypeOf(a)->As<sem::Reference>();
+ ASSERT_NE(ref_a, nullptr);
+ auto* ary_a = ref_a->StoreType()->As<sem::Array>();
+
+ ASSERT_NE(TypeOf(b), nullptr);
+ auto* ref_b = TypeOf(b)->As<sem::Reference>();
+ ASSERT_NE(ref_b, nullptr);
+ auto* ary_b = ref_b->StoreType()->As<sem::Array>();
+
+ auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+ ASSERT_NE(sem_override, nullptr);
+ EXPECT_EQ(ary_a->Count(), sem::OverrideArrayCount{sem_override});
+ EXPECT_EQ(ary_b->Count(), sem::OverrideArrayCount{sem_override});
+ EXPECT_EQ(ary_a, ary_b);
}
TEST_F(ResolverTest, Expr_Bitcast) {
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 3ad88c6..b22e14b 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -654,9 +654,13 @@
/// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) {
auto* el = DataType<T>::Sem(b);
+ sem::ArrayCount count = sem::ConstantArrayCount{N};
+ if (N == 0) {
+ count = sem::RuntimeArrayCount{};
+ }
return b.create<sem::Array>(
/* element */ el,
- /* count */ N,
+ /* count */ count,
/* align */ el->Align(),
/* size */ N * el->Size(),
/* stride */ el->Align(),
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index a5f7f37..418cae0 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -310,7 +310,8 @@
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
// var<private> a : array<f32, 0x40000000u>;
- GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000_u), ast::StorageClass::kPrivate);
+ GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0x40000000_u)),
+ ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes");
@@ -318,21 +319,157 @@
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
// var<private> a : @stride(8) array<f32, 0x20000000u>;
- GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000_u, 8),
+ GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0x20000000_u), 8),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes");
}
-TEST_F(ResolverTypeValidationTest, ArraySize_Overridable) {
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_PrivateVar) {
// override size = 10i;
// var<private> a : array<f32, size>;
Override("size", Expr(10_i));
- GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")), ast::StorageClass::kPrivate);
+ GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), "size"), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: array size must evaluate to a constant integer expression");
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_ComplexExpr) {
+ // override size = 10i;
+ // var<workgroup> a : array<f32, size + 1>;
+ Override("size", Expr(10_i));
+ GlobalVar("a", ty.array(ty.f32(), Add(Source{{12, 34}}, "size", 1_i)),
+ ast::StorageClass::kWorkgroup);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size must evaluate to a constant integer expression or override "
+ "variable");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_InArray) {
+ // override size = 10i;
+ // var<workgroup> a : array<array<f32, size>, 4>;
+ Override("size", Expr(10_i));
+ GlobalVar("a", ty.array(ty.array(Source{{12, 34}}, ty.f32(), "size"), 4_a),
+ ast::StorageClass::kWorkgroup);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_InStruct) {
+ // override size = 10i;
+ // struct S {
+ // a : array<f32, size>
+ // };
+ Override("size", Expr(10_i));
+ Structure("S", utils::Vector{Member("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Explicit) {
+ // override size = 10i;
+ // fn f() {
+ // var a : array<f32, size>;
+ // }
+ Override("size", Expr(10_i));
+ Func("f", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Explicit) {
+ // override size = 10i;
+ // fn f() {
+ // var a : array<f32, size>;
+ // }
+ Override("size", Expr(10_i));
+ Func("f", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Implicit) {
+ // override size = 10i;
+ // var<workgroup> w : array<f32, size>;
+ // fn f() {
+ // var a = w;
+ // }
+ Override("size", Expr(10_i));
+ GlobalVar("w", ty.array(ty.f32(), "size"), ast::StorageClass::kWorkgroup);
+ Func("f", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("a", Expr(Source{{12, 34}}, "w"))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Implicit) {
+ // override size = 10i;
+ // var<workgroup> w : array<f32, size>;
+ // fn f() {
+ // let a = w;
+ // }
+ Override("size", Expr(10_i));
+ GlobalVar("w", ty.array(ty.f32(), "size"), ast::StorageClass::kWorkgroup);
+ Func("f", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Let("a", Expr(Source{{12, 34}}, "w"))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array with an 'override' element count can only be used as the store "
+ "type of a 'var<workgroup>'");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_Param) {
+ // override size = 10i;
+ // fn f(a : array<f32, size>) {
+ // }
+ Override("size", Expr(10_i));
+ Func("f", utils::Vector{Param("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))}, ty.void_(),
+ utils::Empty);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: type of function parameter must be constructible");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Override_ReturnType) {
+ // override size = 10i;
+ // fn f() -> array<f32, size> {
+ // }
+ Override("size", Expr(10_i));
+ Func("f", utils::Empty, ty.array(Source{{12, 34}}, ty.f32(), "size"), utils::Empty);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: function return type must be a constructible type");
+}
+
+TEST_F(ResolverTypeValidationTest, ArraySize_Workgroup_Overridable) {
+ // override size = 10i;
+ // var<workgroup> a : array<f32, size>;
+ Override("size", Expr(10_i));
+ GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
+ ast::StorageClass::kWorkgroup);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeValidationTest, ArraySize_ModuleVar) {
@@ -367,7 +504,8 @@
WrapInFunction(size, a);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: array size must evaluate to a constant integer expression");
+ "12:34 error: array size must evaluate to a constant integer expression or override "
+ "variable");
}
TEST_F(ResolverTypeValidationTest, ArraySize_ComplexExpr) {
@@ -477,7 +615,7 @@
// };
Structure("Foo", utils::Vector{
- Member("rt", ty.array(Source{{12, 34}}, ty.array<f32>(), 4_u)),
+ Member("rt", ty.array(ty.array(Source{{12, 34}}, ty.f32()), 4_u)),
});
EXPECT_FALSE(r()->Resolve()) << r()->error();
@@ -491,10 +629,9 @@
// };
// var<private> a : array<Foo, 4>;
- auto* foo = Structure("Foo", utils::Vector{
- Member("rt", ty.array<f32>()),
- });
- GlobalVar("v", ty.array(Source{{12, 34}}, ty.Of(foo), 4_u), ast::StorageClass::kPrivate);
+ Structure("Foo", utils::Vector{Member("rt", ty.array<f32>())});
+ GlobalVar("v", ty.array(ty.type_name(Source{{12, 34}}, "Foo"), 4_u),
+ ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
@@ -636,8 +773,8 @@
}
TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) {
- auto* tex_ty = ty.sampled_texture(ast::TextureDimension::k2d, ty.f32());
- GlobalVar("arr", ty.array(Source{{12, 34}}, tex_ty, 4_i), ast::StorageClass::kPrivate);
+ auto* tex_ty = ty.sampled_texture(Source{{12, 34}}, ast::TextureDimension::k2d, ty.f32());
+ GlobalVar("arr", ty.array(tex_ty, 4_i), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 1812789..c1f52fb 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -548,23 +548,28 @@
return true;
}
-bool Validator::LocalVariable(const sem::Variable* v) const {
- auto* decl = v->Declaration();
+bool Validator::LocalVariable(const sem::Variable* local) const {
+ auto* decl = local->Declaration();
+ if (IsArrayWithOverrideCount(local->Type())) {
+ RaiseArrayWithOverrideCountError(decl->type ? decl->type->source
+ : decl->constructor->source);
+ return false;
+ }
return Switch(
decl, //
[&](const ast::Var* var) {
if (IsValidationEnabled(var->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) {
- if (!v->Type()->UnwrapRef()->IsConstructible()) {
+ if (!local->Type()->UnwrapRef()->IsConstructible()) {
AddError("function-scope 'var' must have a constructible type",
var->type ? var->type->source : var->source);
return false;
}
}
- return Var(v);
- }, //
- [&](const ast::Let*) { return Let(v); }, //
- [&](const ast::Const*) { return true; }, //
+ return Var(local);
+ }, //
+ [&](const ast::Let*) { return Let(local); }, //
+ [&](const ast::Const*) { return true; }, //
[&](Default) {
TINT_ICE(Resolver, diagnostics_)
<< "Validator::Variable() called with a unknown variable type: "
@@ -578,6 +583,12 @@
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
auto* decl = global->Declaration();
+ if (global->StorageClass() != ast::StorageClass::kWorkgroup &&
+ IsArrayWithOverrideCount(global->Type())) {
+ RaiseArrayWithOverrideCountError(decl->type ? decl->type->source
+ : decl->constructor->source);
+ return false;
+ }
bool ok = Switch(
decl, //
[&](const ast::Var* var) {
@@ -875,7 +886,7 @@
if (IsPlain(var->Type())) {
if (!var->Type()->IsConstructible()) {
- AddError("type of function parameter must be constructible", decl->source);
+ AddError("type of function parameter must be constructible", decl->type->source);
return false;
}
} else if (!var->Type()->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
@@ -1898,20 +1909,23 @@
if (array_type->IsRuntimeSized()) {
AddError("cannot construct a runtime-sized array", ctor->source);
return false;
- } else if (!elem_ty->IsConstructible()) {
+ }
+
+ if (array_type->IsOverrideSized()) {
+ AddError("cannot construct an array that has an override expression count", ctor->source);
+ return false;
+ }
+
+ if (!elem_ty->IsConstructible()) {
AddError("array constructor has non-constructible element type", ctor->source);
return false;
- } else if (!values.IsEmpty() && (values.Length() != array_type->Count())) {
- std::string fm = values.Length() < array_type->Count() ? "few" : "many";
+ }
+
+ const auto count = std::get<sem::ConstantArrayCount>(array_type->Count()).value;
+ if (!values.IsEmpty() && (values.Length() != count)) {
+ std::string fm = values.Length() < count ? "few" : "many";
AddError("array constructor has too " + fm + " elements: expected " +
- std::to_string(array_type->Count()) + ", found " +
- std::to_string(values.Length()),
- ctor->source);
- return false;
- } else if (values.Length() > array_type->Count()) {
- AddError("array constructor has too many elements: expected " +
- std::to_string(array_type->Count()) + ", found " +
- std::to_string(values.Length()),
+ std::to_string(count) + ", found " + std::to_string(values.Length()),
ctor->source);
return false;
}
@@ -2086,18 +2100,25 @@
return true;
}
-bool Validator::Array(const sem::Array* arr, const Source& source) const {
+bool Validator::Array(const sem::Array* arr, const Source& el_source) const {
auto* el_ty = arr->ElemType();
if (!IsPlain(el_ty)) {
- AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array", source);
+ AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array",
+ el_source);
return false;
}
if (!IsFixedFootprint(el_ty)) {
- AddError("an array element type cannot contain a runtime-sized array", source);
+ AddError("an array element type cannot contain a runtime-sized array", el_source);
return false;
}
+
+ if (IsArrayWithOverrideCount(el_ty)) {
+ RaiseArrayWithOverrideCountError(el_source);
+ return false;
+ }
+
return true;
}
@@ -2154,6 +2175,11 @@
return false;
}
}
+
+ if (IsArrayWithOverrideCount(member->Type())) {
+ RaiseArrayWithOverrideCountError(member->Declaration()->type->source);
+ return false;
+ }
} else if (!IsFixedFootprint(member->Type())) {
AddError(
"a struct that contains a runtime array cannot be nested inside "
@@ -2488,6 +2514,22 @@
return !IsValidationDisabled(attributes, validation);
}
+bool Validator::IsArrayWithOverrideCount(const sem::Type* ty) const {
+ if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
+ if (arr->IsOverrideSized()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void Validator::RaiseArrayWithOverrideCountError(const Source& source) const {
+ AddError(
+ "array with an 'override' element count can only be used as the store type of a "
+ "'var<workgroup>'",
+ source);
+}
+
std::string Validator::VectorPretty(uint32_t size, const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(symbols_);
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index a00f6ab..b79a850 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -128,9 +128,10 @@
/// Validates the array
/// @param arr the array to validate
- /// @param source the source of the array
+ /// @param el_source the source of the array element, or the array if the array does not have a
+ /// locally-declared element AST node.
/// @returns true on success, false otherwise.
- bool Array(const sem::Array* arr, const Source& source) const;
+ bool Array(const sem::Array* arr, const Source& el_source) const;
/// Validates an array stride attribute
/// @param attr the stride attribute to validate
@@ -463,6 +464,16 @@
ast::DisabledValidation validation) const;
private:
+ /// @param ty the type to check
+ /// @returns true if @p ty is an array with an `override` expression element count, otherwise
+ /// false.
+ bool IsArrayWithOverrideCount(const sem::Type* ty) const;
+
+ /// Raises an error about an array type using an `override` expression element count, outside
+ /// the single allowed use of a `var<workgroup>`.
+ /// @param source the source for the error
+ void RaiseArrayWithOverrideCountError(const Source& source) const;
+
/// Searches the current statement and up through parents of the current
/// statement looking for a loop or for-loop continuing statement.
/// @returns the closest continuing statement to the current statement that
diff --git a/src/tint/resolver/validator_is_storeable_test.cc b/src/tint/resolver/validator_is_storeable_test.cc
index 88ec911..015b095 100644
--- a/src/tint/resolver/validator_is_storeable_test.cc
+++ b/src/tint/resolver/validator_is_storeable_test.cc
@@ -89,12 +89,13 @@
}
TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
diff --git a/src/tint/sem/array.cc b/src/tint/sem/array.cc
index 624623a..31827cb 100644
--- a/src/tint/sem/array.cc
+++ b/src/tint/sem/array.cc
@@ -16,15 +16,22 @@
#include <string>
+#include "src/tint/ast/variable.h"
#include "src/tint/debug.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/symbol_table.h"
#include "src/tint/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::Array);
namespace tint::sem {
+const char* Array::kErrExpectedConstantCount =
+ "array size is an override-expression, when expected a constant-expression.\n"
+ "Was the SubstituteOverride transform run?";
+
Array::Array(const Type* element,
- uint32_t count,
+ ArrayCount count,
uint32_t align,
uint32_t size,
uint32_t stride,
@@ -35,8 +42,9 @@
size_(size),
stride_(stride),
implicit_stride_(implicit_stride),
- constructible_(count > 0 // Runtime-sized arrays are not constructible
- && element->IsConstructible()) {
+ // Only constant-expression sized arrays are constructible
+ constructible_(std::holds_alternative<ConstantArrayCount>(count) &&
+ element->IsConstructible()) {
TINT_ASSERT(Semantic, element_);
}
@@ -64,8 +72,10 @@
out << "@stride(" << stride_ << ") ";
}
out << "array<" << element_->FriendlyName(symbols);
- if (!IsRuntimeSized()) {
- out << ", " << count_;
+ if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
+ out << ", " << const_count->value;
+ } else if (auto* override_count = std::get_if<OverrideArrayCount>(&count_)) {
+ out << ", " << symbols.NameFor(override_count->variable->Declaration()->symbol);
}
out << ">";
return out.str();
diff --git a/src/tint/sem/array.h b/src/tint/sem/array.h
index 7f72d8a..d41cd98 100644
--- a/src/tint/sem/array.h
+++ b/src/tint/sem/array.h
@@ -16,29 +16,116 @@
#define SRC_TINT_SEM_ARRAY_H_
#include <stdint.h>
+#include <optional>
#include <string>
+#include <variant>
#include "src/tint/sem/node.h"
#include "src/tint/sem/type.h"
+#include "src/tint/utils/compiler_macros.h"
+
+// Forward declarations
+namespace tint::sem {
+class GlobalVariable;
+} // namespace tint::sem
namespace tint::sem {
+/// The variant of an ArrayCount when the array is a constant expression.
+/// Example:
+/// ```
+/// const N = 123;
+/// type arr = array<i32, N>
+/// ```
+struct ConstantArrayCount {
+ /// The array count constant-expression value.
+ uint32_t value;
+};
+
+/// The variant of an ArrayCount when the count is a named override variable.
+/// Example:
+/// ```
+/// override N : i32;
+/// type arr = array<i32, N>
+/// ```
+struct OverrideArrayCount {
+ /// The `override` variable.
+ const GlobalVariable* variable;
+};
+
+/// The variant of an ArrayCount when the array is is runtime-sized.
+/// Example:
+/// ```
+/// type arr = array<i32>
+/// ```
+struct RuntimeArrayCount {};
+
+/// An array count is either a constant-expression value, an override identifier, or runtime-sized.
+using ArrayCount = std::variant<ConstantArrayCount, OverrideArrayCount, RuntimeArrayCount>;
+
+/// Equality operator
+/// @param a the LHS ConstantArrayCount
+/// @param b the RHS ConstantArrayCount
+/// @returns true if @p a is equal to @p b
+inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) {
+ return a.value == b.value;
+}
+
+/// Equality operator
+/// @param a the LHS OverrideArrayCount
+/// @param b the RHS OverrideArrayCount
+/// @returns true if @p a is equal to @p b
+inline bool operator==(const OverrideArrayCount& a, const OverrideArrayCount& b) {
+ return a.variable == b.variable;
+}
+
+/// Equality operator
+/// @returns true
+inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
+ return true;
+}
+
+/// Equality operator
+/// @param a the LHS ArrayCount
+/// @param b the RHS count
+/// @returns true if @p a is equal to @p b
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, ConstantArrayCount> ||
+ std::is_same_v<T, OverrideArrayCount> ||
+ std::is_same_v<T, RuntimeArrayCount>>>
+inline bool operator==(const ArrayCount& a, const T& b) {
+ TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
+ return std::visit(
+ [&](auto count) {
+ if constexpr (std::is_same_v<std::decay_t<decltype(count)>, T>) {
+ return count == b;
+ }
+ return false;
+ },
+ a);
+ TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+}
+
/// Array holds the semantic information for Array nodes.
class Array final : public Castable<Array, Type> {
public:
+ /// An error message string stating that the array count was expected to be a constant
+ /// expression. Used by multiple writers and transforms.
+ static const char* kErrExpectedConstantCount;
+
/// Constructor
/// @param element the array element type
- /// @param count the number of elements in the array. 0 represents a
- /// runtime-sized array.
+ /// @param count the number of elements in the array.
/// @param align the byte alignment of the array
- /// @param size the byte size of the array
+ /// @param size the byte size of the array. The size will be 0 if the array element count is
+ /// pipeline overridable.
/// @param stride the number of bytes from the start of one element of the
- /// array to the start of the next element
+ /// array to the start of the next element
/// @param implicit_stride the number of bytes from the start of one element
/// of the array to the start of the next element, if there was no `@stride`
/// attribute applied.
Array(Type const* element,
- uint32_t count,
+ ArrayCount count,
uint32_t align,
uint32_t size,
uint32_t stride,
@@ -54,9 +141,16 @@
/// @return the array element type
Type const* ElemType() const { return element_; }
- /// @returns the number of elements in the array. 0 represents a runtime-sized
- /// array.
- uint32_t Count() const { return count_; }
+ /// @returns the number of elements in the array.
+ const ArrayCount& Count() const { return count_; }
+
+ /// @returns the array count if the count is a constant expression, otherwise returns nullopt.
+ inline std::optional<uint32_t> ConstantCount() const {
+ if (auto* count = std::get_if<ConstantArrayCount>(&count_)) {
+ return count->value;
+ }
+ return std::nullopt;
+ }
/// @returns the byte alignment of the array
/// @note this may differ from the alignment of a structure member of this
@@ -81,8 +175,14 @@
/// natural stride
bool IsStrideImplicit() const { return stride_ == implicit_stride_; }
+ /// @returns true if this array is sized using an constant expression
+ bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
+
+ /// @returns true if this array is sized using an override variable
+ bool IsOverrideSized() const { return std::holds_alternative<OverrideArrayCount>(count_); }
+
/// @returns true if this array is runtime sized
- bool IsRuntimeSized() const { return count_ == 0; }
+ bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
/// @returns true if constructible as per
/// https://gpuweb.github.io/gpuweb/wgsl/#constructible-types
@@ -95,7 +195,7 @@
private:
Type const* const element_;
- const uint32_t count_;
+ const ArrayCount count_;
const uint32_t align_;
const uint32_t size_;
const uint32_t stride_;
@@ -105,4 +205,38 @@
} // namespace tint::sem
+namespace std {
+
+/// Custom std::hash specialization for tint::sem::ConstantArrayCount.
+template <>
+class hash<tint::sem::ConstantArrayCount> {
+ public:
+ /// @param count the count to hash
+ /// @return the hash value
+ inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const {
+ return std::hash<decltype(count.value)>()(count.value);
+ }
+};
+
+/// Custom std::hash specialization for tint::sem::OverrideArrayCount.
+template <>
+class hash<tint::sem::OverrideArrayCount> {
+ public:
+ /// @param count the count to hash
+ /// @return the hash value
+ inline std::size_t operator()(const tint::sem::OverrideArrayCount& count) const {
+ return std::hash<decltype(count.variable)>()(count.variable);
+ }
+};
+
+/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
+template <>
+class hash<tint::sem::RuntimeArrayCount> {
+ public:
+ /// @return the hash value
+ inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; }
+};
+
+} // namespace std
+
#endif // SRC_TINT_SEM_ARRAY_H_
diff --git a/src/tint/sem/sem_array_test.cc b/src/tint/sem/sem_array_test.cc
index 4b61fd7..d6ce9cd 100644
--- a/src/tint/sem/sem_array_test.cc
+++ b/src/tint/sem/sem_array_test.cc
@@ -21,16 +21,16 @@
using ArrayTest = TestHelper;
TEST_F(ArrayTest, CreateSizedArray) {
- auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
- auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
+ auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
+ auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
+ auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
+ auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
- EXPECT_EQ(a->Count(), 2u);
+ EXPECT_EQ(a->Count(), ConstantArrayCount{2u});
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@@ -47,15 +47,15 @@
}
TEST_F(ArrayTest, CreateRuntimeArray) {
- auto* a = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u);
- auto* b = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u);
- auto* c = create<Array>(create<U32>(), 0u, 5u, 8u, 32u, 32u);
- auto* d = create<Array>(create<U32>(), 0u, 4u, 9u, 32u, 32u);
- auto* e = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 32u);
- auto* f = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
+ auto* b = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
+ auto* c = create<Array>(create<U32>(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u);
+ auto* d = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u);
+ auto* e = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u);
+ auto* f = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
- EXPECT_EQ(a->Count(), 0u);
+ EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{});
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@@ -71,13 +71,13 @@
}
TEST_F(ArrayTest, Hash) {
- auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
- auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
+ auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
+ auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
+ auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
+ auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash());
EXPECT_NE(a->Hash(), c->Hash());
@@ -88,13 +88,13 @@
}
TEST_F(ArrayTest, Equals) {
- auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
- auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
+ auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
+ auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
+ auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
+ auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@@ -106,22 +106,22 @@
}
TEST_F(ArrayTest, FriendlyNameRuntimeSized) {
- auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 4u, 4u);
+ auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSized) {
- auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 4u, 4u);
+ auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>");
}
TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) {
- auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 8u, 4u);
+ auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
- auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 8u, 4u);
+ auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>");
}
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index ccc4141..23f9164 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -246,7 +246,9 @@
},
[&](const Array* a) {
if (count) {
- *count = a->Count();
+ if (auto* const_count = std::get_if<ConstantArrayCount>(&a->Count())) {
+ *count = const_count->value;
+ }
}
return a->ElemType();
},
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index 56dcbff..5c7e97b 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -62,56 +62,56 @@
/* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>(
/* element */ i32,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_ai = create<Array>(
/* element */ ai,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_vec3_i32 = create<Array>(
/* element */ vec3_i32,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_vec3_ai = create<Array>(
/* element */ vec3_ai,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_mat4x3_f16 = create<Array>(
/* element */ mat4x3_f16,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 32u,
/* size */ 5u * 32u,
/* stride */ 5u * 32u,
/* implicit_stride */ 5u * 32u);
const sem::Array* arr_mat4x3_f32 = create<Array>(
/* element */ mat4x3_f32,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_mat4x3_af = create<Array>(
/* element */ mat4x3_af,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_str = create<Array>(
/* element */ str,
- /* count */ 5u,
+ /* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc
index 767a2a3..5920d5b 100644
--- a/src/tint/transform/decompose_memory_access.cc
+++ b/src/tint/transform/decompose_memory_access.cc
@@ -471,8 +471,18 @@
auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
auto* for_init = b.Decl(i);
+ auto arr_cnt = arr_ty->ConstantCount();
+ if (!arr_cnt) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles storage and uniform.
+ // * Runtime-sized arrays are not loadable.
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unexpected non-constant array count";
+ arr_cnt = 1;
+ }
auto* for_cond = b.create<ast::BinaryExpression>(
- ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
+ ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@@ -562,8 +572,18 @@
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
auto* for_init = b.Decl(i);
+ auto arr_cnt = arr_ty->ConstantCount();
+ if (!arr_cnt) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup
+ // arrays, and this method only handles storage and uniform.
+ // * Runtime-sized arrays are not storable.
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unexpected non-constant array count";
+ arr_cnt = 1;
+ }
auto* for_cond = b.create<ast::BinaryExpression>(
- ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
+ ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset =
diff --git a/src/tint/transform/pad_structs.cc b/src/tint/transform/pad_structs.cc
index c0be12c..0c5da5c 100644
--- a/src/tint/transform/pad_structs.cc
+++ b/src/tint/transform/pad_structs.cc
@@ -82,7 +82,7 @@
// std140 structs should be padded out to 16 bytes.
size = utils::RoundUp(16u, size);
} else if (auto* array_ty = ty->As<sem::Array>()) {
- if (array_ty->Count() == 0) {
+ if (array_ty->IsRuntimeSized()) {
has_runtime_sized_array = true;
}
}
diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc
index e662e56..b9b6c5e 100644
--- a/src/tint/transform/robustness.cc
+++ b/src/tint/transform/robustness.cc
@@ -99,14 +99,21 @@
// Must clamp, even if the index is constant.
auto* arr_ptr = b.AddressOf(ctx.Clone(expr->object));
max = b.Sub(b.Call("arrayLength", arr_ptr), 1_u);
- } else {
+ } else if (auto count = arr->ConstantCount()) {
if (sem->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
return nullptr;
}
- max = b.Expr(u32(arr->Count() - 1u));
+ max = b.Expr(u32(count.value() - 1u));
+ } else {
+ // Note: Don't be tempted to use the array override variable as an expression
+ // here, the name might be shadowed!
+ ctx.dst->Diagnostics().add_error(diag::System::Transform,
+ sem::Array::kErrExpectedConstantCount);
+ return nullptr;
}
+
return b.Call("min", idx(), max);
},
[&](Default) {
diff --git a/src/tint/transform/robustness_test.cc b/src/tint/transform/robustness_test.cc
index e3d08f0..4cd4604 100644
--- a/src/tint/transform/robustness_test.cc
+++ b/src/tint/transform/robustness_test.cc
@@ -1316,5 +1316,23 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(RobustnessTest, WorkgroupOverrideCount) {
+ auto* src = R"(
+override N = 123;
+var<workgroup> w : array<f32, N>;
+
+fn f() {
+ var b : f32 = w[1i];
+}
+)";
+
+ auto* expect = R"(error: array size is an override-expression, when expected a constant-expression.
+Was the SubstituteOverride transform run?)";
+
+ auto got = Run<Robustness>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::transform
diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc
index a6cfd5e..e815633 100644
--- a/src/tint/transform/spirv_atomic.cc
+++ b/src/tint/transform/spirv_atomic.cc
@@ -35,6 +35,8 @@
namespace tint::transform {
+using namespace tint::number_suffixes; // NOLINT
+
/// Private implementation of transform
struct SpirvAtomic::State {
private:
@@ -189,10 +191,19 @@
[&](const sem::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::Struct* str) { return b.ty.type_name(Fork(str->Declaration()).name); },
- [&](const sem::Array* arr) {
- return arr->IsRuntimeSized()
- ? b.ty.array(AtomicTypeFor(arr->ElemType()))
- : b.ty.array(AtomicTypeFor(arr->ElemType()), u32(arr->Count()));
+ [&](const sem::Array* arr) -> const ast::Type* {
+ if (arr->IsRuntimeSized()) {
+ return b.ty.array(AtomicTypeFor(arr->ElemType()));
+ }
+ auto count = arr->ConstantCount();
+ if (!count) {
+ ctx.dst->Diagnostics().add_error(
+ diag::System::Transform,
+ "the SpirvAtomic transform does not currently support array counts that "
+ "use override values");
+ count = 1;
+ }
+ return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
},
[&](const sem::Pointer* ptr) {
return b.ty.pointer(AtomicTypeFor(ptr->StoreType()), ptr->StorageClass(),
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 06b4f5d..1f495f0 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -423,7 +423,17 @@
if (!arr->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
}
- return b.create<ast::Array>(std140, b.Expr(u32(arr->Count())),
+ auto count = arr->ConstantCount();
+ if (!count) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles types transitively used as uniform buffers.
+ // * Runtime-sized arrays cannot be used in uniform buffers.
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unexpected non-constant array count";
+ count = 1;
+ }
+ return b.create<ast::Array>(std140, b.Expr(u32(count.value())),
std::move(attrs));
}
return nullptr;
@@ -613,7 +623,17 @@
ty, //
[&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
[&](const sem::Array* arr) {
- return "arr" + std::to_string(arr->Count()) + "_" + ConvertSuffix(arr->ElemType());
+ auto count = arr->ConstantCount();
+ if (!count) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles types transitively used as uniform buffers.
+ // * Runtime-sized arrays cannot be used in uniform buffers.
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unexpected non-constant array count";
+ count = 1;
+ }
+ return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
},
[&](const sem::Matrix* mat) {
return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) +
@@ -710,10 +730,20 @@
auto* i = b.Var("i", b.ty.u32());
auto* dst_el = b.IndexAccessor(var, i);
auto* src_el = Convert(arr->ElemType(), b.IndexAccessor(param, i));
+ auto count = arr->ConstantCount();
+ if (!count) {
+ // Non-constant counts should not be possible:
+ // * Override-expression counts can only be applied to workgroup arrays, and
+ // this method only handles types transitively used as uniform buffers.
+ // * Runtime-sized arrays cannot be used in uniform buffers.
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "unexpected non-constant array count";
+ count = 1;
+ }
stmts.Push(b.Decl(var));
- stmts.Push(b.For(b.Decl(i), //
- b.LessThan(i, u32(arr->Count())), //
- b.Assign(i, b.Add(i, 1_a)), //
+ stmts.Push(b.For(b.Decl(i), //
+ b.LessThan(i, u32(count.value())), //
+ b.Assign(i, b.Add(i, 1_a)), //
b.Block(b.Assign(dst_el, src_el))));
stmts.Push(b.Return(var));
},
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index 3adfff6..3bcac6b 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -24,6 +24,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/sampler.h"
+#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
@@ -112,9 +113,16 @@
}
if (a->IsRuntimeSized()) {
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
- } else {
- return ctx.dst->ty.array(el, u32(a->Count()), std::move(attrs));
}
+ if (auto* override = std::get_if<sem::OverrideArrayCount>(&a->Count())) {
+ auto* count = ctx.Clone(override->variable->Declaration());
+ return ctx.dst->ty.array(el, count, std::move(attrs));
+ }
+ if (auto count = a->ConstantCount()) {
+ return ctx.dst->ty.array(el, u32(count.value()), std::move(attrs));
+ }
+ TINT_ICE(Transform, ctx.dst->Diagnostics()) << sem::Array::kErrExpectedConstantCount;
+ return ctx.dst->ty.array(el, u32(1), std::move(attrs));
}
if (auto* s = ty->As<sem::Struct>()) {
return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index 4ec4d1d..d063ba2 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -65,7 +65,8 @@
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
- return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 32u, 32u);
+ return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 32u,
+ 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
@@ -78,7 +79,8 @@
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
- return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 64u, 32u);
+ return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 64u,
+ 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc
index 94df3b9..0d3ed98 100644
--- a/src/tint/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/transform/zero_init_workgroup_memory.cc
@@ -307,7 +307,13 @@
// `num_values * arr->Count()`
// The index for this array is:
// `(idx % modulo) / division`
- auto modulo = num_values * arr->Count();
+ auto count = arr->ConstantCount();
+ if (!count) {
+ ctx.dst->Diagnostics().add_error(diag::System::Transform,
+ sem::Array::kErrExpectedConstantCount);
+ return Expression{};
+ }
+ auto modulo = num_values * count.value();
auto division = num_values;
auto a = get_expr(modulo);
auto array_indices = a.array_indices;
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 21840bc..3cab83d 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -2240,7 +2240,13 @@
ScopedParen sp(out);
- for (size_t i = 0; i < a->Count(); i++) {
+ auto count = a->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+
+ for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
@@ -2356,16 +2362,23 @@
}
EmitZeroValue(out, member->Type());
}
- } else if (auto* array = type->As<sem::Array>()) {
+ } else if (auto* arr = type->As<sem::Array>()) {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
return false;
}
ScopedParen sp(out);
- for (uint32_t i = 0; i < array->Count(); i++) {
+
+ auto count = arr->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+
+ for (uint32_t i = 0; i < count; i++) {
if (i != 0) {
out << ", ";
}
- EmitZeroValue(out, array->ElemType());
+ EmitZeroValue(out, arr->ElemType());
}
} else {
diagnostics_.add_error(diag::System::Writer, "Invalid type for zero emission: " +
@@ -2697,7 +2710,18 @@
const sem::Type* base_type = ary;
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) {
- sizes.push_back(arr->Count());
+ if (arr->IsRuntimeSized()) {
+ sizes.push_back(0);
+ } else {
+ auto count = arr->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer,
+ sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+ sizes.push_back(count.value());
+ }
+
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 516acdd..fb774c0 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -3193,7 +3193,13 @@
out << "{";
TINT_DEFER(out << "}");
- for (size_t i = 0; i < a->Count(); i++) {
+ auto count = a->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+
+ for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
@@ -3732,11 +3738,18 @@
while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
TINT_ICE(Writer, diagnostics_)
- << "Runtime arrays may only exist in storage buffers, which should have "
+ << "runtime arrays may only exist in storage buffers, which should have "
"been transformed into a ByteAddressBuffer";
return false;
}
- sizes.push_back(arr->Count());
+ const auto count = arr->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer,
+ sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+
+ sizes.push_back(count.value());
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index d1bab39..3708323 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -1690,7 +1690,13 @@
return true;
}
- for (size_t i = 0; i < a->Count(); i++) {
+ auto count = a->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+
+ for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
@@ -2481,7 +2487,20 @@
if (!EmitType(out, arr->ElemType(), "")) {
return false;
}
- out << ", " << (arr->IsRuntimeSized() ? 1u : arr->Count()) << ">";
+ out << ", ";
+ if (arr->IsRuntimeSized()) {
+ out << "1";
+ } else {
+ auto count = arr->ConstantCount();
+ if (!count) {
+ diagnostics_.add_error(diag::System::Writer,
+ sem::Array::kErrExpectedConstantCount);
+ return false;
+ }
+
+ out << count.value();
+ }
+ out << ">";
return true;
},
[&](const sem::Bool*) {
@@ -3133,8 +3152,14 @@
<< "arrays with explicit strides should not exist past the SPIR-V reader";
return SizeAndAlign{};
}
- auto num_els = std::max<uint32_t>(arr->Count(), 1);
- return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
+ if (arr->IsRuntimeSized()) {
+ return SizeAndAlign{arr->Stride(), arr->Align()};
+ }
+ if (auto count = arr->ConstantCount()) {
+ return SizeAndAlign{arr->Stride() * count.value(), arr->Align()};
+ }
+ diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
+ return SizeAndAlign{};
},
[&](const sem::Struct* str) {
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 05f3716..70d16a6 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -1676,11 +1676,18 @@
},
[&](const sem::Vector* v) { return composite(v->Width()); },
[&](const sem::Matrix* m) { return composite(m->columns()); },
- [&](const sem::Array* a) { return composite(a->Count()); },
+ [&](const sem::Array* a) {
+ auto count = a->ConstantCount();
+ if (!count) {
+ error_ = sem::Array::kErrExpectedConstantCount;
+ return static_cast<uint32_t>(0);
+ }
+ return composite(count.value());
+ },
[&](const sem::Struct* s) { return composite(s->Members().size()); },
[&](Default) {
error_ = "unhandled constant type: " + builder_.FriendlyName(ty);
- return false;
+ return 0;
});
}
@@ -3852,17 +3859,23 @@
return true;
}
-bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) {
- auto elem_type = GenerateTypeIfNeeded(ary->ElemType());
+bool Builder::GenerateArrayType(const sem::Array* arr, const Operand& result) {
+ auto elem_type = GenerateTypeIfNeeded(arr->ElemType());
if (elem_type == 0) {
return false;
}
auto result_id = std::get<uint32_t>(result);
- if (ary->IsRuntimeSized()) {
+ if (arr->IsRuntimeSized()) {
push_type(spv::Op::OpTypeRuntimeArray, {result, Operand(elem_type)});
} else {
- auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->Count()));
+ auto count = arr->ConstantCount();
+ if (!count) {
+ error_ = sem::Array::kErrExpectedConstantCount;
+ return static_cast<uint32_t>(0);
+ }
+
+ auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(count.value()));
if (len_id == 0) {
return false;
}
@@ -3871,7 +3884,7 @@
}
push_annot(spv::Op::OpDecorate,
- {Operand(result_id), U32Operand(SpvDecorationArrayStride), Operand(ary->Stride())});
+ {Operand(result_id), U32Operand(SpvDecorationArrayStride), Operand(arr->Stride())});
return true;
}
diff --git a/test/tint/bug/tint/1660.wgsl b/test/tint/bug/tint/1660.wgsl
new file mode 100644
index 0000000..f06d17a
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl
@@ -0,0 +1,5 @@
+// flags: --transform substitute_override
+
+override size = 2;
+
+var<workgroup> a : array<f32, size>;
diff --git a/test/tint/bug/tint/1660.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1660.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..f256b90
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl.expected.dxc.hlsl
@@ -0,0 +1,6 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+groupshared float a[2];
diff --git a/test/tint/bug/tint/1660.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1660.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..f256b90
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl.expected.fxc.hlsl
@@ -0,0 +1,6 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+groupshared float a[2];
diff --git a/test/tint/bug/tint/1660.wgsl.expected.glsl b/test/tint/bug/tint/1660.wgsl.expected.glsl
new file mode 100644
index 0000000..8a98ebe
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl.expected.glsl
@@ -0,0 +1,7 @@
+#version 310 es
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+ return;
+}
+shared float a[2];
diff --git a/test/tint/bug/tint/1660.wgsl.expected.msl b/test/tint/bug/tint/1660.wgsl.expected.msl
new file mode 100644
index 0000000..466ceaa
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl.expected.msl
@@ -0,0 +1,3 @@
+#include <metal_stdlib>
+
+using namespace metal;
diff --git a/test/tint/bug/tint/1660.wgsl.expected.spvasm b/test/tint/bug/tint/1660.wgsl.expected.spvasm
new file mode 100644
index 0000000..1162af8
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl.expected.spvasm
@@ -0,0 +1,24 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 11
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+ OpExecutionMode %unused_entry_point LocalSize 1 1 1
+ OpName %a "a"
+ OpName %unused_entry_point "unused_entry_point"
+ OpDecorate %_arr_float_uint_2 ArrayStride 4
+ %float = OpTypeFloat 32
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+%_arr_float_uint_2 = OpTypeArray %float %uint_2
+%_ptr_Workgroup__arr_float_uint_2 = OpTypePointer Workgroup %_arr_float_uint_2
+ %a = OpVariable %_ptr_Workgroup__arr_float_uint_2 Workgroup
+ %void = OpTypeVoid
+ %7 = OpTypeFunction %void
+%unused_entry_point = OpFunction %void None %7
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/tint/1660.wgsl.expected.wgsl b/test/tint/bug/tint/1660.wgsl.expected.wgsl
new file mode 100644
index 0000000..f996d52
--- /dev/null
+++ b/test/tint/bug/tint/1660.wgsl.expected.wgsl
@@ -0,0 +1,3 @@
+const size = 2;
+
+var<workgroup> a : array<f32, size>;