D3D12: Support [[num_workgroups]] for Dispatch

This patch implements [[num_workgroups]] on the API side for
Dispatch() calls by setting num_workgroups.xyz as root constants.

This patch also adds a temporary validation that on D3D12 backend
using a compute pipeline with [[num_workgroups]] in a
DispatchIndirect call is not supported.

BUG=dawn:839
TEST=dawn_end2end_tests

Change-Id: Iaee2ffd162e9420e4e80944fbb222f10a4600c6a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/66580
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index bd3989b..cfb45a1 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -677,6 +677,8 @@
                     metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
                     metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
                     metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
+
+                    metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
                 }
 
                 if (metadata->stage == SingleShaderStage::Vertex) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 8f8081c..070d001 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -204,6 +204,8 @@
 
         // Store overridableConstants from tint program
         std::unordered_map<std::string, OverridableConstant> overridableConstants;
+
+        bool usesNumWorkgroups = false;
     };
 
     class ShaderModuleBase : public ApiObjectBase, public CachedObject {
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 4bdc3b0..dc32d10 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -254,6 +254,18 @@
             return {};
         }
 
+        void RecordNumWorkgroupsForDispatch(ID3D12GraphicsCommandList* commandList,
+                                            ComputePipeline* pipeline,
+                                            DispatchCmd* dispatch) {
+            if (!pipeline->UsesNumWorkgroups()) {
+                return;
+            }
+
+            PipelineLayout* layout = ToBackend(pipeline->GetLayout());
+            commandList->SetComputeRoot32BitConstants(layout->GetNumWorkgroupsParameterIndex(), 3,
+                                                      dispatch, 0);
+        }
+
         // Records the necessary barriers for a synchronization scope using the resource usage
         // data pre-computed in the frontend. Also performs lazy initialization if required.
         // Returns whether any UAV are used in the synchronization scope.
@@ -1030,6 +1042,7 @@
         ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
 
         Command type;
+        ComputePipeline* lastPipeline = nullptr;
         while (mCommands.NextCommandId(&type)) {
             switch (type) {
                 case Command::Dispatch: {
@@ -1045,6 +1058,7 @@
                                                    resourceUsages.dispatchUsages[currentDispatch]);
                     DAWN_TRY(bindingTracker->Apply(commandContext));
 
+                    RecordNumWorkgroupsForDispatch(commandList, lastPipeline, dispatch);
                     commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
                     currentDispatch++;
                     break;
@@ -1052,6 +1066,14 @@
 
                 case Command::DispatchIndirect: {
                     DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
+
+                    // TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
+                    if (lastPipeline->UsesNumWorkgroups()) {
+                        return DAWN_VALIDATION_ERROR(
+                            "Using a compute pipeline with [[num_workgroups]] in a "
+                            "DispatchIndirect call is not implemented");
+                    }
+
                     Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
 
                     TransitionAndClearForSyncScope(commandContext,
@@ -1078,6 +1100,7 @@
                     commandList->SetPipelineState(pipeline->GetPipelineState());
 
                     bindingTracker->OnSetPipeline(pipeline);
+                    lastPipeline = pipeline;
                     break;
                 }
 
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 29ae08a..54ddc7e 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -84,4 +84,8 @@
         CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
     }
 
+    bool ComputePipeline::UsesNumWorkgroups() const {
+        return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
+    }
+
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.h b/src/dawn_native/d3d12/ComputePipelineD3D12.h
index 7c7a02d..b652026 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.h
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.h
@@ -40,6 +40,8 @@
         // Dawn API
         void SetLabelImpl() override;
 
+        bool UsesNumWorkgroups() const;
+
       private:
         ~ComputePipeline() override;
         using ComputePipelineBase::ComputePipelineBase;
diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
index 372b61b..1a512fa 100644
--- a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
+++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
@@ -174,6 +174,21 @@
         // would need to be updated often
         rootParameters.emplace_back(indexOffsetConstants);
 
+        // Always allocate 3 constants for num_workgroups_x, num_workgroups_y and num_workgroups_z
+        // for Dispatch calls
+        // NOTE: We should consider delaying root signature creation until we know how many values
+        // we need
+        D3D12_ROOT_PARAMETER numWorkgroupsConstants{};
+        numWorkgroupsConstants.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
+        numWorkgroupsConstants.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
+        numWorkgroupsConstants.Constants.Num32BitValues = 3;
+        numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
+        numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
+        mNumWorkgroupsParamterIndex = rootParameters.size();
+        // NOTE: We should consider moving this entry to earlier in the root signature since
+        // dispatch sizes would need to be updated often
+        rootParameters.emplace_back(numWorkgroupsConstants);
+
         D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor;
         rootSignatureDescriptor.NumParameters = rootParameters.size();
         rootSignatureDescriptor.pParameters = rootParameters.data();
@@ -230,7 +245,7 @@
     }
 
     uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const {
-        return kReservedRegisterSpace;
+        return kFirstIndexOffsetRegisterSpace;
     }
 
     uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const {
@@ -240,4 +255,16 @@
     uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const {
         return mFirstIndexOffsetParameterIndex;
     }
+
+    uint32_t PipelineLayout::GetNumWorkgroupsRegisterSpace() const {
+        return kNumWorkgroupsRegisterSpace;
+    }
+
+    uint32_t PipelineLayout::GetNumWorkgroupsShaderRegister() const {
+        return kNumWorkgroupsBaseRegister;
+    }
+
+    uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
+        return mNumWorkgroupsParamterIndex;
+    }
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.h b/src/dawn_native/d3d12/PipelineLayoutD3D12.h
index b1efc0d..cf52f06 100644
--- a/src/dawn_native/d3d12/PipelineLayoutD3D12.h
+++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.h
@@ -26,6 +26,9 @@
     // We reserve a register space that a user cannot use.
     static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1;
     static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0;
+    static constexpr uint32_t kFirstIndexOffsetRegisterSpace = kReservedRegisterSpace;
+    static constexpr uint32_t kNumWorkgroupsRegisterSpace = kReservedRegisterSpace + 1;
+    static constexpr uint32_t kNumWorkgroupsBaseRegister = 0;
 
     class Device;
 
@@ -46,6 +49,10 @@
         uint32_t GetFirstIndexOffsetShaderRegister() const;
         uint32_t GetFirstIndexOffsetParameterIndex() const;
 
+        uint32_t GetNumWorkgroupsRegisterSpace() const;
+        uint32_t GetNumWorkgroupsShaderRegister() const;
+        uint32_t GetNumWorkgroupsParameterIndex() const;
+
         ID3D12RootSignature* GetRootSignature() const;
 
       private:
@@ -59,6 +66,7 @@
                     kMaxBindGroups>
             mDynamicRootParameterIndices;
         uint32_t mFirstIndexOffsetParameterIndex;
+        uint32_t mNumWorkgroupsParamterIndex;
         ComPtr<ID3D12RootSignature> mRootSignature;
     };
 
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 2dafc96..89b5825 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -106,6 +106,9 @@
             tint::transform::BindingRemapper::BindingPoints bindingPoints;
             tint::transform::BindingRemapper::AccessControls accessControls;
             bool isRobustnessEnabled;
+            bool usesNumWorkgroups;
+            uint32_t numWorkgroupsRegisterSpace;
+            uint32_t numWorkgroupsShaderRegister;
 
             // FXC/DXC common inputs
             bool disableWorkgroupInit;
@@ -125,7 +128,7 @@
                 uint32_t compileFlags,
                 const Device* device,
                 const tint::Program* program,
-                const BindingInfoArray& moduleBindingInfo) {
+                const EntryPointMetadata& entryPoint) {
                 Compiler compiler;
                 uint64_t dxcVersion = 0;
                 if (device->IsToggleEnabled(Toggle::UseDXC)) {
@@ -145,6 +148,7 @@
                 // Tint AST to make the "bindings" decoration match the offset chosen by
                 // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
                 // assigned to each interface variable.
+                const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
                 for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
                     const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
                     const auto& groupBindingInfo = moduleBindingInfo[group];
@@ -189,6 +193,9 @@
                 request.isRobustnessEnabled = device->IsRobustnessEnabled();
                 request.disableWorkgroupInit =
                     device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
+                request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
+                request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
+                request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
                 request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
                 request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
                 request.deviceInfo = &device->GetDeviceInfo();
@@ -234,6 +241,10 @@
                 stream << " accessControls=";
                 Serialize(stream, accessControls);
 
+                stream << " useNumWorkgroups=" << usesNumWorkgroups;
+                stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
+                stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
+
                 stream << " shaderModel=" << deviceInfo->shaderModel;
                 stream << " disableWorkgroupInit=" << disableWorkgroupInit;
                 stream << " isRobustnessEnabled=" << isRobustnessEnabled;
@@ -423,6 +434,10 @@
 
             tint::writer::hlsl::Options options;
             options.disable_workgroup_init = request.disableWorkgroupInit;
+            if (request.usesNumWorkgroups) {
+                options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
+                options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
+            }
             auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
             if (!result.success) {
                 errorStream << "Generator: " << result.error << std::endl;
@@ -547,9 +562,9 @@
         }
 
         ShaderCompilationRequest request;
-        DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
-                                     entryPointName, stage, layout, compileFlags, device, program,
-                                     GetEntryPoint(entryPointName).bindings));
+        DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
+                                                                  compileFlags, device, program,
+                                                                  GetEntryPoint(entryPointName)));
 
         PersistentCacheKey shaderCacheKey;
         DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp
index 7ec4076..1a8b163 100644
--- a/src/tests/end2end/ComputeDispatchTests.cpp
+++ b/src/tests/end2end/ComputeDispatchTests.cpp
@@ -26,9 +26,30 @@
         DawnTest::SetUp();
 
         // Write workgroup number into the output buffer if we saw the biggest dispatch
-        // This is a workaround since D3D12 doesn't have gl_NumWorkGroups
         // To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
-        wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
+            [[block]] struct OutputBuf {
+                workGroups : vec3<u32>;
+            };
+
+            [[group(0), binding(0)]] var<storage, read_write> output : OutputBuf;
+
+            [[stage(compute), workgroup_size(1, 1, 1)]]
+            fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>,
+                    [[builtin(num_workgroups)]] dispatch : vec3<u32>) {
+                if (dispatch.x == 0u || dispatch.y == 0u || dispatch.z == 0u) {
+                    output.workGroups = vec3<u32>(0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu);
+                    return;
+                }
+
+                if (all(GlobalInvocationID == dispatch - vec3<u32>(1u, 1u, 1u))) {
+                    output.workGroups = dispatch;
+                }
+            })");
+
+        // TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
+        // [[num_workgroups]] for indirect dispatch.
+        wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
             [[block]] struct InputBuf {
                 expectedDispatch : vec3<u32>;
             };
@@ -54,9 +75,12 @@
             })");
 
         wgpu::ComputePipelineDescriptor csDesc;
-        csDesc.compute.module = module;
+        csDesc.compute.module = moduleForDispatch;
         csDesc.compute.entryPoint = "main";
-        pipeline = device.CreateComputePipeline(&csDesc);
+        pipelineForDispatch = device.CreateComputePipeline(&csDesc);
+
+        csDesc.compute.module = moduleForDispatchIndirect;
+        pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
     }
 
     void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
@@ -66,23 +90,18 @@
             wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
             kSentinelData);
 
-        std::initializer_list<uint32_t> expectedBufferData{x, y, z};
-        wgpu::Buffer expectedBuffer = utils::CreateBufferFromData<uint32_t>(
-            device, wgpu::BufferUsage::Uniform, expectedBufferData);
-
         // Set up bind group and issue dispatch
         wgpu::BindGroup bindGroup =
-            utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+            utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
                                  {
-                                     {0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
-                                     {1, dst, 0, 3 * sizeof(uint32_t)},
+                                     {0, dst, 0, 3 * sizeof(uint32_t)},
                                  });
 
         wgpu::CommandBuffer commands;
         {
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
-            pass.SetPipeline(pipeline);
+            pass.SetPipeline(pipelineForDispatch);
             pass.SetBindGroup(0, bindGroup);
             pass.Dispatch(x, y, z);
             pass.EndPass();
@@ -93,7 +112,7 @@
         queue.Submit(1, &commands);
 
         std::vector<uint32_t> expected =
-            x == 0 || y == 0 || z == 0 ? kSentinelData : expectedBufferData;
+            x == 0 || y == 0 || z == 0 ? kSentinelData : std::initializer_list<uint32_t>{x, y, z};
 
         // Verify the dispatch got called if all group counts are not zero
         EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
@@ -118,7 +137,7 @@
 
         // Set up bind group and issue dispatch
         wgpu::BindGroup bindGroup =
-            utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+            utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
                                  {
                                      {0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
                                      {1, dst, 0, 3 * sizeof(uint32_t)},
@@ -128,7 +147,7 @@
         {
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
-            pass.SetPipeline(pipeline);
+            pass.SetPipeline(pipelineForDispatchIndirect);
             pass.SetBindGroup(0, bindGroup);
             pass.DispatchIndirect(indirectBuffer, indirectOffset);
             pass.EndPass();
@@ -153,7 +172,8 @@
     }
 
   private:
-    wgpu::ComputePipeline pipeline;
+    wgpu::ComputePipeline pipelineForDispatch;
+    wgpu::ComputePipeline pipelineForDispatchIndirect;
 };
 
 // Test basic direct