[msl] Avoid UB for left shift of negative integers

Cast to unsigned and then back again.

Enable and fixup the related unit tests.

Bug: 42251016
Change-Id: Ibcf481ee74bcefa224775bc93a345ff5357d1e3e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/205417
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/binary_test.cc b/src/tint/lang/msl/writer/binary_test.cc
index 196f8ca..f5ece1d 100644
--- a/src/tint/lang/msl/writer/binary_test.cc
+++ b/src/tint/lang/msl/writer/binary_test.cc
@@ -260,13 +260,13 @@
                          testing::ValuesIn(signed_overflow_defined_behaviour_cases));
 
 using MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour = MslWriterTestWithParam<BinaryData>;
-TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, DISABLED_Emit) {
+TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, Emit) {
     auto params = GetParam();
 
     auto* func = b.Function("foo", ty.void_());
     b.Append(func->Block(), [&] {
-        auto* l = b.Let("a", b.Constant(1_i));
-        auto* r = b.Let("b", b.Constant(2_u));
+        auto* l = b.Let("left", b.Constant(1_i));
+        auto* r = b.Let("right", b.Constant(2_u));
         auto* bin = b.Binary(params.op, ty.i32(), l, r);
         b.Let("val", bin);
         b.Return(func);
@@ -284,26 +284,23 @@
 }
 
 constexpr BinaryData shift_signed_overflow_defined_behaviour_cases[] = {
-    {"as_type<int>((as_type<uint>(left) << right))", core::BinaryOp::kShiftLeft},
-    {"(left >> right)", core::BinaryOp::kShiftRight}};
+    {"as_type<int>((as_type<uint>(left) << (right & 31u)))", core::BinaryOp::kShiftLeft},
+    {"(left >> (right & 31u))", core::BinaryOp::kShiftRight}};
 INSTANTIATE_TEST_SUITE_P(MslWriterTest,
                          MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour,
                          testing::ValuesIn(shift_signed_overflow_defined_behaviour_cases));
 
 using MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained =
     MslWriterTestWithParam<BinaryData>;
-TEST_P(MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, DISABLED_Emit) {
+TEST_P(MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, Emit) {
     auto params = GetParam();
 
     auto* func = b.Function("foo", ty.void_());
     b.Append(func->Block(), [&] {
-        auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, i32>());
-        auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, i32>());
-
-        auto* l = b.Load(left);
-        auto* r = b.Load(right);
-        auto* expr1 = b.Binary(params.op, ty.i32(), l, r);
-        auto* expr2 = b.Binary(params.op, ty.i32(), expr1, r);
+        auto* left = b.Let("left", 1_i);
+        auto* right = b.Let("right", 2_i);
+        auto* expr1 = b.Binary(params.op, ty.i32(), left, right);
+        auto* expr2 = b.Binary(params.op, ty.i32(), expr1, right);
 
         b.Let("val", expr2);
         b.Return(func);
@@ -312,22 +309,19 @@
     ASSERT_TRUE(Generate()) << err_ << output_.msl;
     EXPECT_EQ(output_.msl, MetalHeader() + R"(
 void foo() {
-  int const left = 0;
-  int const right = 0;
-  int const v = right;
+  int const left = 1;
+  int const right = 2;
   int const val = )" + params.result +
                                R"(;
+}
 )");
 }
 constexpr BinaryData signed_overflow_defined_behaviour_chained_cases[] = {
-    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) +
-    as_type<uint>(right))))",
+    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) + as_type<uint>(right))))",
      core::BinaryOp::kAdd},
-    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) -
-    as_type<uint>(right))))",
+    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) - as_type<uint>(right))))",
      core::BinaryOp::kSubtract},
-    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) *
-    as_type<uint>(right))))",
+    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) * as_type<uint>(right))))",
      core::BinaryOp::kMultiply}};
 INSTANTIATE_TEST_SUITE_P(MslWriterTest,
                          MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained,
@@ -335,18 +329,15 @@
 
 using MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained =
     MslWriterTestWithParam<BinaryData>;
-TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained, DISABLED_Emit) {
+TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained, Emit) {
     auto params = GetParam();
 
     auto* func = b.Function("foo", ty.void_());
     b.Append(func->Block(), [&] {
-        auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, i32>());
-        auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, u32>());
-
-        auto* l = b.Load(left);
-        auto* r = b.Load(right);
-        auto* expr1 = b.Binary(params.op, ty.i32(), l, r);
-        auto* expr2 = b.Binary(params.op, ty.i32(), expr1, r);
+        auto* left = b.Let("left", b.Constant(1_i));
+        auto* right = b.Let("right", b.Constant(2_u));
+        auto* expr1 = b.Binary(params.op, ty.i32(), left, right);
+        auto* expr2 = b.Binary(params.op, ty.i32(), expr1, right);
 
         b.Let("val", expr2);
         b.Return(func);
@@ -355,17 +346,17 @@
     ASSERT_TRUE(Generate()) << err_ << output_.msl;
     EXPECT_EQ(output_.msl, MetalHeader() + R"(
 void foo() {
-  int left = 0;
-  uint right = 0u;
-  uint const v = right;
+  int const left = 1;
+  uint const right = 2u;
   int const val = )" + params.result +
                                R"(;
+}
 )");
 }
 constexpr BinaryData shift_signed_overflow_defined_behaviour_chained_cases[] = {
-    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << right))) << right)))",
+    {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << (right & 31u)))) << (right & 31u))))",
      core::BinaryOp::kShiftLeft},
-    {R"(((left >> right) >> right))", core::BinaryOp::kShiftRight},
+    {R"(((left >> (right & 31u)) >> (right & 31u)))", core::BinaryOp::kShiftRight},
 };
 INSTANTIATE_TEST_SUITE_P(MslWriterTest,
                          MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained,
diff --git a/src/tint/lang/msl/writer/raise/binary_polyfill.cc b/src/tint/lang/msl/writer/raise/binary_polyfill.cc
index e2f8dc6..1e0240c 100644
--- a/src/tint/lang/msl/writer/raise/binary_polyfill.cc
+++ b/src/tint/lang/msl/writer/raise/binary_polyfill.cc
@@ -53,6 +53,7 @@
         Vector<core::ir::CoreBinary*, 4> fmod_worklist;
         Vector<core::ir::CoreBinary*, 4> logical_bool_worklist;
         Vector<core::ir::CoreBinary*, 4> signed_integer_arithmetic_worklist;
+        Vector<core::ir::CoreBinary*, 4> signed_integer_leftshift_worklist;
         for (auto* inst : ir.Instructions()) {
             if (auto* binary = inst->As<core::ir::CoreBinary>()) {
                 auto op = binary->Op();
@@ -66,6 +67,9 @@
                             op == core::BinaryOp::kSubtract) &&
                            lhs_type->IsSignedIntegerScalarOrVector()) {
                     signed_integer_arithmetic_worklist.Push(binary);
+                } else if (op == core::BinaryOp::kShiftLeft &&
+                           lhs_type->IsSignedIntegerScalarOrVector()) {
+                    signed_integer_leftshift_worklist.Push(binary);
                 }
             }
         }
@@ -80,6 +84,9 @@
         for (auto* signed_arith : signed_integer_arithmetic_worklist) {
             SignedIntegerArithmetic(signed_arith);
         }
+        for (auto* signed_shift_left : signed_integer_leftshift_worklist) {
+            SignedIntegerShiftLeft(signed_shift_left);
+        }
     }
 
     /// Replace a floating point modulo binary instruction with the equivalent MSL intrinsic.
@@ -127,6 +134,24 @@
         });
         binary->Destroy();
     }
+
+    /// Replace a signed integer shift left instruction.
+    /// @param binary the signed integer shift left instruction
+    void SignedIntegerShiftLeft(core::ir::CoreBinary* binary) {
+        // Left-shifting a negative integer is undefined behavior in C++14 and therefore potentially
+        // in MSL too, so we bitcast to an unsigned integer, perform the shift, and bitcast the
+        // result back to a signed integer.
+        auto* signed_ty = binary->Result(0)->Type();
+        auto* unsigned_ty = ty.match_width(ty.u32(), signed_ty);
+        b.InsertBefore(binary, [&] {
+            auto* unsigned_lhs = b.Bitcast(unsigned_ty, binary->LHS());
+            auto* unsigned_binary =
+                b.Binary(binary->Op(), unsigned_ty, unsigned_lhs, binary->RHS());
+            auto* bitcast = b.Bitcast(signed_ty, unsigned_binary);
+            binary->Result(0)->ReplaceAllUsesWith(bitcast->Result(0));
+        });
+        binary->Destroy();
+    }
 };
 
 }  // namespace
diff --git a/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc b/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
index 8560849..5de21bf 100644
--- a/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
+++ b/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
@@ -507,5 +507,77 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(MslWriter_BinaryPolyfillTest, IntShift_Scalar) {
+    auto* lhs = b.FunctionParam<i32>("lhs");
+    auto* rhs = b.FunctionParam<u32>("rhs");
+    auto* func = b.Function("foo", ty.i32());
+    func->SetParams({lhs, rhs});
+    b.Append(func->Block(), [&] {
+        auto* result = b.ShiftLeft<i32>(lhs, rhs);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%lhs:i32, %rhs:u32):i32 {
+  $B1: {
+    %4:i32 = shl %lhs, %rhs
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%lhs:i32, %rhs:u32):i32 {
+  $B1: {
+    %4:u32 = bitcast %lhs
+    %5:u32 = shl %4, %rhs
+    %6:i32 = bitcast %5
+    ret %6
+  }
+}
+)";
+
+    Run(BinaryPolyfill);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_BinaryPolyfillTest, IntShift_Vector) {
+    auto* lhs = b.FunctionParam<vec4<i32>>("lhs");
+    auto* rhs = b.FunctionParam<vec4<u32>>("rhs");
+    auto* func = b.Function("foo", ty.vec4<i32>());
+    func->SetParams({lhs, rhs});
+    b.Append(func->Block(), [&] {
+        auto* result = b.ShiftLeft<vec4<i32>>(lhs, rhs);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%lhs:vec4<i32>, %rhs:vec4<u32>):vec4<i32> {
+  $B1: {
+    %4:vec4<i32> = shl %lhs, %rhs
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%lhs:vec4<i32>, %rhs:vec4<u32>):vec4<i32> {
+  $B1: {
+    %4:vec4<u32> = bitcast %lhs
+    %5:vec4<u32> = shl %4, %rhs
+    %6:vec4<i32> = bitcast %5
+    ret %6
+  }
+}
+)";
+
+    Run(BinaryPolyfill);
+
+    EXPECT_EQ(expect, str());
+}
+
 }  // namespace
 }  // namespace tint::msl::writer::raise
diff --git a/test/tint/bug/tint/1542.wgsl.expected.ir.msl b/test/tint/bug/tint/1542.wgsl.expected.ir.msl
index 31c432a..c9b2411 100644
--- a/test/tint/bug/tint/1542.wgsl.expected.ir.msl
+++ b/test/tint/bug/tint/1542.wgsl.expected.ir.msl
@@ -24,5 +24,5 @@
 
 kernel void tint_symbol(const constant UniformBuffer_packed_vec3* u_input [[buffer(0)]]) {
   tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.u_input=u_input};
-  int3 const temp = (int3((*tint_module_vars.u_input).d) << (uint3(0u) & uint3(31u)));
+  int3 const temp = as_type<int3>((as_type<uint3>(int3((*tint_module_vars.u_input).d)) << (uint3(0u) & uint3(31u))));
 }
diff --git a/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl b/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
index 96b4b6f..387b721 100644
--- a/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
+++ b/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
@@ -4,5 +4,5 @@
 kernel void f() {
   int const a = 1;
   uint const b = 2u;
-  int const r = (a << (b & 31u));
+  int const r = as_type<int>((as_type<uint>(a) << (b & 31u)));
 }
diff --git a/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl b/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
index d9345f3..a559846 100644
--- a/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
+++ b/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
@@ -4,5 +4,5 @@
 kernel void f() {
   int3 const a = int3(1, 2, 3);
   uint3 const b = uint3(4u, 5u, 6u);
-  int3 const r = (a << (b & uint3(31u)));
+  int3 const r = as_type<int3>((as_type<uint3>(a) << (b & uint3(31u))));
 }
diff --git a/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl b/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
index a19d91c..c13ebf7 100644
--- a/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
+++ b/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
@@ -10,5 +10,5 @@
 };
 
 void foo(tint_module_vars_struct tint_module_vars) {
-  (*tint_module_vars.v).a = ((*tint_module_vars.v).a << (2u & 31u));
+  (*tint_module_vars.v).a = as_type<int>((as_type<uint>((*tint_module_vars.v).a) << (2u & 31u)));
 }
diff --git a/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl b/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
index 5aeec41..f4b449f 100644
--- a/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
+++ b/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
@@ -10,5 +10,5 @@
 };
 
 void foo(tint_module_vars_struct tint_module_vars) {
-  (*tint_module_vars.v).a = ((*tint_module_vars.v).a << (uint4(2u) & uint4(31u)));
+  (*tint_module_vars.v).a = as_type<int4>((as_type<uint4>((*tint_module_vars.v).a) << (uint4(2u) & uint4(31u))));
 }