Reland "Dawn: Use no Tint Program at all in shader compilation cache checking" This reverts commit 22ee238376b8b5be900c802598e6e2514547ad37. Reason for revert: The right-shifting rhs not less than bits of lhs issue that cause the roll into Skia is fixed in this CL. Bug: 402772740, 402772408 Original change's description: > Revert "Dawn: Use no Tint Program at all in shader compilation cache checking" > > This reverts commit 7fcdc9036915542a08d5b213bb717fea33a125ef. > > Reason for revert: This broke the roll into Skia due to UB caused by a right-shift of 64 in the Rotl function: > https://logs.chromium.org/logs/skia/705dbe4282a5c611/+/steps/dm/0/stdout > > Bug: 402772740, 402772408 > Original change's description: > > Dawn: Use no Tint Program at all in shader compilation cache checking > > > > This CL remove Tint Program from all backends' shader compilation cache > > key and replace it with SHA3-512 hash of shader module. With this CL, > > Tint Program might not be need until actual cache miss in backend shader > > compilation or front end WGSL parsing cache miss. > > > > Bug: 402772740, 402772408 > > Change-Id: I76ec38c03b15cd3dc4ba8c294dafb3ce9cc61dce > > Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/232375 > > Reviewed-by: Corentin Wallez <cwallez@chromium.org> > > Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com> > > Auto-Submit: Zhaoming Jiang <zhaoming.jiang@microsoft.com> > > TBR=cwallez@chromium.org,geofflang@chromium.org,dawn-scoped@luci-project-accounts.iam.gserviceaccount.com,zhaoming.jiang@microsoft.com > > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: 402772740, 402772408 > Change-Id: I39f25c1a45b02eb42d10374a67a94e53dfb09300 > Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/237734 > Reviewed-by: James Price <jrprice@google.com> > Reviewed-by: dan sinclair <dsinclair@chromium.org> > Commit-Queue: James Price <jrprice@google.com> Bug: 402772740, 402772408 Change-Id: Ia98ced1fdbbcd7822b656886d9ea660bb2899c1c Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/237894 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Auto-Submit: Zhaoming Jiang <zhaoming.jiang@microsoft.com> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
diff --git a/src/dawn/common/Sha3.cpp b/src/dawn/common/Sha3.cpp index fd15e81..f727e45 100644 --- a/src/dawn/common/Sha3.cpp +++ b/src/dawn/common/Sha3.cpp
@@ -53,6 +53,10 @@ // The rotation of a lane by `offset` bits. Sha3Lane Rotl(Sha3Lane l, size_t offset) { DAWN_ASSERT(offset < kLaneBitWidth); + // Offset should not be 0, as the expected result should be just identical and + // right-shifting (kLaneBitWidth - offset) == kLaneBitWidth bits on Sha3Lane (having + // kLaneBitWidth bits) results in undefined behavior. + DAWN_ASSERT(offset > 0); return (l << offset) | (l >> (kLaneBitWidth - offset)); } @@ -105,7 +109,9 @@ }(); void Rho(Sha3State& a) { - for (uint32_t i = 0; i < 25; i++) { + // Rotating starts from i = 1, as kRhoOffsets[0] = 0 and the lane 0 is not rotated. + static_assert(kRhoOffsets[0] == 0); + for (uint32_t i = 1; i < 25; i++) { a[i] = Rotl(a[i], kRhoOffsets[i]); } }
diff --git a/src/dawn/common/Sha3.h b/src/dawn/common/Sha3.h index f54f056..1f226dd 100644 --- a/src/dawn/common/Sha3.h +++ b/src/dawn/common/Sha3.h
@@ -31,6 +31,7 @@ #include <array> #include <cstddef> #include <cstdint> +#include <type_traits> namespace dawn { @@ -53,6 +54,14 @@ // APIs to stream data into the hash function chunk by chunk by calling Update repeatedly. // After Finalize is called, it is no longer valid to use this SHA3 object. void Update(const void* data, size_t size); + + template <typename T, typename std::enable_if_t<std::is_trivially_copyable_v<T>, bool> = true> + void Update(const T& data) { + const uint8_t* dataAsBytes = reinterpret_cast<const uint8_t*>(&data); + size_t size = sizeof(T); + Update(dataAsBytes, size); + } + Output Finalize(); // Helper function to compute the hash directly.
diff --git a/src/dawn/native/ObjectContentHasher.h b/src/dawn/native/ObjectContentHasher.h index 993fdf2..a838998 100644 --- a/src/dawn/native/ObjectContentHasher.h +++ b/src/dawn/native/ObjectContentHasher.h
@@ -75,6 +75,13 @@ } }; + template <typename T, size_t N> + struct RecordImpl<std::array<T, N>> { + static constexpr void Call(ObjectContentHasher* recorder, const std::array<T, N>& array) { + recorder->RecordIterable<std::array<T, N>>(array); + } + }; + template <typename T, typename E> struct RecordImpl<std::map<T, E>> { static constexpr void Call(ObjectContentHasher* recorder, const std::map<T, E>& map) {
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index f7c7a90..a61d41a 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp
@@ -35,6 +35,7 @@ #include "dawn/common/Constants.h" #include "dawn/common/MatchVariant.h" +#include "dawn/common/Sha3.h" #include "dawn/native/BindGroupLayoutInternal.h" #include "dawn/native/ChainUtils.h" #include "dawn/native/CompilationMessages.h" @@ -1462,7 +1463,6 @@ } // ShaderModuleBase - ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const UnpackedPtr<ShaderModuleDescriptor>& descriptor, std::vector<tint::wgsl::Extension> internalExtensions, @@ -1470,12 +1470,19 @@ : Base(device, descriptor->label), mType(Type::Undefined), mInternalExtensions(std::move(internalExtensions)) { + size_t shaderCodeByteSize = 0; + uint8_t* shaderCode = nullptr; + if (auto* spirvDesc = descriptor.Get<ShaderSourceSPIRV>()) { mType = Type::Spirv; mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); + shaderCodeByteSize = mOriginalSpirv.size() * sizeof(decltype(mOriginalSpirv)::value_type); + shaderCode = reinterpret_cast<uint8_t*>(mOriginalSpirv.data()); } else if (auto* wgslDesc = descriptor.Get<ShaderSourceWGSL>()) { mType = Type::Wgsl; mWgsl = std::string(wgslDesc->code); + shaderCodeByteSize = mWgsl.size() * sizeof(decltype(mWgsl)::value_type); + shaderCode = reinterpret_cast<uint8_t*>(mWgsl.data()); } else { DAWN_ASSERT(false); } @@ -1483,6 +1490,26 @@ if (const auto* compileOptions = descriptor.Get<ShaderModuleCompilationOptions>()) { mStrictMath = compileOptions->strictMath; } + + ShaderModuleHasher hasher; + // Hash the metadata. + hasher.Update(mType); + // mStrictMath is a std::optional<bool>, and the bool value might not get initialized by default + // constructor and thus contains dirty data. + bool strictMathAssigned = mStrictMath.has_value(); + bool strictMathValue = mStrictMath.value_or(false); + hasher.Update(strictMathAssigned); + hasher.Update(strictMathValue); + // mInternalExtensions is a length-variable vector, so we need to hash its size and its content + // if any. + hasher.Update(mInternalExtensions.size()); + hasher.Update(mInternalExtensions.data(), + mInternalExtensions.size() * sizeof(decltype(mInternalExtensions)::value_type)); + // Hash the shader code and its size. + hasher.Update(shaderCodeByteSize); + hasher.Update(shaderCode, shaderCodeByteSize); + + mHash = hasher.Finalize(); } ShaderModuleBase::ShaderModuleBase(DeviceBase* device, @@ -1547,17 +1574,22 @@ size_t ShaderModuleBase::ComputeContentHash() { ObjectContentHasher recorder; - recorder.Record(mType); - recorder.Record(mOriginalSpirv); - recorder.Record(mWgsl); - recorder.Record(mStrictMath); + // Use mHash to represent the source content, which includes shader source and metadata. + recorder.Record(mHash); return recorder.GetContentHash(); } bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const { - return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && a->mWgsl == b->mWgsl && - a->mStrictMath == b->mStrictMath; + bool membersEq = a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && + a->mWgsl == b->mWgsl && a->mStrictMath == b->mStrictMath; + // Assert that the hash is equal if and only if the members are equal. + DAWN_ASSERT(membersEq == (a->mHash == b->mHash)); + return membersEq; +} + +const ShaderModuleBase::ShaderModuleHash& ShaderModuleBase::GetHash() const { + return mHash; } ShaderModuleBase::ScopedUseTintProgram ShaderModuleBase::UseTintProgram() {
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index 6579c89..992a1e7 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h
@@ -44,6 +44,7 @@ #include "dawn/common/ContentLessObjectCacheable.h" #include "dawn/common/MutexProtected.h" #include "dawn/common/RefCountedWithExternalCount.h" +#include "dawn/common/Sha3.h" #include "dawn/common/ityp_array.h" #include "dawn/native/BindingInfo.h" #include "dawn/native/CachedObject.h" @@ -370,6 +371,10 @@ std::optional<bool> GetStrictMath() const; + using ShaderModuleHasher = Sha3_512; + using ShaderModuleHash = ShaderModuleHasher::Output; + const ShaderModuleHash& GetHash() const; + using ScopedUseTintProgram = APIRef<ShaderModuleBase>; ScopedUseTintProgram UseTintProgram(); @@ -406,6 +411,10 @@ std::vector<uint32_t> mOriginalSpirv; std::string mWgsl; + // Secure hash computed from shader code and other metadata to be used as a cache key + // representing the shader module. + ShaderModuleHash mHash; + // TODO(dawn:2503): Remove the optional when Dawn can has a consistent default across backends. // Right now D3D uses strictness by default, and Vulkan/Metal use fast math by default. std::optional<bool> mStrictMath;
diff --git a/src/dawn/native/d3d/D3DCompilationRequest.h b/src/dawn/native/d3d/D3DCompilationRequest.h index 69bbbf2..a8d04d8 100644 --- a/src/dawn/native/d3d/D3DCompilationRequest.h +++ b/src/dawn/native/d3d/D3DCompilationRequest.h
@@ -60,40 +60,38 @@ using InterStageShaderVariablesMask = std::bitset<tint::hlsl::writer::kMaxInterStageLocations>; using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>; -#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(std::string_view, entryPointName) \ - X(SingleShaderStage, stage) \ - X(uint32_t, shaderModel) \ - X(uint32_t, compileFlags) \ - X(Compiler, compiler) \ - X(uint64_t, compilerVersion) \ - X(std::wstring_view, dxcShaderProfile) \ - X(std::string_view, fxcShaderProfile) \ - X(pD3DCompile, d3dCompile) \ - X(IDxcLibrary*, dxcLibrary) \ - X(IDxcCompiler3*, dxcCompiler) \ - X(uint32_t, firstIndexOffsetShaderRegister) \ - X(uint32_t, firstIndexOffsetRegisterSpace) \ - X(tint::hlsl::writer::Options, tintOptions) \ - X(SubstituteOverrideConfig, substituteOverrideConfig) \ - X(LimitsForCompilationRequest, limits) \ - X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ - X(uint32_t, maxSubgroupSize) \ - X(bool, disableSymbolRenaming) \ - X(bool, dumpShaders) \ +#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \ + X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \ + X(std::string_view, entryPointName) \ + X(SingleShaderStage, stage) \ + X(uint32_t, shaderModel) \ + X(uint32_t, compileFlags) \ + X(Compiler, compiler) \ + X(uint64_t, compilerVersion) \ + X(std::wstring_view, dxcShaderProfile) \ + X(std::string_view, fxcShaderProfile) \ + X(uint32_t, firstIndexOffsetShaderRegister) \ + X(uint32_t, firstIndexOffsetRegisterSpace) \ + X(tint::hlsl::writer::Options, tintOptions) \ + X(SubstituteOverrideConfig, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ + X(uint32_t, maxSubgroupSize) \ + X(bool, disableSymbolRenaming) \ + X(bool, dumpShaders) \ X(bool, useTintIR) -#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \ - X(bool, hasShaderF16Feature) \ - X(uint32_t, compileFlags) \ - X(Compiler, compiler) \ - X(uint64_t, compilerVersion) \ - X(std::wstring_view, dxcShaderProfile) \ - X(std::string_view, fxcShaderProfile) \ - X(pD3DCompile, d3dCompile) \ - X(IDxcLibrary*, dxcLibrary) \ - X(IDxcCompiler3*, dxcCompiler) +#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \ + X(bool, hasShaderF16Feature) \ + X(uint32_t, compileFlags) \ + X(Compiler, compiler) \ + X(uint64_t, compilerVersion) \ + X(std::wstring_view, dxcShaderProfile) \ + X(std::string_view, fxcShaderProfile) \ + X(CacheKey::UnsafeUnkeyedValue<pD3DCompile>, d3dCompile) \ + X(CacheKey::UnsafeUnkeyedValue<IDxcLibrary*>, dxcLibrary) \ + X(CacheKey::UnsafeUnkeyedValue<IDxcCompiler3*>, dxcCompiler) DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){}; #undef HLSL_COMPILATION_REQUEST_MEMBERS
diff --git a/src/dawn/native/d3d/ShaderUtils.cpp b/src/dawn/native/d3d/ShaderUtils.cpp index db35157..1c1a4ab 100644 --- a/src/dawn/native/d3d/ShaderUtils.cpp +++ b/src/dawn/native/d3d/ShaderUtils.cpp
@@ -178,9 +178,10 @@ // pointers in this vector don't have static lifetime. std::vector<const wchar_t*> arguments = GetDXCArguments(entryPointW, r); ComPtr<IDxcResult> result; - DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(&dxcBuffer, arguments.data(), arguments.size(), - nullptr, IID_PPV_ARGS(&result)), - "DXC compile")); + DAWN_TRY(CheckHRESULT( + r.dxcCompiler.UnsafeGetValue()->Compile(&dxcBuffer, arguments.data(), arguments.size(), + nullptr, IID_PPV_ARGS(&result)), + "DXC compile")); HRESULT hr; DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); @@ -210,9 +211,9 @@ ComPtr<ID3DBlob> compiledShader; ComPtr<ID3DBlob> errors; - auto result = r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr, - entryPointName.c_str(), r.fxcShaderProfile.data(), r.compileFlags, 0, - &compiledShader, &errors); + auto result = r.d3dCompile.UnsafeGetValue()( + hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr, entryPointName.c_str(), + r.fxcShaderProfile.data(), r.compileFlags, 0, &compiledShader, &errors); if (FAILED(result)) { const char* resultAsString = HRESULTAsString(result); @@ -269,12 +270,15 @@ transformInputs.Add<tint::ast::transform::SubstituteOverride::Config>(cfg); } + // Requires Tint Program here right before actual using. + auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram(); + const tint::Program* tintInputProgram = &(inputProgram->program); tint::Program transformedProgram; tint::ast::transform::DataMap transformOutputs; if (!r.useTintIR) { TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms"); DAWN_TRY_ASSIGN(transformedProgram, - RunTransforms(&transformManager, r.inputProgram, transformInputs, + RunTransforms(&transformManager, tintInputProgram, transformInputs, &transformOutputs, nullptr)); } @@ -293,7 +297,7 @@ tint::Result<tint::hlsl::writer::Output> result; if (r.useTintIR) { // Convert the AST program to an IR module. - auto ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram); + auto ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram); DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s", ir.Failure().reason);
diff --git a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp index 57f5a3a..20ae98c 100644 --- a/src/dawn/native/d3d11/ShaderModuleD3D11.cpp +++ b/src/dawn/native/d3d11/ShaderModuleD3D11.cpp
@@ -101,7 +101,7 @@ // D3D11 only supports FXC. req.bytecode.compiler = d3d::Compiler::FXC; - req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile; + req.bytecode.d3dCompile = std::move(pD3DCompile{device->GetFunctions()->d3dCompile}); req.bytecode.compilerVersion = D3D_COMPILER_VERSION; DAWN_ASSERT(device->GetDeviceInfo().shaderModel == 50); switch (stage) { @@ -188,8 +188,8 @@ } } - auto tintProgram = GetTintProgram(); - req.hlsl.inputProgram = &(tintProgram->program); + req.hlsl.shaderModuleHash = GetHash(); + req.hlsl.inputProgram = UseTintProgram(); req.hlsl.entryPointName = programmableStage.entryPoint.c_str(); req.hlsl.stage = stage; // Put the firstIndex into the internally reserved group and binding to avoid conflicting with
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index c5ffcca..4adf5d5 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -159,7 +159,7 @@ req.bytecode.dxcShaderProfile = device->GetDxcShaderProfiles()[stage]; } else { req.bytecode.compiler = d3d::Compiler::FXC; - req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile; + req.bytecode.d3dCompile = std::move(pD3DCompile{device->GetFunctions()->d3dCompile}); req.bytecode.compilerVersion = D3D_COMPILER_VERSION; switch (stage) { case SingleShaderStage::Vertex: @@ -320,8 +320,8 @@ } } - auto tintProgram = GetTintProgram(); - req.hlsl.inputProgram = &(tintProgram->program); + req.hlsl.shaderModuleHash = GetHash(); + req.hlsl.inputProgram = UseTintProgram(); req.hlsl.entryPointName = programmableStage.entryPoint.c_str(); req.hlsl.stage = stage; if (!useTintIR) {
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index d8c0361..8047db9 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -58,17 +58,18 @@ using OptionalVertexPullingTransformConfig = std::optional<tint::VertexPullingConfig>; using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>; -#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(SingleShaderStage, stage) \ - X(const tint::Program*, inputProgram) \ - X(SubstituteOverrideConfig, substituteOverrideConfig) \ - X(LimitsForCompilationRequest, limits) \ - X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ - X(uint32_t, maxSubgroupSize) \ - X(std::string, entryPointName) \ - X(bool, usesSubgroupMatrix) \ - X(bool, disableSymbolRenaming) \ - X(tint::msl::writer::Options, tintOptions) \ +#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(SingleShaderStage, stage) \ + X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \ + X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \ + X(SubstituteOverrideConfig, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ + X(uint32_t, maxSubgroupSize) \ + X(std::string, entryPointName) \ + X(bool, usesSubgroupMatrix) \ + X(bool, disableSymbolRenaming) \ + X(tint::msl::writer::Options, tintOptions) \ X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform) DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); @@ -266,8 +267,8 @@ MslCompilationRequest req = {}; req.stage = stage; - auto tintProgram = programmableStage.module->GetTintProgram(); - req.inputProgram = &(tintProgram->program); + req.shaderModuleHash = programmableStage.module->GetHash(); + req.inputProgram = programmableStage.module->UseTintProgram(); req.substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); req.entryPointName = programmableStage.entryPoint.c_str(); req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); @@ -302,8 +303,11 @@ mslCompilation, device, std::move(req), MslCompilation::FromBlob, [](MslCompilationRequest r) -> ResultOrError<MslCompilation> { TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::msl::writer::Generate"); + // Requires Tint Program here right before actual using. + auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram(); + const tint::Program* tintInputProgram = &(inputProgram->program); // Convert the AST program to an IR module. - auto ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram); + auto ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram); DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s", ir.Failure().reason);
diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp index ee9adbd..5727c93 100644 --- a/src/dawn/native/opengl/ShaderModuleGL.cpp +++ b/src/dawn/native/opengl/ShaderModuleGL.cpp
@@ -92,16 +92,17 @@ using InterstageLocationAndName = std::pair<uint32_t, std::string>; using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>; -#define GLSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(std::string, entryPointName) \ - X(SingleShaderStage, stage) \ - X(SubstituteOverrideConfig, substituteOverrideConfig) \ - X(LimitsForCompilationRequest, limits) \ - X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ - X(bool, disableSymbolRenaming) \ - X(std::vector<InterstageLocationAndName>, interstageVariables) \ - X(tint::glsl::writer::Options, tintOptions) \ +#define GLSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \ + X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \ + X(std::string, entryPointName) \ + X(SingleShaderStage, stage) \ + X(SubstituteOverrideConfig, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ + X(bool, disableSymbolRenaming) \ + X(std::vector<InterstageLocationAndName>, interstageVariables) \ + X(tint::glsl::writer::Options, tintOptions) \ X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform) DAWN_MAKE_CACHE_REQUEST(GLSLCompilationRequest, GLSL_COMPILATION_REQUEST_MEMBERS); @@ -464,8 +465,8 @@ GLSLCompilationRequest req = {}; - auto tintProgram = GetTintProgram(); - req.inputProgram = &(tintProgram->program); + req.shaderModuleHash = GetHash(); + req.inputProgram = UseTintProgram(); // Since (non-Vulkan) GLSL does not support descriptor sets, generate a // mapping from the original group/binding pair to a binding-only @@ -562,8 +563,11 @@ DAWN_TRY_LOAD_OR_RUN( compilationResult, GetDevice(), std::move(req), GLSLCompilation::FromBlob, [](GLSLCompilationRequest r) -> ResultOrError<GLSLCompilation> { + // Requires Tint Program here right before actual using. + auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram(); + const tint::Program* tintInputProgram = &(inputProgram->program); // Convert the AST program to an IR module. - auto ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram); + auto ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram); DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s", ir.Failure().reason);
diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h index 8dde2be..77106f9 100644 --- a/src/dawn/native/stream/Stream.h +++ b/src/dawn/native/stream/Stream.h
@@ -304,6 +304,44 @@ } }; +// Stream specialization for std::array<T, Size> of fundamental types T. +template <typename T, size_t Size> +class Stream<std::array<T, Size>, std::enable_if_t<std::is_fundamental_v<T>>> { + public: + static void Write(Sink* s, const std::array<T, Size>& t) { + static_assert(Size > 0); + memcpy(s->GetSpace(sizeof(t)), t.data(), sizeof(t)); + } + + static MaybeError Read(Source* s, std::array<T, Size>* t) { + static_assert(Size > 0); + const void* ptr; + DAWN_TRY(s->Read(&ptr, sizeof(*t))); + memcpy(t->data(), ptr, sizeof(*t)); + return {}; + } +}; + +// Stream specialization for std::array<T, Size> of non-fundamental types T. +template <typename T, size_t Size> +class Stream<std::array<T, Size>, std::enable_if_t<!std::is_fundamental_v<T>>> { + public: + static void Write(Sink* s, const std::array<T, Size>& v) { + static_assert(Size > 0); + for (const T& it : v) { + StreamIn(s, it); + } + } + + static MaybeError Read(Source* s, std::array<T, Size>* v) { + static_assert(Size > 0); + for (auto& el : *v) { + DAWN_TRY(StreamOut(s, el)); + } + return {}; + } +}; + // Stream specialization for std::pair. template <typename A, typename B> class Stream<std::pair<A, B>> {
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index 5792978..b7ca147 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -203,16 +203,17 @@ using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>; -#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ - X(SingleShaderStage, stage) \ - X(const tint::Program*, inputProgram) \ - X(SubstituteOverrideConfig, substituteOverrideConfig) \ - X(LimitsForCompilationRequest, limits) \ - X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ - X(uint32_t, maxSubgroupSize) \ - X(std::string_view, entryPointName) \ - X(bool, usesSubgroupMatrix) \ - X(tint::spirv::writer::Options, tintOptions) \ +#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ + X(SingleShaderStage, stage) \ + X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \ + X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \ + X(SubstituteOverrideConfig, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \ + X(uint32_t, maxSubgroupSize) \ + X(std::string_view, entryPointName) \ + X(bool, usesSubgroupMatrix) \ + X(tint::spirv::writer::Options, tintOptions) \ X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform) DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS); @@ -341,8 +342,8 @@ SpirvCompilationRequest req = {}; req.stage = stage; - auto tintProgram = GetTintProgram(); - req.inputProgram = &(tintProgram->program); + req.shaderModuleHash = GetHash(); + req.inputProgram = UseTintProgram(); req.entryPointName = programmableStage.entryPoint; req.platform = UnsafeUnkeyedValue(GetDevice()->GetPlatform()); req.substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); @@ -408,12 +409,15 @@ [](SpirvCompilationRequest r) -> ResultOrError<CompiledSpirv> { TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::spirv::writer::Generate()"); + // Requires Tint Program here right before actual using. + auto inputProgram = r.inputProgram.UnsafeGetValue()->GetTintProgram(); + const tint::Program* tintInputProgram = &(inputProgram->program); // Convert the AST program to an IR module. tint::Result<tint::core::ir::Module> ir; { SCOPED_DAWN_HISTOGRAM_TIMER_MICROS(r.platform.UnsafeGetValue(), "ShaderModuleProgramToIR"); - ir = tint::wgsl::reader::ProgramToLoweredIR(*r.inputProgram); + ir = tint::wgsl::reader::ProgramToLoweredIR(*tintInputProgram); DAWN_INVALID_IF(ir != tint::Success, "An error occurred while generating Tint IR\n%s", ir.Failure().reason);
diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp index c100f39..4bc3a82 100644 --- a/src/dawn/tests/unittests/native/StreamTests.cpp +++ b/src/dawn/tests/unittests/native/StreamTests.cpp
@@ -25,6 +25,7 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include <array> #include <cstring> #include <iomanip> #include <limits> @@ -441,7 +442,12 @@ // Test unordered_sets. std::vector<std::unordered_set<int>>{{}, {4, 6, 99, 0}, {100, 300, 300}}, // Test vectors. - std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}}); + std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}}, + // Test different size of arrays. + std::vector<std::array<int, 3>>{{1, 5, 2}, {-3, -3, -3}}, + std::vector<std::array<uint8_t, 5>>{{5, 2, 7, 9, 6}, {3, 3, 3, 3, 42}}, + // array of non-fundamental type + std::vector<std::array<std::string, 2>>{{"abcd", "efg"}, {"123hij", ""}}); static auto kStreamValueInitListParams = std::make_tuple( std::initializer_list<char[12]>{"test string", "string test"},