spirv-reader: support OpBitCount, OpBitReverse

Bug: tint:3
Change-Id: I81580136621ab51a9852e1d692ddad2457b9aab9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35340
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 484fbba..50bb0ab 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -465,6 +465,10 @@
 // given instruction, or ast::Intrinsic::kNone
 ast::Intrinsic GetIntrinsic(SpvOp opcode) {
   switch (opcode) {
+    case SpvOpBitCount:
+      return ast::Intrinsic::kCountOneBits;
+    case SpvOpBitReverse:
+      return ast::Intrinsic::kReverseBits;
     case SpvOpDot:
       return ast::Intrinsic::kDot;
     case SpvOpOuterProduct:
@@ -3726,8 +3730,13 @@
   ident->set_intrinsic(intrinsic);
 
   ast::ExpressionList params;
+  ast::type::Type* first_operand_type = nullptr;
   for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
-    params.emplace_back(MakeOperand(inst, iarg).expr);
+    TypedExpression operand = MakeOperand(inst, iarg);
+    if (first_operand_type == nullptr) {
+      first_operand_type = operand.type;
+    }
+    params.emplace_back(operand.expr);
   }
   auto* call_expr = create<ast::CallExpression>(ident, std::move(params));
   auto* result_type = parser_impl_.ConvertType(inst.type_id());
@@ -3736,7 +3745,8 @@
            << inst.PrettyPrint();
     return {};
   }
-  return {result_type, call_expr};
+  TypedExpression call{result_type, call_expr};
+  return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
 }
 
 TypedExpression FunctionEmitter::MakeSimpleSelect(
diff --git a/src/reader/spirv/function_bit_test.cc b/src/reader/spirv/function_bit_test.cc
index 3a03837..3ed4f4e 100644
--- a/src/reader/spirv/function_bit_test.cc
+++ b/src/reader/spirv/function_bit_test.cc
@@ -627,11 +627,499 @@
       << ToString(fe.ast_body());
 }
 
+std::string BitTestPreamble() {
+  return R"(
+  OpCapability Shader
+  %glsl = OpExtInstImport "GLSL.std.450"
+  OpMemoryModel Logical GLSL450
+  OpEntryPoint GLCompute %100 "main"
+  OpExecutionMode %100 LocalSize 1 1 1
+
+  OpName %u1 "u1"
+  OpName %i1 "i1"
+  OpName %v2u1 "v2u1"
+  OpName %v2i1 "v2i1"
+
+)" + CommonTypes() +
+         R"(
+
+  %100 = OpFunction %void None %voidfn
+  %entry = OpLabel
+
+  %u1 = OpCopyObject %uint %uint_10
+  %i1 = OpCopyObject %int %int_30
+  %v2u1 = OpCopyObject %v2uint %v2uint_10_20
+  %v2i1 = OpCopyObject %v2int %v2int_30_40
+)";
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Uint_Uint) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %uint %u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __u32
+    {
+      Call[not set]{
+        Identifier[not set]{countOneBits}
+        (
+          Identifier[not set]{u1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Uint_Int) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %uint %i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __u32
+    {
+      Bitcast[not set]<__u32>{
+        Call[not set]{
+          Identifier[not set]{countOneBits}
+          (
+            Identifier[not set]{i1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Int_Uint) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %int %u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __i32
+    {
+      Bitcast[not set]<__i32>{
+        Call[not set]{
+          Identifier[not set]{countOneBits}
+          (
+            Identifier[not set]{u1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_Int_Int) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %int %i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __i32
+    {
+      Call[not set]{
+        Identifier[not set]{countOneBits}
+        (
+          Identifier[not set]{i1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_UintVector_UintVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %v2uint %v2u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__u32
+    {
+      Call[not set]{
+        Identifier[not set]{countOneBits}
+        (
+          Identifier[not set]{v2u1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_UintVector_IntVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %v2uint %v2i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Call[not set]{
+          Identifier[not set]{countOneBits}
+          (
+            Identifier[not set]{v2i1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_IntVector_UintVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %v2int %v2u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__i32
+    {
+      Bitcast[not set]<__vec_2__i32>{
+        Call[not set]{
+          Identifier[not set]{countOneBits}
+          (
+            Identifier[not set]{v2u1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitCount_IntVector_IntVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitCount %v2int %v2i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__i32
+    {
+      Call[not set]{
+        Identifier[not set]{countOneBits}
+        (
+          Identifier[not set]{v2i1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Uint_Uint) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %uint %u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __u32
+    {
+      Call[not set]{
+        Identifier[not set]{reverseBits}
+        (
+          Identifier[not set]{u1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Uint_Int) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %uint %i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __u32
+    {
+      Bitcast[not set]<__u32>{
+        Call[not set]{
+          Identifier[not set]{reverseBits}
+          (
+            Identifier[not set]{i1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Int_Uint) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %int %u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __i32
+    {
+      Bitcast[not set]<__i32>{
+        Call[not set]{
+          Identifier[not set]{reverseBits}
+          (
+            Identifier[not set]{u1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_Int_Int) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %int %i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __i32
+    {
+      Call[not set]{
+        Identifier[not set]{reverseBits}
+        (
+          Identifier[not set]{i1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_UintVector_UintVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %v2uint %v2u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__u32
+    {
+      Call[not set]{
+        Identifier[not set]{reverseBits}
+        (
+          Identifier[not set]{v2u1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_UintVector_IntVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %v2uint %v2i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Call[not set]{
+          Identifier[not set]{reverseBits}
+          (
+            Identifier[not set]{v2i1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_IntVector_UintVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %v2int %v2u1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__i32
+    {
+      Bitcast[not set]<__vec_2__i32>{
+        Call[not set]{
+          Identifier[not set]{reverseBits}
+          (
+            Identifier[not set]{v2u1}
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
+TEST_F(SpvUnaryBitTest, BitReverse_IntVector_IntVector) {
+  const auto assembly = BitTestPreamble() + R"(
+     %1 = OpBitReverse %v2int %v2i1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __vec_2__i32
+    {
+      Call[not set]{
+        Identifier[not set]{reverseBits}
+        (
+          Identifier[not set]{v2i1}
+        )
+      }
+    }
+  })"))
+      << body;
+}
+
 // TODO(dneto): OpBitFieldInsert
 // TODO(dneto): OpBitFieldSExtract
 // TODO(dneto): OpBitFieldUExtract
-// TODO(dneto): OpBitReverse
-// TODO(dneto): OpBitCount
 
 }  // namespace
 }  // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 71e9319..6c2bd66 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -209,10 +209,14 @@
   return false;
 }
 
-// Returns true if the operation is binary, and the WGSL operation requires
+// Returns true if the corresponding WGSL operation requires
 // the signedness of the result to match the signedness of the first operand.
-bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) {
+bool AssumesResultSignednessMatchesFirstOperand(SpvOp opcode) {
   switch (opcode) {
+    case SpvOpNot:
+    case SpvOpSNegate:
+    case SpvOpBitCount:
+    case SpvOpBitReverse:
     case SpvOpSDiv:
     case SpvOpSMod:
     case SpvOpSRem:
@@ -1501,14 +1505,7 @@
     const spvtools::opt::Instruction& inst,
     ast::type::Type* first_operand_type) {
   const auto opcode = inst.opcode();
-  if ((opcode == SpvOpSNegate) || (opcode == SpvOpNot)) {
-    // The unary operation cases that force the result type to match the
-    // first operand type.
-    return first_operand_type;
-  }
-  if (AssumesResultSignednessMatchesBinaryFirstOperand(opcode)) {
-    // The binary operation cases that force the result type to match
-    // the first operand type.
+  if (AssumesResultSignednessMatchesFirstOperand(opcode)) {
     return first_operand_type;
   }
   if (IsGlslExtendedInstruction(inst)) {