Descriptorize ComputePipeline
Change-Id: Ic9d7014ba44d927d7f9ddf81a8870432c68941e8
diff --git a/dawn.json b/dawn.json
index b5cfc2b..036f212 100644
--- a/dawn.json
+++ b/dawn.json
@@ -500,27 +500,13 @@
"compute pipeline": {
"category": "object"
},
- "compute pipeline builder": {
- "category": "object",
- "methods": [
- {
- "name": "get result",
- "returns": "compute pipeline"
- },
- {
- "name": "set layout",
- "args": [
- {"name": "layout", "type": "pipeline layout"}
- ]
- },
- {
- "name": "set stage",
- "args": [
- {"name": "stage", "type": "shader stage"},
- {"name": "module", "type": "shader module"},
- {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"}
- ]
- }
+ "compute pipeline descriptor": {
+ "category": "structure",
+ "extensible": true,
+ "members": [
+ {"name": "layout", "type": "pipeline layout"},
+ {"name": "module", "type": "shader module"},
+ {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"}
]
},
"device": {
@@ -569,8 +555,11 @@
"returns": "input state builder"
},
{
- "name": "create compute pipeline builder",
- "returns": "compute pipeline builder"
+ "name": "create compute pipeline",
+ "returns": "compute pipeline",
+ "args": [
+ {"name": "descriptor", "type": "compute pipeline descriptor", "annotation": "const*"}
+ ]
},
{
"name": "create render pipeline builder",
diff --git a/examples/ComputeBoids.cpp b/examples/ComputeBoids.cpp
index 99e351e..4a69710 100644
--- a/examples/ComputeBoids.cpp
+++ b/examples/ComputeBoids.cpp
@@ -231,10 +231,11 @@
dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
- updatePipeline = device.CreateComputePipelineBuilder()
- .SetLayout(pl)
- .SetStage(dawn::ShaderStage::Compute, module, "main")
- .GetResult();
+ dawn::ComputePipelineDescriptor csDesc;
+ csDesc.module = module.Clone();
+ csDesc.entryPoint = "main";
+ csDesc.layout = pl.Clone();
+ updatePipeline = device.CreateComputePipeline(&csDesc);
dawn::BufferView updateParamsView = updateParams.CreateBufferViewBuilder()
.SetExtent(0, sizeof(SimParams))
diff --git a/generator/main.py b/generator/main.py
index 5935e48..c88f2b4 100644
--- a/generator/main.py
+++ b/generator/main.py
@@ -166,7 +166,10 @@
for (member, m) in zip(members, struct.record['members']):
# TODO(kainino@chromium.org): More robust pointer/length handling?
if 'length' in m:
- member.length = members_by_name[m['length']]
+ if m['length'] == 'strlen':
+ member.length = 'strlen'
+ else:
+ member.length = members_by_name[m['length']]
def parse_json(json):
category_to_parser = {
diff --git a/generator/templates/apicpp.h b/generator/templates/apicpp.h
index f561867..f0b9aad 100644
--- a/generator/templates/apicpp.h
+++ b/generator/templates/apicpp.h
@@ -56,15 +56,7 @@
{% endfor %}
{% for type in by_category["structure"] %}
- struct {{as_cppType(type.name)}} {
- {% if type.extensible %}
- const void* nextInChain = nullptr;
- {% endif %}
- {% for member in type.members %}
- {{as_annotated_cppType(member)}};
- {% endfor %}
- };
-
+ struct {{as_cppType(type.name)}};
{% endfor %}
template<typename Derived, typename CType>
@@ -158,6 +150,18 @@
{% endfor %}
+ {% for type in by_category["structure"] %}
+ struct {{as_cppType(type.name)}} {
+ {% if type.extensible %}
+ const void* nextInChain = nullptr;
+ {% endif %}
+ {% for member in type.members %}
+ {{as_annotated_cppType(member)}};
+ {% endfor %}
+ };
+
+ {% endfor %}
+
} // namespace dawn
#endif // DAWN_DAWNCPP_H_
diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp
index 0ffcfe6..8829996 100644
--- a/src/dawn_native/ComputePipeline.cpp
+++ b/src/dawn_native/ComputePipeline.cpp
@@ -18,24 +18,31 @@
namespace dawn_native {
+ MaybeError ValidateComputePipelineDescriptor(DeviceBase*,
+ const ComputePipelineDescriptor* descriptor) {
+ DAWN_TRY_ASSERT(descriptor->nextInChain == nullptr, "nextInChain must be nullptr");
+
+ if (descriptor->entryPoint != std::string("main")) {
+ DAWN_RETURN_ERROR("Currently the entry point has to be main()");
+ }
+
+ if (descriptor->module->GetExecutionModel() != dawn::ShaderStage::Compute) {
+ DAWN_RETURN_ERROR("Setting module with wrong execution model");
+ }
+
+ if (!descriptor->module->IsCompatibleWithPipelineLayout(descriptor->layout)) {
+ DAWN_RETURN_ERROR("Stage not compatible with layout");
+ }
+
+ return {};
+ }
+
// ComputePipelineBase
- ComputePipelineBase::ComputePipelineBase(ComputePipelineBuilder* builder)
- : PipelineBase(builder) {
- if (GetStageMask() != dawn::ShaderStageBit::Compute) {
- builder->HandleError("Compute pipeline should have exactly a compute stage");
- return;
- }
- }
-
- // ComputePipelineBuilder
-
- ComputePipelineBuilder::ComputePipelineBuilder(DeviceBase* device)
- : Builder(device), PipelineBuilder(this) {
- }
-
- ComputePipelineBase* ComputePipelineBuilder::GetResultImpl() {
- return mDevice->CreateComputePipeline(this);
+ ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
+ const ComputePipelineDescriptor* descriptor)
+ : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
+ ExtractModuleData(dawn::ShaderStage::Compute, descriptor->module);
}
} // namespace dawn_native
diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h
index f81ab99..0e9e02d 100644
--- a/src/dawn_native/ComputePipeline.h
+++ b/src/dawn_native/ComputePipeline.h
@@ -19,17 +19,14 @@
namespace dawn_native {
+ class DeviceBase;
+
+ MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
+ const ComputePipelineDescriptor* descriptor);
+
class ComputePipelineBase : public RefCounted, public PipelineBase {
public:
- ComputePipelineBase(ComputePipelineBuilder* builder);
- };
-
- class ComputePipelineBuilder : public Builder<ComputePipelineBase>, public PipelineBuilder {
- public:
- ComputePipelineBuilder(DeviceBase* device);
-
- private:
- ComputePipelineBase* GetResultImpl() override;
+ ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor);
};
} // namespace dawn_native
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 217cf43..57f22dc 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -123,8 +123,15 @@
CommandBufferBuilder* DeviceBase::CreateCommandBufferBuilder() {
return new CommandBufferBuilder(this);
}
- ComputePipelineBuilder* DeviceBase::CreateComputePipelineBuilder() {
- return new ComputePipelineBuilder(this);
+ ComputePipelineBase* DeviceBase::CreateComputePipeline(
+ const ComputePipelineDescriptor* descriptor) {
+ ComputePipelineBase* result = nullptr;
+
+ if (ConsumedError(CreateComputePipelineInternal(&result, descriptor))) {
+ return nullptr;
+ }
+
+ return result;
}
DepthStencilStateBuilder* DeviceBase::CreateDepthStencilStateBuilder() {
return new DepthStencilStateBuilder(this);
@@ -223,6 +230,14 @@
return {};
}
+ MaybeError DeviceBase::CreateComputePipelineInternal(
+ ComputePipelineBase** result,
+ const ComputePipelineDescriptor* descriptor) {
+ DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
+ DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor));
+ return {};
+ }
+
MaybeError DeviceBase::CreatePipelineLayoutInternal(
PipelineLayoutBase** result,
const PipelineLayoutDescriptor* descriptor) {
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index cb76645..723455d 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -47,7 +47,6 @@
virtual BlendStateBase* CreateBlendState(BlendStateBuilder* builder) = 0;
virtual BufferViewBase* CreateBufferView(BufferViewBuilder* builder) = 0;
virtual CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) = 0;
- virtual ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) = 0;
virtual DepthStencilStateBase* CreateDepthStencilState(
DepthStencilStateBuilder* builder) = 0;
virtual InputStateBase* CreateInputState(InputStateBuilder* builder) = 0;
@@ -83,7 +82,7 @@
BlendStateBuilder* CreateBlendStateBuilder();
BufferBase* CreateBuffer(const BufferDescriptor* descriptor);
CommandBufferBuilder* CreateCommandBufferBuilder();
- ComputePipelineBuilder* CreateComputePipelineBuilder();
+ ComputePipelineBase* CreateComputePipeline(const ComputePipelineDescriptor* descriptor);
DepthStencilStateBuilder* CreateDepthStencilStateBuilder();
InputStateBuilder* CreateInputStateBuilder();
PipelineLayoutBase* CreatePipelineLayout(const PipelineLayoutDescriptor* descriptor);
@@ -108,6 +107,8 @@
virtual ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) = 0;
virtual ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) = 0;
+ virtual ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) = 0;
virtual ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) = 0;
virtual ResultOrError<QueueBase*> CreateQueueImpl() = 0;
@@ -121,6 +122,8 @@
MaybeError CreateBindGroupLayoutInternal(BindGroupLayoutBase** result,
const BindGroupLayoutDescriptor* descriptor);
MaybeError CreateBufferInternal(BufferBase** result, const BufferDescriptor* descriptor);
+ MaybeError CreateComputePipelineInternal(ComputePipelineBase** result,
+ const ComputePipelineDescriptor* descriptor);
MaybeError CreatePipelineLayoutInternal(PipelineLayoutBase** result,
const PipelineLayoutDescriptor* descriptor);
MaybeError CreateQueueInternal(QueueBase** result);
diff --git a/src/dawn_native/PerStage.h b/src/dawn_native/PerStage.h
index 5361270..b1b67af 100644
--- a/src/dawn_native/PerStage.h
+++ b/src/dawn_native/PerStage.h
@@ -48,6 +48,11 @@
template <typename T>
class PerStage {
public:
+ PerStage() = default;
+ PerStage(const T& initialValue) {
+ mData.fill(initialValue);
+ }
+
T& operator[](dawn::ShaderStage stage) {
DAWN_ASSERT(static_cast<uint32_t>(stage) < kNumStages);
return mData[static_cast<uint32_t>(stage)];
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index 4f27aa2..d22b9f0 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -24,8 +24,14 @@
// PipelineBase
- PipelineBase::PipelineBase(PipelineBuilder* builder)
- : mStageMask(builder->mStageMask), mLayout(std::move(builder->mLayout)) {
+ PipelineBase::PipelineBase(DeviceBase* device,
+ PipelineLayoutBase* layout,
+ dawn::ShaderStageBit stages)
+ : mStageMask(stages), mLayout(layout), mDevice(device) {
+ }
+
+ PipelineBase::PipelineBase(DeviceBase* device, PipelineBuilder* builder)
+ : mStageMask(builder->mStageMask), mLayout(std::move(builder->mLayout)), mDevice(device) {
if (!mLayout) {
PipelineLayoutDescriptor descriptor;
descriptor.numBindGroupLayouts = 0;
@@ -35,30 +41,32 @@
mLayout->Release();
}
- auto FillPushConstants = [](const ShaderModuleBase* module, PushConstantInfo* info) {
- const auto& moduleInfo = module->GetPushConstants();
- info->mask = moduleInfo.mask;
-
- for (uint32_t i = 0; i < moduleInfo.names.size(); i++) {
- uint32_t size = moduleInfo.sizes[i];
- if (size == 0) {
- continue;
- }
-
- for (uint32_t offset = 0; offset < size; offset++) {
- info->types[i + offset] = moduleInfo.types[i];
- }
- i += size - 1;
- }
- };
-
- for (auto stageBit : IterateStages(builder->mStageMask)) {
- if (!builder->mStages[stageBit].module->IsCompatibleWithPipelineLayout(mLayout.Get())) {
+ for (auto stage : IterateStages(builder->mStageMask)) {
+ if (!builder->mStages[stage].module->IsCompatibleWithPipelineLayout(mLayout.Get())) {
builder->GetParentBuilder()->HandleError("Stage not compatible with layout");
return;
}
- FillPushConstants(builder->mStages[stageBit].module.Get(), &mPushConstants[stageBit]);
+ ExtractModuleData(stage, builder->mStages[stage].module.Get());
+ }
+ }
+
+ void PipelineBase::ExtractModuleData(dawn::ShaderStage stage, ShaderModuleBase* module) {
+ PushConstantInfo* info = &mPushConstants[stage];
+
+ const auto& moduleInfo = module->GetPushConstants();
+ info->mask = moduleInfo.mask;
+
+ for (uint32_t i = 0; i < moduleInfo.names.size(); i++) {
+ uint32_t size = moduleInfo.sizes[i];
+ if (size == 0) {
+ continue;
+ }
+
+ for (uint32_t offset = 0; offset < size; offset++) {
+ info->types[i + offset] = moduleInfo.types[i];
+ }
+ i += size - 1;
}
}
@@ -75,6 +83,10 @@
return mLayout.Get();
}
+ DeviceBase* PipelineBase::GetDevice() const {
+ return mDevice;
+ }
+
// PipelineBuilder
PipelineBuilder::PipelineBuilder(BuilderBase* parentBuilder)
diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h
index 31b81be..0453e8d 100644
--- a/src/dawn_native/Pipeline.h
+++ b/src/dawn_native/Pipeline.h
@@ -39,7 +39,8 @@
class PipelineBase {
public:
- PipelineBase(PipelineBuilder* builder);
+ PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);
+ PipelineBase(DeviceBase* device, PipelineBuilder* builder);
struct PushConstantInfo {
std::bitset<kMaxPushConstants> mask;
@@ -49,11 +50,16 @@
dawn::ShaderStageBit GetStageMask() const;
PipelineLayoutBase* GetLayout();
+ DeviceBase* GetDevice() const;
+
+ protected:
+ void ExtractModuleData(dawn::ShaderStage stage, ShaderModuleBase* module);
private:
dawn::ShaderStageBit mStageMask;
Ref<PipelineLayoutBase> mLayout;
PerStage<PushConstantInfo> mPushConstants;
+ DeviceBase* mDevice;
};
class PipelineBuilder {
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 3b5dbc6..56b2c9d 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -27,7 +27,7 @@
// RenderPipelineBase
RenderPipelineBase::RenderPipelineBase(RenderPipelineBuilder* builder)
- : PipelineBase(builder),
+ : PipelineBase(builder->mDevice, builder),
mDepthStencilState(std::move(builder->mDepthStencilState)),
mIndexFormat(builder->mIndexFormat),
mInputState(std::move(builder->mInputState)),
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 8599d04..edcbd93 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -22,8 +22,8 @@
namespace dawn_native { namespace d3d12 {
- ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
- : ComputePipelineBase(builder), mDevice(ToBackend(builder->GetDevice())) {
+ ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
+ : ComputePipelineBase(device, descriptor) {
uint32_t compileFlags = 0;
#if defined(_DEBUG)
// Enable better shader debugging with the graphics debugging tools.
@@ -32,33 +32,31 @@
// SPRIV-cross does matrix multiplication expecting row major matrices
compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
- const auto& module = ToBackend(builder->GetStageInfo(dawn::ShaderStage::Compute).module);
- const auto& entryPoint = builder->GetStageInfo(dawn::ShaderStage::Compute).entryPoint;
- const auto& hlslSource = module->GetHLSLSource();
+ const ShaderModule* module = ToBackend(descriptor->module);
+ const std::string& hlslSource = module->GetHLSLSource();
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
- const PlatformFunctions* functions = ToBackend(builder->GetDevice())->GetFunctions();
+ const PlatformFunctions* functions = device->GetFunctions();
if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
- nullptr, entryPoint.c_str(), "cs_5_1", compileFlags, 0,
+ nullptr, descriptor->entryPoint, "cs_5_1", compileFlags, 0,
&compiledShader, &errors))) {
printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer()));
ASSERT(false);
}
- D3D12_COMPUTE_PIPELINE_STATE_DESC descriptor = {};
- descriptor.pRootSignature = ToBackend(GetLayout())->GetRootSignature().Get();
- descriptor.CS.pShaderBytecode = compiledShader->GetBufferPointer();
- descriptor.CS.BytecodeLength = compiledShader->GetBufferSize();
+ D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
+ d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature().Get();
+ d3dDesc.CS.pShaderBytecode = compiledShader->GetBufferPointer();
+ d3dDesc.CS.BytecodeLength = compiledShader->GetBufferSize();
- Device* device = ToBackend(builder->GetDevice());
- device->GetD3D12Device()->CreateComputePipelineState(&descriptor,
+ device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
IID_PPV_ARGS(&mPipelineState));
}
ComputePipeline::~ComputePipeline() {
- mDevice->ReferenceUntilUnused(mPipelineState);
+ ToBackend(GetDevice())->ReferenceUntilUnused(mPipelineState);
}
ComPtr<ID3D12PipelineState> ComputePipeline::GetPipelineState() {
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.h b/src/dawn_native/d3d12/ComputePipelineD3D12.h
index e3335c3..7b1af9f 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.h
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.h
@@ -25,14 +25,13 @@
class ComputePipeline : public ComputePipelineBase {
public:
- ComputePipeline(ComputePipelineBuilder* builder);
+ ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
ComPtr<ID3D12PipelineState> GetPipelineState();
private:
ComPtr<ID3D12PipelineState> mPipelineState;
- Device* mDevice = nullptr;
};
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index a51812e..0dc3c88 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -301,8 +301,9 @@
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
- ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
- return new ComputePipeline(builder);
+ ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) {
+ return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index cbb3523..97b2289 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -43,7 +43,6 @@
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
- ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@@ -79,6 +78,8 @@
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
+ ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
diff --git a/src/dawn_native/metal/ComputePipelineMTL.h b/src/dawn_native/metal/ComputePipelineMTL.h
index 5f216a8..6f3aca9 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.h
+++ b/src/dawn_native/metal/ComputePipelineMTL.h
@@ -21,9 +21,11 @@
namespace dawn_native { namespace metal {
+ class Device;
+
class ComputePipeline : public ComputePipelineBase {
public:
- ComputePipeline(ComputePipelineBuilder* builder);
+ ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
void Encode(id<MTLComputeCommandEncoder> encoder);
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index 026f524..5c6eaa6 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -19,22 +19,22 @@
namespace dawn_native { namespace metal {
- ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
- : ComputePipelineBase(builder) {
- auto mtlDevice = ToBackend(builder->GetDevice())->GetMTLDevice();
+ ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
+ : ComputePipelineBase(device, descriptor) {
+ auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
- const auto& module = ToBackend(builder->GetStageInfo(dawn::ShaderStage::Compute).module);
- const auto& entryPoint = builder->GetStageInfo(dawn::ShaderStage::Compute).entryPoint;
+ const auto& module = ToBackend(descriptor->module);
+ const char* entryPoint = descriptor->entryPoint;
- auto compilationData = module->GetFunction(entryPoint.c_str(), dawn::ShaderStage::Compute,
- ToBackend(GetLayout()));
+ auto compilationData =
+ module->GetFunction(entryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout()));
NSError* error = nil;
mMtlComputePipelineState =
[mtlDevice newComputePipelineStateWithFunction:compilationData.function error:&error];
if (error != nil) {
NSLog(@" error => %@", error);
- builder->HandleError("Error creating pipeline state");
+ GetDevice()->HandleError("Error creating pipeline state");
return;
}
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index dbcb941..5f93d15 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -39,7 +39,6 @@
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
- ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@@ -63,6 +62,8 @@
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
+ ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 6ecd6b1..79a326d 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -97,8 +97,9 @@
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
- ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
- return new ComputePipeline(builder);
+ ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) {
+ return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);
diff --git a/src/dawn_native/null/NullBackend.cpp b/src/dawn_native/null/NullBackend.cpp
index 8331614..11f13c2 100644
--- a/src/dawn_native/null/NullBackend.cpp
+++ b/src/dawn_native/null/NullBackend.cpp
@@ -52,8 +52,9 @@
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
- ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
- return new ComputePipeline(builder);
+ ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) {
+ return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);
diff --git a/src/dawn_native/null/NullBackend.h b/src/dawn_native/null/NullBackend.h
index 3b206a3..1a34814 100644
--- a/src/dawn_native/null/NullBackend.h
+++ b/src/dawn_native/null/NullBackend.h
@@ -99,7 +99,6 @@
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
- ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@@ -117,6 +116,8 @@
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
+ ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
diff --git a/src/dawn_native/opengl/ComputePipelineGL.cpp b/src/dawn_native/opengl/ComputePipelineGL.cpp
index f80b2bb..815e4d7 100644
--- a/src/dawn_native/opengl/ComputePipelineGL.cpp
+++ b/src/dawn_native/opengl/ComputePipelineGL.cpp
@@ -14,10 +14,16 @@
#include "dawn_native/opengl/ComputePipelineGL.h"
+#include "dawn_native/opengl/DeviceGL.h"
+
namespace dawn_native { namespace opengl {
- ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
- : ComputePipelineBase(builder), PipelineGL(this, builder) {
+ ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
+ : ComputePipelineBase(device, descriptor) {
+ PerStage<const ShaderModule*> modules(nullptr);
+ modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->module);
+
+ PipelineGL::Initialize(ToBackend(descriptor->layout), modules);
}
void ComputePipeline::ApplyNow() {
diff --git a/src/dawn_native/opengl/ComputePipelineGL.h b/src/dawn_native/opengl/ComputePipelineGL.h
index 654353c..12856ca 100644
--- a/src/dawn_native/opengl/ComputePipelineGL.h
+++ b/src/dawn_native/opengl/ComputePipelineGL.h
@@ -23,9 +23,11 @@
namespace dawn_native { namespace opengl {
+ class Device;
+
class ComputePipeline : public ComputePipelineBase, public PipelineGL {
public:
- ComputePipeline(ComputePipelineBuilder* builder);
+ ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
void ApplyNow();
};
diff --git a/src/dawn_native/opengl/DeviceGL.cpp b/src/dawn_native/opengl/DeviceGL.cpp
index a9dcd55..8066ebd 100644
--- a/src/dawn_native/opengl/DeviceGL.cpp
+++ b/src/dawn_native/opengl/DeviceGL.cpp
@@ -65,8 +65,9 @@
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
- ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
- return new ComputePipeline(builder);
+ ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) {
+ return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);
diff --git a/src/dawn_native/opengl/DeviceGL.h b/src/dawn_native/opengl/DeviceGL.h
index dfe2644..87c27cf 100644
--- a/src/dawn_native/opengl/DeviceGL.h
+++ b/src/dawn_native/opengl/DeviceGL.h
@@ -36,7 +36,6 @@
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
- ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@@ -51,6 +50,8 @@
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
+ ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
diff --git a/src/dawn_native/opengl/PipelineGL.cpp b/src/dawn_native/opengl/PipelineGL.cpp
index 3148d06..c3cdc33 100644
--- a/src/dawn_native/opengl/PipelineGL.cpp
+++ b/src/dawn_native/opengl/PipelineGL.cpp
@@ -43,7 +43,11 @@
} // namespace
- PipelineGL::PipelineGL(PipelineBase* parent, PipelineBuilder* builder) {
+ PipelineGL::PipelineGL() {
+ }
+
+ void PipelineGL::Initialize(const PipelineLayout* layout,
+ const PerStage<const ShaderModule*>& modules) {
auto CreateShader = [](GLenum type, const char* source) -> GLuint {
GLuint shader = glCreateShader(type);
glShaderSource(shader, 1, &source, nullptr);
@@ -91,10 +95,15 @@
mProgram = glCreateProgram();
- for (auto stage : IterateStages(parent->GetStageMask())) {
- const ShaderModule* module = ToBackend(builder->GetStageInfo(stage).module.Get());
+ dawn::ShaderStageBit activeStages = dawn::ShaderStageBit::None;
+ for (dawn::ShaderStage stage : IterateStages(kAllStages)) {
+ if (modules[stage] != nullptr) {
+ activeStages |= StageBit(stage);
+ }
+ }
- GLuint shader = CreateShader(GLShaderType(stage), module->GetSource());
+ for (dawn::ShaderStage stage : IterateStages(activeStages)) {
+ GLuint shader = CreateShader(GLShaderType(stage), modules[stage]->GetSource());
glAttachShader(mProgram, shader);
}
@@ -114,16 +123,14 @@
}
}
- for (auto stage : IterateStages(parent->GetStageMask())) {
- const ShaderModule* module = ToBackend(builder->GetStageInfo(stage).module.Get());
- FillPushConstants(module, &mGlPushConstants[stage], mProgram);
+ for (dawn::ShaderStage stage : IterateStages(activeStages)) {
+ FillPushConstants(modules[stage], &mGlPushConstants[stage], mProgram);
}
glUseProgram(mProgram);
// The uniforms are part of the program state so we can pre-bind buffer units, texture units
// etc.
- const auto& layout = ToBackend(parent->GetLayout());
const auto& indices = layout->GetBindingIndexInfo();
for (uint32_t group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
@@ -159,10 +166,8 @@
// Compute links between stages for combined samplers, then bind them to texture units
{
std::set<CombinedSampler> combinedSamplersSet;
- for (auto stage : IterateStages(parent->GetStageMask())) {
- const auto& module = ToBackend(builder->GetStageInfo(stage).module);
-
- for (const auto& combined : module->GetCombinedSamplerInfo()) {
+ for (dawn::ShaderStage stage : IterateStages(activeStages)) {
+ for (const auto& combined : modules[stage]->GetCombinedSamplerInfo()) {
combinedSamplersSet.insert(combined);
}
}
diff --git a/src/dawn_native/opengl/PipelineGL.h b/src/dawn_native/opengl/PipelineGL.h
index d216c57..400e0f5 100644
--- a/src/dawn_native/opengl/PipelineGL.h
+++ b/src/dawn_native/opengl/PipelineGL.h
@@ -25,11 +25,14 @@
class Device;
class PersistentPipelineState;
+ class PipelineLayout;
class ShaderModule;
class PipelineGL {
public:
- PipelineGL(PipelineBase* parent, PipelineBuilder* builder);
+ PipelineGL();
+
+ void Initialize(const PipelineLayout* layout, const PerStage<const ShaderModule*>& modules);
using GLPushConstantInfo = std::array<GLint, kMaxPushConstants>;
using BindingLocations =
diff --git a/src/dawn_native/opengl/RenderPipelineGL.cpp b/src/dawn_native/opengl/RenderPipelineGL.cpp
index d9911be..f6faca0 100644
--- a/src/dawn_native/opengl/RenderPipelineGL.cpp
+++ b/src/dawn_native/opengl/RenderPipelineGL.cpp
@@ -43,8 +43,13 @@
RenderPipeline::RenderPipeline(RenderPipelineBuilder* builder)
: RenderPipelineBase(builder),
- PipelineGL(this, builder),
mGlPrimitiveTopology(GLPrimitiveTopology(GetPrimitiveTopology())) {
+ PerStage<const ShaderModule*> modules(nullptr);
+ for (dawn::ShaderStage stage : IterateStages(GetStageMask())) {
+ modules[stage] = ToBackend(builder->GetStageInfo(stage).module.Get());
+ }
+
+ PipelineGL::Initialize(ToBackend(GetLayout()), modules);
}
GLenum RenderPipeline::GetGLPrimitiveTopology() const {
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index 7e67f62..06948b3 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -21,34 +21,33 @@
namespace dawn_native { namespace vulkan {
- ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
- : ComputePipelineBase(builder), mDevice(ToBackend(builder->GetDevice())) {
+ ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
+ : ComputePipelineBase(device, descriptor) {
VkComputePipelineCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
createInfo.pNext = nullptr;
createInfo.flags = 0;
- createInfo.layout = ToBackend(GetLayout())->GetHandle();
+ createInfo.layout = ToBackend(descriptor->layout)->GetHandle();
createInfo.basePipelineHandle = VK_NULL_HANDLE;
createInfo.basePipelineIndex = -1;
- const auto& stageInfo = builder->GetStageInfo(dawn::ShaderStage::Compute);
createInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
createInfo.stage.pNext = nullptr;
createInfo.stage.flags = 0;
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
- createInfo.stage.module = ToBackend(stageInfo.module)->GetHandle();
- createInfo.stage.pName = stageInfo.entryPoint.c_str();
+ createInfo.stage.module = ToBackend(descriptor->module)->GetHandle();
+ createInfo.stage.pName = descriptor->entryPoint;
createInfo.stage.pSpecializationInfo = nullptr;
- if (mDevice->fn.CreateComputePipelines(mDevice->GetVkDevice(), VK_NULL_HANDLE, 1,
- &createInfo, nullptr, &mHandle) != VK_SUCCESS) {
+ if (device->fn.CreateComputePipelines(device->GetVkDevice(), VK_NULL_HANDLE, 1, &createInfo,
+ nullptr, &mHandle) != VK_SUCCESS) {
ASSERT(false);
}
}
ComputePipeline::~ComputePipeline() {
if (mHandle != VK_NULL_HANDLE) {
- mDevice->GetFencedDeleter()->DeleteWhenUnused(mHandle);
+ ToBackend(GetDevice())->GetFencedDeleter()->DeleteWhenUnused(mHandle);
mHandle = VK_NULL_HANDLE;
}
}
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.h b/src/dawn_native/vulkan/ComputePipelineVk.h
index 340d9fe..d1b589c 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.h
+++ b/src/dawn_native/vulkan/ComputePipelineVk.h
@@ -25,14 +25,13 @@
class ComputePipeline : public ComputePipelineBase {
public:
- ComputePipeline(ComputePipelineBuilder* builder);
+ ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
VkPipeline GetHandle() const;
private:
VkPipeline mHandle = VK_NULL_HANDLE;
- Device* mDevice = nullptr;
};
}} // namespace dawn_native::vulkan
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index e9a9183..007dead 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -236,8 +236,9 @@
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
- ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
- return new ComputePipeline(builder);
+ ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) {
+ return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 764b125..5e5ad2e 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -67,7 +67,6 @@
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
- ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@@ -82,6 +81,8 @@
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
+ ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
+ const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
diff --git a/src/tests/end2end/ComputeCopyStorageBufferTests.cpp b/src/tests/end2end/ComputeCopyStorageBufferTests.cpp
index 15ec8b3..834185e 100644
--- a/src/tests/end2end/ComputeCopyStorageBufferTests.cpp
+++ b/src/tests/end2end/ComputeCopyStorageBufferTests.cpp
@@ -37,10 +37,12 @@
// Set up shader and pipeline
auto module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, shader);
auto pl = utils::MakeBasicPipelineLayout(device, &bgl);
- auto pipeline = device.CreateComputePipelineBuilder()
- .SetLayout(pl)
- .SetStage(dawn::ShaderStage::Compute, module, "main")
- .GetResult();
+
+ dawn::ComputePipelineDescriptor csDesc;
+ csDesc.module = module.Clone();
+ csDesc.entryPoint = "main";
+ csDesc.layout = pl.Clone();
+ dawn::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
// Set up src storage buffer
dawn::BufferDescriptor srcDesc;
diff --git a/src/tests/end2end/PushConstantTests.cpp b/src/tests/end2end/PushConstantTests.cpp
index baac0be..0417ddf 100644
--- a/src/tests/end2end/PushConstantTests.cpp
+++ b/src/tests/end2end/PushConstantTests.cpp
@@ -145,10 +145,11 @@
})").c_str()
);
- return device.CreateComputePipelineBuilder()
- .SetLayout(pl)
- .SetStage(dawn::ShaderStage::Compute, module, "main")
- .GetResult();
+ dawn::ComputePipelineDescriptor descriptor;
+ descriptor.module = module.Clone();
+ descriptor.entryPoint = "main";
+ descriptor.layout = pl.Clone();
+ return device.CreateComputePipeline(&descriptor);
}
dawn::PipelineLayout MakeEmptyLayout() {