[spirv-reader][ir] Correctly handle GLSL 450 UMax

The SPIR-V `UMax` method allows unsigned arguments and return types.
This is not permitted in WGSL. The SPIR-V spec states that the argument
is treated as a signed value, so bitcast the argument/result as needed.

Bug: 42250952
Change-Id: Iec13d8771c9f0f4b42610cf442a8bbf1b1bb9d0b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/222754
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc
index 8acd2d0..78d0254 100644
--- a/src/tint/lang/spirv/builtin_fn.cc
+++ b/src/tint/lang/spirv/builtin_fn.cc
@@ -122,6 +122,8 @@
             return "smin";
         case BuiltinFn::kSclamp:
             return "sclamp";
+        case BuiltinFn::kUmax:
+            return "umax";
         case BuiltinFn::kSdot:
             return "sdot";
         case BuiltinFn::kUdot:
@@ -182,6 +184,7 @@
         case BuiltinFn::kSmax:
         case BuiltinFn::kSmin:
         case BuiltinFn::kSclamp:
+        case BuiltinFn::kUmax:
             break;
     }
     return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl
index 03ab304..c20bf32 100644
--- a/src/tint/lang/spirv/builtin_fn.cc.tmpl
+++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl
@@ -79,6 +79,7 @@
         case BuiltinFn::kSmax:
         case BuiltinFn::kSmin:
         case BuiltinFn::kSclamp:
+        case BuiltinFn::kUmax:
             break;
     }
     return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h
index 07afe1d..c34ecac 100644
--- a/src/tint/lang/spirv/builtin_fn.h
+++ b/src/tint/lang/spirv/builtin_fn.h
@@ -88,6 +88,7 @@
     kSmax,
     kSmin,
     kSclamp,
+    kUmax,
     kSdot,
     kUdot,
     kNone,
diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc
index fd411aa..ff0dd41 100644
--- a/src/tint/lang/spirv/intrinsic/data.cc
+++ b/src/tint/lang/spirv/intrinsic/data.cc
@@ -5675,12 +5675,19 @@
   },
   {
     /* [40] */
+    /* fn umax<R : iu32>[T : iu32, U : iu32](T, U) -> R */
+    /* fn umax<R : iu32>[T : iu32, U : iu32, N : num](vec<N, T>, vec<N, U>) -> vec<N, R> */
+    /* num overloads */ 2,
+    /* overloads */ OverloadIndex(156),
+  },
+  {
+    /* [41] */
     /* fn sdot(u32, u32, u32) -> i32 */
     /* num overloads */ 1,
     /* overloads */ OverloadIndex(171),
   },
   {
-    /* [41] */
+    /* [42] */
     /* fn udot(u32, u32, u32) -> u32 */
     /* num overloads */ 1,
     /* overloads */ OverloadIndex(172),
diff --git a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
index f135bc1..0d67a1e 100644
--- a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
@@ -646,23 +646,6 @@
 )");
 }
 
-TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_UMax) {
-    EXPECT_IR(Preamble() + R"(
-     %1 = OpExtInst %int %glsl UMax %int_30 %int_35
-     %2 = OpExtInst %v2int %glsl UMax %v2int_30_40 %v2int_40_30
-     OpReturn
-     OpFunctionEnd
-  )",
-              R"(
-%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
-  $B1: {
-    let x_1 = bitcast<i32>(max(bitcast<u32>(i1), bitcast<u32>(i2)));
-    let x_2 = bitcast<vec2i>(max(bitcast<vec2u>(v2i1), bitcast<vec2u>(v2i2)));
-  }
-}
-)");
-}
-
 TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_UMin) {
     EXPECT_IR(Preamble() + R"(
      %1 = OpExtInst %int %glsl UMin %int_30 %int_35
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index c0dfd05..3cd419e 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -74,13 +74,16 @@
                     Abs(builtin);
                     break;
                 case spirv::BuiltinFn::kSmax:
-                    Max(builtin);
+                    SMax(builtin);
                     break;
                 case spirv::BuiltinFn::kSmin:
-                    Min(builtin);
+                    SMin(builtin);
                     break;
                 case spirv::BuiltinFn::kSclamp:
-                    Clamp(builtin);
+                    SClamp(builtin);
+                    break;
+                case spirv::BuiltinFn::kUmax:
+                    UMax(builtin);
                     break;
                 default:
                     TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
@@ -112,7 +115,6 @@
             auto* new_call = b.Call(result_ty, func, new_args);
 
             core::ir::Value* replacement = new_call->Result(0);
-
             if (result_ty->DeepestElement() == ty.u32()) {
                 new_call->Result(0)->SetType(ty.MatchWidth(ty.i32(), result_ty));
                 replacement = b.Bitcast(result_ty, replacement)->Result(0);
@@ -126,12 +128,49 @@
         WrapSignedSpirvMethods(call, core::BuiltinFn::kSign);
     }
     void Abs(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kAbs); }
-    void Max(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMax); }
-    void Min(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMin); }
-    void Clamp(spirv::ir::BuiltinCall* call) {
+    void SMax(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMax); }
+    void SMin(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMin); }
+    void SClamp(spirv::ir::BuiltinCall* call) {
         WrapSignedSpirvMethods(call, core::BuiltinFn::kClamp);
     }
 
+    // The SPIR-V Unsigned methods all interpret their arguments as unsigned (regardless of the type
+    // of the argument). In order to satisfy this, we must bitcast any signed argument to an
+    // unsigned type before calling the WGSL equivalent method.
+    //
+    // The result of the WGSL method will match the arguments, or in this case an unsigned value. If
+    // the SPIR-V instruction expected a signed result we must bitcast the WGSL result to the
+    // correct signed type.
+    void WrapUnsignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BuiltinFn func) {
+        auto args = call->Args();
+
+        b.InsertBefore(call, [&] {
+            auto* result_ty = call->Result(0)->Type();
+            Vector<core::ir::Value*, 2> new_args;
+
+            for (auto* arg : args) {
+                if (arg->Type()->IsSignedIntegerScalarOrVector()) {
+                    arg = b.Bitcast(ty.MatchWidth(ty.u32(), result_ty), arg)->Result(0);
+                }
+                new_args.Push(arg);
+            }
+
+            auto* new_call = b.Call(result_ty, func, new_args);
+
+            core::ir::Value* replacement = new_call->Result(0);
+            if (result_ty->DeepestElement() == ty.i32()) {
+                new_call->Result(0)->SetType(ty.MatchWidth(ty.u32(), result_ty));
+                replacement = b.Bitcast(result_ty, replacement)->Result(0);
+            }
+            call->Result(0)->ReplaceAllUsesWith(replacement);
+        });
+        call->Destroy();
+    }
+
+    void UMax(spirv::ir::BuiltinCall* call) {
+        WrapUnsignedSpirvMethods(call, core::BuiltinFn::kMax);
+    }
+
     void Normalize(spirv::ir::BuiltinCall* call) {
         auto* arg = call->Args()[0];
 
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index 2af04bc..6e87fd1 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -1170,10 +1170,10 @@
     return out;
 }
 
-using SpirvParser_BuiltinsTest_TwoParam =
+using SpirvParser_BuiltinsTest_TwoParamSigned =
     core::ir::transform::TransformTestWithParam<SpirvReaderParams>;
 
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, UnsignedToUnsigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, UnsignedToUnsigned) {
     auto& params = GetParam();
 
     auto* ep = b.ComputeFunction("foo");
@@ -1222,7 +1222,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, SignedToSigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, SignedToSigned) {
     auto params = GetParam();
 
     auto* ep = b.ComputeFunction("foo");
@@ -1265,7 +1265,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, MixedToUnsigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, MixedToUnsigned) {
     auto params = GetParam();
 
     auto* ep = b.ComputeFunction("foo");
@@ -1312,7 +1312,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, MixedToSigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, MixedToSigned) {
     auto params = GetParam();
 
     auto* ep = b.ComputeFunction("foo");
@@ -1358,10 +1358,201 @@
 }
 
 INSTANTIATE_TEST_SUITE_P(SpirvReader,
-                         SpirvParser_BuiltinsTest_TwoParam,
+                         SpirvParser_BuiltinsTest_TwoParamSigned,
                          ::testing::Values(SpirvReaderParams{spirv::BuiltinFn::kSmax, "max"},
                                            SpirvReaderParams{spirv::BuiltinFn::kSmin, "min"}));
 
+using SpirvParser_BuiltinsTest_TwoParamUnsigned =
+    core::ir::transform::TransformTestWithParam<SpirvReaderParams>;
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, UnsignedToUnsigned) {
+    auto& params = GetParam();
+
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.u32(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()}, 10_u, 15_u);
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.vec2<u32>(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()},
+            b.Splat(ty.vec2<u32>(), 10_u), b.Splat(ty.vec2<u32>(), 15_u));
+        b.Return(ep);
+    });
+
+    auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = spirv.u)" +
+               params.name + R"(<u32> 10u, 15u
+    %3:vec2<u32> = spirv.u)" +
+               params.name + R"(<u32> vec2<u32>(10u), vec2<u32>(15u)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = )" + params.name +
+                  R"( 10u, 15u
+    %3:vec2<u32> = )" +
+                  params.name + R"( vec2<u32>(10u), vec2<u32>(15u)
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, SignedToSigned) {
+    auto params = GetParam();
+
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.i32(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()}, 10_i, 15_i);
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.vec2<i32>(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()},
+            b.Splat(ty.vec2<i32>(), 10_i), b.Splat(ty.vec2<i32>(), 15_i));
+        b.Return(ep);
+    });
+
+    auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = spirv.u)" +
+               params.name + R"(<i32> 10i, 15i
+    %3:vec2<i32> = spirv.u)" +
+               params.name + R"(<i32> vec2<i32>(10i), vec2<i32>(15i)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = bitcast 10i
+    %3:u32 = bitcast 15i
+    %4:u32 = )" + params.name +
+                  R"( %2, %3
+    %5:i32 = bitcast %4
+    %6:vec2<u32> = bitcast vec2<i32>(10i)
+    %7:vec2<u32> = bitcast vec2<i32>(15i)
+    %8:vec2<u32> = )" +
+                  params.name + R"( %6, %7
+    %9:vec2<i32> = bitcast %8
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, MixedToUnsigned) {
+    auto params = GetParam();
+
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.u32(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()}, 10_i, 10_u);
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.vec2<u32>(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()},
+            b.Splat(ty.vec2<i32>(), 10_i), b.Splat(ty.vec2<u32>(), 10_u));
+        b.Return(ep);
+    });
+
+    auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = spirv.u)" +
+               params.name + R"(<u32> 10i, 10u
+    %3:vec2<u32> = spirv.u)" +
+               params.name + R"(<u32> vec2<i32>(10i), vec2<u32>(10u)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = bitcast 10i
+    %3:u32 = )" + params.name +
+                  R"( %2, 10u
+    %4:vec2<u32> = bitcast vec2<i32>(10i)
+    %5:vec2<u32> = )" +
+                  params.name + R"( %4, vec2<u32>(10u)
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, MixedToSigned) {
+    auto params = GetParam();
+
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.i32(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()}, 10_u, 10_i);
+        b.CallExplicit<spirv::ir::BuiltinCall>(
+            ty.vec2<i32>(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()},
+            b.Splat(ty.vec2<u32>(), 10_u), b.Splat(ty.vec2<i32>(), 10_i));
+        b.Return(ep);
+    });
+
+    auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = spirv.u)" +
+               params.name + R"(<i32> 10u, 10i
+    %3:vec2<i32> = spirv.u)" +
+               params.name + R"(<i32> vec2<u32>(10u), vec2<i32>(10i)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = bitcast 10i
+    %3:u32 = )" + params.name +
+                  R"( 10u, %2
+    %4:i32 = bitcast %3
+    %5:vec2<u32> = bitcast vec2<i32>(10i)
+    %6:vec2<u32> = )" +
+                  params.name + R"( vec2<u32>(10u), %5
+    %7:vec2<i32> = bitcast %6
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+INSTANTIATE_TEST_SUITE_P(SpirvReader,
+                         SpirvParser_BuiltinsTest_TwoParamUnsigned,
+                         ::testing::Values(SpirvReaderParams{spirv::BuiltinFn::kUmax, "max"}));
+
 TEST_F(SpirvParser_BuiltinsTest, SClamp_UnsignedToUnsigned) {
     auto* ep = b.ComputeFunction("foo");
 
diff --git a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
index 4a4f22d..32010c6 100644
--- a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
@@ -373,7 +373,8 @@
 INSTANTIATE_TEST_SUITE_P(SpirvParser,
                          GlslStd450TwoParamTest,
                          ::testing::Values(GlslStd450TwoParams{"SMax", "smax"},
-                                           GlslStd450TwoParams{"SMin", "smin"}));
+                                           GlslStd450TwoParams{"SMin", "smin"},
+                                           GlslStd450TwoParams{"UMax", "umax"}));
 
 TEST_F(SpirvParserTest, GlslStd450_SClamp_UnsignedToUnsigned) {
     EXPECT_IR(Preamble() + R"(
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 15de75e..9258df4 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -699,7 +699,6 @@
             case GLSLstd450NMin:
             case GLSLstd450FMin:  // FMin is less prescriptive about NaN operands
                 return core::BuiltinFn::kMin;
-            case GLSLstd450UMax:
             case GLSLstd450NMax:
             case GLSLstd450FMax:  // FMax is less prescriptive about NaN operands
                 return core::BuiltinFn::kMax;
@@ -770,6 +769,8 @@
                 return spirv::BuiltinFn::kSmin;
             case GLSLstd450SClamp:
                 return spirv::BuiltinFn::kSclamp;
+            case GLSLstd450UMax:
+                return spirv::BuiltinFn::kUmax;
             default:
                 break;
         }
@@ -780,7 +781,7 @@
                                                                 const core::type::Type* result_ty) {
         if (ext_opcode == GLSLstd450SSign || ext_opcode == GLSLstd450SAbs ||
             ext_opcode == GLSLstd450SMax || ext_opcode == GLSLstd450SMin ||
-            ext_opcode == GLSLstd450SClamp) {
+            ext_opcode == GLSLstd450SClamp || ext_opcode == GLSLstd450UMax) {
             return {result_ty->DeepestElement()};
         }
         return {};
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def
index 5efb7b5..1202a8f 100644
--- a/src/tint/lang/spirv/spirv.def
+++ b/src/tint/lang/spirv/spirv.def
@@ -338,6 +338,9 @@
 implicit(T: iu32, U: iu32, V: iu32) fn sclamp<R: iu32>(T, U, V) -> R
 implicit(T: iu32, U: iu32, V: iu32, N: num) fn sclamp<R: iu32>(vec<N, T>, vec<N, U>,  vec<N, V>) -> vec<N, R>
 
+implicit(T: iu32, U: iu32) fn umax<R: iu32>(T, U) -> R
+implicit(T: iu32, U: iu32, N: num) fn umax<R: iu32>(vec<N, T>, vec<N, U>) -> vec<N, R>
+
 ////////////////////////////////////////////////////////////////////////////////
 // SPV_KHR_integer_dot_product instructions
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 402695a..e768da5 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1387,6 +1387,9 @@
             case spirv::BuiltinFn::kSclamp:
                 ext_inst(GLSLstd450SClamp);
                 break;
+            case spirv::BuiltinFn::kUmax:
+                ext_inst(GLSLstd450UMax);
+                break;
             case spirv::BuiltinFn::kNormalize:
                 ext_inst(GLSLstd450Normalize);
                 break;