[dawn] MacOS use relaxed math where possible

Fast math has potential downsides due to its treatment of inf/nan. We
can use relaxed math as a pragma in MacOSx 15+ versions.

Bug: 425650181
Change-Id: I30794e90928f8841879f528e2598aaf885af22b7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/249035
Commit-Queue: Peter McNeeley <petermcneeley@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index d5af501..1f9dc46 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -65,6 +65,7 @@
     X(uint32_t, maxSubgroupSize)                                                     \
     X(std::string, entryPointName)                                                   \
     X(bool, usesSubgroupMatrix)                                                      \
+    X(bool, useStrictMath)                                                           \
     X(bool, disableSymbolRenaming)                                                   \
     X(tint::msl::writer::Options, tintOptions)                                       \
     X(UnsafeUnserializedValue<dawn::platform::Platform*>, platform)
@@ -221,7 +222,8 @@
     ShaderModule::MetalFunctionData* out,
     uint32_t sampleMask,
     const RenderPipeline* renderPipeline,
-    const BindingInfoArray& moduleBindingInfo) {
+    const BindingInfoArray& moduleBindingInfo,
+    bool useStrictMath) {
     std::ostringstream errorStream;
     errorStream << "Tint MSL failure:\n";
 
@@ -278,6 +280,7 @@
     req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
     req.usesSubgroupMatrix = programmableStage.metadata->usesSubgroupMatrix;
     req.platform = UnsafeUnserializedValue(device->GetPlatform());
+    req.useStrictMath = useStrictMath;
 
     req.tintOptions.strip_all_names = !req.disableSymbolRenaming;
     req.tintOptions.remapped_entry_point_name = device->GetIsolatedEntryPointName();
@@ -374,17 +377,28 @@
                                     r.adapterSupportedLimits.UnsafeGetValue()));
             }
 
+            auto msl = std::move(result->msl);
+            // Metal supports math_mode as both compiler option and as a pragma. We add the
+            // math_mode here as a string conditional on OSx version as the compiler option only
+            // exists for MacOS after 15. See the Metal 4 spec for more information.
+            // Note: this math_mode takes precedence over global flags provide to the compiler
+            // (including the deprecated fastMathEnabled compiler option).
+            std::string math_mode_heading;
+            if (@available(macOS 15.0, iOS 18.0, *)) {
+                math_mode_heading = "\n#pragma METAL fp math_mode(";
+                math_mode_heading += r.useStrictMath ? "safe" : "relaxed";
+                math_mode_heading += +")\n";
+            }
             // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
             // category. -Wunused-variable in particular comes up a lot in generated code, and
             // some (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError
             // instead of a warning.
-            auto msl = std::move(result->msl);
             msl = R"(
                     #ifdef __clang__
                     #pragma clang diagnostic ignored "-Wall"
                     #endif
                 )" +
-                  msl;
+                  math_mode_heading + msl;
 
             return MslCompilation{{
                 std::move(msl),
@@ -430,7 +444,8 @@
     CacheResult<MslCompilation> mslCompilation;
     DAWN_TRY_ASSIGN(mslCompilation,
                     TranslateToMSL(GetDevice(), programmableStage, stage, layout, out, sampleMask,
-                                   renderPipeline, GetEntryPoint(entryPointName).bindings));
+                                   renderPipeline, GetEntryPoint(entryPointName).bindings,
+                                   GetStrictMath().value_or(false)));
 
     out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
     out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
@@ -446,7 +461,11 @@
         (*compileOptions).preserveInvariance = true;
     }
 
-    (*compileOptions).fastMathEnabled = !GetStrictMath().value_or(false);
+    // If possible we will use relaxed math as a pragma in the source rather than this fast math
+    // global compiler option. See crbug.com/425650181
+    if (!@available(macOS 15.0, iOS 18.0, *)) {
+        (*compileOptions).fastMathEnabled = !GetStrictMath().value_or(false);
+    }
 
     auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
     NSError* error = nullptr;
diff --git a/src/tint/lang/msl/validate/validate_metal.mm b/src/tint/lang/msl/validate/validate_metal.mm
index d548da8..b458509 100644
--- a/src/tint/lang/msl/validate/validate_metal.mm
+++ b/src/tint/lang/msl/validate/validate_metal.mm
@@ -33,7 +33,7 @@
 
 namespace tint::msl::validate {
 
-Result ValidateUsingMetal(const std::string& src, MslVersion version) {
+Result ValidateUsingMetal(const std::string& src_original, MslVersion version) {
     Result result;
 
     NSError* error = nil;
@@ -45,10 +45,18 @@
         return result;
     }
 
-    NSString* source = [NSString stringWithCString:src.c_str() encoding:NSUTF8StringEncoding];
-
+    std::string src_modified = src_original;
     MTLCompileOptions* compileOptions = [MTLCompileOptions new];
-    compileOptions.fastMathEnabled = true;
+    if (@available(macOS 15.0, iOS 18.0, *)) {
+        // Use relaxed math where possible.
+        // See crbug.com/425650181
+        std::string("\n#pragma METAL fp math_mode(relaxed)\n") + src_original;
+    } else {
+        compileOptions.fastMathEnabled = true;
+    }
+    NSString* source = [NSString stringWithCString:src_modified.c_str()
+                                          encoding:NSUTF8StringEncoding];
+
     switch (version) {
         case MslVersion::kMsl_2_3:
             compileOptions.languageVersion = MTLLanguageVersion2_3;