[tint][constant] Optimize Value::Equal() for splats

Element-wise comparisons are slow and unnecessary for large `constant::Splat` values.

Bug: chromium:1449538
Change-Id: I17dbfe1fffabbdb45f48920ba8de3af65ed24cae
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/135262
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index ce779bc..006ed50 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1699,6 +1699,7 @@
       "constant/manager_test.cc",
       "constant/scalar_test.cc",
       "constant/splat_test.cc",
+      "constant/value_test.cc",
     ]
   }
 
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 2ecb97a..474b56e 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -940,6 +940,7 @@
     constant/manager_test.cc
     constant/scalar_test.cc
     constant/splat_test.cc
+    constant/value_test.cc
     debug_test.cc
     diagnostic/diagnostic_test.cc
     diagnostic/formatter_test.cc
diff --git a/src/tint/constant/splat_test.cc b/src/tint/constant/splat_test.cc
index 0e2f21e..e310d63 100644
--- a/src/tint/constant/splat_test.cc
+++ b/src/tint/constant/splat_test.cc
@@ -31,9 +31,9 @@
     auto* fNeg0 = constants.Get(-0_f);
     auto* fPos1 = constants.Get(1_f);
 
-    auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
-    auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
-    auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
+    auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
+    auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
+    auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
 
     EXPECT_TRUE(SpfPos0->AllZero());
     EXPECT_FALSE(SpfNeg0->AllZero());
@@ -47,9 +47,9 @@
     auto* fNeg0 = constants.Get(-0_f);
     auto* fPos1 = constants.Get(1_f);
 
-    auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
-    auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
-    auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
+    auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
+    auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
+    auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
 
     EXPECT_TRUE(SpfPos0->AnyZero());
     EXPECT_FALSE(SpfNeg0->AnyZero());
diff --git a/src/tint/constant/value.cc b/src/tint/constant/value.cc
index 7545731..989f2ba 100644
--- a/src/tint/constant/value.cc
+++ b/src/tint/constant/value.cc
@@ -14,6 +14,7 @@
 
 #include "src/tint/constant/value.h"
 
+#include "src/tint/constant/splat.h"
 #include "src/tint/switch.h"
 #include "src/tint/type/array.h"
 #include "src/tint/type/matrix.h"
@@ -30,51 +31,68 @@
 
 /// Equal returns true if the constants `a` and `b` are of the same type and value.
 bool Value::Equal(const constant::Value* b) const {
+    if (this == b) {
+        return true;
+    }
     if (Hash() != b->Hash()) {
         return false;
     }
     if (Type() != b->Type()) {
         return false;
     }
+
+    auto elements_equal = [&](size_t count) {
+        if (count == 0) {
+            return true;
+        }
+
+        // Avoid per-element comparisons if the constants are splats
+        bool a_is_splat = Is<Splat>();
+        bool b_is_splat = b->Is<Splat>();
+        if (a_is_splat && b_is_splat) {
+            return Index(0)->Equal(b->Index(0));
+        }
+
+        if (a_is_splat) {
+            auto* el_a = Index(0);
+            for (size_t i = 0; i < count; i++) {
+                if (!el_a->Equal(b->Index(i))) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        if (b_is_splat) {
+            auto* el_b = b->Index(0);
+            for (size_t i = 0; i < count; i++) {
+                if (!Index(i)->Equal(el_b)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        // Per-element comparison
+        for (size_t i = 0; i < count; i++) {
+            if (!Index(i)->Equal(b->Index(i))) {
+                return false;
+            }
+        }
+        return true;
+    };
+
     return Switch(
         Type(),  //
-        [&](const type::Vector* vec) {
-            for (size_t i = 0; i < vec->Width(); i++) {
-                if (!Index(i)->Equal(b->Index(i))) {
-                    return false;
-                }
-            }
-            return true;
-        },
-        [&](const type::Matrix* mat) {
-            for (size_t i = 0; i < mat->columns(); i++) {
-                if (!Index(i)->Equal(b->Index(i))) {
-                    return false;
-                }
-            }
-            return true;
-        },
+        [&](const type::Vector* vec) { return elements_equal(vec->Width()); },
+        [&](const type::Matrix* mat) { return elements_equal(mat->columns()); },
+        [&](const type::Struct* str) { return elements_equal(str->Members().Length()); },
         [&](const type::Array* arr) {
-            if (auto count = arr->ConstantCount()) {
-                for (size_t i = 0; i < count; i++) {
-                    if (!Index(i)->Equal(b->Index(i))) {
-                        return false;
-                    }
-                }
-                return true;
+            if (auto n = arr->ConstantCount()) {
+                return elements_equal(*n);
             }
-
             return false;
         },
-        [&](const type::Struct* str) {
-            auto count = str->Members().Length();
-            for (size_t i = 0; i < count; i++) {
-                if (!Index(i)->Equal(b->Index(i))) {
-                    return false;
-                }
-            }
-            return true;
-        },
         [&](Default) {
             auto va = InternalValue();
             auto vb = b->InternalValue();
diff --git a/src/tint/constant/value_test.cc b/src/tint/constant/value_test.cc
new file mode 100644
index 0000000..e715c21
--- /dev/null
+++ b/src/tint/constant/value_test.cc
@@ -0,0 +1,78 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/constant/splat.h"
+
+#include "src/tint/constant/scalar.h"
+#include "src/tint/constant/test_helper.h"
+
+namespace tint::constant {
+namespace {
+
+using namespace tint::number_suffixes;  // NOLINT
+
+using ConstantTest_Value = TestHelper;
+
+TEST_F(ConstantTest_Value, Equal_Scalar_Scalar) {
+    EXPECT_TRUE(constants.Get(10_i)->Equal(constants.Get(10_i)));
+    EXPECT_FALSE(constants.Get(10_i)->Equal(constants.Get(20_i)));
+    EXPECT_FALSE(constants.Get(20_i)->Equal(constants.Get(10_i)));
+
+    EXPECT_TRUE(constants.Get(10_u)->Equal(constants.Get(10_u)));
+    EXPECT_FALSE(constants.Get(10_u)->Equal(constants.Get(20_u)));
+    EXPECT_FALSE(constants.Get(20_u)->Equal(constants.Get(10_u)));
+
+    EXPECT_TRUE(constants.Get(10_f)->Equal(constants.Get(10_f)));
+    EXPECT_FALSE(constants.Get(10_f)->Equal(constants.Get(20_f)));
+    EXPECT_FALSE(constants.Get(20_f)->Equal(constants.Get(10_f)));
+}
+
+TEST_F(ConstantTest_Value, Equal_Splat_Splat) {
+    auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+    auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+    auto* vec3f_2_2_2 = constants.Splat(vec3f, constants.Get(2_f), 3);
+
+    EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
+    EXPECT_FALSE(vec3f_2_2_2->Equal(vec3f_1_1_1));
+    EXPECT_FALSE(vec3f_1_1_1->Equal(vec3f_2_2_2));
+}
+
+TEST_F(ConstantTest_Value, Equal_Composite_Composite) {
+    auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+    auto* vec3f_1_1_2 = constants.Composite(
+        vec3f, utils::Vector{constants.Get(1_f), constants.Get(1_f), constants.Get(2_f)});
+    auto* vec3f_1_2_1 = constants.Composite(
+        vec3f, utils::Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
+
+    EXPECT_TRUE(vec3f_1_1_2->Equal(vec3f_1_1_2));
+    EXPECT_FALSE(vec3f_1_2_1->Equal(vec3f_1_1_2));
+    EXPECT_FALSE(vec3f_1_1_2->Equal(vec3f_1_2_1));
+}
+
+TEST_F(ConstantTest_Value, Equal_Splat_Composite) {
+    auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+    auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+    auto* vec3f_1_2_1 = constants.Composite(
+        vec3f, utils::Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
+
+    EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
+    EXPECT_FALSE(vec3f_1_2_1->Equal(vec3f_1_1_1));
+    EXPECT_FALSE(vec3f_1_1_1->Equal(vec3f_1_2_1));
+}
+
+}  // namespace
+}  // namespace tint::constant