Implement and test num_workgroups for all backends

For HLSL, use the new NumWorkgroupsFromUniform transform, and expose
the binding point to use for the generated uniform as a backend
option.

The MSL mapping is trivial, and it was already implemented for WGSL
and SPIR-V.

Bug: tint:752
Change-Id: I4bd37b5d26181629d72b152fe064a60caf8ecdc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63962
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/docs/origin-trial-changes.md b/docs/origin-trial-changes.md
index 620223d..253c961 100644
--- a/docs/origin-trial-changes.md
+++ b/docs/origin-trial-changes.md
@@ -4,6 +4,7 @@
 
 ### New Features
 * The size of an array can now be defined using a non-overridable module-scope constant
+* The `num_workgroups` builtin is now supported.
 
 ### Fixes
 * Hex floats: issue an error when the magnitude is non-zero, and the exponent would cause
diff --git a/src/writer/hlsl/generator.cc b/src/writer/hlsl/generator.cc
index d2ee601..2f9a58a 100644
--- a/src/writer/hlsl/generator.cc
+++ b/src/writer/hlsl/generator.cc
@@ -28,7 +28,8 @@
   Result result;
 
   // Sanitize the program.
-  auto sanitized_result = Sanitize(program, options.disable_workgroup_init);
+  auto sanitized_result = Sanitize(program, options.root_constant_binding_point,
+                                   options.disable_workgroup_init);
   if (!sanitized_result.program.IsValid()) {
     result.success = false;
     result.error = sanitized_result.program.Diagnostics().str();
diff --git a/src/writer/hlsl/generator.h b/src/writer/hlsl/generator.h
index 1ff8d98..6930029 100644
--- a/src/writer/hlsl/generator.h
+++ b/src/writer/hlsl/generator.h
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "src/ast/pipeline_stage.h"
+#include "src/sem/binding_point.h"
 #include "src/writer/text.h"
 
 namespace tint {
@@ -36,6 +37,8 @@
 
 /// Configuration options used for generating HLSL.
 struct Options {
+  /// The binding point to use for information passed via root constants.
+  sem::BindingPoint root_constant_binding_point;
   /// Set to `true` to disable workgroup memory zero initialization
   bool disable_workgroup_init = false;
 };
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index c70e331..94e5de5 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -51,6 +51,7 @@
 #include "src/transform/inline_pointer_lets.h"
 #include "src/transform/loop_to_for_loop.h"
 #include "src/transform/manager.h"
+#include "src/transform/num_workgroups_from_uniform.h"
 #include "src/transform/pad_array_elements.h"
 #include "src/transform/promote_initializers_to_const_var.h"
 #include "src/transform/simplify.h"
@@ -113,7 +114,9 @@
 
 }  // namespace
 
-SanitizedResult Sanitize(const Program* in, bool disable_workgroup_init) {
+SanitizedResult Sanitize(const Program* in,
+                         sem::BindingPoint root_constant_binding_point,
+                         bool disable_workgroup_init) {
   transform::Manager manager;
   transform::DataMap data;
 
@@ -128,6 +131,10 @@
     manager.Add<transform::ZeroInitWorkgroupMemory>();
   }
   manager.Add<transform::CanonicalizeEntryPointIO>();
+  // NumWorkgroupsFromUniform must come after CanonicalizeEntryPointIO, as it
+  // assumes that num_workgroups builtins only appear as struct members and are
+  // only accessed directly via member accessors.
+  manager.Add<transform::NumWorkgroupsFromUniform>();
   manager.Add<transform::InlinePointerLets>();
   // Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
   manager.Add<transform::Simplify>();
@@ -147,6 +154,8 @@
 
   data.Add<transform::CanonicalizeEntryPointIO::Config>(
       transform::CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+  data.Add<transform::NumWorkgroupsFromUniform::Config>(
+      root_constant_binding_point);
 
   SanitizedResult result;
   result.program = std::move(manager.Run(in, data).program);
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 17a6674..5717d71 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -33,6 +33,7 @@
 #include "src/ast/unary_op_expression.h"
 #include "src/program_builder.h"
 #include "src/scope_stack.h"
+#include "src/sem/binding_point.h"
 #include "src/transform/decompose_memory_access.h"
 #include "src/utils/hash.h"
 #include "src/writer/text_generator.h"
@@ -55,9 +56,12 @@
 };
 
 /// Sanitize a program in preparation for generating HLSL.
+/// @param root_constant_binding_point the binding point to use for information
+/// that will be passed via root constants
 /// @param disable_workgroup_init `true` to disable workgroup memory zero
 /// @returns the sanitized program and any supplementary information
 SanitizedResult Sanitize(const Program* program,
+                         sem::BindingPoint root_constant_binding_point = {},
                          bool disable_workgroup_init = false);
 
 /// Implementation class for HLSL generator
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 87d7b50..b790c44 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1587,6 +1587,8 @@
       return "thread_position_in_grid";
     case ast::Builtin::kWorkgroupId:
       return "threadgroup_position_in_grid";
+    case ast::Builtin::kNumWorkgroups:
+      return "threadgroups_per_grid";
     case ast::Builtin::kSampleIndex:
       return "sample_id";
     case ast::Builtin::kSampleMask:
diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc
index f5abc64..5ce939e 100644
--- a/src/writer/msl/generator_impl_test.cc
+++ b/src/writer/msl/generator_impl_test.cc
@@ -75,6 +75,8 @@
                                    "thread_position_in_grid"},
                     MslBuiltinData{ast::Builtin::kWorkgroupId,
                                    "threadgroup_position_in_grid"},
+                    MslBuiltinData{ast::Builtin::kNumWorkgroups,
+                                   "threadgroups_per_grid"},
                     MslBuiltinData{ast::Builtin::kSampleIndex, "sample_id"},
                     MslBuiltinData{ast::Builtin::kSampleMask, "sample_mask"},
                     MslBuiltinData{ast::Builtin::kPointSize, "point_size"}));
diff --git a/test/shader_io/compute_input_builtins.wgsl b/test/shader_io/compute_input_builtins.wgsl
index a90fc66..016ec31 100644
--- a/test/shader_io/compute_input_builtins.wgsl
+++ b/test/shader_io/compute_input_builtins.wgsl
@@ -4,11 +4,12 @@
   [[builtin(local_invocation_index)]] local_invocation_index : u32,
   [[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>,
   [[builtin(workgroup_id)]] workgroup_id : vec3<u32>,
-  // TODO(crbug.com/tint/752): [[builtin(num_workgroups)]] num_workgroups : vec3<u32>;
+  [[builtin(num_workgroups)]] num_workgroups : vec3<u32>,
 ) {
   let foo : u32 =
     local_invocation_id.x +
-    local_invocation_index + 
-    global_invocation_id.x + 
-    workgroup_id.x;
+    local_invocation_index +
+    global_invocation_id.x +
+    workgroup_id.x +
+    num_workgroups.x;
 }
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.hlsl b/test/shader_io/compute_input_builtins.wgsl.expected.hlsl
index cb71cd9..d6ab749 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.hlsl
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.hlsl
@@ -1,3 +1,7 @@
+cbuffer cbuffer_tint_symbol_3 : register(b0, space0) {
+  uint4 tint_symbol_3[1];
+};
+
 struct tint_symbol_1 {
   uint3 local_invocation_id : SV_GroupThreadID;
   uint local_invocation_index : SV_GroupIndex;
@@ -5,12 +9,12 @@
   uint3 workgroup_id : SV_GroupID;
 };
 
-void main_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id) {
-  const uint foo = (((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x);
+void main_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id, uint3 num_workgroups) {
+  const uint foo = ((((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x) + num_workgroups.x);
 }
 
 [numthreads(1, 1, 1)]
 void main(tint_symbol_1 tint_symbol) {
-  main_inner(tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id);
+  main_inner(tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id, tint_symbol_3[0].xyz);
   return;
 }
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.msl b/test/shader_io/compute_input_builtins.wgsl.expected.msl
index 60095f3..67ecc1b 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.msl
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.msl
@@ -1,12 +1,12 @@
 #include <metal_stdlib>
 
 using namespace metal;
-void tint_symbol_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id) {
-  uint const foo = (((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x);
+void tint_symbol_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id, uint3 num_workgroups) {
+  uint const foo = ((((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x) + num_workgroups.x);
 }
 
-kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]]) {
-  tint_symbol_inner(local_invocation_id, local_invocation_index, global_invocation_id, workgroup_id);
+kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]], uint3 num_workgroups [[threadgroups_per_grid]]) {
+  tint_symbol_inner(local_invocation_id, local_invocation_index, global_invocation_id, workgroup_id, num_workgroups);
   return;
 }
 
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.spvasm b/test/shader_io/compute_input_builtins.wgsl.expected.spvasm
index a1dddd0..6c5e5d6 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.spvasm
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.spvasm
@@ -1,26 +1,29 @@
 ; SPIR-V
 ; Version: 1.3
 ; Generator: Google Tint Compiler; 0
-; Bound: 31
+; Bound: 36
 ; Schema: 0
                OpCapability Shader
                OpMemoryModel Logical GLSL450
-               OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1
+               OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1 %num_workgroups_1
                OpExecutionMode %main LocalSize 1 1 1
                OpName %local_invocation_id_1 "local_invocation_id_1"
                OpName %local_invocation_index_1 "local_invocation_index_1"
                OpName %global_invocation_id_1 "global_invocation_id_1"
                OpName %workgroup_id_1 "workgroup_id_1"
+               OpName %num_workgroups_1 "num_workgroups_1"
                OpName %main_inner "main_inner"
                OpName %local_invocation_id "local_invocation_id"
                OpName %local_invocation_index "local_invocation_index"
                OpName %global_invocation_id "global_invocation_id"
                OpName %workgroup_id "workgroup_id"
+               OpName %num_workgroups "num_workgroups"
                OpName %main "main"
                OpDecorate %local_invocation_id_1 BuiltIn LocalInvocationId
                OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
                OpDecorate %global_invocation_id_1 BuiltIn GlobalInvocationId
                OpDecorate %workgroup_id_1 BuiltIn WorkgroupId
+               OpDecorate %num_workgroups_1 BuiltIn NumWorkgroups
        %uint = OpTypeInt 32 0
      %v3uint = OpTypeVector %uint 3
 %_ptr_Input_v3uint = OpTypePointer Input %v3uint
@@ -29,29 +32,34 @@
 %local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
 %global_invocation_id_1 = OpVariable %_ptr_Input_v3uint Input
 %workgroup_id_1 = OpVariable %_ptr_Input_v3uint Input
+%num_workgroups_1 = OpVariable %_ptr_Input_v3uint Input
        %void = OpTypeVoid
-          %9 = OpTypeFunction %void %v3uint %uint %v3uint %v3uint
-         %23 = OpTypeFunction %void
- %main_inner = OpFunction %void None %9
+         %10 = OpTypeFunction %void %v3uint %uint %v3uint %v3uint %v3uint
+         %27 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %10
 %local_invocation_id = OpFunctionParameter %v3uint
 %local_invocation_index = OpFunctionParameter %uint
 %global_invocation_id = OpFunctionParameter %v3uint
 %workgroup_id = OpFunctionParameter %v3uint
-         %16 = OpLabel
-         %17 = OpCompositeExtract %uint %local_invocation_id 0
-         %18 = OpIAdd %uint %17 %local_invocation_index
-         %19 = OpCompositeExtract %uint %global_invocation_id 0
-         %20 = OpIAdd %uint %18 %19
-         %21 = OpCompositeExtract %uint %workgroup_id 0
+%num_workgroups = OpFunctionParameter %v3uint
+         %18 = OpLabel
+         %19 = OpCompositeExtract %uint %local_invocation_id 0
+         %20 = OpIAdd %uint %19 %local_invocation_index
+         %21 = OpCompositeExtract %uint %global_invocation_id 0
          %22 = OpIAdd %uint %20 %21
+         %23 = OpCompositeExtract %uint %workgroup_id 0
+         %24 = OpIAdd %uint %22 %23
+         %25 = OpCompositeExtract %uint %num_workgroups 0
+         %26 = OpIAdd %uint %24 %25
                OpReturn
                OpFunctionEnd
-       %main = OpFunction %void None %23
-         %25 = OpLabel
-         %27 = OpLoad %v3uint %local_invocation_id_1
-         %28 = OpLoad %uint %local_invocation_index_1
-         %29 = OpLoad %v3uint %global_invocation_id_1
-         %30 = OpLoad %v3uint %workgroup_id_1
-         %26 = OpFunctionCall %void %main_inner %27 %28 %29 %30
+       %main = OpFunction %void None %27
+         %29 = OpLabel
+         %31 = OpLoad %v3uint %local_invocation_id_1
+         %32 = OpLoad %uint %local_invocation_index_1
+         %33 = OpLoad %v3uint %global_invocation_id_1
+         %34 = OpLoad %v3uint %workgroup_id_1
+         %35 = OpLoad %v3uint %num_workgroups_1
+         %30 = OpFunctionCall %void %main_inner %31 %32 %33 %34 %35
                OpReturn
                OpFunctionEnd
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.wgsl b/test/shader_io/compute_input_builtins.wgsl.expected.wgsl
index 518b32e..42c8adb 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.wgsl
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.wgsl
@@ -1,4 +1,4 @@
 [[stage(compute), workgroup_size(1)]]
-fn main([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32, [[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>, [[builtin(workgroup_id)]] workgroup_id : vec3<u32>) {
-  let foo : u32 = (((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x);
+fn main([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32, [[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>, [[builtin(workgroup_id)]] workgroup_id : vec3<u32>, [[builtin(num_workgroups)]] num_workgroups : vec3<u32>) {
+  let foo : u32 = ((((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x) + num_workgroups.x);
 }
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl b/test/shader_io/compute_input_builtins_struct.wgsl
index 61b8f64..c9691c9 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl
@@ -3,14 +3,15 @@
   [[builtin(local_invocation_index)]] local_invocation_index : u32;
   [[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>;
   [[builtin(workgroup_id)]] workgroup_id : vec3<u32>;
-  // TODO(crbug.com/tint/752): [[builtin(num_workgroups)]] num_workgroups : vec3<u32>;
+  [[builtin(num_workgroups)]] num_workgroups : vec3<u32>;
 };
 
 [[stage(compute), workgroup_size(1)]]
 fn main(inputs : ComputeInputs) {
   let foo : u32 =
     inputs.local_invocation_id.x +
-    inputs.local_invocation_index + 
-    inputs.global_invocation_id.x + 
-    inputs.workgroup_id.x;
+    inputs.local_invocation_index +
+    inputs.global_invocation_id.x +
+    inputs.workgroup_id.x +
+    inputs.num_workgroups.x;
 }
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl b/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl
index 8e00493..783117e 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl
@@ -1,8 +1,13 @@
+cbuffer cbuffer_tint_symbol_3 : register(b0, space0) {
+  uint4 tint_symbol_3[1];
+};
+
 struct ComputeInputs {
   uint3 local_invocation_id;
   uint local_invocation_index;
   uint3 global_invocation_id;
   uint3 workgroup_id;
+  uint3 num_workgroups;
 };
 struct tint_symbol_1 {
   uint3 local_invocation_id : SV_GroupThreadID;
@@ -12,12 +17,12 @@
 };
 
 void main_inner(ComputeInputs inputs) {
-  const uint foo = (((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x);
+  const uint foo = ((((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x) + inputs.num_workgroups.x);
 }
 
 [numthreads(1, 1, 1)]
 void main(tint_symbol_1 tint_symbol) {
-  const ComputeInputs tint_symbol_2 = {tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id};
-  main_inner(tint_symbol_2);
+  const ComputeInputs tint_symbol_5 = {tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id, tint_symbol_3[0].xyz};
+  main_inner(tint_symbol_5);
   return;
 }
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl b/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl
index 56db89c..34db10e 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl
@@ -6,14 +6,15 @@
   uint local_invocation_index;
   uint3 global_invocation_id;
   uint3 workgroup_id;
+  uint3 num_workgroups;
 };
 
 void tint_symbol_inner(ComputeInputs inputs) {
-  uint const foo = (((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x);
+  uint const foo = ((((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x) + inputs.num_workgroups.x);
 }
 
-kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]]) {
-  ComputeInputs const tint_symbol_1 = {.local_invocation_id=local_invocation_id, .local_invocation_index=local_invocation_index, .global_invocation_id=global_invocation_id, .workgroup_id=workgroup_id};
+kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]], uint3 num_workgroups [[threadgroups_per_grid]]) {
+  ComputeInputs const tint_symbol_1 = {.local_invocation_id=local_invocation_id, .local_invocation_index=local_invocation_index, .global_invocation_id=global_invocation_id, .workgroup_id=workgroup_id, .num_workgroups=num_workgroups};
   tint_symbol_inner(tint_symbol_1);
   return;
 }
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm b/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm
index d270e75..9e26e28 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm
@@ -1,21 +1,23 @@
 ; SPIR-V
 ; Version: 1.3
 ; Generator: Google Tint Compiler; 0
-; Bound: 34
+; Bound: 39
 ; Schema: 0
                OpCapability Shader
                OpMemoryModel Logical GLSL450
-               OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1
+               OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1 %num_workgroups_1
                OpExecutionMode %main LocalSize 1 1 1
                OpName %local_invocation_id_1 "local_invocation_id_1"
                OpName %local_invocation_index_1 "local_invocation_index_1"
                OpName %global_invocation_id_1 "global_invocation_id_1"
                OpName %workgroup_id_1 "workgroup_id_1"
+               OpName %num_workgroups_1 "num_workgroups_1"
                OpName %ComputeInputs "ComputeInputs"
                OpMemberName %ComputeInputs 0 "local_invocation_id"
                OpMemberName %ComputeInputs 1 "local_invocation_index"
                OpMemberName %ComputeInputs 2 "global_invocation_id"
                OpMemberName %ComputeInputs 3 "workgroup_id"
+               OpMemberName %ComputeInputs 4 "num_workgroups"
                OpName %main_inner "main_inner"
                OpName %inputs "inputs"
                OpName %main "main"
@@ -23,10 +25,12 @@
                OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
                OpDecorate %global_invocation_id_1 BuiltIn GlobalInvocationId
                OpDecorate %workgroup_id_1 BuiltIn WorkgroupId
+               OpDecorate %num_workgroups_1 BuiltIn NumWorkgroups
                OpMemberDecorate %ComputeInputs 0 Offset 0
                OpMemberDecorate %ComputeInputs 1 Offset 12
                OpMemberDecorate %ComputeInputs 2 Offset 16
                OpMemberDecorate %ComputeInputs 3 Offset 32
+               OpMemberDecorate %ComputeInputs 4 Offset 48
        %uint = OpTypeInt 32 0
      %v3uint = OpTypeVector %uint 3
 %_ptr_Input_v3uint = OpTypePointer Input %v3uint
@@ -35,32 +39,37 @@
 %local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
 %global_invocation_id_1 = OpVariable %_ptr_Input_v3uint Input
 %workgroup_id_1 = OpVariable %_ptr_Input_v3uint Input
+%num_workgroups_1 = OpVariable %_ptr_Input_v3uint Input
        %void = OpTypeVoid
-%ComputeInputs = OpTypeStruct %v3uint %uint %v3uint %v3uint
-          %9 = OpTypeFunction %void %ComputeInputs
-         %25 = OpTypeFunction %void
- %main_inner = OpFunction %void None %9
+%ComputeInputs = OpTypeStruct %v3uint %uint %v3uint %v3uint %v3uint
+         %10 = OpTypeFunction %void %ComputeInputs
+         %29 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %10
      %inputs = OpFunctionParameter %ComputeInputs
-         %14 = OpLabel
-         %15 = OpCompositeExtract %v3uint %inputs 0
-         %16 = OpCompositeExtract %uint %15 0
-         %17 = OpCompositeExtract %uint %inputs 1
-         %18 = OpIAdd %uint %16 %17
-         %19 = OpCompositeExtract %v3uint %inputs 2
-         %20 = OpCompositeExtract %uint %19 0
-         %21 = OpIAdd %uint %18 %20
-         %22 = OpCompositeExtract %v3uint %inputs 3
-         %23 = OpCompositeExtract %uint %22 0
-         %24 = OpIAdd %uint %21 %23
+         %15 = OpLabel
+         %16 = OpCompositeExtract %v3uint %inputs 0
+         %17 = OpCompositeExtract %uint %16 0
+         %18 = OpCompositeExtract %uint %inputs 1
+         %19 = OpIAdd %uint %17 %18
+         %20 = OpCompositeExtract %v3uint %inputs 2
+         %21 = OpCompositeExtract %uint %20 0
+         %22 = OpIAdd %uint %19 %21
+         %23 = OpCompositeExtract %v3uint %inputs 3
+         %24 = OpCompositeExtract %uint %23 0
+         %25 = OpIAdd %uint %22 %24
+         %26 = OpCompositeExtract %v3uint %inputs 4
+         %27 = OpCompositeExtract %uint %26 0
+         %28 = OpIAdd %uint %25 %27
                OpReturn
                OpFunctionEnd
-       %main = OpFunction %void None %25
-         %27 = OpLabel
-         %29 = OpLoad %v3uint %local_invocation_id_1
-         %30 = OpLoad %uint %local_invocation_index_1
-         %31 = OpLoad %v3uint %global_invocation_id_1
-         %32 = OpLoad %v3uint %workgroup_id_1
-         %33 = OpCompositeConstruct %ComputeInputs %29 %30 %31 %32
-         %28 = OpFunctionCall %void %main_inner %33
+       %main = OpFunction %void None %29
+         %31 = OpLabel
+         %33 = OpLoad %v3uint %local_invocation_id_1
+         %34 = OpLoad %uint %local_invocation_index_1
+         %35 = OpLoad %v3uint %global_invocation_id_1
+         %36 = OpLoad %v3uint %workgroup_id_1
+         %37 = OpLoad %v3uint %num_workgroups_1
+         %38 = OpCompositeConstruct %ComputeInputs %33 %34 %35 %36 %37
+         %32 = OpFunctionCall %void %main_inner %38
                OpReturn
                OpFunctionEnd
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl b/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl
index 99b0c4f..23d354a 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl
@@ -7,9 +7,11 @@
   global_invocation_id : vec3<u32>;
   [[builtin(workgroup_id)]]
   workgroup_id : vec3<u32>;
+  [[builtin(num_workgroups)]]
+  num_workgroups : vec3<u32>;
 };
 
 [[stage(compute), workgroup_size(1)]]
 fn main(inputs : ComputeInputs) {
-  let foo : u32 = (((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x);
+  let foo : u32 = ((((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x) + inputs.num_workgroups.x);
 }