tint: Implement DP4a on HLSL writer

Bug: tint:1497
Test: tint_unittests
Change-Id: I29cc3e56949071230cdbd5afdc59eef076777149
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89706
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/sem/builtin.cc b/src/tint/sem/builtin.cc
index 0328a64..f102307 100644
--- a/src/tint/sem/builtin.cc
+++ b/src/tint/sem/builtin.cc
@@ -83,6 +83,10 @@
            i == sem::BuiltinType::kAtomicCompareExchangeWeak;
 }
 
+bool IsDP4aBuiltin(BuiltinType i) {
+    return i == sem::BuiltinType::kDot4I8Packed || i == sem::BuiltinType::kDot4U8Packed;
+}
+
 Builtin::Builtin(BuiltinType type,
                  const sem::Type* return_type,
                  std::vector<Parameter*> parameters,
@@ -135,6 +139,10 @@
     return IsAtomicBuiltin(type_);
 }
 
+bool Builtin::IsDP4a() const {
+    return IsDP4aBuiltin(type_);
+}
+
 bool Builtin::HasSideEffects() const {
     if (IsAtomic() && type_ != sem::BuiltinType::kAtomicLoad) {
         return true;
@@ -146,13 +154,10 @@
 }
 
 ast::Enable::ExtensionKind Builtin::RequiredExtension() const {
-    switch (type_) {
-        case sem::BuiltinType::kDot4I8Packed:
-        case sem::BuiltinType::kDot4U8Packed:
-            return ast::Enable::ExtensionKind::kChromiumExperimentalDP4a;
-        default:
-            return ast::Enable::ExtensionKind::kNotAnExtension;
+    if (IsDP4a()) {
+        return ast::Enable::ExtensionKind::kChromiumExperimentalDP4a;
     }
+    return ast::Enable::ExtensionKind::kNotAnExtension;
 }
 
 }  // namespace tint::sem
diff --git a/src/tint/sem/builtin.h b/src/tint/sem/builtin.h
index 8d3e2bd..5d340ce 100644
--- a/src/tint/sem/builtin.h
+++ b/src/tint/sem/builtin.h
@@ -70,6 +70,11 @@
 /// @returns true if the given `i` is a atomic builtin
 bool IsAtomicBuiltin(BuiltinType i);
 
+/// Determins if the given `i` is a DP4a builtin
+/// @param i the builtin
+/// @returns true if the given `i` is a DP4a builtin
+bool IsDP4aBuiltin(BuiltinType i);
+
 /// Builtin holds the semantic information for a builtin function.
 class Builtin final : public Castable<Builtin, CallTarget> {
   public:
@@ -130,6 +135,10 @@
     /// @returns true if builtin is a atomic builtin
     bool IsAtomic() const;
 
+    /// @returns true if builtin is a DP4a builtin (defined in the extension
+    /// chromium_experimental_DP4a)
+    bool IsDP4a() const;
+
     /// @returns true if intrinsic may have side-effects (i.e. writes to at least
     /// one of its inputs)
     bool HasSideEffects() const;
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 814d6a0..1a16079 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -1027,6 +1027,9 @@
     if (builtin->IsAtomic()) {
         return EmitWorkgroupAtomicCall(out, expr, builtin);
     }
+    if (builtin->IsDP4a()) {
+        return EmitDP4aCall(out, expr, builtin);
+    }
     auto name = generate_builtin_name(builtin);
     if (name.empty()) {
         return false;
@@ -2033,6 +2036,32 @@
         });
 }
 
+bool GeneratorImpl::EmitDP4aCall(std::ostream& out,
+                                 const ast::CallExpression* expr,
+                                 const sem::Builtin* builtin) {
+    // TODO(crbug.com/tint/1497): support the polyfill version of DP4a functions.
+    return CallBuiltinHelper(
+        out, expr, builtin, [&](TextBuffer* b, const std::vector<std::string>& params) {
+            std::string functionName;
+            switch (builtin->Type()) {
+                case sem::BuiltinType::kDot4I8Packed:
+                    functionName = "dot4add_i8packed";
+                    break;
+                case sem::BuiltinType::kDot4U8Packed:
+                    functionName = "dot4add_u8packed";
+                    break;
+                default:
+                    diagnostics_.add_error(diag::System::Writer,
+                                           "Internal error: unhandled DP4a builtin");
+                    return false;
+            }
+            line(b) << "return " << functionName << "(" << params[0] << ", " << params[1]
+                    << ", 0);";
+
+            return true;
+        });
+}
+
 bool GeneratorImpl::EmitBarrierCall(std::ostream& out, const sem::Builtin* builtin) {
     // TODO(crbug.com/tint/661): Combine sequential barriers to a single
     // instruction.
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index e329638..7cd5960 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -242,7 +242,7 @@
     /// Handles generating a call to data packing builtin
     /// @param out the output of the expression stream
     /// @param expr the call expression
-    /// @param builtin the semantic information for the texture builtin
+    /// @param builtin the semantic information for the builtin
     /// @returns true if the call expression is emitted
     bool EmitDataPackingCall(std::ostream& out,
                              const ast::CallExpression* expr,
@@ -250,11 +250,19 @@
     /// Handles generating a call to data unpacking builtin
     /// @param out the output of the expression stream
     /// @param expr the call expression
-    /// @param builtin the semantic information for the texture builtin
+    /// @param builtin the semantic information for the builtin
     /// @returns true if the call expression is emitted
     bool EmitDataUnpackingCall(std::ostream& out,
                                const ast::CallExpression* expr,
                                const sem::Builtin* builtin);
+    /// Handles generating a call to DP4a builtins (dot4I8Packed and dot4U8Packed)
+    /// @param out the output of the expression stream
+    /// @param expr the call expression
+    /// @param builtin the semantic information for the builtin
+    /// @returns true if the call expression is emitted
+    bool EmitDP4aCall(std::ostream& out,
+                      const ast::CallExpression* expr,
+                      const sem::Builtin* builtin);
     /// Handles a case statement
     /// @param s the switch statement
     /// @param case_idx the index of the switch case in the switch statement
diff --git a/src/tint/writer/hlsl/generator_impl_builtin_test.cc b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
index e4e6ba9..64ef740 100644
--- a/src/tint/writer/hlsl/generator_impl_builtin_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
@@ -726,5 +726,61 @@
 )");
 }
 
+TEST_F(HlslGeneratorImplTest_Builtin, Dot4I8Packed) {
+    auto* ext =
+        create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
+                            "chromium_experimental_dp4a");
+    AST().AddEnable(ext);
+
+    auto* val1 = Var("val1", ty.u32());
+    auto* val2 = Var("val2", ty.u32());
+    auto* call = Call("dot4I8Packed", val1, val2);
+    WrapInFunction(val1, val2, call);
+
+    GeneratorImpl& gen = SanitizeAndBuild();
+
+    ASSERT_TRUE(gen.Generate()) << gen.error();
+    EXPECT_EQ(gen.result(), R"(int tint_dot4I8Packed(uint param_0, uint param_1) {
+  return dot4add_i8packed(param_0, param_1, 0);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+  uint val1 = 0u;
+  uint val2 = 0u;
+  const int tint_symbol = tint_dot4I8Packed(val1, val2);
+  return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Dot4U8Packed) {
+    auto* ext =
+        create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
+                            "chromium_experimental_dp4a");
+    AST().AddEnable(ext);
+
+    auto* val1 = Var("val1", ty.u32());
+    auto* val2 = Var("val2", ty.u32());
+    auto* call = Call("dot4U8Packed", val1, val2);
+    WrapInFunction(val1, val2, call);
+
+    GeneratorImpl& gen = SanitizeAndBuild();
+
+    ASSERT_TRUE(gen.Generate()) << gen.error();
+    EXPECT_EQ(gen.result(), R"(uint tint_dot4U8Packed(uint param_0, uint param_1) {
+  return dot4add_u8packed(param_0, param_1, 0);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+  uint val1 = 0u;
+  uint val2 = 0u;
+  const uint tint_symbol = tint_dot4U8Packed(val1, val2);
+  return;
+}
+)");
+}
+
 }  // namespace
 }  // namespace tint::writer::hlsl