Moved sem::ArrayCount to an inherited structure
This CL moves the ArrayCount from a variant to use inheritance. This
will allow the sem to have different array count classes from the IR.
The ArrayCounts, similar to types, are unique across the code base and
are provided by the TypeManager.
Bug: tint:1718
Change-Id: Ib9c7c9df881e7a34cc3def2ff29571f536d66244
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112441
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 3d10a06..6a0e70d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -419,6 +419,7 @@
"sem/abstract_int.h",
"sem/abstract_numeric.h",
"sem/array.h",
+ "sem/array_count.h",
"sem/atomic.h",
"sem/behavior.h",
"sem/binding_point.h",
@@ -635,6 +636,8 @@
"sem/abstract_numeric.h",
"sem/array.cc",
"sem/array.h",
+ "sem/array_count.cc",
+ "sem/array_count.h",
"sem/atomic.cc",
"sem/atomic.h",
"sem/behavior.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index cb9c96e..72cec7b 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -299,6 +299,8 @@
sem/abstract_numeric.h
sem/array.cc
sem/array.h
+ sem/array_count.cc
+ sem/array_count.h
sem/atomic.cc
sem/atomic.h
sem/behavior.cc
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index d6d1f00..4677009 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -91,6 +91,7 @@
#include "src/tint/program.h"
#include "src/tint/program_id.h"
#include "src/tint/sem/array.h"
+#include "src/tint/sem/array_count.h"
#include "src/tint/sem/bool.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/depth_texture.h"
@@ -457,7 +458,8 @@
/// @returns the node pointer
template <typename T, typename... ARGS>
traits::EnableIf<traits::IsTypeOrDerived<T, sem::Node> &&
- !traits::IsTypeOrDerived<T, sem::Type>,
+ !traits::IsTypeOrDerived<T, sem::Type> &&
+ !traits::IsTypeOrDerived<T, sem::ArrayCount>,
T>*
create(ARGS&&... args) {
AssertNotMoved();
@@ -476,17 +478,28 @@
/// Creates a new sem::Type owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the
- /// returned`Type` will also be destructed.
+ /// returned `Type` will also be destructed.
/// Types are unique (de-aliased), and so calling create() for the same `T`
/// and arguments will return the same pointer.
/// @param args the arguments to pass to the type constructor
/// @returns the de-aliased type pointer
template <typename T, typename... ARGS>
traits::EnableIfIsType<T, sem::Type>* create(ARGS&&... args) {
- static_assert(std::is_base_of<sem::Type, T>::value, "T does not derive from sem::Type");
AssertNotMoved();
return types_.Get<T>(std::forward<ARGS>(args)...);
}
+ /// Creates a new sem::ArrayCount owned by the ProgramBuilder.
+ /// When the ProgramBuilder is destructed, owned ProgramBuilder and the
+ /// returned `ArrayCount` will also be destructed.
+ /// ArrayCounts are unique (de-aliased), and so calling create() for the same `T`
+ /// and arguments will return the same pointer.
+ /// @param args the arguments to pass to the array count constructor
+ /// @returns the de-aliased array count pointer
+ template <typename T, typename... ARGS>
+ traits::EnableIfIsType<T, sem::ArrayCount>* create(ARGS&&... args) {
+ AssertNotMoved();
+ return types_.GetArrayCount<T>(std::forward<ARGS>(args)...);
+ }
/// Marks this builder as moved, preventing any further use of the builder.
void MarkAsMoved();
diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc
index 9df9807..cabc4ab 100644
--- a/src/tint/resolver/const_eval_construction_test.cc
+++ b/src/tint/resolver/const_eval_construction_test.cc
@@ -1321,7 +1321,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1359,7 +1358,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1397,7 +1395,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1449,7 +1446,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@@ -1487,7 +1483,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@@ -1525,7 +1520,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@@ -1564,7 +1558,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@@ -1594,7 +1587,6 @@
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
- EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc
index 469a5e0..ddbc8f8 100644
--- a/src/tint/resolver/inferred_type_test.cc
+++ b/src/tint/resolver/inferred_type_test.cc
@@ -135,8 +135,8 @@
TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
auto* type = ty.array(ty.u32(), 10_u);
- auto* expected_type =
- create<sem::Array>(create<sem::U32>(), sem::ConstantArrayCount{10u}, 4u, 4u * 10u, 4u, 4u);
+ auto* expected_type = create<sem::Array>(
+ create<sem::U32>(), create<sem::ConstantArrayCount>(10u), 4u, 4u * 10u, 4u, 4u);
auto* ctor_expr = Construct(type);
auto* var = Var("a", ast::AddressSpace::kFunction, ctor_expr);
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index e58e980..aed49c3 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -532,12 +532,13 @@
}
const sem::Array* build_array(MatchState& state, const sem::Type* el) {
- return state.builder.create<sem::Array>(el,
- /* count */ sem::RuntimeArrayCount{},
- /* align */ 0u,
- /* size */ 0u,
- /* stride */ 0u,
- /* stride_implicit */ 0u);
+ return state.builder.create<sem::Array>(
+ el,
+ /* count */ state.builder.create<sem::RuntimeArrayCount>(),
+ /* align */ 0u,
+ /* size */ 0u,
+ /* stride */ 0u,
+ /* stride_implicit */ 0u);
}
bool match_ptr(MatchState&, const sem::Type* ty, Number& S, const sem::Type*& T, Number& A) {
diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc
index 1608aa2..03255df 100644
--- a/src/tint/resolver/intrinsic_table_test.cc
+++ b/src/tint/resolver/intrinsic_table_test.cc
@@ -252,7 +252,8 @@
}
TEST_F(IntrinsicTableTest, MatchArray) {
- auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::U32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
auto* arr_ptr = create<sem::Pointer>(arr, ast::AddressSpace::kStorage, ast::Access::kReadWrite);
auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr},
sem::EvaluationStage::kConstant, Source{});
@@ -955,7 +956,8 @@
}
TEST_F(IntrinsicTableTest, MismatchTypeConversion) {
- auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::U32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
auto* f32 = create<sem::F32>();
auto result = table->Lookup(InitConvIntrinsic::kVec3, f32, utils::Vector{arr},
sem::EvaluationStage::kConstant, Source{{12, 34}});
diff --git a/src/tint/resolver/is_host_shareable_test.cc b/src/tint/resolver/is_host_shareable_test.cc
index ab969fa..5e1555b 100644
--- a/src/tint/resolver/is_host_shareable_test.cc
+++ b/src/tint/resolver/is_host_shareable_test.cc
@@ -106,13 +106,14 @@
}
TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
- auto* arr =
- create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::I32>(), create<sem::ConstantArrayCount>(5u), 4u, 20u,
+ 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::I32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
diff --git a/src/tint/resolver/is_storeable_test.cc b/src/tint/resolver/is_storeable_test.cc
index 61cf33b..6618199 100644
--- a/src/tint/resolver/is_storeable_test.cc
+++ b/src/tint/resolver/is_storeable_test.cc
@@ -89,13 +89,14 @@
}
TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
- auto* arr =
- create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::I32>(), create<sem::ConstantArrayCount>(5u), 4u, 20u,
+ 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}
TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::I32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 60f8207..f909bf0 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2143,8 +2143,7 @@
[&](const ast::Array* a) -> sem::Call* {
Mark(a);
// array element type must be inferred if it was not specified.
- sem::ArrayCount el_count =
- sem::ConstantArrayCount{static_cast<uint32_t>(args.Length())};
+ const sem::ArrayCount* el_count = nullptr;
const sem::Type* el_ty = nullptr;
if (a->type) {
el_ty = Type(a->type);
@@ -2155,14 +2154,15 @@
AddError("cannot construct a runtime-sized array", expr->source);
return nullptr;
}
- if (auto count = ArrayCount(a->count)) {
- el_count = count.Get();
- } else {
+ el_count = ArrayCount(a->count);
+ if (!el_count) {
return nullptr;
}
// Note: validation later will detect any mismatches between explicit array
// size and number of initializer expressions.
} else {
+ el_count = builder_->create<sem::ConstantArrayCount>(
+ static_cast<uint32_t>(args.Length()));
auto arg_tys =
utils::Transform(args, [](auto* arg) { return arg->Type()->UnwrapRef(); });
el_ty = sem::Type::Common(arg_tys);
@@ -2936,15 +2936,16 @@
return nullptr;
}
- sem::ArrayCount el_count = sem::RuntimeArrayCount{};
+ const sem::ArrayCount* el_count = nullptr;
// Evaluate the constant array count expression.
if (auto* count_expr = arr->count) {
- if (auto count = ArrayCount(count_expr)) {
- el_count = count.Get();
- } else {
+ el_count = ArrayCount(count_expr);
+ if (!el_count) {
return nullptr;
}
+ } else {
+ el_count = builder_->create<sem::RuntimeArrayCount>();
}
auto* out = Array(arr->type->source, //
@@ -2971,11 +2972,11 @@
return out;
}
-utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count_expr) {
+const sem::ArrayCount* Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array count expression.
const auto* count_sem = Materialize(Expression(count_expr));
if (!count_sem) {
- return utils::Failure;
+ return nullptr;
}
if (count_sem->Stage() == sem::EvaluationStage::kOverride) {
@@ -2983,34 +2984,34 @@
// Is the count a named 'override'?
if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
- return sem::ArrayCount{sem::NamedOverrideArrayCount{global}};
+ return builder_->create<sem::NamedOverrideArrayCount>(global);
}
}
- return sem::ArrayCount{sem::UnnamedOverrideArrayCount{count_sem}};
+ return builder_->create<sem::UnnamedOverrideArrayCount>(count_sem);
}
auto* count_val = count_sem->ConstantValue();
if (!count_val) {
AddError("array count must evaluate to a constant integer expression or override variable",
count_expr->source);
- return utils::Failure;
+ return nullptr;
}
if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
AddError("array count must evaluate to a constant integer expression, but is type '" +
builder_->FriendlyName(ty) + "'",
count_expr->source);
- return utils::Failure;
+ return nullptr;
}
int64_t count = count_val->As<AInt>();
if (count < 1) {
AddError("array count (" + std::to_string(count) + ") must be greater than 0",
count_expr->source);
- return utils::Failure;
+ return nullptr;
}
- return sem::ArrayCount{sem::ConstantArrayCount{static_cast<uint32_t>(count)}};
+ return builder_->create<sem::ConstantArrayCount>(static_cast<uint32_t>(count));
}
bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attributes,
@@ -3046,7 +3047,7 @@
sem::Array* Resolver::Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty,
- sem::ArrayCount el_count,
+ const sem::ArrayCount* el_count,
uint32_t explicit_stride) {
uint32_t el_align = el_ty->Align();
uint32_t el_size = el_ty->Size();
@@ -3054,7 +3055,7 @@
uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
uint64_t size = 0;
- if (auto const_count = std::get_if<sem::ConstantArrayCount>(&el_count)) {
+ if (auto const_count = el_count->As<sem::ConstantArrayCount>()) {
size = const_count->value * stride;
if (size > std::numeric_limits<uint32_t>::max()) {
std::stringstream msg;
@@ -3063,7 +3064,7 @@
AddError(msg.str(), count_source);
return nullptr;
}
- } else if (std::holds_alternative<sem::RuntimeArrayCount>(el_count)) {
+ } else if (el_count->Is<sem::RuntimeArrayCount>()) {
size = stride;
}
auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 56ae912..0deef5f 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -273,7 +273,7 @@
/// Resolves and validates the expression used as the count parameter of an array.
/// @param count_expr the expression used as the second template parameter to an array<>.
/// @returns the number of elements in the array.
- utils::Result<sem::ArrayCount> ArrayCount(const ast::Expression* count_expr);
+ const sem::ArrayCount* ArrayCount(const ast::Expression* count_expr);
/// Resolves and validates the attributes on an array.
/// @param attributes the attributes on the array type.
@@ -296,7 +296,7 @@
sem::Array* Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty,
- sem::ArrayCount el_count,
+ const sem::ArrayCount* el_count,
uint32_t explicit_stride);
/// Builds and returns the semantic information for the alias `alias`.
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 13b1da1..2a88249 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -440,7 +440,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
+ EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_SignedLiteral) {
@@ -453,7 +453,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
+ EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_UnsignedConst) {
@@ -468,7 +468,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
+ EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_SignedConst) {
@@ -483,7 +483,7 @@
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
- EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
+ EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_NamedOverride) {
@@ -500,7 +500,7 @@
auto* ary = ref->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
- EXPECT_EQ(ary->Count(), sem::NamedOverrideArrayCount{sem_override});
+ EXPECT_EQ(ary->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
}
TEST_F(ResolverTest, ArraySize_NamedOverride_Equivalence) {
@@ -525,8 +525,8 @@
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
- EXPECT_EQ(ary_a->Count(), sem::NamedOverrideArrayCount{sem_override});
- EXPECT_EQ(ary_b->Count(), sem::NamedOverrideArrayCount{sem_override});
+ EXPECT_EQ(ary_a->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
+ EXPECT_EQ(ary_b->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
EXPECT_EQ(ary_a, ary_b);
}
@@ -545,7 +545,7 @@
auto* ary = ref->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
- EXPECT_EQ(ary->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(cnt)});
+ EXPECT_EQ(ary->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(cnt)));
}
TEST_F(ResolverTest, ArraySize_UnamedOverride_Equivalence) {
@@ -572,8 +572,8 @@
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
- EXPECT_EQ(ary_a->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(a_cnt)});
- EXPECT_EQ(ary_b->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(b_cnt)});
+ EXPECT_EQ(ary_a->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(a_cnt)));
+ EXPECT_EQ(ary_b->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(b_cnt)));
EXPECT_NE(ary_a, ary_b);
}
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index fe33c26..f2567c8 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -659,9 +659,11 @@
/// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) {
auto* el = DataType<T>::Sem(b);
- sem::ArrayCount count = sem::ConstantArrayCount{N};
+ const sem::ArrayCount* count = nullptr;
if (N == 0) {
- count = sem::RuntimeArrayCount{};
+ count = b.create<sem::RuntimeArrayCount>();
+ } else {
+ count = b.create<sem::ConstantArrayCount>(N);
}
return b.create<sem::Array>(
/* element */ el,
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 83c8c78..5a7f9d8 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1778,7 +1778,12 @@
return false;
}
- const auto count = std::get<sem::ConstantArrayCount>(array_type->Count()).value;
+ if (!array_type->IsConstantSized()) {
+ TINT_ICE(Resolver, diagnostics_) << "Invalid ArrayCount found";
+ return false;
+ }
+
+ const auto count = array_type->Count()->As<sem::ConstantArrayCount>()->value;
if (!values.IsEmpty() && (values.Length() != count)) {
std::string fm = values.Length() < count ? "few" : "many";
AddError("array initializer has too " + fm + " elements: expected " +
diff --git a/src/tint/resolver/validator_is_storeable_test.cc b/src/tint/resolver/validator_is_storeable_test.cc
index 9fa064ac..cd079ce 100644
--- a/src/tint/resolver/validator_is_storeable_test.cc
+++ b/src/tint/resolver/validator_is_storeable_test.cc
@@ -89,13 +89,14 @@
}
TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
- auto* arr =
- create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
+ auto* arr = create<sem::Array>(create<sem::I32>(), create<sem::ConstantArrayCount>(5u), 4u, 20u,
+ 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
- auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
+ auto* arr =
+ create<sem::Array>(create<sem::I32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
diff --git a/src/tint/sem/array.cc b/src/tint/sem/array.cc
index d61d430..207a627 100644
--- a/src/tint/sem/array.cc
+++ b/src/tint/sem/array.cc
@@ -28,10 +28,10 @@
namespace {
-TypeFlags FlagsFrom(const Type* element, ArrayCount count) {
+TypeFlags FlagsFrom(const Type* element, const ArrayCount* count) {
TypeFlags flags;
// Only constant-expression sized arrays are constructible
- if (std::holds_alternative<ConstantArrayCount>(count)) {
+ if (count->Is<ConstantArrayCount>()) {
if (element->IsConstructible()) {
flags.Add(TypeFlag::kConstructable);
}
@@ -39,9 +39,7 @@
flags.Add(TypeFlag::kCreationFixedFootprint);
}
}
- if (std::holds_alternative<ConstantArrayCount>(count) ||
- std::holds_alternative<NamedOverrideArrayCount>(count) ||
- std::holds_alternative<UnnamedOverrideArrayCount>(count)) {
+ if (count->IsAnyOf<ConstantArrayCount, NamedOverrideArrayCount, UnnamedOverrideArrayCount>()) {
if (element->HasFixedFootprint()) {
flags.Add(TypeFlag::kFixedFootprint);
}
@@ -56,7 +54,7 @@
"Was the SubstituteOverride transform run?";
Array::Array(const Type* element,
- ArrayCount count,
+ const ArrayCount* count,
uint32_t align,
uint32_t size,
uint32_t stride,
@@ -91,11 +89,11 @@
out << "@stride(" << stride_ << ") ";
}
out << "array<" << element_->FriendlyName(symbols);
- if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
+ if (auto* const_count = count_->As<ConstantArrayCount>()) {
out << ", " << const_count->value;
- } else if (auto* named_override_count = std::get_if<NamedOverrideArrayCount>(&count_)) {
+ } else if (auto* named_override_count = count_->As<NamedOverrideArrayCount>()) {
out << ", " << symbols.NameFor(named_override_count->variable->Declaration()->symbol);
- } else if (std::holds_alternative<UnnamedOverrideArrayCount>(count_)) {
+ } else if (count_->Is<UnnamedOverrideArrayCount>()) {
out << ", [unnamed override-expression]";
}
out << ">";
diff --git a/src/tint/sem/array.h b/src/tint/sem/array.h
index 4d1ed7d..c068948 100644
--- a/src/tint/sem/array.h
+++ b/src/tint/sem/array.h
@@ -20,6 +20,7 @@
#include <string>
#include <variant>
+#include "src/tint/sem/array_count.h"
#include "src/tint/sem/node.h"
#include "src/tint/sem/type.h"
#include "src/tint/utils/compiler_macros.h"
@@ -33,115 +34,6 @@
namespace tint::sem {
-/// The variant of an ArrayCount when the array is a const-expression.
-/// Example:
-/// ```
-/// const N = 123;
-/// type arr = array<i32, N>
-/// ```
-struct ConstantArrayCount {
- /// The array count constant-expression value.
- uint32_t value;
-};
-
-/// The variant of an ArrayCount when the count is a named override variable.
-/// Example:
-/// ```
-/// override N : i32;
-/// type arr = array<i32, N>
-/// ```
-struct NamedOverrideArrayCount {
- /// The `override` variable.
- const GlobalVariable* variable;
-};
-
-/// The variant of an ArrayCount when the count is an unnamed override variable.
-/// Example:
-/// ```
-/// override N : i32;
-/// type arr = array<i32, N*2>
-/// ```
-struct UnnamedOverrideArrayCount {
- /// The unnamed override expression.
- /// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST
- /// expressions will not result in the same `expr` pointer. This property is important to ensure
- /// that two array declarations with equivalent AST expressions do not compare equal.
- /// For example, consider:
- /// ```
- /// override size : u32;
- /// var<workgroup> a : array<f32, size * 2>;
- /// var<workgroup> b : array<f32, size * 2>;
- /// ```
- // The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and
- // `b` must not compare equal.
- const Expression* expr;
-};
-
-/// The variant of an ArrayCount when the array is is runtime-sized.
-/// Example:
-/// ```
-/// type arr = array<i32>
-/// ```
-struct RuntimeArrayCount {};
-
-/// An array count is either a constant-expression value, a named override identifier, an unnamed
-/// override identifier, or runtime-sized.
-using ArrayCount = std::variant<ConstantArrayCount,
- NamedOverrideArrayCount,
- UnnamedOverrideArrayCount,
- RuntimeArrayCount>;
-
-/// Equality operator
-/// @param a the LHS ConstantArrayCount
-/// @param b the RHS ConstantArrayCount
-/// @returns true if @p a is equal to @p b
-inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) {
- return a.value == b.value;
-}
-
-/// Equality operator
-/// @param a the LHS OverrideArrayCount
-/// @param b the RHS OverrideArrayCount
-/// @returns true if @p a is equal to @p b
-inline bool operator==(const NamedOverrideArrayCount& a, const NamedOverrideArrayCount& b) {
- return a.variable == b.variable;
-}
-
-/// Equality operator
-/// @param a the LHS OverrideArrayCount
-/// @param b the RHS OverrideArrayCount
-/// @returns true if @p a is equal to @p b
-inline bool operator==(const UnnamedOverrideArrayCount& a, const UnnamedOverrideArrayCount& b) {
- return a.expr == b.expr;
-}
-
-/// Equality operator
-/// @returns true
-inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
- return true;
-}
-
-/// Equality operator
-/// @param a the LHS ArrayCount
-/// @param b the RHS count
-/// @returns true if @p a is equal to @p b
-template <typename T,
- typename = std::enable_if_t<
- std::is_same_v<T, ConstantArrayCount> || std::is_same_v<T, NamedOverrideArrayCount> ||
- std::is_same_v<T, UnnamedOverrideArrayCount> || std::is_same_v<T, RuntimeArrayCount>>>
-inline bool operator==(const ArrayCount& a, const T& b) {
- TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
- return std::visit(
- [&](auto count) {
- if constexpr (std::is_same_v<std::decay_t<decltype(count)>, T>) {
- return count == b;
- }
- return false;
- },
- a);
- TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
-}
-
/// Array holds the semantic information for Array nodes.
class Array final : public Castable<Array, Type> {
public:
@@ -161,7 +53,7 @@
/// of the array to the start of the next element, if there was no `@stride`
/// attribute applied.
Array(Type const* element,
- ArrayCount count,
+ const ArrayCount* count,
uint32_t align,
uint32_t size,
uint32_t stride,
@@ -178,11 +70,11 @@
Type const* ElemType() const { return element_; }
/// @returns the number of elements in the array.
- const ArrayCount& Count() const { return count_; }
+ const ArrayCount* Count() const { return count_; }
/// @returns the array count if the count is a const-expression, otherwise returns nullopt.
inline std::optional<uint32_t> ConstantCount() const {
- if (auto* count = std::get_if<ConstantArrayCount>(&count_)) {
+ if (auto* count = count_->As<ConstantArrayCount>()) {
return count->value;
}
return std::nullopt;
@@ -212,23 +104,19 @@
bool IsStrideImplicit() const { return stride_ == implicit_stride_; }
/// @returns true if this array is sized using an const-expression
- bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
+ bool IsConstantSized() const { return count_->Is<ConstantArrayCount>(); }
/// @returns true if this array is sized using a named override variable
- bool IsNamedOverrideSized() const {
- return std::holds_alternative<NamedOverrideArrayCount>(count_);
- }
+ bool IsNamedOverrideSized() const { return count_->Is<NamedOverrideArrayCount>(); }
/// @returns true if this array is sized using an unnamed override variable
- bool IsUnnamedOverrideSized() const {
- return std::holds_alternative<UnnamedOverrideArrayCount>(count_);
- }
+ bool IsUnnamedOverrideSized() const { return count_->Is<UnnamedOverrideArrayCount>(); }
/// @returns true if this array is sized using a named or unnamed override variable
bool IsOverrideSized() const { return IsNamedOverrideSized() || IsUnnamedOverrideSized(); }
/// @returns true if this array is runtime sized
- bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
+ bool IsRuntimeSized() const { return count_->Is<RuntimeArrayCount>(); }
/// @param symbols the program's symbol table
/// @returns the name for this type that closely resembles how it would be
@@ -237,7 +125,7 @@
private:
Type const* const element_;
- const ArrayCount count_;
+ const ArrayCount* count_;
const uint32_t align_;
const uint32_t size_;
const uint32_t stride_;
@@ -246,49 +134,4 @@
} // namespace tint::sem
-namespace std {
-
-/// Custom std::hash specialization for tint::sem::ConstantArrayCount.
-template <>
-class hash<tint::sem::ConstantArrayCount> {
- public:
- /// @param count the count to hash
- /// @return the hash value
- inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const {
- return std::hash<decltype(count.value)>()(count.value);
- }
-};
-
-/// Custom std::hash specialization for tint::sem::NamedOverrideArrayCount.
-template <>
-class hash<tint::sem::NamedOverrideArrayCount> {
- public:
- /// @param count the count to hash
- /// @return the hash value
- inline std::size_t operator()(const tint::sem::NamedOverrideArrayCount& count) const {
- return std::hash<decltype(count.variable)>()(count.variable);
- }
-};
-
-/// Custom std::hash specialization for tint::sem::UnnamedOverrideArrayCount.
-template <>
-class hash<tint::sem::UnnamedOverrideArrayCount> {
- public:
- /// @param count the count to hash
- /// @return the hash value
- inline std::size_t operator()(const tint::sem::UnnamedOverrideArrayCount& count) const {
- return std::hash<decltype(count.expr)>()(count.expr);
- }
-};
-
-/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
-template <>
-class hash<tint::sem::RuntimeArrayCount> {
- public:
- /// @return the hash value
- inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; }
-};
-
-} // namespace std
-
#endif // SRC_TINT_SEM_ARRAY_H_
diff --git a/src/tint/sem/array_count.cc b/src/tint/sem/array_count.cc
new file mode 100644
index 0000000..fa16639
--- /dev/null
+++ b/src/tint/sem/array_count.cc
@@ -0,0 +1,82 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/sem/array_count.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::sem::ArrayCount);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::ConstantArrayCount);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::RuntimeArrayCount);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::NamedOverrideArrayCount);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::UnnamedOverrideArrayCount);
+
+namespace tint::sem {
+
+ArrayCount::ArrayCount() : Base() {}
+ArrayCount::~ArrayCount() = default;
+
+ConstantArrayCount::ConstantArrayCount(uint32_t val) : Base(), value(val) {}
+ConstantArrayCount::~ConstantArrayCount() = default;
+
+size_t ConstantArrayCount::Hash() const {
+ return static_cast<size_t>(TypeInfo::Of<ConstantArrayCount>().full_hashcode);
+}
+
+bool ConstantArrayCount::Equals(const ArrayCount& other) const {
+ if (auto* v = other.As<ConstantArrayCount>()) {
+ return value == v->value;
+ }
+ return false;
+}
+
+RuntimeArrayCount::RuntimeArrayCount() : Base() {}
+RuntimeArrayCount::~RuntimeArrayCount() = default;
+
+size_t RuntimeArrayCount::Hash() const {
+ return static_cast<size_t>(TypeInfo::Of<RuntimeArrayCount>().full_hashcode);
+}
+
+bool RuntimeArrayCount::Equals(const ArrayCount& other) const {
+ return other.Is<RuntimeArrayCount>();
+}
+
+NamedOverrideArrayCount::NamedOverrideArrayCount(const GlobalVariable* var)
+ : Base(), variable(var) {}
+NamedOverrideArrayCount::~NamedOverrideArrayCount() = default;
+
+size_t NamedOverrideArrayCount::Hash() const {
+ return static_cast<size_t>(TypeInfo::Of<NamedOverrideArrayCount>().full_hashcode);
+}
+
+bool NamedOverrideArrayCount::Equals(const ArrayCount& other) const {
+ if (auto* v = other.As<NamedOverrideArrayCount>()) {
+ return variable == v->variable;
+ }
+ return false;
+}
+
+UnnamedOverrideArrayCount::UnnamedOverrideArrayCount(const Expression* e) : Base(), expr(e) {}
+UnnamedOverrideArrayCount::~UnnamedOverrideArrayCount() = default;
+
+size_t UnnamedOverrideArrayCount::Hash() const {
+ return static_cast<size_t>(TypeInfo::Of<UnnamedOverrideArrayCount>().full_hashcode);
+}
+
+bool UnnamedOverrideArrayCount::Equals(const ArrayCount& other) const {
+ if (auto* v = other.As<UnnamedOverrideArrayCount>()) {
+ return expr == v->expr;
+ }
+ return false;
+}
+
+} // namespace tint::sem
diff --git a/src/tint/sem/array_count.h b/src/tint/sem/array_count.h
new file mode 100644
index 0000000..eb1a001
--- /dev/null
+++ b/src/tint/sem/array_count.h
@@ -0,0 +1,170 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_SEM_ARRAY_COUNT_H_
+#define SRC_TINT_SEM_ARRAY_COUNT_H_
+
+#include <functional>
+#include <string>
+
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/node.h"
+#include "src/tint/sem/variable.h"
+
+namespace tint::sem {
+
+/// An array count
+class ArrayCount : public Castable<ArrayCount, Node> {
+ public:
+ ~ArrayCount() override;
+
+ /// @returns a hash of the array count.
+ virtual size_t Hash() const = 0;
+
+ /// @param t other array count
+ /// @returns true if this array count is equal to the given array count
+ virtual bool Equals(const ArrayCount& t) const = 0;
+
+ protected:
+ ArrayCount();
+};
+
+/// The variant of an ArrayCount when the array is a const-expression.
+/// Example:
+/// ```
+/// const N = 123;
+/// type arr = array<i32, N>
+/// ```
+class ConstantArrayCount final : public Castable<ConstantArrayCount, ArrayCount> {
+ public:
+ /// Constructor
+ /// @param val the constant-expression value
+ explicit ConstantArrayCount(uint32_t val);
+ ~ConstantArrayCount() override;
+
+ /// @returns a hash of the array count.
+ size_t Hash() const override;
+
+ /// @param t other array count
+ /// @returns true if this array count is equal to the given array count
+ bool Equals(const ArrayCount& t) const override;
+
+ /// The array count constant-expression value.
+ uint32_t value;
+};
+
+/// The variant of an ArrayCount when the array is is runtime-sized.
+/// Example:
+/// ```
+/// type arr = array<i32>
+/// ```
+class RuntimeArrayCount final : public Castable<RuntimeArrayCount, ArrayCount> {
+ public:
+ /// Constructor
+ RuntimeArrayCount();
+ ~RuntimeArrayCount() override;
+
+ /// @returns a hash of the array count.
+ size_t Hash() const override;
+
+ /// @param t other array count
+ /// @returns true if this array count is equal to the given array count
+ bool Equals(const ArrayCount& t) const override;
+};
+
+/// The variant of an ArrayCount when the count is a named override variable.
+/// Example:
+/// ```
+/// override N : i32;
+/// type arr = array<i32, N>
+/// ```
+class NamedOverrideArrayCount final : public Castable<NamedOverrideArrayCount, ArrayCount> {
+ public:
+ /// Constructor
+ /// @param var the `override` variable
+ explicit NamedOverrideArrayCount(const GlobalVariable* var);
+ ~NamedOverrideArrayCount() override;
+
+ /// @returns a hash of the array count.
+ size_t Hash() const override;
+
+ /// @param t other array count
+ /// @returns true if this array count is equal to the given array count
+ bool Equals(const ArrayCount& t) const override;
+
+ /// The `override` variable.
+ const GlobalVariable* variable;
+};
+
+/// The variant of an ArrayCount when the count is an unnamed override variable.
+/// Example:
+/// ```
+/// override N : i32;
+/// type arr = array<i32, N*2>
+/// ```
+class UnnamedOverrideArrayCount final : public Castable<UnnamedOverrideArrayCount, ArrayCount> {
+ public:
+ /// Constructor
+ /// @param e the override expression
+ explicit UnnamedOverrideArrayCount(const Expression* e);
+ ~UnnamedOverrideArrayCount() override;
+
+ /// @returns a hash of the array count.
+ size_t Hash() const override;
+
+ /// @param t other array count
+ /// @returns true if this array count is equal to the given array count
+ bool Equals(const ArrayCount& t) const override;
+
+ /// The unnamed override expression.
+ /// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST
+ /// expressions will not result in the same `expr` pointer. This property is important to ensure
+ /// that two array declarations with equivalent AST expressions do not compare equal.
+ /// For example, consider:
+ /// ```
+ /// override size : u32;
+ /// var<workgroup> a : array<f32, size * 2>;
+ /// var<workgroup> b : array<f32, size * 2>;
+ /// ```
+ // The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and
+ // `b` must not compare equal.
+ const Expression* expr;
+};
+
+} // namespace tint::sem
+
+namespace std {
+
+/// std::hash specialization for tint::sem::ArrayCount
+template <>
+struct hash<tint::sem::ArrayCount> {
+ /// @param a the array count to obtain a hash from
+ /// @returns the hash of the array count
+ size_t operator()(const tint::sem::ArrayCount& a) const { return a.Hash(); }
+};
+
+/// std::equal_to specialization for tint::sem::ArrayCount
+template <>
+struct equal_to<tint::sem::ArrayCount> {
+ /// @param a the first array count to compare
+ /// @param b the second array count to compare
+ /// @returns true if the two array counts are equal
+ bool operator()(const tint::sem::ArrayCount& a, const tint::sem::ArrayCount& b) const {
+ return a.Equals(b);
+ }
+};
+
+} // namespace std
+
+#endif // SRC_TINT_SEM_ARRAY_COUNT_H_
diff --git a/src/tint/sem/array_test.cc b/src/tint/sem/array_test.cc
index a51b492..44073df 100644
--- a/src/tint/sem/array_test.cc
+++ b/src/tint/sem/array_test.cc
@@ -21,16 +21,16 @@
using ArrayTest = TestHelper;
TEST_F(ArrayTest, CreateSizedArray) {
- auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
- auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
+ auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
+ auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
+ auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
+ auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
+ auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
+ auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
- EXPECT_EQ(a->Count(), ConstantArrayCount{2u});
+ EXPECT_EQ(a->Count(), create<ConstantArrayCount>(2u));
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@@ -47,15 +47,15 @@
}
TEST_F(ArrayTest, CreateRuntimeArray) {
- auto* a = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
- auto* b = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
- auto* c = create<Array>(create<U32>(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u);
- auto* d = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u);
- auto* e = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u);
- auto* f = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 32u);
+ auto* b = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 32u);
+ auto* c = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 5u, 8u, 32u, 32u);
+ auto* d = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 9u, 32u, 32u);
+ auto* e = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 33u, 32u);
+ auto* f = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
- EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{});
+ EXPECT_EQ(a->Count(), create<sem::RuntimeArrayCount>());
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@@ -71,13 +71,13 @@
}
TEST_F(ArrayTest, Hash) {
- auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
- auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
+ auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
+ auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
+ auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
+ auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
+ auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
+ auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash());
EXPECT_NE(a->Hash(), c->Hash());
@@ -88,13 +88,13 @@
}
TEST_F(ArrayTest, Equals) {
- auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
- auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
+ auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
+ auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
+ auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
+ auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
+ auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
+ auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
+ auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@@ -106,32 +106,34 @@
}
TEST_F(ArrayTest, FriendlyNameRuntimeSized) {
- auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u);
+ auto* arr = create<Array>(create<I32>(), create<RuntimeArrayCount>(), 0u, 4u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSized) {
- auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
+ auto* arr = create<Array>(create<I32>(), create<ConstantArrayCount>(5u), 4u, 20u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>");
}
TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) {
- auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u);
+ auto* arr = create<Array>(create<I32>(), create<RuntimeArrayCount>(), 0u, 4u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
- auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u);
+ auto* arr = create<Array>(create<I32>(), create<ConstantArrayCount>(5u), 4u, 20u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>");
}
TEST_F(ArrayTest, IsConstructable) {
- auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* fixed_sized =
+ create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* named_override_sized =
- create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
+ create<Array>(create<U32>(), create<NamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
- create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
- auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
+ create<Array>(create<U32>(), create<UnnamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
+ auto* runtime_sized =
+ create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->IsConstructible());
EXPECT_FALSE(named_override_sized->IsConstructible());
@@ -140,12 +142,14 @@
}
TEST_F(ArrayTest, HasCreationFixedFootprint) {
- auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* fixed_sized =
+ create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* named_override_sized =
- create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
+ create<Array>(create<U32>(), create<NamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
- create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
- auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
+ create<Array>(create<U32>(), create<UnnamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
+ auto* runtime_sized =
+ create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->HasCreationFixedFootprint());
EXPECT_FALSE(named_override_sized->HasCreationFixedFootprint());
@@ -154,12 +158,14 @@
}
TEST_F(ArrayTest, HasFixedFootprint) {
- auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
+ auto* fixed_sized =
+ create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* named_override_sized =
- create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
+ create<Array>(create<U32>(), create<NamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
- create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
- auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
+ create<Array>(create<U32>(), create<UnnamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
+ auto* runtime_sized =
+ create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->HasFixedFootprint());
EXPECT_TRUE(named_override_sized->HasFixedFootprint());
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index f9db53b..3d25e7e 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -273,7 +273,7 @@
},
[&](const Array* a) {
if (count) {
- if (auto* const_count = std::get_if<ConstantArrayCount>(&a->Count())) {
+ if (auto* const_count = a->Count()->As<ConstantArrayCount>()) {
*count = const_count->value;
}
}
diff --git a/src/tint/sem/type_manager.h b/src/tint/sem/type_manager.h
index 72f843a..33c42b3 100644
--- a/src/tint/sem/type_manager.h
+++ b/src/tint/sem/type_manager.h
@@ -19,6 +19,7 @@
#include <unordered_map>
#include <utility>
+#include "src/tint/sem/array_count.h"
#include "src/tint/sem/type.h"
#include "src/tint/utils/unique_allocator.h"
@@ -56,6 +57,7 @@
static TypeManager Wrap(const TypeManager& inner) {
TypeManager out;
out.types_.Wrap(inner.types_);
+ out.array_counts_.Wrap(inner.array_counts_);
return out;
}
@@ -80,6 +82,17 @@
return types_.Find<TYPE>(std::forward<ARGS>(args)...);
}
+ /// @param args the arguments used to construct the object.
+ /// @return a pointer to an instance of `T` with the provided arguments.
+ /// If an existing instance of `T` has been constructed, then the same
+ /// pointer is returned.
+ template <typename TYPE,
+ typename _ = std::enable_if<traits::IsTypeOrDerived<TYPE, sem::ArrayCount>>,
+ typename... ARGS>
+ TYPE* GetArrayCount(ARGS&&... args) {
+ return array_counts_.Get<TYPE>(std::forward<ARGS>(args)...);
+ }
+
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
/// @returns an iterator to the end of the types
@@ -87,6 +100,7 @@
private:
utils::UniqueAllocator<Type> types_;
+ utils::UniqueAllocator<ArrayCount> array_counts_;
};
} // namespace tint::sem
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index db7616b..9148322 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -100,63 +100,63 @@
/* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>(
/* element */ i32,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_ai = create<Array>(
/* element */ ai,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_vec3_i32 = create<Array>(
/* element */ vec3_i32,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_vec3_ai = create<Array>(
/* element */ vec3_ai,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_mat4x3_f16 = create<Array>(
/* element */ mat4x3_f16,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 32u,
/* size */ 5u * 32u,
/* stride */ 5u * 32u,
/* implicit_stride */ 5u * 32u);
const sem::Array* arr_mat4x3_f32 = create<Array>(
/* element */ mat4x3_f32,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_mat4x3_af = create<Array>(
/* element */ mat4x3_af,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_str_f16 = create<Array>(
/* element */ str_f16,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_str_af = create<Array>(
/* element */ str_af,
- /* count */ ConstantArrayCount{5u},
+ /* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index 5c8357c..d2eea42 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -109,11 +109,11 @@
if (a->IsRuntimeSized()) {
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
}
- if (auto* override = std::get_if<sem::NamedOverrideArrayCount>(&a->Count())) {
+ if (auto* override = a->Count()->As<sem::NamedOverrideArrayCount>()) {
auto* count = ctx.Clone(override->variable->Declaration());
return ctx.dst->ty.array(el, count, std::move(attrs));
}
- if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) {
+ if (auto* override = a->Count()->As<sem::UnnamedOverrideArrayCount>()) {
// If the array count is an unnamed (complex) override expression, then its not safe to
// redeclare this type as we'd end up with two types that would not compare equal.
// See crbug.com/tint/1764.
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index 2a39094..29a21e1 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -69,8 +69,8 @@
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
- return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 32u,
- 32u);
+ return b.create<sem::Array>(b.create<sem::F32>(), b.create<sem::ConstantArrayCount>(2u), 4u,
+ 4u, 32u, 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
@@ -83,8 +83,8 @@
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
- return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 64u,
- 32u);
+ return b.create<sem::Array>(b.create<sem::F32>(), b.create<sem::ConstantArrayCount>(2u), 4u,
+ 4u, 64u, 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
diff --git a/src/tint/utils/unique_allocator.h b/src/tint/utils/unique_allocator.h
index 25681cd..7ba5909 100644
--- a/src/tint/utils/unique_allocator.h
+++ b/src/tint/utils/unique_allocator.h
@@ -91,7 +91,7 @@
struct Entry {
/// The pre-calculated hash of the entry
size_t hash;
- /// Tge pointer to the unique object
+ /// The pointer to the unique object
T* ptr;
};
/// Comparator is the hashing and equality function used by the unordered_set