tint/spirv-reader: cast offset and count args to u32 for insertBits/extractBits
Bug: tint:1874
Change-Id: Ieadbfcb7fc61a0404dd988df42e0cfe0c8693b02
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124320
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 7ec46e7..ff0d721 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -3819,7 +3819,14 @@
const auto builtin = GetBuiltin(op);
if (builtin != builtin::Function::kNone) {
- return MakeBuiltinCall(inst);
+ switch (builtin) {
+ case builtin::Function::kExtractBits:
+ return MakeExtractBitsCall(inst);
+ case builtin::Function::kInsertBits:
+ return MakeInsertBitsCall(inst);
+ default:
+ return MakeBuiltinCall(inst);
+ }
}
if (op == spv::Op::OpFMod) {
@@ -5274,6 +5281,42 @@
return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
}
+TypedExpression FunctionEmitter::MakeExtractBitsCall(const spvtools::opt::Instruction& inst) {
+ const auto builtin = GetBuiltin(opcode(inst));
+ auto* name = builtin::str(builtin);
+ auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
+ auto e = MakeOperand(inst, 0);
+ auto offset = ToU32(MakeOperand(inst, 1));
+ auto count = ToU32(MakeOperand(inst, 2));
+ auto* call_expr = builder_.Call(ident, ExpressionList{e.expr, offset.expr, count.expr});
+ auto* result_type = parser_impl_.ConvertType(inst.type_id());
+ if (!result_type) {
+ Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
+ return {};
+ }
+ TypedExpression call{result_type, call_expr};
+ return parser_impl_.RectifyForcedResultType(call, inst, e.type);
+}
+
+TypedExpression FunctionEmitter::MakeInsertBitsCall(const spvtools::opt::Instruction& inst) {
+ const auto builtin = GetBuiltin(opcode(inst));
+ auto* name = builtin::str(builtin);
+ auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
+ auto e = MakeOperand(inst, 0);
+ auto newbits = MakeOperand(inst, 1);
+ auto offset = ToU32(MakeOperand(inst, 2));
+ auto count = ToU32(MakeOperand(inst, 3));
+ auto* call_expr =
+ builder_.Call(ident, ExpressionList{e.expr, newbits.expr, offset.expr, count.expr});
+ auto* result_type = parser_impl_.ConvertType(inst.type_id());
+ if (!result_type) {
+ Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
+ return {};
+ }
+ TypedExpression call{result_type, call_expr};
+ return parser_impl_.RectifyForcedResultType(call, inst, e.type);
+}
+
TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instruction& inst) {
auto condition = MakeOperand(inst, 0);
auto true_value = MakeOperand(inst, 1);
@@ -6053,6 +6096,13 @@
return {ty_.I32(), builder_.Call(builder_.ty.i32(), utils::Vector{value.expr})};
}
+TypedExpression FunctionEmitter::ToU32(TypedExpression value) {
+ if (!value || value.type->Is<U32>()) {
+ return value;
+ }
+ return {ty_.U32(), builder_.Call(builder_.ty.u32(), utils::Vector{value.expr})};
+}
+
TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
if (!value || !value.type->IsUnsignedScalarOrVector()) {
return value;
diff --git a/src/tint/reader/spirv/function.h b/src/tint/reader/spirv/function.h
index 718e8e3..11fc92c 100644
--- a/src/tint/reader/spirv/function.h
+++ b/src/tint/reader/spirv/function.h
@@ -945,6 +945,12 @@
/// @returns the value as an i32 value.
TypedExpression ToI32(TypedExpression value);
+ /// Returns the given value as an u32. If it's already an u32 then simply returns @p value.
+ /// Otherwise, wrap the value in a TypeInitializer expression.
+ /// @param value the value to pass through or convert
+ /// @returns the value as an u32 value.
+ TypedExpression ToU32(TypedExpression value);
+
/// Returns the given value as a signed integer type of the same shape if the value is unsigned
/// scalar or vector, by wrapping the value with a TypeInitializer expression. Returns the
/// value itself if the value was already signed.
@@ -1035,6 +1041,18 @@
/// @returns an expression
TypedExpression MakeBuiltinCall(const spvtools::opt::Instruction& inst);
+ /// Returns an expression for a SPIR-V instruction that maps to the extractBits WGSL
+ /// builtin function call, with special handling to cast offset and count to u32, if needed.
+ /// @param inst the SPIR-V instruction
+ /// @returns an expression
+ TypedExpression MakeExtractBitsCall(const spvtools::opt::Instruction& inst);
+
+ /// Returns an expression for a SPIR-V instruction that maps to the insertBits WGSL
+ /// builtin function call, with special handling to cast offset and count to u32, if needed.
+ /// @param inst the SPIR-V instruction
+ /// @returns an expression
+ TypedExpression MakeInsertBitsCall(const spvtools::opt::Instruction& inst);
+
/// Returns an expression for a SPIR-V OpArrayLength instruction.
/// @param inst the SPIR-V instruction
/// @returns an expression
diff --git a/src/tint/reader/spirv/function_bit_test.cc b/src/tint/reader/spirv/function_bit_test.cc
index 40f9162..2a12f01 100644
--- a/src/tint/reader/spirv/function_bit_test.cc
+++ b/src/tint/reader/spirv/function_bit_test.cc
@@ -33,6 +33,8 @@
%uint_10 = OpConstant %uint 10
%uint_20 = OpConstant %uint 20
+ %int_10 = OpConstant %int 10
+ %int_20 = OpConstant %int 20
%int_30 = OpConstant %int 30
%int_40 = OpConstant %int 40
%float_50 = OpConstant %float 50
@@ -832,7 +834,7 @@
TEST_F(SpvUnaryBitTest, InsertBits_Int) {
const auto assembly = BitTestPreamble() + R"(
- %1 = OpBitFieldInsert %v2int %int_30 %int_40 %uint_10 %uint_20
+ %1 = OpBitFieldInsert %int %int_30 %int_40 %uint_10 %uint_20
OpReturn
OpFunctionEnd
)";
@@ -842,7 +844,23 @@
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
- EXPECT_THAT(body, HasSubstr("let x_1 : vec2<i32> = insertBits(30i, 40i, 10u, 20u);")) << body;
+ EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, InsertBits_Int_SignedOffsetAndCount) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldInsert %int %int_30 %int_40 %int_10 %int_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, u32(10i), u32(20i));"))
+ << body;
}
TEST_F(SpvUnaryBitTest, InsertBits_IntVector) {
@@ -864,9 +882,9 @@
<< body;
}
-TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
+TEST_F(SpvUnaryBitTest, InsertBits_IntVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
- %1 = OpBitFieldInsert %v2uint %uint_20 %uint_10 %uint_10 %uint_20
+ %1 = OpBitFieldInsert %v2int %v2int_30_40 %v2int_40_30 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
@@ -876,7 +894,42 @@
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
- EXPECT_THAT(body, HasSubstr("let x_1 : vec2<u32> = insertBits(20u, 10u, 10u, 20u);")) << body;
+ EXPECT_THAT(
+ body,
+ HasSubstr(
+ R"(let x_1 : vec2<i32> = insertBits(vec2<i32>(30i, 40i), vec2<i32>(40i, 30i), u32(10i), u32(20i));)"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldInsert %uint %uint_20 %uint_10 %uint_10 %uint_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, InsertBits_Uint_SignedOffsetAndCount) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldInsert %uint %uint_20 %uint_10 %int_10 %int_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, u32(10i), u32(20i));"))
+ << body;
}
TEST_F(SpvUnaryBitTest, InsertBits_UintVector) {
@@ -898,9 +951,9 @@
<< body;
}
-TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
+TEST_F(SpvUnaryBitTest, InsertBits_UintVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
- %1 = OpBitFieldSExtract %v2int %int_30 %uint_10 %uint_20
+ %1 = OpBitFieldInsert %v2uint %v2uint_10_20 %v2uint_20_10 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
@@ -910,7 +963,41 @@
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
- EXPECT_THAT(body, HasSubstr("let x_1 : vec2<i32> = extractBits(30i, 10u, 20u);")) << body;
+ EXPECT_THAT(
+ body,
+ HasSubstr(
+ R"(let x_1 : vec2<u32> = insertBits(vec2<u32>(10u, 20u), vec2<u32>(20u, 10u), u32(10i), u32(20i));)"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldSExtract %int %int_30 %uint_10 %uint_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Int_SignedOffsetAndCount) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldSExtract %int %int_30 %int_10 %int_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, u32(10i), u32(20i));")) << body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_IntVector) {
@@ -930,9 +1017,9 @@
<< body;
}
-TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
+TEST_F(SpvUnaryBitTest, ExtractBits_IntVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
- %1 = OpBitFieldUExtract %v2uint %uint_20 %uint_10 %uint_20
+ %1 = OpBitFieldSExtract %v2int %v2int_30_40 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
@@ -942,7 +1029,40 @@
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
- EXPECT_THAT(body, HasSubstr("let x_1 : vec2<u32> = extractBits(20u, 10u, 20u);")) << body;
+ EXPECT_THAT(
+ body,
+ HasSubstr("let x_1 : vec2<i32> = extractBits(vec2<i32>(30i, 40i), u32(10i), u32(20i));"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldUExtract %uint %uint_20 %uint_10 %uint_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, 10u, 20u);")) << body;
+}
+
+TEST_F(SpvUnaryBitTest, ExtractBits_Uint_SignedOffsetAndCount) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldUExtract %uint %uint_20 %int_10 %int_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, u32(10i), u32(20i));")) << body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_UintVector) {
@@ -962,5 +1082,23 @@
<< body;
}
+TEST_F(SpvUnaryBitTest, ExtractBits_UintVector_SignedOffsetAndCount) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitFieldUExtract %v2uint %v2uint_10_20 %int_10 %int_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto ast_body = fe.ast_body();
+ auto body = test::ToString(p->program(), ast_body);
+ EXPECT_THAT(
+ body,
+ HasSubstr("let x_1 : vec2<u32> = extractBits(vec2<u32>(10u, 20u), u32(10i), u32(20i));"))
+ << body;
+}
+
} // namespace
} // namespace tint::reader::spirv