tint: const eval of binary right shift
Bug: tint:1581
Change-Id: I3f40454559c4fc36565de1a11a6e6c8c394fd0cc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112620
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index 179cf88..d9188ed 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -993,8 +993,8 @@
@const op << <T: ia_iu32>(T, u32) -> T
@const op << <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
-op >> <T: iu32>(T, u32) -> T
-op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
+@const op >> <T: ia_iu32>(T, u32) -> T
+@const op >> <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
////////////////////////////////////////////////////////////////////////////////
// Tint internal builtins //
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 8170adb..45ab2de 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -1950,6 +1950,70 @@
return TransformElements(builder, ty, transform, args[0], args[1]);
}
+ConstEval::Result ConstEval::OpShiftRight(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto e1, auto e2) -> ImplResult {
+ using NumberT = decltype(e1);
+ using T = UnwrapNumber<NumberT>;
+ using UT = std::make_unsigned_t<T>;
+ constexpr size_t bit_width = BitWidth<NumberT>;
+ const UT e1u = static_cast<UT>(e1);
+ const UT e2u = static_cast<UT>(e2);
+
+ auto signed_shift_right = [&] {
+ // In C++, right shift of a signed negative number is implementation-defined.
+ // Although most implementations sign-extend, we do it manually to ensure it works
+ // correctly on all implementations.
+ const UT msb = UT{1} << (bit_width - 1);
+ UT sign_ext = 0;
+ if (e1u & msb) {
+ // Set e2 + 1 bits to 1
+ UT num_shift_bits_mask = ((UT{1} << e2u) - UT{1});
+ sign_ext = (num_shift_bits_mask << (bit_width - e2u - UT{1})) | msb;
+ }
+ return static_cast<T>((e1u >> e2u) | sign_ext);
+ };
+
+ T result = 0;
+ if constexpr (IsAbstract<NumberT>) {
+ if (static_cast<size_t>(e2) >= bit_width) {
+ result = T{0};
+ } else {
+ result = signed_shift_right();
+ }
+ } else {
+ if (static_cast<size_t>(e2) >= bit_width) {
+ // At shader/pipeline-creation time, it is an error to shift by the bit width of
+ // the lhs or greater. NOTE: At runtime, we shift by e2 % (bit width of e1).
+ AddError(
+ "shift right value must be less than the bit width of the lhs, which is " +
+ std::to_string(bit_width),
+ source);
+ return utils::Failure;
+ }
+
+ if constexpr (std::is_signed_v<T>) {
+ result = signed_shift_right();
+ } else {
+ result = e1 >> e2;
+ }
+ }
+ return CreateElement(builder, source, sem::Type::DeepestElementOf(ty), NumberT{result});
+ };
+ return Dispatch_ia_iu32(create, c0, c1);
+ };
+
+ if (!sem::Type::DeepestElementOf(args[1]->Type())->Is<sem::U32>()) {
+ TINT_ICE(Resolver, builder.Diagnostics())
+ << "Element type of rhs of ShiftLeft must be a u32";
+ return utils::Failure;
+ }
+
+ return TransformElements(builder, ty, transform, args[0], args[1]);
+}
+
ConstEval::Result ConstEval::abs(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 9905036..3e6c626 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -382,6 +382,15 @@
utils::VectorRef<const sem::Constant*> args,
const Source& source);
+ /// Bitwise shift right operator '<<'
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location
+ /// @return the result value, or null if the value cannot be calculated
+ Result OpShiftRight(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 2590466..a370285 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -917,8 +917,7 @@
template <typename T>
std::vector<Case> ShiftLeftCases() {
- // Shift type is u32 for non-abstract
- using ST = std::conditional_t<IsAbstract<T>, T, u32>;
+ using ST = u32; // Shift type is u32
using B = BitValues<T>;
auto r = std::vector<Case>{
C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), //
@@ -1200,5 +1199,144 @@
ShiftLeftSignChangeErrorCases<AInt>(),
ShiftLeftSignChangeErrorCases<i32>())));
+template <typename T>
+std::vector<Case> ShiftRightCases() {
+ using B = BitValues<T>;
+ auto r = std::vector<Case>{
+ C(T{0b10101100}, u32{0}, T{0b10101100}), //
+ C(T{0b10101100}, u32{1}, T{0b01010110}), //
+ C(T{0b10101100}, u32{2}, T{0b00101011}), //
+ C(T{0b10101100}, u32{3}, T{0b00010101}), //
+ C(T{0b10101100}, u32{4}, T{0b00001010}), //
+ C(T{0b10101100}, u32{5}, T{0b00000101}), //
+ C(T{0b10101100}, u32{6}, T{0b00000010}), //
+ C(T{0b10101100}, u32{7}, T{0b00000001}), //
+ C(T{0b10101100}, u32{8}, T{0b00000000}), //
+ C(T{0b10101100}, u32{9}, T{0b00000000}), //
+ C(B::LeftMost, u32{0}, B::LeftMost), //
+ };
+
+ // msb not set, same for all types: inserted bit is 0
+ ConcatInto( //
+ r, std::vector<Case>{
+ C(T{0b01000000000000000000000010101100}, u32{0}, //
+ T{0b01000000000000000000000010101100}),
+ C(T{0b01000000000000000000000010101100}, u32{1}, //
+ T{0b00100000000000000000000001010110}),
+ C(T{0b01000000000000000000000010101100}, u32{2}, //
+ T{0b00010000000000000000000000101011}),
+ C(T{0b01000000000000000000000010101100}, u32{3}, //
+ T{0b00001000000000000000000000010101}),
+ C(T{0b01000000000000000000000010101100}, u32{4}, //
+ T{0b00000100000000000000000000001010}),
+ C(T{0b01000000000000000000000010101100}, u32{5}, //
+ T{0b00000010000000000000000000000101}),
+ C(T{0b01000000000000000000000010101100}, u32{6}, //
+ T{0b00000001000000000000000000000010}),
+ C(T{0b01000000000000000000000010101100}, u32{7}, //
+ T{0b00000000100000000000000000000001}),
+ C(T{0b01000000000000000000000010101100}, u32{8}, //
+ T{0b00000000010000000000000000000000}),
+ C(T{0b01000000000000000000000010101100}, u32{9}, //
+ T{0b00000000001000000000000000000000}),
+ });
+
+ // msb set, result differs for i32 and u32
+ if constexpr (std::is_same_v<T, u32>) {
+ // If unsigned, insert zero bits at the most significant positions.
+ ConcatInto( //
+ r, std::vector<Case>{
+ C(T{0b10000000000000000000000010101100}, u32{0},
+ T{0b10000000000000000000000010101100}),
+ C(T{0b10000000000000000000000010101100}, u32{1},
+ T{0b01000000000000000000000001010110}),
+ C(T{0b10000000000000000000000010101100}, u32{2},
+ T{0b00100000000000000000000000101011}),
+ C(T{0b10000000000000000000000010101100}, u32{3},
+ T{0b00010000000000000000000000010101}),
+ C(T{0b10000000000000000000000010101100}, u32{4},
+ T{0b00001000000000000000000000001010}),
+ C(T{0b10000000000000000000000010101100}, u32{5},
+ T{0b00000100000000000000000000000101}),
+ C(T{0b10000000000000000000000010101100}, u32{6},
+ T{0b00000010000000000000000000000010}),
+ C(T{0b10000000000000000000000010101100}, u32{7},
+ T{0b00000001000000000000000000000001}),
+ C(T{0b10000000000000000000000010101100}, u32{8},
+ T{0b00000000100000000000000000000000}),
+ C(T{0b10000000000000000000000010101100}, u32{9},
+ T{0b00000000010000000000000000000000}),
+ // msb shifted by bit width - 1
+ C(T{0b10000000000000000000000000000000}, u32{31},
+ T{0b00000000000000000000000000000001}),
+ });
+ } else if constexpr (std::is_same_v<T, i32>) {
+ // If signed, each inserted bit is 1, so the result is negative.
+ ConcatInto( //
+ r, std::vector<Case>{
+ C(T{0b10000000000000000000000010101100}, u32{0},
+ T{0b10000000000000000000000010101100}), //
+ C(T{0b10000000000000000000000010101100}, u32{1},
+ T{0b11000000000000000000000001010110}), //
+ C(T{0b10000000000000000000000010101100}, u32{2},
+ T{0b11100000000000000000000000101011}), //
+ C(T{0b10000000000000000000000010101100}, u32{3},
+ T{0b11110000000000000000000000010101}), //
+ C(T{0b10000000000000000000000010101100}, u32{4},
+ T{0b11111000000000000000000000001010}), //
+ C(T{0b10000000000000000000000010101100}, u32{5},
+ T{0b11111100000000000000000000000101}), //
+ C(T{0b10000000000000000000000010101100}, u32{6},
+ T{0b11111110000000000000000000000010}), //
+ C(T{0b10000000000000000000000010101100}, u32{7},
+ T{0b11111111000000000000000000000001}), //
+ C(T{0b10000000000000000000000010101100}, u32{8},
+ T{0b11111111100000000000000000000000}), //
+ C(T{0b10000000000000000000000010101100}, u32{9},
+ T{0b11111111110000000000000000000000}), //
+ // msb shifted by bit width - 1
+ C(T{0b10000000000000000000000000000000}, u32{31},
+ T{0b11111111111111111111111111111111}),
+ });
+ }
+
+ // Test shift right by bit width or more
+ if constexpr (IsAbstract<T>) {
+ // For abstract int, no error, result is 0
+ ConcatInto( //
+ r, std::vector<Case>{
+ C(T{0}, u32{B::NumBits}, T{0}),
+ C(T{0}, u32{B::NumBits + 1}, T{0}),
+ C(T{0}, u32{B::NumBits + 1000}, T{0}),
+ C(T{42}, u32{B::NumBits}, T{0}),
+ C(T{42}, u32{B::NumBits + 1}, T{0}),
+ C(T{42}, u32{B::NumBits + 1000}, T{0}),
+ });
+ } else {
+ // For concretes, error
+ const char* error_msg =
+ "12:34 error: shift right value must be less than the bit width of the lhs, which is "
+ "32";
+ ConcatInto( //
+ r, std::vector<Case>{
+ E(T{0}, u32{B::NumBits}, error_msg),
+ E(T{0}, u32{B::NumBits + 1}, error_msg),
+ E(T{0}, u32{B::NumBits + 1000}, error_msg),
+ E(T{42}, u32{B::NumBits}, error_msg),
+ E(T{42}, u32{B::NumBits + 1}, error_msg),
+ E(T{42}, u32{B::NumBits + 1000}, error_msg),
+ });
+ }
+
+ return r;
+}
+INSTANTIATE_TEST_SUITE_P(ShiftRight,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kShiftRight),
+ testing::ValuesIn(Concat(ShiftRightCases<AInt>(), //
+ ShiftRightCases<i32>(), //
+ ShiftRightCases<u32>()))));
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 922a1dc..ef40bff 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -890,10 +890,11 @@
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
// @compute @workgroup_size(1 << 2 + 4)
// fn main() {}
+ GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
});
EXPECT_FALSE(r()->Resolve());
@@ -905,10 +906,11 @@
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
// @compute @workgroup_size(1, 1 << 2 + 4)
// fn main() {}
+ GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
});
EXPECT_FALSE(r()->Resolve());
@@ -920,10 +922,11 @@
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
// @compute @workgroup_size(1, 1, 1 << 2 + 4)
// fn main() {}
+ GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index 00b7557..3fedcd0 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -13469,24 +13469,24 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[25],
+ /* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[778],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpShiftRight,
},
{
/* [432] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[25],
+ /* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[780],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpShiftRight,
},
{
/* [433] */
@@ -14975,8 +14975,8 @@
},
{
/* [17] */
- /* op >><T : iu32>(T, u32) -> T */
- /* op >><T : iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
+ /* op >><T : ia_iu32>(T, u32) -> T */
+ /* op >><T : ia_iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[431],
},