spirv-reader: support OpBitCount, OpBitReverse
Bug: tint:3
Change-Id: I81580136621ab51a9852e1d692ddad2457b9aab9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35340
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 484fbba..50bb0ab 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -465,6 +465,10 @@
// given instruction, or ast::Intrinsic::kNone
ast::Intrinsic GetIntrinsic(SpvOp opcode) {
switch (opcode) {
+ case SpvOpBitCount:
+ return ast::Intrinsic::kCountOneBits;
+ case SpvOpBitReverse:
+ return ast::Intrinsic::kReverseBits;
case SpvOpDot:
return ast::Intrinsic::kDot;
case SpvOpOuterProduct:
@@ -3726,8 +3730,13 @@
ident->set_intrinsic(intrinsic);
ast::ExpressionList params;
+ ast::type::Type* first_operand_type = nullptr;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
- params.emplace_back(MakeOperand(inst, iarg).expr);
+ TypedExpression operand = MakeOperand(inst, iarg);
+ if (first_operand_type == nullptr) {
+ first_operand_type = operand.type;
+ }
+ params.emplace_back(operand.expr);
}
auto* call_expr = create<ast::CallExpression>(ident, std::move(params));
auto* result_type = parser_impl_.ConvertType(inst.type_id());
@@ -3736,7 +3745,8 @@
<< inst.PrettyPrint();
return {};
}
- return {result_type, call_expr};
+ TypedExpression call{result_type, call_expr};
+ return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
}
TypedExpression FunctionEmitter::MakeSimpleSelect(
diff --git a/src/reader/spirv/function_bit_test.cc b/src/reader/spirv/function_bit_test.cc
index 3a03837..3ed4f4e 100644
--- a/src/reader/spirv/function_bit_test.cc
+++ b/src/reader/spirv/function_bit_test.cc
@@ -627,11 +627,499 @@
<< ToString(fe.ast_body());
}
+std::string BitTestPreamble() {
+ return R"(
+ OpCapability Shader
+ %glsl = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %100 "main"
+ OpExecutionMode %100 LocalSize 1 1 1
+
+ OpName %u1 "u1"
+ OpName %i1 "i1"
+ OpName %v2u1 "v2u1"
+ OpName %v2i1 "v2i1"
+
+)" + CommonTypes() +
+ R"(
+
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+
+ %u1 = OpCopyObject %uint %uint_10
+ %i1 = OpCopyObject %int %int_30
+ %v2u1 = OpCopyObject %v2uint %v2uint_10_20
+ %v2i1 = OpCopyObject %v2int %v2int_30_40
+)";
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Uint_Uint) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %uint %u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __u32
+ {
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{u1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Uint_Int) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %uint %i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __u32
+ {
+ Bitcast[not set]<__u32>{
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{i1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Int_Uint) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %int %u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __i32
+ {
+ Bitcast[not set]<__i32>{
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{u1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Int_Int) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %int %i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __i32
+ {
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{i1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_UintVector_UintVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %v2uint %v2u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__u32
+ {
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{v2u1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_UintVector_IntVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %v2uint %v2i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__u32
+ {
+ Bitcast[not set]<__vec_2__u32>{
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{v2i1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_IntVector_UintVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %v2int %v2u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__i32
+ {
+ Bitcast[not set]<__vec_2__i32>{
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{v2u1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_IntVector_IntVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitCount %v2int %v2i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__i32
+ {
+ Call[not set]{
+ Identifier[not set]{countOneBits}
+ (
+ Identifier[not set]{v2i1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Uint_Uint) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %uint %u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __u32
+ {
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{u1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Uint_Int) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %uint %i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __u32
+ {
+ Bitcast[not set]<__u32>{
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{i1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Int_Uint) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %int %u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __i32
+ {
+ Bitcast[not set]<__i32>{
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{u1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Int_Int) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %int %i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __i32
+ {
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{i1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_UintVector_UintVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %v2uint %v2u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__u32
+ {
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{v2u1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_UintVector_IntVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %v2uint %v2i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__u32
+ {
+ Bitcast[not set]<__vec_2__u32>{
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{v2i1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_IntVector_UintVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %v2int %v2u1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__i32
+ {
+ Bitcast[not set]<__vec_2__i32>{
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{v2u1}
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_IntVector_IntVector) {
+ const auto assembly = BitTestPreamble() + R"(
+ %1 = OpBitReverse %v2int %v2i1
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __vec_2__i32
+ {
+ Call[not set]{
+ Identifier[not set]{reverseBits}
+ (
+ Identifier[not set]{v2i1}
+ )
+ }
+ }
+ })"))
+ << body;
+}
+
// TODO(dneto): OpBitFieldInsert
// TODO(dneto): OpBitFieldSExtract
// TODO(dneto): OpBitFieldUExtract
-// TODO(dneto): OpBitReverse
-// TODO(dneto): OpBitCount
} // namespace
} // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 71e9319..6c2bd66 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -209,10 +209,14 @@
return false;
}
-// Returns true if the operation is binary, and the WGSL operation requires
+// Returns true if the corresponding WGSL operation requires
// the signedness of the result to match the signedness of the first operand.
-bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) {
+bool AssumesResultSignednessMatchesFirstOperand(SpvOp opcode) {
switch (opcode) {
+ case SpvOpNot:
+ case SpvOpSNegate:
+ case SpvOpBitCount:
+ case SpvOpBitReverse:
case SpvOpSDiv:
case SpvOpSMod:
case SpvOpSRem:
@@ -1501,14 +1505,7 @@
const spvtools::opt::Instruction& inst,
ast::type::Type* first_operand_type) {
const auto opcode = inst.opcode();
- if ((opcode == SpvOpSNegate) || (opcode == SpvOpNot)) {
- // The unary operation cases that force the result type to match the
- // first operand type.
- return first_operand_type;
- }
- if (AssumesResultSignednessMatchesBinaryFirstOperand(opcode)) {
- // The binary operation cases that force the result type to match
- // the first operand type.
+ if (AssumesResultSignednessMatchesFirstOperand(opcode)) {
return first_operand_type;
}
if (IsGlslExtendedInstruction(inst)) {