[tint][constant] Optimize Value::Equal() for splats
Element-wise comparisons are slow and unnecessary for large `constant::Splat` values.
Bug: chromium:1449538
Change-Id: I17dbfe1fffabbdb45f48920ba8de3af65ed24cae
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/135262
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index ce779bc..006ed50 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1699,6 +1699,7 @@
"constant/manager_test.cc",
"constant/scalar_test.cc",
"constant/splat_test.cc",
+ "constant/value_test.cc",
]
}
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 2ecb97a..474b56e 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -940,6 +940,7 @@
constant/manager_test.cc
constant/scalar_test.cc
constant/splat_test.cc
+ constant/value_test.cc
debug_test.cc
diagnostic/diagnostic_test.cc
diagnostic/formatter_test.cc
diff --git a/src/tint/constant/splat_test.cc b/src/tint/constant/splat_test.cc
index 0e2f21e..e310d63 100644
--- a/src/tint/constant/splat_test.cc
+++ b/src/tint/constant/splat_test.cc
@@ -31,9 +31,9 @@
auto* fNeg0 = constants.Get(-0_f);
auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
- auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
- auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
EXPECT_TRUE(SpfPos0->AllZero());
EXPECT_FALSE(SpfNeg0->AllZero());
@@ -47,9 +47,9 @@
auto* fNeg0 = constants.Get(-0_f);
auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2);
- auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2);
- auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
EXPECT_TRUE(SpfPos0->AnyZero());
EXPECT_FALSE(SpfNeg0->AnyZero());
diff --git a/src/tint/constant/value.cc b/src/tint/constant/value.cc
index 7545731..989f2ba 100644
--- a/src/tint/constant/value.cc
+++ b/src/tint/constant/value.cc
@@ -14,6 +14,7 @@
#include "src/tint/constant/value.h"
+#include "src/tint/constant/splat.h"
#include "src/tint/switch.h"
#include "src/tint/type/array.h"
#include "src/tint/type/matrix.h"
@@ -30,51 +31,68 @@
/// Equal returns true if the constants `a` and `b` are of the same type and value.
bool Value::Equal(const constant::Value* b) const {
+ if (this == b) {
+ return true;
+ }
if (Hash() != b->Hash()) {
return false;
}
if (Type() != b->Type()) {
return false;
}
+
+ auto elements_equal = [&](size_t count) {
+ if (count == 0) {
+ return true;
+ }
+
+ // Avoid per-element comparisons if the constants are splats
+ bool a_is_splat = Is<Splat>();
+ bool b_is_splat = b->Is<Splat>();
+ if (a_is_splat && b_is_splat) {
+ return Index(0)->Equal(b->Index(0));
+ }
+
+ if (a_is_splat) {
+ auto* el_a = Index(0);
+ for (size_t i = 0; i < count; i++) {
+ if (!el_a->Equal(b->Index(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ if (b_is_splat) {
+ auto* el_b = b->Index(0);
+ for (size_t i = 0; i < count; i++) {
+ if (!Index(i)->Equal(el_b)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Per-element comparison
+ for (size_t i = 0; i < count; i++) {
+ if (!Index(i)->Equal(b->Index(i))) {
+ return false;
+ }
+ }
+ return true;
+ };
+
return Switch(
Type(), //
- [&](const type::Vector* vec) {
- for (size_t i = 0; i < vec->Width(); i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
- },
- [&](const type::Matrix* mat) {
- for (size_t i = 0; i < mat->columns(); i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
- },
+ [&](const type::Vector* vec) { return elements_equal(vec->Width()); },
+ [&](const type::Matrix* mat) { return elements_equal(mat->columns()); },
+ [&](const type::Struct* str) { return elements_equal(str->Members().Length()); },
[&](const type::Array* arr) {
- if (auto count = arr->ConstantCount()) {
- for (size_t i = 0; i < count; i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
+ if (auto n = arr->ConstantCount()) {
+ return elements_equal(*n);
}
-
return false;
},
- [&](const type::Struct* str) {
- auto count = str->Members().Length();
- for (size_t i = 0; i < count; i++) {
- if (!Index(i)->Equal(b->Index(i))) {
- return false;
- }
- }
- return true;
- },
[&](Default) {
auto va = InternalValue();
auto vb = b->InternalValue();
diff --git a/src/tint/constant/value_test.cc b/src/tint/constant/value_test.cc
new file mode 100644
index 0000000..e715c21
--- /dev/null
+++ b/src/tint/constant/value_test.cc
@@ -0,0 +1,78 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/constant/splat.h"
+
+#include "src/tint/constant/scalar.h"
+#include "src/tint/constant/test_helper.h"
+
+namespace tint::constant {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+
+using ConstantTest_Value = TestHelper;
+
+TEST_F(ConstantTest_Value, Equal_Scalar_Scalar) {
+ EXPECT_TRUE(constants.Get(10_i)->Equal(constants.Get(10_i)));
+ EXPECT_FALSE(constants.Get(10_i)->Equal(constants.Get(20_i)));
+ EXPECT_FALSE(constants.Get(20_i)->Equal(constants.Get(10_i)));
+
+ EXPECT_TRUE(constants.Get(10_u)->Equal(constants.Get(10_u)));
+ EXPECT_FALSE(constants.Get(10_u)->Equal(constants.Get(20_u)));
+ EXPECT_FALSE(constants.Get(20_u)->Equal(constants.Get(10_u)));
+
+ EXPECT_TRUE(constants.Get(10_f)->Equal(constants.Get(10_f)));
+ EXPECT_FALSE(constants.Get(10_f)->Equal(constants.Get(20_f)));
+ EXPECT_FALSE(constants.Get(20_f)->Equal(constants.Get(10_f)));
+}
+
+TEST_F(ConstantTest_Value, Equal_Splat_Splat) {
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+ auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+ auto* vec3f_2_2_2 = constants.Splat(vec3f, constants.Get(2_f), 3);
+
+ EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_2_2_2->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_1_1_1->Equal(vec3f_2_2_2));
+}
+
+TEST_F(ConstantTest_Value, Equal_Composite_Composite) {
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+ auto* vec3f_1_1_2 = constants.Composite(
+ vec3f, utils::Vector{constants.Get(1_f), constants.Get(1_f), constants.Get(2_f)});
+ auto* vec3f_1_2_1 = constants.Composite(
+ vec3f, utils::Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
+
+ EXPECT_TRUE(vec3f_1_1_2->Equal(vec3f_1_1_2));
+ EXPECT_FALSE(vec3f_1_2_1->Equal(vec3f_1_1_2));
+ EXPECT_FALSE(vec3f_1_1_2->Equal(vec3f_1_2_1));
+}
+
+TEST_F(ConstantTest_Value, Equal_Splat_Composite) {
+ auto* vec3f = create<type::Vector>(create<type::F32>(), 3u);
+
+ auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+ auto* vec3f_1_2_1 = constants.Composite(
+ vec3f, utils::Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
+
+ EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_1_2_1->Equal(vec3f_1_1_1));
+ EXPECT_FALSE(vec3f_1_1_1->Equal(vec3f_1_2_1));
+}
+
+} // namespace
+} // namespace tint::constant
diff --git a/test/tint/bug/chromium/1449538.wgsl b/test/tint/bug/chromium/1449538.wgsl
new file mode 100644
index 0000000..0e89f90
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl
@@ -0,0 +1,10 @@
+fn f() {
+ for (var i0520 = array<i32, 636109182>()[ 0]; false;) {}
+ for (var i62 = array<i32, 656633 >()[0]; false;) {}
+ for (var i0520 = array<i32, 636109182>()[ 0]; false;) {}
+ for (var i62 = array<i32, 656633 >()[0]; false;) {}
+ for (var i62 = array<i32, 6566335>()[341]; false;) {}
+ for (var i60 = array<i32, 1>()[0]; false;) {}
+ for (var i62 = array<i32, 6566335>()[341]; false;) {}
+ for (var i60 = array<i32, 1>()[0]; false;) {}
+}
diff --git a/test/tint/bug/chromium/1449538.wgsl.expected.dxc.hlsl b/test/tint/bug/chromium/1449538.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..bb7f436
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl.expected.dxc.hlsl
@@ -0,0 +1,39 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+void f() {
+ {
+ for(int i0520 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i0520 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i60 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i60 = 0; false; ) {
+ }
+ }
+}
diff --git a/test/tint/bug/chromium/1449538.wgsl.expected.fxc.hlsl b/test/tint/bug/chromium/1449538.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..bb7f436
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl.expected.fxc.hlsl
@@ -0,0 +1,39 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+void f() {
+ {
+ for(int i0520 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i0520 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i60 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i60 = 0; false; ) {
+ }
+ }
+}
diff --git a/test/tint/bug/chromium/1449538.wgsl.expected.glsl b/test/tint/bug/chromium/1449538.wgsl.expected.glsl
new file mode 100644
index 0000000..e797aa3
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl.expected.glsl
@@ -0,0 +1,41 @@
+#version 310 es
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+ return;
+}
+void f() {
+ {
+ for(int i0520 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i0520 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i60 = 0; false; ) {
+ }
+ }
+ {
+ for(int i62 = 0; false; ) {
+ }
+ }
+ {
+ for(int i60 = 0; false; ) {
+ }
+ }
+}
+
diff --git a/test/tint/bug/chromium/1449538.wgsl.expected.msl b/test/tint/bug/chromium/1449538.wgsl.expected.msl
new file mode 100644
index 0000000..81a4429
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl.expected.msl
@@ -0,0 +1,22 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void f() {
+ for(int i0520 = 0; false; ) {
+ }
+ for(int i62 = 0; false; ) {
+ }
+ for(int i0520 = 0; false; ) {
+ }
+ for(int i62 = 0; false; ) {
+ }
+ for(int i62 = 0; false; ) {
+ }
+ for(int i60 = 0; false; ) {
+ }
+ for(int i62 = 0; false; ) {
+ }
+ for(int i60 = 0; false; ) {
+ }
+}
+
diff --git a/test/tint/bug/chromium/1449538.wgsl.expected.spvasm b/test/tint/bug/chromium/1449538.wgsl.expected.spvasm
new file mode 100644
index 0000000..d6c641b
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl.expected.spvasm
@@ -0,0 +1,162 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 68
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+ OpExecutionMode %unused_entry_point LocalSize 1 1 1
+ OpName %unused_entry_point "unused_entry_point"
+ OpName %f "f"
+ OpName %i0520 "i0520"
+ OpName %i62 "i62"
+ OpName %i0520_0 "i0520"
+ OpName %i62_0 "i62"
+ OpName %i62_1 "i62"
+ OpName %i60 "i60"
+ OpName %i62_2 "i62"
+ OpName %i60_0 "i60"
+ %void = OpTypeVoid
+ %1 = OpTypeFunction %void
+ %int = OpTypeInt 32 1
+ %8 = OpConstantNull %int
+%_ptr_Function_int = OpTypePointer Function %int
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+%unused_entry_point = OpFunction %void None %1
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %f = OpFunction %void None %1
+ %6 = OpLabel
+ %i0520 = OpVariable %_ptr_Function_int Function %8
+ %i62 = OpVariable %_ptr_Function_int Function %8
+ %i0520_0 = OpVariable %_ptr_Function_int Function %8
+ %i62_0 = OpVariable %_ptr_Function_int Function %8
+ %i62_1 = OpVariable %_ptr_Function_int Function %8
+ %i60 = OpVariable %_ptr_Function_int Function %8
+ %i62_2 = OpVariable %_ptr_Function_int Function %8
+ %i60_0 = OpVariable %_ptr_Function_int Function %8
+ OpStore %i0520 %8
+ OpBranch %11
+ %11 = OpLabel
+ OpLoopMerge %12 %13 None
+ OpBranch %14
+ %14 = OpLabel
+ OpSelectionMerge %17 None
+ OpBranchConditional %true %18 %17
+ %18 = OpLabel
+ OpBranch %12
+ %17 = OpLabel
+ OpBranch %13
+ %13 = OpLabel
+ OpBranch %11
+ %12 = OpLabel
+ OpStore %i62 %8
+ OpBranch %20
+ %20 = OpLabel
+ OpLoopMerge %21 %22 None
+ OpBranch %23
+ %23 = OpLabel
+ OpSelectionMerge %24 None
+ OpBranchConditional %true %25 %24
+ %25 = OpLabel
+ OpBranch %21
+ %24 = OpLabel
+ OpBranch %22
+ %22 = OpLabel
+ OpBranch %20
+ %21 = OpLabel
+ OpStore %i0520_0 %8
+ OpBranch %27
+ %27 = OpLabel
+ OpLoopMerge %28 %29 None
+ OpBranch %30
+ %30 = OpLabel
+ OpSelectionMerge %31 None
+ OpBranchConditional %true %32 %31
+ %32 = OpLabel
+ OpBranch %28
+ %31 = OpLabel
+ OpBranch %29
+ %29 = OpLabel
+ OpBranch %27
+ %28 = OpLabel
+ OpStore %i62_0 %8
+ OpBranch %34
+ %34 = OpLabel
+ OpLoopMerge %35 %36 None
+ OpBranch %37
+ %37 = OpLabel
+ OpSelectionMerge %38 None
+ OpBranchConditional %true %39 %38
+ %39 = OpLabel
+ OpBranch %35
+ %38 = OpLabel
+ OpBranch %36
+ %36 = OpLabel
+ OpBranch %34
+ %35 = OpLabel
+ OpStore %i62_1 %8
+ OpBranch %41
+ %41 = OpLabel
+ OpLoopMerge %42 %43 None
+ OpBranch %44
+ %44 = OpLabel
+ OpSelectionMerge %45 None
+ OpBranchConditional %true %46 %45
+ %46 = OpLabel
+ OpBranch %42
+ %45 = OpLabel
+ OpBranch %43
+ %43 = OpLabel
+ OpBranch %41
+ %42 = OpLabel
+ OpStore %i60 %8
+ OpBranch %48
+ %48 = OpLabel
+ OpLoopMerge %49 %50 None
+ OpBranch %51
+ %51 = OpLabel
+ OpSelectionMerge %52 None
+ OpBranchConditional %true %53 %52
+ %53 = OpLabel
+ OpBranch %49
+ %52 = OpLabel
+ OpBranch %50
+ %50 = OpLabel
+ OpBranch %48
+ %49 = OpLabel
+ OpStore %i62_2 %8
+ OpBranch %55
+ %55 = OpLabel
+ OpLoopMerge %56 %57 None
+ OpBranch %58
+ %58 = OpLabel
+ OpSelectionMerge %59 None
+ OpBranchConditional %true %60 %59
+ %60 = OpLabel
+ OpBranch %56
+ %59 = OpLabel
+ OpBranch %57
+ %57 = OpLabel
+ OpBranch %55
+ %56 = OpLabel
+ OpStore %i60_0 %8
+ OpBranch %62
+ %62 = OpLabel
+ OpLoopMerge %63 %64 None
+ OpBranch %65
+ %65 = OpLabel
+ OpSelectionMerge %66 None
+ OpBranchConditional %true %67 %66
+ %67 = OpLabel
+ OpBranch %63
+ %66 = OpLabel
+ OpBranch %64
+ %64 = OpLabel
+ OpBranch %62
+ %63 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/chromium/1449538.wgsl.expected.wgsl b/test/tint/bug/chromium/1449538.wgsl.expected.wgsl
new file mode 100644
index 0000000..9ebae82
--- /dev/null
+++ b/test/tint/bug/chromium/1449538.wgsl.expected.wgsl
@@ -0,0 +1,18 @@
+fn f() {
+ for(var i0520 = array<i32, 636109182>()[0]; false; ) {
+ }
+ for(var i62 = array<i32, 656633>()[0]; false; ) {
+ }
+ for(var i0520 = array<i32, 636109182>()[0]; false; ) {
+ }
+ for(var i62 = array<i32, 656633>()[0]; false; ) {
+ }
+ for(var i62 = array<i32, 6566335>()[341]; false; ) {
+ }
+ for(var i60 = array<i32, 1>()[0]; false; ) {
+ }
+ for(var i62 = array<i32, 6566335>()[341]; false; ) {
+ }
+ for(var i60 = array<i32, 1>()[0]; false; ) {
+ }
+}