spirv-reader: support NumWorkgroups

Fixed: tint:1065
Change-Id: Id2a8af247e7da79933703e634478f1dec25f9145
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110220
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/tint/reader/spirv/enum_converter.cc b/src/tint/reader/spirv/enum_converter.cc
index d590507..8c70156 100644
--- a/src/tint/reader/spirv/enum_converter.cc
+++ b/src/tint/reader/spirv/enum_converter.cc
@@ -82,6 +82,8 @@
             return ast::BuiltinValue::kLocalInvocationIndex;
         case spv::BuiltIn::GlobalInvocationId:
             return ast::BuiltinValue::kGlobalInvocationId;
+        case spv::BuiltIn::NumWorkgroups:
+            return ast::BuiltinValue::kNumWorkgroups;
         case spv::BuiltIn::WorkgroupId:
             return ast::BuiltinValue::kWorkgroupId;
         case spv::BuiltIn::SampleId:
diff --git a/src/tint/reader/spirv/enum_converter_test.cc b/src/tint/reader/spirv/enum_converter_test.cc
index b366ebd..bdb0247 100644
--- a/src/tint/reader/spirv/enum_converter_test.cc
+++ b/src/tint/reader/spirv/enum_converter_test.cc
@@ -192,6 +192,7 @@
         BuiltinCase{spv::BuiltIn::LocalInvocationIndex, true,
                     ast::BuiltinValue::kLocalInvocationIndex},
         BuiltinCase{spv::BuiltIn::GlobalInvocationId, true, ast::BuiltinValue::kGlobalInvocationId},
+        BuiltinCase{spv::BuiltIn::NumWorkgroups, true, ast::BuiltinValue::kNumWorkgroups},
         BuiltinCase{spv::BuiltIn::WorkgroupId, true, ast::BuiltinValue::kWorkgroupId},
         BuiltinCase{spv::BuiltIn::SampleId, true, ast::BuiltinValue::kSampleIndex},
         BuiltinCase{spv::BuiltIn::SampleMask, true, ast::BuiltinValue::kSampleMask}));
@@ -208,8 +209,6 @@
                          testing::Values(BuiltinCase{static_cast<spv::BuiltIn>(9999), false,
                                                      ast::BuiltinValue::kUndefined},
                                          BuiltinCase{static_cast<spv::BuiltIn>(9999), false,
-                                                     ast::BuiltinValue::kUndefined},
-                                         BuiltinCase{spv::BuiltIn::NumWorkgroups, false,
                                                      ast::BuiltinValue::kUndefined}));
 
 // Dim
diff --git a/src/tint/reader/spirv/parser_impl_module_var_test.cc b/src/tint/reader/spirv/parser_impl_module_var_test.cc
index 44fc4b3..be931a6 100644
--- a/src/tint/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/tint/reader/spirv/parser_impl_module_var_test.cc
@@ -3106,6 +3106,14 @@
 // Returns the start of a shader for testing LocalInvocationIndex,
 // parameterized by store type of %int or %uint
 std::string ComputeBuiltinInputPreamble(std::string builtin, std::string store_type) {
+    std::string ptr_component_type;
+    if (store_type == "%v3int") {
+        ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %int\n";
+    }
+    if (store_type == "%v3uint") {
+        ptr_component_type = " %ptr_comp_ty = OpTypePointer Input %uint\n";
+    }
+
     return R"(
     OpCapability Shader
     OpMemoryModel Logical Simple
@@ -3118,10 +3126,11 @@
     %float = OpTypeFloat 32
     %uint = OpTypeInt 32 0
     %int = OpTypeInt 32 1
+    %int_1 = OpConstant %int 1
     %v3uint = OpTypeVector %uint 3
     %v3int = OpTypeVector %int 3
     %ptr_ty = OpTypePointer Input )" +
-           store_type + R"(
+           store_type + ptr_component_type + R"(
     %1 = OpVariable %ptr_ty Input
 )";
 }
@@ -3331,14 +3340,84 @@
                              {"LocalInvocationId", "%v3int", "local_invocation_id"},
                              {"GlobalInvocationId", "%v3uint", "global_invocation_id"},
                              {"GlobalInvocationId", "%v3int", "global_invocation_id"},
+                             {"NumWorkgroups", "%v3uint", "num_workgroups"},
+                             {"NumWorkgroups", "%v3int", "num_workgroups"},
                              {"WorkgroupId", "%v3uint", "workgroup_id"},
                              {"WorkgroupId", "%v3int", "workgroup_id"}}));
 
-// TODO(dneto): crbug.com/tint/752
-// NumWorkgroups support is blocked by crbug.com/tint/752
-// When the AST supports NumWorkgroups, add these cases:
-//        {"NumWorkgroups", "%uint", "num_workgroups"}
-//        {"NumWorkgroups", "%int", "num_workgroups"}
+// For compute shader builtins that are vectors, test loading one component.
+struct ComputeBuiltinInputVectorCase {
+    std::string spirv_builtin;
+    std::string spirv_store_type;
+    std::string spirv_component_store_type;
+    std::string wgsl_builtin;
+};
+inline std::ostream& operator<<(std::ostream& o, ComputeBuiltinInputVectorCase c) {
+    return o << "ComputeBuiltinInputVectorCase(" << c.spirv_builtin << " " << c.spirv_store_type
+             << " " << c.spirv_component_store_type << " " << c.wgsl_builtin << ")";
+}
+
+using SpvModuleScopeVarParserTest_ComputeBuiltinVector =
+    SpvParserTestBase<::testing::TestWithParam<ComputeBuiltinInputVectorCase>>;
+
+TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltinVector, Load_Component_Direct) {
+    const auto wgsl_type = WgslType(GetParam().spirv_store_type);
+    const auto wgsl_component_type = WgslType(GetParam().spirv_component_store_type);
+    const auto wgsl_builtin = GetParam().wgsl_builtin;
+    const auto unsigned_wgsl_type = UnsignedWgslType(wgsl_type);
+    const auto signed_wgsl_type = SignedWgslType(wgsl_type);
+    const std::string assembly =
+        ComputeBuiltinInputPreamble(GetParam().spirv_builtin, GetParam().spirv_store_type) +
+        R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %3 = OpAccessChain %ptr_comp_ty %1 %int_1
+    %2 = OpLoad )" +
+        GetParam().spirv_component_store_type + R"( %3
+    OpReturn
+    OpFunctionEnd
+ )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+    EXPECT_TRUE(p->error().empty());
+    const auto module_str = test::ToString(p->program());
+    std::string expected = R"(var<private> x_1 : ${wgsl_type};
+
+fn main_1() {
+  let x_2 : ${wgsl_component_type} = x_1.y;
+  return;
+}
+
+@compute @workgroup_size(1i, 1i, 1i)
+fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) {
+  x_1 = ${assignment_value};
+  main_1();
+}
+)";
+
+    expected = utils::ReplaceAll(expected, "${wgsl_type}", wgsl_type);
+    expected = utils::ReplaceAll(expected, "${wgsl_component_type}", wgsl_component_type);
+    expected = utils::ReplaceAll(expected, "${unsigned_wgsl_type}", unsigned_wgsl_type);
+    expected = utils::ReplaceAll(expected, "${wgsl_builtin}", wgsl_builtin);
+    expected = utils::ReplaceAll(expected, "${assignment_value}",
+                                 (wgsl_type == unsigned_wgsl_type)
+                                     ? "x_1_param"
+                                     : "bitcast<" + signed_wgsl_type + ">(x_1_param)");
+
+    EXPECT_EQ(module_str, expected) << module_str;
+}
+
+INSTANTIATE_TEST_SUITE_P(Samples,
+                         SpvModuleScopeVarParserTest_ComputeBuiltinVector,
+                         ::testing::ValuesIn(std::vector<ComputeBuiltinInputVectorCase>{
+                             {"LocalInvocationId", "%v3uint", "%uint", "local_invocation_id"},
+                             {"LocalInvocationId", "%v3int", "%int", "local_invocation_id"},
+                             {"GlobalInvocationId", "%v3uint", "%uint", "global_invocation_id"},
+                             {"GlobalInvocationId", "%v3int", "%int", "global_invocation_id"},
+                             {"NumWorkgroups", "%v3uint", "%uint", "num_workgroups"},
+                             {"NumWorkgroups", "%v3int", "%int", "num_workgroups"},
+                             {"WorkgroupId", "%v3uint", "%uint", "workgroup_id"},
+                             {"WorkgroupId", "%v3int", "%int", "workgroup_id"}}));
 
 TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
     const std::string assembly =