Add ShaderModuleCompilationOptions Allows configuring strictMath in shaders Bug: b/332394417, dawn:2503 Change-Id: I5ee8f9820b3b49c3a1e0edc1238ea2d817adc15b Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/183361 Reviewed-by: Loko Kung <lokokung@google.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/docs/dawn/features/shader_module_compilation_options.md b/docs/dawn/features/shader_module_compilation_options.md new file mode 100644 index 0000000..ce7092a --- /dev/null +++ b/docs/dawn/features/shader_module_compilation_options.md
@@ -0,0 +1,14 @@ +# Shader Module Compilation Options + +Shader module compilation options may be specified to override Dawn's default compilation behavior. + +`wgpu::ShaderModuleCompilationOptions` may be chained on `wgpu::ShaderModuleDescriptor`. If it is, +Dawn will use these compilation options instead of its defaults. + +### `wgpu::ShaderModuleCompilationOptions::strictMath` +Enables or disables strict math. When strict math is disabled, generally the compiler will: +- Assume no NaNs +- Assume no Inf +- Assume no signed 0 +- Use multiplication by reciprocal instead of division +- Allow algebraic transformations according to associative and distribute properties.
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json index 7d18e18..a68e7f9 100644 --- a/src/dawn/dawn.json +++ b/src/dawn/dawn.json
@@ -2262,7 +2262,8 @@ {"value": 1204, "name": "shared fence MTL shared event", "tags": ["dawn", "native"]}, {"value": 1205, "name": "shared buffer memory D3D12 resource", "tags": ["dawn", "native"]}, {"value": 1206, "name": "static samplers", "tags": ["dawn"]}, - {"value": 1207, "name": "y cb cr vulkan samplers", "tags": ["dawn"]} + {"value": 1207, "name": "y cb cr vulkan samplers", "tags": ["dawn"]}, + {"value": 1208, "name": "shader module compilation options", "tags": ["dawn"]} ] }, "filter mode": { @@ -3445,6 +3446,15 @@ {"name": "allow non uniform derivatives", "type": "bool", "default": "false"} ] }, + "shader module compilation options": { + "category": "structure", + "chained": "in", + "chain roots": ["shader module descriptor"], + "tags": ["dawn"], + "members": [ + {"name": "strict math", "type": "bool"} + ] + }, "shader stage": { "category": "bitmask", "values": [ @@ -3696,6 +3706,7 @@ {"value": 1023, "name": "dawn wire WGSL control", "tags": ["dawn"]}, {"value": 1024, "name": "dawn WGSL blocklist", "tags": ["dawn", "native"]}, {"value": 1025, "name": "drm format capabilities", "tags": ["dawn"]}, + {"value": 1026, "name": "shader module compilation options", "tags": ["dawn"]}, {"value": 1100, "name": "shared texture memory vk image descriptor", "tags": ["dawn", "native"]}, {"value": 1101, "name": "shared texture memory vk dedicated allocation descriptor", "tags": ["dawn", "native"]},
diff --git a/src/dawn/native/Features.cpp b/src/dawn/native/Features.cpp index 400e37e..13b8dc5 100644 --- a/src/dawn/native/Features.cpp +++ b/src/dawn/native/Features.cpp
@@ -340,6 +340,11 @@ "https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/dawn/features/" "y_cb_cr_vulkan_samplers.md", FeatureInfo::FeatureState::Experimental}}, + {Feature::ShaderModuleCompilationOptions, + {"Support overriding default shader module compilation options.", + "https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/dawn/features/" + "shader_module_compilation_options.md", + FeatureInfo::FeatureState::Experimental}}, }; } // anonymous namespace
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index 4e27d07..c15c52a 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp
@@ -1052,14 +1052,15 @@ // A WGSL (or SPIR-V, if enabled) subdescriptor is required, and a Dawn-specific SPIR-V options // descriptor is allowed when using SPIR-V. #if TINT_BUILD_SPV_READER - DAWN_TRY_ASSIGN( - moduleType, - (descriptor.ValidateBranches< - Branch<ShaderModuleWGSLDescriptor>, - Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor>>())); + DAWN_TRY_ASSIGN(moduleType, + (descriptor.ValidateBranches< + Branch<ShaderModuleWGSLDescriptor, ShaderModuleCompilationOptions>, + Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor, + ShaderModuleCompilationOptions>>())); #else DAWN_TRY_ASSIGN(moduleType, - (descriptor.ValidateBranches<Branch<ShaderModuleWGSLDescriptor>>())); + (descriptor.ValidateBranches< + Branch<ShaderModuleWGSLDescriptor, ShaderModuleCompilationOptions>>())); #endif DAWN_ASSERT(moduleType != wgpu::SType::Invalid); @@ -1104,6 +1105,11 @@ } DAWN_ASSERT(wgslDesc != nullptr); + DAWN_INVALID_IF(descriptor.Get<ShaderModuleCompilationOptions>() != nullptr && + !device->HasFeature(Feature::ShaderModuleCompilationOptions), + "Shader module compilation options used without %s enabled.", + wgpu::FeatureName::ShaderModuleCompilationOptions); + auto tintFile = std::make_unique<tint::Source::File>("", wgslDesc->code); if (device->IsToggleEnabled(Toggle::DumpShaders)) { @@ -1281,6 +1287,10 @@ } else { DAWN_ASSERT(false); } + + if (const auto* compileOptions = descriptor.Get<ShaderModuleCompilationOptions>()) { + mStrictMath = compileOptions->strictMath; + } } ShaderModuleBase::ShaderModuleBase(DeviceBase* device, @@ -1324,6 +1334,10 @@ return entryPoint; } +std::optional<bool> ShaderModuleBase::GetStrictMath() const { + return mStrictMath; +} + const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const { DAWN_ASSERT(HasEntryPoint(entryPoint)); return *mEntryPoints.at(entryPoint); @@ -1334,12 +1348,14 @@ recorder.Record(mType); recorder.Record(mOriginalSpirv); recorder.Record(mWgsl); + recorder.Record(mStrictMath); return recorder.GetContentHash(); } bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const { - return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && a->mWgsl == b->mWgsl; + return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && a->mWgsl == b->mWgsl && + a->mStrictMath == b->mStrictMath; } ShaderModuleBase::ScopedUseTintProgram ShaderModuleBase::UseTintProgram() {
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index 82974432..de83e17 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h
@@ -323,6 +323,8 @@ bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const; }; + std::optional<bool> GetStrictMath() const; + using ScopedUseTintProgram = APIRef<ShaderModuleBase>; ScopedUseTintProgram UseTintProgram(); @@ -353,6 +355,10 @@ std::vector<uint32_t> mOriginalSpirv; std::string mWgsl; + // TODO(dawn:2503): Remove the optional when Dawn can has a consistent default across backends. + // Right now D3D uses strictness by default, and Vulkan/Metal use fast math by default. + std::optional<bool> mStrictMath; + EntryPointMetadataTable mEntryPoints; PerStage<std::string> mDefaultEntryPointNames; PerStage<size_t> mEntryPointCounts;
diff --git a/src/dawn/native/d3d11/ComputePipelineD3D11.cpp b/src/dawn/native/d3d11/ComputePipelineD3D11.cpp index 01b2df8..0490174 100644 --- a/src/dawn/native/d3d11/ComputePipelineD3D11.cpp +++ b/src/dawn/native/d3d11/ComputePipelineD3D11.cpp
@@ -63,12 +63,12 @@ // Tint does matrix multiplication expecting row major matrices compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; - if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) { + const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Compute); + if (programmableStage.module->GetStrictMath().value_or( + !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) { compileFlags |= D3DCOMPILE_IEEE_STRICTNESS; } - const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Compute); - d3d::CompiledShader compiledShader; DAWN_TRY_ASSIGN(compiledShader, ToBackend(programmableStage.module) ->Compile(programmableStage, SingleShaderStage::Compute,
diff --git a/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp b/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp index 1c33d0b..b29d4a1 100644 --- a/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp +++ b/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
@@ -194,6 +194,7 @@ EnableFeature(Feature::AdapterPropertiesMemoryHeaps); EnableFeature(Feature::AdapterPropertiesD3D); EnableFeature(Feature::R8UnormStorage); + EnableFeature(Feature::ShaderModuleCompilationOptions); // Multi planar formats are always supported since Feature Level 11.0 // https://learn.microsoft.com/en-us/windows/win32/direct3ddxgi/format-support-for-direct3d-11-0-feature-level-hardware
diff --git a/src/dawn/native/d3d11/RenderPipelineD3D11.cpp b/src/dawn/native/d3d11/RenderPipelineD3D11.cpp index e222c66..7fb0cf4 100644 --- a/src/dawn/native/d3d11/RenderPipelineD3D11.cpp +++ b/src/dawn/native/d3d11/RenderPipelineD3D11.cpp
@@ -441,10 +441,6 @@ // Tint does matrix multiplication expecting row major matrices compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; - if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) { - compileFlags |= D3DCOMPILE_IEEE_STRICTNESS; - } - PerStage<d3d::CompiledShader> compiledShader; std::optional<dawn::native::d3d::InterStageShaderVariablesMask> usedInterstageVariables; @@ -459,11 +455,17 @@ if (GetStageMask() & wgpu::ShaderStage::Vertex) { const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Vertex); + uint32_t additionalCompileFlags = 0; + if (programmableStage.module->GetStrictMath().value_or( + !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) { + additionalCompileFlags |= D3DCOMPILE_IEEE_STRICTNESS; + } + DAWN_TRY_ASSIGN( compiledShader[SingleShaderStage::Vertex], ToBackend(programmableStage.module) ->Compile(programmableStage, SingleShaderStage::Vertex, ToBackend(GetLayout()), - compileFlags, usedInterstageVariables)); + compileFlags | additionalCompileFlags, usedInterstageVariables)); const Blob& shaderBlob = compiledShader[SingleShaderStage::Vertex].shaderBlob; DAWN_TRY(CheckHRESULT(device->GetD3D11Device()->CreateVertexShader( shaderBlob.Data(), shaderBlob.Size(), nullptr, &mVertexShader), @@ -519,11 +521,17 @@ } const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Fragment); - DAWN_TRY_ASSIGN( - compiledShader[SingleShaderStage::Fragment], - ToBackend(programmableStage.module) - ->Compile(programmableStage, SingleShaderStage::Fragment, ToBackend(GetLayout()), - compileFlags, usedInterstageVariables, pixelLocalOptions)); + uint32_t additionalCompileFlags = 0; + if (programmableStage.module->GetStrictMath().value_or( + !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) { + additionalCompileFlags |= D3DCOMPILE_IEEE_STRICTNESS; + } + + DAWN_TRY_ASSIGN(compiledShader[SingleShaderStage::Fragment], + ToBackend(programmableStage.module) + ->Compile(programmableStage, SingleShaderStage::Fragment, + ToBackend(GetLayout()), compileFlags | additionalCompileFlags, + usedInterstageVariables, pixelLocalOptions)); DAWN_TRY(CheckHRESULT(device->GetD3D11Device()->CreatePixelShader( compiledShader[SingleShaderStage::Fragment].shaderBlob.Data(), compiledShader[SingleShaderStage::Fragment].shaderBlob.Size(),
diff --git a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp index a8a51e2..63a0b2e 100644 --- a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp +++ b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
@@ -71,13 +71,14 @@ // Tint does matrix multiplication expecting row major matrices compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; - if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) { - compileFlags |= D3DCOMPILE_IEEE_STRICTNESS; - } - const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); ShaderModule* module = ToBackend(computeStage.module.Get()); + if (module->GetStrictMath().value_or( + !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) { + compileFlags |= D3DCOMPILE_IEEE_STRICTNESS; + } + D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {}; d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
diff --git a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp index cc9775d..8771248 100644 --- a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp +++ b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
@@ -159,6 +159,7 @@ EnableFeature(Feature::MultiPlanarRenderTargets); EnableFeature(Feature::R8UnormStorage); EnableFeature(Feature::SharedBufferMemoryD3D12Resource); + EnableFeature(Feature::ShaderModuleCompilationOptions); if (AreTimestampQueriesSupported()) { EnableFeature(Feature::TimestampQuery);
diff --git a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp index 4b84e13..7b96274 100644 --- a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp +++ b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
@@ -356,10 +356,6 @@ // Tint does matrix multiplication expecting row major matrices compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; - if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) { - compileFlags |= D3DCOMPILE_IEEE_STRICTNESS; - } - D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {}; PerStage<D3D12_SHADER_BYTECODE*> shaders; @@ -380,10 +376,16 @@ for (auto stage : IterateStages(GetStageMask())) { const ProgrammableStage& programmableStage = GetStage(stage); - DAWN_TRY_ASSIGN(compiledShader[stage], - ToBackend(programmableStage.module) - ->Compile(programmableStage, stage, ToBackend(GetLayout()), - compileFlags, usedInterstageVariables)); + uint32_t additionalCompileFlags = 0; + if (programmableStage.module->GetStrictMath().value_or( + !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) { + additionalCompileFlags |= D3DCOMPILE_IEEE_STRICTNESS; + } + DAWN_TRY_ASSIGN( + compiledShader[stage], + ToBackend(programmableStage.module) + ->Compile(programmableStage, stage, ToBackend(GetLayout()), + compileFlags | additionalCompileFlags, usedInterstageVariables)); *shaders[stage] = {compiledShader[stage].shaderBlob.Data(), compiledShader[stage].shaderBlob.Size()}; }
diff --git a/src/dawn/native/metal/BackendMTL.mm b/src/dawn/native/metal/BackendMTL.mm index 41f36ec..672d705 100644 --- a/src/dawn/native/metal/BackendMTL.mm +++ b/src/dawn/native/metal/BackendMTL.mm
@@ -618,6 +618,7 @@ EnableFeature(Feature::MSAARenderToSingleSampled); EnableFeature(Feature::DualSourceBlending); EnableFeature(Feature::R8UnormStorage); + EnableFeature(Feature::ShaderModuleCompilationOptions); // SIMD-scoped permute operations is supported by GPU family Metal3, Apple6, Apple7, Apple8, // and Mac2.
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 1782553..f6e6890 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -425,7 +425,8 @@ (*compileOptions).preserveInvariance = true; } } - (*compileOptions).fastMathEnabled = true; + + (*compileOptions).fastMathEnabled = !GetStrictMath().value_or(false); auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); NSError* error = nullptr;
diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp index 3f93e88..a998e92 100644 --- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -249,6 +249,19 @@ ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc)); } +// Test that it is invalid to pass ShaderModuleCompilationOptions if the feature is not enabled. +TEST_F(ShaderModuleValidationTest, ShaderModuleCompilationOptionsNoFeature) { + wgpu::ShaderModuleDescriptor desc = {}; + wgpu::ShaderModuleWGSLDescriptor wgslDesc = {}; + wgslDesc.code = "@compute @workgroup_size(1) fn main() {}"; + + wgpu::ShaderModuleCompilationOptions compilationOptions = {}; + desc.nextInChain = &wgslDesc; + wgslDesc.nextInChain = &compilationOptions; + ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc), + testing::HasSubstr("FeatureName::ShaderModuleCompilationOptions")); +} + // Tests that shader module compilation messages can be queried. TEST_F(ShaderModuleValidationTest, GetCompilationMessages) { // This test works assuming ShaderModule is backed by a native::ShaderModuleBase, which @@ -927,5 +940,22 @@ } } +// Test it is valid to chain ShaderModuleCompilationOptions and path true/false for strictMath. +TEST_F(ShaderModuleExtensionValidationTestUnsafeAllFeatures, ShaderModuleCompilationOptions) { + wgpu::ShaderModuleDescriptor desc = {}; + wgpu::ShaderModuleWGSLDescriptor wgslDesc = {}; + wgslDesc.code = "@compute @workgroup_size(1) fn main() {}"; + + wgpu::ShaderModuleCompilationOptions compilationOptions = {}; + desc.nextInChain = &wgslDesc; + wgslDesc.nextInChain = &compilationOptions; + + compilationOptions.strictMath = false; + device.CreateShaderModule(&desc); + + compilationOptions.strictMath = true; + device.CreateShaderModule(&desc); +} + } // anonymous namespace } // namespace dawn
diff --git a/src/dawn/wire/SupportedFeatures.cpp b/src/dawn/wire/SupportedFeatures.cpp index f1e5b51..9e41fcd 100644 --- a/src/dawn/wire/SupportedFeatures.cpp +++ b/src/dawn/wire/SupportedFeatures.cpp
@@ -103,6 +103,7 @@ case WGPUFeatureName_R8UnormStorage: case WGPUFeatureName_StaticSamplers: case WGPUFeatureName_YCbCrVulkanSamplers: + case WGPUFeatureName_ShaderModuleCompilationOptions: return true; }