Produce tint::ast::Module in the frontend if UseTintGenerator
This factors code to move parsing of tint::ast::Module to the
frontend. All backends will use this code path when
UseTintGenerator is enabled for both SPIR-V and WGSL ingestion.
To avoid too much code explosion, parsing and validating the
shader is moved into ValidateShaderModuleDescriptor which
returns a result struct that gets passed into creation.
Bug: dawn:571
Change-Id: I598693ef36954fd0056a0744a2a0ebd7cc7d40a4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32301
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 3e838ad..72d7606 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -584,7 +584,8 @@
}
ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
- const ShaderModuleDescriptor* descriptor) {
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
ShaderModuleBase blueprint(this, descriptor);
const size_t blueprintHash = blueprint.ComputeContentHash();
@@ -597,7 +598,18 @@
}
ShaderModuleBase* backendObj;
- DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor));
+ if (parseResult == nullptr) {
+ // We skip the parse on creation if validation isn't enabled which let's us quickly
+ // lookup in the cache without validating and parsing. We need the parsed module now, so
+ // call validate. Most of |ValidateShaderModuleDescriptor| is parsing, but we can
+ // consider splitting it if additional validation is added.
+ ASSERT(!IsValidationEnabled());
+ ShaderModuleParseResult localParseResult =
+ ValidateShaderModuleDescriptor(this, descriptor).AcquireSuccess();
+ DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor, &localParseResult));
+ } else {
+ DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor, parseResult));
+ }
backendObj->SetIsCachedReference();
backendObj->SetContentHash(blueprintHash);
mCaches->shaderModules.insert(backendObj);
@@ -1062,10 +1074,15 @@
MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
const ShaderModuleDescriptor* descriptor) {
DAWN_TRY(ValidateIsAlive());
+
+ ShaderModuleParseResult parseResult = {};
+ ShaderModuleParseResult* parseResultPtr = nullptr;
if (IsValidationEnabled()) {
- DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
+ DAWN_TRY_ASSIGN(parseResult, ValidateShaderModuleDescriptor(this, descriptor));
+ parseResultPtr = &parseResult;
}
- DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor));
+
+ DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor, parseResultPtr));
return {};
}
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 08a4c80..f089c76 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -40,6 +40,7 @@
class PersistentCache;
class StagingBufferBase;
struct InternalPipelineStore;
+ struct ShaderModuleParseResult;
class DeviceBase {
public:
@@ -129,7 +130,8 @@
void UncacheSampler(SamplerBase* obj);
ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
- const ShaderModuleDescriptor* descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult);
void UncacheShaderModule(ShaderModuleBase* obj);
Ref<AttachmentState> GetOrCreateAttachmentState(AttachmentStateBlueprint* blueprint);
@@ -275,7 +277,8 @@
virtual ResultOrError<SamplerBase*> CreateSamplerImpl(
const SamplerDescriptor* descriptor) = 0;
virtual ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) = 0;
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) = 0;
virtual ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) = 0;
// Note that previousSwapChain may be nullptr, or come from a different backend.
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 334a40e..6e0781d 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -173,13 +173,12 @@
}
#ifdef DAWN_ENABLE_WGSL
- MaybeError ValidateWGSL(const char* source) {
+ ResultOrError<tint::ast::Module> ParseWGSL(const char* wgsl) {
std::ostringstream errorStream;
- errorStream << "Tint WGSL failure:" << std::endl;
+ errorStream << "Tint WGSL reader failure:" << std::endl;
- tint::Source::File file("", source);
+ tint::Source::File file("", wgsl);
tint::reader::wgsl::Parser parser(&file);
-
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
@@ -191,14 +190,46 @@
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
- tint::TypeDeterminer type_determiner(&module);
- if (!type_determiner.Determine()) {
- errorStream << "Type Determination: " << type_determiner.error();
+ tint::TypeDeterminer typeDeterminer(&module);
+ if (!typeDeterminer.Determine()) {
+ errorStream << "Type Determination: " << typeDeterminer.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
+ return std::move(module);
+ }
+
+ ResultOrError<tint::ast::Module> ParseSPIRV(const std::vector<uint32_t>& spirv) {
+ std::ostringstream errorStream;
+ errorStream << "Tint SPIRV reader failure:" << std::endl;
+
+ tint::reader::spirv::Parser parser(spirv);
+ if (!parser.Parse()) {
+ errorStream << "Parser: " << parser.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::ast::Module module = parser.module();
+ if (!module.IsValid()) {
+ errorStream << "Invalid module generated..." << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::TypeDeterminer typeDeterminer(&module);
+ if (!typeDeterminer.Determine()) {
+ errorStream << "Type Determination: " << typeDeterminer.error();
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ return std::move(module);
+ }
+
+ MaybeError ValidateModule(tint::ast::Module* module) {
+ std::ostringstream errorStream;
+ errorStream << "Tint module validation" << std::endl;
+
tint::Validator validator;
- if (!validator.Validate(&module)) {
+ if (!validator.Validate(module)) {
errorStream << "Validation: " << validator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
@@ -206,111 +237,9 @@
return {};
}
- ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) {
+ ResultOrError<std::vector<uint32_t>> ModuleToSPIRV(tint::ast::Module module) {
std::ostringstream errorStream;
- errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
-
- tint::Source::File file("", source);
- tint::reader::wgsl::Parser parser(&file);
-
- // TODO: This is a duplicate parse with ValidateWGSL, need to store
- // state between calls to avoid this.
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::TypeDeterminer type_determiner(&module);
- if (!type_determiner.Determine()) {
- errorStream << "Type Determination: " << type_determiner.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::writer::spirv::Generator generator(std::move(module));
- if (!generator.Generate()) {
- errorStream << "Generator: " << generator.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- std::vector<uint32_t> spirv = generator.result();
-
- DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
-
- return std::move(spirv);
- }
-
- ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
- const char* source,
- const VertexStateDescriptor& vertexState,
- const std::string& entryPoint,
- uint32_t pullingBufferBindingSet) {
- std::ostringstream errorStream;
- errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
-
- tint::Source::File file("", source);
- tint::reader::wgsl::Parser parser(&file);
-
- // TODO: This is a duplicate parse with ValidateWGSL, need to store
- // state between calls to avoid this.
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::transform::Manager transformManager;
- {
- auto transform = std::make_unique<tint::transform::VertexPulling>();
- tint::transform::VertexStateDescriptor state;
-
- for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
- auto& vertexBuffer = vertexState.vertexBuffers[i];
- tint::transform::VertexBufferLayoutDescriptor layout;
- layout.array_stride = vertexBuffer.arrayStride;
- layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
-
- for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
- auto& attribute = vertexBuffer.attributes[j];
- tint::transform::VertexAttributeDescriptor attr;
- attr.format = ToTintVertexFormat(attribute.format);
- attr.offset = attribute.offset;
- attr.shader_location = attribute.shaderLocation;
-
- layout.attributes.push_back(std::move(attr));
- }
-
- state.push_back(std::move(layout));
- }
- transform->SetVertexState(std::move(state));
- transform->SetEntryPoint(entryPoint);
- transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
- transformManager.append(std::move(transform));
- }
-
- auto result = transformManager.Run(&module);
- if (result.diagnostics.contains_errors()) {
- errorStream << "Vertex pulling transform: "
- << tint::diag::Formatter{}.format(result.diagnostics);
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
- module = std::move(result.module);
-
- tint::TypeDeterminer type_determiner(&module);
- if (!type_determiner.Determine()) {
- errorStream << "Type Determination: " << type_determiner.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
+ errorStream << "Tint SPIR-V writer failure:" << std::endl;
tint::writer::spirv::Generator generator(std::move(module));
if (!generator.Generate()) {
@@ -745,38 +674,16 @@
// completed using PopulateMetadataUsingSPIRVCross. In the future, once
// this function is complete, ReflectShaderUsingSPIRVCross and
// PopulateMetadataUsingSPIRVCross will be removed.
- ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(DeviceBase* device,
- std::vector<uint32_t> spirv) {
+ ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(
+ DeviceBase* device,
+ const tint::ast::Module& module) {
#ifdef DAWN_ENABLE_WGSL
+ ASSERT(module.IsValid());
+
EntryPointMetadataTable result;
std::ostringstream errorStream;
errorStream << "Tint Reflection failure:" << std::endl;
- tint::reader::spirv::Parser parser(spirv);
-
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::TypeDeterminer typeDeterminer(&module);
- if (!typeDeterminer.Determine()) {
- errorStream << "Type Determination: " << typeDeterminer.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::Validator validator;
- if (!validator.Validate(&module)) {
- errorStream << "Validation: " << validator.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
tint::inspector::Inspector inspector(module);
auto entryPoints = inspector.GetEntryPoints();
if (inspector.has_error()) {
@@ -862,8 +769,17 @@
} // anonymous namespace
- MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
- const ShaderModuleDescriptor* descriptor) {
+ ShaderModuleParseResult::ShaderModuleParseResult() = default;
+ ShaderModuleParseResult::~ShaderModuleParseResult() = default;
+
+ ShaderModuleParseResult::ShaderModuleParseResult(ShaderModuleParseResult&& rhs) = default;
+
+ ShaderModuleParseResult& ShaderModuleParseResult::operator=(ShaderModuleParseResult&& rhs) =
+ default;
+
+ ResultOrError<ShaderModuleParseResult> ValidateShaderModuleDescriptor(
+ DeviceBase* device,
+ const ShaderModuleDescriptor* descriptor) {
const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
if (chainedDescriptor == nullptr) {
return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor");
@@ -874,11 +790,29 @@
"Shader module descriptor chained nextInChain must be nullptr");
}
+ ShaderModuleParseResult parseResult = {};
switch (chainedDescriptor->sType) {
case wgpu::SType::ShaderModuleSPIRVDescriptor: {
const auto* spirvDesc =
static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
- DAWN_TRY(ValidateSpirv(spirvDesc->code, spirvDesc->codeSize));
+ std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+ if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+#ifdef DAWN_ENABLE_WGSL
+ tint::ast::Module module;
+ DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
+ if (device->IsValidationEnabled()) {
+ DAWN_TRY(ValidateModule(&module));
+ }
+ parseResult.tintModule = std::make_unique<tint::ast::Module>(std::move(module));
+#else
+ return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
+#endif // DAWN_ENABLE_WGSL
+ } else {
+ if (device->IsValidationEnabled()) {
+ DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
+ }
+ parseResult.spirv = std::move(spirv);
+ }
break;
}
@@ -886,17 +820,35 @@
#ifdef DAWN_ENABLE_WGSL
const auto* wgslDesc =
static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
- DAWN_TRY(ValidateWGSL(wgslDesc->source));
+
+ if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ tint::ast::Module module;
+ DAWN_TRY_ASSIGN(module, ParseWGSL(wgslDesc->source));
+ if (device->IsValidationEnabled()) {
+ DAWN_TRY(ValidateModule(&module));
+ }
+ parseResult.tintModule = std::make_unique<tint::ast::Module>(std::move(module));
+ } else {
+ tint::ast::Module module;
+ DAWN_TRY_ASSIGN(module, ParseWGSL(wgslDesc->source));
+ if (device->IsValidationEnabled()) {
+ DAWN_TRY(ValidateModule(&module));
+ }
+ std::vector<uint32_t> spirv;
+ DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(std::move(module)));
+ DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
+ parseResult.spirv = std::move(spirv);
+ }
break;
#else
- return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
+ return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
#endif // DAWN_ENABLE_WGSL
}
default:
return DAWN_VALIDATION_ERROR("Unsupported sType");
}
- return {};
+ return std::move(parseResult);
}
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
@@ -910,6 +862,23 @@
return bufferSizes;
}
+#ifdef DAWN_ENABLE_WGSL
+ ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
+ tint::ast::Module* module) {
+ tint::transform::Transform::Output output = manager->Run(module);
+ if (output.diagnostics.contains_errors()) {
+ std::string err =
+ "Tint transform failure: " + tint::diag::Formatter{}.format(output.diagnostics);
+ return DAWN_VALIDATION_ERROR(err.c_str());
+ }
+
+ if (!output.module.IsValid()) {
+ return DAWN_VALIDATION_ERROR("Tint transform did not produce valid module.");
+ }
+ return std::move(output.module);
+ }
+#endif
+
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) {
@@ -994,49 +963,136 @@
}
const std::vector<uint32_t>& ShaderModuleBase::GetSpirv() const {
+ ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
return mSpirv;
}
#ifdef DAWN_ENABLE_WGSL
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
+ const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const {
- std::vector<uint32_t> spirv;
- DAWN_TRY_ASSIGN(spirv, ConvertWGSLToSPIRVWithPulling(mWgsl.c_str(), vertexState, entryPoint,
- pullingBufferBindingSet));
+ tint::ast::Module module;
+ DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
+
+ return GeneratePullingSpirv(&module, vertexState, entryPoint, pullingBufferBindingSet);
+ }
+
+ ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
+ tint::ast::Module* moduleIn,
+ const VertexStateDescriptor& vertexState,
+ const std::string& entryPoint,
+ uint32_t pullingBufferBindingSet) const {
+ std::ostringstream errorStream;
+ errorStream << "Tint vertex pulling failure:" << std::endl;
+
+ tint::transform::Manager transformManager;
+ {
+ auto transform = std::make_unique<tint::transform::VertexPulling>();
+ tint::transform::VertexStateDescriptor state;
+ for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
+ const auto& vertexBuffer = vertexState.vertexBuffers[i];
+ tint::transform::VertexBufferLayoutDescriptor layout;
+ layout.array_stride = vertexBuffer.arrayStride;
+ layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
+
+ for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
+ const auto& attribute = vertexBuffer.attributes[j];
+ tint::transform::VertexAttributeDescriptor attr;
+ attr.format = ToTintVertexFormat(attribute.format);
+ attr.offset = attribute.offset;
+ attr.shader_location = attribute.shaderLocation;
+
+ layout.attributes.push_back(std::move(attr));
+ }
+
+ state.push_back(std::move(layout));
+ }
+ transform->SetVertexState(std::move(state));
+ transform->SetEntryPoint(entryPoint);
+ transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
+ transformManager.append(std::move(transform));
+ }
+ if (GetDevice()->IsRobustnessEnabled()) {
+ // TODO(enga): Run the Tint BoundArrayAccessors transform instead of the SPIRV Tools
+ // one, but it appears to crash after running VertexPulling.
+ // transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
+ }
+
+ tint::ast::Module module;
+ DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, moduleIn));
+
+ tint::writer::spirv::Generator generator(std::move(module));
+ if (!generator.Generate()) {
+ errorStream << "Generator: " << generator.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ std::vector<uint32_t> spirv = generator.result();
if (GetDevice()->IsRobustnessEnabled()) {
DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv));
}
-
+ DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
return std::move(spirv);
}
#endif
- MaybeError ShaderModuleBase::InitializeBase() {
- std::vector<uint32_t> spirv;
- if (mType == Type::Wgsl) {
+ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
#ifdef DAWN_ENABLE_WGSL
- DAWN_TRY_ASSIGN(spirv, ConvertWGSLToSPIRV(mWgsl.c_str()));
+ tint::ast::Module* module = parseResult->tintModule.get();
+#endif
+ mSpirv = std::move(parseResult->spirv);
+
+ // If not using Tint to generate backend code, run the robust buffer access pass now since
+ // all backends will use this SPIR-V. If Tint is used, the robustness pass should be run
+ // per-backend.
+ if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator) &&
+ GetDevice()->IsRobustnessEnabled()) {
+ DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv));
+ }
+
+ // We still need the spirv for reflection. Remove this when we use the Tint inspector
+ // completely.
+ std::vector<uint32_t>* spirvPtr = &mSpirv;
+ std::vector<uint32_t> localSpirv;
+ if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+#ifdef DAWN_ENABLE_WGSL
+ ASSERT(module != nullptr);
+ tint::ast::Module clonedModule = module->Clone();
+ tint::TypeDeterminer typeDeterminer(&clonedModule);
+ if (!typeDeterminer.Determine()) {
+ return DAWN_VALIDATION_ERROR(typeDeterminer.error().c_str());
+ }
+ DAWN_TRY_ASSIGN(localSpirv, ModuleToSPIRV(std::move(clonedModule)));
+ DAWN_TRY(ValidateSpirv(localSpirv.data(), localSpirv.size()));
+ spirvPtr = &localSpirv;
+#else
+ UNREACHABLE();
+#endif
+ }
+
+ if (GetDevice()->IsToggleEnabled(Toggle::UseTintInspector)) {
+#ifdef DAWN_ENABLE_WGSL
+ tint::ast::Module localModule;
+
+ tint::ast::Module* modulePtr = module;
+ if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ // We have mSpirv, but no Tint module
+ DAWN_TRY_ASSIGN(localModule, ParseSPIRV(mSpirv));
+ DAWN_TRY(ValidateModule(&localModule));
+ modulePtr = &localModule;
+ }
+
+ EntryPointMetadataTable table;
+ DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), *modulePtr));
+ DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), *spirvPtr, &table));
+ mEntryPoints = std::move(table);
#else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
-#endif // DAWN_ENABLE_WGSL
+#endif
} else {
- spirv = mOriginalSpirv;
- }
-
- if (GetDevice()->IsRobustnessEnabled()) {
- DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv));
- }
-
- mSpirv = std::move(spirv);
- if (GetDevice()->IsToggleEnabled(Toggle::UseTintInspector)) {
- EntryPointMetadataTable table;
- DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), mSpirv));
- DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), mSpirv, &table));
- mEntryPoints = std::move(table);
- } else {
- DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), mSpirv));
+ DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), *spirvPtr));
}
return {};
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index e406f49..e05971a 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -31,6 +31,18 @@
#include <unordered_map>
#include <vector>
+namespace tint {
+
+ namespace ast {
+ class Module;
+ } // namespace ast
+
+ namespace transform {
+ class Manager;
+ } // namespace transform
+
+} // namespace tint
+
namespace spirv_cross {
class Compiler;
}
@@ -43,14 +55,31 @@
using EntryPointMetadataTable =
std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>;
- MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
- const ShaderModuleDescriptor* descriptor);
+ struct ShaderModuleParseResult {
+ ShaderModuleParseResult();
+ ~ShaderModuleParseResult();
+ ShaderModuleParseResult(ShaderModuleParseResult&& rhs);
+ ShaderModuleParseResult& operator=(ShaderModuleParseResult&& rhs);
+
+#ifdef DAWN_ENABLE_WGSL
+ std::unique_ptr<tint::ast::Module> tintModule;
+#endif
+ std::vector<uint32_t> spirv;
+ };
+
+ ResultOrError<ShaderModuleParseResult> ValidateShaderModuleDescriptor(
+ DeviceBase* device,
+ const ShaderModuleDescriptor* descriptor);
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
+#ifdef DAWN_ENABLE_WGSL
+ ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
+ tint::ast::Module* module);
+#endif
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so
@@ -116,13 +145,20 @@
#ifdef DAWN_ENABLE_WGSL
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
+ const std::vector<uint32_t>& spirv,
+ const VertexStateDescriptor& vertexState,
+ const std::string& entryPoint,
+ uint32_t pullingBufferBindingSet) const;
+
+ ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
+ tint::ast::Module* module,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const;
#endif
protected:
- MaybeError InitializeBase();
+ MaybeError InitializeBase(ShaderModuleParseResult* parseResult);
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
diff --git a/src/dawn_native/Toggles.cpp b/src/dawn_native/Toggles.cpp
index 3ca4dad..f48a274 100644
--- a/src/dawn_native/Toggles.cpp
+++ b/src/dawn_native/Toggles.cpp
@@ -143,7 +143,7 @@
"http://crbug.com/1138528"}},
{Toggle::UseTintGenerator,
{"use_tint_generator", "Use Tint instead of SPRIV-cross to generate shaders.",
- "https://crbug.com/dawn/548"}},
+ "https://crbug.com/dawn/571"}},
{Toggle::UseTintInspector,
{"use_tint_inspector", "Use Tint instead of SPRIV-cross for shader reflection.",
"https://crbug.com/dawn/578"}},
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index 4db2133..d28767c 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -325,8 +325,9 @@
return new Sampler(this, descriptor);
}
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) {
- return ShaderModule::Create(this, descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
+ return ShaderModule::Create(this, descriptor, parseResult);
}
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 5d54a3e..dc41448 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -160,7 +160,8 @@
const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) override;
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index d415225..8daf854 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -174,9 +174,10 @@
// static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
- const ShaderModuleDescriptor* descriptor) {
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
- DAWN_TRY(module->InitializeBase());
+ DAWN_TRY(module->Initialize(parseResult));
return module.Detach();
}
@@ -184,6 +185,14 @@
: ShaderModuleBase(device, descriptor) {
}
+ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+ DAWN_TRY(InitializeBase(parseResult));
+#ifdef DAWN_ENABLE_WGSL
+ mTintModule = std::move(parseResult->tintModule);
+#endif
+ return {};
+ }
+
ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint(
const char* entryPointName,
SingleShaderStage stage,
@@ -195,41 +204,11 @@
std::ostringstream errorStream;
errorStream << "Tint HLSL failure:" << std::endl;
- // TODO: Remove redundant SPIRV step between WGSL and HLSL.
- tint::reader::spirv::Parser parser(GetSpirv());
-
- if (!parser.Parse()) {
- errorStream << "Parser: " << parser.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::ast::Module module = parser.module();
- if (!module.IsValid()) {
- errorStream << "Invalid module generated..." << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::TypeDeterminer typeDeterminer(&module);
- if (!typeDeterminer.Determine()) {
- errorStream << "Type Determination: " << typeDeterminer.error();
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- tint::Validator validator;
- if (!validator.Validate(&module)) {
- errorStream << "Validation: " << validator.error() << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
- auto result = transformManager.Run(&module);
- if (result.diagnostics.contains_errors()) {
- errorStream << "Bound Array Accessors Transform: "
- << tint::diag::Formatter{}.format(result.diagnostics);
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
- module = std::move(result.module);
+
+ tint::ast::Module module;
+ DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, mTintModule.get()));
ASSERT(remappedEntryPointName != nullptr);
tint::inspector::Inspector inspector(module);
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 63cb2e9..f5a94f0 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -36,7 +36,8 @@
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<ShaderModule*> Create(Device* device,
- const ShaderModuleDescriptor* descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult);
ResultOrError<CompiledShader> Compile(const char* entryPointName,
SingleShaderStage stage,
@@ -46,6 +47,7 @@
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
+ MaybeError Initialize(ShaderModuleParseResult* parseResult);
ResultOrError<std::string> TranslateToHLSLWithTint(
const char* entryPointName,
@@ -61,6 +63,10 @@
SingleShaderStage stage,
const std::string& hlslSource,
uint32_t compileFlags) const;
+
+#ifdef DAWN_ENABLE_WGSL
+ std::unique_ptr<tint::ast::Module> mTintModule;
+#endif
};
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index c582758..53da499 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -92,7 +92,8 @@
const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) override;
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 3846d0d..37df56a 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -151,8 +151,9 @@
return Sampler::Create(this, descriptor);
}
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) {
- return ShaderModule::Create(this, descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
+ return ShaderModule::Create(this, descriptor, parseResult);
}
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 6ecf57d..3fb2dc2 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -35,7 +35,8 @@
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<ShaderModule*> Create(Device* device,
- const ShaderModuleDescriptor* descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult);
struct MetalFunctionData {
NSPRef<id<MTLFunction>> function;
@@ -51,7 +52,11 @@
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
- MaybeError Initialize();
+ MaybeError Initialize(ShaderModuleParseResult* parseResult);
+
+#ifdef DAWN_ENABLE_WGSL
+ std::unique_ptr<tint::ast::Module> mTintModule;
+#endif
};
}} // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 5af363b..6d52770 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -22,15 +22,24 @@
#include <spirv_msl.hpp>
+#ifdef DAWN_ENABLE_WGSL
+// Tint include must be after spirv_msl.hpp, because spirv-cross has its own
+// version of spirv_headers. We also need to undef SPV_REVISION because SPIRV-Cross
+// is at 3 while spirv-headers is at 4.
+# undef SPV_REVISION
+# include <tint/tint.h>
+#endif // DAWN_ENABLE_WGSL
+
#include <sstream>
namespace dawn_native { namespace metal {
// static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
- const ShaderModuleDescriptor* descriptor) {
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
- DAWN_TRY(module->Initialize());
+ DAWN_TRY(module->Initialize(parseResult));
return module.Detach();
}
@@ -38,8 +47,12 @@
: ShaderModuleBase(device, descriptor) {
}
- MaybeError ShaderModule::Initialize() {
- return InitializeBase();
+ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+ DAWN_TRY(InitializeBase(parseResult));
+#ifdef DAWN_ENABLE_WGSL
+ mTintModule = std::move(parseResult->tintModule);
+#endif
+ return {};
}
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
@@ -59,9 +72,17 @@
std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
stage == SingleShaderStage::Vertex) {
- DAWN_TRY_ASSIGN(pullingSpirv,
- GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
- entryPointName, kPullingBufferBindingSet));
+ if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ DAWN_TRY_ASSIGN(pullingSpirv,
+ GeneratePullingSpirv(mTintModule.get(),
+ *renderPipeline->GetVertexStateDescriptor(),
+ entryPointName, kPullingBufferBindingSet));
+ } else {
+ DAWN_TRY_ASSIGN(
+ pullingSpirv,
+ GeneratePullingSpirv(GetSpirv(), *renderPipeline->GetVertexStateDescriptor(),
+ entryPointName, kPullingBufferBindingSet));
+ }
spirv = &pullingSpirv;
}
#endif
diff --git a/src/dawn_native/null/DeviceNull.cpp b/src/dawn_native/null/DeviceNull.cpp
index fe981d2..a821acf 100644
--- a/src/dawn_native/null/DeviceNull.cpp
+++ b/src/dawn_native/null/DeviceNull.cpp
@@ -127,9 +127,10 @@
return new Sampler(this, descriptor);
}
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) {
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
- DAWN_TRY(module->Initialize());
+ DAWN_TRY(module->Initialize(parseResult));
return module.Detach();
}
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
@@ -395,8 +396,8 @@
// ShaderModule
- MaybeError ShaderModule::Initialize() {
- return InitializeBase();
+ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+ return InitializeBase(parseResult);
}
// OldSwapChain
diff --git a/src/dawn_native/null/DeviceNull.h b/src/dawn_native/null/DeviceNull.h
index 2ffadbe..9a73a9b 100644
--- a/src/dawn_native/null/DeviceNull.h
+++ b/src/dawn_native/null/DeviceNull.h
@@ -135,7 +135,8 @@
const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) override;
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
@@ -246,7 +247,7 @@
public:
using ShaderModuleBase::ShaderModuleBase;
- MaybeError Initialize();
+ MaybeError Initialize(ShaderModuleParseResult* parseResult);
};
class SwapChain final : public NewSwapChainBase {
diff --git a/src/dawn_native/opengl/DeviceGL.cpp b/src/dawn_native/opengl/DeviceGL.cpp
index 4ddb91c..9099ae3 100644
--- a/src/dawn_native/opengl/DeviceGL.cpp
+++ b/src/dawn_native/opengl/DeviceGL.cpp
@@ -128,8 +128,9 @@
return new Sampler(this, descriptor);
}
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) {
- return ShaderModule::Create(this, descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
+ return ShaderModule::Create(this, descriptor, parseResult);
}
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/opengl/DeviceGL.h b/src/dawn_native/opengl/DeviceGL.h
index d19ffe2..1d1f087 100644
--- a/src/dawn_native/opengl/DeviceGL.h
+++ b/src/dawn_native/opengl/DeviceGL.h
@@ -91,7 +91,8 @@
const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) override;
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index 792a2fc..a3b7dda 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -59,9 +59,10 @@
// static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
- const ShaderModuleDescriptor* descriptor) {
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
- DAWN_TRY(module->InitializeBase());
+ DAWN_TRY(module->InitializeBase(parseResult));
return module.Detach();
}
diff --git a/src/dawn_native/opengl/ShaderModuleGL.h b/src/dawn_native/opengl/ShaderModuleGL.h
index c18fa48..a003e19 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.h
+++ b/src/dawn_native/opengl/ShaderModuleGL.h
@@ -47,7 +47,8 @@
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<ShaderModule*> Create(Device* device,
- const ShaderModuleDescriptor* descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult);
std::string TranslateToGLSL(const char* entryPointName,
SingleShaderStage stage,
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index 71f42a8..2fda4ed 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -136,8 +136,9 @@
return Sampler::Create(this, descriptor);
}
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) {
- return ShaderModule::Create(this, descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
+ return ShaderModule::Create(this, descriptor, parseResult);
}
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 6380481..40b11f7 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -127,7 +127,8 @@
const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
- const ShaderModuleDescriptor* descriptor) override;
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index fb107af..f606252 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -24,12 +24,13 @@
// static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
- const ShaderModuleDescriptor* descriptor) {
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
if (module == nullptr) {
return DAWN_VALIDATION_ERROR("Unable to create ShaderModule");
}
- DAWN_TRY(module->Initialize());
+ DAWN_TRY(module->Initialize(parseResult));
return module.Detach();
}
@@ -37,8 +38,8 @@
: ShaderModuleBase(device, descriptor) {
}
- MaybeError ShaderModule::Initialize() {
- DAWN_TRY(InitializeBase());
+ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
+ DAWN_TRY(InitializeBase(parseResult));
const std::vector<uint32_t>& spirv = GetSpirv();
VkShaderModuleCreateInfo createInfo;
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 720cc5e..621ab0e 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -27,14 +27,15 @@
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<ShaderModule*> Create(Device* device,
- const ShaderModuleDescriptor* descriptor);
+ const ShaderModuleDescriptor* descriptor,
+ ShaderModuleParseResult* parseResult);
VkShaderModule GetHandle() const;
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override;
- MaybeError Initialize();
+ MaybeError Initialize(ShaderModuleParseResult* parseResult);
VkShaderModule mHandle = VK_NULL_HANDLE;
};