Refactor ShaderModule initialization into a "task"

Refactor ShaderModule creation to have an Initialize function in the
front-end which does BlobCache logic and shader module parsing within a
lambda.

This task is currently executed immediately in-thread and validation
errors are stored in ShaderModuleBase. DeviceBase::CreateShaderModule always
returns a ShaderModule unless there was an internal backend error.

This task will be run asynchronously in the future.

Asynchronously adding errors to error scopes is not completed yet so
DeviceBase acquires all errors and consumes them immediately.

Refactored WGPU backend ShaderModule to not use the Initialize method.
Simply always create a shader module, don't try to optimize for
creating one only in the non-error case.

Bug:406522796

Change-Id: Ia8111eca4abf57de08a96d891a73fedb2d0edb44
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/251855
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Geoff Lang <geofflang@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 016900a..e1c0126 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -1458,31 +1458,40 @@
     utils::TraceLabel label = utils::GetLabelForTrace(descriptor->label);
     TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APICreateShaderModule", "label", label.label);
 
-    // parseResult is modified by CreateShaderModule via pointer to provide compilation messages in
-    // error cases.
-    ShaderModuleParseResult parseResult;
-    auto creationResult = CreateShaderModule(descriptor, /*internalExtensions=*/{}, &parseResult);
-
+    Ref<ShaderModuleBase> shaderModule;
+    std::unique_ptr<ErrorData> errorData;
+    auto creationResult = CreateShaderModule(descriptor, /*internalExtensions=*/{});
     if (creationResult.IsSuccess()) {
-        Ref<ShaderModuleBase> validShaderModule = creationResult.AcquireSuccess();
-        DAWN_ASSERT(validShaderModule != nullptr && !validShaderModule->IsError());
-        EmitCompilationLog(validShaderModule.Get());
-        return ReturnToAPI(std::move(validShaderModule));
+        // CreateShaderModule can succeed but still return a shader module which failed compilation.
+        // TODO(crbug.com/406522796): Remove this once ShaderModuleBase writes directly to the error
+        // scope.
+        shaderModule = creationResult.AcquireSuccess();
+        if (shaderModule->IsError()) {
+            errorData = shaderModule->GetInitializationError();
+        }
+    } else {
+        // If CreateShaderModule failed, it was due to internal errors that should not surface as
+        // compilation errors.
+        shaderModule = ShaderModuleBase::MakeError(this, descriptor ? descriptor->label : nullptr,
+                                                   ParsedCompilationMessages());
+        DAWN_ASSERT(shaderModule->IsError());
+        errorData = creationResult.AcquireError();
     }
 
-    // If shader creation failed, create an error shader module with compilation messages so the
-    // application can later retrieve it with GetCompilationInfo.
-    Ref<ShaderModuleBase> errorShaderModule = ShaderModuleBase::MakeError(
-        this, descriptor ? descriptor->label : nullptr, std::move(parseResult.compilationMessages));
-    DAWN_ASSERT(errorShaderModule != nullptr && errorShaderModule->IsError());
+    DAWN_ASSERT(shaderModule != nullptr);
 
-    // Acquire the device lock for error handling, and return the error shader module.
-    auto deviceGuard = GetGuard();
-    // Emit error, including Tint errors and warnings for the error shader module.
-    auto consumedError = ConsumedError(creationResult.AcquireError(), InternalErrorType::Internal,
-                                       "calling %s.CreateShaderModule(%s).", this, descriptor);
-    DAWN_ASSERT(consumedError);
-    return ReturnToAPI(std::move(errorShaderModule));
+    if (errorData != nullptr) {
+        // Acquire the device lock for error handling.
+        auto deviceGuard = GetGuard();
+        // Emit error, including Tint errors and warnings.
+        auto consumedError = ConsumedError(std::move(errorData), InternalErrorType::Internal,
+                                           "calling %s.CreateShaderModule(%s).", this, descriptor);
+        DAWN_ASSERT(consumedError);
+    }
+
+    DAWN_ASSERT(errorData == nullptr);
+
+    return ReturnToAPI(std::move(shaderModule));
 }
 
 ShaderModuleBase* DeviceBase::APICreateErrorShaderModule(const ShaderModuleDescriptor* descriptor,
@@ -2191,8 +2200,7 @@
 
 ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
     const ShaderModuleDescriptor* descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* outputParseResult) {
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
     DAWN_TRY(ValidateIsAlive());
 
     // Unpack and validate the descriptor chain before doing further validation or cache
@@ -2242,50 +2250,14 @@
     const size_t blueprintHash = blueprint.ComputeContentHash();
     blueprint.SetContentHash(blueprintHash);
 
-    // Check in-memory shader module cache first, and if missed check the blob cache, and if missed
-    // again call ParseShaderModule.
+    // Check in-memory shader module cache first, and if missed create a new ShaderModule which may
+    // use the BlobCache.
     return GetOrCreate(
         mCaches->shaderModules, &blueprint, [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
-            SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetPlatform(), "CreateShaderModuleUS");
-
-            auto resultOrError = [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
-                ShaderModuleParseRequest req = BuildShaderModuleParseRequest(
-                    this, blueprint.GetHash(), unpacked, internalExtensions,
-                    /* needReflection */ true);
-
-                // Check blob cache first before calling ParseShaderModule. ShaderModuleParseResult
-                // returned from blob cache or ParseShaderModule will hold compilation messages and
-                // validation errors if any. ShaderModuleParseResult from ParseShaderModule also
-                // holds tint program.
-                CacheResult<ShaderModuleParseResult> result;
-                DAWN_TRY_LOAD_OR_RUN(result, this, std::move(req),
-                                     ShaderModuleParseResult::FromBlob, ParseShaderModule,
-                                     "ShaderModuleParsing");
-                GetBlobCache()->EnsureStored(result);
-                ShaderModuleParseResult parseResult = result.Acquire();
-
-                // If ShaderModuleParseResult has validation error, move the compilation messages to
-                // *outputParseResult so that we can create an error shader module from it, and then
-                // return the validation error.
-                if (parseResult.HasError()) {
-                    auto error = parseResult.cachedValidationError->ToErrorData();
-                    if (outputParseResult) {
-                        *outputParseResult = std::move(parseResult);
-                    }
-                    return error;
-                }
-                // Otherwise with no error, create a shader module from parse result and return it.
-                Ref<ShaderModuleBase> shaderModule;
-                DAWN_TRY_ASSIGN(shaderModule,
-                                CreateShaderModuleImpl(unpacked, internalExtensions, &parseResult));
-                shaderModule->SetContentHash(blueprintHash);
-                return shaderModule;
-            }();
-
-            DAWN_HISTOGRAM_BOOLEAN(GetPlatform(), "CreateShaderModuleSuccess",
-                                   resultOrError.IsSuccess());
-
-            return resultOrError;
+            Ref<ShaderModuleBase> shaderModule;
+            DAWN_TRY_ASSIGN(shaderModule, CreateShaderModuleImpl(unpacked, internalExtensions));
+            shaderModule->SetContentHash(blueprintHash);
+            return shaderModule;
         });
 }
 
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 368038b..f6fb597 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -229,8 +229,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSampler(const SamplerDescriptor* descriptor = nullptr);
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
         const ShaderModuleDescriptor* descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions = {},
-        ShaderModuleParseResult* outputParseResult = nullptr);
+        const std::vector<tint::wgsl::Extension>& internalExtensions = {});
     ResultOrError<Ref<SwapChainBase>> CreateSwapChain(Surface* surface,
                                                       SwapChainBase* previousSwapChain,
                                                       const SurfaceConfiguration* config);
@@ -515,8 +514,7 @@
         const SamplerDescriptor* descriptor) = 0;
     virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) = 0;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) = 0;
     // Note that previousSwapChain may be nullptr, or come from a different backend.
     virtual ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index c81e0fd..ee7ab48 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -39,6 +39,7 @@
 #include "dawn/native/ChainUtils.h"
 #include "dawn/native/CompilationMessages.h"
 #include "dawn/native/Device.h"
+#include "dawn/native/Error.h"
 #include "dawn/native/Instance.h"
 #include "dawn/native/ObjectContentHasher.h"
 #include "dawn/native/Pipeline.h"
@@ -1745,7 +1746,7 @@
                                    const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                                    std::vector<tint::wgsl::Extension> internalExtensions,
                                    ApiObjectBase::UntrackedByDeviceTag tag)
-    : Base(device, descriptor->label),
+    : Base(device, ObjectBase::kDelayedInitialization, descriptor->label),
       mType(Type::Undefined),
       mInternalExtensions(std::move(internalExtensions)) {
     size_t shaderCodeByteSize = 0;
@@ -1825,6 +1826,85 @@
         new ShaderModuleBase(device, ObjectBase::kError, label, std::move(compilationMessages)));
 }
 
+void ShaderModuleBase::Initialize() {
+    auto task = [&, shaderModuleRef = Ref<ShaderModuleBase>(this)]() {
+        SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetDevice()->GetPlatform(), "CreateShaderModuleUS");
+
+        auto taskMaybeError = [&]() -> MaybeError {
+            // Check blob cache first before calling ParseShaderModule. ShaderModuleParseResult
+            // returned from blob cache or ParseShaderModule will hold compilation messages and
+            // validation errors if any. ShaderModuleParseResult from ParseShaderModule also
+            // holds tint program.
+            CacheResult<ShaderModuleParseResult> cacheResult;
+            DAWN_TRY_LOAD_OR_RUN(cacheResult, GetDevice(), GenerateShaderModuleParseRequest(true),
+                                 ShaderModuleParseResult::FromBlob, ParseShaderModule,
+                                 "ShaderModuleParsing");
+            GetDevice()->GetBlobCache()->EnsureStored(cacheResult);
+
+            ShaderModuleParseResult parseResult = cacheResult.Acquire();
+
+            // Move the compilation messages regardless of compilation success. Compilation messages
+            // should be inject only once for each shader module.
+            DAWN_ASSERT(mCompilationMessages == nullptr);
+            // Move the compilationMessages into the shader module and emit the tint errors and
+            // warnings
+            mCompilationMessages = std::make_unique<OwnedCompilationMessages>(
+                std::move(parseResult.compilationMessages));
+
+            // If ShaderModuleParseResult has validation error, notify the caller that compilation
+            // failed. The compilation messages have already been stored.
+            if (parseResult.HasError()) {
+                return parseResult.cachedValidationError->ToErrorData();
+            }
+
+            DAWN_ASSERT(!parseResult.HasError());
+            if (parseResult.HasTintProgram()) {
+                mTintData.Use([&](auto tintData) {
+                    tintData->tintProgram =
+                        std::move(parseResult.tintProgram.UnsafeGetValue().value());
+                });
+            }
+
+            // Gather the metadata and default entry point names
+            DAWN_ASSERT(parseResult.metadataTable.has_value());
+            mEntryPoints = std::move(parseResult.metadataTable.value());
+
+            for (auto stage : IterateStages(kAllStages)) {
+                mEntryPointCounts[stage] = 0;
+            }
+            for (auto& [name, metadata] : mEntryPoints) {
+                SingleShaderStage stage = metadata->stage;
+                if (mEntryPointCounts[stage] == 0) {
+                    mDefaultEntryPointNames[stage] = name;
+                }
+                mEntryPointCounts[stage]++;
+            }
+
+            return {};
+        }();
+
+        DAWN_HISTOGRAM_BOOLEAN(GetDevice()->GetPlatform(), "CreateShaderModuleSuccess",
+                               taskMaybeError.IsSuccess());
+        if (taskMaybeError.IsError()) {
+            SetInitializedError();
+            mInitializationError = CachedValidationError(taskMaybeError.AcquireError());
+        } else {
+            SetInitializedNoError();
+
+            // On successful compilation, emit the compilation log
+            GetDevice()->EmitCompilationLog(this);
+        }
+    };
+
+    task();
+    DAWN_ASSERT(IsInitialized());
+}
+
+std::unique_ptr<ErrorData> ShaderModuleBase::GetInitializationError() {
+    DAWN_ASSERT(mInitializationError.has_value());
+    return mInitializationError->ToErrorData();
+}
+
 ObjectType ShaderModuleBase::GetType() const {
     return ObjectType::ShaderModule;
 }
@@ -1895,34 +1975,9 @@
         // shader source code, Dawn will look up from the cache and return the same
         // ShaderModuleBase. In this case, we have to recreate the released mTintProgram for
         // initializing new pipelines.
-        ShaderModuleDescriptor descriptor;
-        ShaderSourceWGSL wgslDescriptor;
-        ShaderSourceSPIRV spirvDescriptor;
-        DawnShaderModuleSPIRVOptionsDescriptor spirvOptionsDescriptor;
-
-        switch (mType) {
-            case Type::Spirv:
-                spirvOptionsDescriptor.allowNonUniformDerivatives =
-                    mAllowSpirvNonUniformDerivitives;
-                spirvDescriptor.nextInChain = &spirvOptionsDescriptor;
-
-                spirvDescriptor.codeSize = mOriginalSpirv.size();
-                spirvDescriptor.code = mOriginalSpirv.data();
-                descriptor.nextInChain = &spirvDescriptor;
-                break;
-            case Type::Wgsl:
-                wgslDescriptor.code = std::string_view(mWgsl);
-                descriptor.nextInChain = &wgslDescriptor;
-                break;
-            default:
-                DAWN_UNREACHABLE();
-        }
-
         // Assuming ParseShaderModule will not throw error for regenerating.
         ShaderModuleParseResult regeneratedParseResult =
-            ParseShaderModule(BuildShaderModuleParseRequest(GetDevice(), mHash, Unpack(&descriptor),
-                                                            mInternalExtensions,
-                                                            /* needReflection */ false))
+            ParseShaderModule(GenerateShaderModuleParseRequest(/* needReflection */ false))
                 .AcquireSuccess();
         DAWN_ASSERT(regeneratedParseResult.HasTintProgram() && !regeneratedParseResult.HasError());
 
@@ -2006,37 +2061,6 @@
     mCompilationMessages = std::move(*compilationMessages);
 }
 
-MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
-    DAWN_ASSERT(!parseResult->HasError());
-    if (parseResult->HasTintProgram()) {
-        mTintData.Use([&](auto tintData) {
-            tintData->tintProgram = std::move(parseResult->tintProgram.UnsafeGetValue().value());
-        });
-    }
-    DAWN_ASSERT(parseResult->metadataTable.has_value());
-    mEntryPoints = std::move(parseResult->metadataTable.value());
-
-    for (auto stage : IterateStages(kAllStages)) {
-        mEntryPointCounts[stage] = 0;
-    }
-    for (auto& [name, metadata] : mEntryPoints) {
-        SingleShaderStage stage = metadata->stage;
-        if (mEntryPointCounts[stage] == 0) {
-            mDefaultEntryPointNames[stage] = name;
-        }
-        mEntryPointCounts[stage]++;
-    }
-
-    // Move the compilation messages if initialized successfully. Compilation messages should be
-    // inject only once for each shader module.
-    DAWN_ASSERT(mCompilationMessages == nullptr);
-    // Move the compilationMessages into the shader module and emit the tint errors and warnings
-    mCompilationMessages =
-        std::make_unique<OwnedCompilationMessages>(std::move(parseResult->compilationMessages));
-
-    return {};
-}
-
 void ShaderModuleBase::WillDropLastExternalRef() {
     // The last external ref being dropped indicates that the application is not currently using,
     // and no pending task will use the shader module. In this case we can free the memory for the
@@ -2044,4 +2068,32 @@
     mTintData.Use([&](auto tintData) { tintData->tintProgram = nullptr; });
 }
 
+ShaderModuleParseRequest ShaderModuleBase::GenerateShaderModuleParseRequest(
+    bool needReflection) const {
+    ShaderModuleDescriptor descriptor;
+    ShaderSourceWGSL wgslDescriptor;
+    ShaderSourceSPIRV spirvDescriptor;
+    DawnShaderModuleSPIRVOptionsDescriptor spirvOptionsDescriptor;
+
+    switch (mType) {
+        case Type::Spirv:
+            spirvOptionsDescriptor.allowNonUniformDerivatives = mAllowSpirvNonUniformDerivitives;
+            spirvDescriptor.nextInChain = &spirvOptionsDescriptor;
+
+            spirvDescriptor.codeSize = mOriginalSpirv.size();
+            spirvDescriptor.code = mOriginalSpirv.data();
+            descriptor.nextInChain = &spirvDescriptor;
+            break;
+        case Type::Wgsl:
+            wgslDescriptor.code = std::string_view(mWgsl);
+            descriptor.nextInChain = &wgslDescriptor;
+            break;
+        default:
+            DAWN_UNREACHABLE();
+    }
+
+    return BuildShaderModuleParseRequest(GetDevice(), mHash, Unpack(&descriptor),
+                                         mInternalExtensions, needReflection);
+}
+
 }  // namespace dawn::native
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 2fa5c26..efb7198 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -373,6 +373,9 @@
                                            StringView label,
                                            ParsedCompilationMessages&& compilationMessages);
 
+    void Initialize();
+    std::unique_ptr<ErrorData> GetInitializationError();
+
     ObjectType GetType() const override;
 
     // Return true iff the program has an entrypoint called `entryPoint`.
@@ -423,8 +426,6 @@
   protected:
     void DestroyImpl() override;
 
-    MaybeError InitializeBase(ShaderModuleParseResult* parseResult);
-
   private:
     ShaderModuleBase(DeviceBase* device,
                      ObjectBase::ErrorTag tag,
@@ -433,6 +434,8 @@
 
     void WillDropLastExternalRef() override;
 
+    ShaderModuleParseRequest GenerateShaderModuleParseRequest(bool needReflection) const;
+
     // The original data in the descriptor for caching.
     enum class Type : uint8_t { Undefined, Spirv, Wgsl };
     Type mType;
@@ -462,6 +465,11 @@
     std::unique_ptr<const OwnedCompilationMessages> mCompilationMessages;
 
     const std::vector<tint::wgsl::Extension> mInternalExtensions;
+
+    // Storage of any error generated during initialization. When initialization is fully
+    // asynchronous, this will be removed and inserted into a stored error scope during
+    // initialization.
+    std::optional<CachedValidationError> mInitializationError;
 };
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index 9a24e93..85af111 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -286,9 +286,8 @@
 
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult);
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    return ShaderModule::Create(this, descriptor, internalExtensions);
 }
 
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
diff --git a/src/dawn/native/d3d11/DeviceD3D11.h b/src/dawn/native/d3d11/DeviceD3D11.h
index ed86fed..2d14e34 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.h
+++ b/src/dawn/native/d3d11/DeviceD3D11.h
@@ -118,8 +118,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
index d57ba62..0caf74d 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
@@ -57,11 +57,10 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(parseResult));
-    return module;
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    Ref<ShaderModule> shader = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
+    shader->Initialize();
+    return shader;
 }
 
 ShaderModule::ShaderModule(Device* device,
@@ -69,10 +68,6 @@
                            std::vector<tint::wgsl::Extension> internalExtensions)
     : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
-MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
-    return InitializeBase(parseResult);
-}
-
 ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
     const ProgrammableStage& programmableStage,
     SingleShaderStage stage,
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.h b/src/dawn/native/d3d11/ShaderModuleD3D11.h
index 476099d..becf7a1 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.h
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.h
@@ -53,8 +53,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult);
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
 
     ResultOrError<d3d::CompiledShader> Compile(
         const ProgrammableStage& programmableStage,
@@ -71,7 +70,6 @@
                  const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                  std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override = default;
-    MaybeError Initialize(ShaderModuleParseResult* parseResult);
 };
 
 }  // namespace dawn::native::d3d11
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 52a5002..cef4753 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -411,9 +411,8 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult);
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    return ShaderModule::Create(this, descriptor, internalExtensions);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/d3d12/DeviceD3D12.h b/src/dawn/native/d3d12/DeviceD3D12.h
index 1ae03ac..f13e81a 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.h
+++ b/src/dawn/native/d3d12/DeviceD3D12.h
@@ -191,8 +191,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index aa00dbb..43f0c05 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -100,11 +100,10 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(parseResult));
-    return module;
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    Ref<ShaderModule> shader = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
+    shader->Initialize();
+    return shader;
 }
 
 ShaderModule::ShaderModule(Device* device,
@@ -112,10 +111,6 @@
                            std::vector<tint::wgsl::Extension> internalExtensions)
     : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
-MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
-    return InitializeBase(parseResult);
-}
-
 ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
     const ProgrammableStage& programmableStage,
     SingleShaderStage stage,
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h
index 0dfe05e..596e0df 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h
@@ -53,8 +53,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult);
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
 
     ResultOrError<d3d::CompiledShader> Compile(
         const ProgrammableStage& programmableStage,
@@ -69,7 +68,6 @@
                  const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                  std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override = default;
-    MaybeError Initialize(ShaderModuleParseResult* parseResult);
 };
 
 }  // namespace dawn::native::d3d12
diff --git a/src/dawn/native/metal/DeviceMTL.h b/src/dawn/native/metal/DeviceMTL.h
index cd38565..4568b4e 100644
--- a/src/dawn/native/metal/DeviceMTL.h
+++ b/src/dawn/native/metal/DeviceMTL.h
@@ -112,8 +112,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 1cd1639..5eb43a6 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -225,9 +225,8 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult);
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    return ShaderModule::Create(this, descriptor, internalExtensions);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/metal/ShaderModuleMTL.h b/src/dawn/native/metal/ShaderModuleMTL.h
index cecc29a..5ead1b1 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.h
+++ b/src/dawn/native/metal/ShaderModuleMTL.h
@@ -54,8 +54,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult);
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
 
     struct MetalFunctionData {
         std::string msl;
@@ -78,8 +77,6 @@
                  const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                  std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override;
-
-    MaybeError Initialize(ShaderModuleParseResult* parseResult);
 };
 
 }  // namespace dawn::native::metal
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index 0ffcfb9..8a52061 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -107,11 +107,10 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(parseResult));
-    return module;
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    Ref<ShaderModule> shader = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
+    shader->Initialize();
+    return shader;
 }
 
 ShaderModule::ShaderModule(Device* device,
@@ -121,10 +120,6 @@
 
 ShaderModule::~ShaderModule() = default;
 
-MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
-    return InitializeBase(parseResult);
-}
-
 namespace {
 
 tint::Bindings GenerateBindingInfo(SingleShaderStage stage,
diff --git a/src/dawn/native/null/DeviceNull.cpp b/src/dawn/native/null/DeviceNull.cpp
index 693b641..7d9da5e 100644
--- a/src/dawn/native/null/DeviceNull.cpp
+++ b/src/dawn/native/null/DeviceNull.cpp
@@ -235,10 +235,9 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
     Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(parseResult));
+    module->Initialize();
     return module;
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
@@ -577,12 +576,6 @@
     }
 }
 
-// ShaderModule
-
-MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
-    return InitializeBase(parseResult);
-}
-
 uint32_t Device::GetOptimalBytesPerRowAlignment() const {
     return 1;
 }
diff --git a/src/dawn/native/null/DeviceNull.h b/src/dawn/native/null/DeviceNull.h
index 15b14ca..e498524 100644
--- a/src/dawn/native/null/DeviceNull.h
+++ b/src/dawn/native/null/DeviceNull.h
@@ -160,8 +160,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
@@ -323,8 +322,6 @@
 class ShaderModule final : public ShaderModuleBase {
   public:
     using ShaderModuleBase::ShaderModuleBase;
-
-    MaybeError Initialize(ShaderModuleParseResult* parseResult);
 };
 
 class SwapChain final : public SwapChainBase {
diff --git a/src/dawn/native/opengl/DeviceGL.cpp b/src/dawn/native/opengl/DeviceGL.cpp
index a3595cf..e0dbcd1 100644
--- a/src/dawn/native/opengl/DeviceGL.cpp
+++ b/src/dawn/native/opengl/DeviceGL.cpp
@@ -285,9 +285,8 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult);
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    return ShaderModule::Create(this, descriptor, internalExtensions);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/opengl/DeviceGL.h b/src/dawn/native/opengl/DeviceGL.h
index a3e08cd..fc582ef 100644
--- a/src/dawn/native/opengl/DeviceGL.h
+++ b/src/dawn/native/opengl/DeviceGL.h
@@ -131,8 +131,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp
index e748388..971f550 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn/native/opengl/ShaderModuleGL.cpp
@@ -307,11 +307,10 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(parseResult));
-    return module;
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    Ref<ShaderModule> shader = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
+    shader->Initialize();
+    return shader;
 }
 
 ShaderModule::ShaderModule(Device* device,
@@ -319,12 +318,6 @@
                            std::vector<tint::wgsl::Extension> internalExtensions)
     : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
-MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
-    DAWN_TRY(InitializeBase(parseResult));
-
-    return {};
-}
-
 std::pair<tint::glsl::writer::Bindings, BindingMap> GenerateBindingInfo(
     SingleShaderStage stage,
     const PipelineLayout* layout,
diff --git a/src/dawn/native/opengl/ShaderModuleGL.h b/src/dawn/native/opengl/ShaderModuleGL.h
index 9b67538..38281b4 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.h
+++ b/src/dawn/native/opengl/ShaderModuleGL.h
@@ -84,8 +84,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult);
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
 
     ResultOrError<GLuint> CompileShader(const OpenGLFunctions& gl,
                                         const ProgrammableStage& programmableStage,
@@ -104,7 +103,6 @@
                  const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                  std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override = default;
-    MaybeError Initialize(ShaderModuleParseResult* parseResult);
 };
 
 }  // namespace opengl
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 5317ec7..a51e279 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -232,9 +232,8 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult);
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    return ShaderModule::Create(this, descriptor, internalExtensions);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/vulkan/DeviceVk.h b/src/dawn/native/vulkan/DeviceVk.h
index f429d7c..8d95b55 100644
--- a/src/dawn/native/vulkan/DeviceVk.h
+++ b/src/dawn/native/vulkan/DeviceVk.h
@@ -156,8 +156,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index ec42b72..e597736 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -88,11 +88,10 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(parseResult));
-    return module;
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    Ref<ShaderModule> shader = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
+    shader->Initialize();
+    return shader;
 }
 
 ShaderModule::ShaderModule(Device* device,
@@ -100,14 +99,6 @@
                            std::vector<tint::wgsl::Extension> internalExtensions)
     : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
-MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
-    return InitializeBase(parseResult);
-}
-
-void ShaderModule::DestroyImpl() {
-    ShaderModuleBase::DestroyImpl();
-}
-
 ShaderModule::~ShaderModule() = default;
 
 #if TINT_BUILD_SPV_WRITER
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index b9bf6d7..ef296dd 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -61,8 +61,7 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult);
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
 
     // Caller is responsible for destroying the `VkShaderModule` returned.
     ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
@@ -76,8 +75,6 @@
                  const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                  std::vector<tint::wgsl::Extension> internalExtensions);
     ~ShaderModule() override;
-    MaybeError Initialize(ShaderModuleParseResult* parseResult);
-    void DestroyImpl() override;
 };
 
 }  // namespace vulkan
diff --git a/src/dawn/native/webgpu/DeviceWGPU.cpp b/src/dawn/native/webgpu/DeviceWGPU.cpp
index 83c2b92..b66f11d 100644
--- a/src/dawn/native/webgpu/DeviceWGPU.cpp
+++ b/src/dawn/native/webgpu/DeviceWGPU.cpp
@@ -183,9 +183,8 @@
 }
 ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    return ShaderModule::Create(this, descriptor, internalExtensions, parseResult);
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    return ShaderModule::Create(this, descriptor, internalExtensions);
 }
 ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(Surface* surface,
                                                               SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/webgpu/DeviceWGPU.h b/src/dawn/native/webgpu/DeviceWGPU.h
index cabd8bf..b98c404 100644
--- a/src/dawn/native/webgpu/DeviceWGPU.h
+++ b/src/dawn/native/webgpu/DeviceWGPU.h
@@ -86,8 +86,7 @@
     ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
     ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult) override;
+        const std::vector<tint::wgsl::Extension>& internalExtensions) override;
     ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
         Surface* surface,
         SwapChainBase* previousSwapChain,
diff --git a/src/dawn/native/webgpu/ShaderModuleWGPU.cpp b/src/dawn/native/webgpu/ShaderModuleWGPU.cpp
index 894545b..180291e 100644
--- a/src/dawn/native/webgpu/ShaderModuleWGPU.cpp
+++ b/src/dawn/native/webgpu/ShaderModuleWGPU.cpp
@@ -38,36 +38,25 @@
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-    const std::vector<tint::wgsl::Extension>& internalExtensions,
-    ShaderModuleParseResult* parseResult) {
-    Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor, internalExtensions));
-    DAWN_TRY(module->Initialize(descriptor, parseResult));
-    return module;
+    const std::vector<tint::wgsl::Extension>& internalExtensions) {
+    auto desc = ToAPI(*descriptor);
+    WGPUShaderModule innerShaderModule =
+        device->wgpu.deviceCreateShaderModule(device->GetInnerHandle(), desc);
+    DAWN_ASSERT(innerShaderModule);
+
+    Ref<ShaderModule> shader =
+        AcquireRef(new ShaderModule(device, descriptor, internalExtensions, innerShaderModule));
+    shader->Initialize();
+    return shader;
 }
 
 ShaderModule::ShaderModule(Device* device,
                            const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-                           std::vector<tint::wgsl::Extension> internalExtensions)
+                           std::vector<tint::wgsl::Extension> internalExtensions,
+                           WGPUShaderModule innerShaderModule)
     : ShaderModuleBase(device, descriptor, std::move(internalExtensions)),
-      ObjectWGPU(device->wgpu.shaderModuleRelease) {}
-
-MaybeError ShaderModule::Initialize(const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-                                    ShaderModuleParseResult* parseResult) {
-    DAWN_TRY(InitializeBase(parseResult));
-
-    if (GetCompilationMessages()->HasWarningsOrErrors()) {
-        // Cached parse result already shows it is an invalid shader.
-        // No need to create the real shader module on the backend.
-        return {};
-    }
-
-    auto desc = ToAPI(*descriptor);
-    mInnerHandle =
-        ToBackend(GetDevice())
-            ->wgpu.deviceCreateShaderModule(ToBackend(GetDevice())->GetInnerHandle(), desc);
-    DAWN_ASSERT(mInnerHandle);
-
-    return {};
+      ObjectWGPU(device->wgpu.shaderModuleRelease) {
+    mInnerHandle = innerShaderModule;
 }
 
 }  // namespace dawn::native::webgpu
diff --git a/src/dawn/native/webgpu/ShaderModuleWGPU.h b/src/dawn/native/webgpu/ShaderModuleWGPU.h
index 59c6b4b4..df8f219 100644
--- a/src/dawn/native/webgpu/ShaderModuleWGPU.h
+++ b/src/dawn/native/webgpu/ShaderModuleWGPU.h
@@ -43,16 +43,14 @@
     static ResultOrError<Ref<ShaderModule>> Create(
         Device* device,
         const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-        const std::vector<tint::wgsl::Extension>& internalExtensions,
-        ShaderModuleParseResult* parseResult);
+        const std::vector<tint::wgsl::Extension>& internalExtensions);
 
   private:
     ShaderModule(Device* device,
                  const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-                 std::vector<tint::wgsl::Extension> internalExtensions);
+                 std::vector<tint::wgsl::Extension> internalExtensions,
+                 WGPUShaderModule innerShaderModule);
     ~ShaderModule() override = default;
-    MaybeError Initialize(const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
-                          ShaderModuleParseResult* parseResult);
 };
 
 }  // namespace dawn::native::webgpu
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.h b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
index 57d4130..99ff686 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.h
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
@@ -117,8 +117,7 @@
     MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>,
                 CreateShaderModuleImpl,
                 (const UnpackedPtr<ShaderModuleDescriptor>&,
-                 const std::vector<tint::wgsl::Extension>&,
-                 ShaderModuleParseResult*),
+                 const std::vector<tint::wgsl::Extension>&),
                 (override));
     MOCK_METHOD(ResultOrError<Ref<SwapChainBase>>,
                 CreateSwapChainImpl,
diff --git a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
index de47fcb..cd7ed94 100644
--- a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
@@ -54,14 +54,7 @@
 
     Ref<ShaderModuleMock> shaderModule =
         AcquireRef(new NiceMock<ShaderModuleMock>(device, descriptor));
-
-    ShaderModuleParseResult parseResult =
-        ParseShaderModule(BuildShaderModuleParseRequest(device, shaderModule->GetHash(), descriptor,
-                                                        {},
-                                                        /* needReflection*/ true))
-            .AcquireSuccess();
-
-    shaderModule->InitializeBase(&parseResult).AcquireSuccess();
+    shaderModule->Initialize();
     return shaderModule;
 }