[msl] Avoid UB for left shift of negative integers
Cast to unsigned and then back again.
Enable and fixup the related unit tests.
Bug: 42251016
Change-Id: Ibcf481ee74bcefa224775bc93a345ff5357d1e3e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/205417
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/binary_test.cc b/src/tint/lang/msl/writer/binary_test.cc
index 196f8ca..f5ece1d 100644
--- a/src/tint/lang/msl/writer/binary_test.cc
+++ b/src/tint/lang/msl/writer/binary_test.cc
@@ -260,13 +260,13 @@
testing::ValuesIn(signed_overflow_defined_behaviour_cases));
using MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour = MslWriterTestWithParam<BinaryData>;
-TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, DISABLED_Emit) {
+TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour, Emit) {
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* l = b.Let("a", b.Constant(1_i));
- auto* r = b.Let("b", b.Constant(2_u));
+ auto* l = b.Let("left", b.Constant(1_i));
+ auto* r = b.Let("right", b.Constant(2_u));
auto* bin = b.Binary(params.op, ty.i32(), l, r);
b.Let("val", bin);
b.Return(func);
@@ -284,26 +284,23 @@
}
constexpr BinaryData shift_signed_overflow_defined_behaviour_cases[] = {
- {"as_type<int>((as_type<uint>(left) << right))", core::BinaryOp::kShiftLeft},
- {"(left >> right)", core::BinaryOp::kShiftRight}};
+ {"as_type<int>((as_type<uint>(left) << (right & 31u)))", core::BinaryOp::kShiftLeft},
+ {"(left >> (right & 31u))", core::BinaryOp::kShiftRight}};
INSTANTIATE_TEST_SUITE_P(MslWriterTest,
MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour,
testing::ValuesIn(shift_signed_overflow_defined_behaviour_cases));
using MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained =
MslWriterTestWithParam<BinaryData>;
-TEST_P(MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, DISABLED_Emit) {
+TEST_P(MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained, Emit) {
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, i32>());
- auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, i32>());
-
- auto* l = b.Load(left);
- auto* r = b.Load(right);
- auto* expr1 = b.Binary(params.op, ty.i32(), l, r);
- auto* expr2 = b.Binary(params.op, ty.i32(), expr1, r);
+ auto* left = b.Let("left", 1_i);
+ auto* right = b.Let("right", 2_i);
+ auto* expr1 = b.Binary(params.op, ty.i32(), left, right);
+ auto* expr2 = b.Binary(params.op, ty.i32(), expr1, right);
b.Let("val", expr2);
b.Return(func);
@@ -312,22 +309,19 @@
ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, MetalHeader() + R"(
void foo() {
- int const left = 0;
- int const right = 0;
- int const v = right;
+ int const left = 1;
+ int const right = 2;
int const val = )" + params.result +
R"(;
+}
)");
}
constexpr BinaryData signed_overflow_defined_behaviour_chained_cases[] = {
- {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) +
- as_type<uint>(right))))",
+ {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) + as_type<uint>(right))))",
core::BinaryOp::kAdd},
- {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) -
- as_type<uint>(right))))",
+ {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) - as_type<uint>(right))))",
core::BinaryOp::kSubtract},
- {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) *
- as_type<uint>(right))))",
+ {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) * as_type<uint>(right))))",
core::BinaryOp::kMultiply}};
INSTANTIATE_TEST_SUITE_P(MslWriterTest,
MslWriterBinaryTest_SignedOverflowDefinedBehaviour_Chained,
@@ -335,18 +329,15 @@
using MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained =
MslWriterTestWithParam<BinaryData>;
-TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained, DISABLED_Emit) {
+TEST_P(MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained, Emit) {
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, i32>());
- auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, u32>());
-
- auto* l = b.Load(left);
- auto* r = b.Load(right);
- auto* expr1 = b.Binary(params.op, ty.i32(), l, r);
- auto* expr2 = b.Binary(params.op, ty.i32(), expr1, r);
+ auto* left = b.Let("left", b.Constant(1_i));
+ auto* right = b.Let("right", b.Constant(2_u));
+ auto* expr1 = b.Binary(params.op, ty.i32(), left, right);
+ auto* expr2 = b.Binary(params.op, ty.i32(), expr1, right);
b.Let("val", expr2);
b.Return(func);
@@ -355,17 +346,17 @@
ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, MetalHeader() + R"(
void foo() {
- int left = 0;
- uint right = 0u;
- uint const v = right;
+ int const left = 1;
+ uint const right = 2u;
int const val = )" + params.result +
R"(;
+}
)");
}
constexpr BinaryData shift_signed_overflow_defined_behaviour_chained_cases[] = {
- {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << right))) << right)))",
+ {R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << (right & 31u)))) << (right & 31u))))",
core::BinaryOp::kShiftLeft},
- {R"(((left >> right) >> right))", core::BinaryOp::kShiftRight},
+ {R"(((left >> (right & 31u)) >> (right & 31u)))", core::BinaryOp::kShiftRight},
};
INSTANTIATE_TEST_SUITE_P(MslWriterTest,
MslWriterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained,
diff --git a/src/tint/lang/msl/writer/raise/binary_polyfill.cc b/src/tint/lang/msl/writer/raise/binary_polyfill.cc
index e2f8dc6..1e0240c 100644
--- a/src/tint/lang/msl/writer/raise/binary_polyfill.cc
+++ b/src/tint/lang/msl/writer/raise/binary_polyfill.cc
@@ -53,6 +53,7 @@
Vector<core::ir::CoreBinary*, 4> fmod_worklist;
Vector<core::ir::CoreBinary*, 4> logical_bool_worklist;
Vector<core::ir::CoreBinary*, 4> signed_integer_arithmetic_worklist;
+ Vector<core::ir::CoreBinary*, 4> signed_integer_leftshift_worklist;
for (auto* inst : ir.Instructions()) {
if (auto* binary = inst->As<core::ir::CoreBinary>()) {
auto op = binary->Op();
@@ -66,6 +67,9 @@
op == core::BinaryOp::kSubtract) &&
lhs_type->IsSignedIntegerScalarOrVector()) {
signed_integer_arithmetic_worklist.Push(binary);
+ } else if (op == core::BinaryOp::kShiftLeft &&
+ lhs_type->IsSignedIntegerScalarOrVector()) {
+ signed_integer_leftshift_worklist.Push(binary);
}
}
}
@@ -80,6 +84,9 @@
for (auto* signed_arith : signed_integer_arithmetic_worklist) {
SignedIntegerArithmetic(signed_arith);
}
+ for (auto* signed_shift_left : signed_integer_leftshift_worklist) {
+ SignedIntegerShiftLeft(signed_shift_left);
+ }
}
/// Replace a floating point modulo binary instruction with the equivalent MSL intrinsic.
@@ -127,6 +134,24 @@
});
binary->Destroy();
}
+
+ /// Replace a signed integer shift left instruction.
+ /// @param binary the signed integer shift left instruction
+ void SignedIntegerShiftLeft(core::ir::CoreBinary* binary) {
+ // Left-shifting a negative integer is undefined behavior in C++14 and therefore potentially
+ // in MSL too, so we bitcast to an unsigned integer, perform the shift, and bitcast the
+ // result back to a signed integer.
+ auto* signed_ty = binary->Result(0)->Type();
+ auto* unsigned_ty = ty.match_width(ty.u32(), signed_ty);
+ b.InsertBefore(binary, [&] {
+ auto* unsigned_lhs = b.Bitcast(unsigned_ty, binary->LHS());
+ auto* unsigned_binary =
+ b.Binary(binary->Op(), unsigned_ty, unsigned_lhs, binary->RHS());
+ auto* bitcast = b.Bitcast(signed_ty, unsigned_binary);
+ binary->Result(0)->ReplaceAllUsesWith(bitcast->Result(0));
+ });
+ binary->Destroy();
+ }
};
} // namespace
diff --git a/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc b/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
index 8560849..5de21bf 100644
--- a/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
+++ b/src/tint/lang/msl/writer/raise/binary_polyfill_test.cc
@@ -507,5 +507,77 @@
EXPECT_EQ(expect, str());
}
+TEST_F(MslWriter_BinaryPolyfillTest, IntShift_Scalar) {
+ auto* lhs = b.FunctionParam<i32>("lhs");
+ auto* rhs = b.FunctionParam<u32>("rhs");
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({lhs, rhs});
+ b.Append(func->Block(), [&] {
+ auto* result = b.ShiftLeft<i32>(lhs, rhs);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%lhs:i32, %rhs:u32):i32 {
+ $B1: {
+ %4:i32 = shl %lhs, %rhs
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%lhs:i32, %rhs:u32):i32 {
+ $B1: {
+ %4:u32 = bitcast %lhs
+ %5:u32 = shl %4, %rhs
+ %6:i32 = bitcast %5
+ ret %6
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_BinaryPolyfillTest, IntShift_Vector) {
+ auto* lhs = b.FunctionParam<vec4<i32>>("lhs");
+ auto* rhs = b.FunctionParam<vec4<u32>>("rhs");
+ auto* func = b.Function("foo", ty.vec4<i32>());
+ func->SetParams({lhs, rhs});
+ b.Append(func->Block(), [&] {
+ auto* result = b.ShiftLeft<vec4<i32>>(lhs, rhs);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%lhs:vec4<i32>, %rhs:vec4<u32>):vec4<i32> {
+ $B1: {
+ %4:vec4<i32> = shl %lhs, %rhs
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%lhs:vec4<i32>, %rhs:vec4<u32>):vec4<i32> {
+ $B1: {
+ %4:vec4<u32> = bitcast %lhs
+ %5:vec4<u32> = shl %4, %rhs
+ %6:vec4<i32> = bitcast %5
+ ret %6
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::msl::writer::raise
diff --git a/test/tint/bug/tint/1542.wgsl.expected.ir.msl b/test/tint/bug/tint/1542.wgsl.expected.ir.msl
index 31c432a..c9b2411 100644
--- a/test/tint/bug/tint/1542.wgsl.expected.ir.msl
+++ b/test/tint/bug/tint/1542.wgsl.expected.ir.msl
@@ -24,5 +24,5 @@
kernel void tint_symbol(const constant UniformBuffer_packed_vec3* u_input [[buffer(0)]]) {
tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.u_input=u_input};
- int3 const temp = (int3((*tint_module_vars.u_input).d) << (uint3(0u) & uint3(31u)));
+ int3 const temp = as_type<int3>((as_type<uint3>(int3((*tint_module_vars.u_input).d)) << (uint3(0u) & uint3(31u))));
}
diff --git a/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl b/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
index 96b4b6f..387b721 100644
--- a/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
+++ b/test/tint/expressions/binary/left-shift/scalar-scalar/i32.wgsl.expected.ir.msl
@@ -4,5 +4,5 @@
kernel void f() {
int const a = 1;
uint const b = 2u;
- int const r = (a << (b & 31u));
+ int const r = as_type<int>((as_type<uint>(a) << (b & 31u)));
}
diff --git a/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl b/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
index d9345f3..a559846 100644
--- a/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
+++ b/test/tint/expressions/binary/left-shift/vector-vector/i32.wgsl.expected.ir.msl
@@ -4,5 +4,5 @@
kernel void f() {
int3 const a = int3(1, 2, 3);
uint3 const b = uint3(4u, 5u, 6u);
- int3 const r = (a << (b & uint3(31u)));
+ int3 const r = as_type<int3>((as_type<uint3>(a) << (b & uint3(31u))));
}
diff --git a/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl b/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
index a19d91c..c13ebf7 100644
--- a/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
+++ b/test/tint/statements/compound_assign/scalar/shift_left.wgsl.expected.ir.msl
@@ -10,5 +10,5 @@
};
void foo(tint_module_vars_struct tint_module_vars) {
- (*tint_module_vars.v).a = ((*tint_module_vars.v).a << (2u & 31u));
+ (*tint_module_vars.v).a = as_type<int>((as_type<uint>((*tint_module_vars.v).a) << (2u & 31u)));
}
diff --git a/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl b/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
index 5aeec41..f4b449f 100644
--- a/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
+++ b/test/tint/statements/compound_assign/vector/shift_left.wgsl.expected.ir.msl
@@ -10,5 +10,5 @@
};
void foo(tint_module_vars_struct tint_module_vars) {
- (*tint_module_vars.v).a = ((*tint_module_vars.v).a << (uint4(2u) & uint4(31u)));
+ (*tint_module_vars.v).a = as_type<int4>((as_type<uint4>((*tint_module_vars.v).a) << (uint4(2u) & uint4(31u))));
}