Add ShaderModuleCompilationOptions

Allows configuring strictMath in shaders

Bug: b/332394417, dawn:2503
Change-Id: I5ee8f9820b3b49c3a1e0edc1238ea2d817adc15b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/183361
Reviewed-by: Loko Kung <lokokung@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/docs/dawn/features/shader_module_compilation_options.md b/docs/dawn/features/shader_module_compilation_options.md
new file mode 100644
index 0000000..ce7092a
--- /dev/null
+++ b/docs/dawn/features/shader_module_compilation_options.md
@@ -0,0 +1,14 @@
+# Shader Module Compilation Options
+
+Shader module compilation options may be specified to override Dawn's default compilation behavior.
+
+`wgpu::ShaderModuleCompilationOptions` may be chained on `wgpu::ShaderModuleDescriptor`. If it is,
+Dawn will use these compilation options instead of its defaults.
+
+### `wgpu::ShaderModuleCompilationOptions::strictMath`
+Enables or disables strict math. When strict math is disabled, generally the compiler will:
+- Assume no NaNs
+- Assume no Inf
+- Assume no signed 0
+- Use multiplication by reciprocal instead of division
+- Allow algebraic transformations according to associative and distribute properties.
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 7d18e18..a68e7f9 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -2262,7 +2262,8 @@
             {"value": 1204, "name": "shared fence MTL shared event", "tags": ["dawn", "native"]},
             {"value": 1205, "name": "shared buffer memory D3D12 resource", "tags": ["dawn", "native"]},
             {"value": 1206, "name": "static samplers", "tags": ["dawn"]},
-            {"value": 1207, "name": "y cb cr vulkan samplers", "tags": ["dawn"]}
+            {"value": 1207, "name": "y cb cr vulkan samplers", "tags": ["dawn"]},
+            {"value": 1208, "name": "shader module compilation options", "tags": ["dawn"]}
         ]
     },
     "filter mode": {
@@ -3445,6 +3446,15 @@
             {"name": "allow non uniform derivatives", "type": "bool", "default": "false"}
         ]
     },
+    "shader module compilation options": {
+        "category": "structure",
+        "chained": "in",
+        "chain roots": ["shader module descriptor"],
+        "tags": ["dawn"],
+        "members": [
+            {"name": "strict math", "type": "bool"}
+        ]
+    },
     "shader stage": {
         "category": "bitmask",
         "values": [
@@ -3696,6 +3706,7 @@
             {"value": 1023, "name": "dawn wire WGSL control", "tags": ["dawn"]},
             {"value": 1024, "name": "dawn WGSL blocklist", "tags": ["dawn", "native"]},
             {"value": 1025, "name": "drm format capabilities", "tags": ["dawn"]},
+            {"value": 1026, "name": "shader module compilation options", "tags": ["dawn"]},
 
             {"value": 1100, "name": "shared texture memory vk image descriptor", "tags": ["dawn", "native"]},
             {"value": 1101, "name": "shared texture memory vk dedicated allocation descriptor", "tags": ["dawn", "native"]},
diff --git a/src/dawn/native/Features.cpp b/src/dawn/native/Features.cpp
index 400e37e..13b8dc5 100644
--- a/src/dawn/native/Features.cpp
+++ b/src/dawn/native/Features.cpp
@@ -340,6 +340,11 @@
       "https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/dawn/features/"
       "y_cb_cr_vulkan_samplers.md",
       FeatureInfo::FeatureState::Experimental}},
+    {Feature::ShaderModuleCompilationOptions,
+     {"Support overriding default shader module compilation options.",
+      "https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/dawn/features/"
+      "shader_module_compilation_options.md",
+      FeatureInfo::FeatureState::Experimental}},
 };
 
 }  // anonymous namespace
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 4e27d07..c15c52a 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -1052,14 +1052,15 @@
     // A WGSL (or SPIR-V, if enabled) subdescriptor is required, and a Dawn-specific SPIR-V options
 // descriptor is allowed when using SPIR-V.
 #if TINT_BUILD_SPV_READER
-    DAWN_TRY_ASSIGN(
-        moduleType,
-        (descriptor.ValidateBranches<
-            Branch<ShaderModuleWGSLDescriptor>,
-            Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor>>()));
+    DAWN_TRY_ASSIGN(moduleType,
+                    (descriptor.ValidateBranches<
+                        Branch<ShaderModuleWGSLDescriptor, ShaderModuleCompilationOptions>,
+                        Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor,
+                               ShaderModuleCompilationOptions>>()));
 #else
     DAWN_TRY_ASSIGN(moduleType,
-                    (descriptor.ValidateBranches<Branch<ShaderModuleWGSLDescriptor>>()));
+                    (descriptor.ValidateBranches<
+                        Branch<ShaderModuleWGSLDescriptor, ShaderModuleCompilationOptions>>()));
 #endif
     DAWN_ASSERT(moduleType != wgpu::SType::Invalid);
 
@@ -1104,6 +1105,11 @@
     }
     DAWN_ASSERT(wgslDesc != nullptr);
 
+    DAWN_INVALID_IF(descriptor.Get<ShaderModuleCompilationOptions>() != nullptr &&
+                        !device->HasFeature(Feature::ShaderModuleCompilationOptions),
+                    "Shader module compilation options used without %s enabled.",
+                    wgpu::FeatureName::ShaderModuleCompilationOptions);
+
     auto tintFile = std::make_unique<tint::Source::File>("", wgslDesc->code);
 
     if (device->IsToggleEnabled(Toggle::DumpShaders)) {
@@ -1281,6 +1287,10 @@
     } else {
         DAWN_ASSERT(false);
     }
+
+    if (const auto* compileOptions = descriptor.Get<ShaderModuleCompilationOptions>()) {
+        mStrictMath = compileOptions->strictMath;
+    }
 }
 
 ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
@@ -1324,6 +1334,10 @@
     return entryPoint;
 }
 
+std::optional<bool> ShaderModuleBase::GetStrictMath() const {
+    return mStrictMath;
+}
+
 const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const {
     DAWN_ASSERT(HasEntryPoint(entryPoint));
     return *mEntryPoints.at(entryPoint);
@@ -1334,12 +1348,14 @@
     recorder.Record(mType);
     recorder.Record(mOriginalSpirv);
     recorder.Record(mWgsl);
+    recorder.Record(mStrictMath);
     return recorder.GetContentHash();
 }
 
 bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
                                                 const ShaderModuleBase* b) const {
-    return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && a->mWgsl == b->mWgsl;
+    return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv && a->mWgsl == b->mWgsl &&
+           a->mStrictMath == b->mStrictMath;
 }
 
 ShaderModuleBase::ScopedUseTintProgram ShaderModuleBase::UseTintProgram() {
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 82974432..de83e17 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -323,6 +323,8 @@
         bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
     };
 
+    std::optional<bool> GetStrictMath() const;
+
     using ScopedUseTintProgram = APIRef<ShaderModuleBase>;
     ScopedUseTintProgram UseTintProgram();
 
@@ -353,6 +355,10 @@
     std::vector<uint32_t> mOriginalSpirv;
     std::string mWgsl;
 
+    // TODO(dawn:2503): Remove the optional when Dawn can has a consistent default across backends.
+    // Right now D3D uses strictness by default, and Vulkan/Metal use fast math by default.
+    std::optional<bool> mStrictMath;
+
     EntryPointMetadataTable mEntryPoints;
     PerStage<std::string> mDefaultEntryPointNames;
     PerStage<size_t> mEntryPointCounts;
diff --git a/src/dawn/native/d3d11/ComputePipelineD3D11.cpp b/src/dawn/native/d3d11/ComputePipelineD3D11.cpp
index 01b2df8..0490174 100644
--- a/src/dawn/native/d3d11/ComputePipelineD3D11.cpp
+++ b/src/dawn/native/d3d11/ComputePipelineD3D11.cpp
@@ -63,12 +63,12 @@
     // Tint does matrix multiplication expecting row major matrices
     compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
-    if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) {
+    const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Compute);
+    if (programmableStage.module->GetStrictMath().value_or(
+            !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) {
         compileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
     }
 
-    const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Compute);
-
     d3d::CompiledShader compiledShader;
     DAWN_TRY_ASSIGN(compiledShader, ToBackend(programmableStage.module)
                                         ->Compile(programmableStage, SingleShaderStage::Compute,
diff --git a/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp b/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
index 1c33d0b..b29d4a1 100644
--- a/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
@@ -194,6 +194,7 @@
     EnableFeature(Feature::AdapterPropertiesMemoryHeaps);
     EnableFeature(Feature::AdapterPropertiesD3D);
     EnableFeature(Feature::R8UnormStorage);
+    EnableFeature(Feature::ShaderModuleCompilationOptions);
 
     // Multi planar formats are always supported since Feature Level 11.0
     // https://learn.microsoft.com/en-us/windows/win32/direct3ddxgi/format-support-for-direct3d-11-0-feature-level-hardware
diff --git a/src/dawn/native/d3d11/RenderPipelineD3D11.cpp b/src/dawn/native/d3d11/RenderPipelineD3D11.cpp
index e222c66..7fb0cf4 100644
--- a/src/dawn/native/d3d11/RenderPipelineD3D11.cpp
+++ b/src/dawn/native/d3d11/RenderPipelineD3D11.cpp
@@ -441,10 +441,6 @@
     // Tint does matrix multiplication expecting row major matrices
     compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
-    if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) {
-        compileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
-    }
-
     PerStage<d3d::CompiledShader> compiledShader;
 
     std::optional<dawn::native::d3d::InterStageShaderVariablesMask> usedInterstageVariables;
@@ -459,11 +455,17 @@
 
     if (GetStageMask() & wgpu::ShaderStage::Vertex) {
         const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Vertex);
+        uint32_t additionalCompileFlags = 0;
+        if (programmableStage.module->GetStrictMath().value_or(
+                !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) {
+            additionalCompileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
+        }
+
         DAWN_TRY_ASSIGN(
             compiledShader[SingleShaderStage::Vertex],
             ToBackend(programmableStage.module)
                 ->Compile(programmableStage, SingleShaderStage::Vertex, ToBackend(GetLayout()),
-                          compileFlags, usedInterstageVariables));
+                          compileFlags | additionalCompileFlags, usedInterstageVariables));
         const Blob& shaderBlob = compiledShader[SingleShaderStage::Vertex].shaderBlob;
         DAWN_TRY(CheckHRESULT(device->GetD3D11Device()->CreateVertexShader(
                                   shaderBlob.Data(), shaderBlob.Size(), nullptr, &mVertexShader),
@@ -519,11 +521,17 @@
         }
 
         const ProgrammableStage& programmableStage = GetStage(SingleShaderStage::Fragment);
-        DAWN_TRY_ASSIGN(
-            compiledShader[SingleShaderStage::Fragment],
-            ToBackend(programmableStage.module)
-                ->Compile(programmableStage, SingleShaderStage::Fragment, ToBackend(GetLayout()),
-                          compileFlags, usedInterstageVariables, pixelLocalOptions));
+        uint32_t additionalCompileFlags = 0;
+        if (programmableStage.module->GetStrictMath().value_or(
+                !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) {
+            additionalCompileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
+        }
+
+        DAWN_TRY_ASSIGN(compiledShader[SingleShaderStage::Fragment],
+                        ToBackend(programmableStage.module)
+                            ->Compile(programmableStage, SingleShaderStage::Fragment,
+                                      ToBackend(GetLayout()), compileFlags | additionalCompileFlags,
+                                      usedInterstageVariables, pixelLocalOptions));
         DAWN_TRY(CheckHRESULT(device->GetD3D11Device()->CreatePixelShader(
                                   compiledShader[SingleShaderStage::Fragment].shaderBlob.Data(),
                                   compiledShader[SingleShaderStage::Fragment].shaderBlob.Size(),
diff --git a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
index a8a51e2..63a0b2e 100644
--- a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
@@ -71,13 +71,14 @@
     // Tint does matrix multiplication expecting row major matrices
     compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
-    if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) {
-        compileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
-    }
-
     const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
     ShaderModule* module = ToBackend(computeStage.module.Get());
 
+    if (module->GetStrictMath().value_or(
+            !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) {
+        compileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
+    }
+
     D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
     d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
 
diff --git a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
index cc9775d..8771248 100644
--- a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
@@ -159,6 +159,7 @@
     EnableFeature(Feature::MultiPlanarRenderTargets);
     EnableFeature(Feature::R8UnormStorage);
     EnableFeature(Feature::SharedBufferMemoryD3D12Resource);
+    EnableFeature(Feature::ShaderModuleCompilationOptions);
 
     if (AreTimestampQueriesSupported()) {
         EnableFeature(Feature::TimestampQuery);
diff --git a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
index 4b84e13..7b96274 100644
--- a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
@@ -356,10 +356,6 @@
     // Tint does matrix multiplication expecting row major matrices
     compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
-    if (!device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness)) {
-        compileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
-    }
-
     D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
 
     PerStage<D3D12_SHADER_BYTECODE*> shaders;
@@ -380,10 +376,16 @@
 
     for (auto stage : IterateStages(GetStageMask())) {
         const ProgrammableStage& programmableStage = GetStage(stage);
-        DAWN_TRY_ASSIGN(compiledShader[stage],
-                        ToBackend(programmableStage.module)
-                            ->Compile(programmableStage, stage, ToBackend(GetLayout()),
-                                      compileFlags, usedInterstageVariables));
+        uint32_t additionalCompileFlags = 0;
+        if (programmableStage.module->GetStrictMath().value_or(
+                !device->IsToggleEnabled(Toggle::D3DDisableIEEEStrictness))) {
+            additionalCompileFlags |= D3DCOMPILE_IEEE_STRICTNESS;
+        }
+        DAWN_TRY_ASSIGN(
+            compiledShader[stage],
+            ToBackend(programmableStage.module)
+                ->Compile(programmableStage, stage, ToBackend(GetLayout()),
+                          compileFlags | additionalCompileFlags, usedInterstageVariables));
         *shaders[stage] = {compiledShader[stage].shaderBlob.Data(),
                            compiledShader[stage].shaderBlob.Size()};
     }
diff --git a/src/dawn/native/metal/BackendMTL.mm b/src/dawn/native/metal/BackendMTL.mm
index 41f36ec..672d705 100644
--- a/src/dawn/native/metal/BackendMTL.mm
+++ b/src/dawn/native/metal/BackendMTL.mm
@@ -618,6 +618,7 @@
         EnableFeature(Feature::MSAARenderToSingleSampled);
         EnableFeature(Feature::DualSourceBlending);
         EnableFeature(Feature::R8UnormStorage);
+        EnableFeature(Feature::ShaderModuleCompilationOptions);
 
         // SIMD-scoped permute operations is supported by GPU family Metal3, Apple6, Apple7, Apple8,
         // and Mac2.
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index 1782553..f6e6890 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -425,7 +425,8 @@
             (*compileOptions).preserveInvariance = true;
         }
     }
-    (*compileOptions).fastMathEnabled = true;
+
+    (*compileOptions).fastMathEnabled = !GetStrictMath().value_or(false);
 
     auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
     NSError* error = nullptr;
diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
index 3f93e88..a998e92 100644
--- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -249,6 +249,19 @@
     ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc));
 }
 
+// Test that it is invalid to pass ShaderModuleCompilationOptions if the feature is not enabled.
+TEST_F(ShaderModuleValidationTest, ShaderModuleCompilationOptionsNoFeature) {
+    wgpu::ShaderModuleDescriptor desc = {};
+    wgpu::ShaderModuleWGSLDescriptor wgslDesc = {};
+    wgslDesc.code = "@compute @workgroup_size(1) fn main() {}";
+
+    wgpu::ShaderModuleCompilationOptions compilationOptions = {};
+    desc.nextInChain = &wgslDesc;
+    wgslDesc.nextInChain = &compilationOptions;
+    ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc),
+                        testing::HasSubstr("FeatureName::ShaderModuleCompilationOptions"));
+}
+
 // Tests that shader module compilation messages can be queried.
 TEST_F(ShaderModuleValidationTest, GetCompilationMessages) {
     // This test works assuming ShaderModule is backed by a native::ShaderModuleBase, which
@@ -927,5 +940,22 @@
     }
 }
 
+// Test it is valid to chain ShaderModuleCompilationOptions and path true/false for strictMath.
+TEST_F(ShaderModuleExtensionValidationTestUnsafeAllFeatures, ShaderModuleCompilationOptions) {
+    wgpu::ShaderModuleDescriptor desc = {};
+    wgpu::ShaderModuleWGSLDescriptor wgslDesc = {};
+    wgslDesc.code = "@compute @workgroup_size(1) fn main() {}";
+
+    wgpu::ShaderModuleCompilationOptions compilationOptions = {};
+    desc.nextInChain = &wgslDesc;
+    wgslDesc.nextInChain = &compilationOptions;
+
+    compilationOptions.strictMath = false;
+    device.CreateShaderModule(&desc);
+
+    compilationOptions.strictMath = true;
+    device.CreateShaderModule(&desc);
+}
+
 }  // anonymous namespace
 }  // namespace dawn
diff --git a/src/dawn/wire/SupportedFeatures.cpp b/src/dawn/wire/SupportedFeatures.cpp
index f1e5b51..9e41fcd 100644
--- a/src/dawn/wire/SupportedFeatures.cpp
+++ b/src/dawn/wire/SupportedFeatures.cpp
@@ -103,6 +103,7 @@
         case WGPUFeatureName_R8UnormStorage:
         case WGPUFeatureName_StaticSamplers:
         case WGPUFeatureName_YCbCrVulkanSamplers:
+        case WGPUFeatureName_ShaderModuleCompilationOptions:
             return true;
     }