[tint] Fix implicit stride in array helper
The element stride may be larger than the element alignment.
Add a test for nested arrays in the IR-based SPIR-V writer which
exposes this bug.
Change-Id: If59e32330eb49ec284487fd38e5c09aeb13de888
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/135580
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/type/manager.cc b/src/tint/type/manager.cc
index 4a98045..136f910 100644
--- a/src/tint/type/manager.cc
+++ b/src/tint/type/manager.cc
@@ -128,15 +128,18 @@
const type::Array* Manager::array(const type::Type* elem_ty,
uint32_t count,
uint32_t stride /* = 0*/) {
+ uint32_t implicit_stride = utils::RoundUp(elem_ty->Align(), elem_ty->Size());
if (stride == 0) {
- stride = elem_ty->Align();
+ stride = implicit_stride;
}
+ TINT_ASSERT(Type, stride >= implicit_stride);
+
return Get<type::Array>(/* element type */ elem_ty,
/* element count */ Get<ConstantArrayCount>(count),
/* array alignment */ elem_ty->Align(),
/* array size */ count * stride,
/* element stride */ stride,
- /* implicit stride */ elem_ty->Align());
+ /* implicit stride */ implicit_stride);
}
const type::Array* Manager::runtime_array(const type::Type* elem_ty, uint32_t stride /* = 0 */) {
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index 7b2142c..284d421 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -149,6 +149,22 @@
EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 16\n");
}
+TEST_F(SpvGeneratorImplTest, Type_Array_NestedArray) {
+ auto* arr = mod.Types().array(mod.Types().array(mod.Types().f32(), 64u), 4u);
+ auto id = generator_.Type(arr);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%3 = OpTypeFloat 32\n"
+ "%5 = OpTypeInt 32 0\n"
+ "%4 = OpConstant %5 64\n"
+ "%2 = OpTypeArray %3 %4\n"
+ "%6 = OpConstant %5 4\n"
+ "%1 = OpTypeArray %2 %6\n");
+ EXPECT_EQ(DumpInstructions(generator_.Module().Annots()),
+ "OpDecorate %2 ArrayStride 4\n"
+ "OpDecorate %1 ArrayStride 256\n");
+}
+
TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_DefaultStride) {
auto* arr = mod.Types().runtime_array(mod.Types().f32());
auto id = generator_.Type(arr);