[spirv-reader][ir] Correctly handle GLSL 450 SAbs

The SPIR-V `SAbs` treats all arguments as if they were an `i32`. The
argument type does not need to match the return type. This differs from
WGSL where the types must match.

Bug: 42250952
Change-Id: Ied37500361cdff635997dcb4313d90cc0614060c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/221735
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc
index 6016b84..462d047 100644
--- a/src/tint/lang/spirv/builtin_fn.cc
+++ b/src/tint/lang/spirv/builtin_fn.cc
@@ -114,6 +114,8 @@
             return "inverse";
         case BuiltinFn::kSign:
             return "sign";
+        case BuiltinFn::kAbs:
+            return "abs";
         case BuiltinFn::kSdot:
             return "sdot";
         case BuiltinFn::kUdot:
@@ -170,6 +172,7 @@
         case BuiltinFn::kNormalize:
         case BuiltinFn::kInverse:
         case BuiltinFn::kSign:
+        case BuiltinFn::kAbs:
             break;
     }
     return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl
index 7ffcae8..d4dead77 100644
--- a/src/tint/lang/spirv/builtin_fn.cc.tmpl
+++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl
@@ -75,6 +75,7 @@
         case BuiltinFn::kNormalize:
         case BuiltinFn::kInverse:
         case BuiltinFn::kSign:
+        case BuiltinFn::kAbs:
             break;
     }
     return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h
index 55084ef..7b5b2c6 100644
--- a/src/tint/lang/spirv/builtin_fn.h
+++ b/src/tint/lang/spirv/builtin_fn.h
@@ -84,6 +84,7 @@
     kNormalize,
     kInverse,
     kSign,
+    kAbs,
     kSdot,
     kUdot,
     kNone,
diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc
index 30b644d..deee63b 100644
--- a/src/tint/lang/spirv/intrinsic/data.cc
+++ b/src/tint/lang/spirv/intrinsic/data.cc
@@ -5503,12 +5503,19 @@
   },
   {
     /* [36] */
+    /* fn abs<R : iu32>[T : iu32](T) -> R */
+    /* fn abs<R : iu32>[T : iu32, N : num](vec<N, T>) -> vec<N, R> */
+    /* num overloads */ 2,
+    /* overloads */ OverloadIndex(154),
+  },
+  {
+    /* [37] */
     /* fn sdot(u32, u32, u32) -> i32 */
     /* num overloads */ 1,
     /* overloads */ OverloadIndex(167),
   },
   {
-    /* [37] */
+    /* [38] */
     /* fn udot(u32, u32, u32) -> u32 */
     /* num overloads */ 1,
     /* overloads */ OverloadIndex(168),
diff --git a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
index f2cc8f3..c69e6b3 100644
--- a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
@@ -696,8 +696,7 @@
 
 INSTANTIATE_TEST_SUITE_P(SpirvReader,
                          SpirvReaderTest_GlslStd450_Inting_Inting,
-                         ::testing::Values(GlslStd450Case{"SAbs", "abs"},
-                                           GlslStd450Case{"FindILsb", "firstTrailingBit"},
+                         ::testing::Values(GlslStd450Case{"FindILsb", "firstTrailingBit"},
                                            GlslStd450Case{"FindSMsb", "firstLeadingBit"},
                                            GlslStd450Case{"SSign", "sign"}));
 
@@ -932,26 +931,6 @@
 )");
 }
 
-// Check that we convert signedness of operands and result type.
-// This is needed for each of the integer-based extended instructions.
-
-TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_SAbs) {
-    EXPECT_IR(Preamble() + R"(
-     %1 = OpExtInst %uint %glsl SAbs %uint_10
-     %2 = OpExtInst %v2uint %glsl SAbs %v2uint_10_20
-     OpReturn
-     OpFunctionEnd
-  )",
-              R"(
-%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
-  $B1: {
-    let x_1 = bitcast<u32>(abs(bitcast<i32>(u1)));
-    let x_2 = bitcast<vec2u>(abs(bitcast<vec2i>(v2u1)));
-  }
-}
-)");
-}
-
 TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_SMax) {
     EXPECT_IR(Preamble() + R"(
      %1 = OpExtInst %uint %glsl SMax %uint_10 %uint_15
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 4746914..1680866 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -70,6 +70,9 @@
                 case spirv::BuiltinFn::kSign:
                     Sign(builtin);
                     break;
+                case spirv::BuiltinFn::kAbs:
+                    Abs(builtin);
+                    break;
                 default:
                     TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
             }
@@ -101,6 +104,28 @@
         call->Destroy();
     }
 
+    void Abs(spirv::ir::BuiltinCall* call) {
+        auto* arg = call->Args()[0];
+
+        b.InsertBefore(call, [&] {
+            auto* result_ty = call->Result(0)->Type();
+            if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
+                arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result(0);
+            }
+            auto* new_call =
+                b.Call(result_ty, core::BuiltinFn::kAbs, Vector<core::ir::Value*, 1>{arg});
+
+            core::ir::Value* replacement = new_call->Result(0);
+            // If the call is a `u32` result type, we need to cast it to `i32`.
+            if (result_ty->DeepestElement() == ty.u32()) {
+                new_call->Result(0)->SetType(ty.MatchWidth(ty.i32(), result_ty));
+                replacement = b.Bitcast(result_ty, replacement)->Result(0);
+            }
+            call->Result(0)->ReplaceAllUsesWith(replacement);
+        });
+        call->Destroy();
+    }
+
     void Normalize(spirv::ir::BuiltinCall* call) {
         auto* arg = call->Args()[0];
 
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index d622215..c30e589 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -1003,5 +1003,161 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(SpirvParser_BuiltinsTest, SAbs_UnsignedToUnsigned) {
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.u32()}, 10_u);
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<u32>(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.u32()},
+                                               b.Splat(ty.vec2<u32>(), 10_u));
+        b.Return(ep);
+    });
+
+    auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = spirv.abs<u32> 10u
+    %3:vec2<u32> = spirv.abs<u32> vec2<u32>(10u)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = bitcast 10u
+    %3:i32 = abs %2
+    %4:u32 = bitcast %3
+    %5:vec2<i32> = bitcast vec2<u32>(10u)
+    %6:vec2<i32> = abs %5
+    %7:vec2<u32> = bitcast %6
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, SAbs_UnsignedToSigned) {
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.i32(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.i32()}, 10_u);
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<i32>(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.i32()},
+                                               b.Splat(ty.vec2<u32>(), 10_u));
+        b.Return(ep);
+    });
+
+    auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = spirv.abs<i32> 10u
+    %3:vec2<i32> = spirv.abs<i32> vec2<u32>(10u)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = bitcast 10u
+    %3:i32 = abs %2
+    %4:vec2<i32> = bitcast vec2<u32>(10u)
+    %5:vec2<i32> = abs %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, SAbs_SignedToSigned) {
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.i32(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.i32()}, 10_i);
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<i32>(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.i32()},
+                                               b.Splat(ty.vec2<i32>(), 10_i));
+        b.Return(ep);
+    });
+
+    auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = spirv.abs<i32> 10i
+    %3:vec2<i32> = spirv.abs<i32> vec2<i32>(10i)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = abs 10i
+    %3:vec2<i32> = abs vec2<i32>(10i)
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, SAbs_SignedToUnsigned) {
+    auto* ep = b.ComputeFunction("foo");
+
+    b.Append(ep->Block(), [&] {  //
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.u32()}, 10_i);
+        b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<u32>(), spirv::BuiltinFn::kAbs,
+                                               Vector<const core::type::Type*, 1>{ty.u32()},
+                                               b.Splat(ty.vec2<i32>(), 10_i));
+        b.Return(ep);
+    });
+
+    auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = spirv.abs<u32> 10i
+    %3:vec2<u32> = spirv.abs<u32> vec2<i32>(10i)
+    ret
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+    Run(Builtins);
+
+    auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = abs 10i
+    %3:u32 = bitcast %2
+    %4:vec2<i32> = abs vec2<i32>(10i)
+    %5:vec2<u32> = bitcast %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(expect, str());
+}
+
 }  // namespace
 }  // namespace tint::spirv::reader::lower
diff --git a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
index b69f6a7..26ee356 100644
--- a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
@@ -41,8 +41,12 @@
   %void = OpTypeVoid
   %voidfn = OpTypeFunction %void
 
+  %uint = OpTypeInt 32 0
+  %int = OpTypeInt 32 1
   %float = OpTypeFloat 32
 
+  %v2int = OpTypeVector %int 2
+  %v2uint = OpTypeVector %uint 2
   %v2float = OpTypeVector %float 2
   %v3float = OpTypeVector %float 3
   %v4float = OpTypeVector %float 4
@@ -50,10 +54,19 @@
   %mat3v3float = OpTypeMatrix %v3float 3
   %mat4v4float = OpTypeMatrix %v4float 4
 
+  %int_10 = OpConstant %int 10
+  %int_20 = OpConstant %int 20
+
+  %uint_10 = OpConstant %uint 10
+  %uint_20 = OpConstant %uint 20
+
   %float_50 = OpConstant %float 50
   %float_60 = OpConstant %float 60
   %float_70 = OpConstant %float 70
 
+  %v2int_10_20 = OpConstantComposite %v2int %int_10 %int_20
+  %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
+
   %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
   %v3float_50_60_70 = OpConstantComposite %v3float %float_50 %float_60 %float_70
   %v4float_50_50_50_50 = OpConstantComposite %v4float %float_50 %float_50 %float_50 %float_50
@@ -143,5 +156,93 @@
 )");
 }
 
+TEST_F(SpirvParserTest, GlslStd450_SAbs_UnsignedToUnsigned) {
+    EXPECT_IR(Preamble() + R"(
+     %1 = OpExtInst %uint %glsl SAbs %uint_10
+     %2 = OpExtInst %v2uint %glsl SAbs %v2uint_10_20
+     %3 = OpCopyObject %uint %1
+     %4 = OpCopyObject %v2uint %2
+     OpReturn
+     OpFunctionEnd
+  )",
+              R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = spirv.abs<u32> 10u
+    %3:vec2<u32> = spirv.abs<u32> vec2<u32>(10u, 20u)
+    %4:u32 = let %2
+    %5:vec2<u32> = let %3
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, GlslStd450_SAbs_UnsignedToSigned) {
+    EXPECT_IR(Preamble() + R"(
+     %1 = OpExtInst %int %glsl SAbs %uint_10
+     %2 = OpExtInst %v2int %glsl SAbs %v2uint_10_20
+     %3 = OpCopyObject %int %1
+     %4 = OpCopyObject %v2int %2
+     OpReturn
+     OpFunctionEnd
+  )",
+              R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = spirv.abs<i32> 10u
+    %3:vec2<i32> = spirv.abs<i32> vec2<u32>(10u, 20u)
+    %4:i32 = let %2
+    %5:vec2<i32> = let %3
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, GlslStd450_SAbs_SignedToUnsigned) {
+    EXPECT_IR(Preamble() + R"(
+     %1 = OpExtInst %uint %glsl SAbs %int_10
+     %2 = OpExtInst %v2uint %glsl SAbs %v2int_10_20
+     %3 = OpCopyObject %uint %1
+     %4 = OpCopyObject %v2uint %2
+     OpReturn
+     OpFunctionEnd
+  )",
+              R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:u32 = spirv.abs<u32> 10i
+    %3:vec2<u32> = spirv.abs<u32> vec2<i32>(10i, 20i)
+    %4:u32 = let %2
+    %5:vec2<u32> = let %3
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, GlslStd450_SAbs_SignedToSigned) {
+    EXPECT_IR(Preamble() + R"(
+     %1 = OpExtInst %int %glsl SAbs %int_10
+     %2 = OpExtInst %v2int %glsl SAbs %v2int_10_20
+     %3 = OpCopyObject %int %1
+     %4 = OpCopyObject %v2int %2
+     OpReturn
+     OpFunctionEnd
+  )",
+              R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = spirv.abs<i32> 10i
+    %3:vec2<i32> = spirv.abs<i32> vec2<i32>(10i, 20i)
+    %4:i32 = let %2
+    %5:vec2<i32> = let %3
+    ret
+  }
+}
+)");
+}
+
 }  // namespace
 }  // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index db2d433..ea3ce72 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -664,7 +664,6 @@
             case GLSLstd450Exp2:
                 return core::BuiltinFn::kExp2;
             case GLSLstd450FAbs:
-            case GLSLstd450SAbs:
                 return core::BuiltinFn::kAbs;
             case GLSLstd450FSign:
                 return core::BuiltinFn::kSign;
@@ -758,6 +757,8 @@
 
     spirv::BuiltinFn GetGlslStd450SpirvEquivalentFuncName(uint32_t ext_opcode) {
         switch (ext_opcode) {
+            case GLSLstd450SAbs:
+                return spirv::BuiltinFn::kAbs;
             case GLSLstd450SSign:
                 return spirv::BuiltinFn::kSign;
             case GLSLstd450Normalize:
@@ -772,7 +773,7 @@
 
     Vector<const core::type::Type*, 1> GlslStd450ExplicitParams(uint32_t ext_opcode,
                                                                 const core::type::Type* result_ty) {
-        if (ext_opcode != GLSLstd450SSign) {
+        if (ext_opcode != GLSLstd450SSign && ext_opcode != GLSLstd450SAbs) {
             return {};
         }
         return {result_ty->DeepestElement()};
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def
index 39aeff1..e3c6855 100644
--- a/src/tint/lang/spirv/spirv.def
+++ b/src/tint/lang/spirv/spirv.def
@@ -328,6 +328,9 @@
 implicit(T: iu32) fn sign<R: iu32>(T) -> R
 implicit(T: iu32, N: num) fn sign<R: iu32>(vec<N, T>) -> vec<N, R>
 
+implicit(T: iu32) fn abs<R: iu32>(T) -> R
+implicit(T: iu32, N: num) fn abs<R: iu32>(vec<N, T>) -> vec<N, R>
+
 ////////////////////////////////////////////////////////////////////////////////
 // SPV_KHR_integer_dot_product instructions
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index a37bccd..f701b80 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1286,6 +1286,9 @@
         };
 
         switch (builtin->Func()) {
+            case spirv::BuiltinFn::kAbs:
+                ext_inst(GLSLstd450SAbs);
+                break;
             case spirv::BuiltinFn::kArrayLength:
                 op = spv::Op::OpArrayLength;
                 break;