Metal: Make the MSL indices match the ones of PipelineLayout
Previously didn't tell SPIRV-Cross at which MSL resource indices the
different SPIRV binding should be, and were lucky that it's giving
indices in increasing order matched the PipelineLayout in all our
samples.
Fix this by making SPIRV->MSL compilation depend on the PipelineLayout
so we can tell SPIRV-Cross which binding goes where.
We should do the same for vertex attributes eventually as they are
hardcoded to start at kMaxBindingsPerGroup currently.
Also a couple unrelated cleanups (unused function, usage of
IterateBitSet).
diff --git a/src/backend/ShaderModule.cpp b/src/backend/ShaderModule.cpp
index 267baa4..12d2054 100644
--- a/src/backend/ShaderModule.cpp
+++ b/src/backend/ShaderModule.cpp
@@ -27,6 +27,10 @@
: device(builder->device) {
}
+ DeviceBase* ShaderModuleBase::GetDevice() const {
+ return device;
+ }
+
void ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) {
// TODO(cwallez@chromium.org): make errors here builder-level
// currently errors here do not prevent the shadermodule from being used
diff --git a/src/backend/ShaderModule.h b/src/backend/ShaderModule.h
index 35e7412..73ec5b6 100644
--- a/src/backend/ShaderModule.h
+++ b/src/backend/ShaderModule.h
@@ -36,6 +36,8 @@
public:
ShaderModuleBase(ShaderModuleBuilder* builder);
+ DeviceBase* GetDevice() const;
+
void ExtractSpirvInfo(const spirv_cross::Compiler& compiler);
struct PushConstantInfo {
diff --git a/src/backend/d3d12/ShaderModuleD3D12.cpp b/src/backend/d3d12/ShaderModuleD3D12.cpp
index ec6fc9f..1c3f522 100644
--- a/src/backend/d3d12/ShaderModuleD3D12.cpp
+++ b/src/backend/d3d12/ShaderModuleD3D12.cpp
@@ -32,14 +32,6 @@
ExtractSpirvInfo(compiler);
- enum RegisterType {
- Buffer,
- UnorderedAccess,
- Texture,
- Sampler,
- Count,
- };
-
// rename bindings so that each register type b/u/t/s starts at 0 and then offset by kMaxBindingsPerGroup * bindGroupIndex
auto RenumberBindings = [&](std::vector<spirv_cross::Resource> resources) {
std::array<uint32_t, kMaxBindGroups> baseRegisters = {};
diff --git a/src/backend/metal/ComputePipelineMTL.mm b/src/backend/metal/ComputePipelineMTL.mm
index a3d0593..da9551a 100644
--- a/src/backend/metal/ComputePipelineMTL.mm
+++ b/src/backend/metal/ComputePipelineMTL.mm
@@ -28,11 +28,11 @@
const auto& module = ToBackend(builder->GetStageInfo(nxt::ShaderStage::Compute).module);
const auto& entryPoint = builder->GetStageInfo(nxt::ShaderStage::Compute).entryPoint;
- id<MTLFunction> function = module->GetFunction(entryPoint.c_str());
+ auto compilationData = module->GetFunction(entryPoint.c_str(), ToBackend(GetLayout()));
NSError *error = nil;
mtlComputePipelineState = [mtlDevice
- newComputePipelineStateWithFunction:function error:&error];
+ newComputePipelineStateWithFunction:compilationData.function error:&error];
if (error != nil) {
NSLog(@" error => %@", error);
builder->HandleError("Error creating pipeline state");
@@ -40,7 +40,7 @@
}
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
- localWorkgroupSize = module->GetLocalWorkGroupSize(entryPoint);
+ localWorkgroupSize = compilationData.localWorkgroupSize;
}
ComputePipeline::~ComputePipeline() {
diff --git a/src/backend/metal/InputStateMTL.mm b/src/backend/metal/InputStateMTL.mm
index adc3826..d72f1c6 100644
--- a/src/backend/metal/InputStateMTL.mm
+++ b/src/backend/metal/InputStateMTL.mm
@@ -15,6 +15,7 @@
#include "backend/metal/InputStateMTL.h"
#include "backend/metal/MetalBackend.h"
+#include "common/BitSetIterator.h"
namespace backend {
namespace metal {
@@ -62,11 +63,7 @@
[attribDesc release];
}
- const auto& inputsSetMask = GetInputsSetMask();
- for (uint32_t i = 0; i < inputsSetMask.size(); ++i) {
- if (!inputsSetMask[i]) {
- continue;
- }
+ for (uint32_t i : IterateBitSet(GetInputsSetMask())) {
const InputInfo& info = GetInput(i);
auto layoutDesc = [MTLVertexBufferLayoutDescriptor new];
@@ -83,6 +80,7 @@
layoutDesc.stepRate = 1;
layoutDesc.stride = info.stride;
}
+ // TODO(cwallez@chromium.org): make the offset depend on the pipeline layout
mtlVertexDescriptor.layouts[kMaxBindingsPerGroup + i] = layoutDesc;
[layoutDesc release];
}
diff --git a/src/backend/metal/PipelineLayoutMTL.mm b/src/backend/metal/PipelineLayoutMTL.mm
index 071c144..3714a3a 100644
--- a/src/backend/metal/PipelineLayoutMTL.mm
+++ b/src/backend/metal/PipelineLayoutMTL.mm
@@ -23,7 +23,8 @@
: PipelineLayoutBase(builder) {
// Each stage has its own numbering namespace in CompilerMSL.
for (auto stage : IterateStages(kAllStages)) {
- uint32_t bufferIndex = 0;
+ // Buffer number 0 is reserved for push constants
+ uint32_t bufferIndex = 1;
uint32_t samplerIndex = 0;
uint32_t textureIndex = 0;
diff --git a/src/backend/metal/RenderPipelineMTL.mm b/src/backend/metal/RenderPipelineMTL.mm
index 3f50799..b851056 100644
--- a/src/backend/metal/RenderPipelineMTL.mm
+++ b/src/backend/metal/RenderPipelineMTL.mm
@@ -64,7 +64,7 @@
const auto& module = ToBackend(builder->GetStageInfo(stage).module);
const auto& entryPoint = builder->GetStageInfo(stage).entryPoint;
- id<MTLFunction> function = module->GetFunction(entryPoint.c_str());
+ id<MTLFunction> function = module->GetFunction(entryPoint.c_str(), ToBackend(GetLayout())).function;
switch (stage) {
case nxt::ShaderStage::Vertex:
diff --git a/src/backend/metal/ShaderModuleMTL.h b/src/backend/metal/ShaderModuleMTL.h
index 3896922..dde12ba 100644
--- a/src/backend/metal/ShaderModuleMTL.h
+++ b/src/backend/metal/ShaderModuleMTL.h
@@ -26,16 +26,20 @@
namespace backend {
namespace metal {
+ class PipelineLayout;
+
class ShaderModule : public ShaderModuleBase {
public:
ShaderModule(ShaderModuleBuilder* builder);
~ShaderModule();
- id<MTLFunction> GetFunction(const char* functionName) const;
- MTLSize GetLocalWorkGroupSize(const std::string& entryPoint) const;
+ struct MetalFunctionData {
+ id<MTLFunction> function;
+ MTLSize localWorkgroupSize;
+ };
+ MetalFunctionData GetFunction(const char* functionName, const PipelineLayout* layout) const;
private:
- id<MTLLibrary> mtlLibrary = nil;
spirv_cross::CompilerMSL* compiler = nullptr;
};
diff --git a/src/backend/metal/ShaderModuleMTL.mm b/src/backend/metal/ShaderModuleMTL.mm
index 5b2b804..d7abff8 100644
--- a/src/backend/metal/ShaderModuleMTL.mm
+++ b/src/backend/metal/ShaderModuleMTL.mm
@@ -15,6 +15,7 @@
#include "backend/metal/ShaderModuleMTL.h"
#include "backend/metal/MetalBackend.h"
+#include "backend/metal/PipelineLayoutMTL.h"
#include <spirv-cross/spirv_msl.hpp>
@@ -23,40 +24,103 @@
namespace backend {
namespace metal {
+ namespace {
+
+ spv::ExecutionModel SpirvExecutionModelForStage(nxt::ShaderStage stage) {
+ switch(stage) {
+ case nxt::ShaderStage::Vertex:
+ return spv::ExecutionModelVertex;
+ case nxt::ShaderStage::Fragment:
+ return spv::ExecutionModelFragment;
+ case nxt::ShaderStage::Compute:
+ return spv::ExecutionModelGLCompute;
+ default:
+ UNREACHABLE();
+ }
+ }
+
+ }
+
ShaderModule::ShaderModule(ShaderModuleBuilder* builder)
: ShaderModuleBase(builder) {
compiler = new spirv_cross::CompilerMSL(builder->AcquireSpirv());
ExtractSpirvInfo(*compiler);
-
- std::string msl = compiler->compile();
- NSString* mslSource = [NSString stringWithFormat:@"%s", msl.c_str()];
-
- auto mtlDevice = ToBackend(builder->GetDevice())->GetMTLDevice();
- NSError *error = nil;
- mtlLibrary = [mtlDevice newLibraryWithSource:mslSource options:nil error:&error];
- if (error != nil) {
- NSLog(@"MTLDevice newLibraryWithSource => %@", error);
- builder->HandleError("Error creating MTLLibrary from MSL source");
- }
}
ShaderModule::~ShaderModule() {
delete compiler;
}
- id<MTLFunction> ShaderModule::GetFunction(const char* functionName) const {
- // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like clean_func_name:
- // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
- if (strcmp(functionName, "main") == 0) {
- functionName = "main0";
- }
- NSString* name = [NSString stringWithFormat:@"%s", functionName];
- return [mtlLibrary newFunctionWithName:name];
- }
+ ShaderModule::MetalFunctionData ShaderModule::GetFunction(const char* functionName,
+ const PipelineLayout* layout) const {
+ // By default SPIRV-Cross will give MSL resources indices in increasing order.
+ // To make the MSL indices match the indices chosen in the PipelineLayout, we build
+ // a table of MSLResourceBinding to give to SPIRV-Cross
+ std::vector<spirv_cross::MSLResourceBinding> mslBindings;
- MTLSize ShaderModule::GetLocalWorkGroupSize(const std::string& entryPoint) const {
- auto size = compiler->get_entry_point(entryPoint).workgroup_size;
- return MTLSizeMake(size.x, size.y, size.z);
+ // Reserve index 0 for buffers for the push constants buffer.
+ for (auto stage : IterateStages(kAllStages)) {
+ spirv_cross::MSLResourceBinding binding;
+ binding.stage = SpirvExecutionModelForStage(stage);
+ binding.desc_set = spirv_cross::kPushConstDescSet;
+ binding.binding = spirv_cross::kPushConstBinding;
+ binding.msl_buffer = 0;
+
+ mslBindings.push_back(binding);
+ }
+
+ // Create one resource binding entry per stage per binding.
+ for (uint32_t group : IterateBitSet(layout->GetBindGroupsLayoutMask())) {
+
+ const auto& bgInfo = layout->GetBindGroupLayout(group)->GetBindingInfo();
+ for (uint32_t binding : IterateBitSet(bgInfo.mask)) {
+
+ for (auto stage : IterateStages(bgInfo.visibilities[binding])) {
+ uint32_t index = layout->GetBindingIndexInfo(stage)[group][binding];
+
+ spirv_cross::MSLResourceBinding mslBinding;
+ mslBinding.stage = SpirvExecutionModelForStage(stage);
+ mslBinding.desc_set = group;
+ mslBinding.binding = binding;
+ mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler = index;
+
+ mslBindings.push_back(mslBinding);
+ }
+ }
+ }
+
+ MetalFunctionData result;
+
+ {
+ auto size = compiler->get_entry_point(functionName).workgroup_size;
+ result.localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
+ }
+
+ {
+ // SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing
+ // by default.
+ std::string msl = compiler->compile(nullptr, &mslBindings);
+ NSString* mslSource = [NSString stringWithFormat:@"%s", msl.c_str()];
+
+ auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
+ NSError *error = nil;
+ id<MTLLibrary> library = [mtlDevice newLibraryWithSource:mslSource options:nil error:&error];
+ if (error != nil) {
+ // TODO(cwallez@chromium.org): forward errors to caller
+ NSLog(@"MTLDevice newLibraryWithSource => %@", error);
+ }
+ // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like clean_func_name:
+ // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
+ if (strcmp(functionName, "main") == 0) {
+ functionName = "main0";
+ }
+
+ NSString* name = [NSString stringWithFormat:@"%s", functionName];
+ result.function = [library newFunctionWithName:name];
+ [library release];
+ }
+
+ return result;
}
}