Update ComputePipelineDescriptor to use PipelineStageDescriptor

The contents of PipelineStageDescriptor were inlined inside of
ComputePipelineDescriptor. This changes updates
ComputePipelineDescriptor to contain PipelineStageDescriptor to match
WebGPU.

Bug: chromium:877147
Change-Id: Ic030b7bd7a237945cbbaf4c567cc361940e1ad00
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6400
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/dawn.json b/dawn.json
index a2516ba..82b243a 100644
--- a/dawn.json
+++ b/dawn.json
@@ -375,8 +375,7 @@
         "extensible": true,
         "members": [
             {"name": "layout", "type": "pipeline layout"},
-            {"name": "module", "type": "shader module"},
-            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"}
+            {"name": "compute stage", "type": "pipeline stage descriptor", "annotation": "const*"}
         ]
     },
     "device": {
diff --git a/examples/ComputeBoids.cpp b/examples/ComputeBoids.cpp
index e9f44ac..ae4234c 100644
--- a/examples/ComputeBoids.cpp
+++ b/examples/ComputeBoids.cpp
@@ -241,9 +241,13 @@
     dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
 
     dawn::ComputePipelineDescriptor csDesc;
-    csDesc.module = module;
-    csDesc.entryPoint = "main";
     csDesc.layout = pl;
+
+    dawn::PipelineStageDescriptor computeStage;
+    computeStage.module = module;
+    computeStage.entryPoint = "main";
+    csDesc.computeStage = &computeStage;
+
     updatePipeline = device.CreateComputePipeline(&csDesc);
 
     for (uint32_t i = 0; i < 2; ++i) {
diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp
index c95115b..a2fb60f 100644
--- a/src/dawn_native/ComputePipeline.cpp
+++ b/src/dawn_native/ComputePipeline.cpp
@@ -24,21 +24,9 @@
             return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
         }
 
-        DAWN_TRY(device->ValidateObject(descriptor->module));
         DAWN_TRY(device->ValidateObject(descriptor->layout));
-
-        if (descriptor->entryPoint != std::string("main")) {
-            return DAWN_VALIDATION_ERROR("Currently the entry point has to be main()");
-        }
-
-        if (descriptor->module->GetExecutionModel() != dawn::ShaderStage::Compute) {
-            return DAWN_VALIDATION_ERROR("Setting module with wrong execution model");
-        }
-
-        if (!descriptor->module->IsCompatibleWithPipelineLayout(descriptor->layout)) {
-            return DAWN_VALIDATION_ERROR("Stage not compatible with layout");
-        }
-
+        DAWN_TRY(ValidatePipelineStageDescriptor(device, descriptor->computeStage,
+                                                 descriptor->layout, dawn::ShaderStage::Compute));
         return {};
     }
 
@@ -47,7 +35,7 @@
     ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
                                              const ComputePipelineDescriptor* descriptor)
         : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
-        ExtractModuleData(dawn::ShaderStage::Compute, descriptor->module);
+        ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
     }
 
     ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index 257bf76..e839b1b 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -20,6 +20,24 @@
 
 namespace dawn_native {
 
+    MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
+                                               const PipelineStageDescriptor* descriptor,
+                                               const PipelineLayoutBase* layout,
+                                               dawn::ShaderStage stage) {
+        DAWN_TRY(device->ValidateObject(descriptor->module));
+
+        if (descriptor->entryPoint != std::string("main")) {
+            return DAWN_VALIDATION_ERROR("Entry point must be \"main\"");
+        }
+        if (descriptor->module->GetExecutionModel() != stage) {
+            return DAWN_VALIDATION_ERROR("Setting module with wrong stages");
+        }
+        if (!descriptor->module->IsCompatibleWithPipelineLayout(layout)) {
+            return DAWN_VALIDATION_ERROR("Stage not compatible with layout");
+        }
+        return {};
+    }
+
     // PipelineBase
 
     PipelineBase::PipelineBase(DeviceBase* device,
diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h
index 1f141f9..c917125 100644
--- a/src/dawn_native/Pipeline.h
+++ b/src/dawn_native/Pipeline.h
@@ -34,6 +34,11 @@
         Float,
     };
 
+    MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
+                                               const PipelineStageDescriptor* descriptor,
+                                               const PipelineLayoutBase* layout,
+                                               dawn::ShaderStage stage);
+
     class PipelineBase : public ObjectBase {
       public:
         struct PushConstantInfo {
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 9c2cf4e..eb8a7b8 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -97,24 +97,6 @@
             return {};
         }
 
-        MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
-                                                   const PipelineStageDescriptor* descriptor,
-                                                   const PipelineLayoutBase* layout,
-                                                   dawn::ShaderStage stage) {
-            DAWN_TRY(device->ValidateObject(descriptor->module));
-
-            if (descriptor->entryPoint != std::string("main")) {
-                return DAWN_VALIDATION_ERROR("Entry point must be \"main\"");
-            }
-            if (descriptor->module->GetExecutionModel() != stage) {
-                return DAWN_VALIDATION_ERROR("Setting module with wrong stages");
-            }
-            if (!descriptor->module->IsCompatibleWithPipelineLayout(layout)) {
-                return DAWN_VALIDATION_ERROR("Stage not compatible with layout");
-            }
-            return {};
-        }
-
         MaybeError ValidateColorStateDescriptor(const ColorStateDescriptor* descriptor) {
             if (descriptor->nextInChain != nullptr) {
                 return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 67f4cbc..d70846e 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -32,7 +32,7 @@
         // SPRIV-cross does matrix multiplication expecting row major matrices
         compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
-        const ShaderModule* module = ToBackend(descriptor->module);
+        const ShaderModule* module = ToBackend(descriptor->computeStage->module);
         const std::string& hlslSource = module->GetHLSLSource(ToBackend(GetLayout()));
 
         ComPtr<ID3DBlob> compiledShader;
@@ -40,8 +40,8 @@
 
         const PlatformFunctions* functions = device->GetFunctions();
         if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
-                                         nullptr, descriptor->entryPoint, "cs_5_1", compileFlags, 0,
-                                         &compiledShader, &errors))) {
+                                         nullptr, descriptor->computeStage->entryPoint, "cs_5_1",
+                                         compileFlags, 0, &compiledShader, &errors))) {
             printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer()));
             ASSERT(false);
         }
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index 5c6eaa6..fc76ee1 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -23,15 +23,14 @@
         : ComputePipelineBase(device, descriptor) {
         auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
 
-        const auto& module = ToBackend(descriptor->module);
-        const char* entryPoint = descriptor->entryPoint;
-
-        auto compilationData =
-            module->GetFunction(entryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout()));
+        const ShaderModule* computeModule = ToBackend(descriptor->computeStage->module);
+        const char* computeEntryPoint = descriptor->computeStage->entryPoint;
+        ShaderModule::MetalFunctionData computeData = computeModule->GetFunction(
+            computeEntryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout()));
 
         NSError* error = nil;
         mMtlComputePipelineState =
-            [mtlDevice newComputePipelineStateWithFunction:compilationData.function error:&error];
+            [mtlDevice newComputePipelineStateWithFunction:computeData.function error:&error];
         if (error != nil) {
             NSLog(@" error => %@", error);
             GetDevice()->HandleError("Error creating pipeline state");
@@ -39,7 +38,7 @@
         }
 
         // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
-        mLocalWorkgroupSize = compilationData.localWorkgroupSize;
+        mLocalWorkgroupSize = computeData.localWorkgroupSize;
     }
 
     ComputePipeline::~ComputePipeline() {
diff --git a/src/dawn_native/opengl/ComputePipelineGL.cpp b/src/dawn_native/opengl/ComputePipelineGL.cpp
index 815e4d7..2cbad4e 100644
--- a/src/dawn_native/opengl/ComputePipelineGL.cpp
+++ b/src/dawn_native/opengl/ComputePipelineGL.cpp
@@ -21,7 +21,7 @@
     ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
         : ComputePipelineBase(device, descriptor) {
         PerStage<const ShaderModule*> modules(nullptr);
-        modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->module);
+        modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->computeStage->module);
 
         PipelineGL::Initialize(ToBackend(descriptor->layout), modules);
     }
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index 06948b3..8e7c7aa 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -35,8 +35,8 @@
         createInfo.stage.pNext = nullptr;
         createInfo.stage.flags = 0;
         createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
-        createInfo.stage.module = ToBackend(descriptor->module)->GetHandle();
-        createInfo.stage.pName = descriptor->entryPoint;
+        createInfo.stage.module = ToBackend(descriptor->computeStage->module)->GetHandle();
+        createInfo.stage.pName = descriptor->computeStage->entryPoint;
         createInfo.stage.pSpecializationInfo = nullptr;
 
         if (device->fn.CreateComputePipelines(device->GetVkDevice(), VK_NULL_HANDLE, 1, &createInfo,
diff --git a/src/tests/end2end/BindGroupTests.cpp b/src/tests/end2end/BindGroupTests.cpp
index 9bf437e..3a46286 100644
--- a/src/tests/end2end/BindGroupTests.cpp
+++ b/src/tests/end2end/BindGroupTests.cpp
@@ -68,9 +68,13 @@
     dawn::ShaderModule module =
         utils::CreateShaderModule(device, dawn::ShaderStage::Compute, shader);
     dawn::ComputePipelineDescriptor cpDesc;
-    cpDesc.module = module;
-    cpDesc.entryPoint = "main";
     cpDesc.layout = pl;
+
+    dawn::PipelineStageDescriptor computeStage;
+    computeStage.module = module;
+    computeStage.entryPoint = "main";
+    cpDesc.computeStage = &computeStage;
+
     dawn::ComputePipeline cp = device.CreateComputePipeline(&cpDesc);
 
     dawn::BufferDescriptor bufferDesc;
diff --git a/src/tests/end2end/ComputeCopyStorageBufferTests.cpp b/src/tests/end2end/ComputeCopyStorageBufferTests.cpp
index 337f2f1..6cca206 100644
--- a/src/tests/end2end/ComputeCopyStorageBufferTests.cpp
+++ b/src/tests/end2end/ComputeCopyStorageBufferTests.cpp
@@ -39,9 +39,13 @@
     auto pl = utils::MakeBasicPipelineLayout(device, &bgl);
 
     dawn::ComputePipelineDescriptor csDesc;
-    csDesc.module = module;
-    csDesc.entryPoint = "main";
     csDesc.layout = pl;
+
+    dawn::PipelineStageDescriptor computeStage;
+    computeStage.module = module;
+    computeStage.entryPoint = "main";
+    csDesc.computeStage = &computeStage;
+
     dawn::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
 
     // Set up src storage buffer
diff --git a/src/tests/end2end/PushConstantTests.cpp b/src/tests/end2end/PushConstantTests.cpp
index c85aca6..e385a17 100644
--- a/src/tests/end2end/PushConstantTests.cpp
+++ b/src/tests/end2end/PushConstantTests.cpp
@@ -149,9 +149,13 @@
             );
 
             dawn::ComputePipelineDescriptor descriptor;
-            descriptor.module = module;
-            descriptor.entryPoint = "main";
             descriptor.layout = pl;
+
+            dawn::PipelineStageDescriptor computeStage;
+            computeStage.module = module;
+            computeStage.entryPoint = "main";
+            descriptor.computeStage = &computeStage;
+
             return device.CreateComputePipeline(&descriptor);
         }