[ir][spirv-writer] Polyfill the `dot()` builtin
OpDot only accepts floating point operands, so we have to polyfill the
integer overload.
Bug: tint:1906
Change-Id: I6e94873bff169c7ec0d04e375b5193f462fd31ce
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/140890
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/ir/call.h b/src/tint/ir/call.h
index 041c012..2d1ff35 100644
--- a/src/tint/ir/call.h
+++ b/src/tint/ir/call.h
@@ -26,7 +26,7 @@
~Call() override;
/// @returns the call arguments
- virtual utils::Slice<Value* const> Args() { return operands_.Slice(); }
+ virtual utils::Slice<Value*> Args() { return operands_.Slice(); }
protected:
/// Constructor
diff --git a/src/tint/ir/intrinsic_call.cc b/src/tint/ir/intrinsic_call.cc
index 2629d8f..4063213 100644
--- a/src/tint/ir/intrinsic_call.cc
+++ b/src/tint/ir/intrinsic_call.cc
@@ -34,6 +34,9 @@
utils::StringStream& operator<<(utils::StringStream& out, enum IntrinsicCall::Kind kind) {
switch (kind) {
+ case IntrinsicCall::Kind::kSpirvDot:
+ out << "spirv.dot";
+ break;
case IntrinsicCall::Kind::kSpirvSelect:
out << "spirv.select";
break;
diff --git a/src/tint/ir/intrinsic_call.h b/src/tint/ir/intrinsic_call.h
index a5ebef5..9bbf9db 100644
--- a/src/tint/ir/intrinsic_call.h
+++ b/src/tint/ir/intrinsic_call.h
@@ -29,6 +29,7 @@
/// The kind of instruction.
enum class Kind {
// SPIR-V backend intrinsics.
+ kSpirvDot,
kSpirvSelect,
};
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv.cc b/src/tint/ir/transform/builtin_polyfill_spirv.cc
index d9f960d..3e0521d 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv.cc
+++ b/src/tint/ir/transform/builtin_polyfill_spirv.cc
@@ -50,6 +50,7 @@
}
if (auto* builtin = inst->As<BuiltinCall>()) {
switch (builtin->Func()) {
+ case builtin::Function::kDot:
case builtin::Function::kSelect:
worklist.Push(builtin);
break;
@@ -63,6 +64,9 @@
for (auto* builtin : worklist) {
Value* replacement = nullptr;
switch (builtin->Func()) {
+ case builtin::Function::kDot:
+ replacement = Dot(builtin);
+ break;
case builtin::Function::kSelect:
replacement = Select(builtin);
break;
@@ -80,6 +84,44 @@
}
}
+ /// Handle a `dot()` builtin.
+ /// @param builtin the builtin call instruction
+ /// @returns the replacement value
+ Value* Dot(BuiltinCall* builtin) {
+ // OpDot only supports floating point operands, so we need to polyfill the integer case.
+ // TODO(crbug.com/tint/1267): If SPV_KHR_integer_dot_product is supported, use that instead.
+ if (builtin->Result()->Type()->is_integer_scalar()) {
+ Instruction* sum = nullptr;
+
+ auto* v1 = builtin->Args()[0];
+ auto* v2 = builtin->Args()[1];
+ auto* vec = v1->Type()->As<type::Vector>();
+ auto* elty = vec->type();
+ for (uint32_t i = 0; i < vec->Width(); i++) {
+ auto* e1 = b.Access(elty, v1, u32(i));
+ e1->InsertBefore(builtin);
+ auto* e2 = b.Access(elty, v2, u32(i));
+ e2->InsertBefore(builtin);
+ auto* mul = b.Multiply(elty, e1, e2);
+ mul->InsertBefore(builtin);
+ if (sum) {
+ sum = b.Add(elty, sum, mul);
+ sum->InsertBefore(builtin);
+ } else {
+ sum = mul;
+ }
+ }
+ return sum->Result();
+ }
+
+ // Replace the builtin call with a call to the spirv.dot intrinsic.
+ auto args = utils::Vector<Value*, 4>(builtin->Args());
+ auto* call =
+ b.Call(builtin->Result()->Type(), IntrinsicCall::Kind::kSpirvDot, std::move(args));
+ call->InsertBefore(builtin);
+ return call->Result();
+ }
+
/// Handle a `select()` builtin.
/// @param builtin the builtin call instruction
/// @returns the replacement value
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv_test.cc b/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
index b7eaa07..007a702 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
+++ b/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
@@ -26,6 +26,131 @@
using IR_BuiltinPolyfillSpirvTest = TransformTest;
+TEST_F(IR_BuiltinPolyfillSpirvTest, Dot_Vec4f) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.f32(), builtin::Function::kDot, arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:vec4<f32>, %arg2:vec4<f32>):f32 -> %b1 {
+ %b1 = block {
+ %4:f32 = dot %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:vec4<f32>, %arg2:vec4<f32>):f32 -> %b1 {
+ %b1 = block {
+ %4:f32 = spirv.dot %arg1, %arg2
+ ret %4
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, Dot_Vec2i) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec2<i32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec2<i32>());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.i32(), builtin::Function::kDot, arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:vec2<i32>, %arg2:vec2<i32>):i32 -> %b1 {
+ %b1 = block {
+ %4:i32 = dot %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:vec2<i32>, %arg2:vec2<i32>):i32 -> %b1 {
+ %b1 = block {
+ %4:i32 = access %arg1, 0u
+ %5:i32 = access %arg2, 0u
+ %6:i32 = mul %4, %5
+ %7:i32 = access %arg1, 1u
+ %8:i32 = access %arg2, 1u
+ %9:i32 = mul %7, %8
+ %10:i32 = add %6, %9
+ ret %10
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, Dot_Vec4u) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<u32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<u32>());
+ auto* func = b.Function("foo", ty.u32());
+ func->SetParams({arg1, arg2});
+
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.u32(), builtin::Function::kDot, arg1, arg2);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg1:vec4<u32>, %arg2:vec4<u32>):u32 -> %b1 {
+ %b1 = block {
+ %4:u32 = dot %arg1, %arg2
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg1:vec4<u32>, %arg2:vec4<u32>):u32 -> %b1 {
+ %b1 = block {
+ %4:u32 = access %arg1, 0u
+ %5:u32 = access %arg2, 0u
+ %6:u32 = mul %4, %5
+ %7:u32 = access %arg1, 1u
+ %8:u32 = access %arg2, 1u
+ %9:u32 = mul %7, %8
+ %10:u32 = add %6, %9
+ %11:u32 = access %arg1, 2u
+ %12:u32 = access %arg2, 2u
+ %13:u32 = mul %11, %12
+ %14:u32 = add %10, %13
+ %15:u32 = access %arg1, 3u
+ %16:u32 = access %arg2, 3u
+ %17:u32 = mul %15, %16
+ %18:u32 = add %14, %17
+ ret %18
+ }
+}
+)";
+
+ Run<BuiltinPolyfillSpirv>();
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(IR_BuiltinPolyfillSpirvTest, Select_ScalarCondition_ScalarOperands) {
auto* argf = b.FunctionParam("argf", ty.i32());
auto* argt = b.FunctionParam("argt", ty.i32());
diff --git a/src/tint/ir/user_call.h b/src/tint/ir/user_call.h
index 9d3ea24..ffab14c 100644
--- a/src/tint/ir/user_call.h
+++ b/src/tint/ir/user_call.h
@@ -38,9 +38,7 @@
~UserCall() override;
/// @returns the call arguments
- utils::Slice<Value* const> Args() override {
- return operands_.Slice().Offset(kArgsOperandOffset);
- }
+ utils::Slice<Value*> Args() override { return operands_.Slice().Offset(kArgsOperandOffset); }
/// @returns the called function name
Function* Func() { return operands_[kFunctionOperandOffset]->As<ir::Function>(); }
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 6f825e3..4bf2293 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -1384,6 +1384,9 @@
spv::Op op = spv::Op::Max;
switch (call->Kind()) {
+ case ir::IntrinsicCall::Kind::kSpirvDot:
+ op = spv::Op::OpDot;
+ break;
case ir::IntrinsicCall::Kind::kSpirvSelect:
op = spv::Op::OpSelect;
break;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index 359aad8..83b6cb8 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -411,6 +411,75 @@
EXPECT_INST("%result = OpExtInst %half %9 Distance %arg1 %arg2");
}
+TEST_F(SpvGeneratorImplTest, Builtin_Dot_vec4f) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({arg1, arg2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.f32(), builtin::Function::kDot, arg1, arg2);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpDot %float %arg1 %arg2");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Dot_vec2i) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec2<i32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec2<i32>());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arg1, arg2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.i32(), builtin::Function::kDot, arg1, arg2);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %8 = OpCompositeExtract %int %arg1 0
+ %9 = OpCompositeExtract %int %arg2 0
+ %10 = OpIMul %int %8 %9
+ %11 = OpCompositeExtract %int %arg1 1
+ %12 = OpCompositeExtract %int %arg2 1
+ %13 = OpIMul %int %11 %12
+ %result = OpIAdd %int %10 %13
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Dot_vec4u) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec4<u32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec4<u32>());
+ auto* func = b.Function("foo", ty.u32());
+ func->SetParams({arg1, arg2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.u32(), builtin::Function::kDot, arg1, arg2);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %8 = OpCompositeExtract %uint %arg1 0
+ %9 = OpCompositeExtract %uint %arg2 0
+ %10 = OpIMul %uint %8 %9
+ %11 = OpCompositeExtract %uint %arg1 1
+ %12 = OpCompositeExtract %uint %arg2 1
+ %13 = OpIMul %uint %11 %12
+ %14 = OpIAdd %uint %10 %13
+ %15 = OpCompositeExtract %uint %arg1 2
+ %16 = OpCompositeExtract %uint %arg2 2
+ %17 = OpIMul %uint %15 %16
+ %18 = OpIAdd %uint %14 %17
+ %19 = OpCompositeExtract %uint %arg1 3
+ %20 = OpCompositeExtract %uint %arg2 3
+ %21 = OpIMul %uint %19 %20
+ %result = OpIAdd %uint %18 %21
+)");
+}
+
// Tests for builtins with the signature: T = func(T, T, T)
using Builtin_3arg = SpvGeneratorImplTestWithParam<BuiltinTestCase>;
TEST_P(Builtin_3arg, Scalar) {