Reland "Dawn: Use no Tint Program at all in shader compilation cache checking"
This reverts commit 22ee238376b8b5be900c802598e6e2514547ad37.
Reason for revert:
The right-shifting rhs not less than bits of lhs issue that cause the
roll into Skia is fixed in this CL.
Bug: 402772740, 402772408
Original change's description:
> Revert "Dawn: Use no Tint Program at all in shader compilation cache checking"
>
> This reverts commit 7fcdc9036915542a08d5b213bb717fea33a125ef.
>
> Reason for revert: This broke the roll into Skia due to UB caused by a right-shift of 64 in the Rotl function:
> https://logs.chromium.org/logs/skia/705dbe4282a5c611/+/steps/dm/0/stdout
>
> Bug: 402772740, 402772408
> Original change's description:
> > Dawn: Use no Tint Program at all in shader compilation cache checking
> >
> > This CL remove Tint Program from all backends' shader compilation cache
> > key and replace it with SHA3-512 hash of shader module. With this CL,
> > Tint Program might not be need until actual cache miss in backend shader
> > compilation or front end WGSL parsing cache miss.
> >
> > Bug: 402772740, 402772408
> > Change-Id: I76ec38c03b15cd3dc4ba8c294dafb3ce9cc61dce
> > Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/232375
> > Reviewed-by: Corentin Wallez <cwallez@chromium.org>
> > Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
> > Auto-Submit: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
>
> TBR=cwallez@chromium.org,geofflang@chromium.org,dawn-scoped@luci-project-accounts.iam.gserviceaccount.com,zhaoming.jiang@microsoft.com
>
> No-Presubmit: true
> No-Tree-Checks: true
> No-Try: true
> Bug: 402772740, 402772408
> Change-Id: I39f25c1a45b02eb42d10374a67a94e53dfb09300
> Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/237734
> Reviewed-by: James Price <jrprice@google.com>
> Reviewed-by: dan sinclair <dsinclair@chromium.org>
> Commit-Queue: James Price <jrprice@google.com>
Bug: 402772740, 402772408
Change-Id: Ia98ced1fdbbcd7822b656886d9ea660bb2899c1c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/237894
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Auto-Submit: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
diff --git a/src/dawn/common/Sha3.cpp b/src/dawn/common/Sha3.cpp
index fd15e81..f727e45 100644
--- a/src/dawn/common/Sha3.cpp
+++ b/src/dawn/common/Sha3.cpp
@@ -53,6 +53,10 @@
// The rotation of a lane by `offset` bits.
Sha3Lane Rotl(Sha3Lane l, size_t offset) {
DAWN_ASSERT(offset < kLaneBitWidth);
+ // Offset should not be 0, as the expected result should be just identical and
+ // right-shifting (kLaneBitWidth - offset) == kLaneBitWidth bits on Sha3Lane (having
+ // kLaneBitWidth bits) results in undefined behavior.
+ DAWN_ASSERT(offset > 0);
return (l << offset) | (l >> (kLaneBitWidth - offset));
}
@@ -105,7 +109,9 @@
}();
void Rho(Sha3State& a) {
- for (uint32_t i = 0; i < 25; i++) {
+ // Rotating starts from i = 1, as kRhoOffsets[0] = 0 and the lane 0 is not rotated.
+ static_assert(kRhoOffsets[0] == 0);
+ for (uint32_t i = 1; i < 25; i++) {
a[i] = Rotl(a[i], kRhoOffsets[i]);
}
}
diff --git a/src/dawn/common/Sha3.h b/src/dawn/common/Sha3.h
index f54f056..1f226dd 100644
--- a/src/dawn/common/Sha3.h
+++ b/src/dawn/common/Sha3.h
@@ -31,6 +31,7 @@
#include <array>
#include <cstddef>
#include <cstdint>
+#include <type_traits>
namespace dawn {
@@ -53,6 +54,14 @@
// APIs to stream data into the hash function chunk by chunk by calling Update repeatedly.
// After Finalize is called, it is no longer valid to use this SHA3 object.
void Update(const void* data, size_t size);
+
+ template <typename T, typename std::enable_if_t<std::is_trivially_copyable_v<T>, bool> = true>
+ void Update(const T& data) {
+ const uint8_t* dataAsBytes = reinterpret_cast<const uint8_t*>(&data);
+ size_t size = sizeof(T);
+ Update(dataAsBytes, size);
+ }
+
Output Finalize();
// Helper function to compute the hash directly.
diff --git a/src/dawn/native/ObjectContentHasher.h b/src/dawn/native/ObjectContentHasher.h
index 993fdf2..a838998 100644
--- a/src/dawn/native/ObjectContentHasher.h
+++ b/src/dawn/native/ObjectContentHasher.h
@@ -75,6 +75,13 @@
}
};
+ template <typename T, size_t N>
+ struct RecordImpl<std::array<T, N>> {
+ static constexpr void Call(ObjectContentHasher* recorder, const std::array<T, N>& array) {
+ recorder->RecordIterable<std::array<T, N>>(array);
+ }
+ };
+
template <typename T, typename E>
struct RecordImpl<std::map<T, E>> {
static constexpr void Call(ObjectContentHasher* recorder, const std::map<T, E>& map) {
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index f7c7a90..a61d41a 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -35,6 +35,7 @@
#include "dawn/common/Constants.h"
#include "dawn/common/MatchVariant.h"
+#include "dawn/common/Sha3.h"
#include "dawn/native/BindGroupLayoutInternal.h"
#include "dawn/native/ChainUtils.h"
#include "dawn/native/CompilationMessages.h"
@@ -1462,7 +1463,6 @@
}
// ShaderModuleBase
-
ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
std::vector<tint::wgsl::Extension> internalExtensions,
@@ -1470,12 +1470,19 @@
: Base(device, descriptor->label),
mType(Type::Undefined),
mInternalExtensions(std::move(internalExtensions)) {
+ size_t shaderCodeByteSize = 0;
+ uint8_t* shaderCode = nullptr;
+
if (auto* spirvDesc = descriptor.Get<ShaderSourceSPIRV>()) {
mType = Type::Spirv;
mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+ shaderCodeByteSize = mOriginalSpirv.size() * sizeof(decltype(mOriginalSpirv)::value_type);
+ shaderCode = reinterpret_cast<uint8_t*>(mOriginalSpirv.data());
} else if (auto* wgslDesc = descriptor.Get<ShaderSourceWGSL>()) {
mType = Type::Wgsl;
mWgsl = std::string(wgslDesc->code);
+ shaderCodeByteSize = mWgsl.size() * sizeof(decltype(mWgsl)::value_type);
+ shaderCode = reinterpret_cast<uint8_t*>(mWgsl.data());
} else {
DAWN_ASSERT(false);
}
@@ -1483,6 +1490,26 @@
if (const auto* compileOptions = descriptor.Get<ShaderModuleCompilationOptions>()) {
mStrictMath = compileOptions->strictMath;
}
+
+ ShaderModuleHasher hasher;
+ // Hash the metadata.
+ hasher.Update(mType);
+ // mStrictMath is a std::optional<bool>, and the bool value might not get initialized by default
+ // constructor and thus contains dirty data.
+ bool strictMathAssigned = mStrictMath.has_value();
+ bool strictMathValue = mStrictMath.value_or(false);
+ hasher.Update(strictMathAssigned);
+ hasher.Update(strictMathValue);
+ // mInternalExtensions is a length-variable vector, so we need to hash its size and its content
+ // if any.
+ hasher.Update(mInternalExtensions.size());
+ hasher.Update(mInternalExtensions.data(),
+ mInternalExtensions.size() * sizeof(decltype(mInternalExtensions)::value_type));
+ // Hash the shader code and its size.
+ hasher.Update(shaderCodeByteSize);
+ hasher.Update(shaderCode, shaderCodeByteSize);
+
+ mHash = hasher.Finalize();
}
ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
@@ -1547,17 +1574,22 @@
size_t ShaderModuleBase::ComputeContentHash() {
ObjectContentHasher recorder;
- recorder.Record(mType);
- recorder.Record(mOriginalSpirv);
- recorder.Record(mWgsl);
- recorder.Record(mStrictMath);
+ // Use mHash to represent the source content, which includes shader source and metadata.
+ recorder.Record(mHash);
return recorder.GetContentHash();
}
bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
const ShaderModuleBase* b) const {
- return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && a->mWgsl == b->mWgsl &&
- a->mStrictMath == b->mStrictMath;
+ bool membersEq = a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv &&
+ a->mWgsl == b->mWgsl && a->mStrictMath == b->mStrictMath;
+ // Assert that the hash is equal if and only if the members are equal.
+ DAWN_ASSERT(membersEq == (a->mHash == b->mHash));
+ return membersEq;
+}
+
+const ShaderModuleBase::ShaderModuleHash& ShaderModuleBase::GetHash() const {
+ return mHash;
}
ShaderModuleBase::ScopedUseTintProgram ShaderModuleBase::UseTintProgram() {
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 6579c89..992a1e7 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -44,6 +44,7 @@
#include "dawn/common/ContentLessObjectCacheable.h"
#include "dawn/common/MutexProtected.h"
#include "dawn/common/RefCountedWithExternalCount.h"
+#include "dawn/common/Sha3.h"
#include "dawn/common/ityp_array.h"
#include "dawn/native/BindingInfo.h"
#include "dawn/native/CachedObject.h"
@@ -370,6 +371,10 @@
std::optional<bool> GetStrictMath() const;
+ using ShaderModuleHasher = Sha3_512;
+ using ShaderModuleHash = ShaderModuleHasher::Output;
+ const ShaderModuleHash& GetHash() const;
+
using ScopedUseTintProgram = APIRef<ShaderModuleBase>;
ScopedUseTintProgram UseTintProgram();
@@ -406,6 +411,10 @@
std::vector<uint32_t> mOriginalSpirv;
std::string mWgsl;
+ // Secure hash computed from shader code and other metadata to be used as a cache key
+ // representing the shader module.
+ ShaderModuleHash mHash;
+
// TODO(dawn:2503): Remove the optional when Dawn can has a consistent default across backends.
// Right now D3D uses strictness by default, and Vulkan/Metal use fast math by default.
std::optional<bool> mStrictMath;
diff --git a/src/dawn/native/d3d/D3DCompilationRequest.h b/src/dawn/native/d3d/D3DCompilationRequest.h
index 69bbbf2..a8d04d8 100644
--- a/src/dawn/native/d3d/D3DCompilationRequest.h
+++ b/src/dawn/native/d3d/D3DCompilationRequest.h
@@ -60,40 +60,38 @@
using InterStageShaderVariablesMask = std::bitset<tint::hlsl::writer::kMaxInterStageLocations>;
using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>;
-#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
- X(const tint::Program*, inputProgram) \
- X(std::string_view, entryPointName) \
- X(SingleShaderStage, stage) \
- X(uint32_t, shaderModel) \
- X(uint32_t, compileFlags) \
- X(Compiler, compiler) \
- X(uint64_t, compilerVersion) \
- X(std::wstring_view, dxcShaderProfile) \
- X(std::string_view, fxcShaderProfile) \
- X(pD3DCompile, d3dCompile) \
- X(IDxcLibrary*, dxcLibrary) \
- X(IDxcCompiler3*, dxcCompiler) \
- X(uint32_t, firstIndexOffsetShaderRegister) \
- X(uint32_t, firstIndexOffsetRegisterSpace) \
- X(tint::hlsl::writer::Options, tintOptions) \
- X(SubstituteOverrideConfig, substituteOverrideConfig) \
- X(LimitsForCompilationRequest, limits) \
- X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
- X(uint32_t, maxSubgroupSize) \
- X(bool, disableSymbolRenaming) \
- X(bool, dumpShaders) \
+#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
+ X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
+ X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
+ X(std::string_view, entryPointName) \
+ X(SingleShaderStage, stage) \
+ X(uint32_t, shaderModel) \
+ X(uint32_t, compileFlags) \
+ X(Compiler, compiler) \
+ X(uint64_t, compilerVersion) \
+ X(std::wstring_view, dxcShaderProfile) \
+ X(std::string_view, fxcShaderProfile) \
+ X(uint32_t, firstIndexOffsetShaderRegister) \
+ X(uint32_t, firstIndexOffsetRegisterSpace) \
+ X(tint::hlsl::writer::Options, tintOptions) \
+ X(SubstituteOverrideConfig, substituteOverrideConfig) \
+ X(LimitsForCompilationRequest, limits) \
+ X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
+ X(uint32_t, maxSubgroupSize) \
+ X(bool, disableSymbolRenaming) \
+ X(bool, dumpShaders) \
X(bool, useTintIR)
-#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
- X(bool, hasShaderF16Feature) \
- X(uint32_t, compileFlags) \
- X(Compiler, compiler) \
- X(uint64_t, compilerVersion) \
- X(std::wstring_view, dxcShaderProfile) \
- X(std::string_view, fxcShaderProfile) \
- X(pD3DCompile, d3dCompile) \
- X(IDxcLibrary*, dxcLibrary) \
- X(IDxcCompiler3*, dxcCompiler)
+#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
+ X(bool, hasShaderF16Feature) \
+ X(uint32_t, compileFlags) \
+ X(Compiler, compiler) \
+ X(uint64_t, compilerVersion) \
+ X(std::wstring_view, dxcShaderProfile) \
+ X(std::string_view, fxcShaderProfile) \
+ X(CacheKey::UnsafeUnkeyedValue<pD3DCompile>, d3dCompile) \
+ X(CacheKey::UnsafeUnkeyedValue<IDxcLibrary*>, dxcLibrary) \
+ X(CacheKey::UnsafeUnkeyedValue<IDxcCompiler3*>, dxcCompiler)
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
#undef HLSL_COMPILATION_REQUEST_MEMBERS
diff --git a/src/dawn/native/d3d/ShaderUtils.cpp b/src/dawn/native/d3d/ShaderUtils.cpp
index db35157..1c1a4ab 100644
--- a/src/dawn/native/d3d/ShaderUtils.cpp
+++ b/src/dawn/native/d3d/ShaderUtils.cpp
@@ -178,9 +178,10 @@
// pointers in this vector don't have static lifetime.
std::vector<const wchar_t*> arguments = GetDXCArguments(entryPointW, r);
ComPtr<IDxcResult> result;
- DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(&dxcBuffer, arguments.data(), arguments.size(),
- nullptr, IID_PPV_ARGS(&result)),
- "DXC compile"));
+ DAWN_TRY(CheckHRESULT(
+ r.dxcCompiler.UnsafeGetValue()->Compile(&dxcBuffer, arguments.data(), arguments.size(),
+ nullptr, IID_PPV_ARGS(&result)),
+ "DXC compile"));
HRESULT hr;
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
@@ -210,9 +211,9 @@
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
- auto result = r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr,
- entryPointName.c_str(), r.fxcShaderProfile.data(), r.compileFlags, 0,
- &compiledShader, &errors);
+ auto result = r.d3dCompile.UnsafeGetValue()(
+ hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr, entryPointName.c_str(),
+ r.fxcShaderProfile.data(), r.compileFlags, 0, &compiledShader, &errors);
if (FAILED(result)) {
const char* resultAsString = HRESULTAsString(result);
@@ -269,12 +270,15 @@
transformInputs.Add<tint::ast::transform::SubstituteOverride::Config>(cfg);
}
+ // Requires Tint Program here right before actual using.
+ auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram();
+ const tint::Program* tintInputProgram = &(inputProgram->program);
tint::Program transformedProgram;
tint::ast::transform::DataMap transformOutputs;
if (!r.useTintIR) {
TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(transformedProgram,
- RunTransforms(&transformManager, r.inputProgram, transformInputs,
+ RunTransforms(&transformManager, tintInputProgram, transformInputs,
&transformOutputs, nullptr));
}
@@ -293,7 +297,7 @@
tint::Result<tint::hlsl::writer::Output> result;
if (r.useTintIR) {
// Convert the AST program to an IR module.
- auto ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram);
+ auto ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram);
DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s",
ir.Failure().reason);
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
index 57f5a3a..20ae98c 100644
--- a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
+++ b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
@@ -101,7 +101,7 @@
// D3D11 only supports FXC.
req.bytecode.compiler = d3d::Compiler::FXC;
- req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile;
+ req.bytecode.d3dCompile = std::move(pD3DCompile{device->GetFunctions()->d3dCompile});
req.bytecode.compilerVersion = D3D_COMPILER_VERSION;
DAWN_ASSERT(device->GetDeviceInfo().shaderModel == 50);
switch (stage) {
@@ -188,8 +188,8 @@
}
}
- auto tintProgram = GetTintProgram();
- req.hlsl.inputProgram = &(tintProgram->program);
+ req.hlsl.shaderModuleHash = GetHash();
+ req.hlsl.inputProgram = UseTintProgram();
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
req.hlsl.stage = stage;
// Put the firstIndex into the internally reserved group and binding to avoid conflicting with
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index c5ffcca..4adf5d5 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -159,7 +159,7 @@
req.bytecode.dxcShaderProfile = device->GetDxcShaderProfiles()[stage];
} else {
req.bytecode.compiler = d3d::Compiler::FXC;
- req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile;
+ req.bytecode.d3dCompile = std::move(pD3DCompile{device->GetFunctions()->d3dCompile});
req.bytecode.compilerVersion = D3D_COMPILER_VERSION;
switch (stage) {
case SingleShaderStage::Vertex:
@@ -320,8 +320,8 @@
}
}
- auto tintProgram = GetTintProgram();
- req.hlsl.inputProgram = &(tintProgram->program);
+ req.hlsl.shaderModuleHash = GetHash();
+ req.hlsl.inputProgram = UseTintProgram();
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
req.hlsl.stage = stage;
if (!useTintIR) {
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index d8c0361..8047db9 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -58,17 +58,18 @@
using OptionalVertexPullingTransformConfig = std::optional<tint::VertexPullingConfig>;
using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>;
-#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
- X(SingleShaderStage, stage) \
- X(const tint::Program*, inputProgram) \
- X(SubstituteOverrideConfig, substituteOverrideConfig) \
- X(LimitsForCompilationRequest, limits) \
- X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
- X(uint32_t, maxSubgroupSize) \
- X(std::string, entryPointName) \
- X(bool, usesSubgroupMatrix) \
- X(bool, disableSymbolRenaming) \
- X(tint::msl::writer::Options, tintOptions) \
+#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
+ X(SingleShaderStage, stage) \
+ X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
+ X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
+ X(SubstituteOverrideConfig, substituteOverrideConfig) \
+ X(LimitsForCompilationRequest, limits) \
+ X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
+ X(uint32_t, maxSubgroupSize) \
+ X(std::string, entryPointName) \
+ X(bool, usesSubgroupMatrix) \
+ X(bool, disableSymbolRenaming) \
+ X(tint::msl::writer::Options, tintOptions) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
@@ -266,8 +267,8 @@
MslCompilationRequest req = {};
req.stage = stage;
- auto tintProgram = programmableStage.module->GetTintProgram();
- req.inputProgram = &(tintProgram->program);
+ req.shaderModuleHash = programmableStage.module->GetHash();
+ req.inputProgram = programmableStage.module->UseTintProgram();
req.substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
req.entryPointName = programmableStage.entryPoint.c_str();
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
@@ -302,8 +303,11 @@
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
[](MslCompilationRequest r) -> ResultOrError<MslCompilation> {
TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::msl::writer::Generate");
+ // Requires Tint Program here right before actual using.
+ auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram();
+ const tint::Program* tintInputProgram = &(inputProgram->program);
// Convert the AST program to an IR module.
- auto ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram);
+ auto ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram);
DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s",
ir.Failure().reason);
diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp
index ee9adbd..5727c93 100644
--- a/src/dawn/native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn/native/opengl/ShaderModuleGL.cpp
@@ -92,16 +92,17 @@
using InterstageLocationAndName = std::pair<uint32_t, std::string>;
using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>;
-#define GLSL_COMPILATION_REQUEST_MEMBERS(X) \
- X(const tint::Program*, inputProgram) \
- X(std::string, entryPointName) \
- X(SingleShaderStage, stage) \
- X(SubstituteOverrideConfig, substituteOverrideConfig) \
- X(LimitsForCompilationRequest, limits) \
- X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
- X(bool, disableSymbolRenaming) \
- X(std::vector<InterstageLocationAndName>, interstageVariables) \
- X(tint::glsl::writer::Options, tintOptions) \
+#define GLSL_COMPILATION_REQUEST_MEMBERS(X) \
+ X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
+ X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
+ X(std::string, entryPointName) \
+ X(SingleShaderStage, stage) \
+ X(SubstituteOverrideConfig, substituteOverrideConfig) \
+ X(LimitsForCompilationRequest, limits) \
+ X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
+ X(bool, disableSymbolRenaming) \
+ X(std::vector<InterstageLocationAndName>, interstageVariables) \
+ X(tint::glsl::writer::Options, tintOptions) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)
DAWN_MAKE_CACHE_REQUEST(GLSLCompilationRequest, GLSL_COMPILATION_REQUEST_MEMBERS);
@@ -464,8 +465,8 @@
GLSLCompilationRequest req = {};
- auto tintProgram = GetTintProgram();
- req.inputProgram = &(tintProgram->program);
+ req.shaderModuleHash = GetHash();
+ req.inputProgram = UseTintProgram();
// Since (non-Vulkan) GLSL does not support descriptor sets, generate a
// mapping from the original group/binding pair to a binding-only
@@ -562,8 +563,11 @@
DAWN_TRY_LOAD_OR_RUN(
compilationResult, GetDevice(), std::move(req), GLSLCompilation::FromBlob,
[](GLSLCompilationRequest r) -> ResultOrError<GLSLCompilation> {
+ // Requires Tint Program here right before actual using.
+ auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram();
+ const tint::Program* tintInputProgram = &(inputProgram->program);
// Convert the AST program to an IR module.
- auto ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram);
+ auto ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram);
DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s",
ir.Failure().reason);
diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h
index 8dde2be..77106f9 100644
--- a/src/dawn/native/stream/Stream.h
+++ b/src/dawn/native/stream/Stream.h
@@ -304,6 +304,44 @@
}
};
+// Stream specialization for std::array<T, Size> of fundamental types T.
+template <typename T, size_t Size>
+class Stream<std::array<T, Size>, std::enable_if_t<std::is_fundamental_v<T>>> {
+ public:
+ static void Write(Sink* s, const std::array<T, Size>& t) {
+ static_assert(Size > 0);
+ memcpy(s->GetSpace(sizeof(t)), t.data(), sizeof(t));
+ }
+
+ static MaybeError Read(Source* s, std::array<T, Size>* t) {
+ static_assert(Size > 0);
+ const void* ptr;
+ DAWN_TRY(s->Read(&ptr, sizeof(*t)));
+ memcpy(t->data(), ptr, sizeof(*t));
+ return {};
+ }
+};
+
+// Stream specialization for std::array<T, Size> of non-fundamental types T.
+template <typename T, size_t Size>
+class Stream<std::array<T, Size>, std::enable_if_t<!std::is_fundamental_v<T>>> {
+ public:
+ static void Write(Sink* s, const std::array<T, Size>& v) {
+ static_assert(Size > 0);
+ for (const T& it : v) {
+ StreamIn(s, it);
+ }
+ }
+
+ static MaybeError Read(Source* s, std::array<T, Size>* v) {
+ static_assert(Size > 0);
+ for (auto& el : *v) {
+ DAWN_TRY(StreamOut(s, el));
+ }
+ return {};
+ }
+};
+
// Stream specialization for std::pair.
template <typename A, typename B>
class Stream<std::pair<A, B>> {
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 5792978..b7ca147 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -203,16 +203,17 @@
using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>;
-#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
- X(SingleShaderStage, stage) \
- X(const tint::Program*, inputProgram) \
- X(SubstituteOverrideConfig, substituteOverrideConfig) \
- X(LimitsForCompilationRequest, limits) \
- X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
- X(uint32_t, maxSubgroupSize) \
- X(std::string_view, entryPointName) \
- X(bool, usesSubgroupMatrix) \
- X(tint::spirv::writer::Options, tintOptions) \
+#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
+ X(SingleShaderStage, stage) \
+ X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
+ X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
+ X(SubstituteOverrideConfig, substituteOverrideConfig) \
+ X(LimitsForCompilationRequest, limits) \
+ X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
+ X(uint32_t, maxSubgroupSize) \
+ X(std::string_view, entryPointName) \
+ X(bool, usesSubgroupMatrix) \
+ X(tint::spirv::writer::Options, tintOptions) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
@@ -341,8 +342,8 @@
SpirvCompilationRequest req = {};
req.stage = stage;
- auto tintProgram = GetTintProgram();
- req.inputProgram = &(tintProgram->program);
+ req.shaderModuleHash = GetHash();
+ req.inputProgram = UseTintProgram();
req.entryPointName = programmableStage.entryPoint;
req.platform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
req.substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
@@ -408,12 +409,15 @@
[](SpirvCompilationRequest r) -> ResultOrError<CompiledSpirv> {
TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::spirv::writer::Generate()");
+ // Requires Tint Program here right before actual using.
+ auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram();
+ const tint::Program* tintInputProgram = &(inputProgram->program);
// Convert the AST program to an IR module.
tint::Result<tint::core::ir::Module> ir;
{
SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(r.platform.UnsafeGetValue(),
"ShaderModuleProgramToIR");
- ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram);
+ ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram);
DAWN_INVALID_IF(ir != tint::Success,
"An error occurred while generating Tint IR\n%s",
ir.Failure().reason);
diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp
index c100f39..4bc3a82 100644
--- a/src/dawn/tests/unittests/native/StreamTests.cpp
+++ b/src/dawn/tests/unittests/native/StreamTests.cpp
@@ -25,6 +25,7 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#include <array>
#include <cstring>
#include <iomanip>
#include <limits>
@@ -441,7 +442,12 @@
// Test unordered_sets.
std::vector<std::unordered_set<int>>{{}, {4, 6, 99, 0}, {100, 300, 300}},
// Test vectors.
- std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}});
+ std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}},
+ // Test different size of arrays.
+ std::vector<std::array<int, 3>>{{1, 5, 2}, {-3, -3, -3}},
+ std::vector<std::array<uint8_t, 5>>{{5, 2, 7, 9, 6}, {3, 3, 3, 3, 42}},
+ // array of non-fundamental type
+ std::vector<std::array<std::string, 2>>{{"abcd", "efg"}, {"123hij", ""}});
static auto kStreamValueInitListParams = std::make_tuple(
std::initializer_list<char[12]>{"test string", "string test"},