tint: const eval of binary right shift

Bug: tint:1581
Change-Id: I3f40454559c4fc36565de1a11a6e6c8c394fd0cc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112620
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index 179cf88..d9188ed 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -993,8 +993,8 @@
 @const op << <T: ia_iu32>(T, u32) -> T
 @const op << <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
 
-op >> <T: iu32>(T, u32) -> T
-op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
+@const op >> <T: ia_iu32>(T, u32) -> T
+@const op >> <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
 
 ////////////////////////////////////////////////////////////////////////////////
 // Tint internal builtins                                                     //
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 8170adb..45ab2de 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -1950,6 +1950,70 @@
     return TransformElements(builder, ty, transform, args[0], args[1]);
 }
 
+ConstEval::Result ConstEval::OpShiftRight(const sem::Type* ty,
+                                          utils::VectorRef<const sem::Constant*> args,
+                                          const Source& source) {
+    auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+        auto create = [&](auto e1, auto e2) -> ImplResult {
+            using NumberT = decltype(e1);
+            using T = UnwrapNumber<NumberT>;
+            using UT = std::make_unsigned_t<T>;
+            constexpr size_t bit_width = BitWidth<NumberT>;
+            const UT e1u = static_cast<UT>(e1);
+            const UT e2u = static_cast<UT>(e2);
+
+            auto signed_shift_right = [&] {
+                // In C++, right shift of a signed negative number is implementation-defined.
+                // Although most implementations sign-extend, we do it manually to ensure it works
+                // correctly on all implementations.
+                const UT msb = UT{1} << (bit_width - 1);
+                UT sign_ext = 0;
+                if (e1u & msb) {
+                    // Set e2 + 1 bits to 1
+                    UT num_shift_bits_mask = ((UT{1} << e2u) - UT{1});
+                    sign_ext = (num_shift_bits_mask << (bit_width - e2u - UT{1})) | msb;
+                }
+                return static_cast<T>((e1u >> e2u) | sign_ext);
+            };
+
+            T result = 0;
+            if constexpr (IsAbstract<NumberT>) {
+                if (static_cast<size_t>(e2) >= bit_width) {
+                    result = T{0};
+                } else {
+                    result = signed_shift_right();
+                }
+            } else {
+                if (static_cast<size_t>(e2) >= bit_width) {
+                    // At shader/pipeline-creation time, it is an error to shift by the bit width of
+                    // the lhs or greater. NOTE: At runtime, we shift by e2 % (bit width of e1).
+                    AddError(
+                        "shift right value must be less than the bit width of the lhs, which is " +
+                            std::to_string(bit_width),
+                        source);
+                    return utils::Failure;
+                }
+
+                if constexpr (std::is_signed_v<T>) {
+                    result = signed_shift_right();
+                } else {
+                    result = e1 >> e2;
+                }
+            }
+            return CreateElement(builder, source, sem::Type::DeepestElementOf(ty), NumberT{result});
+        };
+        return Dispatch_ia_iu32(create, c0, c1);
+    };
+
+    if (!sem::Type::DeepestElementOf(args[1]->Type())->Is<sem::U32>()) {
+        TINT_ICE(Resolver, builder.Diagnostics())
+            << "Element type of rhs of ShiftLeft must be a u32";
+        return utils::Failure;
+    }
+
+    return TransformElements(builder, ty, transform, args[0], args[1]);
+}
+
 ConstEval::Result ConstEval::abs(const sem::Type* ty,
                                  utils::VectorRef<const sem::Constant*> args,
                                  const Source& source) {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 9905036..3e6c626 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -382,6 +382,15 @@
                        utils::VectorRef<const sem::Constant*> args,
                        const Source& source);
 
+    /// Bitwise shift right operator '<<'
+    /// @param ty the expression type
+    /// @param args the input arguments
+    /// @param source the source location
+    /// @return the result value, or null if the value cannot be calculated
+    Result OpShiftRight(const sem::Type* ty,
+                        utils::VectorRef<const sem::Constant*> args,
+                        const Source& source);
+
     ////////////////////////////////////////////////////////////////////////////
     // Builtins
     ////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 2590466..a370285 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -917,8 +917,7 @@
 
 template <typename T>
 std::vector<Case> ShiftLeftCases() {
-    // Shift type is u32 for non-abstract
-    using ST = std::conditional_t<IsAbstract<T>, T, u32>;
+    using ST = u32;  // Shift type is u32
     using B = BitValues<T>;
     auto r = std::vector<Case>{
         C(T{0b1010}, ST{0}, T{0b0000'0000'1010}),  //
@@ -1200,5 +1199,144 @@
                              ShiftLeftSignChangeErrorCases<AInt>(),
                              ShiftLeftSignChangeErrorCases<i32>())));
 
+template <typename T>
+std::vector<Case> ShiftRightCases() {
+    using B = BitValues<T>;
+    auto r = std::vector<Case>{
+        C(T{0b10101100}, u32{0}, T{0b10101100}),  //
+        C(T{0b10101100}, u32{1}, T{0b01010110}),  //
+        C(T{0b10101100}, u32{2}, T{0b00101011}),  //
+        C(T{0b10101100}, u32{3}, T{0b00010101}),  //
+        C(T{0b10101100}, u32{4}, T{0b00001010}),  //
+        C(T{0b10101100}, u32{5}, T{0b00000101}),  //
+        C(T{0b10101100}, u32{6}, T{0b00000010}),  //
+        C(T{0b10101100}, u32{7}, T{0b00000001}),  //
+        C(T{0b10101100}, u32{8}, T{0b00000000}),  //
+        C(T{0b10101100}, u32{9}, T{0b00000000}),  //
+        C(B::LeftMost, u32{0}, B::LeftMost),      //
+    };
+
+    // msb not set, same for all types: inserted bit is 0
+    ConcatInto(  //
+        r, std::vector<Case>{
+               C(T{0b01000000000000000000000010101100}, u32{0},  //
+                 T{0b01000000000000000000000010101100}),
+               C(T{0b01000000000000000000000010101100}, u32{1},  //
+                 T{0b00100000000000000000000001010110}),
+               C(T{0b01000000000000000000000010101100}, u32{2},  //
+                 T{0b00010000000000000000000000101011}),
+               C(T{0b01000000000000000000000010101100}, u32{3},  //
+                 T{0b00001000000000000000000000010101}),
+               C(T{0b01000000000000000000000010101100}, u32{4},  //
+                 T{0b00000100000000000000000000001010}),
+               C(T{0b01000000000000000000000010101100}, u32{5},  //
+                 T{0b00000010000000000000000000000101}),
+               C(T{0b01000000000000000000000010101100}, u32{6},  //
+                 T{0b00000001000000000000000000000010}),
+               C(T{0b01000000000000000000000010101100}, u32{7},  //
+                 T{0b00000000100000000000000000000001}),
+               C(T{0b01000000000000000000000010101100}, u32{8},  //
+                 T{0b00000000010000000000000000000000}),
+               C(T{0b01000000000000000000000010101100}, u32{9},  //
+                 T{0b00000000001000000000000000000000}),
+           });
+
+    // msb set, result differs for i32 and u32
+    if constexpr (std::is_same_v<T, u32>) {
+        // If unsigned, insert zero bits at the most significant positions.
+        ConcatInto(  //
+            r, std::vector<Case>{
+                   C(T{0b10000000000000000000000010101100}, u32{0},
+                     T{0b10000000000000000000000010101100}),
+                   C(T{0b10000000000000000000000010101100}, u32{1},
+                     T{0b01000000000000000000000001010110}),
+                   C(T{0b10000000000000000000000010101100}, u32{2},
+                     T{0b00100000000000000000000000101011}),
+                   C(T{0b10000000000000000000000010101100}, u32{3},
+                     T{0b00010000000000000000000000010101}),
+                   C(T{0b10000000000000000000000010101100}, u32{4},
+                     T{0b00001000000000000000000000001010}),
+                   C(T{0b10000000000000000000000010101100}, u32{5},
+                     T{0b00000100000000000000000000000101}),
+                   C(T{0b10000000000000000000000010101100}, u32{6},
+                     T{0b00000010000000000000000000000010}),
+                   C(T{0b10000000000000000000000010101100}, u32{7},
+                     T{0b00000001000000000000000000000001}),
+                   C(T{0b10000000000000000000000010101100}, u32{8},
+                     T{0b00000000100000000000000000000000}),
+                   C(T{0b10000000000000000000000010101100}, u32{9},
+                     T{0b00000000010000000000000000000000}),
+                   // msb shifted by bit width - 1
+                   C(T{0b10000000000000000000000000000000}, u32{31},
+                     T{0b00000000000000000000000000000001}),
+               });
+    } else if constexpr (std::is_same_v<T, i32>) {
+        // If signed, each inserted bit is 1, so the result is negative.
+        ConcatInto(  //
+            r, std::vector<Case>{
+                   C(T{0b10000000000000000000000010101100}, u32{0},
+                     T{0b10000000000000000000000010101100}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{1},
+                     T{0b11000000000000000000000001010110}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{2},
+                     T{0b11100000000000000000000000101011}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{3},
+                     T{0b11110000000000000000000000010101}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{4},
+                     T{0b11111000000000000000000000001010}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{5},
+                     T{0b11111100000000000000000000000101}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{6},
+                     T{0b11111110000000000000000000000010}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{7},
+                     T{0b11111111000000000000000000000001}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{8},
+                     T{0b11111111100000000000000000000000}),  //
+                   C(T{0b10000000000000000000000010101100}, u32{9},
+                     T{0b11111111110000000000000000000000}),  //
+                   // msb shifted by bit width - 1
+                   C(T{0b10000000000000000000000000000000}, u32{31},
+                     T{0b11111111111111111111111111111111}),
+               });
+    }
+
+    // Test shift right by bit width or more
+    if constexpr (IsAbstract<T>) {
+        // For abstract int, no error, result is 0
+        ConcatInto(  //
+            r, std::vector<Case>{
+                   C(T{0}, u32{B::NumBits}, T{0}),
+                   C(T{0}, u32{B::NumBits + 1}, T{0}),
+                   C(T{0}, u32{B::NumBits + 1000}, T{0}),
+                   C(T{42}, u32{B::NumBits}, T{0}),
+                   C(T{42}, u32{B::NumBits + 1}, T{0}),
+                   C(T{42}, u32{B::NumBits + 1000}, T{0}),
+               });
+    } else {
+        // For concretes, error
+        const char* error_msg =
+            "12:34 error: shift right value must be less than the bit width of the lhs, which is "
+            "32";
+        ConcatInto(  //
+            r, std::vector<Case>{
+                   E(T{0}, u32{B::NumBits}, error_msg),
+                   E(T{0}, u32{B::NumBits + 1}, error_msg),
+                   E(T{0}, u32{B::NumBits + 1000}, error_msg),
+                   E(T{42}, u32{B::NumBits}, error_msg),
+                   E(T{42}, u32{B::NumBits + 1}, error_msg),
+                   E(T{42}, u32{B::NumBits + 1000}, error_msg),
+               });
+    }
+
+    return r;
+}
+INSTANTIATE_TEST_SUITE_P(ShiftRight,
+                         ResolverConstEvalBinaryOpTest,
+                         testing::Combine(  //
+                             testing::Values(ast::BinaryOp::kShiftRight),
+                             testing::ValuesIn(Concat(ShiftRightCases<AInt>(),  //
+                                                      ShiftRightCases<i32>(),   //
+                                                      ShiftRightCases<u32>()))));
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 922a1dc..ef40bff 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -890,10 +890,11 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
     // @compute @workgroup_size(1 << 2 + 4)
     // fn main() {}
+    GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
     Func("main", utils::Empty, ty.void_(), utils::Empty,
          utils::Vector{
              Stage(ast::PipelineStage::kCompute),
-             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
          });
 
     EXPECT_FALSE(r()->Resolve());
@@ -905,10 +906,11 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
     // @compute @workgroup_size(1, 1 << 2 + 4)
     // fn main() {}
+    GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
     Func("main", utils::Empty, ty.void_(), utils::Empty,
          utils::Vector{
              Stage(ast::PipelineStage::kCompute),
-             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
          });
 
     EXPECT_FALSE(r()->Resolve());
@@ -920,10 +922,11 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
     // @compute @workgroup_size(1, 1, 1 << 2 + 4)
     // fn main() {}
+    GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
     Func("main", utils::Empty, ty.void_(), utils::Empty,
          utils::Vector{
              Stage(ast::PipelineStage::kCompute),
-             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
          });
 
     EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index 00b7557..3fedcd0 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -13469,24 +13469,24 @@
     /* num parameters */ 2,
     /* num template types */ 1,
     /* num template numbers */ 0,
-    /* template types */ &kTemplateTypes[25],
+    /* template types */ &kTemplateTypes[28],
     /* template numbers */ &kTemplateNumbers[10],
     /* parameters */ &kParameters[778],
     /* return matcher indices */ &kMatcherIndices[3],
     /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
-    /* const eval */ nullptr,
+    /* const eval */ &ConstEval::OpShiftRight,
   },
   {
     /* [432] */
     /* num parameters */ 2,
     /* num template types */ 1,
     /* num template numbers */ 1,
-    /* template types */ &kTemplateTypes[25],
+    /* template types */ &kTemplateTypes[28],
     /* template numbers */ &kTemplateNumbers[4],
     /* parameters */ &kParameters[780],
     /* return matcher indices */ &kMatcherIndices[30],
     /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
-    /* const eval */ nullptr,
+    /* const eval */ &ConstEval::OpShiftRight,
   },
   {
     /* [433] */
@@ -14975,8 +14975,8 @@
   },
   {
     /* [17] */
-    /* op >><T : iu32>(T, u32) -> T */
-    /* op >><T : iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
+    /* op >><T : ia_iu32>(T, u32) -> T */
+    /* op >><T : ia_iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
     /* num overloads */ 2,
     /* overloads */ &kOverloads[431],
   },