[spirv-reader][ir] Convert workgroup id when needed.
In SPIR-V the workgroup id can be a vec3<i32 or u32>. In WGSL it must be
a vec3<u32>. Make sure we do any required conversions to match types.
Bug: 42250952
Change-Id: I59b54093387270d08da5ef25798da0446634a81c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/245656
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/lower/shader_io.cc b/src/tint/lang/spirv/reader/lower/shader_io.cc
index 9daa609..5e41801 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io.cc
@@ -421,7 +421,8 @@
break;
}
case core::BuiltinValue::kLocalInvocationId:
- case core::BuiltinValue::kGlobalInvocationId: {
+ case core::BuiltinValue::kGlobalInvocationId:
+ case core::BuiltinValue::kWorkgroupId: {
var_type = ty.vec3<u32>();
break;
}
@@ -492,7 +493,8 @@
break;
}
case core::BuiltinValue::kLocalInvocationId:
- case core::BuiltinValue::kGlobalInvocationId: {
+ case core::BuiltinValue::kGlobalInvocationId:
+ case core::BuiltinValue::kWorkgroupId: {
auto* idx_ty = var->Result()->Type()->UnwrapPtr();
auto* elem_ty = idx_ty->DeepestElement();
if (elem_ty->IsSignedIntegerScalar()) {
diff --git a/src/tint/lang/spirv/reader/lower/shader_io_test.cc b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
index f04a02a..731faee 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io_test.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
@@ -2441,6 +2441,96 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvReader_ShaderIOTest, WorkgroupId_i32) {
+ auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.vec3<i32>()));
+ idx->SetBuiltin(core::BuiltinValue::kWorkgroupId);
+ mod.root_block->Append(idx);
+
+ auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ ep->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
+ b.Append(ep->Block(), [&] {
+ auto* idx_value = b.Load(idx);
+ b.Let("a", b.Multiply(ty.vec3<i32>(), idx_value, b.Splat(ty.vec3<i32>(), 2_i)));
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %idx:ptr<__in, vec3<i32>, read> = var undef @builtin(workgroup_id)
+}
+
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B2: {
+ %3:vec3<i32> = load %idx
+ %4:vec3<i32> = mul %3, vec3<i32>(2i)
+ %a:vec3<i32> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func(%idx:vec3<u32> [@workgroup_id]):void {
+ $B1: {
+ %3:vec3<i32> = convert %idx
+ %4:vec3<i32> = mul %3, vec3<i32>(2i)
+ %a:vec3<i32> = let %4
+ ret
+ }
+}
+)";
+
+ Run(ShaderIO);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvReader_ShaderIOTest, WorkgroupId_u32) {
+ auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.vec3<u32>()));
+ idx->SetBuiltin(core::BuiltinValue::kWorkgroupId);
+ mod.root_block->Append(idx);
+
+ auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ ep->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
+ b.Append(ep->Block(), [&] {
+ auto* idx_value = b.Load(idx);
+ b.Let("a", b.Multiply(ty.vec3<u32>(), idx_value, b.Splat(ty.vec3<u32>(), 2_u)));
+
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %idx:ptr<__in, vec3<u32>, read> = var undef @builtin(workgroup_id)
+}
+
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B2: {
+ %3:vec3<u32> = load %idx
+ %4:vec3<u32> = mul %3, vec3<u32>(2u)
+ %a:vec3<u32> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func(%idx:vec3<u32> [@workgroup_id]):void {
+ $B1: {
+ %3:vec3<u32> = mul %idx, vec3<u32>(2u)
+ %a:vec3<u32> = let %3
+ ret
+ }
+}
+)";
+
+ Run(ShaderIO);
+
+ EXPECT_EQ(expect, str());
+}
+
// Test that a sample mask array is converted to a scalar u32 for the entry point.
TEST_F(SpirvReader_ShaderIOTest, SampleMask) {
auto* arr = ty.array<u32, 1>();