Remove descriptor from the parameter of ComputePipeline::Initialize()

This patch removes the parameter "descriptor" in the function
ComputePipeline::Initialize() so that we don't need to define
FlatComputePipelineDescriptor right now.

For render pipeline, as descriptor->vertex is being used for vertex
pulling (passed into vertexModule->CreateFunction()), we will first
refactor the related code in vertex pulling before removing the
parameter "descriptor" in the function RenderPipeline::Initialize().

BUG=dawn:529

Change-Id: Ib172ac0c76fa24070e78c0e57c3262acad9399b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64000
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp
index 1b4753e..0575f90 100644
--- a/src/dawn_native/ComputePipeline.cpp
+++ b/src/dawn_native/ComputePipeline.cpp
@@ -19,31 +19,6 @@
 
 namespace dawn_native {
 
-    FlatComputePipelineDescriptor::FlatComputePipelineDescriptor(
-        const ComputePipelineDescriptor* descriptor)
-        : mLabel(descriptor->label != nullptr ? descriptor->label : ""),
-          mLayout(descriptor->layout) {
-        label = mLabel.c_str();
-        layout = mLayout.Get();
-
-        // TODO(dawn:800): Remove after deprecation period.
-        if (descriptor->compute.module == nullptr && descriptor->computeStage.module != nullptr) {
-            mComputeModule = descriptor->computeStage.module;
-            mEntryPoint = descriptor->computeStage.entryPoint;
-        } else {
-            mComputeModule = descriptor->compute.module;
-            mEntryPoint = descriptor->compute.entryPoint;
-        }
-
-        compute.entryPoint = mEntryPoint.c_str();
-        compute.module = mComputeModule.Get();
-    }
-
-    void FlatComputePipelineDescriptor::SetLayout(Ref<PipelineLayoutBase> appliedLayout) {
-        mLayout = std::move(appliedLayout);
-        layout = mLayout.Get();
-    }
-
     MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
                                                  const ComputePipelineDescriptor* descriptor) {
         if (descriptor->nextInChain != nullptr) {
@@ -92,7 +67,7 @@
         }
     }
 
-    MaybeError ComputePipelineBase::Initialize(const ComputePipelineDescriptor* descriptor) {
+    MaybeError ComputePipelineBase::Initialize() {
         return {};
     }
 
diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h
index ab92c45..95e58a6 100644
--- a/src/dawn_native/ComputePipeline.h
+++ b/src/dawn_native/ComputePipeline.h
@@ -23,22 +23,6 @@
     class DeviceBase;
     struct EntryPointMetadata;
 
-    // We use FlatComputePipelineDescriptor to keep all the members of ComputePipelineDescriptor
-    // (especially the members in pointers) valid in CreateComputePipelineAsyncTask when the
-    // creation of the compute pipeline is executed asynchronously.
-    struct FlatComputePipelineDescriptor : public ComputePipelineDescriptor, public NonMovable {
-      public:
-        explicit FlatComputePipelineDescriptor(const ComputePipelineDescriptor* descriptor);
-
-        void SetLayout(Ref<PipelineLayoutBase> appliedLayout);
-
-      private:
-        std::string mLabel;
-        Ref<PipelineLayoutBase> mLayout;
-        std::string mEntryPoint;
-        Ref<ShaderModuleBase> mComputeModule;
-    };
-
     MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
                                                  const ComputePipelineDescriptor* descriptor);
 
@@ -60,7 +44,7 @@
         // CreateComputePipelineAsyncTask is declared as a friend of ComputePipelineBase as it
         // needs to call the private member function ComputePipelineBase::Initialize().
         friend class CreateComputePipelineAsyncTask;
-        virtual MaybeError Initialize(const ComputePipelineDescriptor* descriptor);
+        virtual MaybeError Initialize();
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/CreatePipelineAsyncTask.cpp b/src/dawn_native/CreatePipelineAsyncTask.cpp
index 40410b9..6a35b94 100644
--- a/src/dawn_native/CreatePipelineAsyncTask.cpp
+++ b/src/dawn_native/CreatePipelineAsyncTask.cpp
@@ -103,23 +103,18 @@
 
     CreateComputePipelineAsyncTask::CreateComputePipelineAsyncTask(
         Ref<ComputePipelineBase> nonInitializedComputePipeline,
-        std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
         size_t blueprintHash,
         WGPUCreateComputePipelineAsyncCallback callback,
         void* userdata)
         : mComputePipeline(nonInitializedComputePipeline),
           mBlueprintHash(blueprintHash),
           mCallback(callback),
-          mUserdata(userdata),
-          mAppliedDescriptor(std::move(descriptor)) {
+          mUserdata(userdata) {
         ASSERT(mComputePipeline != nullptr);
-
-        // TODO(jiawei.shao@intel.com): save nextInChain when it is supported in Dawn.
-        ASSERT(mAppliedDescriptor->nextInChain == nullptr);
     }
 
     void CreateComputePipelineAsyncTask::Run() {
-        MaybeError maybeError = mComputePipeline->Initialize(mAppliedDescriptor.get());
+        MaybeError maybeError = mComputePipeline->Initialize();
         std::string errorMessage;
         if (maybeError.IsError()) {
             mComputePipeline = nullptr;
diff --git a/src/dawn_native/CreatePipelineAsyncTask.h b/src/dawn_native/CreatePipelineAsyncTask.h
index f089637..34456aa 100644
--- a/src/dawn_native/CreatePipelineAsyncTask.h
+++ b/src/dawn_native/CreatePipelineAsyncTask.h
@@ -72,7 +72,6 @@
     class CreateComputePipelineAsyncTask {
       public:
         CreateComputePipelineAsyncTask(Ref<ComputePipelineBase> nonInitializedComputePipeline,
-                                       std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
                                        size_t blueprintHash,
                                        WGPUCreateComputePipelineAsyncCallback callback,
                                        void* userdata);
@@ -87,8 +86,6 @@
         size_t mBlueprintHash;
         WGPUCreateComputePipelineAsyncCallback mCallback;
         void* mUserdata;
-
-        std::unique_ptr<FlatComputePipelineDescriptor> mAppliedDescriptor;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 6e254b6..34b4db0 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -123,18 +123,29 @@
             void* mUserdata;
         };
 
-        MaybeError ValidateLayoutAndSetDefaultLayout(
+        ResultOrError<Ref<PipelineLayoutBase>>
+        ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
             DeviceBase* device,
-            FlatComputePipelineDescriptor* appliedDescriptor) {
-            if (appliedDescriptor->layout == nullptr) {
-                Ref<PipelineLayoutBase> layoutRef;
+            const ComputePipelineDescriptor& descriptor,
+            ComputePipelineDescriptor* outDescriptor) {
+            Ref<PipelineLayoutBase> layoutRef;
+            *outDescriptor = descriptor;
+            // TODO(dawn:800): Remove after deprecation period.
+            if (outDescriptor->compute.module == nullptr &&
+                outDescriptor->computeStage.module != nullptr) {
+                outDescriptor->compute.module = outDescriptor->computeStage.module;
+                outDescriptor->compute.entryPoint = outDescriptor->computeStage.entryPoint;
+            }
+
+            if (outDescriptor->layout == nullptr) {
                 DAWN_TRY_ASSIGN(layoutRef, PipelineLayoutBase::CreateDefault(
                                                device, {{SingleShaderStage::Compute,
-                                                         appliedDescriptor->compute.module,
-                                                         appliedDescriptor->compute.entryPoint}}));
-                appliedDescriptor->SetLayout(std::move(layoutRef));
+                                                         outDescriptor->compute.module,
+                                                         outDescriptor->compute.entryPoint}}));
+                outDescriptor->layout = layoutRef.Get();
             }
-            return {};
+
+            return layoutRef;
         }
 
         ResultOrError<Ref<PipelineLayoutBase>>
@@ -1129,8 +1140,12 @@
             DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
         }
 
-        FlatComputePipelineDescriptor appliedDescriptor(descriptor);
-        DAWN_TRY(ValidateLayoutAndSetDefaultLayout(this, &appliedDescriptor));
+        // Ref will keep the pipeline layout alive until the end of the function where
+        // the pipeline will take another reference.
+        Ref<PipelineLayoutBase> layoutRef;
+        ComputePipelineDescriptor appliedDescriptor;
+        DAWN_TRY_ASSIGN(layoutRef, ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
+                                       this, *descriptor, &appliedDescriptor));
 
         auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor);
         if (pipelineAndBlueprintFromCache.first.Get() != nullptr) {
@@ -1152,12 +1167,13 @@
             DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
         }
 
-        std::unique_ptr<FlatComputePipelineDescriptor> appliedDescriptor =
-            std::make_unique<FlatComputePipelineDescriptor>(descriptor);
-        DAWN_TRY(ValidateLayoutAndSetDefaultLayout(this, appliedDescriptor.get()));
+        Ref<PipelineLayoutBase> layoutRef;
+        ComputePipelineDescriptor appliedDescriptor;
+        DAWN_TRY_ASSIGN(layoutRef, ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
+                                       this, *descriptor, &appliedDescriptor));
 
         // Call the callback directly when we can get a cached compute pipeline object.
-        auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(appliedDescriptor.get());
+        auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor);
         if (pipelineAndBlueprintFromCache.first.Get() != nullptr) {
             Ref<ComputePipelineBase> result = std::move(pipelineAndBlueprintFromCache.first);
             callback(WGPUCreatePipelineAsyncStatus_Success,
@@ -1167,24 +1183,22 @@
             // where the pipeline object may be created asynchronously and the result will be saved
             // to mCreatePipelineAsyncTracker.
             const size_t blueprintHash = pipelineAndBlueprintFromCache.second;
-            CreateComputePipelineAsyncImpl(std::move(appliedDescriptor), blueprintHash, callback,
-                                           userdata);
+            CreateComputePipelineAsyncImpl(&appliedDescriptor, blueprintHash, callback, userdata);
         }
 
         return {};
     }
 
-    // This function is overwritten with the async version on the backends
-    // that supports creating compute pipeline asynchronously
-    void DeviceBase::CreateComputePipelineAsyncImpl(
-        std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-        size_t blueprintHash,
-        WGPUCreateComputePipelineAsyncCallback callback,
-        void* userdata) {
+    // This function is overwritten with the async version on the backends that supports creating
+    // compute pipeline asynchronously.
+    void DeviceBase::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                    size_t blueprintHash,
+                                                    WGPUCreateComputePipelineAsyncCallback callback,
+                                                    void* userdata) {
         Ref<ComputePipelineBase> result;
         std::string errorMessage;
 
-        auto resultOrError = CreateComputePipelineImpl(descriptor.get());
+        auto resultOrError = CreateComputePipelineImpl(descriptor);
         if (resultOrError.IsError()) {
             std::unique_ptr<ErrorData> error = resultOrError.AcquireError();
             errorMessage = error->GetMessage();
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 2e54a4d..80fcd70 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -46,7 +46,6 @@
     class PersistentCache;
     class StagingBufferBase;
     struct CallbackTask;
-    struct FlatComputePipelineDescriptor;
     struct InternalPipelineStore;
     struct ShaderModuleParseResult;
 
@@ -360,11 +359,10 @@
             size_t blueprintHash);
         Ref<RenderPipelineBase> AddOrGetCachedRenderPipeline(Ref<RenderPipelineBase> renderPipeline,
                                                              size_t blueprintHash);
-        virtual void CreateComputePipelineAsyncImpl(
-            std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-            size_t blueprintHash,
-            WGPUCreateComputePipelineAsyncCallback callback,
-            void* userdata);
+        virtual void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                    size_t blueprintHash,
+                                                    WGPUCreateComputePipelineAsyncCallback callback,
+                                                    void* userdata);
         virtual void CreateRenderPipelineAsyncImpl(const RenderPipelineDescriptor* descriptor,
                                                    size_t blueprintHash,
                                                    WGPUCreateRenderPipelineAsyncCallback callback,
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 6649a47..0925b92 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -28,11 +28,11 @@
         Device* device,
         const ComputePipelineDescriptor* descriptor) {
         Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
-        DAWN_TRY(pipeline->Initialize(descriptor));
+        DAWN_TRY(pipeline->Initialize());
         return pipeline;
     }
 
-    MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) {
+    MaybeError ComputePipeline::Initialize() {
         Device* device = ToBackend(GetDevice());
         uint32_t compileFlags = 0;
 
@@ -43,14 +43,15 @@
         // SPRIV-cross does matrix multiplication expecting row major matrices
         compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
-        ShaderModule* module = ToBackend(descriptor->compute.module);
+        const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
+        ShaderModule* module = ToBackend(computeStage.module.Get());
 
         D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
         d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
 
         CompiledShader compiledShader;
         DAWN_TRY_ASSIGN(compiledShader,
-                        module->Compile(descriptor->compute.entryPoint, SingleShaderStage::Compute,
+                        module->Compile(computeStage.entryPoint.c_str(), SingleShaderStage::Compute,
                                         ToBackend(GetLayout()), compileFlags));
         d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
         auto* d3d12Device = device->GetD3D12Device();
@@ -77,14 +78,14 @@
     }
 
     void ComputePipeline::CreateAsync(Device* device,
-                                      std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
+                                      const ComputePipelineDescriptor* descriptor,
                                       size_t blueprintHash,
                                       WGPUCreateComputePipelineAsyncCallback callback,
                                       void* userdata) {
-        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor.get()));
+        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
         std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
-            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, std::move(descriptor),
-                                                             blueprintHash, callback, userdata);
+            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, blueprintHash, callback,
+                                                             userdata);
         CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
     }
 
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.h b/src/dawn_native/d3d12/ComputePipelineD3D12.h
index ddb3ea0..d945ee2 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.h
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.h
@@ -29,7 +29,7 @@
             Device* device,
             const ComputePipelineDescriptor* descriptor);
         static void CreateAsync(Device* device,
-                                std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
+                                const ComputePipelineDescriptor* descriptor,
                                 size_t blueprintHash,
                                 WGPUCreateComputePipelineAsyncCallback callback,
                                 void* userdata);
@@ -43,7 +43,7 @@
       private:
         ~ComputePipeline() override;
         using ComputePipelineBase::ComputePipelineBase;
-        MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override;
+        MaybeError Initialize() override;
         ComPtr<ID3D12PipelineState> mPipelineState;
     };
 
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index 7512352..431d6d2 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -376,13 +376,11 @@
         const TextureViewDescriptor* descriptor) {
         return TextureView::Create(texture, descriptor);
     }
-    void Device::CreateComputePipelineAsyncImpl(
-        std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-        size_t blueprintHash,
-        WGPUCreateComputePipelineAsyncCallback callback,
-        void* userdata) {
-        ComputePipeline::CreateAsync(this, std::move(descriptor), blueprintHash, callback,
-                                     userdata);
+    void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                size_t blueprintHash,
+                                                WGPUCreateComputePipelineAsyncCallback callback,
+                                                void* userdata) {
+        ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
     }
 
     ResultOrError<std::unique_ptr<StagingBufferBase>> Device::CreateStagingBuffer(size_t size) {
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 124cc97..8430969 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -173,11 +173,10 @@
         ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
             TextureBase* texture,
             const TextureViewDescriptor* descriptor) override;
-        void CreateComputePipelineAsyncImpl(
-            std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-            size_t blueprintHash,
-            WGPUCreateComputePipelineAsyncCallback callback,
-            void* userdata) override;
+        void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                            size_t blueprintHash,
+                                            WGPUCreateComputePipelineAsyncCallback callback,
+                                            void* userdata) override;
 
         void ShutDownImpl() override;
         MaybeError WaitForIdleForDestruction() override;
diff --git a/src/dawn_native/metal/ComputePipelineMTL.h b/src/dawn_native/metal/ComputePipelineMTL.h
index 5ea4e05..6f777b6 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.h
+++ b/src/dawn_native/metal/ComputePipelineMTL.h
@@ -31,7 +31,7 @@
             Device* device,
             const ComputePipelineDescriptor* descriptor);
         static void CreateAsync(Device* device,
-                                std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
+                                const ComputePipelineDescriptor* descriptor,
                                 size_t blueprintHash,
                                 WGPUCreateComputePipelineAsyncCallback callback,
                                 void* userdata);
@@ -42,7 +42,7 @@
 
       private:
         using ComputePipelineBase::ComputePipelineBase;
-        MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override;
+        MaybeError Initialize() override;
 
         NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState;
         MTLSize mLocalWorkgroupSize;
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index 090705f..14110cc 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -25,15 +25,16 @@
         Device* device,
         const ComputePipelineDescriptor* descriptor) {
         Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
-        DAWN_TRY(pipeline->Initialize(descriptor));
+        DAWN_TRY(pipeline->Initialize());
         return pipeline;
     }
 
-    MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) {
+    MaybeError ComputePipeline::Initialize() {
         auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
 
-        ShaderModule* computeModule = ToBackend(descriptor->compute.module);
-        const char* computeEntryPoint = descriptor->compute.entryPoint;
+        const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
+        ShaderModule* computeModule = ToBackend(computeStage.module.Get());
+        const char* computeEntryPoint = computeStage.entryPoint.c_str();
         ShaderModule::MetalFunctionData computeData;
         DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
                                                ToBackend(GetLayout()), &computeData));
@@ -69,14 +70,14 @@
     }
 
     void ComputePipeline::CreateAsync(Device* device,
-                                      std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
+                                      const ComputePipelineDescriptor* descriptor,
                                       size_t blueprintHash,
                                       WGPUCreateComputePipelineAsyncCallback callback,
                                       void* userdata) {
-        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor.get()));
+        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
         std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
-            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, std::move(descriptor),
-                                                             blueprintHash, callback, userdata);
+            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, blueprintHash, callback,
+                                                             userdata);
         CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
     }
 
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index 27e11d0..19b900b 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -113,11 +113,10 @@
         ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
             TextureBase* texture,
             const TextureViewDescriptor* descriptor) override;
-        void CreateComputePipelineAsyncImpl(
-            std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-            size_t blueprintHash,
-            WGPUCreateComputePipelineAsyncCallback callback,
-            void* userdata) override;
+        void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                            size_t blueprintHash,
+                                            WGPUCreateComputePipelineAsyncCallback callback,
+                                            void* userdata) override;
 
         void InitTogglesFromDriver();
         void ShutDownImpl() override;
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 664f2c6..348f5db 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -267,13 +267,11 @@
         const TextureViewDescriptor* descriptor) {
         return TextureView::Create(texture, descriptor);
     }
-    void Device::CreateComputePipelineAsyncImpl(
-        std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-        size_t blueprintHash,
-        WGPUCreateComputePipelineAsyncCallback callback,
-        void* userdata) {
-        ComputePipeline::CreateAsync(this, std::move(descriptor), blueprintHash, callback,
-                                     userdata);
+    void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                size_t blueprintHash,
+                                                WGPUCreateComputePipelineAsyncCallback callback,
+                                                void* userdata) {
+        ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
     }
 
     ResultOrError<ExecutionSerial> Device::CheckAndUpdateCompletedSerials() {
diff --git a/src/dawn_native/opengl/ComputePipelineGL.cpp b/src/dawn_native/opengl/ComputePipelineGL.cpp
index 7680cee..fb33344 100644
--- a/src/dawn_native/opengl/ComputePipelineGL.cpp
+++ b/src/dawn_native/opengl/ComputePipelineGL.cpp
@@ -23,11 +23,11 @@
         Device* device,
         const ComputePipelineDescriptor* descriptor) {
         Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
-        DAWN_TRY(pipeline->Initialize(descriptor));
+        DAWN_TRY(pipeline->Initialize());
         return pipeline;
     }
 
-    MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor*) {
+    MaybeError ComputePipeline::Initialize() {
         DAWN_TRY(
             InitializeBase(ToBackend(GetDevice())->gl, ToBackend(GetLayout()), GetAllStages()));
         return {};
diff --git a/src/dawn_native/opengl/ComputePipelineGL.h b/src/dawn_native/opengl/ComputePipelineGL.h
index e84e366..8e77440 100644
--- a/src/dawn_native/opengl/ComputePipelineGL.h
+++ b/src/dawn_native/opengl/ComputePipelineGL.h
@@ -36,7 +36,7 @@
       private:
         using ComputePipelineBase::ComputePipelineBase;
         ~ComputePipeline() override = default;
-        MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override;
+        MaybeError Initialize() override;
     };
 
 }}  // namespace dawn_native::opengl
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index 4cabbb1..e289845 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -29,16 +29,16 @@
         Device* device,
         const ComputePipelineDescriptor* descriptor) {
         Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
-        DAWN_TRY(pipeline->Initialize(descriptor));
+        DAWN_TRY(pipeline->Initialize());
         return pipeline;
     }
 
-    MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) {
+    MaybeError ComputePipeline::Initialize() {
         VkComputePipelineCreateInfo createInfo;
         createInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
         createInfo.pNext = nullptr;
         createInfo.flags = 0;
-        createInfo.layout = ToBackend(descriptor->layout)->GetHandle();
+        createInfo.layout = ToBackend(GetLayout())->GetHandle();
         createInfo.basePipelineHandle = ::VK_NULL_HANDLE;
         createInfo.basePipelineIndex = -1;
 
@@ -47,11 +47,12 @@
         createInfo.stage.flags = 0;
         createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
         // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
+        const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
         DAWN_TRY_ASSIGN(createInfo.stage.module,
-                        ToBackend(descriptor->compute.module)
-                            ->GetTransformedModuleHandle(descriptor->compute.entryPoint,
+                        ToBackend(computeStage.module.Get())
+                            ->GetTransformedModuleHandle(computeStage.entryPoint.c_str(),
                                                          ToBackend(GetLayout())));
-        createInfo.stage.pName = descriptor->compute.entryPoint;
+        createInfo.stage.pName = computeStage.entryPoint.c_str();
         createInfo.stage.pSpecializationInfo = nullptr;
 
         Device* device = ToBackend(GetDevice());
@@ -95,14 +96,14 @@
     }
 
     void ComputePipeline::CreateAsync(Device* device,
-                                      std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
+                                      const ComputePipelineDescriptor* descriptor,
                                       size_t blueprintHash,
                                       WGPUCreateComputePipelineAsyncCallback callback,
                                       void* userdata) {
-        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor.get()));
+        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
         std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
-            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, std::move(descriptor),
-                                                             blueprintHash, callback, userdata);
+            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, blueprintHash, callback,
+                                                             userdata);
         CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
     }
 
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.h b/src/dawn_native/vulkan/ComputePipelineVk.h
index 7afe9cf..72e2716 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.h
+++ b/src/dawn_native/vulkan/ComputePipelineVk.h
@@ -30,7 +30,7 @@
             Device* device,
             const ComputePipelineDescriptor* descriptor);
         static void CreateAsync(Device* device,
-                                std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
+                                const ComputePipelineDescriptor* descriptor,
                                 size_t blueprintHash,
                                 WGPUCreateComputePipelineAsyncCallback callback,
                                 void* userdata);
@@ -43,7 +43,7 @@
       private:
         ~ComputePipeline() override;
         using ComputePipelineBase::ComputePipelineBase;
-        MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override;
+        MaybeError Initialize() override;
 
         VkPipeline mHandle = VK_NULL_HANDLE;
     };
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index 12d883f..45b98c8 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -162,13 +162,11 @@
         const TextureViewDescriptor* descriptor) {
         return TextureView::Create(texture, descriptor);
     }
-    void Device::CreateComputePipelineAsyncImpl(
-        std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-        size_t blueprintHash,
-        WGPUCreateComputePipelineAsyncCallback callback,
-        void* userdata) {
-        ComputePipeline::CreateAsync(this, std::move(descriptor), blueprintHash, callback,
-                                     userdata);
+    void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                size_t blueprintHash,
+                                                WGPUCreateComputePipelineAsyncCallback callback,
+                                                void* userdata) {
+        ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
     }
 
     MaybeError Device::TickImpl() {
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 9bce52c..e053ebf 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -137,11 +137,10 @@
         ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
             TextureBase* texture,
             const TextureViewDescriptor* descriptor) override;
-        void CreateComputePipelineAsyncImpl(
-            std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
-            size_t blueprintHash,
-            WGPUCreateComputePipelineAsyncCallback callback,
-            void* userdata) override;
+        void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                            size_t blueprintHash,
+                                            WGPUCreateComputePipelineAsyncCallback callback,
+                                            void* userdata) override;
 
         ResultOrError<VulkanDeviceKnobs> CreateDevice(VkPhysicalDevice physicalDevice);
         void GatherQueueFromDevice();