Propagating errors out of GetFunction in MTL backend
BUG=dawn:303
Change-Id: Iff1903aecae4c043b222208b3eab5efdf9774b52
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14501
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/metal/ComputePipelineMTL.h b/src/dawn_native/metal/ComputePipelineMTL.h
index 71b5ba3..6ff1d01 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.h
+++ b/src/dawn_native/metal/ComputePipelineMTL.h
@@ -25,7 +25,8 @@
class ComputePipeline : public ComputePipelineBase {
public:
- ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
+ static ResultOrError<ComputePipeline*> Create(Device* device,
+ const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
void Encode(id<MTLComputeCommandEncoder> encoder);
@@ -33,6 +34,9 @@
bool RequiresStorageBufferLength() const;
private:
+ using ComputePipelineBase::ComputePipelineBase;
+ MaybeError Initialize(const ComputePipelineDescriptor* descriptor);
+
id<MTLComputePipelineState> mMtlComputePipelineState = nil;
MTLSize mLocalWorkgroupSize;
bool mRequiresStorageBufferLength;
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index fd72364..0a66866 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -19,27 +19,37 @@
namespace dawn_native { namespace metal {
- ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
- : ComputePipelineBase(device, descriptor) {
+ // static
+ ResultOrError<ComputePipeline*> ComputePipeline::Create(
+ Device* device,
+ const ComputePipelineDescriptor* descriptor) {
+ std::unique_ptr<ComputePipeline> pipeline =
+ std::make_unique<ComputePipeline>(device, descriptor);
+ DAWN_TRY(pipeline->Initialize(descriptor));
+ return pipeline.release();
+ }
+
+ MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) {
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
const ShaderModule* computeModule = ToBackend(descriptor->computeStage.module);
const char* computeEntryPoint = descriptor->computeStage.entryPoint;
- ShaderModule::MetalFunctionData computeData = computeModule->GetFunction(
- computeEntryPoint, SingleShaderStage::Compute, ToBackend(GetLayout()));
+ ShaderModule::MetalFunctionData computeData;
+ DAWN_TRY(computeModule->GetFunction(computeEntryPoint, SingleShaderStage::Compute,
+ ToBackend(GetLayout()), &computeData));
NSError* error = nil;
mMtlComputePipelineState =
[mtlDevice newComputePipelineStateWithFunction:computeData.function error:&error];
if (error != nil) {
NSLog(@" error => %@", error);
- GetDevice()->HandleError(wgpu::ErrorType::DeviceLost, "Error creating pipeline state");
- return;
+ return DAWN_DEVICE_LOST_ERROR("Error creating pipeline state");
}
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
mLocalWorkgroupSize = computeData.localWorkgroupSize;
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
+ return {};
}
ComputePipeline::~ComputePipeline() {
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 504e6b3..0249df1 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -111,7 +111,7 @@
}
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) {
- return new ComputePipeline(this, descriptor);
+ return ComputePipeline::Create(this, descriptor);
}
ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) {
@@ -122,7 +122,7 @@
}
ResultOrError<RenderPipelineBase*> Device::CreateRenderPipelineImpl(
const RenderPipelineDescriptor* descriptor) {
- return new RenderPipeline(this, descriptor);
+ return RenderPipeline::Create(this, descriptor);
}
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor);
diff --git a/src/dawn_native/metal/RenderPipelineMTL.h b/src/dawn_native/metal/RenderPipelineMTL.h
index bce358b..47fc048 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.h
+++ b/src/dawn_native/metal/RenderPipelineMTL.h
@@ -25,7 +25,8 @@
class RenderPipeline : public RenderPipelineBase {
public:
- RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor);
+ static ResultOrError<RenderPipeline*> Create(Device* device,
+ const RenderPipelineDescriptor* descriptor);
~RenderPipeline();
MTLIndexType GetMTLIndexType() const;
@@ -44,6 +45,9 @@
wgpu::ShaderStage GetStagesRequiringStorageBufferLength() const;
private:
+ using RenderPipelineBase::RenderPipelineBase;
+ MaybeError Initialize(const RenderPipelineDescriptor* descriptor);
+
MTLVertexDescriptor* MakeVertexDesc();
MTLIndexType mMtlIndexType;
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 3c4d852..c0e7af7 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -311,20 +311,31 @@
} // anonymous namespace
- RenderPipeline::RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor)
- : RenderPipelineBase(device, descriptor),
- mMtlIndexType(MTLIndexFormat(GetVertexStateDescriptor()->indexFormat)),
- mMtlPrimitiveTopology(MTLPrimitiveTopology(GetPrimitiveTopology())),
- mMtlFrontFace(MTLFrontFace(GetFrontFace())),
- mMtlCullMode(ToMTLCullMode(GetCullMode())) {
- auto mtlDevice = device->GetMTLDevice();
+ // static
+ ResultOrError<RenderPipeline*> RenderPipeline::Create(
+ Device* device,
+ const RenderPipelineDescriptor* descriptor) {
+ std::unique_ptr<RenderPipeline> pipeline =
+ std::make_unique<RenderPipeline>(device, descriptor);
+ DAWN_TRY(pipeline->Initialize(descriptor));
+ return pipeline.release();
+ }
+
+ MaybeError RenderPipeline::Initialize(const RenderPipelineDescriptor* descriptor) {
+ mMtlIndexType = MTLIndexFormat(GetVertexStateDescriptor()->indexFormat);
+ mMtlPrimitiveTopology = MTLPrimitiveTopology(GetPrimitiveTopology());
+ mMtlFrontFace = MTLFrontFace(GetFrontFace());
+ mMtlCullMode = ToMTLCullMode(GetCullMode());
+ auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new];
const ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
- ShaderModule::MetalFunctionData vertexData = vertexModule->GetFunction(
- vertexEntryPoint, SingleShaderStage::Vertex, ToBackend(GetLayout()));
+ ShaderModule::MetalFunctionData vertexData;
+ DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
+ ToBackend(GetLayout()), &vertexData));
+
descriptorMTL.vertexFunction = vertexData.function;
if (vertexData.needsStorageBufferLength) {
mStagesRequiringStorageBufferLength |= wgpu::ShaderStage::Vertex;
@@ -332,8 +343,10 @@
const ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module);
const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint;
- ShaderModule::MetalFunctionData fragmentData = fragmentModule->GetFunction(
- fragmentEntryPoint, SingleShaderStage::Fragment, ToBackend(GetLayout()));
+ ShaderModule::MetalFunctionData fragmentData;
+ DAWN_TRY(fragmentModule->GetFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
+ ToBackend(GetLayout()), &fragmentData));
+
descriptorMTL.fragmentFunction = fragmentData.function;
if (fragmentData.needsStorageBufferLength) {
mStagesRequiringStorageBufferLength |= wgpu::ShaderStage::Fragment;
@@ -372,9 +385,7 @@
[descriptorMTL release];
if (error != nil) {
NSLog(@" error => %@", error);
- device->HandleError(wgpu::ErrorType::DeviceLost,
- "Error creating rendering pipeline state");
- return;
+ return DAWN_DEVICE_LOST_ERROR("Error creating rendering pipeline state");
}
}
@@ -385,6 +396,8 @@
MakeDepthStencilDesc(GetDepthStencilStateDescriptor());
mMtlDepthStencilState = [mtlDevice newDepthStencilStateWithDescriptor:depthStencilDesc];
[depthStencilDesc release];
+
+ return {};
}
RenderPipeline::~RenderPipeline() {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index e259b69..45df04f 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -43,9 +43,10 @@
[function release];
}
};
- MetalFunctionData GetFunction(const char* functionName,
- SingleShaderStage functionStage,
- const PipelineLayout* layout) const;
+ MaybeError GetFunction(const char* functionName,
+ SingleShaderStage functionStage,
+ const PipelineLayout* layout,
+ MetalFunctionData* out) const;
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 79a9db5..5248b7d 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -92,17 +92,22 @@
return {};
}
- ShaderModule::MetalFunctionData ShaderModule::GetFunction(const char* functionName,
- SingleShaderStage functionStage,
- const PipelineLayout* layout) const {
+ MaybeError ShaderModule::GetFunction(const char* functionName,
+ SingleShaderStage functionStage,
+ const PipelineLayout* layout,
+ ShaderModule::MetalFunctionData* out) const {
+ ASSERT(!IsError());
+ ASSERT(out);
std::unique_ptr<spirv_cross::CompilerMSL> compiler_impl;
spirv_cross::CompilerMSL* compiler;
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
// Initializing the compiler is needed every call, because this method uses reflection
// to mutate the compiler's IR.
- mSpvcContext.InitializeForMsl(mSpirv.data(), mSpirv.size(), GetMSLCompileOptions());
- // TODO(rharrison): Handle initialize failing
-
+ if (mSpvcContext.InitializeForMsl(mSpirv.data(), mSpirv.size(),
+ GetMSLCompileOptions()) !=
+ shaderc_spvc_status_success) {
+ return DAWN_DEVICE_LOST_ERROR("Unable to initialize instance of spvc");
+ }
compiler = reinterpret_cast<spirv_cross::CompilerMSL*>(mSpvcContext.GetCompiler());
} else {
// If these options are changed, the values in DawnSPIRVCrossMSLFastFuzzer.cpp need to
@@ -147,12 +152,10 @@
}
}
- MetalFunctionData result;
-
{
spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage);
auto size = compiler->get_entry_point(functionName, executionModel).workgroup_size;
- result.localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
+ out->localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
}
{
@@ -167,9 +170,14 @@
options:nil
error:&error];
if (error != nil) {
- // TODO(cwallez@chromium.org): forward errors to caller
+ // TODO(cwallez@chromium.org): Switch that NSLog to use dawn::InfoLog or even be
+ // folded in the DAWN_VALIDATION_ERROR
NSLog(@"MTLDevice newLibraryWithSource => %@", error);
+ if (error.code != MTLLibraryErrorCompileWarning) {
+ return DAWN_VALIDATION_ERROR("Unable to create library object");
+ }
}
+
// TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like
// clean_func_name:
// https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
@@ -178,13 +186,13 @@
}
NSString* name = [NSString stringWithFormat:@"%s", functionName];
- result.function = [library newFunctionWithName:name];
+ out->function = [library newFunctionWithName:name];
[library release];
}
- result.needsStorageBufferLength = compiler->needs_buffer_size_buffer();
+ out->needsStorageBufferLength = compiler->needs_buffer_size_buffer();
- return result;
+ return {};
}
}} // namespace dawn_native::metal
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index 27aab13..2739902 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -139,14 +139,16 @@
wgpu::ShaderModule module =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
#version 450
+ shared uint i;
void main() {
- int i = 0;
+ i = 0;
})");
wgpu::ShaderModule sameModule =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
#version 450
+ shared uint i;
void main() {
- int i = 0;
+ i = 0;
})");
wgpu::ShaderModule otherModule =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
@@ -195,8 +197,9 @@
desc.computeStage.module =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
#version 450
+ shared uint i;
void main() {
- int i = 0;
+ i = 0;
})");
desc.layout = pl;
@@ -311,8 +314,9 @@
wgpu::ShaderModule otherModule =
utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, R"(
#version 450
+ layout (location = 0) out vec4 color;
void main() {
- int i = 0;
+ color = vec4(0.0);
})");
EXPECT_NE(module.Get(), otherModule.Get());