Make all backend::ShaderModule get SPIRV from the frontend

This will make it easier to support SPIRV as a chained sub-descriptor of
ShaderModuleDescriptor in follow-up CLs.

Also fix a couple style and formatting issues.

Bug: dawn:22
Change-Id: Iddaf1f87edee65687e17670b70024835918a0382
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/19864
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 41f84d9..713dd95 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -324,7 +324,7 @@
     // ShaderModuleBase
 
     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
-        : CachedObject(device), mCode(descriptor->code, descriptor->code + descriptor->codeSize) {
+        : CachedObject(device), mSpirv(descriptor->code, descriptor->code + descriptor->codeSize) {
         mFragmentOutputFormatBaseTypes.fill(Format::Other);
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvcParser)) {
             mSpvcContext.SetUseSpvcParser(true);
@@ -836,7 +836,7 @@
     size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
         size_t hash = 0;
 
-        for (uint32_t word : module->mCode) {
+        for (uint32_t word : module->mSpirv) {
             HashCombine(&hash, word);
         }
 
@@ -845,7 +845,7 @@
 
     bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
                                                     const ShaderModuleBase* b) const {
-        return a->mCode == b->mCode;
+        return a->mSpirv == b->mSpirv;
     }
 
     MaybeError ShaderModuleBase::CheckSpvcSuccess(shaderc_spvc_status status,
@@ -856,7 +856,15 @@
         return {};
     }
 
-    shaderc_spvc::CompileOptions ShaderModuleBase::GetCompileOptions() {
+    shaderc_spvc::Context* ShaderModuleBase::GetContext() {
+        return &mSpvcContext;
+    }
+
+    const std::vector<uint32_t>& ShaderModuleBase::GetSpirv() const {
+        return mSpirv;
+    }
+
+    shaderc_spvc::CompileOptions ShaderModuleBase::GetCompileOptions() const {
         shaderc_spvc::CompileOptions options;
         options.SetValidate(GetDevice()->IsValidationEnabled());
         return options;
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 0653bbf..b3fd839 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -83,13 +83,12 @@
             bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
         };
 
-        shaderc_spvc::Context* GetContext() {
-            return &mSpvcContext;
-        }
+        shaderc_spvc::Context* GetContext();
+        const std::vector<uint32_t>& GetSpirv() const;
 
       protected:
         static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
-        shaderc_spvc::CompileOptions GetCompileOptions();
+        shaderc_spvc::CompileOptions GetCompileOptions() const;
 
         shaderc_spvc::Context mSpvcContext;
 
@@ -103,9 +102,7 @@
         MaybeError ExtractSpirvInfoWithSpvc();
         MaybeError ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler);
 
-        // TODO(cwallez@chromium.org): The code is only stored for deduplication. We could maybe
-        // store a cryptographic hash of the code instead?
-        std::vector<uint32_t> mCode;
+        std::vector<uint32_t> mSpirv;
 
         ModuleBindingInfo mBindingInfo;
         std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index d9289ca..e006d6f 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -28,7 +28,7 @@
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
                                                       const ShaderModuleDescriptor* descriptor) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->Initialize(descriptor));
+        DAWN_TRY(module->Initialize());
         return module.Detach();
     }
 
@@ -36,8 +36,9 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize(const ShaderModuleDescriptor* descriptor) {
-        mSpirv.assign(descriptor->code, descriptor->code + descriptor->codeSize);
+    MaybeError ShaderModule::Initialize() {
+        const std::vector<uint32_t>& spirv = GetSpirv();
+
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
             shaderc_spvc::CompileOptions options = GetCompileOptions();
 
@@ -52,7 +53,7 @@
             options.SetHLSLPointSizeCompat(true);
 
             DAWN_TRY(CheckSpvcSuccess(
-                mSpvcContext.InitializeForHlsl(descriptor->code, descriptor->codeSize, options),
+                mSpvcContext.InitializeForHlsl(spirv.data(), spirv.size(), options),
                 "Unable to initialize instance of spvc"));
 
             spirv_cross::Compiler* compiler;
@@ -60,14 +61,17 @@
                                       "Unable to get cross compiler"));
             DAWN_TRY(ExtractSpirvInfo(*compiler));
         } else {
-            spirv_cross::CompilerHLSL compiler(descriptor->code, descriptor->codeSize);
+            spirv_cross::CompilerHLSL compiler(spirv);
             DAWN_TRY(ExtractSpirvInfo(compiler));
         }
         return {};
     }
 
     ResultOrError<std::string> ShaderModule::GetHLSLSource(PipelineLayout* layout) {
-        std::unique_ptr<spirv_cross::CompilerHLSL> compiler_impl;
+        ASSERT(!IsError());
+        const std::vector<uint32_t>& spirv = GetSpirv();
+
+        std::unique_ptr<spirv_cross::CompilerHLSL> compilerImpl;
         spirv_cross::CompilerHLSL* compiler = nullptr;
         if (!GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
             // If these options are changed, the values in DawnSPIRVCrossHLSLFastFuzzer.cpp need to
@@ -87,8 +91,8 @@
             options_hlsl.point_coord_compat = true;
             options_hlsl.point_size_compat = true;
 
-            compiler_impl = std::make_unique<spirv_cross::CompilerHLSL>(mSpirv);
-            compiler = compiler_impl.get();
+            compilerImpl = std::make_unique<spirv_cross::CompilerHLSL>(spirv);
+            compiler = compilerImpl.get();
             compiler->set_common_options(options_glsl);
             compiler->set_hlsl_options(options_hlsl);
         }
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 289c2db..e34d881 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -32,9 +32,7 @@
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
-        MaybeError Initialize(const ShaderModuleDescriptor* descriptor);
-
-        std::vector<uint32_t> mSpirv;
+        MaybeError Initialize();
     };
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 53e9f7d..d4d41ab 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -51,14 +51,9 @@
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
-        MaybeError Initialize(const ShaderModuleDescriptor* descriptor);
+        MaybeError Initialize();
 
         shaderc_spvc::CompileOptions GetMSLCompileOptions();
-
-        // Calling compile on CompilerMSL somehow changes internal state that makes subsequent
-        // compiles return invalid MSL. We keep the spirv around and recreate the compiler everytime
-        // we need to use it.
-        std::vector<uint32_t> mSpirv;
     };
 
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index a06adcc..d0a8bfc 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -58,7 +58,7 @@
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
                                                       const ShaderModuleDescriptor* descriptor) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->Initialize(descriptor));
+        DAWN_TRY(module->Initialize());
         return module.Detach();
     }
 
@@ -66,21 +66,22 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize(const ShaderModuleDescriptor* descriptor) {
-        mSpirv.assign(descriptor->code, descriptor->code + descriptor->codeSize);
+    MaybeError ShaderModule::Initialize() {
+        const std::vector<uint32_t>& spirv = GetSpirv();
+
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
             shaderc_spvc::CompileOptions options = GetMSLCompileOptions();
 
-            DAWN_TRY(CheckSpvcSuccess(
-                mSpvcContext.InitializeForMsl(descriptor->code, descriptor->codeSize, options),
-                "Unable to initialize instance of spvc"));
+            DAWN_TRY(
+                CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv.data(), spirv.size(), options),
+                                 "Unable to initialize instance of spvc"));
 
             spirv_cross::CompilerMSL* compiler;
             DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)),
                                       "Unable to get cross compiler"));
             DAWN_TRY(ExtractSpirvInfo(*compiler));
         } else {
-            spirv_cross::CompilerMSL compiler(mSpirv);
+            spirv_cross::CompilerMSL compiler(spirv);
             DAWN_TRY(ExtractSpirvInfo(compiler));
         }
         return {};
@@ -92,13 +93,15 @@
                                          ShaderModule::MetalFunctionData* out) {
         ASSERT(!IsError());
         ASSERT(out);
-        std::unique_ptr<spirv_cross::CompilerMSL> compiler_impl;
+        const std::vector<uint32_t>& spirv = GetSpirv();
+
+        std::unique_ptr<spirv_cross::CompilerMSL> compilerImpl;
         spirv_cross::CompilerMSL* compiler;
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
             // Initializing the compiler is needed every call, because this method uses reflection
             // to mutate the compiler's IR.
             DAWN_TRY(CheckSpvcSuccess(
-                mSpvcContext.InitializeForMsl(mSpirv.data(), mSpirv.size(), GetMSLCompileOptions()),
+                mSpvcContext.InitializeForMsl(spirv.data(), spirv.size(), GetMSLCompileOptions()),
                 "Unable to initialize instance of spvc"));
             DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)),
                                       "Unable to get cross compiler"));
@@ -118,8 +121,8 @@
             // the shader storage buffer lengths.
             options_msl.buffer_size_buffer_index = kBufferLengthBufferSlot;
 
-            compiler_impl = std::make_unique<spirv_cross::CompilerMSL>(mSpirv);
-            compiler = compiler_impl.get();
+            compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(spirv);
+            compiler = compilerImpl.get();
             compiler->set_msl_options(options_msl);
         }
 
diff --git a/src/dawn_native/null/DeviceNull.cpp b/src/dawn_native/null/DeviceNull.cpp
index 79da48c..701e720 100644
--- a/src/dawn_native/null/DeviceNull.cpp
+++ b/src/dawn_native/null/DeviceNull.cpp
@@ -125,14 +125,14 @@
     }
     ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
         const ShaderModuleDescriptor* descriptor) {
-        auto module = new ShaderModule(this, descriptor);
+        Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
 
         if (IsToggleEnabled(Toggle::UseSpvc)) {
             shaderc_spvc::CompileOptions options;
             options.SetValidate(IsValidationEnabled());
             shaderc_spvc::Context* context = module->GetContext();
-            shaderc_spvc_status status =
-                context->InitializeForGlsl(descriptor->code, descriptor->codeSize, options);
+            shaderc_spvc_status status = context->InitializeForGlsl(
+                module->GetSpirv().data(), module->GetSpirv().size(), options);
             if (status != shaderc_spvc_status_success) {
                 return DAWN_VALIDATION_ERROR("Unable to initialize instance of spvc");
             }
@@ -144,10 +144,10 @@
             }
             DAWN_TRY(module->ExtractSpirvInfo(*compiler));
         } else {
-            spirv_cross::Compiler compiler(descriptor->code, descriptor->codeSize);
+            spirv_cross::Compiler compiler(module->GetSpirv());
             DAWN_TRY(module->ExtractSpirvInfo(compiler));
         }
-        return module;
+        return module.Detach();
     }
     ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
         const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index 86398c9..8979716 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -51,7 +51,7 @@
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
                                                       const ShaderModuleDescriptor* descriptor) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->Initialize(descriptor));
+        DAWN_TRY(module->Initialize());
         return module.Detach();
     }
 
@@ -67,10 +67,11 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize(const ShaderModuleDescriptor* descriptor) {
-        std::unique_ptr<spirv_cross::CompilerGLSL> compiler_impl;
-        spirv_cross::CompilerGLSL* compiler;
+    MaybeError ShaderModule::Initialize() {
+        const std::vector<uint32_t>& spirv = GetSpirv();
 
+        std::unique_ptr<spirv_cross::CompilerGLSL> compilerImpl;
+        spirv_cross::CompilerGLSL* compiler;
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
             // If these options are changed, the values in DawnSPIRVCrossGLSLFastFuzzer.cpp need to
             // be updated.
@@ -90,7 +91,7 @@
             options.SetGLSLLanguageVersion(440);
 #endif
             DAWN_TRY(CheckSpvcSuccess(
-                mSpvcContext.InitializeForGlsl(descriptor->code, descriptor->codeSize, options),
+                mSpvcContext.InitializeForGlsl(spirv.data(), spirv.size(), options),
                 "Unable to initialize instance of spvc"));
             DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)),
                                       "Unable to get cross compiler"));
@@ -108,15 +109,14 @@
 
             // TODO(cwallez@chromium.org): discover the backing context version and use that.
 #if defined(DAWN_PLATFORM_APPLE)
-        options.version = 410;
+            options.version = 410;
 #else
-        options.version = 440;
+            options.version = 440;
 #endif
 
-        compiler_impl =
-            std::make_unique<spirv_cross::CompilerGLSL>(descriptor->code, descriptor->codeSize);
-        compiler = compiler_impl.get();
-        compiler->set_common_options(options);
+            compilerImpl = std::make_unique<spirv_cross::CompilerGLSL>(spirv);
+            compiler = compilerImpl.get();
+            compiler->set_common_options(options);
         }
 
         DAWN_TRY(ExtractSpirvInfo(*compiler));
diff --git a/src/dawn_native/opengl/ShaderModuleGL.h b/src/dawn_native/opengl/ShaderModuleGL.h
index 849f2da..9e2b5c9 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.h
+++ b/src/dawn_native/opengl/ShaderModuleGL.h
@@ -51,7 +51,7 @@
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
-        MaybeError Initialize(const ShaderModuleDescriptor* descriptor);
+        MaybeError Initialize();
 
         CombinedSamplerInfo mCombinedInfo;
         std::string mGlslSource;
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index 8cc8265..36f9db8 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -28,7 +28,7 @@
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
         if (!module)
             return DAWN_VALIDATION_ERROR("Unable to create ShaderModule");
-        DAWN_TRY(module->Initialize(descriptor));
+        DAWN_TRY(module->Initialize());
         return module.Detach();
     }
 
@@ -36,14 +36,16 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize(const ShaderModuleDescriptor* descriptor) {
+    MaybeError ShaderModule::Initialize() {
+        const std::vector<uint32_t>& spirv = GetSpirv();
+
         // Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to
         // have a translation step eventually anyway.
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
             shaderc_spvc::CompileOptions options = GetCompileOptions();
 
             DAWN_TRY(CheckSpvcSuccess(
-                mSpvcContext.InitializeForVulkan(descriptor->code, descriptor->codeSize, options),
+                mSpvcContext.InitializeForVulkan(spirv.data(), spirv.size(), options),
                 "Unable to initialize instance of spvc"));
 
             spirv_cross::Compiler* compiler;
@@ -51,7 +53,7 @@
                                       "Unable to get cross compiler"));
             DAWN_TRY(ExtractSpirvInfo(*compiler));
         } else {
-            spirv_cross::Compiler compiler(descriptor->code, descriptor->codeSize);
+            spirv_cross::Compiler compiler(spirv);
             DAWN_TRY(ExtractSpirvInfo(compiler));
         }
 
@@ -69,8 +71,8 @@
             createInfo.codeSize = vulkanSource.size() * sizeof(uint32_t);
             createInfo.pCode = vulkanSource.data();
         } else {
-            createInfo.codeSize = descriptor->codeSize * sizeof(uint32_t);
-            createInfo.pCode = descriptor->code;
+            createInfo.codeSize = spirv.size() * sizeof(uint32_t);
+            createInfo.pCode = spirv.data();
         }
 
         Device* device = ToBackend(GetDevice());
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 962ccf8..720cc5e 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -34,7 +34,7 @@
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override;
-        MaybeError Initialize(const ShaderModuleDescriptor* descriptor);
+        MaybeError Initialize();
 
         VkShaderModule mHandle = VK_NULL_HANDLE;
     };