Produce tint::ast::Module in the frontend if UseTintGenerator

This factors code to move parsing of tint::ast::Module to the
frontend. All backends will use this code path when
UseTintGenerator is enabled for both SPIR-V and WGSL ingestion.

To avoid too much code explosion, parsing and validating the
shader is moved into ValidateShaderModuleDescriptor which
returns a result struct that gets passed into creation.

Bug: dawn:571
Change-Id: I598693ef36954fd0056a0744a2a0ebd7cc7d40a4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32301
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 3e838ad..72d7606 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -584,7 +584,8 @@
     }
 
     ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
-        const ShaderModuleDescriptor* descriptor) {
+        const ShaderModuleDescriptor* descriptor,
+        ShaderModuleParseResult* parseResult) {
         ShaderModuleBase blueprint(this, descriptor);
 
         const size_t blueprintHash = blueprint.ComputeContentHash();
@@ -597,7 +598,18 @@
         }
 
         ShaderModuleBase* backendObj;
-        DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor));
+        if (parseResult == nullptr) {
+            // We skip the parse on creation if validation isn't enabled which let's us quickly
+            // lookup in the cache without validating and parsing. We need the parsed module now, so
+            // call validate. Most of |ValidateShaderModuleDescriptor| is parsing, but we can
+            // consider splitting it if additional validation is added.
+            ASSERT(!IsValidationEnabled());
+            ShaderModuleParseResult localParseResult =
+                ValidateShaderModuleDescriptor(this, descriptor).AcquireSuccess();
+            DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor, &localParseResult));
+        } else {
+            DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor, parseResult));
+        }
         backendObj->SetIsCachedReference();
         backendObj->SetContentHash(blueprintHash);
         mCaches->shaderModules.insert(backendObj);
@@ -1062,10 +1074,15 @@
     MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
                                                       const ShaderModuleDescriptor* descriptor) {
         DAWN_TRY(ValidateIsAlive());
+
+        ShaderModuleParseResult parseResult = {};
+        ShaderModuleParseResult* parseResultPtr = nullptr;
         if (IsValidationEnabled()) {
-            DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
+            DAWN_TRY_ASSIGN(parseResult, ValidateShaderModuleDescriptor(this, descriptor));
+            parseResultPtr = &parseResult;
         }
-        DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor));
+
+        DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor, parseResultPtr));
         return {};
     }
 
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 08a4c80..f089c76 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -40,6 +40,7 @@
     class PersistentCache;
     class StagingBufferBase;
     struct InternalPipelineStore;
+    struct ShaderModuleParseResult;
 
     class DeviceBase {
       public:
@@ -129,7 +130,8 @@
         void UncacheSampler(SamplerBase* obj);
 
         ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
-            const ShaderModuleDescriptor* descriptor);
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult);
         void UncacheShaderModule(ShaderModuleBase* obj);
 
         Ref<AttachmentState> GetOrCreateAttachmentState(AttachmentStateBlueprint* blueprint);
@@ -275,7 +277,8 @@
         virtual ResultOrError<SamplerBase*> CreateSamplerImpl(
             const SamplerDescriptor* descriptor) = 0;
         virtual ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
-            const ShaderModuleDescriptor* descriptor) = 0;
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult) = 0;
         virtual ResultOrError<SwapChainBase*> CreateSwapChainImpl(
             const SwapChainDescriptor* descriptor) = 0;
         // Note that previousSwapChain may be nullptr, or come from a different backend.
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 334a40e..6e0781d 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -173,13 +173,12 @@
         }
 
 #ifdef DAWN_ENABLE_WGSL
-        MaybeError ValidateWGSL(const char* source) {
+        ResultOrError<tint::ast::Module> ParseWGSL(const char* wgsl) {
             std::ostringstream errorStream;
-            errorStream << "Tint WGSL failure:" << std::endl;
+            errorStream << "Tint WGSL reader failure:" << std::endl;
 
-            tint::Source::File file("", source);
+            tint::Source::File file("", wgsl);
             tint::reader::wgsl::Parser parser(&file);
-
             if (!parser.Parse()) {
                 errorStream << "Parser: " << parser.error() << std::endl;
                 return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
@@ -191,14 +190,46 @@
                 return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
             }
 
-            tint::TypeDeterminer type_determiner(&module);
-            if (!type_determiner.Determine()) {
-                errorStream << "Type Determination: " << type_determiner.error();
+            tint::TypeDeterminer typeDeterminer(&module);
+            if (!typeDeterminer.Determine()) {
+                errorStream << "Type Determination: " << typeDeterminer.error();
                 return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
             }
 
+            return std::move(module);
+        }
+
+        ResultOrError<tint::ast::Module> ParseSPIRV(const std::vector<uint32_t>& spirv) {
+            std::ostringstream errorStream;
+            errorStream << "Tint SPIRV reader failure:" << std::endl;
+
+            tint::reader::spirv::Parser parser(spirv);
+            if (!parser.Parse()) {
+                errorStream << "Parser: " << parser.error() << std::endl;
+                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+            }
+
+            tint::ast::Module module = parser.module();
+            if (!module.IsValid()) {
+                errorStream << "Invalid module generated..." << std::endl;
+                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+            }
+
+            tint::TypeDeterminer typeDeterminer(&module);
+            if (!typeDeterminer.Determine()) {
+                errorStream << "Type Determination: " << typeDeterminer.error();
+                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+            }
+
+            return std::move(module);
+        }
+
+        MaybeError ValidateModule(tint::ast::Module* module) {
+            std::ostringstream errorStream;
+            errorStream << "Tint module validation" << std::endl;
+
             tint::Validator validator;
-            if (!validator.Validate(&module)) {
+            if (!validator.Validate(module)) {
                 errorStream << "Validation: " << validator.error() << std::endl;
                 return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
             }
@@ -206,111 +237,9 @@
             return {};
         }
 
-        ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) {
+        ResultOrError<std::vector<uint32_t>> ModuleToSPIRV(tint::ast::Module module) {
             std::ostringstream errorStream;
-            errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
-
-            tint::Source::File file("", source);
-            tint::reader::wgsl::Parser parser(&file);
-
-            // TODO: This is a duplicate parse with ValidateWGSL, need to store
-            // state between calls to avoid this.
-            if (!parser.Parse()) {
-                errorStream << "Parser: " << parser.error() << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::ast::Module module = parser.module();
-            if (!module.IsValid()) {
-                errorStream << "Invalid module generated..." << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::TypeDeterminer type_determiner(&module);
-            if (!type_determiner.Determine()) {
-                errorStream << "Type Determination: " << type_determiner.error();
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::writer::spirv::Generator generator(std::move(module));
-            if (!generator.Generate()) {
-                errorStream << "Generator: " << generator.error() << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            std::vector<uint32_t> spirv = generator.result();
-
-            DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
-
-            return std::move(spirv);
-        }
-
-        ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
-            const char* source,
-            const VertexStateDescriptor& vertexState,
-            const std::string& entryPoint,
-            uint32_t pullingBufferBindingSet) {
-            std::ostringstream errorStream;
-            errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
-
-            tint::Source::File file("", source);
-            tint::reader::wgsl::Parser parser(&file);
-
-            // TODO: This is a duplicate parse with ValidateWGSL, need to store
-            // state between calls to avoid this.
-            if (!parser.Parse()) {
-                errorStream << "Parser: " << parser.error() << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::ast::Module module = parser.module();
-            if (!module.IsValid()) {
-                errorStream << "Invalid module generated..." << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::transform::Manager transformManager;
-            {
-                auto transform = std::make_unique<tint::transform::VertexPulling>();
-                tint::transform::VertexStateDescriptor state;
-
-                for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
-                    auto& vertexBuffer = vertexState.vertexBuffers[i];
-                    tint::transform::VertexBufferLayoutDescriptor layout;
-                    layout.array_stride = vertexBuffer.arrayStride;
-                    layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
-
-                    for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
-                        auto& attribute = vertexBuffer.attributes[j];
-                        tint::transform::VertexAttributeDescriptor attr;
-                        attr.format = ToTintVertexFormat(attribute.format);
-                        attr.offset = attribute.offset;
-                        attr.shader_location = attribute.shaderLocation;
-
-                        layout.attributes.push_back(std::move(attr));
-                    }
-
-                    state.push_back(std::move(layout));
-                }
-                transform->SetVertexState(std::move(state));
-                transform->SetEntryPoint(entryPoint);
-                transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
-                transformManager.append(std::move(transform));
-            }
-
-            auto result = transformManager.Run(&module);
-            if (result.diagnostics.contains_errors()) {
-                errorStream << "Vertex pulling transform: "
-                            << tint::diag::Formatter{}.format(result.diagnostics);
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-            module = std::move(result.module);
-
-            tint::TypeDeterminer type_determiner(&module);
-            if (!type_determiner.Determine()) {
-                errorStream << "Type Determination: " << type_determiner.error();
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
+            errorStream << "Tint SPIR-V writer failure:" << std::endl;
 
             tint::writer::spirv::Generator generator(std::move(module));
             if (!generator.Generate()) {
@@ -745,38 +674,16 @@
         // completed using PopulateMetadataUsingSPIRVCross. In the future, once
         // this function is complete, ReflectShaderUsingSPIRVCross and
         // PopulateMetadataUsingSPIRVCross will be removed.
-        ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(DeviceBase* device,
-                                                                      std::vector<uint32_t> spirv) {
+        ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(
+            DeviceBase* device,
+            const tint::ast::Module& module) {
 #ifdef DAWN_ENABLE_WGSL
+            ASSERT(module.IsValid());
+
             EntryPointMetadataTable result;
             std::ostringstream errorStream;
             errorStream << "Tint Reflection failure:" << std::endl;
 
-            tint::reader::spirv::Parser parser(spirv);
-
-            if (!parser.Parse()) {
-                errorStream << "Parser: " << parser.error() << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::ast::Module module = parser.module();
-            if (!module.IsValid()) {
-                errorStream << "Invalid module generated..." << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::TypeDeterminer typeDeterminer(&module);
-            if (!typeDeterminer.Determine()) {
-                errorStream << "Type Determination: " << typeDeterminer.error();
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
-            tint::Validator validator;
-            if (!validator.Validate(&module)) {
-                errorStream << "Validation: " << validator.error() << std::endl;
-                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-            }
-
             tint::inspector::Inspector inspector(module);
             auto entryPoints = inspector.GetEntryPoints();
             if (inspector.has_error()) {
@@ -862,8 +769,17 @@
 
     }  // anonymous namespace
 
-    MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
-                                              const ShaderModuleDescriptor* descriptor) {
+    ShaderModuleParseResult::ShaderModuleParseResult() = default;
+    ShaderModuleParseResult::~ShaderModuleParseResult() = default;
+
+    ShaderModuleParseResult::ShaderModuleParseResult(ShaderModuleParseResult&& rhs) = default;
+
+    ShaderModuleParseResult& ShaderModuleParseResult::operator=(ShaderModuleParseResult&& rhs) =
+        default;
+
+    ResultOrError<ShaderModuleParseResult> ValidateShaderModuleDescriptor(
+        DeviceBase* device,
+        const ShaderModuleDescriptor* descriptor) {
         const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
         if (chainedDescriptor == nullptr) {
             return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor");
@@ -874,11 +790,29 @@
                 "Shader module descriptor chained nextInChain must be nullptr");
         }
 
+        ShaderModuleParseResult parseResult = {};
         switch (chainedDescriptor->sType) {
             case wgpu::SType::ShaderModuleSPIRVDescriptor: {
                 const auto* spirvDesc =
                     static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
-                DAWN_TRY(ValidateSpirv(spirvDesc->code, spirvDesc->codeSize));
+                std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+                if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+#ifdef DAWN_ENABLE_WGSL
+                    tint::ast::Module module;
+                    DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
+                    if (device->IsValidationEnabled()) {
+                        DAWN_TRY(ValidateModule(&module));
+                    }
+                    parseResult.tintModule = std::make_unique<tint::ast::Module>(std::move(module));
+#else
+                    return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
+#endif  // DAWN_ENABLE_WGSL
+                } else {
+                    if (device->IsValidationEnabled()) {
+                        DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
+                    }
+                    parseResult.spirv = std::move(spirv);
+                }
                 break;
             }
 
@@ -886,17 +820,35 @@
 #ifdef DAWN_ENABLE_WGSL
                 const auto* wgslDesc =
                     static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
-                DAWN_TRY(ValidateWGSL(wgslDesc->source));
+
+                if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+                    tint::ast::Module module;
+                    DAWN_TRY_ASSIGN(module, ParseWGSL(wgslDesc->source));
+                    if (device->IsValidationEnabled()) {
+                        DAWN_TRY(ValidateModule(&module));
+                    }
+                    parseResult.tintModule = std::make_unique<tint::ast::Module>(std::move(module));
+                } else {
+                    tint::ast::Module module;
+                    DAWN_TRY_ASSIGN(module, ParseWGSL(wgslDesc->source));
+                    if (device->IsValidationEnabled()) {
+                        DAWN_TRY(ValidateModule(&module));
+                    }
+                    std::vector<uint32_t> spirv;
+                    DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(std::move(module)));
+                    DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
+                    parseResult.spirv = std::move(spirv);
+                }
                 break;
 #else
-                return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
+                return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
 #endif  // DAWN_ENABLE_WGSL
             }
             default:
                 return DAWN_VALIDATION_ERROR("Unsupported sType");
         }
 
-        return {};
+        return std::move(parseResult);
     }
 
     RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
@@ -910,6 +862,23 @@
         return bufferSizes;
     }
 
+#ifdef DAWN_ENABLE_WGSL
+    ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
+                                                   tint::ast::Module* module) {
+        tint::transform::Transform::Output output = manager->Run(module);
+        if (output.diagnostics.contains_errors()) {
+            std::string err =
+                "Tint transform failure: " + tint::diag::Formatter{}.format(output.diagnostics);
+            return DAWN_VALIDATION_ERROR(err.c_str());
+        }
+
+        if (!output.module.IsValid()) {
+            return DAWN_VALIDATION_ERROR("Tint transform did not produce valid module.");
+        }
+        return std::move(output.module);
+    }
+#endif
+
     MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
                                                        const EntryPointMetadata& entryPoint,
                                                        const PipelineLayoutBase* layout) {
@@ -994,49 +963,136 @@
     }
 
     const std::vector<uint32_t>& ShaderModuleBase::GetSpirv() const {
+        ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
         return mSpirv;
     }
 
 #ifdef DAWN_ENABLE_WGSL
     ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
+        const std::vector<uint32_t>& spirv,
         const VertexStateDescriptor& vertexState,
         const std::string& entryPoint,
         uint32_t pullingBufferBindingSet) const {
-        std::vector<uint32_t> spirv;
-        DAWN_TRY_ASSIGN(spirv, ConvertWGSLToSPIRVWithPulling(mWgsl.c_str(), vertexState, entryPoint,
-                                                             pullingBufferBindingSet));
+        tint::ast::Module module;
+        DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
+
+        return GeneratePullingSpirv(&module, vertexState, entryPoint, pullingBufferBindingSet);
+    }
+
+    ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
+        tint::ast::Module* moduleIn,
+        const VertexStateDescriptor& vertexState,
+        const std::string& entryPoint,
+        uint32_t pullingBufferBindingSet) const {
+        std::ostringstream errorStream;
+        errorStream << "Tint vertex pulling failure:" << std::endl;
+
+        tint::transform::Manager transformManager;
+        {
+            auto transform = std::make_unique<tint::transform::VertexPulling>();
+            tint::transform::VertexStateDescriptor state;
+            for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
+                const auto& vertexBuffer = vertexState.vertexBuffers[i];
+                tint::transform::VertexBufferLayoutDescriptor layout;
+                layout.array_stride = vertexBuffer.arrayStride;
+                layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
+
+                for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
+                    const auto& attribute = vertexBuffer.attributes[j];
+                    tint::transform::VertexAttributeDescriptor attr;
+                    attr.format = ToTintVertexFormat(attribute.format);
+                    attr.offset = attribute.offset;
+                    attr.shader_location = attribute.shaderLocation;
+
+                    layout.attributes.push_back(std::move(attr));
+                }
+
+                state.push_back(std::move(layout));
+            }
+            transform->SetVertexState(std::move(state));
+            transform->SetEntryPoint(entryPoint);
+            transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
+            transformManager.append(std::move(transform));
+        }
+        if (GetDevice()->IsRobustnessEnabled()) {
+            // TODO(enga): Run the Tint BoundArrayAccessors transform instead of the SPIRV Tools
+            // one, but it appears to crash after running VertexPulling.
+            // transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
+        }
+
+        tint::ast::Module module;
+        DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, moduleIn));
+
+        tint::writer::spirv::Generator generator(std::move(module));
+        if (!generator.Generate()) {
+            errorStream << "Generator: " << generator.error() << std::endl;
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        std::vector<uint32_t> spirv = generator.result();
         if (GetDevice()->IsRobustnessEnabled()) {
             DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv));
         }
-
+        DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
         return std::move(spirv);
     }
 #endif
 
-    MaybeError ShaderModuleBase::InitializeBase() {
-        std::vector<uint32_t> spirv;
-        if (mType == Type::Wgsl) {
+    MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
 #ifdef DAWN_ENABLE_WGSL
-            DAWN_TRY_ASSIGN(spirv, ConvertWGSLToSPIRV(mWgsl.c_str()));
+        tint::ast::Module* module = parseResult->tintModule.get();
+#endif
+        mSpirv = std::move(parseResult->spirv);
+
+        // If not using Tint to generate backend code, run the robust buffer access pass now since
+        // all backends will use this SPIR-V. If Tint is used, the robustness pass should be run
+        // per-backend.
+        if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator) &&
+            GetDevice()->IsRobustnessEnabled()) {
+            DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv));
+        }
+
+        // We still need the spirv for reflection. Remove this when we use the Tint inspector
+        // completely.
+        std::vector<uint32_t>* spirvPtr = &mSpirv;
+        std::vector<uint32_t> localSpirv;
+        if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+#ifdef DAWN_ENABLE_WGSL
+            ASSERT(module != nullptr);
+            tint::ast::Module clonedModule = module->Clone();
+            tint::TypeDeterminer typeDeterminer(&clonedModule);
+            if (!typeDeterminer.Determine()) {
+                return DAWN_VALIDATION_ERROR(typeDeterminer.error().c_str());
+            }
+            DAWN_TRY_ASSIGN(localSpirv, ModuleToSPIRV(std::move(clonedModule)));
+            DAWN_TRY(ValidateSpirv(localSpirv.data(), localSpirv.size()));
+            spirvPtr = &localSpirv;
+#else
+            UNREACHABLE();
+#endif
+        }
+
+        if (GetDevice()->IsToggleEnabled(Toggle::UseTintInspector)) {
+#ifdef DAWN_ENABLE_WGSL
+            tint::ast::Module localModule;
+
+            tint::ast::Module* modulePtr = module;
+            if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+                // We have mSpirv, but no Tint module
+                DAWN_TRY_ASSIGN(localModule, ParseSPIRV(mSpirv));
+                DAWN_TRY(ValidateModule(&localModule));
+                modulePtr = &localModule;
+            }
+
+            EntryPointMetadataTable table;
+            DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), *modulePtr));
+            DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), *spirvPtr, &table));
+            mEntryPoints = std::move(table);
 #else
             return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
-#endif  // DAWN_ENABLE_WGSL
+#endif
         } else {
-            spirv = mOriginalSpirv;
-        }
-
-        if (GetDevice()->IsRobustnessEnabled()) {
-            DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv));
-        }
-
-        mSpirv = std::move(spirv);
-        if (GetDevice()->IsToggleEnabled(Toggle::UseTintInspector)) {
-            EntryPointMetadataTable table;
-            DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), mSpirv));
-            DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), mSpirv, &table));
-            mEntryPoints = std::move(table);
-        } else {
-            DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), mSpirv));
+            DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), *spirvPtr));
         }
 
         return {};
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index e406f49..e05971a 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -31,6 +31,18 @@
 #include <unordered_map>
 #include <vector>
 
+namespace tint {
+
+    namespace ast {
+        class Module;
+    }  // namespace ast
+
+    namespace transform {
+        class Manager;
+    }  // namespace transform
+
+}  // namespace tint
+
 namespace spirv_cross {
     class Compiler;
 }
@@ -43,14 +55,31 @@
     using EntryPointMetadataTable =
         std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>;
 
-    MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
-                                              const ShaderModuleDescriptor* descriptor);
+    struct ShaderModuleParseResult {
+        ShaderModuleParseResult();
+        ~ShaderModuleParseResult();
+        ShaderModuleParseResult(ShaderModuleParseResult&& rhs);
+        ShaderModuleParseResult& operator=(ShaderModuleParseResult&& rhs);
+
+#ifdef DAWN_ENABLE_WGSL
+        std::unique_ptr<tint::ast::Module> tintModule;
+#endif
+        std::vector<uint32_t> spirv;
+    };
+
+    ResultOrError<ShaderModuleParseResult> ValidateShaderModuleDescriptor(
+        DeviceBase* device,
+        const ShaderModuleDescriptor* descriptor);
     MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
                                                        const EntryPointMetadata& entryPoint,
                                                        const PipelineLayoutBase* layout);
 
     RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
                                                             const PipelineLayoutBase* layout);
+#ifdef DAWN_ENABLE_WGSL
+    ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
+                                                   tint::ast::Module* module);
+#endif
 
     // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
     // stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so
@@ -116,13 +145,20 @@
 
 #ifdef DAWN_ENABLE_WGSL
         ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
+            const std::vector<uint32_t>& spirv,
+            const VertexStateDescriptor& vertexState,
+            const std::string& entryPoint,
+            uint32_t pullingBufferBindingSet) const;
+
+        ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
+            tint::ast::Module* module,
             const VertexStateDescriptor& vertexState,
             const std::string& entryPoint,
             uint32_t pullingBufferBindingSet) const;
 #endif
 
       protected:
-        MaybeError InitializeBase();
+        MaybeError InitializeBase(ShaderModuleParseResult* parseResult);
 
       private:
         ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
diff --git a/src/dawn_native/Toggles.cpp b/src/dawn_native/Toggles.cpp
index 3ca4dad..f48a274 100644
--- a/src/dawn_native/Toggles.cpp
+++ b/src/dawn_native/Toggles.cpp
@@ -143,7 +143,7 @@
               "http://crbug.com/1138528"}},
             {Toggle::UseTintGenerator,
              {"use_tint_generator", "Use Tint instead of SPRIV-cross to generate shaders.",
-              "https://crbug.com/dawn/548"}},
+              "https://crbug.com/dawn/571"}},
             {Toggle::UseTintInspector,
              {"use_tint_inspector", "Use Tint instead of SPRIV-cross for shader reflection.",
               "https://crbug.com/dawn/578"}},
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index 4db2133..d28767c 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -325,8 +325,9 @@
         return new Sampler(this, descriptor);
     }
     ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
-        const ShaderModuleDescriptor* descriptor) {
-        return ShaderModule::Create(this, descriptor);
+        const ShaderModuleDescriptor* descriptor,
+        ShaderModuleParseResult* parseResult) {
+        return ShaderModule::Create(this, descriptor, parseResult);
     }
     ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
         const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 5d54a3e..dc41448 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -160,7 +160,8 @@
             const RenderPipelineDescriptor* descriptor) override;
         ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
         ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
-            const ShaderModuleDescriptor* descriptor) override;
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult) override;
         ResultOrError<SwapChainBase*> CreateSwapChainImpl(
             const SwapChainDescriptor* descriptor) override;
         ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index d415225..8daf854 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -174,9 +174,10 @@
 
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
-                                                      const ShaderModuleDescriptor* descriptor) {
+                                                      const ShaderModuleDescriptor* descriptor,
+                                                      ShaderModuleParseResult* parseResult) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->InitializeBase());
+        DAWN_TRY(module->Initialize(parseResult));
         return module.Detach();
     }
 
@@ -184,6 +185,14 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
+    MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+        DAWN_TRY(InitializeBase(parseResult));
+#ifdef DAWN_ENABLE_WGSL
+        mTintModule = std::move(parseResult->tintModule);
+#endif
+        return {};
+    }
+
     ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint(
         const char* entryPointName,
         SingleShaderStage stage,
@@ -195,41 +204,11 @@
         std::ostringstream errorStream;
         errorStream << "Tint HLSL failure:" << std::endl;
 
-        // TODO: Remove redundant SPIRV step between WGSL and HLSL.
-        tint::reader::spirv::Parser parser(GetSpirv());
-
-        if (!parser.Parse()) {
-            errorStream << "Parser: " << parser.error() << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-
-        tint::ast::Module module = parser.module();
-        if (!module.IsValid()) {
-            errorStream << "Invalid module generated..." << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-
-        tint::TypeDeterminer typeDeterminer(&module);
-        if (!typeDeterminer.Determine()) {
-            errorStream << "Type Determination: " << typeDeterminer.error();
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-
-        tint::Validator validator;
-        if (!validator.Validate(&module)) {
-            errorStream << "Validation: " << validator.error() << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-
         tint::transform::Manager transformManager;
         transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
-        auto result = transformManager.Run(&module);
-        if (result.diagnostics.contains_errors()) {
-            errorStream << "Bound Array Accessors Transform: "
-                        << tint::diag::Formatter{}.format(result.diagnostics);
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-        module = std::move(result.module);
+
+        tint::ast::Module module;
+        DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, mTintModule.get()));
 
         ASSERT(remappedEntryPointName != nullptr);
         tint::inspector::Inspector inspector(module);
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 63cb2e9..f5a94f0 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -36,7 +36,8 @@
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<ShaderModule*> Create(Device* device,
-                                                   const ShaderModuleDescriptor* descriptor);
+                                                   const ShaderModuleDescriptor* descriptor,
+                                                   ShaderModuleParseResult* parseResult);
 
         ResultOrError<CompiledShader> Compile(const char* entryPointName,
                                               SingleShaderStage stage,
@@ -46,6 +47,7 @@
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
+        MaybeError Initialize(ShaderModuleParseResult* parseResult);
 
         ResultOrError<std::string> TranslateToHLSLWithTint(
             const char* entryPointName,
@@ -61,6 +63,10 @@
                                          SingleShaderStage stage,
                                          const std::string& hlslSource,
                                          uint32_t compileFlags) const;
+
+#ifdef DAWN_ENABLE_WGSL
+        std::unique_ptr<tint::ast::Module> mTintModule;
+#endif
     };
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index c582758..53da499 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -92,7 +92,8 @@
             const RenderPipelineDescriptor* descriptor) override;
         ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
         ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
-            const ShaderModuleDescriptor* descriptor) override;
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult) override;
         ResultOrError<SwapChainBase*> CreateSwapChainImpl(
             const SwapChainDescriptor* descriptor) override;
         ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 3846d0d..37df56a 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -151,8 +151,9 @@
         return Sampler::Create(this, descriptor);
     }
     ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
-        const ShaderModuleDescriptor* descriptor) {
-        return ShaderModule::Create(this, descriptor);
+        const ShaderModuleDescriptor* descriptor,
+        ShaderModuleParseResult* parseResult) {
+        return ShaderModule::Create(this, descriptor, parseResult);
     }
     ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
         const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 6ecf57d..3fb2dc2 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -35,7 +35,8 @@
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<ShaderModule*> Create(Device* device,
-                                                   const ShaderModuleDescriptor* descriptor);
+                                                   const ShaderModuleDescriptor* descriptor,
+                                                   ShaderModuleParseResult* parseResult);
 
         struct MetalFunctionData {
             NSPRef<id<MTLFunction>> function;
@@ -51,7 +52,11 @@
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
-        MaybeError Initialize();
+        MaybeError Initialize(ShaderModuleParseResult* parseResult);
+
+#ifdef DAWN_ENABLE_WGSL
+        std::unique_ptr<tint::ast::Module> mTintModule;
+#endif
     };
 
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 5af363b..6d52770 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -22,15 +22,24 @@
 
 #include <spirv_msl.hpp>
 
+#ifdef DAWN_ENABLE_WGSL
+// Tint include must be after spirv_msl.hpp, because spirv-cross has its own
+// version of spirv_headers. We also need to undef SPV_REVISION because SPIRV-Cross
+// is at 3 while spirv-headers is at 4.
+#    undef SPV_REVISION
+#    include <tint/tint.h>
+#endif  // DAWN_ENABLE_WGSL
+
 #include <sstream>
 
 namespace dawn_native { namespace metal {
 
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
-                                                      const ShaderModuleDescriptor* descriptor) {
+                                                      const ShaderModuleDescriptor* descriptor,
+                                                      ShaderModuleParseResult* parseResult) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->Initialize());
+        DAWN_TRY(module->Initialize(parseResult));
         return module.Detach();
     }
 
@@ -38,8 +47,12 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize() {
-        return InitializeBase();
+    MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+        DAWN_TRY(InitializeBase(parseResult));
+#ifdef DAWN_ENABLE_WGSL
+        mTintModule = std::move(parseResult->tintModule);
+#endif
+        return {};
     }
 
     MaybeError ShaderModule::CreateFunction(const char* entryPointName,
@@ -59,9 +72,17 @@
         std::vector<uint32_t> pullingSpirv;
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
             stage == SingleShaderStage::Vertex) {
-            DAWN_TRY_ASSIGN(pullingSpirv,
-                            GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
-                                                 entryPointName, kPullingBufferBindingSet));
+            if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+                DAWN_TRY_ASSIGN(pullingSpirv,
+                                GeneratePullingSpirv(mTintModule.get(),
+                                                     *renderPipeline->GetVertexStateDescriptor(),
+                                                     entryPointName, kPullingBufferBindingSet));
+            } else {
+                DAWN_TRY_ASSIGN(
+                    pullingSpirv,
+                    GeneratePullingSpirv(GetSpirv(), *renderPipeline->GetVertexStateDescriptor(),
+                                         entryPointName, kPullingBufferBindingSet));
+            }
             spirv = &pullingSpirv;
         }
 #endif
diff --git a/src/dawn_native/null/DeviceNull.cpp b/src/dawn_native/null/DeviceNull.cpp
index fe981d2..a821acf 100644
--- a/src/dawn_native/null/DeviceNull.cpp
+++ b/src/dawn_native/null/DeviceNull.cpp
@@ -127,9 +127,10 @@
         return new Sampler(this, descriptor);
     }
     ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
-        const ShaderModuleDescriptor* descriptor) {
+        const ShaderModuleDescriptor* descriptor,
+        ShaderModuleParseResult* parseResult) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
-        DAWN_TRY(module->Initialize());
+        DAWN_TRY(module->Initialize(parseResult));
         return module.Detach();
     }
     ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
@@ -395,8 +396,8 @@
 
     // ShaderModule
 
-    MaybeError ShaderModule::Initialize() {
-        return InitializeBase();
+    MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+        return InitializeBase(parseResult);
     }
 
     // OldSwapChain
diff --git a/src/dawn_native/null/DeviceNull.h b/src/dawn_native/null/DeviceNull.h
index 2ffadbe..9a73a9b 100644
--- a/src/dawn_native/null/DeviceNull.h
+++ b/src/dawn_native/null/DeviceNull.h
@@ -135,7 +135,8 @@
             const RenderPipelineDescriptor* descriptor) override;
         ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
         ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
-            const ShaderModuleDescriptor* descriptor) override;
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult) override;
         ResultOrError<SwapChainBase*> CreateSwapChainImpl(
             const SwapChainDescriptor* descriptor) override;
         ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
@@ -246,7 +247,7 @@
       public:
         using ShaderModuleBase::ShaderModuleBase;
 
-        MaybeError Initialize();
+        MaybeError Initialize(ShaderModuleParseResult* parseResult);
     };
 
     class SwapChain final : public NewSwapChainBase {
diff --git a/src/dawn_native/opengl/DeviceGL.cpp b/src/dawn_native/opengl/DeviceGL.cpp
index 4ddb91c..9099ae3 100644
--- a/src/dawn_native/opengl/DeviceGL.cpp
+++ b/src/dawn_native/opengl/DeviceGL.cpp
@@ -128,8 +128,9 @@
         return new Sampler(this, descriptor);
     }
     ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
-        const ShaderModuleDescriptor* descriptor) {
-        return ShaderModule::Create(this, descriptor);
+        const ShaderModuleDescriptor* descriptor,
+        ShaderModuleParseResult* parseResult) {
+        return ShaderModule::Create(this, descriptor, parseResult);
     }
     ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
         const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/opengl/DeviceGL.h b/src/dawn_native/opengl/DeviceGL.h
index d19ffe2..1d1f087 100644
--- a/src/dawn_native/opengl/DeviceGL.h
+++ b/src/dawn_native/opengl/DeviceGL.h
@@ -91,7 +91,8 @@
             const RenderPipelineDescriptor* descriptor) override;
         ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
         ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
-            const ShaderModuleDescriptor* descriptor) override;
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult) override;
         ResultOrError<SwapChainBase*> CreateSwapChainImpl(
             const SwapChainDescriptor* descriptor) override;
         ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index 792a2fc..a3b7dda 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -59,9 +59,10 @@
 
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
-                                                      const ShaderModuleDescriptor* descriptor) {
+                                                      const ShaderModuleDescriptor* descriptor,
+                                                      ShaderModuleParseResult* parseResult) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->InitializeBase());
+        DAWN_TRY(module->InitializeBase(parseResult));
         return module.Detach();
     }
 
diff --git a/src/dawn_native/opengl/ShaderModuleGL.h b/src/dawn_native/opengl/ShaderModuleGL.h
index c18fa48..a003e19 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.h
+++ b/src/dawn_native/opengl/ShaderModuleGL.h
@@ -47,7 +47,8 @@
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<ShaderModule*> Create(Device* device,
-                                                   const ShaderModuleDescriptor* descriptor);
+                                                   const ShaderModuleDescriptor* descriptor,
+                                                   ShaderModuleParseResult* parseResult);
 
         std::string TranslateToGLSL(const char* entryPointName,
                                     SingleShaderStage stage,
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index 71f42a8..2fda4ed 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -136,8 +136,9 @@
         return Sampler::Create(this, descriptor);
     }
     ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
-        const ShaderModuleDescriptor* descriptor) {
-        return ShaderModule::Create(this, descriptor);
+        const ShaderModuleDescriptor* descriptor,
+        ShaderModuleParseResult* parseResult) {
+        return ShaderModule::Create(this, descriptor, parseResult);
     }
     ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
         const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 6380481..40b11f7 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -127,7 +127,8 @@
             const RenderPipelineDescriptor* descriptor) override;
         ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
         ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
-            const ShaderModuleDescriptor* descriptor) override;
+            const ShaderModuleDescriptor* descriptor,
+            ShaderModuleParseResult* parseResult) override;
         ResultOrError<SwapChainBase*> CreateSwapChainImpl(
             const SwapChainDescriptor* descriptor) override;
         ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index fb107af..f606252 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -24,12 +24,13 @@
 
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
-                                                      const ShaderModuleDescriptor* descriptor) {
+                                                      const ShaderModuleDescriptor* descriptor,
+                                                      ShaderModuleParseResult* parseResult) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
         if (module == nullptr) {
             return DAWN_VALIDATION_ERROR("Unable to create ShaderModule");
         }
-        DAWN_TRY(module->Initialize());
+        DAWN_TRY(module->Initialize(parseResult));
         return module.Detach();
     }
 
@@ -37,8 +38,8 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize() {
-        DAWN_TRY(InitializeBase());
+    MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+        DAWN_TRY(InitializeBase(parseResult));
         const std::vector<uint32_t>& spirv = GetSpirv();
 
         VkShaderModuleCreateInfo createInfo;
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 720cc5e..621ab0e 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -27,14 +27,15 @@
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<ShaderModule*> Create(Device* device,
-                                                   const ShaderModuleDescriptor* descriptor);
+                                                   const ShaderModuleDescriptor* descriptor,
+                                                   ShaderModuleParseResult* parseResult);
 
         VkShaderModule GetHandle() const;
 
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override;
-        MaybeError Initialize();
+        MaybeError Initialize(ShaderModuleParseResult* parseResult);
 
         VkShaderModule mHandle = VK_NULL_HANDLE;
     };