[ir][spirv-writer] Polyfill the `dot()` builtin

OpDot only accepts floating point operands, so we have to polyfill the
integer overload.

Bug: tint:1906
Change-Id: I6e94873bff169c7ec0d04e375b5193f462fd31ce
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/140890
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/ir/call.h b/src/tint/ir/call.h
index 041c012..2d1ff35 100644
--- a/src/tint/ir/call.h
+++ b/src/tint/ir/call.h
@@ -26,7 +26,7 @@
     ~Call() override;
 
     /// @returns the call arguments
-    virtual utils::Slice<Value* const> Args() { return operands_.Slice(); }
+    virtual utils::Slice<Value*> Args() { return operands_.Slice(); }
 
   protected:
     /// Constructor
diff --git a/src/tint/ir/intrinsic_call.cc b/src/tint/ir/intrinsic_call.cc
index 2629d8f..4063213 100644
--- a/src/tint/ir/intrinsic_call.cc
+++ b/src/tint/ir/intrinsic_call.cc
@@ -34,6 +34,9 @@
 
 utils::StringStream& operator<<(utils::StringStream& out, enum IntrinsicCall::Kind kind) {
     switch (kind) {
+        case IntrinsicCall::Kind::kSpirvDot:
+            out << "spirv.dot";
+            break;
         case IntrinsicCall::Kind::kSpirvSelect:
             out << "spirv.select";
             break;
diff --git a/src/tint/ir/intrinsic_call.h b/src/tint/ir/intrinsic_call.h
index a5ebef5..9bbf9db 100644
--- a/src/tint/ir/intrinsic_call.h
+++ b/src/tint/ir/intrinsic_call.h
@@ -29,6 +29,7 @@
     /// The kind of instruction.
     enum class Kind {
         // SPIR-V backend intrinsics.
+        kSpirvDot,
         kSpirvSelect,
     };
 
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv.cc b/src/tint/ir/transform/builtin_polyfill_spirv.cc
index d9f960d..3e0521d 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv.cc
+++ b/src/tint/ir/transform/builtin_polyfill_spirv.cc
@@ -50,6 +50,7 @@
             }
             if (auto* builtin = inst->As<BuiltinCall>()) {
                 switch (builtin->Func()) {
+                    case builtin::Function::kDot:
                     case builtin::Function::kSelect:
                         worklist.Push(builtin);
                         break;
@@ -63,6 +64,9 @@
         for (auto* builtin : worklist) {
             Value* replacement = nullptr;
             switch (builtin->Func()) {
+                case builtin::Function::kDot:
+                    replacement = Dot(builtin);
+                    break;
                 case builtin::Function::kSelect:
                     replacement = Select(builtin);
                     break;
@@ -80,6 +84,44 @@
         }
     }
 
+    /// Handle a `dot()` builtin.
+    /// @param builtin the builtin call instruction
+    /// @returns the replacement value
+    Value* Dot(BuiltinCall* builtin) {
+        // OpDot only supports floating point operands, so we need to polyfill the integer case.
+        // TODO(crbug.com/tint/1267): If SPV_KHR_integer_dot_product is supported, use that instead.
+        if (builtin->Result()->Type()->is_integer_scalar()) {
+            Instruction* sum = nullptr;
+
+            auto* v1 = builtin->Args()[0];
+            auto* v2 = builtin->Args()[1];
+            auto* vec = v1->Type()->As<type::Vector>();
+            auto* elty = vec->type();
+            for (uint32_t i = 0; i < vec->Width(); i++) {
+                auto* e1 = b.Access(elty, v1, u32(i));
+                e1->InsertBefore(builtin);
+                auto* e2 = b.Access(elty, v2, u32(i));
+                e2->InsertBefore(builtin);
+                auto* mul = b.Multiply(elty, e1, e2);
+                mul->InsertBefore(builtin);
+                if (sum) {
+                    sum = b.Add(elty, sum, mul);
+                    sum->InsertBefore(builtin);
+                } else {
+                    sum = mul;
+                }
+            }
+            return sum->Result();
+        }
+
+        // Replace the builtin call with a call to the spirv.dot intrinsic.
+        auto args = utils::Vector<Value*, 4>(builtin->Args());
+        auto* call =
+            b.Call(builtin->Result()->Type(), IntrinsicCall::Kind::kSpirvDot, std::move(args));
+        call->InsertBefore(builtin);
+        return call->Result();
+    }
+
     /// Handle a `select()` builtin.
     /// @param builtin the builtin call instruction
     /// @returns the replacement value
diff --git a/src/tint/ir/transform/builtin_polyfill_spirv_test.cc b/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
index b7eaa07..007a702 100644
--- a/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
+++ b/src/tint/ir/transform/builtin_polyfill_spirv_test.cc
@@ -26,6 +26,131 @@
 
 using IR_BuiltinPolyfillSpirvTest = TransformTest;
 
+TEST_F(IR_BuiltinPolyfillSpirvTest, Dot_Vec4f) {
+    auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+    auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.f32());
+    func->SetParams({arg1, arg2});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Call(ty.f32(), builtin::Function::kDot, arg1, arg2);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%arg1:vec4<f32>, %arg2:vec4<f32>):f32 -> %b1 {
+  %b1 = block {
+    %4:f32 = dot %arg1, %arg2
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%arg1:vec4<f32>, %arg2:vec4<f32>):f32 -> %b1 {
+  %b1 = block {
+    %4:f32 = spirv.dot %arg1, %arg2
+    ret %4
+  }
+}
+)";
+
+    Run<BuiltinPolyfillSpirv>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, Dot_Vec2i) {
+    auto* arg1 = b.FunctionParam("arg1", ty.vec2<i32>());
+    auto* arg2 = b.FunctionParam("arg2", ty.vec2<i32>());
+    auto* func = b.Function("foo", ty.i32());
+    func->SetParams({arg1, arg2});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Call(ty.i32(), builtin::Function::kDot, arg1, arg2);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%arg1:vec2<i32>, %arg2:vec2<i32>):i32 -> %b1 {
+  %b1 = block {
+    %4:i32 = dot %arg1, %arg2
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%arg1:vec2<i32>, %arg2:vec2<i32>):i32 -> %b1 {
+  %b1 = block {
+    %4:i32 = access %arg1, 0u
+    %5:i32 = access %arg2, 0u
+    %6:i32 = mul %4, %5
+    %7:i32 = access %arg1, 1u
+    %8:i32 = access %arg2, 1u
+    %9:i32 = mul %7, %8
+    %10:i32 = add %6, %9
+    ret %10
+  }
+}
+)";
+
+    Run<BuiltinPolyfillSpirv>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BuiltinPolyfillSpirvTest, Dot_Vec4u) {
+    auto* arg1 = b.FunctionParam("arg1", ty.vec4<u32>());
+    auto* arg2 = b.FunctionParam("arg2", ty.vec4<u32>());
+    auto* func = b.Function("foo", ty.u32());
+    func->SetParams({arg1, arg2});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Call(ty.u32(), builtin::Function::kDot, arg1, arg2);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%arg1:vec4<u32>, %arg2:vec4<u32>):u32 -> %b1 {
+  %b1 = block {
+    %4:u32 = dot %arg1, %arg2
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%arg1:vec4<u32>, %arg2:vec4<u32>):u32 -> %b1 {
+  %b1 = block {
+    %4:u32 = access %arg1, 0u
+    %5:u32 = access %arg2, 0u
+    %6:u32 = mul %4, %5
+    %7:u32 = access %arg1, 1u
+    %8:u32 = access %arg2, 1u
+    %9:u32 = mul %7, %8
+    %10:u32 = add %6, %9
+    %11:u32 = access %arg1, 2u
+    %12:u32 = access %arg2, 2u
+    %13:u32 = mul %11, %12
+    %14:u32 = add %10, %13
+    %15:u32 = access %arg1, 3u
+    %16:u32 = access %arg2, 3u
+    %17:u32 = mul %15, %16
+    %18:u32 = add %14, %17
+    ret %18
+  }
+}
+)";
+
+    Run<BuiltinPolyfillSpirv>();
+
+    EXPECT_EQ(expect, str());
+}
+
 TEST_F(IR_BuiltinPolyfillSpirvTest, Select_ScalarCondition_ScalarOperands) {
     auto* argf = b.FunctionParam("argf", ty.i32());
     auto* argt = b.FunctionParam("argt", ty.i32());
diff --git a/src/tint/ir/user_call.h b/src/tint/ir/user_call.h
index 9d3ea24..ffab14c 100644
--- a/src/tint/ir/user_call.h
+++ b/src/tint/ir/user_call.h
@@ -38,9 +38,7 @@
     ~UserCall() override;
 
     /// @returns the call arguments
-    utils::Slice<Value* const> Args() override {
-        return operands_.Slice().Offset(kArgsOperandOffset);
-    }
+    utils::Slice<Value*> Args() override { return operands_.Slice().Offset(kArgsOperandOffset); }
 
     /// @returns the called function name
     Function* Func() { return operands_[kFunctionOperandOffset]->As<ir::Function>(); }
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 6f825e3..4bf2293 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -1384,6 +1384,9 @@
 
     spv::Op op = spv::Op::Max;
     switch (call->Kind()) {
+        case ir::IntrinsicCall::Kind::kSpirvDot:
+            op = spv::Op::OpDot;
+            break;
         case ir::IntrinsicCall::Kind::kSpirvSelect:
             op = spv::Op::OpSelect;
             break;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index 359aad8..83b6cb8 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -411,6 +411,75 @@
     EXPECT_INST("%result = OpExtInst %half %9 Distance %arg1 %arg2");
 }
 
+TEST_F(SpvGeneratorImplTest, Builtin_Dot_vec4f) {
+    auto* arg1 = b.FunctionParam("arg1", ty.vec4<f32>());
+    auto* arg2 = b.FunctionParam("arg2", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.f32());
+    func->SetParams({arg1, arg2});
+    b.With(func->Block(), [&] {
+        auto* result = b.Call(ty.f32(), builtin::Function::kDot, arg1, arg2);
+        b.Return(func, result);
+        mod.SetName(result, "result");
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST("%result = OpDot %float %arg1 %arg2");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Dot_vec2i) {
+    auto* arg1 = b.FunctionParam("arg1", ty.vec2<i32>());
+    auto* arg2 = b.FunctionParam("arg2", ty.vec2<i32>());
+    auto* func = b.Function("foo", ty.i32());
+    func->SetParams({arg1, arg2});
+    b.With(func->Block(), [&] {
+        auto* result = b.Call(ty.i32(), builtin::Function::kDot, arg1, arg2);
+        b.Return(func, result);
+        mod.SetName(result, "result");
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+          %8 = OpCompositeExtract %int %arg1 0
+          %9 = OpCompositeExtract %int %arg2 0
+         %10 = OpIMul %int %8 %9
+         %11 = OpCompositeExtract %int %arg1 1
+         %12 = OpCompositeExtract %int %arg2 1
+         %13 = OpIMul %int %11 %12
+     %result = OpIAdd %int %10 %13
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Dot_vec4u) {
+    auto* arg1 = b.FunctionParam("arg1", ty.vec4<u32>());
+    auto* arg2 = b.FunctionParam("arg2", ty.vec4<u32>());
+    auto* func = b.Function("foo", ty.u32());
+    func->SetParams({arg1, arg2});
+    b.With(func->Block(), [&] {
+        auto* result = b.Call(ty.u32(), builtin::Function::kDot, arg1, arg2);
+        b.Return(func, result);
+        mod.SetName(result, "result");
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+          %8 = OpCompositeExtract %uint %arg1 0
+          %9 = OpCompositeExtract %uint %arg2 0
+         %10 = OpIMul %uint %8 %9
+         %11 = OpCompositeExtract %uint %arg1 1
+         %12 = OpCompositeExtract %uint %arg2 1
+         %13 = OpIMul %uint %11 %12
+         %14 = OpIAdd %uint %10 %13
+         %15 = OpCompositeExtract %uint %arg1 2
+         %16 = OpCompositeExtract %uint %arg2 2
+         %17 = OpIMul %uint %15 %16
+         %18 = OpIAdd %uint %14 %17
+         %19 = OpCompositeExtract %uint %arg1 3
+         %20 = OpCompositeExtract %uint %arg2 3
+         %21 = OpIMul %uint %19 %20
+     %result = OpIAdd %uint %18 %21
+)");
+}
+
 // Tests for builtins with the signature: T = func(T, T, T)
 using Builtin_3arg = SpvGeneratorImplTestWithParam<BuiltinTestCase>;
 TEST_P(Builtin_3arg, Scalar) {