Dawn: Implement shader module front-end blob cache
This CL implement the blob cache for shader module front-end, currently
caching the reflection information, compilation messages, and validation
error (if any) into disk. With this CL, shader modules and pipelines can
be created without calling Tint and backend compilers at all on backends
supporting pipeline cache (i.e. D3D and vk) if cache hit.
This CL also make FeaturesSet and TogglesSet serializable so we can add
them in cache request.
Bug: 42240459, 402772740
Change-Id: Id576776f7e1b9d54f431a2fa50082bc5951880ee
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/242854
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
diff --git a/include/dawn/native/D3DBackend.h b/include/dawn/native/D3DBackend.h
index 1fad332..7c37184 100644
--- a/include/dawn/native/D3DBackend.h
+++ b/include/dawn/native/D3DBackend.h
@@ -30,7 +30,6 @@
#include <dxgi1_4.h>
#include <webgpu/webgpu_cpp_chained_struct.h>
-#include <windows.h>
#include <wrl/client.h>
#include <memory>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 9d04c28..b866cf0 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -369,6 +369,8 @@
"Serializable.h",
"ShaderModule.cpp",
"ShaderModule.h",
+ "ShaderModuleParseRequest.cpp",
+ "ShaderModuleParseRequest.h",
"SharedBufferMemory.cpp",
"SharedBufferMemory.h",
"SharedFence.cpp",
diff --git a/src/dawn/native/Buffer.h b/src/dawn/native/Buffer.h
index d0439eb..a6484f1 100644
--- a/src/dawn/native/Buffer.h
+++ b/src/dawn/native/Buffer.h
@@ -47,6 +47,7 @@
namespace dawn::native {
struct CopyTextureToBufferCmd;
+class MemoryDump;
enum class MapType : uint32_t;
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index ee38f41..3ff07d8 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -130,6 +130,7 @@
"ScratchBuffer.h"
"Serializable.h"
"ShaderModule.h"
+ "ShaderModuleParseRequest.h"
"SharedBufferMemory.h"
"SharedFence.h"
"SharedResourceMemory.h"
@@ -238,6 +239,7 @@
"Sampler.cpp"
"ScratchBuffer.cpp"
"ShaderModule.cpp"
+ "ShaderModuleParseRequest.cpp"
"SharedBufferMemory.cpp"
"SharedFence.cpp"
"SharedResourceMemory.cpp"
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 11cd5d9..1124f13 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -48,6 +48,8 @@
#include "dawn/native/BlitBufferToDepthStencil.h"
#include "dawn/native/BlobCache.h"
#include "dawn/native/Buffer.h"
+#include "dawn/native/CacheRequest.h"
+#include "dawn/native/CacheResult.h"
#include "dawn/native/ChainUtils.h"
#include "dawn/native/CommandBuffer.h"
#include "dawn/native/CommandEncoder.h"
@@ -70,6 +72,7 @@
#include "dawn/native/RenderBundleEncoder.h"
#include "dawn/native/RenderPipeline.h"
#include "dawn/native/Sampler.h"
+#include "dawn/native/ShaderModuleParseRequest.h"
#include "dawn/native/SharedBufferMemory.h"
#include "dawn/native/SharedFence.h"
#include "dawn/native/SharedTextureMemory.h"
@@ -2131,15 +2134,6 @@
const ShaderModuleDescriptor* descriptor,
const std::vector<tint::wgsl::Extension>& internalExtensions,
ShaderModuleParseResult* outputParseResult) {
- // ShaderModuleParseResult holds OwnedCompilationMessages, which would be used for creating
- // error shader module if errors occurred. If the outputParseResult is not provided, create an
- // inplace one.
- std::optional<ShaderModuleParseResult> inplaceParseResult;
- if (outputParseResult == nullptr) {
- inplaceParseResult.emplace();
- outputParseResult = &*inplaceParseResult;
- }
-
DAWN_TRY(ValidateIsAlive());
// Unpack and validate the descriptor chain before doing further validation or cache
@@ -2182,21 +2176,42 @@
const size_t blueprintHash = blueprint.ComputeContentHash();
blueprint.SetContentHash(blueprintHash);
- // Check in-memory shader module cache first, and if missed call ParseShaderModule.
+ // Check in-memory shader module cache first, and if missed check the blob cache, and if missed
+ // again call ParseShaderModule.
return GetOrCreate(
mCaches->shaderModules, &blueprint, [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(GetPlatform(), "CreateShaderModuleUS");
auto resultOrError = [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
- // Try to validate and parse the shader code, and if an error occurred return it
- // without updating the cache.
- DAWN_TRY(ParseShaderModule(this, unpacked, internalExtensions,
- /* needReflection */ true, outputParseResult));
+ ShaderModuleParseRequest req = BuildShaderModuleParseRequest(
+ this, blueprint.GetHash(), unpacked, internalExtensions,
+ /* needReflection */ true);
+ // Check blob cache first before calling ParseShaderModule. ShaderModuleParseResult
+ // returned from blob cache or ParseShaderModule will hold compilation messages and
+ // validation errors if any. ShaderModuleParseResult from ParseShaderModule also
+ // holds tint program.
+ CacheResult<ShaderModuleParseResult> result;
+ DAWN_TRY_LOAD_OR_RUN(result, this, std::move(req),
+ ShaderModuleParseResult::FromBlob, ParseShaderModule,
+ "ShaderModuleParsing");
+ GetBlobCache()->EnsureStored(result);
+ ShaderModuleParseResult parseResult = result.Acquire();
+
+ // If ShaderModuleParseResult has validation error, move the compilation messages to
+ // *outputParseResult so that we can create an error shader module from it, and then
+ // return the validation error.
+ if (parseResult.HasError()) {
+ auto error = parseResult.cachedValidationError->ToErrorData();
+ if (outputParseResult) {
+ *outputParseResult = std::move(parseResult);
+ }
+ return error;
+ }
+ // Otherwise with no error, create a shader module from parse result and return it.
Ref<ShaderModuleBase> shaderModule;
- // If created successfully, outputParseResult are moved into the shader module.
- DAWN_TRY_ASSIGN(shaderModule, CreateShaderModuleImpl(unpacked, internalExtensions,
- outputParseResult));
+ DAWN_TRY_ASSIGN(shaderModule,
+ CreateShaderModuleImpl(unpacked, internalExtensions, &parseResult));
shaderModule->SetContentHash(blueprintHash);
return shaderModule;
}();
@@ -2283,6 +2298,10 @@
return mToggles;
}
+const FeaturesSet& DeviceBase::GetEnabledFeatures() const {
+ return mEnabledFeatures;
+}
+
void DeviceBase::ForceEnableFeatureForTesting(Feature feature) {
mEnabledFeatures.EnableFeature(feature);
mFormatTable = BuildFormatTable(this);
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 9fbdef6..8b5bcf8 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -333,6 +333,7 @@
const tint::wgsl::AllowedFeatures& GetWGSLAllowedFeatures() const;
bool IsToggleEnabled(Toggle toggle) const;
const TogglesState& GetTogglesState() const;
+ const FeaturesSet& GetEnabledFeatures() const;
bool IsValidationEnabled() const;
bool IsRobustnessEnabled() const;
bool IsCompatibilityMode() const;
diff --git a/src/dawn/native/Error.h b/src/dawn/native/Error.h
index 0cea9a9..39485e5 100644
--- a/src/dawn/native/Error.h
+++ b/src/dawn/native/Error.h
@@ -34,7 +34,6 @@
#include "dawn/common/Result.h"
#include "dawn/native/ErrorData.h"
-#include "dawn/native/Toggles.h"
#include "dawn/native/webgpu_absl_format.h"
namespace dawn::native {
diff --git a/src/dawn/native/Features.h b/src/dawn/native/Features.h
index 4869100..7ae6f03 100644
--- a/src/dawn/native/Features.h
+++ b/src/dawn/native/Features.h
@@ -35,6 +35,7 @@
#include "dawn/common/ityp_bitset.h"
#include "dawn/native/DawnNative.h"
#include "dawn/native/Features_autogen.h"
+#include "dawn/native/Serializable.h"
namespace dawn::native {
@@ -45,15 +46,16 @@
// A wrapper of the bitset to store if an feature is enabled or not. This wrapper provides the
// convenience to convert the enums of enum class Feature to the indices of a bitset.
-struct FeaturesSet {
- ityp::bitset<Feature, kEnumCount<Feature>> featuresBitSet;
-
+using FeaturesBitSet = ityp::bitset<Feature, kEnumCount<Feature>>;
+#define FEATURES_SET_MEMBER(X) X(FeaturesBitSet, featuresBitSet)
+DAWN_SERIALIZABLE(struct, FeaturesSet, FEATURES_SET_MEMBER) {
void EnableFeature(Feature feature);
void EnableFeature(wgpu::FeatureName feature);
bool IsEnabled(Feature feature) const;
bool IsEnabled(wgpu::FeatureName feature) const;
- void ToSupportedFeatures(SupportedFeatures* supportedFeatures) const;
+ void ToSupportedFeatures(SupportedFeatures * supportedFeatures) const;
};
+#undef FEATURES_SET_MEMBER
} // namespace dawn::native
diff --git a/src/dawn/native/Limits.cpp b/src/dawn/native/Limits.cpp
index bb94732..086331e 100644
--- a/src/dawn/native/Limits.cpp
+++ b/src/dawn/native/Limits.cpp
@@ -371,15 +371,14 @@
DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
return result;
}
+LimitsForShaderModuleParseRequest LimitsForShaderModuleParseRequest::Create(const Limits& limits) {
+ LimitsForShaderModuleParseRequest result;
+ DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_SHADER_MODULE_PARSE_REQUEST_MEMBERS)
+ return result;
+}
#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT
#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT
-template <>
-void stream::Stream<LimitsForCompilationRequest>::Write(Sink* s,
- const LimitsForCompilationRequest& t) {
- t.VisitAll([&](const auto&... members) { StreamIn(s, members...); });
-}
-
void NormalizeLimits(CombinedLimits* limits) {
// Enforce internal Dawn constants for some limits to ensure they don't go over fixed limits
// in Dawn's internal code.
diff --git a/src/dawn/native/Limits.h b/src/dawn/native/Limits.h
index 4858982..d87acdd 100644
--- a/src/dawn/native/Limits.h
+++ b/src/dawn/native/Limits.h
@@ -33,7 +33,7 @@
#include "dawn/native/ChainUtils.h"
#include "dawn/native/Error.h"
#include "dawn/native/Features.h"
-#include "dawn/native/VisitableMembers.h"
+#include "dawn/native/Serializable.h"
#include "dawn/native/dawn_platform.h"
namespace dawn::native {
@@ -83,10 +83,19 @@
X(uint32_t, maxComputeInvocationsPerWorkgroup) \
X(uint32_t, maxComputeWorkgroupStorageSize)
-struct LimitsForCompilationRequest {
+DAWN_SERIALIZABLE(struct, LimitsForCompilationRequest, LIMITS_FOR_COMPILATION_REQUEST_MEMBERS) {
static LimitsForCompilationRequest Create(const Limits& limits);
- DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS)
- bool operator==(const LimitsForCompilationRequest& other) const = default;
+};
+
+#define LIMITS_FOR_SHADER_MODULE_PARSE_REQUEST_MEMBERS(X) \
+ X(uint32_t, maxVertexAttributes) \
+ X(uint32_t, maxInterStageShaderVariables) \
+ X(uint32_t, maxColorAttachments)
+
+DAWN_SERIALIZABLE(struct,
+ LimitsForShaderModuleParseRequest,
+ LIMITS_FOR_SHADER_MODULE_PARSE_REQUEST_MEMBERS) {
+ static LimitsForShaderModuleParseRequest Create(const Limits& limits);
};
// Enforce restriction for limit values, including:
diff --git a/src/dawn/native/Serializable.h b/src/dawn/native/Serializable.h
index 88918c2..31ea636 100644
--- a/src/dawn/native/Serializable.h
+++ b/src/dawn/native/Serializable.h
@@ -75,6 +75,7 @@
public:
UnsafeUnserializedValue() = default;
explicit UnsafeUnserializedValue(T&& value) : mValue(std::forward<T>(value)) {}
+ explicit UnsafeUnserializedValue(const T& value) : mValue(value) {}
UnsafeUnserializedValue(const UnsafeUnserializedValue<T>& other)
: mValue(other.UnsafeGetValue()) {}
UnsafeUnserializedValue<T>& operator=(UnsafeUnserializedValue<T>&& other) {
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index da1ffe6..5f54066 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -46,6 +46,7 @@
#include "dawn/native/PipelineLayout.h"
#include "dawn/native/RenderPipeline.h"
#include "dawn/native/Sampler.h"
+#include "dawn/native/ShaderModuleParseRequest.h"
#include "dawn/native/TintUtils.h"
#ifdef DAWN_ENABLE_SPIRV_VALIDATION
@@ -58,8 +59,7 @@
namespace {
-ResultOrError<SingleShaderStage> TintPipelineStageToShaderStage(
- tint::inspector::PipelineStage stage) {
+SingleShaderStage TintPipelineStageToShaderStage(tint::inspector::PipelineStage stage) {
switch (stage) {
case tint::inspector::PipelineStage::kVertex:
return SingleShaderStage::Vertex;
@@ -376,45 +376,65 @@
DAWN_UNREACHABLE();
}
-ResultOrError<tint::Program> ParseWGSL(const tint::Source::File* file,
- const tint::wgsl::AllowedFeatures& allowedFeatures,
- const std::vector<tint::wgsl::Extension>& internalExtensions,
- ParsedCompilationMessages* outMessages) {
+// Validation errors, if any, are stored within outputParseResult instead of get returned as
+// ErrorData.
+MaybeError ParseWGSL(std::unique_ptr<tint::Source::File> file,
+ const WGSLAllowedFeatures& allowedFeatures,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
+ ShaderModuleParseResult* outputParseResult) {
tint::wgsl::reader::Options options;
- options.allowed_features = allowedFeatures;
+ options.allowed_features = allowedFeatures.ToTint();
options.allowed_features.extensions.insert(internalExtensions.begin(),
internalExtensions.end());
- tint::Program program = tint::wgsl::reader::Parse(file, options);
- if (outMessages != nullptr) {
- DAWN_TRY(outMessages->AddMessages(program.Diagnostics()));
- }
- if (!program.IsValid()) {
- return DAWN_VALIDATION_ERROR("Error while parsing WGSL: %s\n", program.Diagnostics().Str());
+ tint::Program program = tint::wgsl::reader::Parse(file.get(), options);
+
+ // Store the compilation messages into outputParseResult.
+ DAWN_TRY(outputParseResult->compilationMessages.AddMessages(program.Diagnostics()));
+
+ // If WGSL parsing succeed, store the generated Tint program with no validation error.
+ if (program.IsValid()) {
+ outputParseResult->tintProgram = UnsafeUnserializedValue<std::optional<Ref<TintProgram>>>(
+ AcquireRef(new TintProgram(std::move(program), std::move(file))));
+ DAWN_ASSERT(outputParseResult->HasTintProgram() && !outputParseResult->HasError());
+ } else {
+ // Otherwise, store the validation error messages to outputParseResult.
+ outputParseResult->SetValidationError(
+ DAWN_VALIDATION_ERROR("Error while parsing WGSL: %s\n", program.Diagnostics().Str()));
+ DAWN_ASSERT(!outputParseResult->HasTintProgram() && outputParseResult->HasError());
}
- return std::move(program);
+ return {};
}
#if TINT_BUILD_SPV_READER
-ResultOrError<tint::Program> ParseSPIRV(const std::vector<uint32_t>& spirv,
- const tint::wgsl::AllowedFeatures& allowedFeatures,
- ParsedCompilationMessages* outMessages,
- const DawnShaderModuleSPIRVOptionsDescriptor* optionsDesc) {
+// Validation errors, if any, are stored within outputParseResult instead of get returned as
+// ErrorData
+MaybeError ParseSPIRV(const std::vector<uint32_t>& spirv,
+ const WGSLAllowedFeatures& allowedFeatures,
+ ShaderModuleParseResult* outputParseResult,
+ bool allowNonUniformDerivatives) {
tint::spirv::reader::Options options;
- if (optionsDesc) {
- options.allow_non_uniform_derivatives = optionsDesc->allowNonUniformDerivatives;
- }
- options.allowed_features = allowedFeatures;
+ options.allow_non_uniform_derivatives = allowNonUniformDerivatives;
+ options.allowed_features = allowedFeatures.ToTint();
+
tint::Program program = tint::spirv::reader::Read(spirv, options);
- if (outMessages != nullptr) {
- DAWN_TRY(outMessages->AddMessages(program.Diagnostics()));
- }
- if (!program.IsValid()) {
- return DAWN_VALIDATION_ERROR("Error while parsing SPIR-V: %s\n",
- program.Diagnostics().Str());
+
+ // Store the compilation messages into outputParseResult.
+ DAWN_TRY(outputParseResult->compilationMessages.AddMessages(program.Diagnostics()));
+
+ // If SpirV parsing succeed, store the generated Tint program with no validation error.
+ if (program.IsValid()) {
+ outputParseResult->tintProgram = UnsafeUnserializedValue<std::optional<Ref<TintProgram>>>(
+ AcquireRef(new TintProgram(std::move(program), nullptr)));
+ DAWN_ASSERT(outputParseResult->HasTintProgram() && !outputParseResult->HasError());
+ } else {
+ // Otherwise, store the validation error messages to outputParseResult.
+ outputParseResult->SetValidationError(
+ DAWN_VALIDATION_ERROR("Error while parsing SPIR-V: %s\n", program.Diagnostics().Str()));
+ DAWN_ASSERT(!outputParseResult->HasTintProgram() && outputParseResult->HasError());
}
- return std::move(program);
+ return {};
}
#endif // TINT_BUILD_SPV_READER
@@ -661,7 +681,7 @@
}
ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
- const DeviceBase* device,
+ const ShaderModuleParseDeviceInfo& deviceInfo,
tint::inspector::Inspector* inspector,
const tint::inspector::EntryPoint& entryPoint) {
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@@ -715,7 +735,7 @@
metadata->overrides[identifier] = override;
}
- DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
+ metadata->stage = TintPipelineStageToShaderStage(entryPoint.stage);
if (metadata->stage == SingleShaderStage::Compute) {
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
@@ -725,9 +745,9 @@
metadata->usesDepthTextureWithNonComparisonSampler =
entryPoint.has_depth_texture_with_non_comparison_sampler;
- const CombinedLimits& limits = device->GetLimits();
- const uint32_t maxVertexAttributes = limits.v1.maxVertexAttributes;
- const uint32_t maxInterStageShaderVariables = limits.v1.maxInterStageShaderVariables;
+ const LimitsForShaderModuleParseRequest& limits = deviceInfo.limits;
+ const uint32_t maxVertexAttributes = limits.maxVertexAttributes;
+ const uint32_t maxInterStageShaderVariables = limits.maxInterStageShaderVariables;
metadata->usedInterStageVariables.resize(maxInterStageShaderVariables);
metadata->interStageVariables.resize(maxInterStageShaderVariables);
@@ -900,7 +920,7 @@
}
// Fragment output reflection.
- uint32_t maxColorAttachments = limits.v1.maxColorAttachments;
+ uint32_t maxColorAttachments = limits.maxColorAttachments;
for (const auto& outputVar : entryPoint.output_variables) {
EntryPointMetadata::FragmentRenderAttachmentInfo variable;
DAWN_TRY_ASSIGN(variable.baseType,
@@ -944,7 +964,7 @@
// Tint should disallow using @color(N) without the respective enable, which is gated
// on the extension.
- DAWN_ASSERT(device->HasFeature(Feature::FramebufferFetch));
+ DAWN_ASSERT(deviceInfo.features.IsEnabled(Feature::FramebufferFetch));
EntryPointMetadata::FragmentRenderAttachmentInfo variable;
DAWN_TRY_ASSIGN(variable.baseType,
@@ -990,10 +1010,10 @@
info.arraySize = BindingIndex(resource.array_size.value_or(1));
DAWN_INVALID_IF(resource.array_size.has_value() &&
- device->IsToggleEnabled(Toggle::DisableBindGroupLayoutEntryArraySize),
+ deviceInfo.toggles.Has(Toggle::DisableBindGroupLayoutEntryArraySize),
"Use of binding_array is disabled.");
DAWN_INVALID_IF(
- resource.array_size.has_value() && !device->IsToggleEnabled(Toggle::AllowUnsafeAPIs),
+ resource.array_size.has_value() && !deviceInfo.toggles.Has(Toggle::AllowUnsafeAPIs),
"Use of binding_array is disabled as an unsafe API.");
DAWN_INVALID_IF(info.arraySize == BindingIndex(0), "binding_array size is 0.");
if (DelayedInvalidIf(
@@ -1138,7 +1158,7 @@
metadata->usesSubgroupMatrix = entryPoint.uses_subgroup_matrix;
// Compute the texture+sampler combination count.
- if (device->IsCompatibilityMode()) {
+ if (deviceInfo.isCompatibilityMode) {
// separate sampled from non-sampled and put sampled in set
std::set<tint::BindingPoint> sampledTextures;
std::set<tint::BindingPoint> sampledExternalTextures;
@@ -1179,27 +1199,44 @@
return std::move(metadata);
}
-MaybeError ReflectShaderUsingTint(DeviceBase* device,
- const tint::Program* program,
- EntryPointMetadataTable* entryPointMetadataTable) {
- DAWN_ASSERT(program->IsValid());
+void ReflectShaderUsingTint(const ShaderModuleParseDeviceInfo& deviceInfo,
+ ShaderModuleParseResult* outputParseResult) {
+ DAWN_ASSERT(outputParseResult->HasTintProgram());
+ const tint::Program* program =
+ &outputParseResult->tintProgram.UnsafeGetValue().value().Get()->program;
+ DAWN_ASSERT(program && program->IsValid());
tint::inspector::Inspector inspector(*program);
std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints();
- DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n",
- inspector.error());
+ if (inspector.has_error()) {
+ outputParseResult->SetValidationError(
+ DAWN_VALIDATION_ERROR("Tint Reflection failure: Inspector: %s\n", inspector.error()));
+ return;
+ }
+
+ // A ShaderModuleParseResult should get reflected at most once.
+ DAWN_ASSERT(!outputParseResult->metadataTable.has_value());
+ EntryPointMetadataTable& metadataTable = outputParseResult->metadataTable.emplace();
for (const tint::inspector::EntryPoint& entryPoint : entryPoints) {
- std::unique_ptr<EntryPointMetadata> metadata;
- DAWN_TRY_ASSIGN_CONTEXT(metadata,
- ReflectEntryPointUsingTint(device, &inspector, entryPoint),
- "processing entry point \"%s\".", entryPoint.name);
-
- DAWN_ASSERT(!entryPointMetadataTable->contains(entryPoint.name));
- entryPointMetadataTable->emplace(entryPoint.name, std::move(metadata));
+ auto entryPointReflectionResult =
+ ReflectEntryPointUsingTint(deviceInfo, &inspector, entryPoint);
+ // If validation error occurs, store the error into output parse result, drop the incomplete
+ // metadate table, and stop reflection.
+ if (entryPointReflectionResult.IsError()) {
+ auto error = entryPointReflectionResult.AcquireError();
+ error->AppendContext(
+ absl::StrFormat("processing entry point \"%s\".", entryPoint.name));
+ // The incomplete metadate table is also dropped in SetValidationError.
+ outputParseResult->SetValidationError(std::move(error));
+ return;
+ }
+ // Otherwise add the reflection to metadata table.
+ auto reflection = entryPointReflectionResult.AcquireSuccess();
+ DAWN_ASSERT(!metadataTable.contains(entryPoint.name));
+ metadataTable.emplace(entryPoint.name, std::move(reflection));
}
- return {};
}
} // anonymous namespace
@@ -1294,80 +1331,110 @@
return Extent3D{x, y, z};
}
+CachedValidationError::CachedValidationError(std::unique_ptr<ErrorData>&& errorData) {
+ DAWN_ASSERT(errorData->GetType() == InternalErrorType::Validation);
+ message = errorData->GetMessage();
+ contexts = errorData->GetContexts();
+ DAWN_ASSERT(!message.empty());
+}
+
+std::unique_ptr<ErrorData> CachedValidationError::ToErrorData() const {
+ DAWN_ASSERT(!message.empty());
+ auto error = std::make_unique<ErrorData>(InternalErrorType::Validation, message);
+ std::for_each(contexts.begin(), contexts.end(), [&error](auto c) { error->AppendContext(c); });
+ return error;
+}
+
bool ShaderModuleParseResult::HasTintProgram() const {
return tintProgram.UnsafeGetValue().has_value() &&
tintProgram.UnsafeGetValue().value() != nullptr;
}
-MaybeError ParseShaderModule(DeviceBase* device,
- const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
- const std::vector<tint::wgsl::Extension>& internalExtensions,
- bool needReflection,
- ShaderModuleParseResult* parseResult) {
- DAWN_ASSERT(parseResult != nullptr);
+bool ShaderModuleParseResult::HasError() const {
+ // If cachedValidationError holds error, it must have non-empty error message string.
+ DAWN_ASSERT(!cachedValidationError.has_value() || !cachedValidationError->message.empty());
+ return cachedValidationError.has_value();
+}
- ParsedCompilationMessages* outMessages = &parseResult->compilationMessages;
+std::unique_ptr<ErrorData> ShaderModuleParseResult::ToErrorData() const {
+ DAWN_ASSERT(HasError());
+ return cachedValidationError->ToErrorData();
+}
- // Parse shader module to generate uncacheable part of parse result. Assuming the descriptor
- // chain has already been validated.
+void ShaderModuleParseResult::SetValidationError(std::unique_ptr<ErrorData>&& errorData) {
+ DAWN_ASSERT(errorData->GetType() == InternalErrorType::Validation);
+ cachedValidationError = CachedValidationError(std::move(errorData));
+ // If validation error occurs, clear the Tint program and metadata table.
+ tintProgram.UnsafeGetValue().reset();
+ metadataTable.reset();
+ DAWN_ASSERT(HasError());
+}
+
+ResultOrError<ShaderModuleParseResult> ParseShaderModule(ShaderModuleParseRequest req) {
+ ShaderModuleParseResult outputParseResult;
+
+ const ShaderModuleParseDeviceInfo& deviceInfo = req.deviceInfo;
+ LogEmitter* logEmitter = req.logEmitter.UnsafeGetValue();
+ bool dumpShaders = deviceInfo.toggles.Has(Toggle::DumpShaders);
+
#if TINT_BUILD_SPV_READER
// Handling SPIR-V if enabled.
- if (const auto* spirvDesc = descriptor.Get<ShaderSourceSPIRV>()) {
+ if (std::holds_alternative<ShaderModuleParseSpirvDescription>(req.shaderDescription)) {
// SpirV toggle should have been validated before chacking cache.
- DAWN_ASSERT(!device->IsToggleEnabled(Toggle::DisallowSpirv));
- // Descriptor should not contain WGSL part.
- DAWN_ASSERT(descriptor.Get<ShaderSourceWGSL>() == nullptr);
+ DAWN_ASSERT(!deviceInfo.toggles.Has(Toggle::DisallowSpirv));
- const auto* spirvOptions = descriptor.Get<DawnShaderModuleSPIRVOptionsDescriptor>();
- DAWN_ASSERT(spirvDesc != nullptr);
-
- // TODO(dawn:2033): Avoid unnecessary copies of the SPIR-V code.
- std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+ ShaderModuleParseSpirvDescription& spirvDesc =
+ std::get<ShaderModuleParseSpirvDescription>(req.shaderDescription);
+ const std::vector<uint32_t>& spirvCode = spirvDesc.spirvCode.UnsafeGetValue();
#ifdef DAWN_ENABLE_SPIRV_VALIDATION
- const bool dumpSpirv = device->IsToggleEnabled(Toggle::DumpShaders);
- DAWN_TRY(ValidateSpirv(device, spirv.data(), spirv.size(), dumpSpirv));
+ MaybeError validationResult =
+ ValidateSpirv(logEmitter, spirvCode.data(), spirvCode.size(), dumpShaders);
+ // If SpirV validation error occurs, store it into outputParseResult and return.
+ if (validationResult.IsError()) {
+ outputParseResult.SetValidationError(validationResult.AcquireError());
+ }
#endif // DAWN_ENABLE_SPIRV_VALIDATION
- tint::Program program;
- DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, device->GetWGSLAllowedFeatures(), outMessages,
- spirvOptions));
- parseResult->tintProgram = UnsafeUnserializedValue<std::optional<Ref<TintProgram>>>(
- AcquireRef(new TintProgram(std::move(program), nullptr)));
+ // Try parsing SpirV if no validation error.
+ if (!outputParseResult.HasError()) {
+ DAWN_TRY(ParseSPIRV(spirvCode, deviceInfo.wgslAllowedFeatures, &outputParseResult,
+ spirvDesc.allowNonUniformDerivatives));
+ }
}
#else // TINT_BUILD_SPV_READER
- // SPIR-V is not enabled, so the descriptor should not contain it.
- DAWN_ASSERT(descriptor.Get<ShaderSourceSPIRV>() == nullptr);
+ // SPIR-V is not enabled, so the descriptor should not contain it.
+ DAWN_ASSERT(!std::holds_alternative<ShaderModuleParseSpirvDescription>(req.shaderDescription));
#endif // TINT_BUILD_SPV_READER
// Handling WGSL.
- if (const ShaderSourceWGSL* wgslDesc = descriptor.Get<ShaderSourceWGSL>()) {
- auto tintFile = std::make_unique<tint::Source::File>("", wgslDesc->code);
+ if (std::holds_alternative<ShaderModuleParseWGSLDescription>(req.shaderDescription)) {
+ ShaderModuleParseWGSLDescription wgslDesc =
+ std::get<ShaderModuleParseWGSLDescription>(req.shaderDescription);
+ const StringView& wgsl = wgslDesc.wgsl.UnsafeGetValue();
+ const std::vector<tint::wgsl::Extension>& internalExtensions =
+ wgslDesc.internalExtensions.UnsafeGetValue();
- if (device->IsToggleEnabled(Toggle::DumpShaders)) {
+ auto tintFile = std::make_unique<tint::Source::File>("", wgsl);
+
+ if (dumpShaders) {
std::ostringstream dumpedMsg;
- dumpedMsg << "// Dumped WGSL:\n" << std::string_view(wgslDesc->code) << "\n";
- device->EmitLog(wgpu::LoggingType::Info, dumpedMsg.str().c_str());
+ dumpedMsg << "// Dumped WGSL:\n" << std::string_view(wgsl) << "\n";
+ logEmitter->EmitLog(wgpu::LoggingType::Info, dumpedMsg.str().c_str());
}
- tint::Program program;
- DAWN_TRY_ASSIGN(program, ParseWGSL(tintFile.get(), device->GetWGSLAllowedFeatures(),
- internalExtensions, outMessages));
- parseResult->tintProgram = UnsafeUnserializedValue<std::optional<Ref<TintProgram>>>(
- AcquireRef(new TintProgram(std::move(program), std::move(tintFile))));
+ DAWN_TRY(ParseWGSL(std::move(tintFile), deviceInfo.wgslAllowedFeatures, internalExtensions,
+ &outputParseResult));
}
- // Assert parsed shader are correctly generated.
- DAWN_ASSERT(parseResult->HasTintProgram());
-
- // Generate reflection information if required.
- if (needReflection) {
- parseResult->metadataTable.emplace();
- DAWN_TRY(ReflectShaderUsingTint(device,
- &parseResult->tintProgram.UnsafeGetValue().value()->program,
- &parseResult->metadataTable.value()));
+ // Generate reflection information if required and parsed succeed.
+ if (outputParseResult.HasTintProgram() && req.needReflection) {
+ ReflectShaderUsingTint(deviceInfo, &outputParseResult);
}
- return {};
+ // Assert everything succeed and we have a Tint program, xor validation error occurs and we only
+ // get error with no Tint program (not generated or get removed).
+ DAWN_ASSERT(outputParseResult.HasTintProgram() != outputParseResult.HasError());
+ return outputParseResult;
}
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
@@ -1668,11 +1735,13 @@
DAWN_UNREACHABLE();
}
- ShaderModuleParseResult regeneratedParseResult;
- ParseShaderModule(GetDevice(), Unpack(&descriptor), mInternalExtensions,
- /* needReflection */ false, ®eneratedParseResult)
- .AcquireSuccess();
- DAWN_ASSERT(regeneratedParseResult.HasTintProgram());
+ // Assuming ParseShaderModule will not throw error for regenerating.
+ ShaderModuleParseResult regeneratedParseResult =
+ ParseShaderModule(BuildShaderModuleParseRequest(GetDevice(), mHash, Unpack(&descriptor),
+ mInternalExtensions,
+ /* needReflection */ false))
+ .AcquireSuccess();
+ DAWN_ASSERT(regeneratedParseResult.HasTintProgram() && !regeneratedParseResult.HasError());
tintData->tintProgram =
std::move(regeneratedParseResult.tintProgram.UnsafeGetValue().value());
@@ -1755,6 +1824,7 @@
}
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
+ DAWN_ASSERT(!parseResult->HasError());
if (parseResult->HasTintProgram()) {
mTintData.Use([&](auto tintData) {
tintData->tintProgram = std::move(parseResult->tintProgram.UnsafeGetValue().value());
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 54aea40..fe1ad33 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -50,6 +50,7 @@
#include "dawn/native/CachedObject.h"
#include "dawn/native/CompilationMessages.h"
#include "dawn/native/Error.h"
+#include "dawn/native/ErrorData.h"
#include "dawn/native/Format.h"
#include "dawn/native/Forward.h"
#include "dawn/native/IntegerTypes.h"
@@ -69,6 +70,7 @@
namespace dawn::native {
struct EntryPointMetadata;
+class ShaderModuleParseRequest;
// Base component type of an inter-stage variable
enum class InterStageComponentType {
@@ -113,17 +115,36 @@
const std::unique_ptr<tint::Source::File> file; // Keep the tint::Source::File alive
};
+#define CACHED_VALIDATION_ERROR_MEMBER(X) \
+ X(std::string, message) \
+ X(std::vector<std::string>, contexts)
+// clang-format off
+DAWN_SERIALIZABLE(struct, CachedValidationError, CACHED_VALIDATION_ERROR_MEMBER){
+ CachedValidationError() = default;
+ explicit CachedValidationError(std::unique_ptr<ErrorData>&& errorData);
+ std::unique_ptr<ErrorData> ToErrorData() const;
+};
+// clang-format on
+#undef CACHED_VALIDATION_ERROR_MEMBER
+
// ShaderModuleParseResult is used for shader module creation and can be generated by
// ParseShaderModule or loaded from blob cache.
#define SHADER_MODULE_PARSE_RESULT_MEMBER(X) \
X(UnsafeUnserializedValue<std::optional<Ref<TintProgram>>>, tintProgram) \
/* EntryPointMetadataTable might be unnecessary in cases like Tint Program recreation. */ \
X(std::optional<EntryPointMetadataTable>, metadataTable) \
- X(ParsedCompilationMessages, compilationMessages)
+ X(ParsedCompilationMessages, compilationMessages) \
+ /* Nullopt if no validation error occurs. */ \
+ X(std::optional<CachedValidationError>, cachedValidationError)
DAWN_SERIALIZABLE(struct, ShaderModuleParseResult, SHADER_MODULE_PARSE_RESULT_MEMBER) {
// Check if ShaderModuleParseResult holds a valid tintProgram. A ShaderModuleParseResult loaded
- // from blob cache holds no tintProgram but other information.
+ // from blob cache holds no tintProgram.
bool HasTintProgram() const;
+ // Check if ShaderModuleParseResult holds validation error.
+ bool HasError() const;
+ std::unique_ptr<ErrorData> ToErrorData() const;
+
+ void SetValidationError(std::unique_ptr<ErrorData> && errorData);
};
#undef SHADER_MODULE_PARSE_RESULT_MEMBER
@@ -133,12 +154,11 @@
};
// Parse a shader module from a validated ShaderModuleDescriptor, and generate reflection
-// information if required. Errors are returned only if the shader code itself is invalid.
-MaybeError ParseShaderModule(DeviceBase* device,
- const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
- const std::vector<tint::wgsl::Extension>& internalExtensions,
- bool needReflection,
- ShaderModuleParseResult* parseResult);
+// information if required. Validation errors generated during parsing are also made cacheable and
+// returned within ShaderModuleParseResult together with compilation messages, rather than as an
+// error (i.e. ResultOrError::IsSuccess() is true in this case). Other types of errors still get
+// returned as ErrorData in ResultOrError (i.e. ResultOrError::IsError() is true).
+ResultOrError<ShaderModuleParseResult> ParseShaderModule(ShaderModuleParseRequest req);
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
diff --git a/src/dawn/native/ShaderModuleParseRequest.cpp b/src/dawn/native/ShaderModuleParseRequest.cpp
new file mode 100644
index 0000000..17ec312
--- /dev/null
+++ b/src/dawn/native/ShaderModuleParseRequest.cpp
@@ -0,0 +1,109 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "dawn/native/ShaderModuleParseRequest.h"
+
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "dawn/native/ChainUtils.h"
+#include "dawn/native/Device.h"
+#include "dawn/native/ShaderModule.h"
+
+namespace dawn::native {
+
+WGSLAllowedFeatures WGSLAllowedFeatures::FromTint(tint::wgsl::AllowedFeatures allowedFeatures) {
+ return {{
+ .extensions = std::move(allowedFeatures.extensions),
+ .features = std::move(allowedFeatures.features),
+ }};
+}
+
+tint::wgsl::AllowedFeatures WGSLAllowedFeatures::ToTint() const {
+ return tint::wgsl::AllowedFeatures{
+ .extensions = extensions,
+ .features = features,
+ };
+}
+
+ShaderModuleParseRequest BuildShaderModuleParseRequest(
+ DeviceBase* device,
+ ShaderModuleBase::ShaderModuleHash shaderModuleHash,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
+ bool needReflection) {
+ ShaderModuleParseRequest req;
+ req.logEmitter = UnsafeUnserializedValue<LogEmitter*>(device);
+ req.deviceInfo = {
+ {.toggles = device->GetTogglesState().GetEnabledToggles(),
+ .features = device->GetEnabledFeatures(),
+ .limits = LimitsForShaderModuleParseRequest::Create(device->GetLimits().v1),
+ .wgslAllowedFeatures = WGSLAllowedFeatures::FromTint(device->GetWGSLAllowedFeatures()),
+ .isCompatibilityMode = device->IsCompatibilityMode()}};
+ req.shaderModuleHash = shaderModuleHash;
+ req.needReflection = needReflection;
+
+// Assuming the descriptor chain has already been validated.
+#if TINT_BUILD_SPV_READER
+ // Handling SPIR-V if enabled.
+ if (const auto* spirvDesc = descriptor.Get<ShaderSourceSPIRV>()) {
+ // SpirV toggle should have been validated before chacking cache.
+ DAWN_ASSERT(!device->IsToggleEnabled(Toggle::DisallowSpirv));
+ // Descriptor should not contain WGSL part.
+ DAWN_ASSERT(descriptor.Get<ShaderSourceWGSL>() == nullptr);
+
+ const auto* spirvOptions = descriptor.Get<DawnShaderModuleSPIRVOptionsDescriptor>();
+ DAWN_ASSERT(spirvDesc != nullptr);
+
+ ShaderModuleParseSpirvDescription spirv = {
+ {// TODO(dawn:2033): Avoid unnecessary copies of the SPIR-V code.
+ .spirvCode = UnsafeUnserializedValue(
+ std::vector<uint32_t>(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize)),
+ .allowNonUniformDerivatives =
+ spirvOptions ? static_cast<bool>(spirvOptions->allowNonUniformDerivatives)
+ : false}};
+
+ req.shaderDescription = std::move(spirv);
+ return req;
+ }
+#else // TINT_BUILD_SPV_READER
+ // SPIR-V is not enabled, so the descriptor should not contain it.
+ DAWN_ASSERT(descriptor.Get<ShaderSourceSPIRV>() == nullptr);
+#endif // TINT_BUILD_SPV_READER
+
+ // Handling WGSL.
+ const ShaderSourceWGSL* wgslDesc = descriptor.Get<ShaderSourceWGSL>();
+ DAWN_ASSERT(wgslDesc != nullptr);
+
+ req.shaderDescription = ShaderModuleParseWGSLDescription{
+ {.wgsl = UnsafeUnserializedValue(wgslDesc->code),
+ .internalExtensions = UnsafeUnserializedValue(internalExtensions)}};
+
+ return req;
+}
+} // namespace dawn::native
diff --git a/src/dawn/native/ShaderModuleParseRequest.h b/src/dawn/native/ShaderModuleParseRequest.h
new file mode 100644
index 0000000..d807bfa
--- /dev/null
+++ b/src/dawn/native/ShaderModuleParseRequest.h
@@ -0,0 +1,105 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_DAWN_NATIVE_SHADERMODULEPARSEREQUEST_H_
+#define SRC_DAWN_NATIVE_SHADERMODULEPARSEREQUEST_H_
+
+#include <unordered_set>
+#include <vector>
+
+#include "dawn/native/CacheRequest.h"
+#include "dawn/native/Limits.h"
+#include "dawn/native/Serializable.h"
+#include "dawn/native/ShaderModule.h"
+
+namespace dawn::native {
+
+// Mapping tint::wgsl::AllowedFeatures
+#define WGSL_ALLOWED_FEATURES_MEMBER(X) \
+ X(std::unordered_set<tint::wgsl::Extension>, extensions) \
+ X(std::unordered_set<tint::wgsl::LanguageFeature>, features)
+// clang-format off
+DAWN_SERIALIZABLE(struct, WGSLAllowedFeatures, WGSL_ALLOWED_FEATURES_MEMBER){
+ static WGSLAllowedFeatures FromTint(tint::wgsl::AllowedFeatures allowedFeatures);
+ tint::wgsl::AllowedFeatures ToTint() const;
+};
+// clang-format on
+#undef WGSL_ALLOWED_FEATURES_MEMBER
+
+#define SHADER_MODULE_PARSE_DEVICE_INFO_MEMBER(X) \
+ /* Toggles, features, and limits */ \
+ X(TogglesSet, toggles) \
+ X(FeaturesSet, features) \
+ X(LimitsForShaderModuleParseRequest, limits) \
+ X(WGSLAllowedFeatures, wgslAllowedFeatures) \
+ X(bool, isCompatibilityMode)
+DAWN_SERIALIZABLE(struct, ShaderModuleParseDeviceInfo, SHADER_MODULE_PARSE_DEVICE_INFO_MEMBER){};
+#undef SHADER_MODULE_PARSE_DEVICE_INFO_MEMBER
+
+#define SHADER_MODULE_PARSE_SPIRV_DESCRIPTION_MEMBER(X) \
+ /* Don't need to key the spirv code since it is hashed in shaderModuleHash. */ \
+ X(UnsafeUnserializedValue<std::vector<uint32_t>>, spirvCode) \
+ X(bool, allowNonUniformDerivatives)
+DAWN_SERIALIZABLE(struct,
+ ShaderModuleParseSpirvDescription,
+ SHADER_MODULE_PARSE_SPIRV_DESCRIPTION_MEMBER){};
+#undef SHADER_MODULE_PARSE_SPIRV_DESCRIPTION_MEMBER
+
+#define SHADER_MODULE_PARSE_WGSL_DESCRIPTION_MEMBER(X) \
+ /* Don't need to key the WGSL code and internal extensions since they are */ \
+ /* hashed in shaderModuleHash. */ \
+ X(UnsafeUnserializedValue<StringView>, wgsl) \
+ X(UnsafeUnserializedValue<std::vector<tint::wgsl::Extension>>, internalExtensions)
+DAWN_SERIALIZABLE(struct,
+ ShaderModuleParseWGSLDescription,
+ SHADER_MODULE_PARSE_WGSL_DESCRIPTION_MEMBER){};
+#undef SHADER_MODULE_PARSE_WGSL_DESCRIPTION_MEMBER
+
+using ShaderModuleParseDescriptionVariant =
+ std::variant<ShaderModuleParseSpirvDescription, ShaderModuleParseWGSLDescription>;
+
+#define SHADER_MODULE_PARSE_REQUEST_MEMBER(X) \
+ X(UnsafeUnserializedValue<LogEmitter*>, logEmitter) \
+ X(ShaderModuleParseDeviceInfo, deviceInfo) \
+ X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
+ X(ShaderModuleParseDescriptionVariant, shaderDescription) \
+ X(bool, needReflection)
+
+DAWN_MAKE_CACHE_REQUEST(ShaderModuleParseRequest, SHADER_MODULE_PARSE_REQUEST_MEMBER);
+#undef SHADER_MODULE_PARSE_REQUEST_MEMBER
+
+// Helper function to create a ShaderModuleParseRequest
+ShaderModuleParseRequest BuildShaderModuleParseRequest(
+ DeviceBase* device,
+ ShaderModuleBase::ShaderModuleHash shaderModuleHash,
+ const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
+ const std::vector<tint::wgsl::Extension>& internalExtensions,
+ bool needReflection);
+
+} // namespace dawn::native
+
+#endif // SRC_DAWN_NATIVE_SHADERMODULEPARSEREQUEST_H_
diff --git a/src/dawn/native/SpirvValidation.cpp b/src/dawn/native/SpirvValidation.cpp
index f12cee6..61aac75 100644
--- a/src/dawn/native/SpirvValidation.cpp
+++ b/src/dawn/native/SpirvValidation.cpp
@@ -32,8 +32,6 @@
#include <sstream>
#include <string>
-#include "dawn/native/Device.h"
-
namespace dawn::native {
MaybeError ValidateSpirv(LogEmitter* logEmitter,
diff --git a/src/dawn/native/Texture.h b/src/dawn/native/Texture.h
index 3743655..e9d8877 100644
--- a/src/dawn/native/Texture.h
+++ b/src/dawn/native/Texture.h
@@ -49,6 +49,8 @@
namespace dawn::native {
+class MemoryDump;
+
enum class AllowMultiPlanarTextureFormat {
No,
SingleLayerOnly,
diff --git a/src/dawn/native/Toggles.cpp b/src/dawn/native/Toggles.cpp
index 21bd277..5103ae4 100644
--- a/src/dawn/native/Toggles.cpp
+++ b/src/dawn/native/Toggles.cpp
@@ -872,9 +872,13 @@
return enabledTogglesName;
}
+const TogglesSet& TogglesState::GetEnabledToggles() const {
+ return mEnabledToggles;
+}
+
// Allowing TogglesState to be used in cache key.
void StreamIn(stream::Sink* s, const TogglesState& togglesState) {
- StreamIn(s, togglesState.mEnabledToggles.bitset);
+ StreamIn(s, togglesState.GetEnabledToggles());
}
const char* ToggleEnumToName(Toggle toggle) {
diff --git a/src/dawn/native/Toggles.h b/src/dawn/native/Toggles.h
index e2ea5c6..0c7f919 100644
--- a/src/dawn/native/Toggles.h
+++ b/src/dawn/native/Toggles.h
@@ -35,6 +35,7 @@
#include "absl/container/flat_hash_map.h"
#include "dawn/common/ityp_bitset.h"
#include "dawn/native/DawnNative.h"
+#include "dawn/native/Serializable.h"
namespace dawn::native {
@@ -176,16 +177,22 @@
// A wrapper of the bitset to store if a toggle is present or not. This wrapper provides the
// convenience to convert the enums of enum class Toggle to the indices of a bitset.
-struct TogglesSet {
- ityp::bitset<uint32_t, static_cast<size_t>(Toggle::EnumCount)> bitset;
+using TogglesBitSet = ityp::bitset<uint32_t, static_cast<size_t>(Toggle::EnumCount)>;
+#define TOGGLES_SET_MEMBER(X) X(TogglesBitSet, bitset)
+DAWN_SERIALIZABLE(struct, TogglesSet, TOGGLES_SET_MEMBER) {
using Iterator = ityp::bitset<uint32_t, static_cast<size_t>(Toggle::EnumCount)>::Iterator;
void Set(Toggle toggle, bool enabled);
bool Has(Toggle toggle) const;
size_t Count() const;
- Iterator begin() const { return bitset.begin(); }
- Iterator end() const { return bitset.end(); }
+ Iterator begin() const {
+ return bitset.begin();
+ }
+ Iterator end() const {
+ return bitset.end();
+ }
};
+#undef TOGGLES_SET_MEMBER
namespace stream {
class Sink;
@@ -230,6 +237,7 @@
ToggleStage GetStage() const;
std::vector<const char*> GetEnabledToggleNames() const;
std::vector<const char*> GetDisabledToggleNames() const;
+ const TogglesSet& GetEnabledToggles() const;
// Friend definition of StreamIn which can be found by ADL to override stream::StreamIn<T>. This
// allows writing TogglesState to stream for cache key.
diff --git a/src/dawn/native/WaitListEvent.cpp b/src/dawn/native/WaitListEvent.cpp
index 35198b6..70f4420 100644
--- a/src/dawn/native/WaitListEvent.cpp
+++ b/src/dawn/native/WaitListEvent.cpp
@@ -27,6 +27,10 @@
#include "dawn/native/WaitListEvent.h"
+#include <array>
+
+#include "dawn/common/Ref.h"
+
namespace dawn::native {
WaitListEvent::WaitListEvent() = default;
diff --git a/src/dawn/native/d3d/QueueD3D.cpp b/src/dawn/native/d3d/QueueD3D.cpp
index 5f07c57..a737869 100644
--- a/src/dawn/native/d3d/QueueD3D.cpp
+++ b/src/dawn/native/d3d/QueueD3D.cpp
@@ -28,6 +28,7 @@
#include "dawn/native/d3d/QueueD3D.h"
#include <algorithm>
+#include <array>
#include <utility>
#include "dawn/native/WaitAnySystemEvent.h"
diff --git a/src/dawn/native/metal/CommandRecordingContext.h b/src/dawn/native/metal/CommandRecordingContext.h
index 17833a7..b155aa8 100644
--- a/src/dawn/native/metal/CommandRecordingContext.h
+++ b/src/dawn/native/metal/CommandRecordingContext.h
@@ -30,6 +30,7 @@
#include "dawn/common/NSRef.h"
#include "dawn/common/NonMovable.h"
#include "dawn/native/Error.h"
+#include "partition_alloc/pointers/raw_ptr.h"
#import <Metal/Metal.h>
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 78dab01..76206b8 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -674,6 +674,7 @@
"end2end/ShaderAtomicTests.cpp",
"end2end/ShaderBuiltinPartialConstArgsErrorTests.cpp",
"end2end/ShaderF16Tests.cpp",
+ "end2end/ShaderModuleCachingTests.cpp",
"end2end/ShaderTests.cpp",
"end2end/ShaderValidationTests.cpp",
"end2end/StorageTextureTests.cpp",
diff --git a/src/dawn/tests/end2end/ShaderModuleCachingTests.cpp b/src/dawn/tests/end2end/ShaderModuleCachingTests.cpp
new file mode 100644
index 0000000..f5382a0
--- /dev/null
+++ b/src/dawn/tests/end2end/ShaderModuleCachingTests.cpp
@@ -0,0 +1,166 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <memory>
+#include <string_view>
+
+#include "dawn/tests/DawnTest.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
+#include "dawn/utils/WGPUHelpers.h"
+
+namespace dawn {
+namespace {
+
+using ::testing::NiceMock;
+
+static constexpr std::string_view kComputeShaderDefault = R"(
+ @compute @workgroup_size(1) fn main() {}
+ )";
+
+static constexpr std::string_view kComputeShaderMultipleEntryPoints = R"(
+ @compute @workgroup_size(16) fn main() {}
+ @compute @workgroup_size(64) fn main2() {}
+ )";
+
+class ShaderModuleCachingTests : public DawnTest {
+ protected:
+ std::unique_ptr<platform::Platform> CreateTestPlatform() override {
+ return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
+ }
+
+ NiceMock<CachingInterfaceMock> mMockCache;
+};
+
+// Tests that shader module creation works fine even if the cache is disabled.
+// Note: This tests needs to use more than 1 device since the frontend cache on each device
+// will prevent going out to the blob cache.
+TEST_P(ShaderModuleCachingTests, ShaderModuleNoCache) {
+ mMockCache.Disable();
+
+ // First time should create and since cache is disabled, it should not write out to the
+ // cache.
+ {
+ wgpu::Device device = CreateDevice();
+ EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+
+ // Second time should create fine with no cache hits since cache is disabled.
+ {
+ wgpu::Device device = CreateDevice();
+ EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+}
+
+// Tests that shader module creation on the same device uses frontend cache when possible.
+TEST_P(ShaderModuleCachingTests, ShaderModuleFrontedCache) {
+ // First creation should create a cache entry.
+ wgpu::ShaderModule shaderModule;
+ EXPECT_CACHE_STATS(
+ mMockCache, Hit(0), Add(1),
+ shaderModule = utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+
+ // Second creation on the same device should just return from frontend cache and should not
+ // call out to the blob cache.
+ EXPECT_CALL(mMockCache, LoadData).Times(0);
+ wgpu::ShaderModule sameShaderModule;
+ EXPECT_CACHE_STATS(
+ mMockCache, Hit(0), Add(0),
+ sameShaderModule = utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+
+ EXPECT_EQ(shaderModule.Get() == sameShaderModule.Get(), !UsesWire());
+}
+
+// Tests that shader module creation hits the cache when it is enabled.
+// Note: This test needs to use more than 1 device since the frontend cache on each device
+// will prevent going out to the blob cache.
+TEST_P(ShaderModuleCachingTests, ShaderModuleBlobCache) {
+ // First time should create and write out to the blob cache.
+ {
+ wgpu::Device device = CreateDevice();
+ EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(1),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+
+ // Second time should create shader module using the blob cache.
+ {
+ wgpu::Device device = CreateDevice();
+ EXPECT_CACHE_STATS(mMockCache, Hit(1), Add(0),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+}
+
+// Tests that shader module creation wouldn't hit the cache if the shader modules are not exactly
+// the same.
+TEST_P(ShaderModuleCachingTests, DifferentShaderModuleBlobCache) {
+ // First time should create and write out to the cache.
+ {
+ wgpu::Device device = CreateDevice();
+ EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(1),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+
+ // Cache should not hit: different shader module.
+ {
+ wgpu::Device device = CreateDevice();
+ EXPECT_CACHE_STATS(
+ mMockCache, Hit(0), Add(1),
+ utils::CreateShaderModule(device, kComputeShaderMultipleEntryPoints.data()));
+ }
+}
+
+// Tests that shader module creation does not hits the cache when it is enabled but we use different
+// isolation keys.
+TEST_P(ShaderModuleCachingTests, ShaderModuleBlobCacheIsolationKey) {
+ // First time should create and write out to the cache.
+ {
+ wgpu::Device device = CreateDevice("isolation key 1");
+ EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(1),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+
+ // Second time should also create and write out to the cache.
+ {
+ wgpu::Device device = CreateDevice("isolation key 2");
+ EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(1),
+ utils::CreateShaderModule(device, kComputeShaderDefault.data()));
+ }
+}
+
+DAWN_INSTANTIATE_TEST(ShaderModuleCachingTests,
+ D3D11Backend(),
+ D3D12Backend(),
+ D3D12Backend({"use_dxc"}),
+ MetalBackend(),
+ OpenGLBackend(),
+ OpenGLESBackend(),
+ VulkanBackend());
+
+} // anonymous namespace
+} // namespace dawn
diff --git a/src/dawn/tests/unittests/ChainUtilsTests.cpp b/src/dawn/tests/unittests/ChainUtilsTests.cpp
index 3cca776..819c6af 100644
--- a/src/dawn/tests/unittests/ChainUtilsTests.cpp
+++ b/src/dawn/tests/unittests/ChainUtilsTests.cpp
@@ -29,6 +29,7 @@
#include <gtest/gtest.h>
#include "dawn/native/ChainUtils.h"
+#include "dawn/native/DawnNative.h"
#include "dawn/native/dawn_platform.h"
namespace dawn::native {
diff --git a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
index 748218e..de47fcb 100644
--- a/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/ShaderModuleMock.cpp
@@ -31,6 +31,7 @@
#include <utility>
#include "dawn/native/ChainUtils.h"
+#include "dawn/native/ShaderModuleParseRequest.h"
namespace dawn::native {
@@ -50,12 +51,16 @@
Ref<ShaderModuleMock> ShaderModuleMock::Create(
DeviceMock* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor) {
- ShaderModuleParseResult parseResult{};
- ParseShaderModule(device, descriptor, {}, /* needReflection*/ true, &parseResult)
- .AcquireSuccess();
Ref<ShaderModuleMock> shaderModule =
AcquireRef(new NiceMock<ShaderModuleMock>(device, descriptor));
+
+ ShaderModuleParseResult parseResult =
+ ParseShaderModule(BuildShaderModuleParseRequest(device, shaderModule->GetHash(), descriptor,
+ {},
+ /* needReflection*/ true))
+ .AcquireSuccess();
+
shaderModule->InitializeBase(&parseResult).AcquireSuccess();
return shaderModule;
}