[spirv-reader][ir] Convert local invocation id when needed.

In SPIR-V the local invocation id can be a vec3<i32 or u32>. In WGSL it
must be a vec3<u32>. Make sure we do any required conversions to match
types.

Bug: 42250952
Change-Id: Ib1d7f68e9840c73a17fdaf38a79222e7ee420e2e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/245694
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/lower/shader_io.cc b/src/tint/lang/spirv/reader/lower/shader_io.cc
index 6c0cde0..03d1b97 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io.cc
@@ -410,7 +410,6 @@
             if (entry_point && var->Attributes().builtin.has_value()) {
                 switch (var->Attributes().builtin.value()) {
                     case core::BuiltinValue::kSampleMask: {
-                        // Use a scalar u32 for sample_mask builtins
                         TINT_ASSERT(var_type->Is<core::type::Array>());
                         TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u);
                         var_type = ty.u32();
@@ -418,10 +417,13 @@
                     }
                     case core::BuiltinValue::kInstanceIndex:
                     case core::BuiltinValue::kLocalInvocationIndex: {
-                        // Use a scalar u32
                         var_type = ty.u32();
                         break;
                     }
+                    case core::BuiltinValue::kLocalInvocationId: {
+                        var_type = ty.vec3<u32>();
+                        break;
+                    }
                     default: {
                         break;
                     }
@@ -488,6 +490,16 @@
                         }
                         break;
                     }
+                    case core::BuiltinValue::kLocalInvocationId: {
+                        auto* idx_ty = var->Result()->Type()->UnwrapPtr();
+                        auto* elem_ty = idx_ty->DeepestElement();
+                        if (elem_ty->IsSignedIntegerScalar()) {
+                            auto* conv = b.Convert(ty.MatchWidth(ty.i32(), idx_ty), result);
+                            func->Block()->Prepend(conv);
+                            result = conv->Result();
+                        }
+                        break;
+                    }
                     default: {
                         break;
                     }
diff --git a/src/tint/lang/spirv/reader/lower/shader_io_test.cc b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
index c7d7672..fbabf03 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io_test.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
@@ -2261,6 +2261,96 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(SpirvReader_ShaderIOTest, LocalInvocationId_i32) {
+    auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.vec3<i32>()));
+    idx->SetBuiltin(core::BuiltinValue::kLocalInvocationId);
+    mod.root_block->Append(idx);
+
+    auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+    ep->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
+    b.Append(ep->Block(), [&] {
+        auto* idx_value = b.Load(idx);
+        b.Let("a", b.Multiply(ty.vec3<i32>(), idx_value, b.Splat(ty.vec3<i32>(), 2_i)));
+        b.Return(ep);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %idx:ptr<__in, vec3<i32>, read> = var undef @builtin(local_invocation_id)
+}
+
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B2: {
+    %3:vec3<i32> = load %idx
+    %4:vec3<i32> = mul %3, vec3<i32>(2i)
+    %a:vec3<i32> = let %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func(%idx:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:vec3<i32> = convert %idx
+    %4:vec3<i32> = mul %3, vec3<i32>(2i)
+    %a:vec3<i32> = let %4
+    ret
+  }
+}
+)";
+
+    Run(ShaderIO);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvReader_ShaderIOTest, LocalInvocationId_u32) {
+    auto* idx = b.Var("idx", ty.ptr(core::AddressSpace::kIn, ty.vec3<u32>()));
+    idx->SetBuiltin(core::BuiltinValue::kLocalInvocationId);
+    mod.root_block->Append(idx);
+
+    auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+    ep->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
+    b.Append(ep->Block(), [&] {
+        auto* idx_value = b.Load(idx);
+        b.Let("a", b.Multiply(ty.vec3<u32>(), idx_value, b.Splat(ty.vec3<u32>(), 2_u)));
+
+        b.Return(ep);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %idx:ptr<__in, vec3<u32>, read> = var undef @builtin(local_invocation_id)
+}
+
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B2: {
+    %3:vec3<u32> = load %idx
+    %4:vec3<u32> = mul %3, vec3<u32>(2u)
+    %a:vec3<u32> = let %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func(%idx:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:vec3<u32> = mul %idx, vec3<u32>(2u)
+    %a:vec3<u32> = let %3
+    ret
+  }
+}
+)";
+
+    Run(ShaderIO);
+
+    EXPECT_EQ(expect, str());
+}
+
 // Test that a sample mask array is converted to a scalar u32 for the entry point.
 TEST_F(SpirvReader_ShaderIOTest, SampleMask) {
     auto* arr = ty.array<u32, 1>();