[spirv-reader][ir] Support `OpBitCount`.
Adds support to translate `OpBitCount` into the WGSL `countOneBits`
method. In SPIR-V the result and base values do not have to have the
same type, so `bitcast` if needed.
Bug: 394878128
Change-Id: I04f971f1a4682fed2ba0f4ef6cc7a2dfb09b4006
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/225174
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc
index 723bfd8..df00a2d 100644
--- a/src/tint/lang/spirv/builtin_fn.cc
+++ b/src/tint/lang/spirv/builtin_fn.cc
@@ -146,6 +146,8 @@
return "modf";
case BuiltinFn::kFrexp:
return "frexp";
+ case BuiltinFn::kBitCount:
+ return "bit_count";
case BuiltinFn::kSdot:
return "sdot";
case BuiltinFn::kUdot:
@@ -227,6 +229,7 @@
case BuiltinFn::kFaceForward:
case BuiltinFn::kLdexp:
case BuiltinFn::kCooperativeMatrixMulAdd:
+ case BuiltinFn::kBitCount:
break;
}
return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl
index 0fa98f9..272b6fb 100644
--- a/src/tint/lang/spirv/builtin_fn.cc.tmpl
+++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl
@@ -94,6 +94,7 @@
case BuiltinFn::kFaceForward:
case BuiltinFn::kLdexp:
case BuiltinFn::kCooperativeMatrixMulAdd:
+ case BuiltinFn::kBitCount:
break;
}
return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h
index 08adce5..b54c521 100644
--- a/src/tint/lang/spirv/builtin_fn.h
+++ b/src/tint/lang/spirv/builtin_fn.h
@@ -100,6 +100,7 @@
kLdexp,
kModf,
kFrexp,
+ kBitCount,
kSdot,
kUdot,
kCooperativeMatrixLoad,
diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc
index e4c9599..901cac2 100644
--- a/src/tint/lang/spirv/intrinsic/data.cc
+++ b/src/tint/lang/spirv/intrinsic/data.cc
@@ -6383,30 +6383,37 @@
},
{
/* [52] */
+ /* fn bit_count<R : iu32>[T : iu32](T) -> R */
+ /* fn bit_count<R : iu32>[T : iu32, N : num](vec<N, T>) -> vec<N, R> */
+ /* num overloads */ 2,
+ /* overloads */ OverloadIndex(154),
+ },
+ {
+ /* [53] */
/* fn sdot(u32, u32, u32) -> i32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(183),
},
{
- /* [53] */
+ /* [54] */
/* fn udot(u32, u32, u32) -> u32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(184),
},
{
- /* [54] */
+ /* [55] */
/* fn cooperative_matrix_load<T : subgroup_matrix<K, S, C, R>>[K : subgroup_matrix_kind, S : fiu32_f16, C : num, R : num](ptr<workgroup_or_storage, S, readable>, u32, u32, u32) -> T */
/* num overloads */ 1,
/* overloads */ OverloadIndex(185),
},
{
- /* [55] */
+ /* [56] */
/* fn cooperative_matrix_store[K : subgroup_matrix_kind, S : fiu32_f16, C : num, R : num](ptr<workgroup_or_storage, S, writable>, subgroup_matrix<K, S, C, R>, u32, u32, u32) */
/* num overloads */ 1,
/* overloads */ OverloadIndex(186),
},
{
- /* [56] */
+ /* [57] */
/* fn cooperative_matrix_mul_add[S : subgroup_matrix_elements, C : num, R : num, K : num](subgroup_matrix<subgroup_matrix_kind_left, S, K, R>, subgroup_matrix<subgroup_matrix_kind_right, S, C, K>, subgroup_matrix<subgroup_matrix_kind_result, S, C, R>) -> subgroup_matrix<subgroup_matrix_kind_result, S, C, R> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(187),
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index dd5772f..912d1a7 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -119,6 +119,9 @@
case spirv::BuiltinFn::kFrexp:
Frexp(builtin);
break;
+ case spirv::BuiltinFn::kBitCount:
+ BitCount(builtin);
+ break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
}
@@ -355,6 +358,22 @@
call->Destroy();
}
+ void BitCount(spirv::ir::BuiltinCall* call) {
+ auto arg = call->Args()[0];
+
+ b.InsertBefore(call, [&] {
+ auto* res_ty = call->Result(0)->Type();
+ auto* arg_ty = arg->Type();
+
+ auto* bc = b.Call(arg_ty, core::BuiltinFn::kCountOneBits, arg)->Result(0);
+ if (res_ty != arg_ty) {
+ bc = b.Bitcast(res_ty, bc)->Result(0);
+ }
+ call->Result(0)->ReplaceAllUsesWith(bc);
+ });
+ call->Destroy();
+ }
+
void Inverse(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
auto* mat_ty = arg->Type()->As<core::type::Matrix>();
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index 65d7edc..ba717d1 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -2894,5 +2894,269 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Scalar_UnsignedToUnsigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.u32()}, 10_u);
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.bit_count<u32> 10u
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = countOneBits 10u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Scalar_UnsignedToSigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.i32(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.i32()}, 10_u);
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.bit_count<i32> 10u
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = countOneBits 10u
+ %3:i32 = bitcast %2
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Scalar_SignedToUnsigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.u32()}, 10_i);
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.bit_count<u32> 10i
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = countOneBits 10i
+ %3:u32 = bitcast %2
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Scalar_SignedToSigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.i32(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.i32()}, 10_i);
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.bit_count<i32> 10i
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = countOneBits 10i
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Vector_UnsignedToUnsigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<u32>(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.u32()},
+ b.Splat(ty.vec2<u32>(), 10_u));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<u32> = spirv.bit_count<u32> vec2<u32>(10u)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<u32> = countOneBits vec2<u32>(10u)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Vector_UnsignedToSigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<i32>(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.i32()},
+ b.Splat(ty.vec2<u32>(), 10_u));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<i32> = spirv.bit_count<i32> vec2<u32>(10u)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<u32> = countOneBits vec2<u32>(10u)
+ %3:vec2<i32> = bitcast %2
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Vector_SignedToUnsigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<u32>(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.u32()},
+ b.Splat(ty.vec2<i32>(), 10_i));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<u32> = spirv.bit_count<u32> vec2<i32>(10i)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<i32> = countOneBits vec2<i32>(10i)
+ %3:vec2<u32> = bitcast %2
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, BitCount_Vector_SignedToSigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<i32>(), spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{ty.i32()},
+ b.Splat(ty.vec2<i32>(), 10_i));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<i32> = spirv.bit_count<i32> vec2<i32>(10i)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<i32> = countOneBits vec2<i32>(10i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::spirv::reader::lower
diff --git a/src/tint/lang/spirv/reader/parser/builtin_test.cc b/src/tint/lang/spirv/reader/parser/builtin_test.cc
index fb8cf46..e3a5465 100644
--- a/src/tint/lang/spirv/reader/parser/builtin_test.cc
+++ b/src/tint/lang/spirv/reader/parser/builtin_test.cc
@@ -59,5 +59,223 @@
)");
}
+TEST_F(SpirvParserTest, BitCount_Scalar_UnsignedToUnsigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %uint_10 = OpConstant %uint 10
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %uint %uint_10
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.bit_count<u32> 10u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Scalar_UnsignedToSigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %uint_10 = OpConstant %uint 10
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %int %uint_10
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.bit_count<i32> 10u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Scalar_SignedToUnsigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %int_20 = OpConstant %int 20
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %uint %int_20
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.bit_count<u32> 20i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Scalar_SignedToSigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %int = OpTypeInt 32 1
+ %int_20 = OpConstant %int 20
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %int %int_20
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.bit_count<i32> 20i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Vector_UnsignedToUnsigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v2_uint = OpTypeVector %uint 2
+ %uint_10 = OpConstant %uint 10
+ %uint_20 = OpConstant %uint 20
+%v2_uint_10_20 = OpConstantComposite %v2_uint %uint_10 %uint_20
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %v2_uint %v2_uint_10_20
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<u32> = spirv.bit_count<u32> vec2<u32>(10u, 20u)
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Vector_UnsignedToSigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %int = OpTypeInt 32 1
+ %uint = OpTypeInt 32 0
+ %v2_int = OpTypeVector %int 2
+ %v2_uint = OpTypeVector %uint 2
+ %uint_10 = OpConstant %uint 10
+ %uint_20 = OpConstant %uint 20
+%v2_uint_10_20 = OpConstantComposite %v2_uint %uint_10 %uint_20
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %v2_int %v2_uint_10_20
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<i32> = spirv.bit_count<i32> vec2<u32>(10u, 20u)
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Vector_SignedToUnsigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %int = OpTypeInt 32 1
+ %uint = OpTypeInt 32 0
+ %v2_int = OpTypeVector %int 2
+ %v2_uint = OpTypeVector %uint 2
+ %int_10 = OpConstant %int 10
+ %int_20 = OpConstant %int 20
+%v2_int_10_20 = OpConstantComposite %v2_int %int_10 %int_20
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %v2_uint %v2_int_10_20
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<u32> = spirv.bit_count<u32> vec2<i32>(10i, 20i)
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, BitCount_Vector_SignedToSigned) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %int = OpTypeInt 32 1
+ %v2_int = OpTypeVector %int 2
+ %int_10 = OpConstant %int 10
+ %int_20 = OpConstant %int 20
+%v2_int_10_20 = OpConstantComposite %v2_int %int_10 %int_20
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpBitCount %v2_int %v2_int_10_20
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<i32> = spirv.bit_count<i32> vec2<i32>(10i, 20i)
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 3ab1ffc..b69dd45 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -625,6 +625,9 @@
Value(inst.GetSingleWordOperand(3))),
inst.result_id());
break;
+ case spv::Op::OpBitCount:
+ EmitBitCount(inst);
+ break;
default:
TINT_UNIMPLEMENTED()
<< "unhandled SPIR-V instruction: " << static_cast<uint32_t>(inst.opcode());
@@ -632,6 +635,15 @@
}
}
+ void EmitBitCount(const spvtools::opt::Instruction& inst) {
+ auto* res_ty = Type(inst.type_id());
+ Emit(b_.CallExplicit<spirv::ir::BuiltinCall>(
+ res_ty, spirv::BuiltinFn::kBitCount,
+ Vector<const core::type::Type*, 1>{res_ty->DeepestElement()},
+ Value(inst.GetSingleWordOperand(2))),
+ inst.result_id());
+ }
+
/// @param inst the SPIR-V instruction
/// Note: This isn't technically correct, but there is no `kill` equivalent in WGSL. The closets
/// we have is `discard` which maps to `OpDemoteToHelperInvocation` in SPIR-V.
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def
index c94ab13..2d0eba6 100644
--- a/src/tint/lang/spirv/spirv.def
+++ b/src/tint/lang/spirv/spirv.def
@@ -390,6 +390,9 @@
N: num,
S: function_private_workgroup_storage) fn frexp(x: vec<N, T>, i: ptr<S, vec<N, R>, writable>) -> vec<N, T>
+implicit(T: iu32) fn bit_count<R: iu32>(T) -> R
+implicit(T: iu32, N: num) fn bit_count<R: iu32>(vec<N, T>) -> vec<N, R>
+
////////////////////////////////////////////////////////////////////////////////
// SPV_KHR_integer_dot_product instructions
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index a7de433..f39ccaa 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1491,11 +1491,14 @@
case spirv::BuiltinFn::kCooperativeMatrixLoad:
op = spv::Op::OpCooperativeMatrixLoadKHR;
break;
+ case spirv::BuiltinFn::kCooperativeMatrixMulAdd:
+ op = spv::Op::OpCooperativeMatrixMulAddKHR;
+ break;
case spirv::BuiltinFn::kCooperativeMatrixStore:
op = spv::Op::OpCooperativeMatrixStoreKHR;
break;
- case spirv::BuiltinFn::kCooperativeMatrixMulAdd:
- op = spv::Op::OpCooperativeMatrixMulAddKHR;
+ case spirv::BuiltinFn::kBitCount:
+ op = spv::Op::OpBitCount;
break;
case spirv::BuiltinFn::kNone:
TINT_ICE() << "undefined spirv ir function";