tint/spirv-reader: cast offset and count args to u32 for insertBits/extractBits

Bug: tint:1874
Change-Id: Ieadbfcb7fc61a0404dd988df42e0cfe0c8693b02
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124320
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 7ec46e7..ff0d721 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -3819,7 +3819,14 @@
 
     const auto builtin = GetBuiltin(op);
     if (builtin != builtin::Function::kNone) {
-        return MakeBuiltinCall(inst);
+        switch (builtin) {
+            case builtin::Function::kExtractBits:
+                return MakeExtractBitsCall(inst);
+            case builtin::Function::kInsertBits:
+                return MakeInsertBitsCall(inst);
+            default:
+                return MakeBuiltinCall(inst);
+        }
     }
 
     if (op == spv::Op::OpFMod) {
@@ -5274,6 +5281,42 @@
     return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
 }
 
+TypedExpression FunctionEmitter::MakeExtractBitsCall(const spvtools::opt::Instruction& inst) {
+    const auto builtin = GetBuiltin(opcode(inst));
+    auto* name = builtin::str(builtin);
+    auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
+    auto e = MakeOperand(inst, 0);
+    auto offset = ToU32(MakeOperand(inst, 1));
+    auto count = ToU32(MakeOperand(inst, 2));
+    auto* call_expr = builder_.Call(ident, ExpressionList{e.expr, offset.expr, count.expr});
+    auto* result_type = parser_impl_.ConvertType(inst.type_id());
+    if (!result_type) {
+        Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
+        return {};
+    }
+    TypedExpression call{result_type, call_expr};
+    return parser_impl_.RectifyForcedResultType(call, inst, e.type);
+}
+
+TypedExpression FunctionEmitter::MakeInsertBitsCall(const spvtools::opt::Instruction& inst) {
+    const auto builtin = GetBuiltin(opcode(inst));
+    auto* name = builtin::str(builtin);
+    auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
+    auto e = MakeOperand(inst, 0);
+    auto newbits = MakeOperand(inst, 1);
+    auto offset = ToU32(MakeOperand(inst, 2));
+    auto count = ToU32(MakeOperand(inst, 3));
+    auto* call_expr =
+        builder_.Call(ident, ExpressionList{e.expr, newbits.expr, offset.expr, count.expr});
+    auto* result_type = parser_impl_.ConvertType(inst.type_id());
+    if (!result_type) {
+        Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
+        return {};
+    }
+    TypedExpression call{result_type, call_expr};
+    return parser_impl_.RectifyForcedResultType(call, inst, e.type);
+}
+
 TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instruction& inst) {
     auto condition = MakeOperand(inst, 0);
     auto true_value = MakeOperand(inst, 1);
@@ -6053,6 +6096,13 @@
     return {ty_.I32(), builder_.Call(builder_.ty.i32(), utils::Vector{value.expr})};
 }
 
+TypedExpression FunctionEmitter::ToU32(TypedExpression value) {
+    if (!value || value.type->Is<U32>()) {
+        return value;
+    }
+    return {ty_.U32(), builder_.Call(builder_.ty.u32(), utils::Vector{value.expr})};
+}
+
 TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
     if (!value || !value.type->IsUnsignedScalarOrVector()) {
         return value;
diff --git a/src/tint/reader/spirv/function.h b/src/tint/reader/spirv/function.h
index 718e8e3..11fc92c 100644
--- a/src/tint/reader/spirv/function.h
+++ b/src/tint/reader/spirv/function.h
@@ -945,6 +945,12 @@
     /// @returns the value as an i32 value.
     TypedExpression ToI32(TypedExpression value);
 
+    /// Returns the given value as an u32. If it's already an u32 then simply returns @p value.
+    /// Otherwise, wrap the value in a TypeInitializer expression.
+    /// @param value the value to pass through or convert
+    /// @returns the value as an u32 value.
+    TypedExpression ToU32(TypedExpression value);
+
     /// Returns the given value as a signed integer type of the same shape if the value is unsigned
     /// scalar or vector, by wrapping the value with a TypeInitializer expression.  Returns the
     /// value itself if the value was already signed.
@@ -1035,6 +1041,18 @@
     /// @returns an expression
     TypedExpression MakeBuiltinCall(const spvtools::opt::Instruction& inst);
 
+    /// Returns an expression for a SPIR-V instruction that maps to the extractBits WGSL
+    /// builtin function call, with special handling to cast offset and count to u32, if needed.
+    /// @param inst the SPIR-V instruction
+    /// @returns an expression
+    TypedExpression MakeExtractBitsCall(const spvtools::opt::Instruction& inst);
+
+    /// Returns an expression for a SPIR-V instruction that maps to the insertBits WGSL
+    /// builtin function call, with special handling to cast offset and count to u32, if needed.
+    /// @param inst the SPIR-V instruction
+    /// @returns an expression
+    TypedExpression MakeInsertBitsCall(const spvtools::opt::Instruction& inst);
+
     /// Returns an expression for a SPIR-V OpArrayLength instruction.
     /// @param inst the SPIR-V instruction
     /// @returns an expression
diff --git a/src/tint/reader/spirv/function_bit_test.cc b/src/tint/reader/spirv/function_bit_test.cc
index 40f9162..2a12f01 100644
--- a/src/tint/reader/spirv/function_bit_test.cc
+++ b/src/tint/reader/spirv/function_bit_test.cc
@@ -33,6 +33,8 @@
 
   %uint_10 = OpConstant %uint 10
   %uint_20 = OpConstant %uint 20
+  %int_10 = OpConstant %int 10
+  %int_20 = OpConstant %int 20
   %int_30 = OpConstant %int 30
   %int_40 = OpConstant %int 40
   %float_50 = OpConstant %float 50
@@ -832,7 +834,7 @@
 
 TEST_F(SpvUnaryBitTest, InsertBits_Int) {
     const auto assembly = BitTestPreamble() + R"(
-     %1 = OpBitFieldInsert %v2int %int_30 %int_40 %uint_10 %uint_20
+     %1 = OpBitFieldInsert %int %int_30 %int_40 %uint_10 %uint_20
      OpReturn
      OpFunctionEnd
   )";
@@ -842,7 +844,23 @@
     EXPECT_TRUE(fe.EmitBody()) << p->error();
     auto ast_body = fe.ast_body();
     auto body = test::ToString(p->program(), ast_body);
-    EXPECT_THAT(body, HasSubstr("let x_1 : vec2<i32> = insertBits(30i, 40i, 10u, 20u);")) << body;
+    EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, InsertBits_Int_SignedOffsetAndCount) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldInsert %int %int_30 %int_40 %int_10 %int_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, u32(10i), u32(20i));"))
+        << body;
 }
 
 TEST_F(SpvUnaryBitTest, InsertBits_IntVector) {
@@ -864,9 +882,9 @@
         << body;
 }
 
-TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
+TEST_F(SpvUnaryBitTest, InsertBits_IntVector_SignedOffsetAndCount) {
     const auto assembly = BitTestPreamble() + R"(
-     %1 = OpBitFieldInsert %v2uint %uint_20 %uint_10 %uint_10 %uint_20
+     %1 = OpBitFieldInsert %v2int %v2int_30_40 %v2int_40_30 %int_10 %int_20
      OpReturn
      OpFunctionEnd
   )";
@@ -876,7 +894,42 @@
     EXPECT_TRUE(fe.EmitBody()) << p->error();
     auto ast_body = fe.ast_body();
     auto body = test::ToString(p->program(), ast_body);
-    EXPECT_THAT(body, HasSubstr("let x_1 : vec2<u32> = insertBits(20u, 10u, 10u, 20u);")) << body;
+    EXPECT_THAT(
+        body,
+        HasSubstr(
+            R"(let x_1 : vec2<i32> = insertBits(vec2<i32>(30i, 40i), vec2<i32>(40i, 30i), u32(10i), u32(20i));)"))
+        << body;
+}
+
+TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldInsert %uint %uint_20 %uint_10 %uint_10 %uint_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, InsertBits_Uint_SignedOffsetAndCount) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldInsert %uint %uint_20 %uint_10 %int_10 %int_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, u32(10i), u32(20i));"))
+        << body;
 }
 
 TEST_F(SpvUnaryBitTest, InsertBits_UintVector) {
@@ -898,9 +951,9 @@
         << body;
 }
 
-TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
+TEST_F(SpvUnaryBitTest, InsertBits_UintVector_SignedOffsetAndCount) {
     const auto assembly = BitTestPreamble() + R"(
-     %1 = OpBitFieldSExtract %v2int %int_30 %uint_10 %uint_20
+     %1 = OpBitFieldInsert %v2uint %v2uint_10_20 %v2uint_20_10 %int_10 %int_20
      OpReturn
      OpFunctionEnd
   )";
@@ -910,7 +963,41 @@
     EXPECT_TRUE(fe.EmitBody()) << p->error();
     auto ast_body = fe.ast_body();
     auto body = test::ToString(p->program(), ast_body);
-    EXPECT_THAT(body, HasSubstr("let x_1 : vec2<i32> = extractBits(30i, 10u, 20u);")) << body;
+    EXPECT_THAT(
+        body,
+        HasSubstr(
+            R"(let x_1 : vec2<u32> = insertBits(vec2<u32>(10u, 20u), vec2<u32>(20u, 10u), u32(10i), u32(20i));)"))
+        << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldSExtract %int %int_30 %uint_10 %uint_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Int_SignedOffsetAndCount) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldSExtract %int %int_30 %int_10 %int_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, u32(10i), u32(20i));")) << body;
 }
 
 TEST_F(SpvUnaryBitTest, ExtractBits_IntVector) {
@@ -930,9 +1017,9 @@
         << body;
 }
 
-TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
+TEST_F(SpvUnaryBitTest, ExtractBits_IntVector_SignedOffsetAndCount) {
     const auto assembly = BitTestPreamble() + R"(
-     %1 = OpBitFieldUExtract %v2uint %uint_20 %uint_10 %uint_20
+     %1 = OpBitFieldSExtract %v2int %v2int_30_40 %int_10 %int_20
      OpReturn
      OpFunctionEnd
   )";
@@ -942,7 +1029,40 @@
     EXPECT_TRUE(fe.EmitBody()) << p->error();
     auto ast_body = fe.ast_body();
     auto body = test::ToString(p->program(), ast_body);
-    EXPECT_THAT(body, HasSubstr("let x_1 : vec2<u32> = extractBits(20u, 10u, 20u);")) << body;
+    EXPECT_THAT(
+        body,
+        HasSubstr("let x_1 : vec2<i32> = extractBits(vec2<i32>(30i, 40i), u32(10i), u32(20i));"))
+        << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldUExtract %uint %uint_20 %uint_10 %uint_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Uint_SignedOffsetAndCount) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldUExtract %uint %uint_20 %int_10 %int_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, u32(10i), u32(20i));")) << body;
 }
 
 TEST_F(SpvUnaryBitTest, ExtractBits_UintVector) {
@@ -962,5 +1082,23 @@
         << body;
 }
 
+TEST_F(SpvUnaryBitTest, ExtractBits_UintVector_SignedOffsetAndCount) {
+    const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitFieldUExtract %v2uint %v2uint_10_20 %int_10 %int_20
+     OpReturn
+     OpFunctionEnd
+  )";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    auto body = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(
+        body,
+        HasSubstr("let x_1 : vec2<u32> = extractBits(vec2<u32>(10u, 20u), u32(10i), u32(20i));"))
+        << body;
+}
+
 }  // namespace
 }  // namespace tint::reader::spirv