[spirv-reader][ir] Convert subgroup size when needed.
In SPIR-V the subgroup size can be i32 or u32. In WGSL it must be u32.
Make sure we do any required conversions to match types.
Bug: 42250952
Change-Id: Iebdfb6272c06a1152f3a0f12d87276603751c6b0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/245756
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/lower/shader_io.cc b/src/tint/lang/spirv/reader/lower/shader_io.cc
index d638f07..9c9254f 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io.cc
@@ -419,6 +419,7 @@
case core::BuiltinValue::kVertexIndex:
case core::BuiltinValue::kLocalInvocationIndex:
case core::BuiltinValue::kSubgroupInvocationId:
+ case core::BuiltinValue::kSubgroupSize:
case core::BuiltinValue::kSampleIndex: {
var_type = ty.u32();
break;
@@ -490,6 +491,7 @@
case core::BuiltinValue::kVertexIndex:
case core::BuiltinValue::kLocalInvocationIndex:
case core::BuiltinValue::kSubgroupInvocationId:
+ case core::BuiltinValue::kSubgroupSize:
case core::BuiltinValue::kSampleIndex: {
auto* idx_ty = var->Result()->Type()->UnwrapPtr();
if (idx_ty->IsSignedIntegerScalar()) {
diff --git a/src/tint/lang/spirv/reader/lower/shader_io_test.cc b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
index 16a024a..0d54fda 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io_test.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
@@ -2482,6 +2482,96 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvReader_ShaderIOTest, SubgroupSize_i32) {
+ auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.i32()));
+ idx->SetBuiltin(core::BuiltinValue::kSubgroupSize);
+ mod.root_block->Append(idx);
+
+ auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ ep->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
+ b.Append(ep->Block(), [&] {
+ auto* idx_value = b.Load(idx);
+ b.Let("a", b.Multiply(ty.i32(), idx_value, 2_i));
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %idx:ptr<__in, i32, read> = var undef @builtin(subgroup_size)
+}
+
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B2: {
+ %3:i32 = load %idx
+ %4:i32 = mul %3, 2i
+ %a:i32 = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func(%idx:u32 [@subgroup_size]):void {
+ $B1: {
+ %3:i32 = convert %idx
+ %4:i32 = mul %3, 2i
+ %a:i32 = let %4
+ ret
+ }
+}
+)";
+
+ Run(ShaderIO);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvReader_ShaderIOTest, SubgroupSize_u32) {
+ auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.u32()));
+ idx->SetBuiltin(core::BuiltinValue::kSubgroupSize);
+ mod.root_block->Append(idx);
+
+ auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ ep->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
+ b.Append(ep->Block(), [&] {
+ auto* idx_value = b.Load(idx);
+ b.Let("a", b.Multiply(ty.u32(), idx_value, 2_u));
+
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %idx:ptr<__in, u32, read> = var undef @builtin(subgroup_size)
+}
+
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B2: {
+ %3:u32 = load %idx
+ %4:u32 = mul %3, 2u
+ %a:u32 = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func(%idx:u32 [@subgroup_size]):void {
+ $B1: {
+ %3:u32 = mul %idx, 2u
+ %a:u32 = let %3
+ ret
+ }
+}
+)";
+
+ Run(ShaderIO);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(SpirvReader_ShaderIOTest, LocalInvocationId_i32) {
auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.vec3<i32>()));
idx->SetBuiltin(core::BuiltinValue::kLocalInvocationId);