Metal: Use ShaderModule reflection when possible.

This change the Metal backend in preparation for supporting multiple
entrypoints:

 - Explicitly set the spirv_cross entry point before compiling.
 - Moves gathering of the local size to the frontend as it will be
   useful for validation in the future.
 - Query spirv-cross for the modified entrypoint name instead of
   duplicating the code in Dawn.
 - Move some conversion helpers from ShaderModule.cpp to their own
   SpirvUtils file.

Bug: dawn:216
Change-Id: I87d4953428e0bfeb97e39ed22f94d86ae7987782
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28241
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 6b70082..4558f8c 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -253,6 +253,8 @@
     "Sampler.h",
     "ShaderModule.cpp",
     "ShaderModule.h",
+    "SpirvUtils.cpp",
+    "SpirvUtils.h",
     "StagingBuffer.cpp",
     "StagingBuffer.h",
     "Surface.cpp",
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index b7db561..ac435bb 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -131,6 +131,8 @@
     "Sampler.h"
     "ShaderModule.cpp"
     "ShaderModule.h"
+    "SpirvUtils.cpp"
+    "SpirvUtils.h"
     "StagingBuffer.cpp"
     "StagingBuffer.h"
     "Surface.cpp"
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 2bdadb9..3d68909 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -19,6 +19,7 @@
 #include "dawn_native/Device.h"
 #include "dawn_native/Pipeline.h"
 #include "dawn_native/PipelineLayout.h"
+#include "dawn_native/SpirvUtils.h"
 
 #include <spirv-tools/libspirv.hpp>
 #include <spirv_cross.hpp>
@@ -36,114 +37,6 @@
 namespace dawn_native {
 
     namespace {
-        Format::Type SpirvCrossBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) {
-            switch (spirvBaseType) {
-                case spirv_cross::SPIRType::Float:
-                    return Format::Type::Float;
-                case spirv_cross::SPIRType::Int:
-                    return Format::Type::Sint;
-                case spirv_cross::SPIRType::UInt:
-                    return Format::Type::Uint;
-                default:
-                    UNREACHABLE();
-                    return Format::Type::Other;
-            }
-        }
-
-        wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) {
-            switch (dim) {
-                case spv::Dim::Dim1D:
-                    return wgpu::TextureViewDimension::e1D;
-                case spv::Dim::Dim2D:
-                    if (arrayed) {
-                        return wgpu::TextureViewDimension::e2DArray;
-                    } else {
-                        return wgpu::TextureViewDimension::e2D;
-                    }
-                case spv::Dim::Dim3D:
-                    return wgpu::TextureViewDimension::e3D;
-                case spv::Dim::DimCube:
-                    if (arrayed) {
-                        return wgpu::TextureViewDimension::CubeArray;
-                    } else {
-                        return wgpu::TextureViewDimension::Cube;
-                    }
-                default:
-                    UNREACHABLE();
-                    return wgpu::TextureViewDimension::Undefined;
-            }
-        }
-
-        wgpu::TextureFormat ToWGPUTextureFormat(spv::ImageFormat format) {
-            switch (format) {
-                case spv::ImageFormatR8:
-                    return wgpu::TextureFormat::R8Unorm;
-                case spv::ImageFormatR8Snorm:
-                    return wgpu::TextureFormat::R8Snorm;
-                case spv::ImageFormatR8ui:
-                    return wgpu::TextureFormat::R8Uint;
-                case spv::ImageFormatR8i:
-                    return wgpu::TextureFormat::R8Sint;
-                case spv::ImageFormatR16ui:
-                    return wgpu::TextureFormat::R16Uint;
-                case spv::ImageFormatR16i:
-                    return wgpu::TextureFormat::R16Sint;
-                case spv::ImageFormatR16f:
-                    return wgpu::TextureFormat::R16Float;
-                case spv::ImageFormatRg8:
-                    return wgpu::TextureFormat::RG8Unorm;
-                case spv::ImageFormatRg8Snorm:
-                    return wgpu::TextureFormat::RG8Snorm;
-                case spv::ImageFormatRg8ui:
-                    return wgpu::TextureFormat::RG8Uint;
-                case spv::ImageFormatRg8i:
-                    return wgpu::TextureFormat::RG8Sint;
-                case spv::ImageFormatR32f:
-                    return wgpu::TextureFormat::R32Float;
-                case spv::ImageFormatR32ui:
-                    return wgpu::TextureFormat::R32Uint;
-                case spv::ImageFormatR32i:
-                    return wgpu::TextureFormat::R32Sint;
-                case spv::ImageFormatRg16ui:
-                    return wgpu::TextureFormat::RG16Uint;
-                case spv::ImageFormatRg16i:
-                    return wgpu::TextureFormat::RG16Sint;
-                case spv::ImageFormatRg16f:
-                    return wgpu::TextureFormat::RG16Float;
-                case spv::ImageFormatRgba8:
-                    return wgpu::TextureFormat::RGBA8Unorm;
-                case spv::ImageFormatRgba8Snorm:
-                    return wgpu::TextureFormat::RGBA8Snorm;
-                case spv::ImageFormatRgba8ui:
-                    return wgpu::TextureFormat::RGBA8Uint;
-                case spv::ImageFormatRgba8i:
-                    return wgpu::TextureFormat::RGBA8Sint;
-                case spv::ImageFormatRgb10A2:
-                    return wgpu::TextureFormat::RGB10A2Unorm;
-                case spv::ImageFormatR11fG11fB10f:
-                    return wgpu::TextureFormat::RG11B10Ufloat;
-                case spv::ImageFormatRg32f:
-                    return wgpu::TextureFormat::RG32Float;
-                case spv::ImageFormatRg32ui:
-                    return wgpu::TextureFormat::RG32Uint;
-                case spv::ImageFormatRg32i:
-                    return wgpu::TextureFormat::RG32Sint;
-                case spv::ImageFormatRgba16ui:
-                    return wgpu::TextureFormat::RGBA16Uint;
-                case spv::ImageFormatRgba16i:
-                    return wgpu::TextureFormat::RGBA16Sint;
-                case spv::ImageFormatRgba16f:
-                    return wgpu::TextureFormat::RGBA16Float;
-                case spv::ImageFormatRgba32f:
-                    return wgpu::TextureFormat::RGBA32Float;
-                case spv::ImageFormatRgba32ui:
-                    return wgpu::TextureFormat::RGBA32Uint;
-                case spv::ImageFormatRgba32i:
-                    return wgpu::TextureFormat::RGBA32Sint;
-                default:
-                    return wgpu::TextureFormat::Undefined;
-            }
-        }
 
         std::string GetShaderDeclarationString(BindGroupIndex group, BindingNumber binding) {
             std::ostringstream ostream;
@@ -550,27 +443,15 @@
 
         ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfo(
             const DeviceBase* device,
-            const spirv_cross::Compiler& compiler) {
+            const spirv_cross::Compiler& compiler,
+            const char* entryPointName) {
             std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
 
             // TODO(cwallez@chromium.org): make errors here creation errors
             // currently errors here do not prevent the shadermodule from being used
             const auto& resources = compiler.get_shader_resources();
 
-            switch (compiler.get_execution_model()) {
-                case spv::ExecutionModelVertex:
-                    metadata->stage = SingleShaderStage::Vertex;
-                    break;
-                case spv::ExecutionModelFragment:
-                    metadata->stage = SingleShaderStage::Fragment;
-                    break;
-                case spv::ExecutionModelGLCompute:
-                    metadata->stage = SingleShaderStage::Compute;
-                    break;
-                default:
-                    UNREACHABLE();
-                    return DAWN_VALIDATION_ERROR("Unexpected shader execution model");
-            }
+            metadata->stage = ExecutionModelToShaderStage(compiler.get_execution_model());
 
             if (resources.push_constant_buffers.size() > 0) {
                 return DAWN_VALIDATION_ERROR("Push constants aren't supported.");
@@ -635,7 +516,7 @@
                             info->viewDimension =
                                 SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
                             info->textureComponentType =
-                                SpirvCrossBaseTypeToFormatType(textureComponentType);
+                                SpirvBaseTypeToFormatType(textureComponentType);
                             info->type = bindingType;
                             break;
                         }
@@ -664,7 +545,7 @@
                             spirv_cross::SPIRType::ImageType imageType =
                                 compiler.get_type(info->base_type_id).image;
                             wgpu::TextureFormat storageTextureFormat =
-                                ToWGPUTextureFormat(imageType.format);
+                                SpirvImageFormatToTextureFormat(imageType.format);
                             if (storageTextureFormat == wgpu::TextureFormat::Undefined) {
                                 return DAWN_VALIDATION_ERROR(
                                     "Invalid image format declaration on storage image");
@@ -756,7 +637,7 @@
                     spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType =
                         compiler.get_type(fragmentOutput.base_type_id).basetype;
                     Format::Type formatType =
-                        SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType);
+                        SpirvBaseTypeToFormatType(shaderFragmentOutputBaseType);
                     if (formatType == Format::Type::Other) {
                         return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
                     }
@@ -764,6 +645,14 @@
                 }
             }
 
+            if (metadata->stage == SingleShaderStage::Compute) {
+                const spirv_cross::SPIREntryPoint& spirEntryPoint =
+                    compiler.get_entry_point(entryPointName, spv::ExecutionModelGLCompute);
+                metadata->localWorkgroupSize.x = spirEntryPoint.workgroup_size.x;
+                metadata->localWorkgroupSize.y = spirEntryPoint.workgroup_size.y;
+                metadata->localWorkgroupSize.z = spirEntryPoint.workgroup_size.z;
+            }
+
             return {std::move(metadata)};
         }
 
@@ -935,7 +824,7 @@
         }
 
         spirv_cross::Compiler compiler(mSpirv);
-        DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler));
+        DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler, "main"));
 
         return {};
     }
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 36f7c78..aef5286 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -24,7 +24,6 @@
 #include "dawn_native/Forward.h"
 #include "dawn_native/IntegerTypes.h"
 #include "dawn_native/PerStage.h"
-
 #include "dawn_native/dawn_platform.h"
 
 #include <bitset>
@@ -82,8 +81,10 @@
             ityp::array<ColorAttachmentIndex, Format::Type, kMaxColorAttachments>;
         FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
 
-        // The shader stage for this binding, TODO(dawn:216): can likely be removed once we
-        // properly support multiple entrypoints per ShaderModule.
+        // The local workgroup size declared for a compute entry point (or 0s otehrwise).
+        Origin3D localWorkgroupSize;
+
+        // The shader stage for this binding.
         SingleShaderStage stage;
     };
 
diff --git a/src/dawn_native/SpirvUtils.cpp b/src/dawn_native/SpirvUtils.cpp
new file mode 100644
index 0000000..e462e0d
--- /dev/null
+++ b/src/dawn_native/SpirvUtils.cpp
@@ -0,0 +1,154 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/SpirvUtils.h"
+
+namespace dawn_native {
+
+    spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage) {
+        switch (stage) {
+            case SingleShaderStage::Vertex:
+                return spv::ExecutionModelVertex;
+            case SingleShaderStage::Fragment:
+                return spv::ExecutionModelFragment;
+            case SingleShaderStage::Compute:
+                return spv::ExecutionModelGLCompute;
+            default:
+                UNREACHABLE();
+        }
+    }
+
+    SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model) {
+        switch (model) {
+            case spv::ExecutionModelVertex:
+                return SingleShaderStage::Vertex;
+            case spv::ExecutionModelFragment:
+                return SingleShaderStage::Fragment;
+            case spv::ExecutionModelGLCompute:
+                return SingleShaderStage::Compute;
+            default:
+                UNREACHABLE();
+        }
+    }
+
+    wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) {
+        switch (dim) {
+            case spv::Dim::Dim1D:
+                return wgpu::TextureViewDimension::e1D;
+            case spv::Dim::Dim2D:
+                if (arrayed) {
+                    return wgpu::TextureViewDimension::e2DArray;
+                } else {
+                    return wgpu::TextureViewDimension::e2D;
+                }
+            case spv::Dim::Dim3D:
+                return wgpu::TextureViewDimension::e3D;
+            case spv::Dim::DimCube:
+                if (arrayed) {
+                    return wgpu::TextureViewDimension::CubeArray;
+                } else {
+                    return wgpu::TextureViewDimension::Cube;
+                }
+            default:
+                UNREACHABLE();
+                return wgpu::TextureViewDimension::Undefined;
+        }
+    }
+
+    wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format) {
+        switch (format) {
+            case spv::ImageFormatR8:
+                return wgpu::TextureFormat::R8Unorm;
+            case spv::ImageFormatR8Snorm:
+                return wgpu::TextureFormat::R8Snorm;
+            case spv::ImageFormatR8ui:
+                return wgpu::TextureFormat::R8Uint;
+            case spv::ImageFormatR8i:
+                return wgpu::TextureFormat::R8Sint;
+            case spv::ImageFormatR16ui:
+                return wgpu::TextureFormat::R16Uint;
+            case spv::ImageFormatR16i:
+                return wgpu::TextureFormat::R16Sint;
+            case spv::ImageFormatR16f:
+                return wgpu::TextureFormat::R16Float;
+            case spv::ImageFormatRg8:
+                return wgpu::TextureFormat::RG8Unorm;
+            case spv::ImageFormatRg8Snorm:
+                return wgpu::TextureFormat::RG8Snorm;
+            case spv::ImageFormatRg8ui:
+                return wgpu::TextureFormat::RG8Uint;
+            case spv::ImageFormatRg8i:
+                return wgpu::TextureFormat::RG8Sint;
+            case spv::ImageFormatR32f:
+                return wgpu::TextureFormat::R32Float;
+            case spv::ImageFormatR32ui:
+                return wgpu::TextureFormat::R32Uint;
+            case spv::ImageFormatR32i:
+                return wgpu::TextureFormat::R32Sint;
+            case spv::ImageFormatRg16ui:
+                return wgpu::TextureFormat::RG16Uint;
+            case spv::ImageFormatRg16i:
+                return wgpu::TextureFormat::RG16Sint;
+            case spv::ImageFormatRg16f:
+                return wgpu::TextureFormat::RG16Float;
+            case spv::ImageFormatRgba8:
+                return wgpu::TextureFormat::RGBA8Unorm;
+            case spv::ImageFormatRgba8Snorm:
+                return wgpu::TextureFormat::RGBA8Snorm;
+            case spv::ImageFormatRgba8ui:
+                return wgpu::TextureFormat::RGBA8Uint;
+            case spv::ImageFormatRgba8i:
+                return wgpu::TextureFormat::RGBA8Sint;
+            case spv::ImageFormatRgb10A2:
+                return wgpu::TextureFormat::RGB10A2Unorm;
+            case spv::ImageFormatR11fG11fB10f:
+                return wgpu::TextureFormat::RG11B10Ufloat;
+            case spv::ImageFormatRg32f:
+                return wgpu::TextureFormat::RG32Float;
+            case spv::ImageFormatRg32ui:
+                return wgpu::TextureFormat::RG32Uint;
+            case spv::ImageFormatRg32i:
+                return wgpu::TextureFormat::RG32Sint;
+            case spv::ImageFormatRgba16ui:
+                return wgpu::TextureFormat::RGBA16Uint;
+            case spv::ImageFormatRgba16i:
+                return wgpu::TextureFormat::RGBA16Sint;
+            case spv::ImageFormatRgba16f:
+                return wgpu::TextureFormat::RGBA16Float;
+            case spv::ImageFormatRgba32f:
+                return wgpu::TextureFormat::RGBA32Float;
+            case spv::ImageFormatRgba32ui:
+                return wgpu::TextureFormat::RGBA32Uint;
+            case spv::ImageFormatRgba32i:
+                return wgpu::TextureFormat::RGBA32Sint;
+            default:
+                return wgpu::TextureFormat::Undefined;
+        }
+    }
+
+    Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) {
+        switch (spirvBaseType) {
+            case spirv_cross::SPIRType::Float:
+                return Format::Type::Float;
+            case spirv_cross::SPIRType::Int:
+                return Format::Type::Sint;
+            case spirv_cross::SPIRType::UInt:
+                return Format::Type::Uint;
+            default:
+                UNREACHABLE();
+                return Format::Type::Other;
+        }
+    }
+
+}  // namespace dawn_native
diff --git a/src/dawn_native/SpirvUtils.h b/src/dawn_native/SpirvUtils.h
new file mode 100644
index 0000000..ceb6fd6
--- /dev/null
+++ b/src/dawn_native/SpirvUtils.h
@@ -0,0 +1,44 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains utilities to convert from-to spirv.hpp datatypes without polluting other
+// headers with spirv.hpp
+
+#ifndef DAWNNATIVE_SPIRV_UTILS_H_
+#define DAWNNATIVE_SPIRV_UTILS_H_
+
+#include "dawn_native/Format.h"
+#include "dawn_native/PerStage.h"
+#include "dawn_native/dawn_platform.h"
+
+#include <spirv_cross.hpp>
+
+namespace dawn_native {
+
+    // Returns the spirv_cross equivalent for this shader stage and vice-versa.
+    spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage);
+    SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model);
+
+    // Returns the texture view dimension for corresponding to (dim, arrayed).
+    wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed);
+
+    // Returns the texture format corresponding to format.
+    wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format);
+
+    // Returns the format "component type" corresponding to the SPIRV base type.
+    Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType);
+
+}  // namespace dawn_native
+
+#endif  // DAWNNATIVE_SPIRV_UTILS_H_
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index aca771d..6c96515 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -34,8 +34,8 @@
         ShaderModule* computeModule = ToBackend(descriptor->computeStage.module);
         const char* computeEntryPoint = descriptor->computeStage.entryPoint;
         ShaderModule::MetalFunctionData computeData;
-        DAWN_TRY(computeModule->GetFunction(computeEntryPoint, SingleShaderStage::Compute,
-                                            ToBackend(GetLayout()), &computeData));
+        DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
+                                               ToBackend(GetLayout()), &computeData));
 
         NSError* error = nil;
         mMtlComputePipelineState =
@@ -46,7 +46,9 @@
         }
 
         // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
-        mLocalWorkgroupSize = computeData.localWorkgroupSize;
+        Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize;
+        mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
+
         mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
         return {};
     }
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index ede97aa..ee5fd04 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -335,8 +335,9 @@
         ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
         const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
         ShaderModule::MetalFunctionData vertexData;
-        DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
-                                           ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this));
+        DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
+                                              ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
+                                              this));
 
         descriptorMTL.vertexFunction = vertexData.function;
         if (vertexData.needsStorageBufferLength) {
@@ -346,9 +347,9 @@
         ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module);
         const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint;
         ShaderModule::MetalFunctionData fragmentData;
-        DAWN_TRY(fragmentModule->GetFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
-                                             ToBackend(GetLayout()), &fragmentData,
-                                             descriptor->sampleMask));
+        DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
+                                                ToBackend(GetLayout()), &fragmentData,
+                                                descriptor->sampleMask));
 
         descriptorMTL.fragmentFunction = fragmentData.function;
         if (fragmentData.needsStorageBufferLength) {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 3a211e6..4e543c7 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -38,18 +38,17 @@
 
         struct MetalFunctionData {
             id<MTLFunction> function = nil;
-            MTLSize localWorkgroupSize;
             bool needsStorageBufferLength;
             ~MetalFunctionData() {
                 [function release];
             }
         };
-        MaybeError GetFunction(const char* functionName,
-                               SingleShaderStage functionStage,
-                               const PipelineLayout* layout,
-                               MetalFunctionData* out,
-                               uint32_t sampleMask = 0xFFFFFFFF,
-                               const RenderPipeline* renderPipeline = nullptr);
+        MaybeError CreateFunction(const char* entryPointName,
+                                  SingleShaderStage stage,
+                                  const PipelineLayout* layout,
+                                  MetalFunctionData* out,
+                                  uint32_t sampleMask = 0xFFFFFFFF,
+                                  const RenderPipeline* renderPipeline = nullptr);
 
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 047d453..2588aa0 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -15,6 +15,7 @@
 #include "dawn_native/metal/ShaderModuleMTL.h"
 
 #include "dawn_native/BindGroupLayout.h"
+#include "dawn_native/SpirvUtils.h"
 #include "dawn_native/metal/DeviceMTL.h"
 #include "dawn_native/metal/PipelineLayoutMTL.h"
 #include "dawn_native/metal/RenderPipelineMTL.h"
@@ -25,22 +26,6 @@
 
 namespace dawn_native { namespace metal {
 
-    namespace {
-
-        spv::ExecutionModel SpirvExecutionModelForStage(SingleShaderStage stage) {
-            switch (stage) {
-                case SingleShaderStage::Vertex:
-                    return spv::ExecutionModelVertex;
-                case SingleShaderStage::Fragment:
-                    return spv::ExecutionModelFragment;
-                case SingleShaderStage::Compute:
-                    return spv::ExecutionModelGLCompute;
-                default:
-                    UNREACHABLE();
-            }
-        }
-    }  // namespace
-
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
                                                       const ShaderModuleDescriptor* descriptor) {
@@ -57,25 +42,26 @@
         return InitializeBase();
     }
 
-    MaybeError ShaderModule::GetFunction(const char* functionName,
-                                         SingleShaderStage functionStage,
-                                         const PipelineLayout* layout,
-                                         ShaderModule::MetalFunctionData* out,
-                                         uint32_t sampleMask,
-                                         const RenderPipeline* renderPipeline) {
+    MaybeError ShaderModule::CreateFunction(const char* entryPointName,
+                                            SingleShaderStage stage,
+                                            const PipelineLayout* layout,
+                                            ShaderModule::MetalFunctionData* out,
+                                            uint32_t sampleMask,
+                                            const RenderPipeline* renderPipeline) {
         ASSERT(!IsError());
         ASSERT(out);
         const std::vector<uint32_t>* spirv = &GetSpirv();
+        spv::ExecutionModel executionModel = ShaderStageToExecutionModel(stage);
 
 #ifdef DAWN_ENABLE_WGSL
         // Use set 4 since it is bigger than what users can access currently
         static const uint32_t kPullingBufferBindingSet = 4;
         std::vector<uint32_t> pullingSpirv;
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
-            functionStage == SingleShaderStage::Vertex) {
+            stage == SingleShaderStage::Vertex) {
             DAWN_TRY_ASSIGN(pullingSpirv,
                             GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
-                                                 functionName, kPullingBufferBindingSet));
+                                                 entryPointName, kPullingBufferBindingSet));
             spirv = &pullingSpirv;
         }
 #endif
@@ -99,6 +85,7 @@
 
         spirv_cross::CompilerMSL compiler(*spirv);
         compiler.set_msl_options(options_msl);
+        compiler.set_entry_point(entryPointName, executionModel);
 
         // By default SPIRV-Cross will give MSL resources indices in increasing order.
         // To make the MSL indices match the indices chosen in the PipelineLayout, we build
@@ -116,30 +103,33 @@
                 const BindingInfo& bindingInfo =
                     layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex);
 
-                for (auto stage : IterateStages(bindingInfo.visibility)) {
-                    uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex];
-                    spirv_cross::MSLResourceBinding mslBinding;
-                    mslBinding.stage = SpirvExecutionModelForStage(stage);
-                    mslBinding.desc_set = static_cast<uint32_t>(group);
-                    mslBinding.binding = static_cast<uint32_t>(bindingNumber);
-                    mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler =
-                        shaderIndex;
-
-                    compiler.add_msl_resource_binding(mslBinding);
+                if (!(bindingInfo.visibility & StageBit(stage))) {
+                    continue;
                 }
+
+                uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex];
+
+                spirv_cross::MSLResourceBinding mslBinding;
+                mslBinding.stage = executionModel;
+                mslBinding.desc_set = static_cast<uint32_t>(group);
+                mslBinding.binding = static_cast<uint32_t>(bindingNumber);
+                mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler =
+                    shaderIndex;
+
+                compiler.add_msl_resource_binding(mslBinding);
             }
         }
 
 #ifdef DAWN_ENABLE_WGSL
         // Add vertex buffers bound as storage buffers
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
-            functionStage == SingleShaderStage::Vertex) {
+            stage == SingleShaderStage::Vertex) {
             for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
                 uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
 
                 spirv_cross::MSLResourceBinding mslBinding;
 
-                mslBinding.stage = SpirvExecutionModelForStage(SingleShaderStage::Vertex);
+                mslBinding.stage = spv::ExecutionModelVertex;
                 mslBinding.desc_set = kPullingBufferBindingSet;
                 mslBinding.binding = dawnIndex;
                 mslBinding.msl_buffer = metalIndex;
@@ -149,16 +139,16 @@
 #endif
 
         {
-            spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage);
-            auto size = compiler.get_entry_point(functionName, executionModel).workgroup_size;
-            out->localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
-        }
-
-        {
             // SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing
             // by default.
             NSString* mslSource;
             std::string msl = compiler.compile();
+
+            // Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the
+            // modified entryPointName from it.
+            const std::string& modifiedEntryPointName =
+                compiler.get_entry_point(entryPointName, executionModel).name;
+
             // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
             // category. -Wunused-variable in particular comes up a lot in generated code, and some
             // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
@@ -183,18 +173,7 @@
                 }
             }
 
-            // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like
-            // clean_func_name:
-            // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
-            const char* metalFunctionName = functionName;
-            if (strcmp(metalFunctionName, "main") == 0) {
-                metalFunctionName = "main0";
-            }
-            if (strcmp(metalFunctionName, "saturate") == 0) {
-                metalFunctionName = "saturate0";
-            }
-
-            NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName];
+            NSString* name = [[NSString alloc] initWithUTF8String:modifiedEntryPointName.c_str()];
             out->function = [library newFunctionWithName:name];
             [library release];
         }
@@ -202,7 +181,7 @@
         out->needsStorageBufferLength = compiler.needs_buffer_size_buffer();
 
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
-            GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) {
+            GetEntryPoint(entryPointName, stage).usedVertexAttributes.any()) {
             out->needsStorageBufferLength = true;
         }