Descriptorize ShaderModule
Change-Id: Ic79d00380f583485de0fb05bd47b1f869919ebe6
diff --git a/dawn.json b/dawn.json
index 6a37e4f..ec2d54e 100644
--- a/dawn.json
+++ b/dawn.json
@@ -596,8 +596,11 @@
]
},
{
- "name": "create shader module builder",
- "returns": "shader module builder"
+ "name": "create shader module",
+ "returns": "shader module",
+ "args": [
+ {"name": "descriptor", "type": "shader module descriptor", "annotation": "const*"}
+ ]
},
{
"name": "create swap chain builder",
@@ -910,20 +913,12 @@
"shader module": {
"category": "object"
},
- "shader module builder": {
- "category": "object",
- "methods": [
- {
- "name": "get result",
- "returns": "shader module"
- },
- {
- "name": "set source",
- "args": [
- {"name": "code size", "type": "uint32_t"},
- {"name": "code", "type": "uint32_t", "annotation": "const*", "length": "code size"}
- ]
- }
+ "shader module descriptor": {
+ "category": "structure",
+ "extensible": true,
+ "members": [
+ {"name": "code size", "type": "uint32_t"},
+ {"name": "code", "type": "uint32_t", "annotation": "const*", "length": "code size"}
]
},
"shader stage": {
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index a5da9fe..4bd782d 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -160,8 +160,14 @@
return result;
}
- ShaderModuleBuilder* DeviceBase::CreateShaderModuleBuilder() {
- return new ShaderModuleBuilder(this);
+ ShaderModuleBase* DeviceBase::CreateShaderModule(const ShaderModuleDescriptor* descriptor) {
+ ShaderModuleBase* result = nullptr;
+
+ if (ConsumedError(CreateShaderModuleInternal(&result, descriptor))) {
+ return nullptr;
+ }
+
+ return result;
}
SwapChainBuilder* DeviceBase::CreateSwapChainBuilder() {
return new SwapChainBuilder(this);
@@ -219,6 +225,13 @@
return {};
}
+ MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
+ const ShaderModuleDescriptor* descriptor) {
+ DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
+ DAWN_TRY_ASSIGN(*result, CreateShaderModuleImpl(descriptor));
+ return {};
+ }
+
// Other implementation details
void DeviceBase::ConsumeError(ErrorData* error) {
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index fd154e7..9ff29ed 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -55,7 +55,6 @@
virtual RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) = 0;
virtual RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) = 0;
- virtual ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) = 0;
virtual SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) = 0;
virtual TextureBase* CreateTexture(TextureBuilder* builder) = 0;
virtual TextureViewBase* CreateTextureView(TextureViewBuilder* builder) = 0;
@@ -94,7 +93,7 @@
RenderPassDescriptorBuilder* CreateRenderPassDescriptorBuilder();
RenderPipelineBuilder* CreateRenderPipelineBuilder();
SamplerBase* CreateSampler(const SamplerDescriptor* descriptor);
- ShaderModuleBuilder* CreateShaderModuleBuilder();
+ ShaderModuleBase* CreateShaderModule(const ShaderModuleDescriptor* descriptor);
SwapChainBuilder* CreateSwapChainBuilder();
TextureBuilder* CreateTextureBuilder();
@@ -111,6 +110,8 @@
virtual ResultOrError<QueueBase*> CreateQueueImpl() = 0;
virtual ResultOrError<SamplerBase*> CreateSamplerImpl(
const SamplerDescriptor* descriptor) = 0;
+ virtual ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) = 0;
MaybeError CreateBindGroupLayoutInternal(BindGroupLayoutBase** result,
const BindGroupLayoutDescriptor* descriptor);
@@ -118,6 +119,8 @@
const PipelineLayoutDescriptor* descriptor);
MaybeError CreateQueueInternal(QueueBase** result);
MaybeError CreateSamplerInternal(SamplerBase** result, const SamplerDescriptor* descriptor);
+ MaybeError CreateShaderModuleInternal(ShaderModuleBase** result,
+ const ShaderModuleDescriptor* descriptor);
void ConsumeError(ErrorData* error);
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 193ebf0..db7119c 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -23,7 +23,15 @@
namespace dawn_native {
- ShaderModuleBase::ShaderModuleBase(ShaderModuleBuilder* builder) : mDevice(builder->mDevice) {
+ MaybeError ValidateShaderModuleDescriptor(DeviceBase*, const ShaderModuleDescriptor*) {
+ // TODO(cwallez@chromium.org): Use spirv-val to check the module is well-formed
+ return {};
+ }
+
+ // ShaderModuleBase
+
+ ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor*)
+ : mDevice(device) {
}
DeviceBase* ShaderModuleBase::GetDevice() const {
@@ -218,24 +226,4 @@
return true;
}
- ShaderModuleBuilder::ShaderModuleBuilder(DeviceBase* device) : Builder(device) {
- }
-
- std::vector<uint32_t> ShaderModuleBuilder::AcquireSpirv() {
- return std::move(mSpirv);
- }
-
- ShaderModuleBase* ShaderModuleBuilder::GetResultImpl() {
- if (mSpirv.size() == 0) {
- HandleError("Shader module needs to have the source set");
- return nullptr;
- }
-
- return mDevice->CreateShaderModule(this);
- }
-
- void ShaderModuleBuilder::SetSource(uint32_t codeSize, const uint32_t* code) {
- mSpirv.assign(code, code + codeSize);
- }
-
} // namespace dawn_native
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index c64cb62..c9a6085 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -17,6 +17,7 @@
#include "common/Constants.h"
#include "dawn_native/Builder.h"
+#include "dawn_native/Error.h"
#include "dawn_native/Forward.h"
#include "dawn_native/RefCounted.h"
@@ -32,9 +33,12 @@
namespace dawn_native {
+ MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
+ const ShaderModuleDescriptor* descriptor);
+
class ShaderModuleBase : public RefCounted {
public:
- ShaderModuleBase(ShaderModuleBuilder* builder);
+ ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
DeviceBase* GetDevice() const;
@@ -75,23 +79,6 @@
dawn::ShaderStage mExecutionModel;
};
- class ShaderModuleBuilder : public Builder<ShaderModuleBase> {
- public:
- ShaderModuleBuilder(DeviceBase* device);
-
- std::vector<uint32_t> AcquireSpirv();
-
- // Dawn API
- void SetSource(uint32_t codeSize, const uint32_t* code);
-
- private:
- friend class ShaderModuleBase;
-
- ShaderModuleBase* GetResultImpl() override;
-
- std::vector<uint32_t> mSpirv;
- };
-
} // namespace dawn_native
#endif // DAWNNATIVE_SHADERMODULE_H_
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index 249daa3..dddbc26 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -304,8 +304,9 @@
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor);
}
- ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) {
- return new ShaderModule(builder);
+ ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) {
+ return new ShaderModule(this, descriptor);
}
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder);
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 4041cb9..5574153 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -49,7 +49,6 @@
RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
- ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@@ -83,6 +82,8 @@
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
+ ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) override;
uint64_t mSerial = 0;
ComPtr<ID3D12Fence> mFence;
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index c1f096d..5a9c829 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -15,6 +15,7 @@
#include "dawn_native/d3d12/ShaderModuleD3D12.h"
#include "common/Assert.h"
+#include "dawn_native/d3d12/DeviceD3D12.h"
#include <spirv-cross/spirv_hlsl.hpp>
@@ -44,8 +45,9 @@
std::array<T, kNumBindingTypes> mMap{};
};
- ShaderModule::ShaderModule(ShaderModuleBuilder* builder) : ShaderModuleBase(builder) {
- spirv_cross::CompilerHLSL compiler(builder->AcquireSpirv());
+ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
+ : ShaderModuleBase(device, descriptor) {
+ spirv_cross::CompilerHLSL compiler(descriptor->code, descriptor->codeSize);
spirv_cross::CompilerGLSL::Options options_glsl;
options_glsl.vertex.fixup_clipspace = true;
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 032cfa8..11065c1 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -23,7 +23,7 @@
class ShaderModule : public ShaderModuleBase {
public:
- ShaderModule(ShaderModuleBuilder* builder);
+ ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
const std::string& GetHLSLSource() const;
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index 57390c1..7a87caa 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -46,7 +46,6 @@
RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
- ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@@ -69,6 +68,8 @@
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
+ ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) override;
void OnCompletedHandler();
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index c8bf183..0f41704 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -123,8 +123,9 @@
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor);
}
- ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) {
- return new ShaderModule(builder);
+ ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) {
+ return new ShaderModule(this, descriptor);
}
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 60ed58f..a021e42 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -25,11 +25,12 @@
namespace dawn_native { namespace metal {
+ class Device;
class PipelineLayout;
class ShaderModule : public ShaderModuleBase {
public:
- ShaderModule(ShaderModuleBuilder* builder);
+ ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
struct MetalFunctionData {
id<MTLFunction> function;
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 8de07d5..8e2a59a 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -40,8 +40,9 @@
}
}
- ShaderModule::ShaderModule(ShaderModuleBuilder* builder)
- : ShaderModuleBase(builder), mSpirv(builder->AcquireSpirv()) {
+ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
+ : ShaderModuleBase(device, descriptor) {
+ mSpirv.assign(descriptor->code, descriptor->code + descriptor->codeSize);
spirv_cross::CompilerMSL compiler(mSpirv);
ExtractSpirvInfo(compiler);
}
diff --git a/src/dawn_native/null/NullBackend.cpp b/src/dawn_native/null/NullBackend.cpp
index b4c6bad..82b7c19 100644
--- a/src/dawn_native/null/NullBackend.cpp
+++ b/src/dawn_native/null/NullBackend.cpp
@@ -78,10 +78,11 @@
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor);
}
- ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) {
- auto module = new ShaderModule(builder);
+ ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) {
+ auto module = new ShaderModule(this, descriptor);
- spirv_cross::Compiler compiler(builder->AcquireSpirv());
+ spirv_cross::Compiler compiler(descriptor->code, descriptor->codeSize);
module->ExtractSpirvInfo(compiler);
return module;
diff --git a/src/dawn_native/null/NullBackend.h b/src/dawn_native/null/NullBackend.h
index d780f6d..6d06a65 100644
--- a/src/dawn_native/null/NullBackend.h
+++ b/src/dawn_native/null/NullBackend.h
@@ -106,7 +106,6 @@
RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
- ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@@ -123,6 +122,8 @@
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
+ ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) override;
std::vector<std::unique_ptr<PendingOperation>> mPendingOperations;
};
diff --git a/src/dawn_native/opengl/DeviceGL.cpp b/src/dawn_native/opengl/DeviceGL.cpp
index 55fb1b1..26bdf41 100644
--- a/src/dawn_native/opengl/DeviceGL.cpp
+++ b/src/dawn_native/opengl/DeviceGL.cpp
@@ -91,8 +91,9 @@
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor);
}
- ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) {
- return new ShaderModule(builder);
+ ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) {
+ return new ShaderModule(this, descriptor);
}
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder);
diff --git a/src/dawn_native/opengl/DeviceGL.h b/src/dawn_native/opengl/DeviceGL.h
index e77a240..b1fcdec 100644
--- a/src/dawn_native/opengl/DeviceGL.h
+++ b/src/dawn_native/opengl/DeviceGL.h
@@ -43,7 +43,6 @@
RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
- ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@@ -57,6 +56,8 @@
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
+ ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) override;
};
}} // namespace dawn_native::opengl
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index c8f1f4e..84226a3 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -16,6 +16,7 @@
#include "common/Assert.h"
#include "common/Platform.h"
+#include "dawn_native/opengl/DeviceGL.h"
#include <spirv-cross/spirv_glsl.hpp>
@@ -46,8 +47,9 @@
return o.str();
}
- ShaderModule::ShaderModule(ShaderModuleBuilder* builder) : ShaderModuleBase(builder) {
- spirv_cross::CompilerGLSL compiler(builder->AcquireSpirv());
+ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
+ : ShaderModuleBase(device, descriptor) {
+ spirv_cross::CompilerGLSL compiler(descriptor->code, descriptor->codeSize);
spirv_cross::CompilerGLSL::Options options;
// TODO(cwallez@chromium.org): discover the backing context version and use that.
diff --git a/src/dawn_native/opengl/ShaderModuleGL.h b/src/dawn_native/opengl/ShaderModuleGL.h
index 9bd7727..8e485e6 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.h
+++ b/src/dawn_native/opengl/ShaderModuleGL.h
@@ -40,7 +40,7 @@
class ShaderModule : public ShaderModuleBase {
public:
- ShaderModule(ShaderModuleBuilder* builder);
+ ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
using CombinedSamplerInfo = std::vector<CombinedSampler>;
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index 0fce56d..1bf7769 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -262,8 +262,9 @@
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor);
}
- ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) {
- return new ShaderModule(builder);
+ ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) {
+ return new ShaderModule(this, descriptor);
}
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder);
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index f59cd98..19e1d91 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -74,7 +74,6 @@
RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
- ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@@ -88,6 +87,8 @@
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
+ ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
+ const ShaderModuleDescriptor* descriptor) override;
bool CreateInstance(VulkanGlobalKnobs* usedKnobs,
const std::vector<const char*>& requiredExtensions);
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index 1c4644e..0dd8810 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -21,22 +21,19 @@
namespace dawn_native { namespace vulkan {
- ShaderModule::ShaderModule(ShaderModuleBuilder* builder) : ShaderModuleBase(builder) {
- std::vector<uint32_t> spirv = builder->AcquireSpirv();
-
+ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
+ : ShaderModuleBase(device, descriptor) {
// Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to
// have a translation step eventually anyway.
- spirv_cross::Compiler compiler(spirv);
+ spirv_cross::Compiler compiler(descriptor->code, descriptor->codeSize);
ExtractSpirvInfo(compiler);
VkShaderModuleCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
createInfo.pNext = nullptr;
createInfo.flags = 0;
- createInfo.codeSize = spirv.size() * sizeof(uint32_t);
- createInfo.pCode = spirv.data();
-
- Device* device = ToBackend(GetDevice());
+ createInfo.codeSize = descriptor->codeSize * sizeof(uint32_t);
+ createInfo.pCode = descriptor->code;
if (device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &mHandle) !=
VK_SUCCESS) {
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 4e46787..8c904d2 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -21,9 +21,11 @@
namespace dawn_native { namespace vulkan {
+ class Device;
+
class ShaderModule : public ShaderModuleBase {
public:
- ShaderModule(ShaderModuleBuilder* builder);
+ ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule();
VkShaderModule GetHandle() const;
diff --git a/src/tests/unittests/WireTests.cpp b/src/tests/unittests/WireTests.cpp
index cd0e889..d65292d 100644
--- a/src/tests/unittests/WireTests.cpp
+++ b/src/tests/unittests/WireTests.cpp
@@ -295,15 +295,13 @@
// Test that the wire is able to send C strings
TEST_F(WireTests, CStringArgument) {
// Create shader module
- dawnShaderModuleBuilder shaderModuleBuilder = dawnDeviceCreateShaderModuleBuilder(device);
- dawnShaderModule shaderModule = dawnShaderModuleBuilderGetResult(shaderModuleBuilder);
-
- dawnShaderModuleBuilder apiShaderModuleBuilder = api.GetNewShaderModuleBuilder();
- EXPECT_CALL(api, DeviceCreateShaderModuleBuilder(apiDevice))
- .WillOnce(Return(apiShaderModuleBuilder));
+ dawnShaderModuleDescriptor descriptor;
+ descriptor.nextInChain = nullptr;
+ descriptor.codeSize = 0;
+ dawnShaderModule shaderModule = dawnDeviceCreateShaderModule(device, &descriptor);
dawnShaderModule apiShaderModule = api.GetNewShaderModule();
- EXPECT_CALL(api, ShaderModuleBuilderGetResult(apiShaderModuleBuilder))
+ EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
.WillOnce(Return(apiShaderModule));
// Create pipeline
diff --git a/src/tests/unittests/validation/InputStateValidationTests.cpp b/src/tests/unittests/validation/InputStateValidationTests.cpp
index b4543d4..5f352dc 100644
--- a/src/tests/unittests/validation/InputStateValidationTests.cpp
+++ b/src/tests/unittests/validation/InputStateValidationTests.cpp
@@ -25,19 +25,14 @@
dawn::RenderPipeline CreatePipeline(bool success, const dawn::InputState& inputState, std::string vertexSource) {
DummyRenderPass renderpassData = CreateDummyRenderPass();
- dawn::ShaderModuleBuilder vsModuleBuilder = AssertWillBeSuccess(device.CreateShaderModuleBuilder());
- utils::FillShaderModuleBuilder(vsModuleBuilder, dawn::ShaderStage::Vertex, vertexSource.c_str());
- dawn::ShaderModule vsModule = vsModuleBuilder.GetResult();
-
- dawn::ShaderModuleBuilder fsModuleBuilder = AssertWillBeSuccess(device.CreateShaderModuleBuilder());
- utils::FillShaderModuleBuilder(fsModuleBuilder, dawn::ShaderStage::Fragment, R"(
+ dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str());
+ dawn::ShaderModule fsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
#version 450
layout(location = 0) out vec4 fragColor;
void main() {
fragColor = vec4(1.0, 0.0, 0.0, 1.0);
}
)");
- dawn::ShaderModule fsModule = fsModuleBuilder.GetResult();
dawn::RenderPipelineBuilder builder;
if (success) {
diff --git a/src/tests/unittests/validation/PushConstantsValidationTests.cpp b/src/tests/unittests/validation/PushConstantsValidationTests.cpp
index c436697..ec9e63b 100644
--- a/src/tests/unittests/validation/PushConstantsValidationTests.cpp
+++ b/src/tests/unittests/validation/PushConstantsValidationTests.cpp
@@ -27,14 +27,12 @@
uint32_t constants[kMaxPushConstants] = {0};
void TestCreateShaderModule(bool success, std::string vertexSource) {
- dawn::ShaderModuleBuilder builder;
+ dawn::ShaderModule module;
if (success) {
- builder = AssertWillBeSuccess(device.CreateShaderModuleBuilder());
+ module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str());
} else {
- builder = AssertWillBeError(device.CreateShaderModuleBuilder());
+ ASSERT_DEVICE_ERROR(module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str()));
}
- utils::FillShaderModuleBuilder(builder, dawn::ShaderStage::Vertex, vertexSource.c_str());
- builder.GetResult();
}
private:
diff --git a/src/utils/DawnHelpers.cpp b/src/utils/DawnHelpers.cpp
index 833a0b9..65adb97 100644
--- a/src/utils/DawnHelpers.cpp
+++ b/src/utils/DawnHelpers.cpp
@@ -25,9 +25,9 @@
namespace utils {
- void FillShaderModuleBuilder(const dawn::ShaderModuleBuilder& builder,
- dawn::ShaderStage stage,
- const char* source) {
+ dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
+ dawn::ShaderStage stage,
+ const char* source) {
shaderc::Compiler compiler;
shaderc::CompileOptions options;
@@ -49,7 +49,7 @@
auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?", options);
if (result.GetCompilationStatus() != shaderc_compilation_status_success) {
std::cerr << result.GetErrorMessage();
- return;
+ return {};
}
// result.cend and result.cbegin return pointers to uint32_t.
@@ -58,7 +58,10 @@
// So this size is in units of sizeof(uint32_t).
ptrdiff_t resultSize = resultEnd - resultBegin;
// SetSource takes data as uint32_t*.
- builder.SetSource(static_cast<uint32_t>(resultSize), result.cbegin());
+
+ dawn::ShaderModuleDescriptor descriptor;
+ descriptor.codeSize = static_cast<uint32_t>(resultSize);
+ descriptor.code = result.cbegin();
#ifdef DUMP_SPIRV_ASSEMBLY
{
@@ -87,14 +90,8 @@
printf("\n");
printf("SPIRV JS ARRAY DUMP END\n");
#endif
- }
- dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
- dawn::ShaderStage stage,
- const char* source) {
- dawn::ShaderModuleBuilder builder = device.CreateShaderModuleBuilder();
- FillShaderModuleBuilder(builder, stage, source);
- return builder.GetResult();
+ return device.CreateShaderModule(&descriptor);
}
dawn::Buffer CreateBufferFromData(const dawn::Device& device,
diff --git a/src/utils/DawnHelpers.h b/src/utils/DawnHelpers.h
index 9929ab5..88698ce 100644
--- a/src/utils/DawnHelpers.h
+++ b/src/utils/DawnHelpers.h
@@ -18,9 +18,6 @@
namespace utils {
- void FillShaderModuleBuilder(const dawn::ShaderModuleBuilder& builder,
- dawn::ShaderStage stage,
- const char* source);
dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
dawn::ShaderStage stage,
const char* source);