[msl] Avoid UB for left shift of negative integers Cast to unsigned and then back again. Enable and fixup the related unit tests. Bug: 42251016 Change-Id: Ibcf481ee74bcefa224775bc93a345ff5357d1e3e Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/205417 Commit-Queue: James Price <jrprice@google.com> Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/binary_test.cc b/src/tint/lang/msl/writer/binary_test.cc index 196f8ca..f5ece1d 100644 --- a/src/tint/lang/msl/writer/binary_test.cc +++ b/src/tint/lang/msl/writer/binary_test.cc
@@ -260,13 +260,13 @@ testing::ValuesIn(signed_overflow_defined_behaviour_cases)); using MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour = MslWriterTestWithParam<BinaryData>; -TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, DISABLED_Emit) { +TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, Emit) { auto params = GetParam(); auto* func = b.Function("foo", ty.void_()); b.Append(func->Block(), [&] { - auto* l = b.Let("a", b.Constant(1_i)); - auto* r = b.Let("b", b.Constant(2_u)); + auto* l = b.Let("left", b.Constant(1_i)); + auto* r = b.Let("right", b.Constant(2_u)); auto* bin = b.Binary(params.op, ty.i32(), l, r); b.Let("val", bin); b.Return(func); @@ -284,26 +284,23 @@ } constexpr BinaryData shift_signed_overflow_defined_behaviour_cases[] = { - {"as_type<int>((as_type<uint>(left) << right))", core::BinaryOp::kShiftLeft}, - {"(left >> right)", core::BinaryOp::kShiftRight}}; + {"as_type<int>((as_type<uint>(left) << (right & 31u)))", core::BinaryOp::kShiftLeft}, + {"(left >> (right & 31u))", core::BinaryOp::kShiftRight}}; INSTANTIATE_TEST_SUITE_P(MslWriterTest, MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, testing::ValuesIn(shift_signed_overflow_defined_behaviour_cases)); using MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained = MslWriterTestWithParam<BinaryData>; -TEST_P(MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, DISABLED_Emit) { +TEST_P(MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, Emit) { auto params = GetParam(); auto* func = b.Function("foo", ty.void_()); b.Append(func->Block(), [&] { - auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, i32>()); - auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, i32>()); - - auto* l = b.Load(left); - auto* r = b.Load(right); - auto* expr1 = b.Binary(params.op, ty.i32(), l, r); - auto* expr2 = b.Binary(params.op, ty.i32(), expr1, r); + auto* left = b.Let("left", 1_i); + auto* right = b.Let("right", 2_i); + auto* expr1 = b.Binary(params.op, ty.i32(), left, right); + auto* expr2 = b.Binary(params.op, ty.i32(), expr1, right); b.Let("val", expr2); b.Return(func); @@ -312,22 +309,19 @@ ASSERT_TRUE(Generate()) << err_ << output_.msl; EXPECT_EQ(output_.msl, MetalHeader() + R"( void foo() { - int const left = 0; - int const right = 0; - int const v = right; + int const left = 1; + int const right = 2; int const val = )" + params.result + R"(; +} )"); } constexpr BinaryData signed_overflow_defined_behaviour_chained_cases[] = { - {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) + - as_type<uint>(right))))", + {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) + as_type<uint>(right))))", core::BinaryOp::kAdd}, - {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) - - as_type<uint>(right))))", + {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) - as_type<uint>(right))))", core::BinaryOp::kSubtract}, - {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) * - as_type<uint>(right))))", + {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) * as_type<uint>(right))))", core::BinaryOp::kMultiply}}; INSTANTIATE_TEST_SUITE_P(MslWriterTest, MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, @@ -335,18 +329,15 @@ using MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained = MslWriterTestWithParam<BinaryData>; -TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained, DISABLED_Emit) { +TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained, Emit) { auto params = GetParam(); auto* func = b.Function("foo", ty.void_()); b.Append(func->Block(), [&] { - auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, i32>()); - auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, u32>()); - - auto* l = b.Load(left); - auto* r = b.Load(right); - auto* expr1 = b.Binary(params.op, ty.i32(), l, r); - auto* expr2 = b.Binary(params.op, ty.i32(), expr1, r); + auto* left = b.Let("left", b.Constant(1_i)); + auto* right = b.Let("right", b.Constant(2_u)); + auto* expr1 = b.Binary(params.op, ty.i32(), left, right); + auto* expr2 = b.Binary(params.op, ty.i32(), expr1, right); b.Let("val", expr2); b.Return(func); @@ -355,17 +346,17 @@ ASSERT_TRUE(Generate()) << err_ << output_.msl; EXPECT_EQ(output_.msl, MetalHeader() + R"( void foo() { - int left = 0; - uint right = 0u; - uint const v = right; + int const left = 1; + uint const right = 2u; int const val = )" + params.result + R"(; +} )"); } constexpr BinaryData shift_signed_overflow_defined_behaviour_chained_cases[] = { - {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << right))) << right)))", + {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << (right & 31u)))) << (right & 31u))))", core::BinaryOp::kShiftLeft}, - {R"(((left >> right) >> right))", core::BinaryOp::kShiftRight}, + {R"(((left >> (right & 31u)) >> (right & 31u)))", core::BinaryOp::kShiftRight}, }; INSTANTIATE_TEST_SUITE_P(MslWriterTest, MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained,
diff --git a/src/tint/lang/msl/writer/raise/binary_polyfill.cc b/src/tint/lang/msl/writer/raise/binary_polyfill.cc index e2f8dc6..1e0240c 100644 --- a/src/tint/lang/msl/writer/raise/binary_polyfill.cc +++ b/src/tint/lang/msl/writer/raise/binary_polyfill.cc
@@ -53,6 +53,7 @@ Vector<core::ir::CoreBinary*, 4> fmod_worklist; Vector<core::ir::CoreBinary*, 4> logical_bool_worklist; Vector<core::ir::CoreBinary*, 4> signed_integer_arithmetic_worklist; + Vector<core::ir::CoreBinary*, 4> signed_integer_leftshift_worklist; for (auto* inst : ir.Instructions()) { if (auto* binary = inst->As<core::ir::CoreBinary>()) { auto op = binary->Op(); @@ -66,6 +67,9 @@ op == core::BinaryOp::kSubtract) && lhs_type->IsSignedIntegerScalarOrVector()) { signed_integer_arithmetic_worklist.Push(binary); + } else if (op == core::BinaryOp::kShiftLeft && + lhs_type->IsSignedIntegerScalarOrVector()) { + signed_integer_leftshift_worklist.Push(binary); } } } @@ -80,6 +84,9 @@ for (auto* signed_arith : signed_integer_arithmetic_worklist) { SignedIntegerArithmetic(signed_arith); } + for (auto* signed_shift_left : signed_integer_leftshift_worklist) { + SignedIntegerShiftLeft(signed_shift_left); + } } /// Replace a floating point modulo binary instruction with the equivalent MSL intrinsic. @@ -127,6 +134,24 @@ }); binary->Destroy(); } + + /// Replace a signed integer shift left instruction. + /// @param binary the signed integer shift left instruction + void SignedIntegerShiftLeft(core::ir::CoreBinary* binary) { + // Left-shifting a negative integer is undefined behavior in C++14 and therefore potentially + // in MSL too, so we bitcast to an unsigned integer, perform the shift, and bitcast the + // result back to a signed integer. + auto* signed_ty = binary->Result(0)->Type(); + auto* unsigned_ty = ty.match_width(ty.u32(), signed_ty); + b.InsertBefore(binary, [&] { + auto* unsigned_lhs = b.Bitcast(unsigned_ty, binary->LHS()); + auto* unsigned_binary = + b.Binary(binary->Op(), unsigned_ty, unsigned_lhs, binary->RHS()); + auto* bitcast = b.Bitcast(signed_ty, unsigned_binary); + binary->Result(0)->ReplaceAllUsesWith(bitcast->Result(0)); + }); + binary->Destroy(); + } }; } // namespace
diff --git a/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc b/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc index 8560849..5de21bf 100644 --- a/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc +++ b/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
@@ -507,5 +507,77 @@ EXPECT_EQ(expect, str()); } +TEST_F(MslWriter_BinaryPolyfillTest, IntShift_Scalar) { + auto* lhs = b.FunctionParam<i32>("lhs"); + auto* rhs = b.FunctionParam<u32>("rhs"); + auto* func = b.Function("foo", ty.i32()); + func->SetParams({lhs, rhs}); + b.Append(func->Block(), [&] { + auto* result = b.ShiftLeft<i32>(lhs, rhs); + b.Return(func, result); + }); + + auto* src = R"( +%foo = func(%lhs:i32, %rhs:u32):i32 { + $B1: { + %4:i32 = shl %lhs, %rhs + ret %4 + } +} +)"; + EXPECT_EQ(src, str()); + + auto* expect = R"( +%foo = func(%lhs:i32, %rhs:u32):i32 { + $B1: { + %4:u32 = bitcast %lhs + %5:u32 = shl %4, %rhs + %6:i32 = bitcast %5 + ret %6 + } +} +)"; + + Run(BinaryPolyfill); + + EXPECT_EQ(expect, str()); +} + +TEST_F(MslWriter_BinaryPolyfillTest, IntShift_Vector) { + auto* lhs = b.FunctionParam<vec4<i32>>("lhs"); + auto* rhs = b.FunctionParam<vec4<u32>>("rhs"); + auto* func = b.Function("foo", ty.vec4<i32>()); + func->SetParams({lhs, rhs}); + b.Append(func->Block(), [&] { + auto* result = b.ShiftLeft<vec4<i32>>(lhs, rhs); + b.Return(func, result); + }); + + auto* src = R"( +%foo = func(%lhs:vec4<i32>, %rhs:vec4<u32>):vec4<i32> { + $B1: { + %4:vec4<i32> = shl %lhs, %rhs + ret %4 + } +} +)"; + EXPECT_EQ(src, str()); + + auto* expect = R"( +%foo = func(%lhs:vec4<i32>, %rhs:vec4<u32>):vec4<i32> { + $B1: { + %4:vec4<u32> = bitcast %lhs + %5:vec4<u32> = shl %4, %rhs + %6:vec4<i32> = bitcast %5 + ret %6 + } +} +)"; + + Run(BinaryPolyfill); + + EXPECT_EQ(expect, str()); +} + } // namespace } // namespace tint::msl::writer::raise
diff --git a/test/tint/bug/tint/1542.wgsl.expected.ir.msl b/test/tint/bug/tint/1542.wgsl.expected.ir.msl index 31c432a..c9b2411 100644 --- a/test/tint/bug/tint/1542.wgsl.expected.ir.msl +++ b/test/tint/bug/tint/1542.wgsl.expected.ir.msl
@@ -24,5 +24,5 @@ kernel void tint_symbol(const constant UniformBuffer_packed_vec3* u_input [[buffer(0)]]) { tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.u_input=u_input}; - int3 const temp = (int3((*tint_module_vars.u_input).d) << (uint3(0u) & uint3(31u))); + int3 const temp = as_type<int3>((as_type<uint3>(int3((*tint_module_vars.u_input).d)) << (uint3(0u) & uint3(31u)))); }
diff --git a/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl b/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl index 96b4b6f..387b721 100644 --- a/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl +++ b/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
@@ -4,5 +4,5 @@ kernel void f() { int const a = 1; uint const b = 2u; - int const r = (a << (b & 31u)); + int const r = as_type<int>((as_type<uint>(a) << (b & 31u))); }
diff --git a/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl b/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl index d9345f3..a559846 100644 --- a/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl +++ b/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
@@ -4,5 +4,5 @@ kernel void f() { int3 const a = int3(1, 2, 3); uint3 const b = uint3(4u, 5u, 6u); - int3 const r = (a << (b & uint3(31u))); + int3 const r = as_type<int3>((as_type<uint3>(a) << (b & uint3(31u)))); }
diff --git a/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl b/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl index a19d91c..c13ebf7 100644 --- a/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl +++ b/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
@@ -10,5 +10,5 @@ }; void foo(tint_module_vars_struct tint_module_vars) { - (*tint_module_vars.v).a = ((*tint_module_vars.v).a << (2u & 31u)); + (*tint_module_vars.v).a = as_type<int>((as_type<uint>((*tint_module_vars.v).a) << (2u & 31u))); }
diff --git a/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl b/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl index 5aeec41..f4b449f 100644 --- a/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl +++ b/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
@@ -10,5 +10,5 @@ }; void foo(tint_module_vars_struct tint_module_vars) { - (*tint_module_vars.v).a = ((*tint_module_vars.v).a << (uint4(2u) & uint4(31u))); + (*tint_module_vars.v).a = as_type<int4>((as_type<uint4>((*tint_module_vars.v).a) << (uint4(2u) & uint4(31u)))); }