Make ShaderModuleBase use an internal EntryPointMetadata
WGSL and SPIR-V modules can contain multiple entrypoints, for different
shader stages, that the pipelines can choose from. This is the first CL
in a stack that will change Dawn internals to not rely on ShaderModules
having a single entrypoint.
EntryPointMetadata is introduced that will contain all reflection data
for an entrypoint of a shader module. To ease review this CL doesn't
introduce any functional changes and doesn't expose the
EntryPointMetadata at the ShaderModuleBase interface. Instead
ShaderModuleBase contains a single metadata object for its single entry
point, and layout-related queries and proxied to the EntryPointMetadata
object.
Finally some small renames and formatting changes are done.
Bug: dawn:216
Change-Id: I0f4d12a5075ba14c5e8fd666be4073d34288f6f9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27240
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 6057e70..03ddf03 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -371,182 +371,330 @@
}
}
#endif
- } // anonymous namespace
- MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) {
- spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
+ MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) {
+ spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
- std::ostringstream errorStream;
- errorStream << "SPIRV Validation failure:" << std::endl;
+ std::ostringstream errorStream;
+ errorStream << "SPIRV Validation failure:" << std::endl;
- spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*,
- const spv_position_t& position,
- const char* message) {
- switch (level) {
- case SPV_MSG_FATAL:
- case SPV_MSG_INTERNAL_ERROR:
- case SPV_MSG_ERROR:
- errorStream << "error: line " << position.index << ": " << message << std::endl;
- break;
- case SPV_MSG_WARNING:
- errorStream << "warning: line " << position.index << ": " << message
- << std::endl;
- break;
- case SPV_MSG_INFO:
- errorStream << "info: line " << position.index << ": " << message << std::endl;
- break;
- default:
- break;
+ spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ errorStream << "error: line " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_WARNING:
+ errorStream << "warning: line " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_INFO:
+ errorStream << "info: line " << position.index << ": " << message
+ << std::endl;
+ break;
+ default:
+ break;
+ }
+ });
+
+ if (!spirvTools.Validate(code, codeSize)) {
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
- });
- if (!spirvTools.Validate(code, codeSize)) {
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ return {};
}
- return {};
- }
-
#ifdef DAWN_ENABLE_WGSL
- MaybeError ValidateWGSL(const char* source) {
- std::ostringstream errorStream;
- errorStream << "Tint WGSL failure:" << std::endl;
+ MaybeError ValidateWGSL(const char* source) {
+ std::ostringstream errorStream;
+ errorStream << "Tint WGSL failure:" << std::endl;
- tint::Context context;
- tint::reader::wgsl::Parser parser(&context, source);
+ tint::Context context;
+ tint::reader::wgsl::Parser parser(&context, source);
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::TypeDeterminer type_determiner(&context, &module);
- if (!type_determiner.Determine()) {
- errorStream << "Type Determination: " << type_determiner.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::Validator validator;
- if (!validator.Validate(&module)) {
- errorStream << "Validation: " << validator.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- return {};
- }
-
- ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) {
- std::ostringstream errorStream;
- errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
-
- tint::Context context;
- tint::reader::wgsl::Parser parser(&context, source);
-
- // TODO: This is a duplicate parse with ValidateWGSL, need to store
- // state between calls to avoid this.
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::TypeDeterminer type_determiner(&context, &module);
- if (!type_determiner.Determine()) {
- errorStream << "Type Determination: " << type_determiner.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::writer::spirv::Generator generator(std::move(module));
- if (!generator.Generate()) {
- errorStream << "Generator: " << generator.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- std::vector<uint32_t> spirv = generator.result();
- return std::move(spirv);
- }
-
- ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
- const char* source,
- const VertexStateDescriptor& vertexState,
- const std::string& entryPoint,
- uint32_t pullingBufferBindingSet) {
- std::ostringstream errorStream;
- errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
-
- tint::Context context;
- tint::reader::wgsl::Parser parser(&context, source);
-
- // TODO: This is a duplicate parse with ValidateWGSL, need to store
- // state between calls to avoid this.
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::transform::VertexPullingTransform transform(&context, &module);
- auto state = std::make_unique<tint::ast::transform::VertexStateDescriptor>();
- for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
- auto& vertexBuffer = vertexState.vertexBuffers[i];
- tint::ast::transform::VertexBufferLayoutDescriptor layout;
- layout.array_stride = vertexBuffer.arrayStride;
- layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
-
- for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
- auto& attribute = vertexBuffer.attributes[j];
- tint::ast::transform::VertexAttributeDescriptor attr;
- attr.format = ToTintVertexFormat(attribute.format);
- attr.offset = attribute.offset;
- attr.shader_location = attribute.shaderLocation;
-
- layout.attributes.push_back(std::move(attr));
+ if (!parser.Parse()) {
+ errorStream << "Parser: " << parser.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
- state->vertex_buffers.push_back(std::move(layout));
- }
- transform.SetVertexState(std::move(state));
- transform.SetEntryPoint(entryPoint);
- transform.SetPullingBufferBindingSet(pullingBufferBindingSet);
+ tint::ast::Module module = parser.module();
+ if (!module.IsValid()) {
+ errorStream << "Invalid module generated..." << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
- if (!transform.Run()) {
- errorStream << "Vertex pulling transform: " << transform.GetError();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ tint::TypeDeterminer type_determiner(&context, &module);
+ if (!type_determiner.Determine()) {
+ errorStream << "Type Determination: " << type_determiner.error();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::Validator validator;
+ if (!validator.Validate(&module)) {
+ errorStream << "Validation: " << validator.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ return {};
}
- tint::TypeDeterminer type_determiner(&context, &module);
- if (!type_determiner.Determine()) {
- errorStream << "Type Determination: " << type_determiner.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) {
+ std::ostringstream errorStream;
+ errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
+
+ tint::Context context;
+ tint::reader::wgsl::Parser parser(&context, source);
+
+ // TODO: This is a duplicate parse with ValidateWGSL, need to store
+ // state between calls to avoid this.
+ if (!parser.Parse()) {
+ errorStream << "Parser: " << parser.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::ast::Module module = parser.module();
+ if (!module.IsValid()) {
+ errorStream << "Invalid module generated..." << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::TypeDeterminer type_determiner(&context, &module);
+ if (!type_determiner.Determine()) {
+ errorStream << "Type Determination: " << type_determiner.error();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::writer::spirv::Generator generator(std::move(module));
+ if (!generator.Generate()) {
+ errorStream << "Generator: " << generator.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ std::vector<uint32_t> spirv = generator.result();
+ return std::move(spirv);
}
- tint::writer::spirv::Generator generator(std::move(module));
- if (!generator.Generate()) {
- errorStream << "Generator: " << generator.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
+ ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
+ const char* source,
+ const VertexStateDescriptor& vertexState,
+ const std::string& entryPoint,
+ uint32_t pullingBufferBindingSet) {
+ std::ostringstream errorStream;
+ errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
- std::vector<uint32_t> spirv = generator.result();
- return std::move(spirv);
- }
+ tint::Context context;
+ tint::reader::wgsl::Parser parser(&context, source);
+
+ // TODO: This is a duplicate parse with ValidateWGSL, need to store
+ // state between calls to avoid this.
+ if (!parser.Parse()) {
+ errorStream << "Parser: " << parser.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::ast::Module module = parser.module();
+ if (!module.IsValid()) {
+ errorStream << "Invalid module generated..." << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::ast::transform::VertexPullingTransform transform(&context, &module);
+ auto state = std::make_unique<tint::ast::transform::VertexStateDescriptor>();
+ for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
+ auto& vertexBuffer = vertexState.vertexBuffers[i];
+ tint::ast::transform::VertexBufferLayoutDescriptor layout;
+ layout.array_stride = vertexBuffer.arrayStride;
+ layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
+
+ for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
+ auto& attribute = vertexBuffer.attributes[j];
+ tint::ast::transform::VertexAttributeDescriptor attr;
+ attr.format = ToTintVertexFormat(attribute.format);
+ attr.offset = attribute.offset;
+ attr.shader_location = attribute.shaderLocation;
+
+ layout.attributes.push_back(std::move(attr));
+ }
+
+ state->vertex_buffers.push_back(std::move(layout));
+ }
+ transform.SetVertexState(std::move(state));
+ transform.SetEntryPoint(entryPoint);
+ transform.SetPullingBufferBindingSet(pullingBufferBindingSet);
+
+ if (!transform.Run()) {
+ errorStream << "Vertex pulling transform: " << transform.GetError();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::TypeDeterminer type_determiner(&context, &module);
+ if (!type_determiner.Determine()) {
+ errorStream << "Type Determination: " << type_determiner.error();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::writer::spirv::Generator generator(std::move(module));
+ if (!generator.Generate()) {
+ errorStream << "Generator: " << generator.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ std::vector<uint32_t> spirv = generator.result();
+ return std::move(spirv);
+ }
#endif // DAWN_ENABLE_WGSL
+ std::vector<uint64_t> GetBindGroupMinBufferSizes(
+ const ShaderModuleBase::BindingInfoMap& shaderBindings,
+ const BindGroupLayoutBase* layout) {
+ std::vector<uint64_t> requiredBufferSizes(layout->GetUnverifiedBufferCount());
+ uint32_t packedIdx = 0;
+
+ for (BindingIndex bindingIndex{0}; bindingIndex < layout->GetBufferCount();
+ ++bindingIndex) {
+ const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex);
+ if (bindingInfo.minBufferBindingSize != 0) {
+ // Skip bindings that have minimum buffer size set in the layout
+ continue;
+ }
+
+ ASSERT(packedIdx < requiredBufferSizes.size());
+ const auto& shaderInfo = shaderBindings.find(bindingInfo.binding);
+ if (shaderInfo != shaderBindings.end()) {
+ requiredBufferSizes[packedIdx] = shaderInfo->second.minBufferBindingSize;
+ } else {
+ // We have to include buffers if they are included in the bind group's
+ // packed vector. We don't actually need to check these at draw time, so
+ // if this is a problem in the future we can optimize it further.
+ requiredBufferSizes[packedIdx] = 0;
+ }
+ ++packedIdx;
+ }
+
+ return requiredBufferSizes;
+ }
+
+ MaybeError ValidateCompatibilityWithBindGroupLayout(
+ BindGroupIndex group,
+ const ShaderModuleBase::EntryPointMetadata& entryPoint,
+ const BindGroupLayoutBase* layout) {
+ const BindGroupLayoutBase::BindingMap& layoutBindings = layout->GetBindingMap();
+
+ // Iterate over all bindings used by this group in the shader, and find the
+ // corresponding binding in the BindGroupLayout, if it exists.
+ for (const auto& it : entryPoint.bindings[group]) {
+ BindingNumber bindingNumber = it.first;
+ const ShaderModuleBase::ShaderBindingInfo& shaderInfo = it.second;
+
+ const auto& bindingIt = layoutBindings.find(bindingNumber);
+ if (bindingIt == layoutBindings.end()) {
+ return DAWN_VALIDATION_ERROR("Missing bind group layout entry for " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+ BindingIndex bindingIndex(bindingIt->second);
+ const BindingInfo& layoutInfo = layout->GetBindingInfo(bindingIndex);
+
+ if (layoutInfo.type != shaderInfo.type) {
+ // Binding mismatch between shader and bind group is invalid. For example, a
+ // writable binding in the shader with a readonly storage buffer in the bind
+ // group layout is invalid. However, a readonly binding in the shader with a
+ // writable storage buffer in the bind group layout is valid.
+ bool validBindingConversion =
+ layoutInfo.type == wgpu::BindingType::StorageBuffer &&
+ shaderInfo.type == wgpu::BindingType::ReadonlyStorageBuffer;
+
+ // TODO(crbug.com/dawn/367): Temporarily allow using either a sampler or a
+ // comparison sampler until we can perform the proper shader analysis of what
+ // type is used in the shader module.
+ validBindingConversion |=
+ (layoutInfo.type == wgpu::BindingType::Sampler &&
+ shaderInfo.type == wgpu::BindingType::ComparisonSampler);
+ validBindingConversion |=
+ (layoutInfo.type == wgpu::BindingType::ComparisonSampler &&
+ shaderInfo.type == wgpu::BindingType::Sampler);
+
+ if (!validBindingConversion) {
+ return DAWN_VALIDATION_ERROR(
+ "The binding type of the bind group layout entry conflicts " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+ }
+
+ if ((layoutInfo.visibility & StageBit(entryPoint.stage)) == 0) {
+ return DAWN_VALIDATION_ERROR("The bind group layout entry for " +
+ GetShaderDeclarationString(group, bindingNumber) +
+ " is not visible for the shader stage");
+ }
+
+ switch (layoutInfo.type) {
+ case wgpu::BindingType::SampledTexture: {
+ if (layoutInfo.textureComponentType != shaderInfo.textureComponentType) {
+ return DAWN_VALIDATION_ERROR(
+ "The textureComponentType of the bind group layout entry is "
+ "different from " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+
+ if (layoutInfo.viewDimension != shaderInfo.viewDimension) {
+ return DAWN_VALIDATION_ERROR(
+ "The viewDimension of the bind group layout entry is different "
+ "from " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+ break;
+ }
+
+ case wgpu::BindingType::ReadonlyStorageTexture:
+ case wgpu::BindingType::WriteonlyStorageTexture: {
+ ASSERT(layoutInfo.storageTextureFormat != wgpu::TextureFormat::Undefined);
+ ASSERT(shaderInfo.storageTextureFormat != wgpu::TextureFormat::Undefined);
+ if (layoutInfo.storageTextureFormat != shaderInfo.storageTextureFormat) {
+ return DAWN_VALIDATION_ERROR(
+ "The storageTextureFormat of the bind group layout entry is "
+ "different from " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+ if (layoutInfo.viewDimension != shaderInfo.viewDimension) {
+ return DAWN_VALIDATION_ERROR(
+ "The viewDimension of the bind group layout entry is different "
+ "from " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+ break;
+ }
+
+ case wgpu::BindingType::UniformBuffer:
+ case wgpu::BindingType::ReadonlyStorageBuffer:
+ case wgpu::BindingType::StorageBuffer: {
+ if (layoutInfo.minBufferBindingSize != 0 &&
+ shaderInfo.minBufferBindingSize > layoutInfo.minBufferBindingSize) {
+ return DAWN_VALIDATION_ERROR(
+ "The minimum buffer size of the bind group layout entry is smaller "
+ "than " +
+ GetShaderDeclarationString(group, bindingNumber));
+ }
+ break;
+ }
+ case wgpu::BindingType::Sampler:
+ case wgpu::BindingType::ComparisonSampler:
+ break;
+
+ case wgpu::BindingType::StorageTexture:
+ default:
+ UNREACHABLE();
+ return DAWN_VALIDATION_ERROR("Unsupported binding type");
+ }
+ }
+
+ return {};
+ }
+
+ } // anonymous namespace
+
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
const ShaderModuleDescriptor* descriptor) {
const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
@@ -582,7 +730,45 @@
}
return {};
- } // namespace
+ }
+
+ RequiredBufferSizes ComputeRequiredBufferSizesForLayout(
+ const ShaderModuleBase::EntryPointMetadata& entryPoint,
+ const PipelineLayoutBase* layout) {
+ RequiredBufferSizes bufferSizes;
+ for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+ bufferSizes[group] = GetBindGroupMinBufferSizes(entryPoint.bindings[group],
+ layout->GetBindGroupLayout(group));
+ }
+
+ return bufferSizes;
+ }
+
+ MaybeError ValidateCompatibilityWithPipelineLayout(
+ const ShaderModuleBase::EntryPointMetadata& entryPoint,
+ const PipelineLayoutBase* layout) {
+ for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+ DAWN_TRY(ValidateCompatibilityWithBindGroupLayout(group, entryPoint,
+ layout->GetBindGroupLayout(group)));
+ }
+
+ for (BindGroupIndex group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) {
+ if (entryPoint.bindings[group].size() > 0) {
+ std::ostringstream ostream;
+ ostream << "No bind group layout entry matches the declaration set "
+ << static_cast<uint32_t>(group) << " in the shader module";
+ return DAWN_VALIDATION_ERROR(ostream.str());
+ }
+ }
+
+ return {};
+ }
+
+ // EntryPointMetadata
+
+ ShaderModuleBase::EntryPointMetadata::EntryPointMetadata() {
+ fragmentOutputFormatBaseTypes.fill(Format::Type::Other);
+ }
// ShaderModuleBase
@@ -608,7 +794,6 @@
UNREACHABLE();
}
- mFragmentOutputFormatBaseTypes.fill(Format::Type::Other);
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvcParser)) {
mSpvcContext.SetUseSpvcParser(true);
}
@@ -632,18 +817,22 @@
MaybeError ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) {
ASSERT(!IsError());
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
- DAWN_TRY(ExtractSpirvInfoWithSpvc());
+ DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfoWithSpvc());
} else {
- DAWN_TRY(ExtractSpirvInfoWithSpirvCross(compiler));
+ DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfoWithSpirvCross(compiler));
}
return {};
}
- MaybeError ShaderModuleBase::ExtractSpirvInfoWithSpvc() {
+ ResultOrError<std::unique_ptr<ShaderModuleBase::EntryPointMetadata>>
+ ShaderModuleBase::ExtractSpirvInfoWithSpvc() {
+ DeviceBase* device = GetDevice();
+ std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
+
shaderc_spvc_execution_model execution_model;
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetExecutionModel(&execution_model),
"Unable to get execution model for shader."));
- mExecutionModel = ToSingleShaderStage(execution_model);
+ metadata->stage = ToSingleShaderStage(execution_model);
size_t push_constant_buffers_count;
DAWN_TRY(
@@ -658,15 +847,16 @@
// Fill in bindingInfo with the SPIRV bindings
auto ExtractResourcesBinding =
- [this](std::vector<shaderc_spvc_binding_info> bindings) -> MaybeError {
- for (const auto& binding : bindings) {
+ [](const DeviceBase* device, const std::vector<shaderc_spvc_binding_info>& spvcBindings,
+ ModuleBindingInfo* metadataBindings) -> MaybeError {
+ for (const shaderc_spvc_binding_info& binding : spvcBindings) {
BindGroupIndex bindGroupIndex(binding.set);
if (bindGroupIndex >= kMaxBindGroupsTyped) {
return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV");
}
- const auto& it = mBindingInfo[bindGroupIndex].emplace(
+ const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
BindingNumber(binding.binding), ShaderBindingInfo{});
if (!it.second) {
return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
@@ -694,8 +884,7 @@
return DAWN_VALIDATION_ERROR(
"Invalid image format declaration on storage image");
}
- const Format& format =
- GetDevice()->GetValidInternalFormat(storageTextureFormat);
+ const Format& format = device->GetValidInternalFormat(storageTextureFormat);
if (!format.supportsStorageUsage) {
return DAWN_VALIDATION_ERROR(
"The storage texture format is not supported");
@@ -722,45 +911,45 @@
shaderc_spvc_shader_resource_uniform_buffers,
shaderc_spvc_binding_type_uniform_buffer, &resource_bindings),
"Unable to get binding info for uniform buffers from shader"));
- DAWN_TRY(ExtractResourcesBinding(resource_bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings));
DAWN_TRY(CheckSpvcSuccess(
mSpvcContext.GetBindingInfo(shaderc_spvc_shader_resource_separate_images,
shaderc_spvc_binding_type_sampled_texture,
&resource_bindings),
"Unable to get binding info for sampled textures from shader"));
- DAWN_TRY(ExtractResourcesBinding(resource_bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings));
DAWN_TRY(CheckSpvcSuccess(
mSpvcContext.GetBindingInfo(shaderc_spvc_shader_resource_separate_samplers,
shaderc_spvc_binding_type_sampler, &resource_bindings),
"Unable to get binding info for samples from shader"));
- DAWN_TRY(ExtractResourcesBinding(resource_bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings));
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetBindingInfo(
shaderc_spvc_shader_resource_storage_buffers,
shaderc_spvc_binding_type_storage_buffer, &resource_bindings),
"Unable to get binding info for storage buffers from shader"));
- DAWN_TRY(ExtractResourcesBinding(resource_bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings));
DAWN_TRY(CheckSpvcSuccess(
mSpvcContext.GetBindingInfo(shaderc_spvc_shader_resource_storage_images,
shaderc_spvc_binding_type_storage_texture,
&resource_bindings),
"Unable to get binding info for storage textures from shader"));
- DAWN_TRY(ExtractResourcesBinding(resource_bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings));
std::vector<shaderc_spvc_resource_location_info> input_stage_locations;
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetInputStageLocationInfo(&input_stage_locations),
"Unable to get input stage location information from shader"));
for (const auto& input : input_stage_locations) {
- if (mExecutionModel == SingleShaderStage::Vertex) {
+ if (metadata->stage == SingleShaderStage::Vertex) {
if (input.location >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
}
- mUsedVertexAttributes.set(input.location);
- } else if (mExecutionModel == SingleShaderStage::Fragment) {
+ metadata->usedVertexAttributes.set(input.location);
+ } else if (metadata->stage == SingleShaderStage::Fragment) {
// Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives
// them all the location 0, causing a compile error.
if (!input.has_location) {
@@ -774,13 +963,13 @@
"Unable to get output stage location information from shader"));
for (const auto& output : output_stage_locations) {
- if (mExecutionModel == SingleShaderStage::Vertex) {
+ if (metadata->stage == SingleShaderStage::Vertex) {
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL
// gives them all the location 0, causing a compile error.
if (!output.has_location) {
return DAWN_VALIDATION_ERROR("Need location qualifier on vertex output");
}
- } else if (mExecutionModel == SingleShaderStage::Fragment) {
+ } else if (metadata->stage == SingleShaderStage::Fragment) {
if (output.location >= kMaxColorAttachments) {
return DAWN_VALIDATION_ERROR(
"Fragment output location over limits in the SPIRV");
@@ -788,7 +977,7 @@
}
}
- if (mExecutionModel == SingleShaderStage::Fragment) {
+ if (metadata->stage == SingleShaderStage::Fragment) {
std::vector<shaderc_spvc_resource_type_info> output_types;
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetOutputStageTypeInfo(&output_types),
"Unable to get output stage type information from shader"));
@@ -797,27 +986,32 @@
if (output.type == shaderc_spvc_texture_format_type_other) {
return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
}
- mFragmentOutputFormatBaseTypes[output.location] = ToDawnFormatType(output.type);
+ metadata->fragmentOutputFormatBaseTypes[output.location] =
+ ToDawnFormatType(output.type);
}
}
- return {};
+
+ return {std::move(metadata)};
}
- MaybeError ShaderModuleBase::ExtractSpirvInfoWithSpirvCross(
- const spirv_cross::Compiler& compiler) {
+ ResultOrError<std::unique_ptr<ShaderModuleBase::EntryPointMetadata>>
+ ShaderModuleBase::ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler) {
+ DeviceBase* device = GetDevice();
+ std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
+
// TODO(cwallez@chromium.org): make errors here creation errors
// currently errors here do not prevent the shadermodule from being used
const auto& resources = compiler.get_shader_resources();
switch (compiler.get_execution_model()) {
case spv::ExecutionModelVertex:
- mExecutionModel = SingleShaderStage::Vertex;
+ metadata->stage = SingleShaderStage::Vertex;
break;
case spv::ExecutionModelFragment:
- mExecutionModel = SingleShaderStage::Fragment;
+ metadata->stage = SingleShaderStage::Fragment;
break;
case spv::ExecutionModelGLCompute:
- mExecutionModel = SingleShaderStage::Compute;
+ metadata->stage = SingleShaderStage::Compute;
break;
default:
UNREACHABLE();
@@ -834,9 +1028,10 @@
// Fill in bindingInfo with the SPIRV bindings
auto ExtractResourcesBinding =
- [this](const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
- const spirv_cross::Compiler& compiler,
- wgpu::BindingType bindingType) -> MaybeError {
+ [](const DeviceBase* device,
+ const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
+ const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType,
+ ModuleBindingInfo* metadataBindings) -> MaybeError {
for (const auto& resource : resources) {
if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) {
return DAWN_VALIDATION_ERROR("No Binding decoration set for resource");
@@ -857,7 +1052,7 @@
}
const auto& it =
- mBindingInfo[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{});
+ (*metadataBindings)[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{});
if (!it.second) {
return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
}
@@ -919,8 +1114,7 @@
return DAWN_VALIDATION_ERROR(
"Invalid image format declaration on storage image");
}
- const Format& format =
- GetDevice()->GetValidInternalFormat(storageTextureFormat);
+ const Format& format = device->GetValidInternalFormat(storageTextureFormat);
if (!format.supportsStorageUsage) {
return DAWN_VALIDATION_ERROR(
"The storage texture format is not supported");
@@ -938,19 +1132,19 @@
return {};
};
- DAWN_TRY(ExtractResourcesBinding(resources.uniform_buffers, compiler,
- wgpu::BindingType::UniformBuffer));
- DAWN_TRY(ExtractResourcesBinding(resources.separate_images, compiler,
- wgpu::BindingType::SampledTexture));
- DAWN_TRY(ExtractResourcesBinding(resources.separate_samplers, compiler,
- wgpu::BindingType::Sampler));
- DAWN_TRY(ExtractResourcesBinding(resources.storage_buffers, compiler,
- wgpu::BindingType::StorageBuffer));
- DAWN_TRY(ExtractResourcesBinding(resources.storage_images, compiler,
- wgpu::BindingType::StorageTexture));
+ DAWN_TRY(ExtractResourcesBinding(device, resources.uniform_buffers, compiler,
+ wgpu::BindingType::UniformBuffer, &metadata->bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resources.separate_images, compiler,
+ wgpu::BindingType::SampledTexture, &metadata->bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resources.separate_samplers, compiler,
+ wgpu::BindingType::Sampler, &metadata->bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resources.storage_buffers, compiler,
+ wgpu::BindingType::StorageBuffer, &metadata->bindings));
+ DAWN_TRY(ExtractResourcesBinding(device, resources.storage_images, compiler,
+ wgpu::BindingType::StorageTexture, &metadata->bindings));
// Extract the vertex attributes
- if (mExecutionModel == SingleShaderStage::Vertex) {
+ if (metadata->stage == SingleShaderStage::Vertex) {
for (const auto& attrib : resources.stage_inputs) {
if (!(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation))) {
return DAWN_VALIDATION_ERROR(
@@ -962,7 +1156,7 @@
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
}
- mUsedVertexAttributes.set(location);
+ metadata->usedVertexAttributes.set(location);
}
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
@@ -974,7 +1168,7 @@
}
}
- if (mExecutionModel == SingleShaderStage::Fragment) {
+ if (metadata->stage == SingleShaderStage::Fragment) {
// Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives
// them all the location 0, causing a compile error.
for (const auto& attrib : resources.stage_inputs) {
@@ -1003,209 +1197,44 @@
if (formatType == Format::Type::Other) {
return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
}
- mFragmentOutputFormatBaseTypes[location] = formatType;
+ metadata->fragmentOutputFormatBaseTypes[location] = formatType;
}
}
- return {};
+
+ return {std::move(metadata)};
}
const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const {
ASSERT(!IsError());
- return mBindingInfo;
+ return mMainEntryPoint->bindings;
}
const std::bitset<kMaxVertexAttributes>& ShaderModuleBase::GetUsedVertexAttributes() const {
ASSERT(!IsError());
- return mUsedVertexAttributes;
+ return mMainEntryPoint->usedVertexAttributes;
}
const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes()
const {
ASSERT(!IsError());
- return mFragmentOutputFormatBaseTypes;
+ return mMainEntryPoint->fragmentOutputFormatBaseTypes;
}
SingleShaderStage ShaderModuleBase::GetExecutionModel() const {
ASSERT(!IsError());
- return mExecutionModel;
+ return mMainEntryPoint->stage;
}
RequiredBufferSizes ShaderModuleBase::ComputeRequiredBufferSizesForLayout(
const PipelineLayoutBase* layout) const {
- RequiredBufferSizes bufferSizes;
- for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
- bufferSizes[group] =
- GetBindGroupMinBufferSizes(mBindingInfo[group], layout->GetBindGroupLayout(group));
- }
-
- return bufferSizes;
- }
-
- std::vector<uint64_t> ShaderModuleBase::GetBindGroupMinBufferSizes(
- const BindingInfoMap& shaderMap,
- const BindGroupLayoutBase* layout) const {
- std::vector<uint64_t> requiredBufferSizes(layout->GetUnverifiedBufferCount());
- uint32_t packedIdx = 0;
-
- for (BindingIndex bindingIndex{0}; bindingIndex < layout->GetBufferCount();
- ++bindingIndex) {
- const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex);
- if (bindingInfo.minBufferBindingSize != 0) {
- // Skip bindings that have minimum buffer size set in the layout
- continue;
- }
-
- ASSERT(packedIdx < requiredBufferSizes.size());
- const auto& shaderInfo = shaderMap.find(bindingInfo.binding);
- if (shaderInfo != shaderMap.end()) {
- requiredBufferSizes[packedIdx] = shaderInfo->second.minBufferBindingSize;
- } else {
- // We have to include buffers if they are included in the bind group's
- // packed vector. We don't actually need to check these at draw time, so
- // if this is a problem in the future we can optimize it further.
- requiredBufferSizes[packedIdx] = 0;
- }
- ++packedIdx;
- }
-
- return requiredBufferSizes;
+ ASSERT(!IsError());
+ return ::dawn_native::ComputeRequiredBufferSizesForLayout(*mMainEntryPoint, layout);
}
MaybeError ShaderModuleBase::ValidateCompatibilityWithPipelineLayout(
const PipelineLayoutBase* layout) const {
ASSERT(!IsError());
-
- for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
- DAWN_TRY(
- ValidateCompatibilityWithBindGroupLayout(group, layout->GetBindGroupLayout(group)));
- }
-
- for (BindGroupIndex group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) {
- if (mBindingInfo[group].size() > 0) {
- std::ostringstream ostream;
- ostream << "No bind group layout entry matches the declaration set "
- << static_cast<uint32_t>(group) << " in the shader module";
- return DAWN_VALIDATION_ERROR(ostream.str());
- }
- }
-
- return {};
- }
-
- MaybeError ShaderModuleBase::ValidateCompatibilityWithBindGroupLayout(
- BindGroupIndex group,
- const BindGroupLayoutBase* layout) const {
- ASSERT(!IsError());
-
- const BindGroupLayoutBase::BindingMap& bindingMap = layout->GetBindingMap();
-
- // Iterate over all bindings used by this group in the shader, and find the
- // corresponding binding in the BindGroupLayout, if it exists.
- for (const auto& it : mBindingInfo[group]) {
- BindingNumber bindingNumber = it.first;
- const ShaderBindingInfo& moduleInfo = it.second;
-
- const auto& bindingIt = bindingMap.find(bindingNumber);
- if (bindingIt == bindingMap.end()) {
- return DAWN_VALIDATION_ERROR("Missing bind group layout entry for " +
- GetShaderDeclarationString(group, bindingNumber));
- }
- BindingIndex bindingIndex(bindingIt->second);
-
- const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex);
-
- if (bindingInfo.type != moduleInfo.type) {
- // Binding mismatch between shader and bind group is invalid. For example, a
- // writable binding in the shader with a readonly storage buffer in the bind group
- // layout is invalid. However, a readonly binding in the shader with a writable
- // storage buffer in the bind group layout is valid.
- bool validBindingConversion =
- bindingInfo.type == wgpu::BindingType::StorageBuffer &&
- moduleInfo.type == wgpu::BindingType::ReadonlyStorageBuffer;
-
- // TODO(crbug.com/dawn/367): Temporarily allow using either a sampler or a
- // comparison sampler until we can perform the proper shader analysis of what type
- // is used in the shader module.
- validBindingConversion |= (bindingInfo.type == wgpu::BindingType::Sampler &&
- moduleInfo.type == wgpu::BindingType::ComparisonSampler);
- validBindingConversion |=
- (bindingInfo.type == wgpu::BindingType::ComparisonSampler &&
- moduleInfo.type == wgpu::BindingType::Sampler);
-
- if (!validBindingConversion) {
- return DAWN_VALIDATION_ERROR(
- "The binding type of the bind group layout entry conflicts " +
- GetShaderDeclarationString(group, bindingNumber));
- }
- }
-
- if ((bindingInfo.visibility & StageBit(mExecutionModel)) == 0) {
- return DAWN_VALIDATION_ERROR("The bind group layout entry for " +
- GetShaderDeclarationString(group, bindingNumber) +
- " is not visible for the shader stage");
- }
-
- switch (bindingInfo.type) {
- case wgpu::BindingType::SampledTexture: {
- if (bindingInfo.textureComponentType != moduleInfo.textureComponentType) {
- return DAWN_VALIDATION_ERROR(
- "The textureComponentType of the bind group layout entry is different "
- "from " +
- GetShaderDeclarationString(group, bindingNumber));
- }
-
- if (bindingInfo.viewDimension != moduleInfo.viewDimension) {
- return DAWN_VALIDATION_ERROR(
- "The viewDimension of the bind group layout entry is different "
- "from " +
- GetShaderDeclarationString(group, bindingNumber));
- }
- break;
- }
-
- case wgpu::BindingType::ReadonlyStorageTexture:
- case wgpu::BindingType::WriteonlyStorageTexture: {
- ASSERT(bindingInfo.storageTextureFormat != wgpu::TextureFormat::Undefined);
- ASSERT(moduleInfo.storageTextureFormat != wgpu::TextureFormat::Undefined);
- if (bindingInfo.storageTextureFormat != moduleInfo.storageTextureFormat) {
- return DAWN_VALIDATION_ERROR(
- "The storageTextureFormat of the bind group layout entry is different "
- "from " +
- GetShaderDeclarationString(group, bindingNumber));
- }
- if (bindingInfo.viewDimension != moduleInfo.viewDimension) {
- return DAWN_VALIDATION_ERROR(
- "The viewDimension of the bind group layout entry is different "
- "from " +
- GetShaderDeclarationString(group, bindingNumber));
- }
- break;
- }
-
- case wgpu::BindingType::UniformBuffer:
- case wgpu::BindingType::ReadonlyStorageBuffer:
- case wgpu::BindingType::StorageBuffer: {
- if (bindingInfo.minBufferBindingSize != 0 &&
- moduleInfo.minBufferBindingSize > bindingInfo.minBufferBindingSize) {
- return DAWN_VALIDATION_ERROR(
- "The minimum buffer size of the bind group layout entry is smaller "
- "than " +
- GetShaderDeclarationString(group, bindingNumber));
- }
- break;
- }
- case wgpu::BindingType::Sampler:
- case wgpu::BindingType::ComparisonSampler:
- break;
-
- case wgpu::BindingType::StorageTexture:
- default:
- UNREACHABLE();
- return DAWN_VALIDATION_ERROR("Unsupported binding type");
- }
- }
-
- return {};
+ return ::dawn_native::ValidateCompatibilityWithPipelineLayout(*mMainEntryPoint, layout);
}
size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 336551e..f6779aa 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -43,8 +43,6 @@
class ShaderModuleBase : public CachedObject {
public:
- enum class Type { Undefined, Spirv, Wgsl };
-
ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
~ShaderModuleBase() override;
@@ -98,6 +96,15 @@
uint32_t pullingBufferBindingSet) const;
#endif
+ struct EntryPointMetadata {
+ EntryPointMetadata();
+
+ ModuleBindingInfo bindings;
+ std::bitset<kMaxVertexAttributes> usedVertexAttributes;
+ SingleShaderStage stage;
+ FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
+ };
+
protected:
static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
shaderc_spvc::CompileOptions GetCompileOptions() const;
@@ -108,27 +115,18 @@
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
- MaybeError ValidateCompatibilityWithBindGroupLayout(
- BindGroupIndex group,
- const BindGroupLayoutBase* layout) const;
-
- std::vector<uint64_t> GetBindGroupMinBufferSizes(const BindingInfoMap& shaderMap,
- const BindGroupLayoutBase* layout) const;
-
// Different implementations reflection into the shader depending on
// whether using spvc, or directly accessing spirv-cross.
- MaybeError ExtractSpirvInfoWithSpvc();
- MaybeError ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler);
+ ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfoWithSpvc();
+ ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfoWithSpirvCross(
+ const spirv_cross::Compiler& compiler);
+ enum class Type { Undefined, Spirv, Wgsl };
Type mType;
std::vector<uint32_t> mSpirv;
std::string mWgsl;
- ModuleBindingInfo mBindingInfo;
- std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
- SingleShaderStage mExecutionModel;
-
- FragmentOutputBaseTypes mFragmentOutputFormatBaseTypes;
+ std::unique_ptr<EntryPointMetadata> mMainEntryPoint;
};
} // namespace dawn_native