Import Tint changes from Dawn
Changes:
- fbb339ff97405b49762883d5e3155f83527ef417 tint: Add sem::Load node by Ben Clayton <bclayton@google.com>
- bed9c98a07fd83802d04cff5bdb11b66835401a6 tint: const eval of fract by Antonio Maiorano <amaiorano@google.com>
- 60eac7250428232c5020001b6f578370b58826be Move constants into a Castable hierarchy. by dan sinclair <dsinclair@chromium.org>
- 1c953fa86362195082fd96341037666ea2cb0905 tint: fix build by Antonio Maiorano <amaiorano@google.com>
- 51d88ebf30dd8b40010c3553c79cfb7c270a9ec5 tint/utils: Reduce cost of HashCombine by Ben Clayton <bclayton@google.com>
- be9696777842d0de4d6639a9214774331b17c287 tint: const eval of pow builtin by Antonio Maiorano <amaiorano@google.com>
- 7f5b9d0b6f0d5f43fa89b961ebfa9c379bdd4e8e tint: add CheckedPow by Antonio Maiorano <amaiorano@google.com>
- e3f3de773a19824f381c0c462f0de7864f76a704 tint/resolver: Fix const-eval Equal() by Ben Clayton <bclayton@google.com>
- f2b86aaffbaef0563a89e1f187e3e5bbef4c2d25 tint: Add hash randomization by Ben Clayton <bclayton@google.com>
- 10182c46d95ea0ba2de20e39e3f1ee177da2f185 Move sem::Constant to constant::Constant by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: fbb339ff97405b49762883d5e3155f83527ef417
Change-Id: I992b659ca76dee32ef4c2b68fa0ab599b81d88e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/114281
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 930958b..45a965d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -366,6 +366,8 @@
"castable.h",
"clone_context.cc",
"clone_context.h",
+ "constant/constant.h",
+ "constant/node.h",
"debug.cc",
"debug.h",
"demangler.cc",
@@ -421,13 +423,13 @@
"sem/builtin_type.h",
"sem/call.h",
"sem/call_target.h",
- "sem/constant.h",
"sem/evaluation_stage.h",
"sem/expression.h",
"sem/for_loop_statement.h",
"sem/if_statement.h",
"sem/index_accessor_expression.h",
"sem/info.h",
+ "sem/load.h",
"sem/loop_statement.h",
"sem/materialize.h",
"sem/module.h",
@@ -649,8 +651,6 @@
"sem/call.h",
"sem/call_target.cc",
"sem/call_target.h",
- "sem/constant.cc",
- "sem/constant.h",
"sem/evaluation_stage.h",
"sem/expression.cc",
"sem/expression.h",
@@ -663,6 +663,8 @@
"sem/index_accessor_expression.h",
"sem/info.cc",
"sem/info.h",
+ "sem/load.cc",
+ "sem/load.h",
"sem/loop_statement.cc",
"sem/loop_statement.h",
"sem/materialize.cc",
@@ -758,8 +760,19 @@
public_deps = [ ":libtint_core_all_src" ]
}
+libtint_source_set("libtint_constant_src") {
+ sources = [
+ "constant/constant.cc",
+ "constant/constant.h",
+ "constant/node.cc",
+ "constant/node.h",
+ ]
+ public_deps = [ ":libtint_core_all_src" ]
+}
+
libtint_source_set("libtint_core_src") {
public_deps = [
+ ":libtint_constant_src",
":libtint_core_all_src",
":libtint_sem_src",
":libtint_type_src",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 7dd09c2..630f998 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -254,6 +254,10 @@
castable.h
clone_context.cc
clone_context.h
+ constant/constant.cc
+ constant/constant.h
+ constant/node.cc
+ constant/node.h
demangler.cc
demangler.h
inspector/entry_point.cc
@@ -306,8 +310,6 @@
sem/call_target.h
sem/call.cc
sem/call.h
- sem/constant.cc
- sem/constant.h
sem/evaluation_stage.h
sem/expression.cc
sem/expression.h
@@ -320,6 +322,8 @@
sem/index_accessor_expression.h
sem/info.cc
sem/info.h
+ sem/load.cc
+ sem/load.h
sem/loop_statement.cc
sem/loop_statement.h
sem/materialize.cc
diff --git a/src/tint/sem/constant.cc b/src/tint/constant/constant.cc
similarity index 81%
rename from src/tint/sem/constant.cc
rename to src/tint/constant/constant.cc
index 70bb08c..a5b0caf 100644
--- a/src/tint/sem/constant.cc
+++ b/src/tint/constant/constant.cc
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/sem/constant.h"
+#include "src/tint/constant/constant.h"
-namespace tint::sem {
+TINT_INSTANTIATE_TYPEINFO(tint::constant::Constant);
+
+namespace tint::constant {
Constant::Constant() = default;
Constant::~Constant() = default;
-} // namespace tint::sem
+} // namespace tint::constant
diff --git a/src/tint/sem/constant.h b/src/tint/constant/constant.h
similarity index 87%
rename from src/tint/sem/constant.h
rename to src/tint/constant/constant.h
index fcf45c7..bd60f49 100644
--- a/src/tint/sem/constant.h
+++ b/src/tint/constant/constant.h
@@ -1,4 +1,4 @@
-// Copyright 2021 The Tint Authors.
+// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0(the "License");
// you may not use this file except in compliance with the License.
@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_TINT_SEM_CONSTANT_H_
-#define SRC_TINT_SEM_CONSTANT_H_
+#ifndef SRC_TINT_CONSTANT_CONSTANT_H_
+#define SRC_TINT_CONSTANT_CONSTANT_H_
#include <variant>
+#include "src/tint/castable.h"
+#include "src/tint/constant/node.h"
#include "src/tint/number.h"
#include "src/tint/type/type.h"
-namespace tint::sem {
+namespace tint::constant {
/// Constant is the interface to a compile-time evaluated expression value.
-class Constant {
+class Constant : public Castable<Constant, Node> {
public:
/// Constructor
Constant();
/// Destructor
- virtual ~Constant();
+ ~Constant() override;
/// @returns the type of the constant
virtual const type::Type* Type() const = 0;
@@ -73,6 +75,6 @@
}
};
-} // namespace tint::sem
+} // namespace tint::constant
-#endif // SRC_TINT_SEM_CONSTANT_H_
+#endif // SRC_TINT_CONSTANT_CONSTANT_H_
diff --git a/src/tint/constant/node.cc b/src/tint/constant/node.cc
new file mode 100644
index 0000000..8f18ba6
--- /dev/null
+++ b/src/tint/constant/node.cc
@@ -0,0 +1,27 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/constant/node.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::constant::Node);
+
+namespace tint::constant {
+
+Node::Node() = default;
+
+Node::Node(const Node&) = default;
+
+Node::~Node() = default;
+
+} // namespace tint::constant
diff --git a/src/tint/constant/node.h b/src/tint/constant/node.h
new file mode 100644
index 0000000..41d00e0
--- /dev/null
+++ b/src/tint/constant/node.h
@@ -0,0 +1,37 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_CONSTANT_NODE_H_
+#define SRC_TINT_CONSTANT_NODE_H_
+
+#include "src/tint/castable.h"
+
+namespace tint::constant {
+
+/// Node is the base class for all constant nodes
+class Node : public Castable<Node> {
+ public:
+ /// Constructor
+ Node();
+
+ /// Copy constructor
+ Node(const Node&);
+
+ /// Destructor
+ ~Node() override;
+};
+
+} // namespace tint::constant
+
+#endif // SRC_TINT_CONSTANT_NODE_H_
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index 7e36750..db2dd09 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -494,8 +494,8 @@
@const fn floor<N: num, T: fa_f32_f16>(@test_value(1.5) vec<N, T>) -> vec<N, T>
@const fn fma<T: fa_f32_f16>(T, T, T) -> T
@const fn fma<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
-fn fract<T: f32_f16>(T) -> T
-fn fract<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
+@const fn fract<T: fa_f32_f16>(@test_value(1.25) T) -> T
+@const fn fract<N: num, T: fa_f32_f16>(@test_value(1.25) vec<N, T>) -> vec<N, T>
@const fn frexp<T: fa_f32_f16>(T) -> __frexp_result<T>
@const fn frexp<N: num, T: fa_f32_f16>(vec<N, T>) -> __frexp_result_vec<N, T>
@stage("fragment") fn fwidth(f32) -> f32
@@ -531,8 +531,8 @@
@const fn pack2x16unorm(vec2<f32>) -> u32
@const fn pack4x8snorm(vec4<f32>) -> u32
@const fn pack4x8unorm(vec4<f32>) -> u32
-fn pow<T: f32_f16>(T, T) -> T
-fn pow<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
+@const fn pow<T: fa_f32_f16>(T, T) -> T
+@const fn pow<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@const fn quantizeToF16(f32) -> f32
@const fn quantizeToF16<N: num>(vec<N, f32>) -> vec<N, f32>
@const fn radians<T: fa_f32_f16>(T) -> T
diff --git a/src/tint/number.h b/src/tint/number.h
index 975ce79..82a9963 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -412,6 +412,9 @@
#endif
#endif
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
+TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+
/// @returns a + b, or an empty optional if the resulting value overflowed the AInt
inline std::optional<AInt> CheckedAdd(AInt a, AInt b) {
int64_t result;
@@ -582,17 +585,29 @@
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
- // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
- TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
-
if (auto mul = CheckedMul(a, b)) {
return CheckedAdd(mul.value(), c);
}
return {};
-
- TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
}
+/// @returns the value of `base` raised to the power `exp`, or an empty optional if the operation
+/// cannot be performed.
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedPow(FloatingPointT base, FloatingPointT exp) {
+ static_assert(IsNumber<FloatingPointT>);
+ if ((base < 0) || (base == 0 && exp <= 0)) {
+ return {};
+ }
+ auto result = FloatingPointT{std::pow(base.value, exp.value)};
+ if (!std::isfinite(result.value)) {
+ return {};
+ }
+ return result;
+}
+
+TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+
} // namespace tint
namespace tint::number_suffixes {
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index a795840..fe03663 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -399,12 +399,32 @@
std::variant<std::optional<AFloat>, std::optional<f32>, std::optional<f16>>;
using BinaryCheckedCase_Float = std::tuple<FloatExpectedTypes, FloatInputTypes, FloatInputTypes>;
+/// Validates that result is equal to expect. If `float_comp` is true, uses EXPECT_FLOAT_EQ to
+/// compare the values.
+template <typename T>
+void ValidateResult(std::optional<T> result, std::optional<T> expect, bool float_comp = false) {
+ if (!expect) {
+ EXPECT_TRUE(!result) << *result;
+ } else {
+ ASSERT_TRUE(result);
+ if constexpr (IsIntegral<T>) {
+ EXPECT_EQ(*result, *expect);
+ } else {
+ if (float_comp) {
+ EXPECT_FLOAT_EQ(*result, *expect);
+ } else {
+ EXPECT_EQ(*result, *expect);
+ }
+ }
+ }
+}
+
TEST_P(CheckedAddTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedAdd(a, b) == expect) << std::hex << "0x" << a << " + 0x" << b;
- EXPECT_TRUE(CheckedAdd(b, a) == expect) << std::hex << "0x" << a << " + 0x" << b;
+ ValidateResult(CheckedAdd(a, b), expect);
+ ValidateResult(CheckedAdd(b, a), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedAddTest_AInt,
@@ -477,7 +497,7 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedSub(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+ ValidateResult(CheckedSub(a, b), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedSubTest_AInt,
@@ -514,8 +534,7 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedSub(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " - 0x" << rhs;
+ ValidateResult(CheckedSub(lhs, rhs), expect);
},
std::get<1>(p));
}
@@ -546,8 +565,8 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedMul(a, b) == expect) << std::hex << "0x" << a << " * 0x" << b;
- EXPECT_TRUE(CheckedMul(b, a) == expect) << std::hex << "0x" << a << " * 0x" << b;
+ ValidateResult(CheckedMul(a, b), expect);
+ ValidateResult(CheckedMul(b, a), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedMulTest_AInt,
@@ -595,10 +614,8 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedMul(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " * 0x" << rhs;
- EXPECT_TRUE(CheckedMul(rhs, lhs) == expect)
- << std::hex << "0x" << lhs << " * 0x" << rhs;
+ ValidateResult(CheckedMul(lhs, rhs), expect);
+ ValidateResult(CheckedMul(rhs, lhs), expect);
},
std::get<1>(p));
}
@@ -628,7 +645,7 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedDiv(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+ ValidateResult(CheckedDiv(a, b), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedDivTest_AInt,
@@ -657,8 +674,7 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedDiv(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " / 0x" << rhs;
+ ValidateResult(CheckedDiv(lhs, rhs), expect);
},
std::get<1>(p));
}
@@ -692,7 +708,7 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedMod(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+ EXPECT_TRUE(CheckedMod(a, b) == expect) << std::hex << "0x" << a << " % 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedModTest_AInt,
@@ -726,8 +742,7 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedMod(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " / 0x" << rhs;
+ ValidateResult(CheckedMod(lhs, rhs), expect);
},
std::get<1>(p));
}
@@ -759,6 +774,51 @@
CheckedModTest_FloatCases<f32>(),
CheckedModTest_FloatCases<f16>())));
+using CheckedPowTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedPowTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ ValidateResult(CheckedPow(lhs, rhs), expect, /* float_comp */ true);
+ },
+ std::get<1>(p));
+}
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedPowTest_FloatCases() {
+ return {
+ {T(0), T(0), T(1)}, //
+ {T(0), T(0), T::Highest()}, //
+ {T(1), T(1), T(1)}, //
+ {T(1), T(1), T::Lowest()}, //
+ {T(4), T(2), T(2)}, //
+ {T(8), T(2), T(3)}, //
+ {T(1), T(1), T::Highest()}, //
+ {T(1), T(1), -T(1)}, //
+ {T(0.25), T(2), -T(2)}, //
+ {T(0.125), T(2), -T(3)}, //
+ {T(15.625), T(2.5), T(3)}, //
+ {T(11.313708498), T(2), T(3.5)}, //
+ {T(24.705294220), T(2.5), T(3.5)}, //
+ {T(0.0883883476), T(2), -T(3.5)}, //
+ {Overflow<T>, -T(1), T(1)}, //
+ {Overflow<T>, -T(1), T::Highest()}, //
+ {Overflow<T>, T::Lowest(), T(1)}, //
+ {Overflow<T>, T::Lowest(), T::Highest()}, //
+ {Overflow<T>, T::Lowest(), T::Lowest()}, //
+ {Overflow<T>, T(0), T(0)}, //
+ {Overflow<T>, T(0), -T(1)}, //
+ {Overflow<T>, T(0), T::Lowest()}, //
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedPowTest_Float,
+ CheckedPowTest_Float,
+ testing::ValuesIn(Concat(CheckedPowTest_FloatCases<AFloat>(),
+ CheckedPowTest_FloatCases<f32>(),
+ CheckedPowTest_FloatCases<f16>())));
+
using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;
using CheckedMaddTest_AInt = testing::TestWithParam<TernaryCheckedCase>;
@@ -767,10 +827,8 @@
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
auto c = std::get<3>(GetParam());
- EXPECT_EQ(CheckedMadd(a, b, c), expect)
- << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
- EXPECT_EQ(CheckedMadd(b, a, c), expect)
- << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
+ ValidateResult(CheckedMadd(a, b, c), expect);
+ ValidateResult(CheckedMadd(b, a, c), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedMaddTest_AInt,
diff --git a/src/tint/program.h b/src/tint/program.h
index 2b10f3b..a906041 100644
--- a/src/tint/program.h
+++ b/src/tint/program.h
@@ -19,8 +19,8 @@
#include <unordered_set>
#include "src/tint/ast/function.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/program_id.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/info.h"
#include "src/tint/symbol_table.h"
#include "src/tint/type/manager.h"
@@ -44,8 +44,8 @@
/// SemNodeAllocator is an alias to BlockAllocator<sem::Node>
using SemNodeAllocator = utils::BlockAllocator<sem::Node>;
- /// ConstantAllocator is an alias to BlockAllocator<sem::Constant>
- using ConstantAllocator = utils::BlockAllocator<sem::Constant>;
+ /// ConstantAllocator is an alias to BlockAllocator<constant::Constant>
+ using ConstantAllocator = utils::BlockAllocator<constant::Constant>;
/// Constructor
Program();
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 81495f7..92fdf31 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -87,11 +87,11 @@
#include "src/tint/ast/void.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/number.h"
#include "src/tint/program.h"
#include "src/tint/program_id.h"
#include "src/tint/sem/array_count.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/struct.h"
#include "src/tint/type/array.h"
#include "src/tint/type/bool.h"
@@ -265,8 +265,8 @@
/// SemNodeAllocator is an alias to BlockAllocator<sem::Node>
using SemNodeAllocator = utils::BlockAllocator<sem::Node>;
- /// ConstantAllocator is an alias to BlockAllocator<sem::Constant>
- using ConstantAllocator = utils::BlockAllocator<sem::Constant>;
+ /// ConstantAllocator is an alias to BlockAllocator<constant::Constant>
+ using ConstantAllocator = utils::BlockAllocator<constant::Constant>;
/// Constructor
ProgramBuilder();
@@ -465,12 +465,12 @@
return sem_nodes_.Create<T>(std::forward<ARGS>(args)...);
}
- /// Creates a new sem::Constant owned by the ProgramBuilder.
+ /// Creates a new constant::Constant owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, the sem::Node will also be destructed.
/// @param args the arguments to pass to the constructor
/// @returns the node pointer
template <typename T, typename... ARGS>
- traits::EnableIf<traits::IsTypeOrDerived<T, sem::Constant>, T>* create(ARGS&&... args) {
+ traits::EnableIf<traits::IsTypeOrDerived<T, constant::Constant>, T>* create(ARGS&&... args) {
AssertNotMoved();
return constant_nodes_.Create<T>(std::forward<ARGS>(args)...);
}
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index da73215..5d13abf 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -22,8 +22,9 @@
#include <type_traits>
#include <utility>
+#include "src/tint/constant/constant.h"
+#include "src/tint/number.h"
#include "src/tint/program_builder.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_initializer.h"
#include "src/tint/type/abstract_float.h"
@@ -204,10 +205,10 @@
}
template <typename NumberT>
-std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
+std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
std::stringstream ss;
ss << std::setprecision(20);
- ss << base << "^" << value << " cannot be represented as "
+ ss << base << "^" << exp << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}
@@ -246,8 +247,13 @@
return count;
}
-/// ImplConstant inherits from sem::Constant to add an private implementation method for conversion.
-struct ImplConstant : public sem::Constant {
+/// ImplConstant inherits from constant::Constant to add an private implementation method for
+/// conversion.
+class ImplConstant : public Castable<ImplConstant, constant::Constant> {
+ public:
+ ImplConstant() = default;
+ ~ImplConstant() override = default;
+
/// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure.
virtual utils::Result<const ImplConstant*> Convert(ProgramBuilder& builder,
@@ -261,12 +267,13 @@
// Forward declaration
const ImplConstant* CreateComposite(ProgramBuilder& builder,
const type::Type* type,
- utils::VectorRef<const sem::Constant*> elements);
+ utils::VectorRef<const constant::Constant*> elements);
/// Element holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface.
template <typename T>
-struct Element : ImplConstant {
+class Element : public Castable<Element<T>, ImplConstant> {
+ public:
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool");
@@ -284,7 +291,7 @@
return static_cast<AInt>(value);
}
}
- const sem::Constant* Index(size_t) const override { return nullptr; }
+ const constant::Constant* Index(size_t) const override { return nullptr; }
bool AllZero() const override { return IsPositiveZero(value); }
bool AnyZero() const override { return IsPositiveZero(value); }
bool AllEqual() const override { return true; }
@@ -354,12 +361,13 @@
/// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is
/// identical. Splat may be of a vector, matrix or array type.
/// Splat implements the Constant interface.
-struct Splat : ImplConstant {
- Splat(const type::Type* t, const sem::Constant* e, size_t n) : type(t), el(e), count(n) {}
+class Splat : public Castable<Splat, ImplConstant> {
+ public:
+ Splat(const type::Type* t, const constant::Constant* e, size_t n) : type(t), el(e), count(n) {}
~Splat() override = default;
const type::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
- const sem::Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
+ const constant::Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
bool AllZero() const override { return el->AllZero(); }
bool AnyZero() const override { return el->AnyZero(); }
bool AllEqual() const override { return true; }
@@ -369,8 +377,8 @@
const type::Type* target_ty,
const Source& source) const override {
// Convert the single splatted element type.
- // Note: This file is the only place where `sem::Constant`s are created, so this static_cast
- // is safe.
+ // Note: This file is the only place where `constant::Constant`s are created, so this
+ // static_cast is safe.
auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
builder, type::Type::ElementOf(target_ty), source);
if (!conv_el) {
@@ -383,7 +391,7 @@
}
type::Type const* const type;
- const sem::Constant* el;
+ const constant::Constant* el;
const size_t count;
};
@@ -392,16 +400,17 @@
/// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate Constant type.
/// Composite implements the Constant interface.
-struct Composite : ImplConstant {
+class Composite : public Castable<Composite, ImplConstant> {
+ public:
Composite(const type::Type* t,
- utils::VectorRef<const sem::Constant*> els,
+ utils::VectorRef<const constant::Constant*> els,
bool all_0,
bool any_0)
: type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
~Composite() override = default;
const type::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
- const sem::Constant* Index(size_t i) const override {
+ const constant::Constant* Index(size_t i) const override {
return i < elements.Length() ? elements[i] : nullptr;
}
bool AllZero() const override { return all_zero; }
@@ -413,7 +422,7 @@
const type::Type* target_ty,
const Source& source) const override {
// Convert each of the composite element types.
- utils::Vector<const sem::Constant*, 4> conv_els;
+ utils::Vector<const constant::Constant*, 4> conv_els;
conv_els.Reserve(elements.Length());
std::function<const type::Type*(size_t idx)> target_el_ty;
if (auto* str = target_ty->As<type::Struct>()) {
@@ -429,7 +438,7 @@
}
for (auto* el : elements) {
- // Note: This file is the only place where `sem::Constant`s are created, so the
+ // Note: This file is the only place where `constant::Constant`s are created, so the
// static_cast is safe.
auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
builder, target_el_ty(conv_els.Length()), source);
@@ -453,15 +462,33 @@
}
type::Type const* const type;
- const utils::Vector<const sem::Constant*, 8> elements;
+ const utils::Vector<const constant::Constant*, 8> elements;
const bool all_zero;
const bool any_zero;
const size_t hash;
};
+} // namespace
+} // namespace tint::resolver
+
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::ImplConstant);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::AInt>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::AFloat>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::i32>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::u32>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::f16>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::f32>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<bool>);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Splat);
+TINT_INSTANTIATE_TYPEINFO(tint::resolver::Composite);
+
+namespace tint::resolver {
+namespace {
+
/// CreateElement constructs and returns an Element<T>.
template <typename T>
ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const type::Type* t, T v) {
+ static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
TINT_ASSERT(Resolver, t->is_scalar());
if constexpr (IsFloatingPoint<T>) {
@@ -496,7 +523,7 @@
},
[&](const type::Struct* s) -> const ImplConstant* {
utils::Hashmap<const type::Type*, const ImplConstant*, 8> zero_by_type;
- utils::Vector<const sem::Constant*, 4> zeros;
+ utils::Vector<const constant::Constant*, 4> zeros;
zeros.Reserve(s->Members().Length());
for (auto* member : s->Members()) {
auto* zero = zero_by_type.GetOrCreate(
@@ -522,7 +549,7 @@
}
/// Equal returns true if the constants `a` and `b` are of the same type and value.
-bool Equal(const sem::Constant* a, const sem::Constant* b) {
+bool Equal(const constant::Constant* a, const constant::Constant* b) {
if (a->Hash() != b->Hash()) {
return false;
}
@@ -559,7 +586,22 @@
return false;
},
- [&](Default) { return a->Value() == b->Value(); });
+ [&](const type::Struct* str) {
+ auto count = str->Members().Length();
+ for (size_t i = 0; i < count; i++) {
+ if (!Equal(a->Index(i), b->Index(i))) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](Default) {
+ auto va = a->Value();
+ auto vb = b->Value();
+ TINT_ASSERT(Resolver, !std::holds_alternative<std::monostate>(va));
+ TINT_ASSERT(Resolver, !std::holds_alternative<std::monostate>(vb));
+ return va == vb;
+ });
}
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
@@ -567,7 +609,7 @@
/// depending on the element types and values.
const ImplConstant* CreateComposite(ProgramBuilder& builder,
const type::Type* type,
- utils::VectorRef<const sem::Constant*> elements) {
+ utils::VectorRef<const constant::Constant*> elements) {
if (elements.IsEmpty()) {
return nullptr;
}
@@ -617,7 +659,7 @@
return f(cs...);
}
}
- utils::Vector<const sem::Constant*, 8> els;
+ utils::Vector<const constant::Constant*, 8> els;
els.Reserve(n);
for (uint32_t i = 0; i < n; i++) {
if (auto el = detail::TransformElements(builder, type::Type::ElementOf(composite_ty),
@@ -653,8 +695,8 @@
ImplResult TransformBinaryElements(ProgramBuilder& builder,
const type::Type* composite_ty,
F&& f,
- const sem::Constant* c0,
- const sem::Constant* c1) {
+ const constant::Constant* c0,
+ const constant::Constant* c1) {
uint32_t n0 = 0;
type::Type::ElementOf(c0->Type(), &n0);
uint32_t n1 = 0;
@@ -665,7 +707,7 @@
return f(c0, c1);
}
- utils::Vector<const sem::Constant*, 8> els;
+ utils::Vector<const constant::Constant*, 8> els;
els.Reserve(max_n);
for (uint32_t i = 0; i < max_n; i++) {
auto nested_or_self = [&](auto* c, uint32_t num_elems) {
@@ -1180,8 +1222,8 @@
}
ConstEval::Result ConstEval::Dot(const Source& source,
- const sem::Constant* v1,
- const sem::Constant* v2) {
+ const constant::Constant* v1,
+ const constant::Constant* v2) {
auto* vec_ty = v1->Type()->As<type::Vector>();
TINT_ASSERT(Resolver, vec_ty);
auto* elem_ty = vec_ty->type();
@@ -1208,7 +1250,7 @@
ConstEval::Result ConstEval::Length(const Source& source,
const type::Type* ty,
- const sem::Constant* c0) {
+ const constant::Constant* c0) {
auto* vec_ty = c0->Type()->As<type::Vector>();
// Evaluates to the absolute value of e if T is scalar.
if (vec_ty == nullptr) {
@@ -1229,9 +1271,9 @@
ConstEval::Result ConstEval::Mul(const Source& source,
const type::Type* ty,
- const sem::Constant* v1,
- const sem::Constant* v2) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ const constant::Constant* v1,
+ const constant::Constant* v2) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
return Dispatch_fia_fiu32_f16(MulFunc(source, c0->Type()), c0, c1);
};
return TransformBinaryElements(builder, ty, transform, v1, v2);
@@ -1239,9 +1281,9 @@
ConstEval::Result ConstEval::Sub(const Source& source,
const type::Type* ty,
- const sem::Constant* v1,
- const sem::Constant* v2) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ const constant::Constant* v1,
+ const constant::Constant* v2) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
return Dispatch_fia_fiu32_f16(SubFunc(source, c0->Type()), c0, c1);
};
return TransformBinaryElements(builder, ty, transform, v1, v2);
@@ -1319,7 +1361,7 @@
}
// Multiple arguments. Must be a type initializer.
- utils::Vector<const sem::Constant*, 4> els;
+ utils::Vector<const constant::Constant*, 4> els;
els.Reserve(args.Length());
for (auto* arg : args) {
els.Push(arg->ConstantValue());
@@ -1328,7 +1370,7 @@
}
ConstEval::Result ConstEval::Conv(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
uint32_t el_count = 0;
auto* el_ty = type::Type::ElementOf(ty, &el_count);
@@ -1344,19 +1386,19 @@
}
ConstEval::Result ConstEval::Zero(const type::Type* ty,
- utils::VectorRef<const sem::Constant*>,
+ utils::VectorRef<const constant::Constant*>,
const Source&) {
return ZeroValue(builder, ty);
}
ConstEval::Result ConstEval::Identity(const type::Type*,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
return args[0];
}
ConstEval::Result ConstEval::VecSplat(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
if (auto* arg = args[0]) {
return builder.create<Splat>(ty, arg, static_cast<const type::Vector*>(ty)->Width());
@@ -1365,15 +1407,15 @@
}
ConstEval::Result ConstEval::VecInitS(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
return CreateComposite(builder, ty, args);
}
ConstEval::Result ConstEval::VecInitM(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
- utils::Vector<const sem::Constant*, 4> els;
+ utils::Vector<const constant::Constant*, 4> els;
for (auto* arg : args) {
auto* val = arg;
if (!val) {
@@ -1397,13 +1439,13 @@
}
ConstEval::Result ConstEval::MatInitS(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
auto* m = static_cast<const type::Matrix*>(ty);
- utils::Vector<const sem::Constant*, 4> els;
+ utils::Vector<const constant::Constant*, 4> els;
for (uint32_t c = 0; c < m->columns(); c++) {
- utils::Vector<const sem::Constant*, 4> column;
+ utils::Vector<const constant::Constant*, 4> column;
for (uint32_t r = 0; r < m->rows(); r++) {
auto i = r + c * m->rows();
column.Push(args[i]);
@@ -1414,7 +1456,7 @@
}
ConstEval::Result ConstEval::MatInitV(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
return CreateComposite(builder, ty, args);
}
@@ -1478,9 +1520,9 @@
}
ConstEval::Result ConstEval::OpComplement(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c) {
+ auto transform = [&](const constant::Constant* c) {
auto create = [&](auto i) {
return CreateElement(builder, source, c->Type(), decltype(i)(~i.value));
};
@@ -1490,9 +1532,9 @@
}
ConstEval::Result ConstEval::OpUnaryMinus(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c) {
+ auto transform = [&](const constant::Constant* c) {
auto create = [&](auto i) {
// For signed integrals, avoid C++ UB by not negating the
// smallest negative number. In WGSL, this operation is well
@@ -1515,9 +1557,9 @@
}
ConstEval::Result ConstEval::OpNot(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c) {
+ auto transform = [&](const constant::Constant* c) {
auto create = [&](auto i) {
return CreateElement(builder, source, c->Type(), decltype(i)(!i));
};
@@ -1527,9 +1569,9 @@
}
ConstEval::Result ConstEval::OpPlus(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
return Dispatch_fia_fiu32_f16(AddFunc(source, c0->Type()), c0, c1);
};
@@ -1537,25 +1579,25 @@
}
ConstEval::Result ConstEval::OpMinus(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
return Sub(source, ty, args[0], args[1]);
}
ConstEval::Result ConstEval::OpMultiply(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
return Mul(source, ty, args[0], args[1]);
}
ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* mat_ty = args[0]->Type()->As<type::Matrix>();
auto* vec_ty = args[1]->Type()->As<type::Vector>();
auto* elem_ty = vec_ty->type();
- auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
+ auto dot = [&](const constant::Constant* m, size_t row, const constant::Constant* v) {
ImplResult result;
switch (mat_ty->columns()) {
case 2:
@@ -1588,7 +1630,7 @@
return result;
};
- utils::Vector<const sem::Constant*, 4> result;
+ utils::Vector<const constant::Constant*, 4> result;
for (size_t i = 0; i < mat_ty->rows(); ++i) {
auto r = dot(args[0], i, args[1]); // matrix row i * vector
if (!r) {
@@ -1599,13 +1641,13 @@
return CreateComposite(builder, ty, result);
}
ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* vec_ty = args[0]->Type()->As<type::Vector>();
auto* mat_ty = args[1]->Type()->As<type::Matrix>();
auto* elem_ty = vec_ty->type();
- auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
+ auto dot = [&](const constant::Constant* v, const constant::Constant* m, size_t col) {
ImplResult result;
switch (mat_ty->rows()) {
case 2:
@@ -1638,7 +1680,7 @@
return result;
};
- utils::Vector<const sem::Constant*, 4> result;
+ utils::Vector<const constant::Constant*, 4> result;
for (size_t i = 0; i < mat_ty->columns(); ++i) {
auto r = dot(args[0], args[1], i); // vector * matrix col i
if (!r) {
@@ -1650,7 +1692,7 @@
}
ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* mat1 = args[0];
auto* mat2 = args[1];
@@ -1658,7 +1700,8 @@
auto* mat2_ty = mat2->Type()->As<type::Matrix>();
auto* elem_ty = mat1_ty->type();
- auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) {
+ auto dot = [&](const constant::Constant* m1, size_t row, const constant::Constant* m2,
+ size_t col) {
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
@@ -1695,9 +1738,9 @@
return result;
};
- utils::Vector<const sem::Constant*, 4> result_mat;
+ utils::Vector<const constant::Constant*, 4> result_mat;
for (size_t c = 0; c < mat2_ty->columns(); ++c) {
- utils::Vector<const sem::Constant*, 4> col_vec;
+ utils::Vector<const constant::Constant*, 4> col_vec;
for (size_t r = 0; r < mat1_ty->rows(); ++r) {
auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c
if (!v) {
@@ -1714,9 +1757,9 @@
}
ConstEval::Result ConstEval::OpDivide(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
return Dispatch_fia_fiu32_f16(DivFunc(source, c0->Type()), c0, c1);
};
@@ -1724,9 +1767,9 @@
}
ConstEval::Result ConstEval::OpModulo(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
return Dispatch_fia_fiu32_f16(ModFunc(source, c0->Type()), c0, c1);
};
@@ -1734,9 +1777,9 @@
}
ConstEval::Result ConstEval::OpEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i == j);
};
@@ -1747,9 +1790,9 @@
}
ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i != j);
};
@@ -1760,9 +1803,9 @@
}
ConstEval::Result ConstEval::OpLessThan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i < j);
};
@@ -1773,9 +1816,9 @@
}
ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i > j);
};
@@ -1786,9 +1829,9 @@
}
ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i <= j);
};
@@ -1799,9 +1842,9 @@
}
ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i >= j);
};
@@ -1812,7 +1855,7 @@
}
ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is true, so we could
// technically only return the value of the rhs.
@@ -1820,7 +1863,7 @@
}
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is false, so we could
// technically only return the value of the rhs.
@@ -1828,9 +1871,9 @@
}
ConstEval::Result ConstEval::OpAnd(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
using T = decltype(i);
T result;
@@ -1848,9 +1891,9 @@
}
ConstEval::Result ConstEval::OpOr(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
using T = decltype(i);
T result;
@@ -1868,9 +1911,9 @@
}
ConstEval::Result ConstEval::OpXor(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty),
decltype(i){i ^ j});
@@ -1882,9 +1925,9 @@
}
ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult {
using NumberT = decltype(e1);
using T = UnwrapNumber<NumberT>;
@@ -1969,9 +2012,9 @@
}
ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult {
using NumberT = decltype(e1);
using T = UnwrapNumber<NumberT>;
@@ -2034,9 +2077,9 @@
}
ConstEval::Result ConstEval::abs(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
NumberT result;
@@ -2059,9 +2102,9 @@
}
ConstEval::Result ConstEval::acos(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
if (i < NumberT(-1.0) || i > NumberT(1.0)) {
@@ -2077,9 +2120,9 @@
}
ConstEval::Result ConstEval::acosh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
if (i < NumberT(1.0)) {
@@ -2095,21 +2138,21 @@
}
ConstEval::Result ConstEval::all(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
return CreateElement(builder, source, ty, !args[0]->AnyZero());
}
ConstEval::Result ConstEval::any(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
return CreateElement(builder, source, ty, !args[0]->AllZero());
}
ConstEval::Result ConstEval::asin(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
if (i < NumberT(-1.0) || i > NumberT(1.0)) {
@@ -2125,9 +2168,9 @@
}
ConstEval::Result ConstEval::asinh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) {
return CreateElement(builder, source, c0->Type(), decltype(i)(std::asinh(i.value)));
};
@@ -2138,9 +2181,9 @@
}
ConstEval::Result ConstEval::atan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) {
return CreateElement(builder, source, c0->Type(), decltype(i)(std::atan(i.value)));
};
@@ -2150,9 +2193,9 @@
}
ConstEval::Result ConstEval::atanh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
if (i <= NumberT(-1.0) || i >= NumberT(1.0)) {
@@ -2169,9 +2212,9 @@
}
ConstEval::Result ConstEval::atan2(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) {
return CreateElement(builder, source, c0->Type(),
decltype(i)(std::atan2(i.value, j.value)));
@@ -2182,9 +2225,9 @@
}
ConstEval::Result ConstEval::ceil(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
return CreateElement(builder, source, c0->Type(), decltype(e)(std::ceil(e)));
};
@@ -2194,19 +2237,19 @@
}
ConstEval::Result ConstEval::clamp(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
- const sem::Constant* c2) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1,
+ const constant::Constant* c2) {
return Dispatch_fia_fiu32_f16(ClampFunc(source, c0->Type()), c0, c1, c2);
};
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
ConstEval::Result ConstEval::cos(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::cos(i.value)));
@@ -2217,9 +2260,9 @@
}
ConstEval::Result ConstEval::cosh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::cosh(i.value)));
@@ -2230,9 +2273,9 @@
}
ConstEval::Result ConstEval::countLeadingZeros(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -2245,9 +2288,9 @@
}
ConstEval::Result ConstEval::countOneBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -2269,9 +2312,9 @@
}
ConstEval::Result ConstEval::countTrailingZeros(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -2284,7 +2327,7 @@
}
ConstEval::Result ConstEval::cross(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* u = args[0];
auto* v = args[1];
@@ -2323,13 +2366,13 @@
}
return CreateComposite(builder, ty,
- utils::Vector<const sem::Constant*, 3>{x.Get(), y.Get(), z.Get()});
+ utils::Vector<const constant::Constant*, 3>{x.Get(), y.Get(), z.Get()});
}
ConstEval::Result ConstEval::degrees(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -2353,7 +2396,7 @@
}
ConstEval::Result ConstEval::determinant(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto calculate = [&]() -> ConstEval::Result {
auto* m = args[0];
@@ -2389,7 +2432,7 @@
}
ConstEval::Result ConstEval::distance(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto err = [&]() -> ImplResult {
AddNote("when calculating distance", source);
@@ -2409,7 +2452,7 @@
}
ConstEval::Result ConstEval::dot(const type::Type*,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto r = Dot(source, args[0], args[1]);
if (!r) {
@@ -2419,9 +2462,9 @@
}
ConstEval::Result ConstEval::exp(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e0) -> ImplResult {
using NumberT = decltype(e0);
auto val = NumberT(std::exp(e0));
@@ -2437,9 +2480,9 @@
}
ConstEval::Result ConstEval::exp2(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e0) -> ImplResult {
using NumberT = decltype(e0);
auto val = NumberT(std::exp2(e0));
@@ -2455,9 +2498,9 @@
}
ConstEval::Result ConstEval::extractBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto in_e) -> ImplResult {
using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>;
@@ -2510,7 +2553,7 @@
}
ConstEval::Result ConstEval::faceForward(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
// Returns e1 if dot(e2, e3) is negative, and -e1 otherwise.
auto* e1 = args[0];
@@ -2529,9 +2572,9 @@
}
ConstEval::Result ConstEval::firstLeadingBit(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -2573,9 +2616,9 @@
}
ConstEval::Result ConstEval::firstTrailingBit(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -2599,9 +2642,9 @@
}
ConstEval::Result ConstEval::floor(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
return CreateElement(builder, source, c0->Type(), decltype(e)(std::floor(e)));
};
@@ -2611,10 +2654,10 @@
}
ConstEval::Result ConstEval::fma(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c1, const sem::Constant* c2,
- const sem::Constant* c3) {
+ auto transform = [&](const constant::Constant* c1, const constant::Constant* c2,
+ const constant::Constant* c3) {
auto create = [&](auto e1, auto e2, auto e3) -> ImplResult {
auto err_msg = [&] {
AddNote("when calculating fma", source);
@@ -2637,8 +2680,22 @@
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
+ConstEval::Result ConstEval::fract(const type::Type* ty,
+ utils::VectorRef<const constant::Constant*> args,
+ const Source& source) {
+ auto transform = [&](const constant::Constant* c1) {
+ auto create = [&](auto e) -> ImplResult {
+ using NumberT = decltype(e);
+ auto r = e - std::floor(e);
+ return CreateElement(builder, source, c1->Type(), NumberT{r});
+ };
+ return Dispatch_fa_f32_f16(create, c1);
+ };
+ return TransformElements(builder, ty, transform, args[0]);
+}
+
ConstEval::Result ConstEval::frexp(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* arg = args[0];
@@ -2647,7 +2704,7 @@
ImplResult exp;
};
- auto scalar = [&](const sem::Constant* s) {
+ auto scalar = [&](const constant::Constant* s) {
int exp = 0;
double fract = std::frexp(s->As<AFloat>(), &exp);
return Switch(
@@ -2680,8 +2737,8 @@
};
if (auto* vec = arg->Type()->As<type::Vector>()) {
- utils::Vector<const sem::Constant*, 4> fract_els;
- utils::Vector<const sem::Constant*, 4> exp_els;
+ utils::Vector<const constant::Constant*, 4> fract_els;
+ utils::Vector<const constant::Constant*, 4> exp_els;
for (uint32_t i = 0; i < vec->Width(); i++) {
auto fe = scalar(arg->Index(i));
if (!fe.fract || !fe.exp) {
@@ -2693,7 +2750,7 @@
auto fract_ty = builder.create<type::Vector>(fract_els[0]->Type(), vec->Width());
auto exp_ty = builder.create<type::Vector>(exp_els[0]->Type(), vec->Width());
return CreateComposite(builder, ty,
- utils::Vector<const sem::Constant*, 2>{
+ utils::Vector<const constant::Constant*, 2>{
CreateComposite(builder, fract_ty, std::move(fract_els)),
CreateComposite(builder, exp_ty, std::move(exp_els)),
});
@@ -2703,7 +2760,7 @@
return utils::Failure;
}
return CreateComposite(builder, ty,
- utils::Vector<const sem::Constant*, 2>{
+ utils::Vector<const constant::Constant*, 2>{
fe.fract.Get(),
fe.exp.Get(),
});
@@ -2711,9 +2768,9 @@
}
ConstEval::Result ConstEval::insertBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto in_e, auto in_newbits) -> ImplResult {
using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>;
@@ -2763,9 +2820,9 @@
}
ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e);
@@ -2797,7 +2854,7 @@
}
ConstEval::Result ConstEval::length(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto r = Length(source, ty, args[0]);
if (!r) {
@@ -2807,9 +2864,9 @@
}
ConstEval::Result ConstEval::log(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto v) -> ImplResult {
using NumberT = decltype(v);
if (v <= NumberT(0)) {
@@ -2824,9 +2881,9 @@
}
ConstEval::Result ConstEval::log2(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto v) -> ImplResult {
using NumberT = decltype(v);
if (v <= NumberT(0)) {
@@ -2841,9 +2898,9 @@
}
ConstEval::Result ConstEval::max(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e0, auto e1) {
return CreateElement(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1)));
};
@@ -2853,9 +2910,9 @@
}
ConstEval::Result ConstEval::min(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e0, auto e1) {
return CreateElement(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1)));
};
@@ -2865,9 +2922,9 @@
}
ConstEval::Result ConstEval::mix(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, size_t index) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) {
auto create = [&](auto e1, auto e2) -> ImplResult {
using NumberT = decltype(e1);
// e3 is either a vector or a scalar
@@ -2908,23 +2965,23 @@
}
ConstEval::Result ConstEval::modf(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform_fract = [&](const sem::Constant* c) {
+ auto transform_fract = [&](const constant::Constant* c) {
auto create = [&](auto e) {
return CreateElement(builder, source, c->Type(),
decltype(e)(e.value - std::trunc(e.value)));
};
return Dispatch_fa_f32_f16(create, c);
};
- auto transform_whole = [&](const sem::Constant* c) {
+ auto transform_whole = [&](const constant::Constant* c) {
auto create = [&](auto e) {
return CreateElement(builder, source, c->Type(), decltype(e)(std::trunc(e.value)));
};
return Dispatch_fa_f32_f16(create, c);
};
- utils::Vector<const sem::Constant*, 2> fields;
+ utils::Vector<const constant::Constant*, 2> fields;
if (auto fract = TransformElements(builder, args[0]->Type(), transform_fract, args[0])) {
fields.Push(fract.Get());
@@ -2942,7 +2999,7 @@
}
ConstEval::Result ConstEval::normalize(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* len_ty = type::Type::DeepestElementOf(ty);
auto len = Length(source, len_ty, args[0]);
@@ -2959,7 +3016,7 @@
}
ConstEval::Result ConstEval::pack2x16float(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto convert = [&](f32 val) -> utils::Result<uint32_t> {
auto conv = CheckedConvert<f16>(val);
@@ -2987,7 +3044,7 @@
}
ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(source, val, f32(-1.0f), f32(1.0f)).Get();
@@ -3004,7 +3061,7 @@
}
ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(source, val, f32(0.0f), f32(1.0f)).Get();
@@ -3020,7 +3077,7 @@
}
ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(source, val, f32(-1.0f), f32(1.0f)).Get();
@@ -3040,7 +3097,7 @@
}
ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(source, val, f32(0.0f), f32(1.0f)).Get();
@@ -3058,10 +3115,27 @@
return CreateElement(builder, source, ty, ret);
}
+ConstEval::Result ConstEval::pow(const type::Type* ty,
+ utils::VectorRef<const constant::Constant*> args,
+ const Source& source) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
+ auto create = [&](auto e1, auto e2) -> ImplResult {
+ auto r = CheckedPow(e1, e2);
+ if (!r) {
+ AddError(OverflowErrorMessage(e1, "^", e2), source);
+ return utils::Failure;
+ }
+ return CreateElement(builder, source, c0->Type(), *r);
+ };
+ return Dispatch_fa_f32_f16(create, c0, c1);
+ };
+ return TransformElements(builder, ty, transform, args[0], args[1]);
+}
+
ConstEval::Result ConstEval::radians(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -3085,7 +3159,7 @@
}
ConstEval::Result ConstEval::reflect(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto calculate = [&]() -> ConstEval::Result {
// For the incident vector e1 and surface orientation e2, returns the reflection direction
@@ -3128,7 +3202,7 @@
}
ConstEval::Result ConstEval::refract(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* vec_ty = ty->As<type::Vector>();
auto* el_ty = vec_ty->type();
@@ -3226,9 +3300,9 @@
}
ConstEval::Result ConstEval::reverseBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto in_e) -> ImplResult {
using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>;
@@ -3253,9 +3327,9 @@
}
ConstEval::Result ConstEval::round(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
@@ -3289,9 +3363,9 @@
}
ConstEval::Result ConstEval::saturate(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
return CreateElement(builder, source, c0->Type(),
@@ -3303,10 +3377,10 @@
}
ConstEval::Result ConstEval::select_bool(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto cond = args[2]->As<bool>();
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto f, auto t) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
};
@@ -3317,9 +3391,9 @@
}
ConstEval::Result ConstEval::select_boolvec(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, size_t index) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) {
auto create = [&](auto f, auto t) -> ImplResult {
// Get corresponding bool value at the current vector value index
auto cond = args[2]->Index(index)->As<bool>();
@@ -3332,9 +3406,9 @@
}
ConstEval::Result ConstEval::sign(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e);
NumberT result;
@@ -3354,9 +3428,9 @@
}
ConstEval::Result ConstEval::sin(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::sin(i.value)));
@@ -3367,9 +3441,9 @@
}
ConstEval::Result ConstEval::sinh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::sinh(i.value)));
@@ -3380,10 +3454,10 @@
}
ConstEval::Result ConstEval::smoothstep(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
- const sem::Constant* c2) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1,
+ const constant::Constant* c2) {
auto create = [&](auto low, auto high, auto x) -> ImplResult {
using NumberT = decltype(low);
@@ -3431,9 +3505,9 @@
}
ConstEval::Result ConstEval::step(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto edge, auto x) -> ImplResult {
using NumberT = decltype(edge);
NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0);
@@ -3445,9 +3519,9 @@
}
ConstEval::Result ConstEval::sqrt(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
return Dispatch_fa_f32_f16(SqrtFunc(source, c0->Type()), c0);
};
@@ -3455,9 +3529,9 @@
}
ConstEval::Result ConstEval::tan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::tan(i.value)));
@@ -3468,9 +3542,9 @@
}
ConstEval::Result ConstEval::tanh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::tanh(i.value)));
@@ -3481,7 +3555,7 @@
}
ConstEval::Result ConstEval::transpose(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source&) {
auto* m = args[0];
auto* mat_ty = m->Type()->As<type::Matrix>();
@@ -3489,9 +3563,9 @@
auto* result_mat_ty = ty->As<type::Matrix>();
// Produce column vectors from each row
- utils::Vector<const sem::Constant*, 4> result_mat;
+ utils::Vector<const constant::Constant*, 4> result_mat;
for (size_t r = 0; r < mat_ty->rows(); ++r) {
- utils::Vector<const sem::Constant*, 4> new_col_vec;
+ utils::Vector<const constant::Constant*, 4> new_col_vec;
for (size_t c = 0; c < mat_ty->columns(); ++c) {
new_col_vec.Push(me(r, c));
}
@@ -3501,9 +3575,9 @@
}
ConstEval::Result ConstEval::trunc(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0) {
+ auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) {
return CreateElement(builder, source, c0->Type(), decltype(i)(std::trunc(i.value)));
};
@@ -3513,12 +3587,12 @@
}
ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* inner_ty = type::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
- utils::Vector<const sem::Constant*, 2> els;
+ utils::Vector<const constant::Constant*, 2> els;
els.Reserve(2);
for (size_t i = 0; i < 2; ++i) {
auto in = f16::FromBits(uint16_t((e >> (16 * i)) & 0x0000'ffff));
@@ -3537,12 +3611,12 @@
}
ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* inner_ty = type::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
- utils::Vector<const sem::Constant*, 2> els;
+ utils::Vector<const constant::Constant*, 2> els;
els.Reserve(2);
for (size_t i = 0; i < 2; ++i) {
auto val = f32(
@@ -3557,12 +3631,12 @@
}
ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* inner_ty = type::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
- utils::Vector<const sem::Constant*, 2> els;
+ utils::Vector<const constant::Constant*, 2> els;
els.Reserve(2);
for (size_t i = 0; i < 2; ++i) {
auto val = f32(static_cast<float>(uint16_t((e >> (16 * i)) & 0x0000'ffff)) / 65535.f);
@@ -3576,12 +3650,12 @@
}
ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* inner_ty = type::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
- utils::Vector<const sem::Constant*, 4> els;
+ utils::Vector<const constant::Constant*, 4> els;
els.Reserve(4);
for (size_t i = 0; i < 4; ++i) {
auto val =
@@ -3596,12 +3670,12 @@
}
ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
auto* inner_ty = type::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
- utils::Vector<const sem::Constant*, 4> els;
+ utils::Vector<const constant::Constant*, 4> els;
els.Reserve(4);
for (size_t i = 0; i < 4; ++i) {
auto val = f32(static_cast<float>(uint8_t((e >> (8 * i)) & 0x0000'00ff)) / 255.f);
@@ -3615,9 +3689,9 @@
}
ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c) -> ImplResult {
+ auto transform = [&](const constant::Constant* c) -> ImplResult {
auto value = c->As<f32>();
auto conv = CheckedConvert<f32>(f16(value));
if (!conv) {
@@ -3630,7 +3704,7 @@
}
ConstEval::Result ConstEval::Convert(const type::Type* target_ty,
- const sem::Constant* value,
+ const constant::Constant* value,
const Source& source) {
if (value->Type() == target_ty) {
return value;
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index fa5c0ee..6caac07 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -30,8 +30,10 @@
namespace tint::ast {
class LiteralExpression;
} // namespace tint::ast
-namespace tint::sem {
+namespace tint::constant {
class Constant;
+} // namespace tint::constant
+namespace tint::sem {
class Expression;
} // namespace tint::sem
namespace tint::type {
@@ -48,18 +50,20 @@
public:
/// The result type of a method that may raise a diagnostic error and the caller should abort
/// resolving. Can be one of three distinct values:
- /// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
+ /// * A non-null constant::Constant pointer. Returned when a expression resolves to a creation
+ /// time
/// value.
- /// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time
+ /// * A null constant::Constant pointer. Returned when a expression cannot resolve to a creation
+ /// time
/// value, but is otherwise legal.
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
/// will have already reported a diagnostic error message, and the caller should abort
/// resolving.
- using Result = utils::Result<const sem::Constant*>;
+ using Result = utils::Result<const constant::Constant*>;
/// Typedef for a constant evaluation function
using Function = Result (ConstEval::*)(const type::Type* result_ty,
- utils::VectorRef<const sem::Constant*>,
+ utils::VectorRef<const constant::Constant*>,
const Source&);
/// Constructor
@@ -109,7 +113,7 @@
/// @param value the value being converted
/// @param source the source location
/// @return the converted value, or null if the value cannot be calculated
- Result Convert(const type::Type* ty, const sem::Constant* value, const Source& source);
+ Result Convert(const type::Type* ty, const constant::Constant* value, const Source& source);
////////////////////////////////////////////////////////////////////////////////////////////////
// Constant value evaluation methods, to be indirectly called via the intrinsic table
@@ -121,7 +125,7 @@
/// @param source the source location
/// @return the converted value, or null if the value cannot be calculated
Result Conv(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Zero value type initializer
@@ -130,7 +134,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result Zero(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Identity value type initializer
@@ -139,7 +143,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result Identity(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Vector splat initializer
@@ -148,7 +152,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result VecSplat(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Vector initializer using scalars
@@ -157,7 +161,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result VecInitS(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Vector initializer using a mix of scalars and smaller vectors
@@ -166,7 +170,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result VecInitM(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Matrix initializer using scalar values
@@ -175,7 +179,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result MatInitS(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Matrix initializer using column vectors
@@ -184,7 +188,7 @@
/// @param source the source location
/// @return the constructed value, or null if the value cannot be calculated
Result MatInitV(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
@@ -197,7 +201,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpComplement(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Unary minus operator '-'
@@ -206,7 +210,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpUnaryMinus(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Unary not operator '!'
@@ -215,7 +219,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpNot(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
@@ -228,7 +232,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpPlus(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Minus operator '-'
@@ -237,7 +241,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpMinus(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Multiply operator '*' for the same type on the LHS and RHS
@@ -246,7 +250,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpMultiply(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Multiply operator '*' for matCxR<T> * vecC<T>
@@ -255,7 +259,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpMultiplyMatVec(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Multiply operator '*' for vecR<T> * matCxR<T>
@@ -264,7 +268,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpMultiplyVecMat(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Multiply operator '*' for matKxR<T> * matCxK<T>
@@ -273,7 +277,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpMultiplyMatMat(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Divide operator '/'
@@ -282,7 +286,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpDivide(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Modulo operator '%'
@@ -291,7 +295,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpModulo(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Equality operator '=='
@@ -300,7 +304,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Inequality operator '!='
@@ -309,7 +313,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpNotEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Less than operator '<'
@@ -318,7 +322,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpLessThan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Greater than operator '>'
@@ -327,7 +331,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpGreaterThan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Less than or equal operator '<='
@@ -336,7 +340,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpLessThanEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Greater than or equal operator '>='
@@ -345,7 +349,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpGreaterThanEqual(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Logical and operator '&&'
@@ -354,7 +358,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpLogicalAnd(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Logical or operator '||'
@@ -363,7 +367,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpLogicalOr(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Bitwise and operator '&'
@@ -372,7 +376,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpAnd(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Bitwise or operator '|'
@@ -381,7 +385,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpOr(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Bitwise xor operator '^'
@@ -390,7 +394,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpXor(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Bitwise shift left operator '<<'
@@ -399,7 +403,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpShiftLeft(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// Bitwise shift right operator '<<'
@@ -408,7 +412,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpShiftRight(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
@@ -421,7 +425,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result abs(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// acos builtin
@@ -430,7 +434,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result acos(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// acosh builtin
@@ -439,7 +443,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result acosh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// all builtin
@@ -448,7 +452,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result all(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// any builtin
@@ -457,7 +461,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result any(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// asin builtin
@@ -466,7 +470,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result asin(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// asinh builtin
@@ -475,7 +479,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result asinh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// atan builtin
@@ -484,7 +488,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result atan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// atanh builtin
@@ -493,7 +497,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result atanh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// atan2 builtin
@@ -502,7 +506,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result atan2(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// ceil builtin
@@ -511,7 +515,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result ceil(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// clamp builtin
@@ -520,7 +524,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result clamp(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// cos builtin
@@ -529,7 +533,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result cos(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// cosh builtin
@@ -538,7 +542,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result cosh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// countLeadingZeros builtin
@@ -547,7 +551,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result countLeadingZeros(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// countOneBits builtin
@@ -556,7 +560,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result countOneBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// countTrailingZeros builtin
@@ -565,7 +569,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result countTrailingZeros(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// cross builtin
@@ -574,7 +578,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result cross(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// degrees builtin
@@ -583,7 +587,7 @@
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result degrees(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// determinant builtin
@@ -592,7 +596,7 @@
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result determinant(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// distance builtin
@@ -601,7 +605,7 @@
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result distance(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// dot builtin
@@ -610,7 +614,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result dot(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// exp builtin
@@ -619,7 +623,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result exp(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// exp2 builtin
@@ -628,7 +632,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result exp2(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// extractBits builtin
@@ -637,7 +641,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result extractBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// faceForward builtin
@@ -646,7 +650,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result faceForward(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// firstLeadingBit builtin
@@ -655,7 +659,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result firstLeadingBit(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// firstTrailingBit builtin
@@ -664,7 +668,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result firstTrailingBit(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// floor builtin
@@ -673,7 +677,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result floor(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// fma builtin
@@ -682,16 +686,25 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result fma(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
+ /// fract builtin
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location
+ /// @return the result value, or null if the value cannot be calculated
+ Result fract(const type::Type* ty,
+ utils::VectorRef<const constant::Constant*> args,
+ const Source& source);
+
/// frexp builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result frexp(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// insertBits builtin
@@ -700,7 +713,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result insertBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// inverseSqrt builtin
@@ -709,7 +722,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result inverseSqrt(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// length builtin
@@ -718,7 +731,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result length(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// log builtin
@@ -727,7 +740,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result log(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// log2 builtin
@@ -736,7 +749,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result log2(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// max builtin
@@ -745,7 +758,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result max(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// min builtin
@@ -754,7 +767,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result min(const type::Type* ty, // NOLINT(build/include_what_you_use) -- confused by min
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// mix builtin
@@ -763,7 +776,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result mix(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// modf builtin
@@ -772,7 +785,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result modf(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// normalize builtin
@@ -781,7 +794,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result normalize(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// pack2x16float builtin
@@ -790,7 +803,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result pack2x16float(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// pack2x16snorm builtin
@@ -799,7 +812,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result pack2x16snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// pack2x16unorm builtin
@@ -808,7 +821,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result pack2x16unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// pack4x8snorm builtin
@@ -817,7 +830,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result pack4x8snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// pack4x8unorm builtin
@@ -826,16 +839,25 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result pack4x8unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
+ /// pow builtin
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location
+ /// @return the result value, or null if the value cannot be calculated
+ Result pow(const type::Type* ty,
+ utils::VectorRef<const constant::Constant*> args,
+ const Source& source);
+
/// radians builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result radians(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// reflect builtin
@@ -844,7 +866,7 @@
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result reflect(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// refract builtin
@@ -853,7 +875,7 @@
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result refract(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// reverseBits builtin
@@ -862,7 +884,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result reverseBits(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// round builtin
@@ -871,7 +893,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result round(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// saturate builtin
@@ -880,7 +902,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result saturate(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// select builtin with single bool third arg
@@ -889,7 +911,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result select_bool(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// select builtin with vector of bool third arg
@@ -898,7 +920,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result select_boolvec(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// sign builtin
@@ -907,7 +929,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result sign(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// sin builtin
@@ -916,7 +938,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result sin(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// sinh builtin
@@ -925,7 +947,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result sinh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// smoothstep builtin
@@ -934,7 +956,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result smoothstep(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// step builtin
@@ -943,7 +965,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result step(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// sqrt builtin
@@ -952,7 +974,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result sqrt(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// tan builtin
@@ -961,7 +983,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result tan(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// tanh builtin
@@ -970,7 +992,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result tanh(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// transpose builtin
@@ -979,7 +1001,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result transpose(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// trunc builtin
@@ -988,7 +1010,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result trunc(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// unpack2x16float builtin
@@ -997,7 +1019,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result unpack2x16float(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// unpack2x16snorm builtin
@@ -1006,7 +1028,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result unpack2x16snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// unpack2x16unorm builtin
@@ -1015,7 +1037,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result unpack2x16unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// unpack4x8snorm builtin
@@ -1024,7 +1046,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result unpack4x8snorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// unpack4x8unorm builtin
@@ -1033,7 +1055,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result unpack4x8unorm(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// quantizeToF16 builtin
@@ -1042,7 +1064,7 @@
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result quantizeToF16(const type::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
+ utils::VectorRef<const constant::Constant*> args,
const Source& source);
private:
@@ -1339,14 +1361,14 @@
/// @param v1 the first vector
/// @param v2 the second vector
/// @returns the dot product
- Result Dot(const Source& source, const sem::Constant* v1, const sem::Constant* v2);
+ Result Dot(const Source& source, const constant::Constant* v1, const constant::Constant* v2);
/// Returns the length of c0
/// @param source the source location
/// @param ty the return type
/// @param c0 the constant to calculate the length of
/// @returns the length of c0
- Result Length(const Source& source, const type::Type* ty, const sem::Constant* c0);
+ Result Length(const Source& source, const type::Type* ty, const constant::Constant* c0);
/// Returns the product of v1 and v2
/// @param source the source location
@@ -1356,8 +1378,8 @@
/// @returns the product of v1 and v2
Result Mul(const Source& source,
const type::Type* ty,
- const sem::Constant* v1,
- const sem::Constant* v2);
+ const constant::Constant* v1,
+ const constant::Constant* v2);
/// Returns the difference between v2 and v1
/// @param source the source location
@@ -1367,8 +1389,8 @@
/// @returns the difference between v2 and v1
Result Sub(const Source& source,
const type::Type* ty,
- const sem::Constant* v1,
- const sem::Constant* v2);
+ const constant::Constant* v1,
+ const constant::Constant* v2);
ProgramBuilder& builder;
};
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index b3fc800..d626afc 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -99,7 +99,7 @@
auto& expected = expected_case.value;
auto* sem = Sem().Get(expr);
- const sem::Constant* value = sem->ConstantValue();
+ const constant::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
@@ -892,19 +892,20 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
- const sem::Constant* value = sem->ConstantValue();
+ const constant::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
auto* expected_sem = Sem().Get(expected_expr);
- const sem::Constant* expected_value = expected_sem->ConstantValue();
+ const constant::Constant* expected_value = expected_sem->ConstantValue();
ASSERT_NE(expected_value, nullptr);
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
- ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
- EXPECT_EQ(a->As<bool>(), b->As<bool>());
- return HasFailure() ? Action::kStop : Action::kContinue;
- });
+ ForEachElemPair(value, expected_value,
+ [&](const constant::Constant* a, const constant::Constant* b) {
+ EXPECT_EQ(a->As<bool>(), b->As<bool>());
+ return HasFailure() ? Action::kStop : Action::kContinue;
+ });
}
template <typename T>
diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc
index 4abf09e..25adcef 100644
--- a/src/tint/resolver/const_eval_builtin_test.cc
+++ b/src/tint/resolver/const_eval_builtin_test.cc
@@ -162,7 +162,7 @@
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
- const sem::Constant* value = sem->ConstantValue();
+ const constant::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
@@ -1173,6 +1173,38 @@
FmaCases<f16>()))));
template <typename T>
+std::vector<Case> FractCases() {
+ auto r = std::vector<Case>{
+ C({T(0)}, T(0)),
+ C({T(0.1)}, T(0.1)),
+ C({T(-0.1)}, T(0.9)),
+ C({T(0.0000001)}, T(0.0000001)),
+ C({T(-0.0000001)}, T(0.9999999)),
+ C({T(12.34567)}, T(0.34567)).FloatComp(0.002),
+ C({T(-12.34567)}, T(0.65433)).FloatComp(0.002),
+ C({T::Lowest()}, T(0)),
+ C({T::Highest()}, T(0)),
+ // Vector tests
+ C({Vec(T(0.1), T(-0.1), T(-0.0000001))}, Vec(T(0.1), T(0.9), T(0.9999999))),
+ };
+ // Note: Valid results are in the closed interval [0, 1.0]. For example, if e is a very small
+ // negative number, then fract(e) may be 1.0.
+ ConcatIntoIf<!std::is_same_v<T, f16>>( //
+ r, std::vector<Case>{
+ C({T(-0.000000000000000001)}, T(1)),
+ });
+
+ return r;
+}
+INSTANTIATE_TEST_SUITE_P( //
+ Fract,
+ ResolverConstEvalBuiltinTest,
+ testing::Combine(testing::Values(sem::BuiltinType::kFract),
+ testing::ValuesIn(Concat(FractCases<AFloat>(), //
+ FractCases<f32>(),
+ FractCases<f16>()))));
+
+template <typename T>
std::vector<Case> FrexpCases() {
using F = T; // fract type
using E = std::conditional_t<std::is_same_v<T, AFloat>, AInt, i32>; // exp type
@@ -1972,6 +2004,54 @@
testing::ValuesIn(Pack2x16unormCases())));
template <typename T>
+std::vector<Case> PowCases() {
+ auto error_msg = [](auto base, auto exp) {
+ return "12:34 error: " + OverflowErrorMessage(base, "^", exp);
+ };
+ return {
+ C({T(0), T(1)}, T(0)), //
+ C({T(0), T::Highest()}, T(0)), //
+ C({T(1), T(1)}, T(1)), //
+ C({T(1), T::Lowest()}, T(1)), //
+ C({T(2), T(2)}, T(4)), //
+ C({T(2), T(3)}, T(8)), //
+ // Positive base, negative exponent
+ C({T(1), T::Highest()}, T(1)), //
+ C({T(1), -T(1)}, T(1)), //
+ C({T(2), -T(2)}, T(0.25)), //
+ C({T(2), -T(3)}, T(0.125)), //
+ // Decimal values
+ C({T(2.5), T(3)}, T(15.625)), //
+ C({T(2), T(3.5)}, T(11.313708498)).FloatComp(), //
+ C({T(2.5), T(3.5)}, T(24.705294220)).FloatComp(), //
+ C({T(2), -T(3.5)}, T(0.0883883476)).FloatComp(), //
+
+ // Vector tests
+ C({Vec(T(0), T(1), T(2)), Vec(T(2), T(2), T(2))}, Vec(T(0), T(1), T(4))),
+ C({Vec(T(2), T(2), T(2)), Vec(T(2), T(3), T(4))}, Vec(T(4), T(8), T(16))),
+
+ // Error if base < 0
+ E({-T(1), T(1)}, error_msg(-T(1), T(1))),
+ E({-T(1), T::Highest()}, error_msg(-T(1), T::Highest())),
+ E({T::Lowest(), T(1)}, error_msg(T::Lowest(), T(1))),
+ E({T::Lowest(), T::Highest()}, error_msg(T::Lowest(), T::Highest())),
+ E({T::Lowest(), T::Lowest()}, error_msg(T::Lowest(), T::Lowest())),
+
+ // Error if base == 0 and exp <= 0
+ E({T(0), T(0)}, error_msg(T(0), T(0))),
+ E({T(0), -T(1)}, error_msg(T(0), -T(1))),
+ E({T(0), T::Lowest()}, error_msg(T(0), T::Lowest())),
+ };
+}
+INSTANTIATE_TEST_SUITE_P( //
+ Pow,
+ ResolverConstEvalBuiltinTest,
+ testing::Combine(testing::Values(sem::BuiltinType::kPow),
+ testing::ValuesIn(Concat(PowCases<AFloat>(), //
+ PowCases<f32>(), //
+ PowCases<f16>()))));
+
+template <typename T>
std::vector<Case> ReverseBitsCases() {
using B = BitValues<T>;
return {
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index a8858f4..0429e8f 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -36,9 +36,10 @@
template <typename T>
inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
-/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args
+/// Walks the constant::Constant @p c, accumulating all the inner-most scalar values into @p args
template <size_t N>
-inline void CollectScalars(const sem::Constant* c, utils::Vector<builder::Scalar, N>& scalars) {
+inline void CollectScalars(const constant::Constant* c,
+ utils::Vector<builder::Scalar, N>& scalars) {
Switch(
c->Type(), //
[&](const type::AbstractInt*) { scalars.Push(c->As<AInt>()); },
@@ -56,8 +57,8 @@
});
}
-/// Walks the sem::Constant @p c, returning all the inner-most scalar values.
-inline utils::Vector<builder::Scalar, 16> ScalarsFrom(const sem::Constant* c) {
+/// Walks the constant::Constant @p c, returning all the inner-most scalar values.
+inline utils::Vector<builder::Scalar, 16> ScalarsFrom(const constant::Constant* c) {
utils::Vector<builder::Scalar, 16> out;
CollectScalars(c, out);
return out;
@@ -88,7 +89,7 @@
/// @param got_constant the constant value evaluated by the resolver
/// @param expected_value the expected value for the test
/// @param flags optional flags for controlling the comparisons
-inline void CheckConstant(const sem::Constant* got_constant,
+inline void CheckConstant(const constant::Constant* got_constant,
const builder::Value& expected_value,
CheckConstantFlags flags = {}) {
auto values_flat = ScalarsFrom(got_constant);
@@ -237,10 +238,10 @@
/// Returns the overflow error message for exponentiation
template <typename NumberT>
-std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
+std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
std::stringstream ss;
ss << std::setprecision(20);
- ss << base << "^" << value << " cannot be represented as "
+ ss << base << "^" << exp << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}
@@ -257,7 +258,7 @@
// TODO(amaiorano): Move to Constant.h?
enum class Action { kStop, kContinue };
template <typename Func>
-inline Action ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
+inline Action ForEachElemPair(const constant::Constant* a, const constant::Constant* b, Func&& f) {
EXPECT_EQ(a->Type(), b->Type());
size_t i = 0;
while (true) {
diff --git a/src/tint/resolver/const_eval_unary_op_test.cc b/src/tint/resolver/const_eval_unary_op_test.cc
index f7f5928..cb342d1 100644
--- a/src/tint/resolver/const_eval_unary_op_test.cc
+++ b/src/tint/resolver/const_eval_unary_op_test.cc
@@ -57,7 +57,7 @@
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
- const sem::Constant* value = sem->ConstantValue();
+ const constant::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index bda6238..926f6b5 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -12461,24 +12461,24 @@
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[26],
+ /* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[856],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::fract,
},
{
/* [348] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[26],
+ /* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[857],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::fract,
},
{
/* [349] */
@@ -12797,24 +12797,24 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[26],
+ /* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[622],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::pow,
},
{
/* [376] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[26],
+ /* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[624],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::pow,
},
{
/* [377] */
@@ -14203,8 +14203,8 @@
},
{
/* [39] */
- /* fn fract<T : f32_f16>(T) -> T */
- /* fn fract<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
+ /* fn fract<T : fa_f32_f16>(@test_value(1.25) T) -> T */
+ /* fn fract<N : num, T : fa_f32_f16>(@test_value(1.25) vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[347],
},
@@ -14345,8 +14345,8 @@
},
{
/* [60] */
- /* fn pow<T : f32_f16>(T, T) -> T */
- /* fn pow<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
+ /* fn pow<T : fa_f32_f16>(T, T) -> T */
+ /* fn pow<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[375],
},
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index ce18dc4..281fa94 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1285,7 +1285,7 @@
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- const sem::Constant* const_value = nullptr;
+ const constant::Constant* const_value = nullptr;
if (!sel->IsDefault()) {
// The sem statement was created in the switch when attempting to determine the
// common type.
@@ -1804,7 +1804,7 @@
return nullptr;
}
- const sem::Constant* materialized_val = nullptr;
+ const constant::Constant* materialized_val = nullptr;
if (!skip_const_eval_.Contains(decl)) {
auto expr_val = expr->ConstantValue();
if (!expr_val) {
@@ -1856,7 +1856,9 @@
return param_el_ty && !param_el_ty->Is<type::AbstractNumeric>();
}
-bool Resolver::Convert(const sem::Constant*& c, const type::Type* target_ty, const Source& source) {
+bool Resolver::Convert(const constant::Constant*& c,
+ const type::Type* target_ty,
+ const Source& source) {
auto r = const_eval_.Convert(target_ty, c, source);
if (!r) {
return false;
@@ -1866,7 +1868,7 @@
}
template <size_t N>
-utils::Result<utils::Vector<const sem::Constant*, N>> Resolver::ConvertArguments(
+utils::Result<utils::Vector<const constant::Constant*, N>> Resolver::ConvertArguments(
const utils::Vector<const sem::Expression*, N>& args,
const sem::CallTarget* target) {
auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
@@ -1924,7 +1926,7 @@
ty = builder_->create<type::Reference>(ty, ref->AddressSpace(), ref->Access());
}
- const sem::Constant* val = nullptr;
+ const constant::Constant* val = nullptr;
auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
@@ -1955,7 +1957,7 @@
RegisterLoadIfNeeded(inner);
- const sem::Constant* val = nullptr;
+ const constant::Constant* val = nullptr;
// TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented.
if (auto r = const_eval_.Bitcast(ty, inner)) {
val = r.Get();
@@ -2017,7 +2019,7 @@
return nullptr;
}
- const sem::Constant* value = nullptr;
+ const constant::Constant* value = nullptr;
auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
@@ -2047,7 +2049,7 @@
}
auto stage = args_stage; // The evaluation stage of the call
- const sem::Constant* value = nullptr; // The constant value for the call
+ const constant::Constant* value = nullptr; // The constant value for the call
if (stage == sem::EvaluationStage::kConstant) {
if (auto r = const_eval_.ArrayOrStructInit(ty, args)) {
value = r.Get();
@@ -2341,7 +2343,7 @@
// If the builtin is @const, and all arguments have constant values, evaluate the builtin
// now.
- const sem::Constant* value = nullptr;
+ const constant::Constant* value = nullptr;
auto stage = sem::EarliestStage(arg_stage, builtin.sem->Stage());
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
@@ -2559,7 +2561,7 @@
return nullptr;
}
- const sem::Constant* val = nullptr;
+ const constant::Constant* val = nullptr;
if (auto r = const_eval_.Literal(ty, literal)) {
val = r.Get();
} else {
@@ -2827,7 +2829,7 @@
RegisterLoadIfNeeded(lhs);
RegisterLoadIfNeeded(rhs);
- const sem::Constant* value = nullptr;
+ const constant::Constant* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) {
if (skip_const_eval_.Contains(expr)) {
@@ -2872,7 +2874,7 @@
const type::Type* ty = nullptr;
const sem::Variable* root_ident = nullptr;
- const sem::Constant* value = nullptr;
+ const constant::Constant* value = nullptr;
auto stage = sem::EvaluationStage::kRuntime;
switch (unary->op) {
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 451de0b..3e160b9 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -23,6 +23,7 @@
#include <utility>
#include <vector>
+#include "src/tint/constant/constant.h"
#include "src/tint/program_builder.h"
#include "src/tint/resolver/const_eval.h"
#include "src/tint/resolver/dependency_graph.h"
@@ -32,7 +33,6 @@
#include "src/tint/scope_stack.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/block_statement.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/bitset.h"
@@ -197,13 +197,13 @@
/// Converts `c` to `target_ty`
/// @returns true on success, false on failure.
- bool Convert(const sem::Constant*& c, const type::Type* target_ty, const Source& source);
+ bool Convert(const constant::Constant*& c, const type::Type* target_ty, const Source& source);
/// Transforms `args` to a vector of constants, and converts each constant to the call target's
/// parameter type.
/// @returns the vector of constants, `utils::Failure` on failure.
template <size_t N>
- utils::Result<utils::Vector<const sem::Constant*, N>> ConvertArguments(
+ utils::Result<utils::Vector<const constant::Constant*, N>> ConvertArguments(
const utils::Vector<const sem::Expression*, N>& args,
const sem::CallTarget* target);
diff --git a/src/tint/sem/call.cc b/src/tint/sem/call.cc
index e89bc5f..0ed2a4a 100644
--- a/src/tint/sem/call.cc
+++ b/src/tint/sem/call.cc
@@ -26,7 +26,7 @@
EvaluationStage stage,
utils::VectorRef<const sem::Expression*> arguments,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
bool has_side_effects)
: Base(declaration, target->ReturnType(), stage, statement, constant, has_side_effects),
target_(target),
diff --git a/src/tint/sem/call.h b/src/tint/sem/call.h
index 1213c6d..152ebb9 100644
--- a/src/tint/sem/call.h
+++ b/src/tint/sem/call.h
@@ -41,7 +41,7 @@
EvaluationStage stage,
utils::VectorRef<const sem::Expression*> arguments,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
bool has_side_effects);
/// Destructor
diff --git a/src/tint/sem/expression.cc b/src/tint/sem/expression.cc
index 4136852..7c59111 100644
--- a/src/tint/sem/expression.cc
+++ b/src/tint/sem/expression.cc
@@ -26,7 +26,7 @@
const type::Type* type,
EvaluationStage stage,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
bool has_side_effects,
const Variable* root_ident /* = nullptr */)
: declaration_(declaration),
diff --git a/src/tint/sem/expression.h b/src/tint/sem/expression.h
index 1b3e102..7127a1e 100644
--- a/src/tint/sem/expression.h
+++ b/src/tint/sem/expression.h
@@ -16,8 +16,8 @@
#define SRC_TINT_SEM_EXPRESSION_H_
#include "src/tint/ast/expression.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/sem/behavior.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/evaluation_stage.h"
#include "src/tint/sem/node.h"
@@ -44,7 +44,7 @@
const type::Type* type,
EvaluationStage stage,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
bool has_side_effects,
const Variable* root_ident = nullptr);
@@ -64,7 +64,7 @@
const Statement* Stmt() const { return statement_; }
/// @return the constant value of this expression
- const Constant* ConstantValue() const { return constant_; }
+ const constant::Constant* ConstantValue() const { return constant_; }
/// Returns the variable or parameter that this expression derives from.
/// For reference and pointer expressions, this will either be the originating
@@ -95,7 +95,7 @@
const type::Type* const type_;
const EvaluationStage stage_;
const Statement* const statement_;
- const Constant* const constant_;
+ const constant::Constant* const constant_;
sem::Behaviors behaviors_{sem::Behavior::kNext};
const bool has_side_effects_;
};
diff --git a/src/tint/sem/expression_test.cc b/src/tint/sem/expression_test.cc
index aa3dc00..1913f56 100644
--- a/src/tint/sem/expression_test.cc
+++ b/src/tint/sem/expression_test.cc
@@ -23,13 +23,13 @@
namespace tint::sem {
namespace {
-class MockConstant : public sem::Constant {
+class MockConstant : public constant::Constant {
public:
explicit MockConstant(const type::Type* ty) : type(ty) {}
~MockConstant() override {}
const type::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
- const Constant* Index(size_t) const override { return {}; }
+ const constant::Constant* Index(size_t) const override { return {}; }
bool AllZero() const override { return {}; }
bool AnyZero() const override { return {}; }
bool AllEqual() const override { return {}; }
diff --git a/src/tint/sem/index_accessor_expression.cc b/src/tint/sem/index_accessor_expression.cc
index 43e6cdd..f8a3fdb 100644
--- a/src/tint/sem/index_accessor_expression.cc
+++ b/src/tint/sem/index_accessor_expression.cc
@@ -28,7 +28,7 @@
const Expression* object,
const Expression* index,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
bool has_side_effects,
const Variable* root_ident /* = nullptr */)
: Base(declaration, type, stage, statement, constant, has_side_effects, root_ident),
diff --git a/src/tint/sem/index_accessor_expression.h b/src/tint/sem/index_accessor_expression.h
index 62ecba6..0e7586d 100644
--- a/src/tint/sem/index_accessor_expression.h
+++ b/src/tint/sem/index_accessor_expression.h
@@ -45,7 +45,7 @@
const Expression* object,
const Expression* index,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
bool has_side_effects,
const Variable* root_ident = nullptr);
diff --git a/src/tint/sem/load.cc b/src/tint/sem/load.cc
new file mode 100644
index 0000000..58f7068
--- /dev/null
+++ b/src/tint/sem/load.cc
@@ -0,0 +1,37 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0(the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/sem/load.h"
+
+#include "src/tint/debug.h"
+#include "src/tint/type/reference.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::sem::Load);
+
+namespace tint::sem {
+Load::Load(const Expression* ref, const Statement* statement)
+ : Base(/* declaration */ ref->Declaration(),
+ /* type */ ref->Type()->UnwrapRef(),
+ /* stage */ EvaluationStage::kRuntime, // Loads can only be runtime
+ /* statement */ statement,
+ /* constant */ nullptr,
+ /* has_side_effects */ ref->HasSideEffects(),
+ /* root_ident */ ref->RootIdentifier()),
+ reference_(ref) {
+ TINT_ASSERT(Semantic, ref->Type()->Is<type::Reference>());
+}
+
+Load::~Load() = default;
+
+} // namespace tint::sem
diff --git a/src/tint/sem/load.h b/src/tint/sem/load.h
new file mode 100644
index 0000000..1a63266
--- /dev/null
+++ b/src/tint/sem/load.h
@@ -0,0 +1,50 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0(the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_SEM_LOAD_H_
+#define SRC_TINT_SEM_LOAD_H_
+
+#include "src/tint/sem/expression.h"
+#include "src/tint/type/reference.h"
+
+namespace tint::sem {
+
+/// Load is a semantic expression which represents the load of a reference to a non-reference value.
+/// Loads from reference types are implicit in WGSL, so the Load semantic node shares the same AST
+/// node as the inner semantic node.
+class Load final : public Castable<Load, Expression> {
+ public:
+ /// Constructor
+ /// @param reference the reference expression being loaded
+ /// @param statement the statement that owns this expression
+ Load(const Expression* reference, const Statement* statement);
+
+ /// Destructor
+ ~Load() override;
+
+ /// @return the reference being loaded
+ const Expression* Reference() const { return reference_; }
+
+ /// @returns the type of the loaded reference.
+ const type::Reference* ReferenceType() const {
+ return static_cast<const type::Reference*>(reference_->Type());
+ }
+
+ private:
+ Expression const* const reference_;
+};
+
+} // namespace tint::sem
+
+#endif // SRC_TINT_SEM_LOAD_H_
diff --git a/src/tint/sem/materialize.cc b/src/tint/sem/materialize.cc
index b682463..90c8056 100644
--- a/src/tint/sem/materialize.cc
+++ b/src/tint/sem/materialize.cc
@@ -20,7 +20,7 @@
Materialize::Materialize(const Expression* expr,
const Statement* statement,
const type::Type* type,
- const Constant* constant)
+ const constant::Constant* constant)
: Base(/* declaration */ expr->Declaration(),
/* type */ type,
/* stage */ constant ? EvaluationStage::kConstant : EvaluationStage::kNotEvaluated,
diff --git a/src/tint/sem/materialize.h b/src/tint/sem/materialize.h
index 99fee52..8a65a03 100644
--- a/src/tint/sem/materialize.h
+++ b/src/tint/sem/materialize.h
@@ -35,12 +35,12 @@
Materialize(const Expression* expr,
const Statement* statement,
const type::Type* type,
- const Constant* constant);
+ const constant::Constant* constant);
/// Destructor
~Materialize() override;
- /// @return the target of the call
+ /// @return the expression being materialized
const Expression* Expr() const { return expr_; }
private:
diff --git a/src/tint/sem/member_accessor_expression.cc b/src/tint/sem/member_accessor_expression.cc
index 9daa946..be77b3f 100644
--- a/src/tint/sem/member_accessor_expression.cc
+++ b/src/tint/sem/member_accessor_expression.cc
@@ -27,7 +27,7 @@
const type::Type* type,
EvaluationStage stage,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
const Expression* object,
bool has_side_effects,
const Variable* root_ident /* = nullptr */)
@@ -39,7 +39,7 @@
StructMemberAccess::StructMemberAccess(const ast::MemberAccessorExpression* declaration,
const type::Type* type,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
const Expression* object,
const StructMember* member,
bool has_side_effects,
@@ -59,7 +59,7 @@
Swizzle::Swizzle(const ast::MemberAccessorExpression* declaration,
const type::Type* type,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
const Expression* object,
utils::VectorRef<uint32_t> indices,
bool has_side_effects,
diff --git a/src/tint/sem/member_accessor_expression.h b/src/tint/sem/member_accessor_expression.h
index a56c61f..aaad5ca 100644
--- a/src/tint/sem/member_accessor_expression.h
+++ b/src/tint/sem/member_accessor_expression.h
@@ -52,7 +52,7 @@
const type::Type* type,
EvaluationStage stage,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
const Expression* object,
bool has_side_effects,
const Variable* root_ident = nullptr);
@@ -78,7 +78,7 @@
StructMemberAccess(const ast::MemberAccessorExpression* declaration,
const type::Type* type,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
const Expression* object,
const StructMember* member,
bool has_side_effects,
@@ -110,7 +110,7 @@
Swizzle(const ast::MemberAccessorExpression* declaration,
const type::Type* type,
const Statement* statement,
- const Constant* constant,
+ const constant::Constant* constant,
const Expression* object,
utils::VectorRef<uint32_t> indices,
bool has_side_effects,
diff --git a/src/tint/sem/struct_test.cc b/src/tint/sem/struct_test.cc
index 8d16c2d..b009d55 100644
--- a/src/tint/sem/struct_test.cc
+++ b/src/tint/sem/struct_test.cc
@@ -34,17 +34,6 @@
EXPECT_EQ(s->SizeNoPadding(), 16u);
}
-TEST_F(SemStructTest, Hash) {
- auto* a_impl = create<ast::Struct>(Sym("a"), utils::Empty, utils::Empty);
- auto* a = create<sem::Struct>(a_impl, a_impl->source, a_impl->name, utils::Empty,
- 4u /* align */, 4u /* size */, 4u /* size_no_padding */);
- auto* b_impl = create<ast::Struct>(Sym("b"), utils::Empty, utils::Empty);
- auto* b = create<sem::Struct>(b_impl, b_impl->source, b_impl->name, utils::Empty,
- 4u /* align */, 4u /* size */, 4u /* size_no_padding */);
-
- EXPECT_NE(a->Hash(), b->Hash());
-}
-
TEST_F(SemStructTest, Equals) {
auto* a_impl = create<ast::Struct>(Sym("a"), utils::Empty, utils::Empty);
auto* a = create<sem::Struct>(a_impl, a_impl->source, a_impl->name, utils::Empty,
diff --git a/src/tint/sem/switch_statement.cc b/src/tint/sem/switch_statement.cc
index 5eb2f09..4cc857d 100644
--- a/src/tint/sem/switch_statement.cc
+++ b/src/tint/sem/switch_statement.cc
@@ -49,7 +49,7 @@
return static_cast<const ast::CaseStatement*>(Base::Declaration());
}
-CaseSelector::CaseSelector(const ast::CaseSelector* decl, const Constant* val)
+CaseSelector::CaseSelector(const ast::CaseSelector* decl, const constant::Constant* val)
: Base(), decl_(decl), val_(val) {}
CaseSelector::~CaseSelector() = default;
diff --git a/src/tint/sem/switch_statement.h b/src/tint/sem/switch_statement.h
index 929f8cf..4d18210 100644
--- a/src/tint/sem/switch_statement.h
+++ b/src/tint/sem/switch_statement.h
@@ -25,10 +25,12 @@
class CaseSelector;
class SwitchStatement;
} // namespace tint::ast
+namespace tint::constant {
+class Constant;
+} // namespace tint::constant
namespace tint::sem {
class CaseStatement;
class CaseSelector;
-class Constant;
class Expression;
} // namespace tint::sem
@@ -101,7 +103,7 @@
/// Constructor
/// @param decl the selector declaration
/// @param val the case selector value, nullptr for a default selector
- explicit CaseSelector(const ast::CaseSelector* decl, const Constant* val = nullptr);
+ explicit CaseSelector(const ast::CaseSelector* decl, const constant::Constant* val = nullptr);
/// Destructor
~CaseSelector() override;
@@ -113,11 +115,11 @@
const ast::CaseSelector* Declaration() const;
/// @returns the selector constant value, or nullptr if this is the default selector
- const Constant* Value() const { return val_; }
+ const constant::Constant* Value() const { return val_; }
private:
const ast::CaseSelector* const decl_;
- const Constant* const val_;
+ const constant::Constant* const val_;
};
} // namespace tint::sem
diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc
index 0dc279e..41bff70 100644
--- a/src/tint/sem/variable.cc
+++ b/src/tint/sem/variable.cc
@@ -33,7 +33,7 @@
EvaluationStage stage,
ast::AddressSpace address_space,
ast::Access access,
- const Constant* constant_value)
+ const constant::Constant* constant_value)
: declaration_(declaration),
type_(type),
stage_(stage),
@@ -49,7 +49,7 @@
ast::AddressSpace address_space,
ast::Access access,
const sem::Statement* statement,
- const Constant* constant_value)
+ const constant::Constant* constant_value)
: Base(declaration, type, stage, address_space, access, constant_value),
statement_(statement) {}
@@ -60,7 +60,7 @@
EvaluationStage stage,
ast::AddressSpace address_space,
ast::Access access,
- const Constant* constant_value,
+ const constant::Constant* constant_value,
sem::BindingPoint binding_point,
std::optional<uint32_t> location)
: Base(declaration, type, stage, address_space, access, constant_value),
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index bd6282e..0ff1c2c 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -59,7 +59,7 @@
EvaluationStage stage,
ast::AddressSpace address_space,
ast::Access access,
- const Constant* constant_value);
+ const constant::Constant* constant_value);
/// Destructor
~Variable() override;
@@ -80,7 +80,7 @@
ast::Access Access() const { return access_; }
/// @return the constant value of this expression
- const Constant* ConstantValue() const { return constant_value_; }
+ const constant::Constant* ConstantValue() const { return constant_value_; }
/// @returns the variable initializer expression, or nullptr if the variable
/// does not have one.
@@ -102,7 +102,7 @@
const EvaluationStage stage_;
const ast::AddressSpace address_space_;
const ast::Access access_;
- const Constant* constant_value_;
+ const constant::Constant* constant_value_;
const Expression* initializer_ = nullptr;
std::vector<const VariableUser*> users_;
};
@@ -124,7 +124,7 @@
ast::AddressSpace address_space,
ast::Access access,
const sem::Statement* statement,
- const Constant* constant_value);
+ const constant::Constant* constant_value);
/// Destructor
~LocalVariable() override;
@@ -164,7 +164,7 @@
EvaluationStage stage,
ast::AddressSpace address_space,
ast::Access access,
- const Constant* constant_value,
+ const constant::Constant* constant_value,
sem::BindingPoint binding_point = {},
std::optional<uint32_t> location = std::nullopt);
diff --git a/src/tint/tint.natvis b/src/tint/tint.natvis
index 2296189..fccb23a 100644
--- a/src/tint/tint.natvis
+++ b/src/tint/tint.natvis
@@ -256,7 +256,7 @@
<DisplayString>vec{width_}<{*subtype_}></DisplayString>
</Type>
- <Type Name="tint::sem::Constant">
+ <Type Name="tint::constant::Constant">
<DisplayString>Type={*Type()} Value={Value()}</DisplayString>
</Type>
diff --git a/src/tint/type/array_test.cc b/src/tint/type/array_test.cc
index 367aa4b..cd4d5f1 100644
--- a/src/tint/type/array_test.cc
+++ b/src/tint/type/array_test.cc
@@ -74,18 +74,8 @@
TEST_F(ArrayTest, Hash) {
auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
- auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
- auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
- auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
- auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
- auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
- EXPECT_NE(a->Hash(), e->Hash());
- EXPECT_NE(a->Hash(), f->Hash());
- EXPECT_NE(a->Hash(), g->Hash());
}
TEST_F(ArrayTest, Equals) {
diff --git a/src/tint/type/atomic_test.cc b/src/tint/type/atomic_test.cc
index b67e118..740ce79 100644
--- a/src/tint/type/atomic_test.cc
+++ b/src/tint/type/atomic_test.cc
@@ -33,9 +33,7 @@
TEST_F(AtomicTest, Hash) {
auto* a = create<Atomic>(create<I32>());
auto* b = create<Atomic>(create<I32>());
- auto* c = create<Atomic>(create<U32>());
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
}
TEST_F(AtomicTest, Equals) {
diff --git a/src/tint/type/depth_texture_test.cc b/src/tint/type/depth_texture_test.cc
index b729dc8..3b73af4 100644
--- a/src/tint/type/depth_texture_test.cc
+++ b/src/tint/type/depth_texture_test.cc
@@ -36,10 +36,8 @@
TEST_F(DepthTextureTest, Hash) {
auto* a = create<DepthTexture>(ast::TextureDimension::k2d);
auto* b = create<DepthTexture>(ast::TextureDimension::k2d);
- auto* c = create<DepthTexture>(ast::TextureDimension::k2dArray);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
}
TEST_F(DepthTextureTest, Equals) {
diff --git a/src/tint/type/matrix_test.cc b/src/tint/type/matrix_test.cc
index aff6924..7d88c6a 100644
--- a/src/tint/type/matrix_test.cc
+++ b/src/tint/type/matrix_test.cc
@@ -40,14 +40,8 @@
TEST_F(MatrixTest, Hash) {
auto* a = create<Matrix>(create<Vector>(create<I32>(), 3u), 4u);
auto* b = create<Matrix>(create<Vector>(create<I32>(), 3u), 4u);
- auto* c = create<Matrix>(create<Vector>(create<F32>(), 3u), 4u);
- auto* d = create<Matrix>(create<Vector>(create<I32>(), 2u), 4u);
- auto* e = create<Matrix>(create<Vector>(create<I32>(), 3u), 2u);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
- EXPECT_NE(a->Hash(), e->Hash());
}
TEST_F(MatrixTest, Equals) {
diff --git a/src/tint/type/multisampled_texture_test.cc b/src/tint/type/multisampled_texture_test.cc
index cf3d28a..8eac7c6 100644
--- a/src/tint/type/multisampled_texture_test.cc
+++ b/src/tint/type/multisampled_texture_test.cc
@@ -38,11 +38,7 @@
TEST_F(MultisampledTextureTest, Hash) {
auto* a = create<MultisampledTexture>(ast::TextureDimension::k2d, create<F32>());
auto* b = create<MultisampledTexture>(ast::TextureDimension::k2d, create<F32>());
- auto* c = create<MultisampledTexture>(ast::TextureDimension::k3d, create<F32>());
- auto* d = create<MultisampledTexture>(ast::TextureDimension::k2d, create<I32>());
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
}
TEST_F(MultisampledTextureTest, Equals) {
diff --git a/src/tint/type/pointer_test.cc b/src/tint/type/pointer_test.cc
index 4f8033f..ab51f7c 100644
--- a/src/tint/type/pointer_test.cc
+++ b/src/tint/type/pointer_test.cc
@@ -40,14 +40,8 @@
TEST_F(PointerTest, Hash) {
auto* a = create<Pointer>(create<I32>(), ast::AddressSpace::kStorage, ast::Access::kReadWrite);
auto* b = create<Pointer>(create<I32>(), ast::AddressSpace::kStorage, ast::Access::kReadWrite);
- auto* c = create<Pointer>(create<F32>(), ast::AddressSpace::kStorage, ast::Access::kReadWrite);
- auto* d = create<Pointer>(create<I32>(), ast::AddressSpace::kPrivate, ast::Access::kReadWrite);
- auto* e = create<Pointer>(create<I32>(), ast::AddressSpace::kStorage, ast::Access::kRead);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
- EXPECT_NE(a->Hash(), e->Hash());
}
TEST_F(PointerTest, Equals) {
diff --git a/src/tint/type/reference_test.cc b/src/tint/type/reference_test.cc
index 3d50997..094770e 100644
--- a/src/tint/type/reference_test.cc
+++ b/src/tint/type/reference_test.cc
@@ -46,16 +46,8 @@
create<Reference>(create<I32>(), ast::AddressSpace::kStorage, ast::Access::kReadWrite);
auto* b =
create<Reference>(create<I32>(), ast::AddressSpace::kStorage, ast::Access::kReadWrite);
- auto* c =
- create<Reference>(create<F32>(), ast::AddressSpace::kStorage, ast::Access::kReadWrite);
- auto* d =
- create<Reference>(create<I32>(), ast::AddressSpace::kPrivate, ast::Access::kReadWrite);
- auto* e = create<Reference>(create<I32>(), ast::AddressSpace::kStorage, ast::Access::kRead);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
- EXPECT_NE(a->Hash(), e->Hash());
}
TEST_F(ReferenceTest, Equals) {
diff --git a/src/tint/type/sampled_texture_test.cc b/src/tint/type/sampled_texture_test.cc
index ab0c74d..f9cb5be 100644
--- a/src/tint/type/sampled_texture_test.cc
+++ b/src/tint/type/sampled_texture_test.cc
@@ -41,12 +41,8 @@
TEST_F(SampledTextureTest, Hash) {
auto* a = create<SampledTexture>(ast::TextureDimension::kCube, create<F32>());
auto* b = create<SampledTexture>(ast::TextureDimension::kCube, create<F32>());
- auto* c = create<SampledTexture>(ast::TextureDimension::k2d, create<F32>());
- auto* d = create<SampledTexture>(ast::TextureDimension::kCube, create<I32>());
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
}
TEST_F(SampledTextureTest, Equals) {
diff --git a/src/tint/type/sampler_test.cc b/src/tint/type/sampler_test.cc
index 65a2370..ba30541 100644
--- a/src/tint/type/sampler_test.cc
+++ b/src/tint/type/sampler_test.cc
@@ -39,10 +39,8 @@
TEST_F(SamplerTest, Hash) {
auto* a = create<Sampler>(ast::SamplerKind::kSampler);
auto* b = create<Sampler>(ast::SamplerKind::kSampler);
- auto* c = create<Sampler>(ast::SamplerKind::kComparisonSampler);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
}
TEST_F(SamplerTest, Equals) {
diff --git a/src/tint/type/storage_texture_test.cc b/src/tint/type/storage_texture_test.cc
index b375207..0c86d89 100644
--- a/src/tint/type/storage_texture_test.cc
+++ b/src/tint/type/storage_texture_test.cc
@@ -55,17 +55,8 @@
ast::Access::kReadWrite);
auto* b = Create(ast::TextureDimension::kCube, ast::TexelFormat::kRgba32Float,
ast::Access::kReadWrite);
- auto* c =
- Create(ast::TextureDimension::k2d, ast::TexelFormat::kRgba32Float, ast::Access::kReadWrite);
- auto* d =
- Create(ast::TextureDimension::kCube, ast::TexelFormat::kR32Float, ast::Access::kReadWrite);
- auto* e =
- Create(ast::TextureDimension::kCube, ast::TexelFormat::kRgba32Float, ast::Access::kRead);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
- EXPECT_NE(a->Hash(), e->Hash());
}
TEST_F(StorageTextureTest, Equals) {
diff --git a/src/tint/type/struct_test.cc b/src/tint/type/struct_test.cc
index ae9a249..8bacfca 100644
--- a/src/tint/type/struct_test.cc
+++ b/src/tint/type/struct_test.cc
@@ -31,15 +31,6 @@
EXPECT_EQ(s->SizeNoPadding(), 16u);
}
-TEST_F(TypeStructTest, Hash) {
- auto* a = create<Struct>(Source{}, Sym("a"), utils::Empty, 4u /* align */, 4u /* size */,
- 4u /* size_no_padding */);
- auto* b = create<Struct>(Source{}, Sym("b"), utils::Empty, 4u /* align */, 4u /* size */,
- 4u /* size_no_padding */);
-
- EXPECT_NE(a->Hash(), b->Hash());
-}
-
TEST_F(TypeStructTest, Equals) {
auto* a = create<Struct>(Source{}, Sym("a"), utils::Empty, 4u /* align */, 4u /* size */,
4u /* size_no_padding */);
diff --git a/src/tint/type/vector_test.cc b/src/tint/type/vector_test.cc
index aed294a..65d7b1b 100644
--- a/src/tint/type/vector_test.cc
+++ b/src/tint/type/vector_test.cc
@@ -37,12 +37,8 @@
TEST_F(VectorTest, Hash) {
auto* a = create<Vector>(create<I32>(), 2u);
auto* b = create<Vector>(create<I32>(), 2u);
- auto* c = create<Vector>(create<F32>(), 2u);
- auto* d = create<Vector>(create<F32>(), 3u);
EXPECT_EQ(a->Hash(), b->Hash());
- EXPECT_NE(a->Hash(), c->Hash());
- EXPECT_NE(a->Hash(), d->Hash());
}
TEST_F(VectorTest, Equals) {
diff --git a/src/tint/utils/enum_set_test.cc b/src/tint/utils/enum_set_test.cc
index e469650..4cdeb63 100644
--- a/src/tint/utils/enum_set_test.cc
+++ b/src/tint/utils/enum_set_test.cc
@@ -192,7 +192,6 @@
TEST(EnumSetTest, Hash) {
auto hash = [&](EnumSet<E> s) { return std::hash<EnumSet<E>>()(s); };
EXPECT_EQ(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::B)));
- EXPECT_NE(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::C)));
}
TEST(EnumSetTest, Value) {
diff --git a/src/tint/utils/hash.h b/src/tint/utils/hash.h
index 89cf0f0..8c9dd25 100644
--- a/src/tint/utils/hash.h
+++ b/src/tint/utils/hash.h
@@ -23,6 +23,7 @@
#include <variant>
#include <vector>
+#include "src/tint/utils/crc32.h"
#include "src/tint/utils/vector.h"
namespace tint::utils {
@@ -37,14 +38,26 @@
template <>
struct HashCombineOffset<4> {
/// @returns the seed bias value for HashCombine()
- static constexpr inline uint32_t value() { return 0x7f4a7c16; }
+ static constexpr inline uint32_t value() {
+ constexpr uint32_t base = 0x7f4a7c16;
+#ifdef TINT_HASH_SEED
+ return base ^ static_cast<uint32_t>(TINT_HASH_SEED);
+#endif
+ return base;
+ }
};
/// Specialization of HashCombineOffset for size_t == 8.
template <>
struct HashCombineOffset<8> {
/// @returns the seed bias value for HashCombine()
- static constexpr inline uint64_t value() { return 0x9e3779b97f4a7c16; }
+ static constexpr inline uint64_t value() {
+ constexpr uint64_t base = 0x9e3779b97f4a7c16;
+#ifdef TINT_HASH_SEED
+ return base ^ static_cast<uint64_t>(TINT_HASH_SEED);
+#endif
+ return base;
+ }
};
} // namespace detail
@@ -76,6 +89,9 @@
/// @returns a hash of the pointer
size_t operator()(T* ptr) const {
auto hash = std::hash<T*>()(ptr);
+#ifdef TINT_HASH_SEED
+ hash ^= static_cast<uint32_t>(TINT_HASH_SEED);
+#endif
return hash ^ (hash >> 4);
}
};
@@ -148,7 +164,7 @@
template <typename... ARGS>
size_t HashCombine(size_t hash, const ARGS&... values) {
constexpr size_t offset = detail::HashCombineOffset<sizeof(size_t)>::value();
- ((hash ^= Hash(values) + offset + (hash << 6) + (hash >> 2)), ...);
+ ((hash ^= Hash(values) + (offset ^ (hash >> 2))), ...);
return hash;
}
diff --git a/src/tint/utils/hash_test.cc b/src/tint/utils/hash_test.cc
index 6ce6820..cdceab4 100644
--- a/src/tint/utils/hash_test.cc
+++ b/src/tint/utils/hash_test.cc
@@ -26,28 +26,19 @@
TEST(HashTests, Basic) {
EXPECT_EQ(Hash(123), Hash(123));
- EXPECT_NE(Hash(123), Hash(321));
EXPECT_EQ(Hash(123, 456), Hash(123, 456));
- EXPECT_NE(Hash(123, 456), Hash(456, 123));
- EXPECT_NE(Hash(123, 456), Hash(123));
EXPECT_EQ(Hash(123, 456, false), Hash(123, 456, false));
- EXPECT_NE(Hash(123, 456, false), Hash(123, 456));
EXPECT_EQ(Hash(std::string("hello")), Hash(std::string("hello")));
- EXPECT_NE(Hash(std::string("hello")), Hash(std::string("world")));
}
TEST(HashTests, StdVector) {
EXPECT_EQ(Hash(std::vector<int>({})), Hash(std::vector<int>({})));
EXPECT_EQ(Hash(std::vector<int>({1, 2, 3})), Hash(std::vector<int>({1, 2, 3})));
- EXPECT_NE(Hash(std::vector<int>({1, 2, 3})), Hash(std::vector<int>({1, 2, 4})));
- EXPECT_NE(Hash(std::vector<int>({1, 2, 3})), Hash(std::vector<int>({1, 2, 3, 4})));
}
TEST(HashTests, TintVector) {
EXPECT_EQ(Hash(Vector<int, 0>({})), Hash(Vector<int, 0>({})));
EXPECT_EQ(Hash(Vector<int, 0>({1, 2, 3})), Hash(Vector<int, 0>({1, 2, 3})));
- EXPECT_NE(Hash(Vector<int, 0>({1, 2, 3})), Hash(Vector<int, 0>({1, 2, 4})));
- EXPECT_NE(Hash(Vector<int, 0>({1, 2, 3})), Hash(Vector<int, 0>({1, 2, 3, 4})));
EXPECT_EQ(Hash(Vector<int, 3>({1, 2, 3})), Hash(Vector<int, 4>({1, 2, 3})));
EXPECT_EQ(Hash(Vector<int, 3>({1, 2, 3})), Hash(Vector<int, 2>({1, 2, 3})));
}
@@ -55,8 +46,6 @@
TEST(HashTests, Tuple) {
EXPECT_EQ(Hash(std::make_tuple(1)), Hash(std::make_tuple(1)));
EXPECT_EQ(Hash(std::make_tuple(1, 2, 3)), Hash(std::make_tuple(1, 2, 3)));
- EXPECT_NE(Hash(std::make_tuple(1, 2, 3)), Hash(std::make_tuple(1, 2, 4)));
- EXPECT_NE(Hash(std::make_tuple(1, 2, 3)), Hash(std::make_tuple(1, 2, 3, 4)));
}
TEST(HashTests, UnorderedKeyWrapper) {
diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc
index 5e97aef..6d993e2 100644
--- a/src/tint/utils/hashmap_test.cc
+++ b/src/tint/utils/hashmap_test.cc
@@ -388,11 +388,8 @@
Hashmap<int, std::string, 8> b;
EXPECT_EQ(Hash(a), Hash(b));
a.Add(1, "one");
- EXPECT_NE(Hash(a), Hash(b));
b.Add(2, "two");
- EXPECT_NE(Hash(a), Hash(b));
a.Add(2, "two");
- EXPECT_NE(Hash(a), Hash(b));
b.Add(1, "one");
EXPECT_EQ(Hash(a), Hash(b));
}
@@ -402,11 +399,8 @@
Hashmap<int, std::string, 4> b;
EXPECT_EQ(Hash(a), Hash(b));
a.Add(1, "one");
- EXPECT_NE(Hash(a), Hash(b));
b.Add(2, "two");
- EXPECT_NE(Hash(a), Hash(b));
a.Add(2, "two");
- EXPECT_NE(Hash(a), Hash(b));
b.Add(1, "one");
EXPECT_EQ(Hash(a), Hash(b));
}
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index d5f8b9c..6f97cc6 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -26,10 +26,10 @@
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/debug.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/module.h"
@@ -2294,7 +2294,7 @@
return true;
}
-bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constant) {
+bool GeneratorImpl::EmitConstant(std::ostream& out, const constant::Constant* constant) {
return Switch(
constant->Type(), //
[&](const type::Bool*) {
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index a0bf3a5..2d5d015 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -363,7 +363,7 @@
/// @param out the output stream
/// @param constant the constant value to emit
/// @returns true if the constant value was successfully emitted
- bool EmitConstant(std::ostream& out, const sem::Constant* constant);
+ bool EmitConstant(std::ostream& out, const constant::Constant* constant);
/// Handles a literal
/// @param out the output stream
/// @param lit the literal to emit
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 92df2fe..d62c668 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -27,10 +27,10 @@
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/variable_decl_statement.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/debug.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/module.h"
@@ -3263,7 +3263,7 @@
}
bool GeneratorImpl::EmitConstant(std::ostream& out,
- const sem::Constant* constant,
+ const constant::Constant* constant,
bool is_variable_initializer) {
return Switch(
constant->Type(), //
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index 50e8114..caf3816 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -352,7 +352,7 @@
/// initializer
/// @returns true if the constant value was successfully emitted
bool EmitConstant(std::ostream& out,
- const sem::Constant* constant,
+ const constant::Constant* constant,
bool is_variable_initializer);
/// Handles a literal
/// @param out the output stream
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 6db24ca..6b47d8c 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -31,8 +31,8 @@
#include "src/tint/ast/module.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/void.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/sem/call.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/module.h"
@@ -1658,7 +1658,7 @@
});
}
-bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constant) {
+bool GeneratorImpl::EmitConstant(std::ostream& out, const constant::Constant* constant) {
return Switch(
constant->Type(), //
[&](const type::Bool*) {
diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h
index d2eea40..53fff20 100644
--- a/src/tint/writer/msl/generator_impl.h
+++ b/src/tint/writer/msl/generator_impl.h
@@ -260,7 +260,7 @@
/// @param out the output stream
/// @param constant the constant value to emit
/// @returns true if the constant value was successfully emitted
- bool EmitConstant(std::ostream& out, const sem::Constant* constant);
+ bool EmitConstant(std::ostream& out, const constant::Constant* constant);
/// Handles a literal
/// @param out the output of the expression stream
/// @param lit the literal to emit
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 14cebf1..cc04df1 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -22,9 +22,9 @@
#include "src/tint/ast/id_attribute.h"
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/ast/traverse_expressions.h"
+#include "src/tint/constant/constant.h"
#include "src/tint/sem/builtin.h"
#include "src/tint/sem/call.h"
-#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/materialize.h"
#include "src/tint/sem/member_accessor_expression.h"
@@ -1641,7 +1641,7 @@
return GenerateConstantIfNeeded(constant);
}
-uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant* constant) {
+uint32_t Builder::GenerateConstantIfNeeded(const constant::Constant* constant) {
if (constant->AllZero()) {
return GenerateConstantNullIfNeeded(constant->Type());
}
diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h
index e257680..f3af99a 100644
--- a/src/tint/writer/spirv/builder.h
+++ b/src/tint/writer/spirv/builder.h
@@ -559,7 +559,7 @@
/// Generates a constant value if needed
/// @param constant the constant to generate.
/// @returns the ID on success or 0 on failure
- uint32_t GenerateConstantIfNeeded(const sem::Constant* constant);
+ uint32_t GenerateConstantIfNeeded(const constant::Constant* constant);
/// Generates a scalar constant if needed
/// @param constant the constant to generate.