Metal: Use ShaderModule reflection when possible.
This change the Metal backend in preparation for supporting multiple
entrypoints:
- Explicitly set the spirv_cross entry point before compiling.
- Moves gathering of the local size to the frontend as it will be
useful for validation in the future.
- Query spirv-cross for the modified entrypoint name instead of
duplicating the code in Dawn.
- Move some conversion helpers from ShaderModule.cpp to their own
SpirvUtils file.
Bug: dawn:216
Change-Id: I87d4953428e0bfeb97e39ed22f94d86ae7987782
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28241
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 6b70082..4558f8c 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -253,6 +253,8 @@
"Sampler.h",
"ShaderModule.cpp",
"ShaderModule.h",
+ "SpirvUtils.cpp",
+ "SpirvUtils.h",
"StagingBuffer.cpp",
"StagingBuffer.h",
"Surface.cpp",
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index b7db561..ac435bb 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -131,6 +131,8 @@
"Sampler.h"
"ShaderModule.cpp"
"ShaderModule.h"
+ "SpirvUtils.cpp"
+ "SpirvUtils.h"
"StagingBuffer.cpp"
"StagingBuffer.h"
"Surface.cpp"
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 2bdadb9..3d68909 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -19,6 +19,7 @@
#include "dawn_native/Device.h"
#include "dawn_native/Pipeline.h"
#include "dawn_native/PipelineLayout.h"
+#include "dawn_native/SpirvUtils.h"
#include <spirv-tools/libspirv.hpp>
#include <spirv_cross.hpp>
@@ -36,114 +37,6 @@
namespace dawn_native {
namespace {
- Format::Type SpirvCrossBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) {
- switch (spirvBaseType) {
- case spirv_cross::SPIRType::Float:
- return Format::Type::Float;
- case spirv_cross::SPIRType::Int:
- return Format::Type::Sint;
- case spirv_cross::SPIRType::UInt:
- return Format::Type::Uint;
- default:
- UNREACHABLE();
- return Format::Type::Other;
- }
- }
-
- wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) {
- switch (dim) {
- case spv::Dim::Dim1D:
- return wgpu::TextureViewDimension::e1D;
- case spv::Dim::Dim2D:
- if (arrayed) {
- return wgpu::TextureViewDimension::e2DArray;
- } else {
- return wgpu::TextureViewDimension::e2D;
- }
- case spv::Dim::Dim3D:
- return wgpu::TextureViewDimension::e3D;
- case spv::Dim::DimCube:
- if (arrayed) {
- return wgpu::TextureViewDimension::CubeArray;
- } else {
- return wgpu::TextureViewDimension::Cube;
- }
- default:
- UNREACHABLE();
- return wgpu::TextureViewDimension::Undefined;
- }
- }
-
- wgpu::TextureFormat ToWGPUTextureFormat(spv::ImageFormat format) {
- switch (format) {
- case spv::ImageFormatR8:
- return wgpu::TextureFormat::R8Unorm;
- case spv::ImageFormatR8Snorm:
- return wgpu::TextureFormat::R8Snorm;
- case spv::ImageFormatR8ui:
- return wgpu::TextureFormat::R8Uint;
- case spv::ImageFormatR8i:
- return wgpu::TextureFormat::R8Sint;
- case spv::ImageFormatR16ui:
- return wgpu::TextureFormat::R16Uint;
- case spv::ImageFormatR16i:
- return wgpu::TextureFormat::R16Sint;
- case spv::ImageFormatR16f:
- return wgpu::TextureFormat::R16Float;
- case spv::ImageFormatRg8:
- return wgpu::TextureFormat::RG8Unorm;
- case spv::ImageFormatRg8Snorm:
- return wgpu::TextureFormat::RG8Snorm;
- case spv::ImageFormatRg8ui:
- return wgpu::TextureFormat::RG8Uint;
- case spv::ImageFormatRg8i:
- return wgpu::TextureFormat::RG8Sint;
- case spv::ImageFormatR32f:
- return wgpu::TextureFormat::R32Float;
- case spv::ImageFormatR32ui:
- return wgpu::TextureFormat::R32Uint;
- case spv::ImageFormatR32i:
- return wgpu::TextureFormat::R32Sint;
- case spv::ImageFormatRg16ui:
- return wgpu::TextureFormat::RG16Uint;
- case spv::ImageFormatRg16i:
- return wgpu::TextureFormat::RG16Sint;
- case spv::ImageFormatRg16f:
- return wgpu::TextureFormat::RG16Float;
- case spv::ImageFormatRgba8:
- return wgpu::TextureFormat::RGBA8Unorm;
- case spv::ImageFormatRgba8Snorm:
- return wgpu::TextureFormat::RGBA8Snorm;
- case spv::ImageFormatRgba8ui:
- return wgpu::TextureFormat::RGBA8Uint;
- case spv::ImageFormatRgba8i:
- return wgpu::TextureFormat::RGBA8Sint;
- case spv::ImageFormatRgb10A2:
- return wgpu::TextureFormat::RGB10A2Unorm;
- case spv::ImageFormatR11fG11fB10f:
- return wgpu::TextureFormat::RG11B10Ufloat;
- case spv::ImageFormatRg32f:
- return wgpu::TextureFormat::RG32Float;
- case spv::ImageFormatRg32ui:
- return wgpu::TextureFormat::RG32Uint;
- case spv::ImageFormatRg32i:
- return wgpu::TextureFormat::RG32Sint;
- case spv::ImageFormatRgba16ui:
- return wgpu::TextureFormat::RGBA16Uint;
- case spv::ImageFormatRgba16i:
- return wgpu::TextureFormat::RGBA16Sint;
- case spv::ImageFormatRgba16f:
- return wgpu::TextureFormat::RGBA16Float;
- case spv::ImageFormatRgba32f:
- return wgpu::TextureFormat::RGBA32Float;
- case spv::ImageFormatRgba32ui:
- return wgpu::TextureFormat::RGBA32Uint;
- case spv::ImageFormatRgba32i:
- return wgpu::TextureFormat::RGBA32Sint;
- default:
- return wgpu::TextureFormat::Undefined;
- }
- }
std::string GetShaderDeclarationString(BindGroupIndex group, BindingNumber binding) {
std::ostringstream ostream;
@@ -550,27 +443,15 @@
ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfo(
const DeviceBase* device,
- const spirv_cross::Compiler& compiler) {
+ const spirv_cross::Compiler& compiler,
+ const char* entryPointName) {
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
// TODO(cwallez@chromium.org): make errors here creation errors
// currently errors here do not prevent the shadermodule from being used
const auto& resources = compiler.get_shader_resources();
- switch (compiler.get_execution_model()) {
- case spv::ExecutionModelVertex:
- metadata->stage = SingleShaderStage::Vertex;
- break;
- case spv::ExecutionModelFragment:
- metadata->stage = SingleShaderStage::Fragment;
- break;
- case spv::ExecutionModelGLCompute:
- metadata->stage = SingleShaderStage::Compute;
- break;
- default:
- UNREACHABLE();
- return DAWN_VALIDATION_ERROR("Unexpected shader execution model");
- }
+ metadata->stage = ExecutionModelToShaderStage(compiler.get_execution_model());
if (resources.push_constant_buffers.size() > 0) {
return DAWN_VALIDATION_ERROR("Push constants aren't supported.");
@@ -635,7 +516,7 @@
info->viewDimension =
SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
info->textureComponentType =
- SpirvCrossBaseTypeToFormatType(textureComponentType);
+ SpirvBaseTypeToFormatType(textureComponentType);
info->type = bindingType;
break;
}
@@ -664,7 +545,7 @@
spirv_cross::SPIRType::ImageType imageType =
compiler.get_type(info->base_type_id).image;
wgpu::TextureFormat storageTextureFormat =
- ToWGPUTextureFormat(imageType.format);
+ SpirvImageFormatToTextureFormat(imageType.format);
if (storageTextureFormat == wgpu::TextureFormat::Undefined) {
return DAWN_VALIDATION_ERROR(
"Invalid image format declaration on storage image");
@@ -756,7 +637,7 @@
spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType =
compiler.get_type(fragmentOutput.base_type_id).basetype;
Format::Type formatType =
- SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType);
+ SpirvBaseTypeToFormatType(shaderFragmentOutputBaseType);
if (formatType == Format::Type::Other) {
return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
}
@@ -764,6 +645,14 @@
}
}
+ if (metadata->stage == SingleShaderStage::Compute) {
+ const spirv_cross::SPIREntryPoint& spirEntryPoint =
+ compiler.get_entry_point(entryPointName, spv::ExecutionModelGLCompute);
+ metadata->localWorkgroupSize.x = spirEntryPoint.workgroup_size.x;
+ metadata->localWorkgroupSize.y = spirEntryPoint.workgroup_size.y;
+ metadata->localWorkgroupSize.z = spirEntryPoint.workgroup_size.z;
+ }
+
return {std::move(metadata)};
}
@@ -935,7 +824,7 @@
}
spirv_cross::Compiler compiler(mSpirv);
- DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler));
+ DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler, "main"));
return {};
}
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 36f7c78..aef5286 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -24,7 +24,6 @@
#include "dawn_native/Forward.h"
#include "dawn_native/IntegerTypes.h"
#include "dawn_native/PerStage.h"
-
#include "dawn_native/dawn_platform.h"
#include <bitset>
@@ -82,8 +81,10 @@
ityp::array<ColorAttachmentIndex, Format::Type, kMaxColorAttachments>;
FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
- // The shader stage for this binding, TODO(dawn:216): can likely be removed once we
- // properly support multiple entrypoints per ShaderModule.
+ // The local workgroup size declared for a compute entry point (or 0s otehrwise).
+ Origin3D localWorkgroupSize;
+
+ // The shader stage for this binding.
SingleShaderStage stage;
};
diff --git a/src/dawn_native/SpirvUtils.cpp b/src/dawn_native/SpirvUtils.cpp
new file mode 100644
index 0000000..e462e0d
--- /dev/null
+++ b/src/dawn_native/SpirvUtils.cpp
@@ -0,0 +1,154 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/SpirvUtils.h"
+
+namespace dawn_native {
+
+ spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage) {
+ switch (stage) {
+ case SingleShaderStage::Vertex:
+ return spv::ExecutionModelVertex;
+ case SingleShaderStage::Fragment:
+ return spv::ExecutionModelFragment;
+ case SingleShaderStage::Compute:
+ return spv::ExecutionModelGLCompute;
+ default:
+ UNREACHABLE();
+ }
+ }
+
+ SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model) {
+ switch (model) {
+ case spv::ExecutionModelVertex:
+ return SingleShaderStage::Vertex;
+ case spv::ExecutionModelFragment:
+ return SingleShaderStage::Fragment;
+ case spv::ExecutionModelGLCompute:
+ return SingleShaderStage::Compute;
+ default:
+ UNREACHABLE();
+ }
+ }
+
+ wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) {
+ switch (dim) {
+ case spv::Dim::Dim1D:
+ return wgpu::TextureViewDimension::e1D;
+ case spv::Dim::Dim2D:
+ if (arrayed) {
+ return wgpu::TextureViewDimension::e2DArray;
+ } else {
+ return wgpu::TextureViewDimension::e2D;
+ }
+ case spv::Dim::Dim3D:
+ return wgpu::TextureViewDimension::e3D;
+ case spv::Dim::DimCube:
+ if (arrayed) {
+ return wgpu::TextureViewDimension::CubeArray;
+ } else {
+ return wgpu::TextureViewDimension::Cube;
+ }
+ default:
+ UNREACHABLE();
+ return wgpu::TextureViewDimension::Undefined;
+ }
+ }
+
+ wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format) {
+ switch (format) {
+ case spv::ImageFormatR8:
+ return wgpu::TextureFormat::R8Unorm;
+ case spv::ImageFormatR8Snorm:
+ return wgpu::TextureFormat::R8Snorm;
+ case spv::ImageFormatR8ui:
+ return wgpu::TextureFormat::R8Uint;
+ case spv::ImageFormatR8i:
+ return wgpu::TextureFormat::R8Sint;
+ case spv::ImageFormatR16ui:
+ return wgpu::TextureFormat::R16Uint;
+ case spv::ImageFormatR16i:
+ return wgpu::TextureFormat::R16Sint;
+ case spv::ImageFormatR16f:
+ return wgpu::TextureFormat::R16Float;
+ case spv::ImageFormatRg8:
+ return wgpu::TextureFormat::RG8Unorm;
+ case spv::ImageFormatRg8Snorm:
+ return wgpu::TextureFormat::RG8Snorm;
+ case spv::ImageFormatRg8ui:
+ return wgpu::TextureFormat::RG8Uint;
+ case spv::ImageFormatRg8i:
+ return wgpu::TextureFormat::RG8Sint;
+ case spv::ImageFormatR32f:
+ return wgpu::TextureFormat::R32Float;
+ case spv::ImageFormatR32ui:
+ return wgpu::TextureFormat::R32Uint;
+ case spv::ImageFormatR32i:
+ return wgpu::TextureFormat::R32Sint;
+ case spv::ImageFormatRg16ui:
+ return wgpu::TextureFormat::RG16Uint;
+ case spv::ImageFormatRg16i:
+ return wgpu::TextureFormat::RG16Sint;
+ case spv::ImageFormatRg16f:
+ return wgpu::TextureFormat::RG16Float;
+ case spv::ImageFormatRgba8:
+ return wgpu::TextureFormat::RGBA8Unorm;
+ case spv::ImageFormatRgba8Snorm:
+ return wgpu::TextureFormat::RGBA8Snorm;
+ case spv::ImageFormatRgba8ui:
+ return wgpu::TextureFormat::RGBA8Uint;
+ case spv::ImageFormatRgba8i:
+ return wgpu::TextureFormat::RGBA8Sint;
+ case spv::ImageFormatRgb10A2:
+ return wgpu::TextureFormat::RGB10A2Unorm;
+ case spv::ImageFormatR11fG11fB10f:
+ return wgpu::TextureFormat::RG11B10Ufloat;
+ case spv::ImageFormatRg32f:
+ return wgpu::TextureFormat::RG32Float;
+ case spv::ImageFormatRg32ui:
+ return wgpu::TextureFormat::RG32Uint;
+ case spv::ImageFormatRg32i:
+ return wgpu::TextureFormat::RG32Sint;
+ case spv::ImageFormatRgba16ui:
+ return wgpu::TextureFormat::RGBA16Uint;
+ case spv::ImageFormatRgba16i:
+ return wgpu::TextureFormat::RGBA16Sint;
+ case spv::ImageFormatRgba16f:
+ return wgpu::TextureFormat::RGBA16Float;
+ case spv::ImageFormatRgba32f:
+ return wgpu::TextureFormat::RGBA32Float;
+ case spv::ImageFormatRgba32ui:
+ return wgpu::TextureFormat::RGBA32Uint;
+ case spv::ImageFormatRgba32i:
+ return wgpu::TextureFormat::RGBA32Sint;
+ default:
+ return wgpu::TextureFormat::Undefined;
+ }
+ }
+
+ Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) {
+ switch (spirvBaseType) {
+ case spirv_cross::SPIRType::Float:
+ return Format::Type::Float;
+ case spirv_cross::SPIRType::Int:
+ return Format::Type::Sint;
+ case spirv_cross::SPIRType::UInt:
+ return Format::Type::Uint;
+ default:
+ UNREACHABLE();
+ return Format::Type::Other;
+ }
+ }
+
+} // namespace dawn_native
diff --git a/src/dawn_native/SpirvUtils.h b/src/dawn_native/SpirvUtils.h
new file mode 100644
index 0000000..ceb6fd6
--- /dev/null
+++ b/src/dawn_native/SpirvUtils.h
@@ -0,0 +1,44 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains utilities to convert from-to spirv.hpp datatypes without polluting other
+// headers with spirv.hpp
+
+#ifndef DAWNNATIVE_SPIRV_UTILS_H_
+#define DAWNNATIVE_SPIRV_UTILS_H_
+
+#include "dawn_native/Format.h"
+#include "dawn_native/PerStage.h"
+#include "dawn_native/dawn_platform.h"
+
+#include <spirv_cross.hpp>
+
+namespace dawn_native {
+
+ // Returns the spirv_cross equivalent for this shader stage and vice-versa.
+ spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage);
+ SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model);
+
+ // Returns the texture view dimension for corresponding to (dim, arrayed).
+ wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed);
+
+ // Returns the texture format corresponding to format.
+ wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format);
+
+ // Returns the format "component type" corresponding to the SPIRV base type.
+ Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType);
+
+} // namespace dawn_native
+
+#endif // DAWNNATIVE_SPIRV_UTILS_H_
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index aca771d..6c96515 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -34,8 +34,8 @@
ShaderModule* computeModule = ToBackend(descriptor->computeStage.module);
const char* computeEntryPoint = descriptor->computeStage.entryPoint;
ShaderModule::MetalFunctionData computeData;
- DAWN_TRY(computeModule->GetFunction(computeEntryPoint, SingleShaderStage::Compute,
- ToBackend(GetLayout()), &computeData));
+ DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
+ ToBackend(GetLayout()), &computeData));
NSError* error = nil;
mMtlComputePipelineState =
@@ -46,7 +46,9 @@
}
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
- mLocalWorkgroupSize = computeData.localWorkgroupSize;
+ Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize;
+ mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
+
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
return {};
}
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index ede97aa..ee5fd04 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -335,8 +335,9 @@
ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
ShaderModule::MetalFunctionData vertexData;
- DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
- ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this));
+ DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
+ ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
+ this));
descriptorMTL.vertexFunction = vertexData.function;
if (vertexData.needsStorageBufferLength) {
@@ -346,9 +347,9 @@
ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module);
const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint;
ShaderModule::MetalFunctionData fragmentData;
- DAWN_TRY(fragmentModule->GetFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
- ToBackend(GetLayout()), &fragmentData,
- descriptor->sampleMask));
+ DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
+ ToBackend(GetLayout()), &fragmentData,
+ descriptor->sampleMask));
descriptorMTL.fragmentFunction = fragmentData.function;
if (fragmentData.needsStorageBufferLength) {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 3a211e6..4e543c7 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -38,18 +38,17 @@
struct MetalFunctionData {
id<MTLFunction> function = nil;
- MTLSize localWorkgroupSize;
bool needsStorageBufferLength;
~MetalFunctionData() {
[function release];
}
};
- MaybeError GetFunction(const char* functionName,
- SingleShaderStage functionStage,
- const PipelineLayout* layout,
- MetalFunctionData* out,
- uint32_t sampleMask = 0xFFFFFFFF,
- const RenderPipeline* renderPipeline = nullptr);
+ MaybeError CreateFunction(const char* entryPointName,
+ SingleShaderStage stage,
+ const PipelineLayout* layout,
+ MetalFunctionData* out,
+ uint32_t sampleMask = 0xFFFFFFFF,
+ const RenderPipeline* renderPipeline = nullptr);
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 047d453..2588aa0 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -15,6 +15,7 @@
#include "dawn_native/metal/ShaderModuleMTL.h"
#include "dawn_native/BindGroupLayout.h"
+#include "dawn_native/SpirvUtils.h"
#include "dawn_native/metal/DeviceMTL.h"
#include "dawn_native/metal/PipelineLayoutMTL.h"
#include "dawn_native/metal/RenderPipelineMTL.h"
@@ -25,22 +26,6 @@
namespace dawn_native { namespace metal {
- namespace {
-
- spv::ExecutionModel SpirvExecutionModelForStage(SingleShaderStage stage) {
- switch (stage) {
- case SingleShaderStage::Vertex:
- return spv::ExecutionModelVertex;
- case SingleShaderStage::Fragment:
- return spv::ExecutionModelFragment;
- case SingleShaderStage::Compute:
- return spv::ExecutionModelGLCompute;
- default:
- UNREACHABLE();
- }
- }
- } // namespace
-
// static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) {
@@ -57,25 +42,26 @@
return InitializeBase();
}
- MaybeError ShaderModule::GetFunction(const char* functionName,
- SingleShaderStage functionStage,
- const PipelineLayout* layout,
- ShaderModule::MetalFunctionData* out,
- uint32_t sampleMask,
- const RenderPipeline* renderPipeline) {
+ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
+ SingleShaderStage stage,
+ const PipelineLayout* layout,
+ ShaderModule::MetalFunctionData* out,
+ uint32_t sampleMask,
+ const RenderPipeline* renderPipeline) {
ASSERT(!IsError());
ASSERT(out);
const std::vector<uint32_t>* spirv = &GetSpirv();
+ spv::ExecutionModel executionModel = ShaderStageToExecutionModel(stage);
#ifdef DAWN_ENABLE_WGSL
// Use set 4 since it is bigger than what users can access currently
static const uint32_t kPullingBufferBindingSet = 4;
std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
- functionStage == SingleShaderStage::Vertex) {
+ stage == SingleShaderStage::Vertex) {
DAWN_TRY_ASSIGN(pullingSpirv,
GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
- functionName, kPullingBufferBindingSet));
+ entryPointName, kPullingBufferBindingSet));
spirv = &pullingSpirv;
}
#endif
@@ -99,6 +85,7 @@
spirv_cross::CompilerMSL compiler(*spirv);
compiler.set_msl_options(options_msl);
+ compiler.set_entry_point(entryPointName, executionModel);
// By default SPIRV-Cross will give MSL resources indices in increasing order.
// To make the MSL indices match the indices chosen in the PipelineLayout, we build
@@ -116,30 +103,33 @@
const BindingInfo& bindingInfo =
layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex);
- for (auto stage : IterateStages(bindingInfo.visibility)) {
- uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex];
- spirv_cross::MSLResourceBinding mslBinding;
- mslBinding.stage = SpirvExecutionModelForStage(stage);
- mslBinding.desc_set = static_cast<uint32_t>(group);
- mslBinding.binding = static_cast<uint32_t>(bindingNumber);
- mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler =
- shaderIndex;
-
- compiler.add_msl_resource_binding(mslBinding);
+ if (!(bindingInfo.visibility & StageBit(stage))) {
+ continue;
}
+
+ uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex];
+
+ spirv_cross::MSLResourceBinding mslBinding;
+ mslBinding.stage = executionModel;
+ mslBinding.desc_set = static_cast<uint32_t>(group);
+ mslBinding.binding = static_cast<uint32_t>(bindingNumber);
+ mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler =
+ shaderIndex;
+
+ compiler.add_msl_resource_binding(mslBinding);
}
}
#ifdef DAWN_ENABLE_WGSL
// Add vertex buffers bound as storage buffers
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
- functionStage == SingleShaderStage::Vertex) {
+ stage == SingleShaderStage::Vertex) {
for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
spirv_cross::MSLResourceBinding mslBinding;
- mslBinding.stage = SpirvExecutionModelForStage(SingleShaderStage::Vertex);
+ mslBinding.stage = spv::ExecutionModelVertex;
mslBinding.desc_set = kPullingBufferBindingSet;
mslBinding.binding = dawnIndex;
mslBinding.msl_buffer = metalIndex;
@@ -149,16 +139,16 @@
#endif
{
- spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage);
- auto size = compiler.get_entry_point(functionName, executionModel).workgroup_size;
- out->localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
- }
-
- {
// SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing
// by default.
NSString* mslSource;
std::string msl = compiler.compile();
+
+ // Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the
+ // modified entryPointName from it.
+ const std::string& modifiedEntryPointName =
+ compiler.get_entry_point(entryPointName, executionModel).name;
+
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some
// (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
@@ -183,18 +173,7 @@
}
}
- // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like
- // clean_func_name:
- // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
- const char* metalFunctionName = functionName;
- if (strcmp(metalFunctionName, "main") == 0) {
- metalFunctionName = "main0";
- }
- if (strcmp(metalFunctionName, "saturate") == 0) {
- metalFunctionName = "saturate0";
- }
-
- NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName];
+ NSString* name = [[NSString alloc] initWithUTF8String:modifiedEntryPointName.c_str()];
out->function = [library newFunctionWithName:name];
[library release];
}
@@ -202,7 +181,7 @@
out->needsStorageBufferLength = compiler.needs_buffer_size_buffer();
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
- GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) {
+ GetEntryPoint(entryPointName, stage).usedVertexAttributes.any()) {
out->needsStorageBufferLength = true;
}