spirv-reader: fix signedness for shifts

Fixed: tint:675
Change-Id: Ib754191284a62b9f4be56dc8d38e5319345ab9bf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49824
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Alan Baker <alanbaker@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index b4151aa..dd3974d 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -210,11 +210,6 @@
     case SpvOpSMod:
     case SpvOpFMod:
       return ast::BinaryOp::kModulo;
-    case SpvOpShiftLeftLogical:
-      return ast::BinaryOp::kShiftLeft;
-    case SpvOpShiftRightLogical:
-    case SpvOpShiftRightArithmetic:
-      return ast::BinaryOp::kShiftRight;
     case SpvOpLogicalEqual:
     case SpvOpIEqual:
     case SpvOpFOrdEqual:
@@ -3345,6 +3340,34 @@
                           Source{}, ast_type, MakeOperand(inst, 0).expr)};
   }
 
+  if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical ||
+      opcode == SpvOpShiftRightArithmetic) {
+    auto arg0 = MakeOperand(inst, 0);
+    // The second operand must be unsigned. It's ok to wrap the shift amount
+    // since the shift is modulo the bit width of the first operand.
+    auto arg1 = parser_impl_.AsUnsigned(MakeOperand(inst, 1));
+
+    switch (opcode) {
+      case SpvOpShiftLeftLogical:
+        binary_op = ast::BinaryOp::kShiftLeft;
+        break;
+      case SpvOpShiftRightLogical:
+        arg0 = parser_impl_.AsUnsigned(arg0);
+        binary_op = ast::BinaryOp::kShiftRight;
+        break;
+      case SpvOpShiftRightArithmetic:
+        arg0 = parser_impl_.AsSigned(arg0);
+        binary_op = ast::BinaryOp::kShiftRight;
+        break;
+      default:
+        break;
+    }
+    TypedExpression result{
+        ast_type, create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr,
+                                                arg1.expr)};
+    return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
+  }
+
   auto negated_op = NegatedFloatCompare(opcode);
   if (negated_op != ast::BinaryOp::kNone) {
     auto arg0 = MakeOperand(inst, 0);
diff --git a/src/reader/spirv/function_bit_test.cc b/src/reader/spirv/function_bit_test.cc
index cb3abe2..4632c19 100644
--- a/src/reader/spirv/function_bit_test.cc
+++ b/src/reader/spirv/function_bit_test.cc
@@ -203,7 +203,7 @@
       << p->error() << "\n"
       << assembly;
   auto fe = p->function_emitter(100);
-  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_TRUE(fe.EmitBody()) << p->error() << assembly;
   std::ostringstream ss;
   ss << R"(VariableConst{
     x_1
@@ -215,115 +215,536 @@
 }
 
 INSTANTIATE_TEST_SUITE_P(
-    SpvParserTest_ShiftLeftLogical,
+    SpvParserTest_ShiftLeftLogical_Arg2Unsigned,
     SpvBinaryBitTest,
     ::testing::Values(
-        // Both uint
+        // uint uint -> uint
         BinaryData{"uint", "uint_10", "OpShiftLeftLogical", "uint_20", "__u32",
                    "ScalarConstructor[not set]{10u}", "shift_left",
                    "ScalarConstructor[not set]{20u}"},
-        // Both int
-        BinaryData{"int", "int_30", "OpShiftLeftLogical", "int_40", "__i32",
+        // int, uint -> int
+        BinaryData{"int", "int_30", "OpShiftLeftLogical", "uint_20", "__i32",
                    "ScalarConstructor[not set]{30}", "shift_left",
-                   "ScalarConstructor[not set]{40}"},
-        // Mixed, returning uint
-        BinaryData{"uint", "int_30", "OpShiftLeftLogical", "uint_10", "__u32",
-                   "ScalarConstructor[not set]{30}", "shift_left",
-                   "ScalarConstructor[not set]{10u}"},
-        // Mixed, returning int
-        BinaryData{"int", "int_30", "OpShiftLeftLogical", "uint_10", "__i32",
-                   "ScalarConstructor[not set]{30}", "shift_left",
-                   "ScalarConstructor[not set]{10u}"},
-        // Both v2uint
+                   "ScalarConstructor[not set]{20u}"},
+        // v2uint v2uint -> v2uint
         BinaryData{"v2uint", "v2uint_10_20", "OpShiftLeftLogical",
                    "v2uint_20_10", "__vec_2__u32", AstFor("v2uint_10_20"),
                    "shift_left", AstFor("v2uint_20_10")},
-        // Both v2int
-        BinaryData{"v2int", "v2int_30_40", "OpShiftLeftLogical", "v2int_40_30",
+        // v2int, v2uint -> v2int
+        BinaryData{"v2int", "v2int_30_40", "OpShiftLeftLogical", "v2uint_20_10",
                    "__vec_2__i32", AstFor("v2int_30_40"), "shift_left",
-                   AstFor("v2int_40_30")},
-        // Mixed, returning v2uint
-        BinaryData{"v2uint", "v2int_30_40", "OpShiftLeftLogical",
-                   "v2uint_10_20", "__vec_2__u32", AstFor("v2int_30_40"),
-                   "shift_left", AstFor("v2uint_10_20")},
-        // Mixed, returning v2int
-        BinaryData{"v2int", "v2int_40_30", "OpShiftLeftLogical", "v2uint_20_10",
-                   "__vec_2__i32", AstFor("v2int_40_30"), "shift_left",
                    AstFor("v2uint_20_10")}));
 
 INSTANTIATE_TEST_SUITE_P(
-    SpvParserTest_ShiftRightLogical,
-    SpvBinaryBitTest,
+    // WGSL requires second operand to be unsigned, so insert bitcasts
+    SpvParserTest_ShiftLeftLogical_Arg2Signed,
+    SpvBinaryBitGeneralTest,
     ::testing::Values(
-        // Both uint
-        BinaryData{"uint", "uint_10", "OpShiftRightLogical", "uint_20", "__u32",
-                   "ScalarConstructor[not set]{10u}", "shift_right",
-                   "ScalarConstructor[not set]{20u}"},
-        // Both int
-        BinaryData{"int", "int_30", "OpShiftRightLogical", "int_40", "__i32",
-                   "ScalarConstructor[not set]{30}", "shift_right",
-                   "ScalarConstructor[not set]{40}"},
-        // Mixed, returning uint
-        BinaryData{"uint", "int_30", "OpShiftRightLogical", "uint_10", "__u32",
-                   "ScalarConstructor[not set]{30}", "shift_right",
-                   "ScalarConstructor[not set]{10u}"},
-        // Mixed, returning int
-        BinaryData{"int", "int_30", "OpShiftRightLogical", "uint_10", "__i32",
-                   "ScalarConstructor[not set]{30}", "shift_right",
-                   "ScalarConstructor[not set]{10u}"},
-        // Both v2uint
-        BinaryData{"v2uint", "v2uint_10_20", "OpShiftRightLogical",
-                   "v2uint_20_10", "__vec_2__u32", AstFor("v2uint_10_20"),
-                   "shift_right", AstFor("v2uint_20_10")},
-        // Both v2int
-        BinaryData{"v2int", "v2int_30_40", "OpShiftRightLogical", "v2int_40_30",
-                   "__vec_2__i32", AstFor("v2int_30_40"), "shift_right",
-                   AstFor("v2int_40_30")},
-        // Mixed, returning v2uint
-        BinaryData{"v2uint", "v2int_30_40", "OpShiftRightLogical",
-                   "v2uint_10_20", "__vec_2__u32", AstFor("v2int_30_40"),
-                   "shift_right", AstFor("v2uint_10_20")},
-        // Mixed, returning v2int
-        BinaryData{"v2int", "v2int_40_30", "OpShiftRightLogical",
-                   "v2uint_20_10", "__vec_2__i32", AstFor("v2int_40_30"),
-                   "shift_right", AstFor("v2uint_20_10")}));
+        // int, int -> int
+        BinaryDataGeneral{"int", "int_30", "OpShiftLeftLogical", "int_40",
+                          R"(__i32
+    {
+      Binary[not set]{
+        ScalarConstructor[not set]{30}
+        shift_left
+        Bitcast[not set]<__u32>{
+          ScalarConstructor[not set]{40}
+        }
+      }
+    }
+)"},
+        // uint, int -> uint
+        BinaryDataGeneral{"uint", "uint_10", "OpShiftLeftLogical", "int_40",
+                          R"(__u32
+    {
+      Binary[not set]{
+        ScalarConstructor[not set]{10u}
+        shift_left
+        Bitcast[not set]<__u32>{
+          ScalarConstructor[not set]{40}
+        }
+      }
+    }
+)"},
+        // v2uint, v2int -> v2uint
+        BinaryDataGeneral{"v2uint", "v2uint_10_20", "OpShiftLeftLogical",
+                          "v2uint_20_10",
+                          R"(__vec_2__u32
+    {
+      Binary[not set]{
+        TypeConstructor[not set]{
+          __vec_2__u32
+          ScalarConstructor[not set]{10u}
+          ScalarConstructor[not set]{20u}
+        }
+        shift_left
+        TypeConstructor[not set]{
+          __vec_2__u32
+          ScalarConstructor[not set]{20u}
+          ScalarConstructor[not set]{10u}
+        }
+      }
+    }
+)"},
+        // v2int, v2int -> v2int
+        BinaryDataGeneral{"v2int", "v2int_30_40", "OpShiftLeftLogical",
+                          "v2int_40_30",
+                          R"(__vec_2__i32
+    {
+      Binary[not set]{
+        TypeConstructor[not set]{
+          __vec_2__i32
+          ScalarConstructor[not set]{30}
+          ScalarConstructor[not set]{40}
+        }
+        shift_left
+        Bitcast[not set]<__vec_2__u32>{
+          TypeConstructor[not set]{
+            __vec_2__i32
+            ScalarConstructor[not set]{40}
+            ScalarConstructor[not set]{30}
+          }
+        }
+      }
+    }
+)"}));
 
 INSTANTIATE_TEST_SUITE_P(
-    SpvParserTest_ShiftRightArithmetic,
-    SpvBinaryBitTest,
+    SpvParserTest_ShiftLeftLogical_BitcastResult,
+    SpvBinaryBitGeneralTest,
     ::testing::Values(
-        // Both uint
-        BinaryData{"uint", "uint_10", "OpShiftRightArithmetic", "uint_20",
-                   "__u32", "ScalarConstructor[not set]{10u}", "shift_right",
-                   "ScalarConstructor[not set]{20u}"},
-        // Both int
-        BinaryData{"int", "int_30", "OpShiftRightArithmetic", "int_40", "__i32",
-                   "ScalarConstructor[not set]{30}", "shift_right",
-                   "ScalarConstructor[not set]{40}"},
-        // Mixed, returning uint
-        BinaryData{"uint", "int_30", "OpShiftRightArithmetic", "uint_10",
-                   "__u32", "ScalarConstructor[not set]{30}", "shift_right",
-                   "ScalarConstructor[not set]{10u}"},
-        // Mixed, returning int
-        BinaryData{"int", "int_30", "OpShiftRightArithmetic", "uint_10",
-                   "__i32", "ScalarConstructor[not set]{30}", "shift_right",
-                   "ScalarConstructor[not set]{10u}"},
-        // Both v2uint
-        BinaryData{"v2uint", "v2uint_10_20", "OpShiftRightArithmetic",
-                   "v2uint_20_10", "__vec_2__u32", AstFor("v2uint_10_20"),
-                   "shift_right", AstFor("v2uint_20_10")},
-        // Both v2int
-        BinaryData{"v2int", "v2int_30_40", "OpShiftRightArithmetic",
-                   "v2int_40_30", "__vec_2__i32", AstFor("v2int_30_40"),
-                   "shift_right", AstFor("v2int_40_30")},
-        // Mixed, returning v2uint
-        BinaryData{"v2uint", "v2int_30_40", "OpShiftRightArithmetic",
-                   "v2uint_10_20", "__vec_2__u32", AstFor("v2int_30_40"),
-                   "shift_right", AstFor("v2uint_10_20")},
-        // Mixed, returning v2int
-        BinaryData{"v2int", "v2int_40_30", "OpShiftRightArithmetic",
-                   "v2uint_20_10", "__vec_2__i32", AstFor("v2int_40_30"),
-                   "shift_right", AstFor("v2uint_20_10")}));
+        // int, int -> uint
+        BinaryDataGeneral{"uint", "int_30", "OpShiftLeftLogical", "uint_10",
+                          R"(__u32
+    {
+      Bitcast[not set]<__u32>{
+        Binary[not set]{
+          ScalarConstructor[not set]{30}
+          shift_left
+          ScalarConstructor[not set]{10u}
+        }
+      }
+    }
+)"},
+        // v2uint, v2int -> v2uint
+        BinaryDataGeneral{"v2uint", "v2int_30_40", "OpShiftLeftLogical",
+                          "v2uint_20_10",
+                          R"(__vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Binary[not set]{
+          TypeConstructor[not set]{
+            __vec_2__i32
+            ScalarConstructor[not set]{30}
+            ScalarConstructor[not set]{40}
+          }
+          shift_left
+          TypeConstructor[not set]{
+            __vec_2__u32
+            ScalarConstructor[not set]{20u}
+            ScalarConstructor[not set]{10u}
+          }
+        }
+      }
+    }
+)"}));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_ShiftRightLogical_Arg2Unsigned,
+    SpvBinaryBitGeneralTest,
+    ::testing::Values(
+        // uint, uint -> uint
+        BinaryDataGeneral{"uint", "uint_10", "OpShiftRightLogical", "uint_20",
+                          R"(__u32
+    {
+      Binary[not set]{
+        ScalarConstructor[not set]{10u}
+        shift_right
+        ScalarConstructor[not set]{20u}
+      }
+    }
+)"},
+        // int, uint -> int
+        BinaryDataGeneral{"int", "int_30", "OpShiftRightLogical", "uint_20",
+                          R"(__i32
+    {
+      Bitcast[not set]<__i32>{
+        Binary[not set]{
+          Bitcast[not set]<__u32>{
+            ScalarConstructor[not set]{30}
+          }
+          shift_right
+          ScalarConstructor[not set]{20u}
+        }
+      }
+    }
+)"},
+        // v2uint, v2uint -> v2uint
+        BinaryDataGeneral{"v2uint", "v2uint_10_20", "OpShiftRightLogical",
+                          "v2uint_20_10",
+                          R"(__vec_2__u32
+    {
+      Binary[not set]{
+        TypeConstructor[not set]{
+          __vec_2__u32
+          ScalarConstructor[not set]{10u}
+          ScalarConstructor[not set]{20u}
+        }
+        shift_right
+        TypeConstructor[not set]{
+          __vec_2__u32
+          ScalarConstructor[not set]{20u}
+          ScalarConstructor[not set]{10u}
+        }
+      }
+    }
+)"},
+        // v2int, v2uint -> v2int
+        BinaryDataGeneral{"v2int", "v2int_30_40", "OpShiftRightLogical",
+                          "v2uint_10_20",
+                          R"(__vec_2__i32
+    {
+      Bitcast[not set]<__vec_2__i32>{
+        Binary[not set]{
+          Bitcast[not set]<__vec_2__u32>{
+            TypeConstructor[not set]{
+              __vec_2__i32
+              ScalarConstructor[not set]{30}
+              ScalarConstructor[not set]{40}
+            }
+          }
+          shift_right
+          TypeConstructor[not set]{
+            __vec_2__u32
+            ScalarConstructor[not set]{10u}
+            ScalarConstructor[not set]{20u}
+          }
+        }
+      }
+    }
+)"}));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_ShiftRightLogical_Arg2Signed,
+    SpvBinaryBitGeneralTest,
+    ::testing::Values(
+        // uint, int -> uint
+        BinaryDataGeneral{"uint", "uint_10", "OpShiftRightLogical", "int_30",
+                          R"(__u32
+    {
+      Binary[not set]{
+        ScalarConstructor[not set]{10u}
+        shift_right
+        Bitcast[not set]<__u32>{
+          ScalarConstructor[not set]{30}
+        }
+      }
+    }
+)"},
+        // int, int -> int
+        BinaryDataGeneral{"int", "int_30", "OpShiftRightLogical", "int_40",
+                          R"(__i32
+    {
+      Bitcast[not set]<__i32>{
+        Binary[not set]{
+          Bitcast[not set]<__u32>{
+            ScalarConstructor[not set]{30}
+          }
+          shift_right
+          Bitcast[not set]<__u32>{
+            ScalarConstructor[not set]{40}
+          }
+        }
+      }
+    }
+)"},
+        // v2uint, v2int -> v2uint
+        BinaryDataGeneral{"v2uint", "v2uint_10_20", "OpShiftRightLogical",
+                          "v2int_30_40",
+                          R"(__vec_2__u32
+    {
+      Binary[not set]{
+        TypeConstructor[not set]{
+          __vec_2__u32
+          ScalarConstructor[not set]{10u}
+          ScalarConstructor[not set]{20u}
+        }
+        shift_right
+        Bitcast[not set]<__vec_2__u32>{
+          TypeConstructor[not set]{
+            __vec_2__i32
+            ScalarConstructor[not set]{30}
+            ScalarConstructor[not set]{40}
+          }
+        }
+      }
+    }
+)"},
+        // v2int, v2int -> v2int
+        BinaryDataGeneral{"v2int", "v2int_40_30", "OpShiftRightLogical",
+                          "v2int_30_40",
+                          R"(__vec_2__i32
+    {
+      Bitcast[not set]<__vec_2__i32>{
+        Binary[not set]{
+          Bitcast[not set]<__vec_2__u32>{
+            TypeConstructor[not set]{
+              __vec_2__i32
+              ScalarConstructor[not set]{40}
+              ScalarConstructor[not set]{30}
+            }
+          }
+          shift_right
+          Bitcast[not set]<__vec_2__u32>{
+            TypeConstructor[not set]{
+              __vec_2__i32
+              ScalarConstructor[not set]{30}
+              ScalarConstructor[not set]{40}
+            }
+          }
+        }
+      }
+    }
+)"}));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_ShiftRightLogical_BitcastResult,
+    SpvBinaryBitGeneralTest,
+    ::testing::Values(
+        // uint, uint -> int
+        BinaryDataGeneral{"int", "uint_20", "OpShiftRightLogical", "uint_10",
+                          R"(__i32
+    {
+      Bitcast[not set]<__i32>{
+        Binary[not set]{
+          ScalarConstructor[not set]{20u}
+          shift_right
+          ScalarConstructor[not set]{10u}
+        }
+      }
+    }
+)"},
+        // v2uint, v2uint -> v2int
+        BinaryDataGeneral{"v2int", "v2uint_10_20", "OpShiftRightLogical",
+                          "v2uint_20_10",
+                          R"(__vec_2__i32
+    {
+      Bitcast[not set]<__vec_2__i32>{
+        Binary[not set]{
+          TypeConstructor[not set]{
+            __vec_2__u32
+            ScalarConstructor[not set]{10u}
+            ScalarConstructor[not set]{20u}
+          }
+          shift_right
+          TypeConstructor[not set]{
+            __vec_2__u32
+            ScalarConstructor[not set]{20u}
+            ScalarConstructor[not set]{10u}
+          }
+        }
+      }
+    }
+)"}));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_ShiftRightArithmetic_Arg2Unsigned,
+    SpvBinaryBitGeneralTest,
+    ::testing::Values(
+        // uint, uint -> uint
+        BinaryDataGeneral{"uint", "uint_10", "OpShiftRightArithmetic",
+                          "uint_20",
+                          R"(__u32
+    {
+      Bitcast[not set]<__u32>{
+        Binary[not set]{
+          Bitcast[not set]<__i32>{
+            ScalarConstructor[not set]{10u}
+          }
+          shift_right
+          ScalarConstructor[not set]{20u}
+        }
+      }
+    }
+)"},
+        // int, uint -> int
+        BinaryDataGeneral{"int", "int_30", "OpShiftRightArithmetic", "uint_10",
+                          R"(__i32
+    {
+      Binary[not set]{
+        ScalarConstructor[not set]{30}
+        shift_right
+        ScalarConstructor[not set]{10u}
+      }
+    }
+)"},
+        // v2uint, v2uint -> v2uint
+        BinaryDataGeneral{"v2uint", "v2uint_10_20", "OpShiftRightArithmetic",
+                          "v2uint_20_10",
+                          R"(__vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Binary[not set]{
+          Bitcast[not set]<__vec_2__i32>{
+            TypeConstructor[not set]{
+              __vec_2__u32
+              ScalarConstructor[not set]{10u}
+              ScalarConstructor[not set]{20u}
+            }
+          }
+          shift_right
+          TypeConstructor[not set]{
+            __vec_2__u32
+            ScalarConstructor[not set]{20u}
+            ScalarConstructor[not set]{10u}
+          }
+        }
+      }
+    }
+)"},
+        // v2int, v2uint -> v2int
+        BinaryDataGeneral{"v2int", "v2int_40_30", "OpShiftRightArithmetic",
+                          "v2uint_20_10",
+                          R"(__vec_2__i32
+    {
+      Binary[not set]{
+        TypeConstructor[not set]{
+          __vec_2__i32
+          ScalarConstructor[not set]{40}
+          ScalarConstructor[not set]{30}
+        }
+        shift_right
+        TypeConstructor[not set]{
+          __vec_2__u32
+          ScalarConstructor[not set]{20u}
+          ScalarConstructor[not set]{10u}
+        }
+      }
+    }
+)"}));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_ShiftRightArithmetic_Arg2Signed,
+    SpvBinaryBitGeneralTest,
+    ::testing::Values(
+        // uint, int -> uint
+        BinaryDataGeneral{"uint", "uint_10", "OpShiftRightArithmetic", "int_30",
+                          R"(__u32
+    {
+      Bitcast[not set]<__u32>{
+        Binary[not set]{
+          Bitcast[not set]<__i32>{
+            ScalarConstructor[not set]{10u}
+          }
+          shift_right
+          Bitcast[not set]<__u32>{
+            ScalarConstructor[not set]{30}
+          }
+        }
+      }
+    }
+)"},
+        // int, int -> int
+        BinaryDataGeneral{"int", "int_30", "OpShiftRightArithmetic", "int_40",
+                          R"(__i32
+    {
+      Binary[not set]{
+        ScalarConstructor[not set]{30}
+        shift_right
+        Bitcast[not set]<__u32>{
+          ScalarConstructor[not set]{40}
+        }
+      }
+    }
+)"},
+        // v2uint, v2int -> v2uint
+        BinaryDataGeneral{"v2uint", "v2uint_10_20", "OpShiftRightArithmetic",
+                          "v2int_30_40",
+                          R"(__vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Binary[not set]{
+          Bitcast[not set]<__vec_2__i32>{
+            TypeConstructor[not set]{
+              __vec_2__u32
+              ScalarConstructor[not set]{10u}
+              ScalarConstructor[not set]{20u}
+            }
+          }
+          shift_right
+          Bitcast[not set]<__vec_2__u32>{
+            TypeConstructor[not set]{
+              __vec_2__i32
+              ScalarConstructor[not set]{30}
+              ScalarConstructor[not set]{40}
+            }
+          }
+        }
+      }
+    }
+)"},
+        // v2int, v2int -> v2int
+        BinaryDataGeneral{"v2int", "v2int_40_30", "OpShiftRightArithmetic",
+                          "v2int_30_40",
+                          R"(__vec_2__i32
+    {
+      Binary[not set]{
+        TypeConstructor[not set]{
+          __vec_2__i32
+          ScalarConstructor[not set]{40}
+          ScalarConstructor[not set]{30}
+        }
+        shift_right
+        Bitcast[not set]<__vec_2__u32>{
+          TypeConstructor[not set]{
+            __vec_2__i32
+            ScalarConstructor[not set]{30}
+            ScalarConstructor[not set]{40}
+          }
+        }
+      }
+    }
+)"}));
+
+INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_ShiftRightArithmetic_BitcastResult,
+    SpvBinaryBitGeneralTest,
+    ::testing::Values(
+        // int, uint -> uint
+        BinaryDataGeneral{"uint", "int_30", "OpShiftRightArithmetic", "uint_10",
+                          R"(__u32
+    {
+      Bitcast[not set]<__u32>{
+        Binary[not set]{
+          ScalarConstructor[not set]{30}
+          shift_right
+          ScalarConstructor[not set]{10u}
+        }
+      }
+    }
+)"},
+        // v2int, v2uint -> v2uint
+        BinaryDataGeneral{"v2uint", "v2int_30_40", "OpShiftRightArithmetic",
+                          "v2uint_20_10",
+                          R"(__vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Binary[not set]{
+          TypeConstructor[not set]{
+            __vec_2__i32
+            ScalarConstructor[not set]{30}
+            ScalarConstructor[not set]{40}
+          }
+          shift_right
+          TypeConstructor[not set]{
+            __vec_2__u32
+            ScalarConstructor[not set]{20u}
+            ScalarConstructor[not set]{10u}
+          }
+        }
+      }
+    }
+)"}));
 
 INSTANTIATE_TEST_SUITE_P(
     SpvParserTest_BitwiseAnd,
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 2077927..78eb256 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -199,6 +199,9 @@
     case SpvOpBitwiseAnd:
     case SpvOpBitwiseOr:
     case SpvOpBitwiseXor:
+    case SpvOpShiftLeftLogical:
+    case SpvOpShiftRightLogical:
+    case SpvOpShiftRightArithmetic:
       return true;
     default:
       break;
@@ -1771,6 +1774,24 @@
           create<ast::BitcastExpression>(Source{}, expr.type, expr.expr)};
 }
 
+TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) {
+  if (expr.type && expr.type->is_signed_scalar_or_vector()) {
+    auto new_type = GetUnsignedIntMatchingShape(expr.type);
+    return {new_type,
+            create<ast::BitcastExpression>(Source{}, new_type, expr.expr)};
+  }
+  return expr;
+}
+
+TypedExpression ParserImpl::AsSigned(TypedExpression expr) {
+  if (expr.type && expr.type->is_unsigned_scalar_or_vector()) {
+    auto new_type = GetSignedIntMatchingShape(expr.type);
+    return {new_type,
+            create<ast::BitcastExpression>(Source{}, new_type, expr.expr)};
+  }
+  return expr;
+}
+
 bool ParserImpl::EmitFunctions() {
   if (!success_) {
     return false;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index be6ea4b..253b202 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -354,7 +354,7 @@
       const spvtools::opt::Instruction& inst,
       TypedExpression&& expr);
 
-  /// Conversts a second operand to the signedness of the first operand
+  /// Converts a second operand to the signedness of the first operand
   /// of a binary operator, if the WGSL operator requires they be the same.
   /// Returns the converted expression, or the original expression if the
   /// conversion is not needed.
@@ -407,6 +407,18 @@
       const spvtools::opt::Instruction& inst,
       typ::Type first_operand_type);
 
+  /// Returns the given expression, but ensuring it's an unsigned type of the
+  /// same shape as the operand. Wraps the expresison with a bitcast if needed.
+  /// Assumes the given expresion is a integer scalar or vector.
+  /// @param expr an integer scalar or integer vector expression.
+  TypedExpression AsUnsigned(TypedExpression expr);
+
+  /// Returns the given expression, but ensuring it's a signed type of the
+  /// same shape as the operand. Wraps the expresison with a bitcast if needed.
+  /// Assumes the given expresion is a integer scalar or vector.
+  /// @param expr an integer scalar or integer vector expression.
+  TypedExpression AsSigned(TypedExpression expr);
+
   /// Bookkeeping used for tracking the "position" builtin variable.
   struct BuiltInPositionInfo {
     /// The ID for the gl_PerVertex struct containing the Position builtin.