[tint][constant] Optimize Value::Equal() for splats
Element-wise comparisons are slow and unnecessary for large `constant::Splat` values.
Bug: chromium:1449538
Change-Id: I17dbfe1fffabbdb45f48920ba8de3af65ed24cae
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/135262
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index ce779bc..006ed50 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1699,6 +1699,7 @@
"constant/manager_test.cc",
"constant/scalar_test.cc",
"constant/splat_test.cc",
+ "constant/value_test.cc",
]
}
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 2ecb97a..474b56e 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -940,6 +940,7 @@
constant/manager_test.cc
constant/scalar_test.cc
constant/splat_test.cc
+ constant/value_test.cc
debug_test.cc
diagnostic/diagnostic_test.cc
diagnostic/formatter_test.cc
diff --git a/src/tint/constant/splat_test.cc b/src/tint/constant/splat_test.cc
index 0e2f21e..e310d63 100644
--- a/src/tint/constant/splat_test.cc
+++ b/src/tint/constant/splat_test.cc
@@ -31,9 +31,9 @@
auto* fNeg0 = constants.Get(-0_f);
auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
- auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
- auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
EXPECT_TRUE(SpfPos0->AllZero());
EXPECT_FALSE(SpfNeg0->AllZero());
@@ -47,9 +47,9 @@
auto* fNeg0 = constants.Get(-0_f);
auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
- auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
- auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
EXPECT_TRUE(SpfPos0->AnyZero());
EXPECT_FALSE(SpfNeg0->AnyZero());
diff --git a/src/tint/constant/value.cc b/src/tint/constant/value.cc
index 7545731..989f2ba 100644
--- a/src/tint/constant/value.cc
+++ b/src/tint/constant/value.cc
@@ -14,6 +14,7 @@
#include "src/tint/constant/value.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/switch.h"
#include "src/tint/type/array.h"
#include "src/tint/type/matrix.h"
@@ -30,51 +31,68 @@
/// Equal returns true if the constants `a` and `b` are of the same type and value.
bool Value::Equal(const constant::Value* b) const {
+ if (this == b) {
+ return true;
+ }
if (Hash() != b->Hash()) {
return false;
}
if (Type() != b->Type()) {
return false;
}
+
+ auto elements_equal = [&](size_t count) {
+ if (count == 0) {
+ return true;
+ }
+
+ // Avoid per-element comparisons if the constants are splats
+ bool a_is_splat = Is<Splat>();
+ bool b_is_splat = b->Is<Splat>();
+ if (a_is_splat && b_is_splat) {
+ return Index(0)->Equal(b->Index(0));
+ }
+
+ if (a_is_splat) {
+ auto* el_a = Index(0);
+ for (size_t i = 0; i < count; i++) {
+ if (!el_a->Equal(b->Index(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ if (b_is_splat) {
+ auto* el_b = b->Index(0);
+ for (size_t i = 0; i < count; i++) {
+ if (!Index(i)->Equal(el_b)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Per-element comparison
+ for (size_t i = 0; i < count; i++) {
+ if (!Index(i)->Equal(b->Index(i))) {
+ return false;
+ }
+ }
+ return true;
+ };
+
return Switch(
Type(), //
- [&](const type::Vector* vec) {
- for (size_t i = 0; i < vec->Width(); i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
- },
- [&](const type::Matrix* mat) {
- for (size_t i = 0; i < mat->columns(); i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
- },
+ [&](const type::Vector* vec) { return elements_equal(vec->Width()); },
+ [&](const type::Matrix* mat) { return elements_equal(mat->columns()); },
+ [&](const type::Struct* str) { return elements_equal(str->Members().Length()); },
[&](const type::Array* arr) {
- if (auto count = arr->ConstantCount()) {
- for (size_t i = 0; i < count; i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
+ if (auto n = arr->ConstantCount()) {
+ return elements_equal(*n);
}
-
return false;
},
- [&](const type::Struct* str) {
- auto count = str->Members().Length();
- for (size_t i = 0; i < count; i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
- },
[&](Default) {
auto va = InternalValue();
auto vb = b->InternalValue();
diff --git a/src/tint/constant/value_test.cc b/src/tint/constant/value_test.cc
new file mode 100644
index 0000000..e715c21
--- /dev/null
+++ b/src/tint/constant/value_test.cc
@@ -0,0 +1,78 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/constant/splat.h"
+
+#include "src/tint/constant/scalar.h"
+#include "src/tint/constant/test_helper.h"
+
+namespace tint::constant {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using ConstantTest_Value = TestHelper;
+
+TEST_F(ConstantTest_Value, Equal_Scalar_Scalar) {
+ EXPECT_TRUE(constants.Get(10_i)->Equal(constants.Get(10_i)));
+ EXPECT_FALSE(constants.Get(10_i)->Equal(constants.Get(20_i)));
+ EXPECT_FALSE(constants.Get(20_i)->Equal(constants.Get(10_i)));
+
+ EXPECT_TRUE(constants.Get(10_u)->Equal(constants.Get(10_u)));
+ EXPECT_FALSE(constants.Get(10_u)->Equal(constants.Get(20_u)));
+ EXPECT_FALSE(constants.Get(20_u)->Equal(constants.Get(10_u)));
+
+ EXPECT_TRUE(constants.Get(10_f)->Equal(constants.Get(10_f)));
+ EXPECT_FALSE(constants.Get(10_f)->Equal(constants.Get(20_f)));
+ EXPECT_FALSE(constants.Get(20_f)->Equal(constants.Get(10_f)));
+}
+
+TEST_F(ConstantTest_Value, Equal_Splat_Splat) {
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+ auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+ auto* vec3f_2_2_2 = constants.Splat(vec3f, constants.Get(2_f), 3);
+
+ EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_2_2_2->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_1_1_1->Equal(vec3f_2_2_2));
+}
+
+TEST_F(ConstantTest_Value, Equal_Composite_Composite) {
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+ auto* vec3f_1_1_2 = constants.Composite(
+ vec3f, utils::Vector{constants.Get(1_f), constants.Get(1_f), constants.Get(2_f)});
+ auto* vec3f_1_2_1 = constants.Composite(
+ vec3f, utils::Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
+
+ EXPECT_TRUE(vec3f_1_1_2->Equal(vec3f_1_1_2));
+ EXPECT_FALSE(vec3f_1_2_1->Equal(vec3f_1_1_2));
+ EXPECT_FALSE(vec3f_1_1_2->Equal(vec3f_1_2_1));
+}
+
+TEST_F(ConstantTest_Value, Equal_Splat_Composite) {
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+ auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+ auto* vec3f_1_2_1 = constants.Composite(
+ vec3f, utils::Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
+
+ EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_1_2_1->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_1_1_1->Equal(vec3f_1_2_1));
+}
+
+} // namespace
+} // namespace tint::constant