Import Tint changes from Dawn
Changes:
- f8865ecde66b5bc462c86e17afd8ec7c33d0e926 [tint][wgsl] Improve 'cannot take address of' diagnostic by Ben Clayton <bclayton@google.com>
- dd852a700405828f7339a89b89ea807f96500c87 [tint] Add f16 overload of subgroupBroadcast() by James Price <jrprice@google.com>
- 0f8c1cabc4d0b1b98a2eb173c83bef00b7e002a9 Fix workgroup_storage_size computation by Gregg Tavares <gman@chromium.org>
- 01fa7c86fdcfbbf81de4bccf1144b1c6d2f710a6 Various fixes for Win32 CTS coverage by Ben Clayton <bclayton@google.com>
- a0b1a226c1e1d64e947e1626f0a32b4ddb900523 [tint][utils] Replace std::unordered_map with Hashset in ... by Ben Clayton <bclayton@google.com>
- f70336708248f8e2b77187f552c220c473e2372c [tint][core] Remove dead code by Ben Clayton <bclayton@google.com>
- ce5c2650c889910ec154401f7079473f7d90b97a [tint][utils] Simplify pointer hashing by Ben Clayton <bclayton@google.com>
- 7a27d6cfab9a161d8af3d0946b36970cf6694938 [tint][utils] Reimplement Hashmap / Hashset. by Ben Clayton <bclayton@google.com>
- 3d49551542867b3d9d3c02c9dadf61ffdcc0074d [tint][core] Fix tests by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: f8865ecde66b5bc462c86e17afd8ec7c33d0e926
Change-Id: Iae6f9d13e5b3974f1bada4a4eb1e486a3af9998d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/172820
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/cmd/fuzz/ir/fuzz.h b/src/tint/cmd/fuzz/ir/fuzz.h
index 183ebe1..533cac4 100644
--- a/src/tint/cmd/fuzz/ir/fuzz.h
+++ b/src/tint/cmd/fuzz/ir/fuzz.h
@@ -81,7 +81,7 @@
/// Registers the fuzzer function with the IR fuzzer executable.
/// @param fuzzer the fuzzer
-void Register(const IRFuzzer& fuzzer);
+void Register([[maybe_unused]] const IRFuzzer& fuzzer);
/// TINT_IR_MODULE_FUZZER registers the fuzzer function.
#define TINT_IR_MODULE_FUZZER(FUNCTION) \
diff --git a/src/tint/cmd/tint/main.cc b/src/tint/cmd/tint/main.cc
index 9797feb..22add27 100644
--- a/src/tint/cmd/tint/main.cc
+++ b/src/tint/cmd/tint/main.cc
@@ -1184,8 +1184,8 @@
std::unordered_map<tint::OverrideId, double> values;
values.reserve(options.overrides.Count());
- for (auto override : options.overrides) {
- const auto& name = override.key;
+ for (auto& override : options.overrides) {
+ const auto& name = override.key.Value();
const auto& value = override.value;
if (name.empty()) {
std::cerr << "empty override name\n";
diff --git a/src/tint/lang/core/constant/composite_test.cc b/src/tint/lang/core/constant/composite_test.cc
index 1e0e550..f03f323 100644
--- a/src/tint/lang/core/constant/composite_test.cc
+++ b/src/tint/lang/core/constant/composite_test.cc
@@ -47,8 +47,8 @@
auto* fPos1 = constants.Get(1_f);
auto* fNeg1 = constants.Get(-1_f);
- auto* compositePosZeros = constants.Composite(vec3f, Vector{fPos0, fPos0});
- auto* compositeNegZeros = constants.Composite(vec3f, Vector{fNeg0, fNeg0});
+ auto* compositePosZeros = constants.Composite(vec3f, Vector{fPos0, fPos0, fPos0});
+ auto* compositeNegZeros = constants.Composite(vec3f, Vector{fNeg0, fNeg0, fNeg0});
auto* compositeMixed = constants.Composite(vec3f, Vector{fNeg0, fPos1, fPos0});
auto* compositePosNeg = constants.Composite(vec3f, Vector{fNeg1, fPos1, fPos1});
@@ -66,8 +66,8 @@
auto* fPos1 = constants.Get(1_f);
auto* fNeg1 = constants.Get(-1_f);
- auto* compositePosZeros = constants.Composite(vec3f, Vector{fPos0, fPos0});
- auto* compositeNegZeros = constants.Composite(vec3f, Vector{fNeg0, fNeg0});
+ auto* compositePosZeros = constants.Composite(vec3f, Vector{fPos0, fPos0, fPos0});
+ auto* compositeNegZeros = constants.Composite(vec3f, Vector{fNeg0, fNeg0, fNeg0});
auto* compositeMixed = constants.Composite(vec3f, Vector{fNeg0, fPos1, fPos0});
auto* compositePosNeg = constants.Composite(vec3f, Vector{fNeg1, fPos1, fPos1});
@@ -78,12 +78,12 @@
}
TEST_F(ConstantTest_Composite, Index) {
- auto* vec3f = create<core::type::Vector>(create<core::type::F32>(), 3u);
+ auto* vec2f = create<core::type::Vector>(create<core::type::F32>(), 2u);
auto* fPos0 = constants.Get(0_f);
auto* fPos1 = constants.Get(1_f);
- auto* composite = constants.Composite(vec3f, Vector{fPos1, fPos0});
+ auto* composite = constants.Composite(vec2f, Vector{fPos1, fPos0});
ASSERT_NE(composite->Index(0), nullptr);
ASSERT_NE(composite->Index(1), nullptr);
@@ -96,12 +96,12 @@
}
TEST_F(ConstantTest_Composite, Clone) {
- auto* vec3f = create<core::type::Vector>(create<core::type::F32>(), 3u);
+ auto* vec2f = create<core::type::Vector>(create<core::type::F32>(), 2u);
auto* fPos0 = constants.Get(0_f);
auto* fPos1 = constants.Get(1_f);
- auto* composite = constants.Composite(vec3f, Vector{fPos1, fPos0});
+ auto* composite = constants.Composite(vec2f, Vector{fPos1, fPos0});
constant::Manager mgr;
constant::CloneContext ctx{core::type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr};
diff --git a/src/tint/lang/core/constant/eval_runtime_semantics_test.cc b/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
index 95798e9..d48c8c9 100644
--- a/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
+++ b/src/tint/lang/core/constant/eval_runtime_semantics_test.cc
@@ -479,8 +479,9 @@
}
TEST_F(ConstEvalRuntimeSemanticsTest, Unpack2x16Float_OutOfRange) {
+ auto* vec2f = create<core::type::Vector>(create<core::type::F32>(), 2u);
auto* a = constants.Get(u32(0x51437C00));
- auto result = eval.unpack2x16float(create<core::type::U32>(), Vector{a}, {});
+ auto result = eval.unpack2x16float(vec2f, Vector{a}, {});
ASSERT_EQ(result, Success);
EXPECT_FLOAT_EQ(result.Get()->Index(0)->ValueAs<f32>(), 0.f);
EXPECT_FLOAT_EQ(result.Get()->Index(1)->ValueAs<f32>(), 42.09375f);
diff --git a/src/tint/lang/core/constant/manager.cc b/src/tint/lang/core/constant/manager.cc
index ddd19d8..69feff2 100644
--- a/src/tint/lang/core/constant/manager.cc
+++ b/src/tint/lang/core/constant/manager.cc
@@ -144,7 +144,7 @@
zeros.Reserve(s->Members().Length());
for (auto* member : s->Members()) {
auto* zero =
- zero_by_type.GetOrCreate(member->Type(), [&] { return Zero(member->Type()); });
+ zero_by_type.GetOrAdd(member->Type(), [&] { return Zero(member->Type()); });
if (!zero) {
return nullptr;
}
diff --git a/src/tint/lang/core/core.def b/src/tint/lang/core/core.def
index 4a5ed76..1e57b92 100644
--- a/src/tint/lang/core/core.def
+++ b/src/tint/lang/core/core.def
@@ -851,7 +851,7 @@
@stage("fragment", "compute") fn atomicCompareExchangeWeak<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T, T) -> __atomic_compare_exchange_result<T>
@must_use @stage("compute") fn subgroupBallot() -> vec4<u32>
-@must_use @stage("compute") fn subgroupBroadcast<T: fiu32>(value: T, @const sourceLaneIndex: u32) -> T
+@must_use @stage("compute") fn subgroupBroadcast<T: fiu32_f16>(value: T, @const sourceLaneIndex: u32) -> T
////////////////////////////////////////////////////////////////////////////////
// Value constructors //
diff --git a/src/tint/lang/core/intrinsic/data.cc b/src/tint/lang/core/intrinsic/data.cc
index 97482d5..a127690 100644
--- a/src/tint/lang/core/intrinsic/data.cc
+++ b/src/tint/lang/core/intrinsic/data.cc
@@ -4686,41 +4686,46 @@
{
/* [29] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(65),
+ /* matcher_index */ TypeMatcherIndex(67),
},
{
/* [30] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(57),
+ /* matcher_index */ TypeMatcherIndex(65),
},
{
/* [31] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(58),
+ /* matcher_index */ TypeMatcherIndex(57),
},
{
/* [32] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(55),
+ /* matcher_index */ TypeMatcherIndex(58),
},
{
/* [33] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(56),
+ /* matcher_index */ TypeMatcherIndex(55),
},
{
/* [34] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(59),
+ /* matcher_index */ TypeMatcherIndex(56),
},
{
/* [35] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(54),
+ /* matcher_index */ TypeMatcherIndex(59),
},
{
/* [36] */
/* name */ "T",
+ /* matcher_index */ TypeMatcherIndex(54),
+ },
+ {
+ /* [37] */
+ /* name */ "T",
/* matcher_index */ TypeMatcherIndex(71),
},
};
@@ -5540,7 +5545,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(52),
@@ -6437,7 +6442,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(10),
@@ -6723,7 +6728,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(72),
@@ -7399,7 +7404,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(102),
@@ -7412,7 +7417,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(385),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(102),
@@ -7477,7 +7482,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(108),
@@ -7490,7 +7495,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(388),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(108),
@@ -7555,7 +7560,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(114),
@@ -7568,7 +7573,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(391),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(114),
@@ -7633,7 +7638,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(120),
@@ -7646,7 +7651,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(394),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(120),
@@ -7711,7 +7716,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(126),
@@ -7724,7 +7729,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(397),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(126),
@@ -7789,7 +7794,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(132),
@@ -7802,7 +7807,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(400),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(132),
@@ -7867,7 +7872,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(138),
@@ -7880,7 +7885,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(403),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(138),
@@ -7945,7 +7950,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(144),
@@ -7958,7 +7963,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(406),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(144),
@@ -8023,7 +8028,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(150),
@@ -8036,7 +8041,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(409),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(150),
@@ -8426,7 +8431,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -8439,7 +8444,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -8478,7 +8483,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -8491,7 +8496,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -8608,7 +8613,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(30),
+ /* template_types */ TemplateTypeIndex(31),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(31),
@@ -8647,7 +8652,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(31),
+ /* template_types */ TemplateTypeIndex(32),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(33),
@@ -8686,7 +8691,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(32),
+ /* template_types */ TemplateTypeIndex(33),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(15),
@@ -8725,7 +8730,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(33),
+ /* template_types */ TemplateTypeIndex(34),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(85),
@@ -8764,7 +8769,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(34),
+ /* template_types */ TemplateTypeIndex(35),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(9),
@@ -10233,7 +10238,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10246,7 +10251,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10285,7 +10290,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10298,7 +10303,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10467,7 +10472,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(16),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10480,7 +10485,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(351),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10493,7 +10498,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(16),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10506,7 +10511,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(351),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10948,7 +10953,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(0),
+ /* template_types */ TemplateTypeIndex(29),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(348),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10987,7 +10992,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(213),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(156),
@@ -11926,7 +11931,7 @@
},
{
/* [120] */
- /* fn subgroupBroadcast<T : fiu32>(value: T, @const sourceLaneIndex: u32) -> T */
+ /* fn subgroupBroadcast<T : fiu32_f16>(value: T, @const sourceLaneIndex: u32) -> T */
/* num overloads */ 1,
/* overloads */ OverloadIndex(465),
},
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index ed03faa..ef364ce 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -175,7 +175,7 @@
if (block_in == nullptr) {
return 0;
}
- return blocks_.GetOrCreate(block_in, [&]() -> uint32_t {
+ return blocks_.GetOrAdd(block_in, [&]() -> uint32_t {
auto& block_out = *mod_out_.add_blocks();
auto id = static_cast<uint32_t>(blocks_.Count());
for (auto* inst : *block_in) {
@@ -351,7 +351,7 @@
if (type_in == nullptr) {
return 0;
}
- return types_.GetOrCreate(type_in, [&]() -> uint32_t {
+ return types_.GetOrAdd(type_in, [&]() -> uint32_t {
pb::Type type_out;
tint::Switch(
type_in, //
@@ -498,7 +498,7 @@
if (!value_in) {
return 0;
}
- return values_.GetOrCreate(value_in, [&] {
+ return values_.GetOrAdd(value_in, [&] {
auto& value_out = *mod_out_.add_values();
auto id = static_cast<uint32_t>(mod_out_.values().size());
@@ -563,7 +563,7 @@
if (!constant_in) {
return 0;
}
- return constant_values_.GetOrCreate(constant_in, [&] {
+ return constant_values_.GetOrAdd(constant_in, [&] {
pb::ConstantValue constant_out;
tint::Switch(
constant_in, //
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index e427f34..3640439 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -298,7 +298,7 @@
/// @param val the constant value
/// @returns the new constant
ir::Constant* Constant(const core::constant::Value* val) {
- return ir.constants.GetOrCreate(val, [&] { return ir.values.Create<ir::Constant>(val); });
+ return ir.constants.GetOrAdd(val, [&] { return ir.values.Create<ir::Constant>(val); });
}
/// Creates a ir::Constant for an i32 Scalar
diff --git a/src/tint/lang/core/ir/disassembler.cc b/src/tint/lang/core/ir/disassembler.cc
index 6397f8a..2e95c2c 100644
--- a/src/tint/lang/core/ir/disassembler.cc
+++ b/src/tint/lang/core/ir/disassembler.cc
@@ -110,12 +110,12 @@
size_t Disassembler::IdOf(const Block* node) {
TINT_ASSERT(node);
- return block_ids_.GetOrCreate(node, [&] { return block_ids_.Count(); });
+ return block_ids_.GetOrAdd(node, [&] { return block_ids_.Count(); });
}
std::string Disassembler::IdOf(const Value* value) {
TINT_ASSERT(value);
- return value_ids_.GetOrCreate(value, [&] {
+ return value_ids_.GetOrAdd(value, [&] {
if (auto sym = mod_.NameOf(value)) {
if (ids_.Add(sym.Name())) {
return sym.Name();
@@ -137,7 +137,7 @@
return "undef";
}
- return if_names_.GetOrCreate(inst, [&] { return "if_" + std::to_string(if_names_.Count()); });
+ return if_names_.GetOrAdd(inst, [&] { return "if_" + std::to_string(if_names_.Count()); });
}
std::string Disassembler::NameOf(const Loop* inst) {
@@ -145,8 +145,8 @@
return "undef";
}
- return loop_names_.GetOrCreate(inst,
- [&] { return "loop_" + std::to_string(loop_names_.Count()); });
+ return loop_names_.GetOrAdd(inst,
+ [&] { return "loop_" + std::to_string(loop_names_.Count()); });
}
std::string Disassembler::NameOf(const Switch* inst) {
@@ -154,7 +154,7 @@
return "undef";
}
- return switch_names_.GetOrCreate(
+ return switch_names_.GetOrAdd(
inst, [&] { return "switch_" + std::to_string(switch_names_.Count()); });
}
diff --git a/src/tint/lang/core/ir/disassembler.h b/src/tint/lang/core/ir/disassembler.h
index 8f15e95..c585fc4 100644
--- a/src/tint/lang/core/ir/disassembler.h
+++ b/src/tint/lang/core/ir/disassembler.h
@@ -89,24 +89,20 @@
/// @param inst the instruction to retrieve
/// @returns the source for the instruction
Source InstructionSource(const Instruction* inst) {
- return instruction_to_src_.Get(inst).value_or(Source{});
+ return instruction_to_src_.GetOr(inst, Source{});
}
/// @param operand the operand to retrieve
/// @returns the source for the operand
- Source OperandSource(IndexedValue operand) {
- return operand_to_src_.Get(operand).value_or(Source{});
- }
+ Source OperandSource(IndexedValue operand) { return operand_to_src_.GetOr(operand, Source{}); }
/// @param result the result to retrieve
/// @returns the source for the result
- Source ResultSource(IndexedValue result) {
- return result_to_src_.Get(result).value_or(Source{});
- }
+ Source ResultSource(IndexedValue result) { return result_to_src_.GetOr(result, Source{}); }
/// @param blk teh block to retrieve
/// @returns the source for the block
- Source BlockSource(const Block* blk) { return block_to_src_.Get(blk).value_or(Source{}); }
+ Source BlockSource(const Block* blk) { return block_to_src_.GetOr(blk, Source{}); }
/// Stores the given @p src location for @p inst instruction
/// @param inst the instruction to store
diff --git a/src/tint/lang/core/ir/module.cc b/src/tint/lang/core/ir/module.cc
index 4738c14..cbc747b 100644
--- a/src/tint/lang/core/ir/module.cc
+++ b/src/tint/lang/core/ir/module.cc
@@ -49,7 +49,7 @@
}
Symbol Module::NameOf(const Value* value) const {
- return value_to_name_.Get(value).value_or(Symbol{});
+ return value_to_name_.GetOr(value, Symbol{});
}
void Module::SetName(Instruction* inst, std::string_view name) {
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.cc b/src/tint/lang/core/ir/transform/binary_polyfill.cc
index 83fb233..27d3b13 100644
--- a/src/tint/lang/core/ir/transform/binary_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.cc
@@ -155,7 +155,7 @@
bool is_signed = result_ty->is_signed_integer_scalar_or_vector();
auto& helpers = is_div ? int_div_helpers : int_mod_helpers;
- auto* helper = helpers.GetOrCreate(result_ty, [&] {
+ auto* helper = helpers.GetOrAdd(result_ty, [&] {
// Generate a name for the helper function.
StringStream name;
name << "tint_" << (is_div ? "div_" : "mod_");
diff --git a/src/tint/lang/core/ir/transform/conversion_polyfill.cc b/src/tint/lang/core/ir/transform/conversion_polyfill.cc
index 705d946..5d4fdbc 100644
--- a/src/tint/lang/core/ir/transform/conversion_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/conversion_polyfill.cc
@@ -106,7 +106,7 @@
auto* src_el_ty = src_ty->DeepestElement();
auto& helpers = src_el_ty->Is<type::F32>() ? f32toi_helpers : f16toi_helpers;
- auto* helper = helpers.GetOrCreate(res_ty, [&] {
+ auto* helper = helpers.GetOrAdd(res_ty, [&] {
// Generate a name for the helper function.
StringStream name;
name << "tint_";
diff --git a/src/tint/lang/core/ir/transform/demote_to_helper.cc b/src/tint/lang/core/ir/transform/demote_to_helper.cc
index 6c61cbc..a460c41 100644
--- a/src/tint/lang/core/ir/transform/demote_to_helper.cc
+++ b/src/tint/lang/core/ir/transform/demote_to_helper.cc
@@ -91,7 +91,7 @@
/// @param func the function to check
/// @returns true if @p func contains a discard instruction
bool HasDiscard(Function* func) {
- return function_discard_status.GetOrCreate(func, [&] { return HasDiscard(func->Block()); });
+ return function_discard_status.GetOrAdd(func, [&] { return HasDiscard(func->Block()); });
}
/// Check if a block (transitively) contains a discard instruction.
diff --git a/src/tint/lang/core/ir/transform/direct_variable_access.cc b/src/tint/lang/core/ir/transform/direct_variable_access.cc
index 95a601e..5297c81 100644
--- a/src/tint/lang/core/ir/transform/direct_variable_access.cc
+++ b/src/tint/lang/core/ir/transform/direct_variable_access.cc
@@ -356,7 +356,7 @@
}
// Look to see if this callee signature already has a variant created.
- auto* new_target = (*target_info)->variants_by_sig.GetOrCreate(signature, [&] {
+ auto* new_target = (*target_info)->variants_by_sig.GetOrAdd(signature, [&] {
// New signature.
// Clone the original function to seed the new variant.
diff --git a/src/tint/lang/core/ir/transform/preserve_padding.cc b/src/tint/lang/core/ir/transform/preserve_padding.cc
index 49da61f..79e6588 100644
--- a/src/tint/lang/core/ir/transform/preserve_padding.cc
+++ b/src/tint/lang/core/ir/transform/preserve_padding.cc
@@ -124,7 +124,7 @@
}
// The type contains padding bytes, so call a helper function that decomposes the accesses.
- auto* helper = helpers.GetOrCreate(store_type, [&] {
+ auto* helper = helpers.GetOrAdd(store_type, [&] {
auto* func = b.Function("tint_store_and_preserve_padding", ty.void_());
auto* target = b.FunctionParam("target", ty.ptr(storage, store_type));
auto* value_param = b.FunctionParam("value_param", store_type);
diff --git a/src/tint/lang/core/ir/transform/shader_io.cc b/src/tint/lang/core/ir/transform/shader_io.cc
index a0c6d78..9586785 100644
--- a/src/tint/lang/core/ir/transform/shader_io.cc
+++ b/src/tint/lang/core/ir/transform/shader_io.cc
@@ -228,9 +228,9 @@
/// Finalize any state needed to complete the transform.
void Finalize() {
// Remove IO attributes from all structure members that had them prior to this transform.
- for (auto* member : members_to_strip) {
+ for (auto& member : members_to_strip) {
// TODO(crbug.com/tint/745): Remove the const_cast.
- const_cast<core::type::StructMember*>(member)->SetAttributes({});
+ const_cast<core::type::StructMember*>(member.Value())->SetAttributes({});
}
}
};
diff --git a/src/tint/lang/core/ir/transform/std140.cc b/src/tint/lang/core/ir/transform/std140.cc
index ba436eb..f2caaf5 100644
--- a/src/tint/lang/core/ir/transform/std140.cc
+++ b/src/tint/lang/core/ir/transform/std140.cc
@@ -125,7 +125,7 @@
/// @param type the type to rewrite
/// @returns the new type
const core::type::Type* RewriteType(const core::type::Type* type) {
- return rewritten_types.GetOrCreate(type, [&]() -> const core::type::Type* {
+ return rewritten_types.GetOrAdd(type, [&]() -> const core::type::Type* {
return tint::Switch(
type,
[&](const core::type::Array* arr) -> const core::type::Type* {
@@ -219,7 +219,7 @@
orig_ty, //
[&](const core::type::Struct* str) -> Value* {
// Create a helper function that converts the struct to the original type.
- auto* helper = convert_helpers.GetOrCreate(str, [&] {
+ auto* helper = convert_helpers.GetOrAdd(str, [&] {
auto* input_str = source->Type()->As<core::type::Struct>();
auto* func = b.Function("convert_" + str->FriendlyName(), str);
auto* input = b.FunctionParam("input", input_str);
diff --git a/src/tint/lang/core/ir/transform/value_to_let.cc b/src/tint/lang/core/ir/transform/value_to_let.cc
index bbdcab0..ac51672 100644
--- a/src/tint/lang/core/ir/transform/value_to_let.cc
+++ b/src/tint/lang/core/ir/transform/value_to_let.cc
@@ -84,7 +84,7 @@
Access pending_access = Access::kLoad;
auto put_pending_in_lets = [&] {
- for (auto* pending : pending_resolution) {
+ for (auto& pending : pending_resolution) {
PutInLet(pending);
}
pending_resolution.Clear();
@@ -97,7 +97,7 @@
case 0: // No usage
break;
case 1: { // Single usage
- auto* usage = (*usages.begin()).instruction;
+ auto usage = (*usages.begin())->instruction;
if (usage->Block() == inst->Block()) {
// Usage in same block. Assign to pending_resolution, as we don't
// know whether its safe to inline yet.
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
index db902b4..eeef8b2 100644
--- a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
@@ -105,7 +105,7 @@
if (ptr && ptr->AddressSpace() == core::AddressSpace::kWorkgroup) {
// Record the usage of the variable for each block that references it.
var->Result(0)->ForEachUse([&](const Usage& use) {
- block_to_direct_vars.GetOrZero(use.instruction->Block())->Add(var);
+ block_to_direct_vars.GetOrAddZero(use.instruction->Block()).Add(var);
});
var_to_id.Add(var, next_id++);
}
@@ -184,7 +184,7 @@
/// @param func the function
/// @returns the set of transitively referenced workgroup variables
VarSet GetReferencedVars(Function* func) {
- return function_to_transitive_vars.GetOrCreate(func, [&] {
+ return function_to_transitive_vars.GetOrAdd(func, [&] {
VarSet vars;
GetReferencedVars(func->Block(), vars);
return vars;
@@ -196,8 +196,8 @@
/// @param vars the set of transitively referenced workgroup variables to populate
void GetReferencedVars(Block* block, VarSet& vars) {
// Add directly referenced vars.
- if (auto itr = block_to_direct_vars.Find(block)) {
- for (auto* var : *itr) {
+ if (auto itr = block_to_direct_vars.Get(block)) {
+ for (auto& var : *itr) {
vars.Add(var);
}
}
@@ -209,7 +209,7 @@
[&](UserCall* call) {
// Get variables referenced by a function called from this block.
auto callee_vars = GetReferencedVars(call->Target());
- for (auto* var : callee_vars) {
+ for (auto& var : callee_vars) {
vars.Add(var);
}
},
@@ -234,7 +234,7 @@
StoreMap& stores) {
// If this type can be trivially zeroed, store to the whole element.
if (CanTriviallyZero(type)) {
- stores.GetOrZero(iteration_count)->Push(Store{var, type, indices});
+ stores.GetOrAddZero(iteration_count).Push(Store{var, type, indices});
return;
}
@@ -253,7 +253,7 @@
PrepareStores(var, arr->ElemType(), iteration_count * count, new_indices, stores);
},
[&](const type::Atomic*) {
- stores.GetOrZero(iteration_count)->Push(Store{var, type, indices});
+ stores.GetOrAddZero(iteration_count).Push(Store{var, type, indices});
},
[&](const type::Struct* str) {
for (auto* member : str->Members()) {
diff --git a/src/tint/lang/core/ir/value.cc b/src/tint/lang/core/ir/value.cc
index e062a62..c1a66e8 100644
--- a/src/tint/lang/core/ir/value.cc
+++ b/src/tint/lang/core/ir/value.cc
@@ -55,14 +55,14 @@
while (!uses_.IsEmpty()) {
auto& use = *uses_.begin();
auto* replacement = replacer(use);
- use.instruction->SetOperand(use.operand_index, replacement);
+ use->instruction->SetOperand(use->operand_index, replacement);
}
}
void Value::ReplaceAllUsesWith(Value* replacement) {
while (!uses_.IsEmpty()) {
auto& use = *uses_.begin();
- use.instruction->SetOperand(use.operand_index, replacement);
+ use->instruction->SetOperand(use->operand_index, replacement);
}
}
diff --git a/src/tint/lang/core/ir/value.h b/src/tint/lang/core/ir/value.h
index 55f7463..e2819bc 100644
--- a/src/tint/lang/core/ir/value.h
+++ b/src/tint/lang/core/ir/value.h
@@ -100,7 +100,7 @@
/// @param instruction the instruction
/// @param operand_index the in
bool HasUsage(const Instruction* instruction, size_t operand_index) const {
- return uses_.Contains({const_cast<Instruction*>(instruction), operand_index});
+ return uses_.Contains(Usage{const_cast<Instruction*>(instruction), operand_index});
}
/// Apply a function to all uses of the value that exist prior to calling this method.
diff --git a/src/tint/lang/core/ir/var.cc b/src/tint/lang/core/ir/var.cc
index 11a9287..92f2446 100644
--- a/src/tint/lang/core/ir/var.cc
+++ b/src/tint/lang/core/ir/var.cc
@@ -78,7 +78,7 @@
if (result->Usages().All([](const Usage& u) { return u.instruction->Is<ir::Store>(); })) {
while (!result->Usages().IsEmpty()) {
auto& usage = *result->Usages().begin();
- usage.instruction->Destroy();
+ usage->instruction->Destroy();
}
Destroy();
}
diff --git a/src/tint/lang/core/type/manager.h b/src/tint/lang/core/type/manager.h
index f02290b..0a9a201 100644
--- a/src/tint/lang/core/type/manager.h
+++ b/src/tint/lang/core/type/manager.h
@@ -172,7 +172,7 @@
typename _ = std::enable_if<tint::traits::IsTypeOrDerived<TYPE, Type>>,
typename... ARGS>
auto* Find(ARGS&&... args) const {
- return types_.Find<ToType<TYPE>>(std::forward<ARGS>(args)...);
+ return types_.Find<TYPE>(std::forward<ARGS>(args)...);
}
/// @returns a void type
@@ -501,16 +501,6 @@
TypeIterator end() const { return types_.end(); }
private:
- /// ToType<T> is specialized for various `T` types and each specialization contains a single
- /// `type` alias to the corresponding type deriving from `core::type::Type`.
- template <typename T>
- struct ToTypeImpl {
- using type = T;
- };
-
- template <typename T>
- using ToType = typename ToTypeImpl<T>::type;
-
/// Unique types owned by the manager
UniqueAllocator<Type> types_;
/// Unique nodes (excluding types) owned by the manager
diff --git a/src/tint/lang/core/type/struct.h b/src/tint/lang/core/type/struct.h
index 539c27b..1ea2b0b 100644
--- a/src/tint/lang/core/type/struct.h
+++ b/src/tint/lang/core/type/struct.h
@@ -148,7 +148,7 @@
/// @returns true iff this structure has been used by address space that's
/// host-shareable.
bool IsHostShareable() const {
- for (auto sc : address_space_usage_) {
+ for (auto& sc : address_space_usage_) {
if (core::IsHostShareable(sc)) {
return true;
}
diff --git a/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
index 1a66806..bd3deb4 100644
--- a/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
@@ -398,8 +398,8 @@
auto* src_vec = src_type->As<core::type::Vector>();
TINT_ASSERT(src_vec);
TINT_ASSERT(((src_vec->Width() == 2u) || (src_vec->Width() == 4u)));
- std::string fn = GetOrCreate(
- bitcast_funcs_, BinaryOperandType{{src_type, dst_type}}, [&]() -> std::string {
+ std::string fn =
+ GetOrAdd(bitcast_funcs_, BinaryOperandType{{src_type, dst_type}}, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -452,8 +452,8 @@
auto* dst_vec = dst_type->As<core::type::Vector>();
TINT_ASSERT(dst_vec);
TINT_ASSERT(((dst_vec->Width() == 2u) || (dst_vec->Width() == 4u)));
- std::string fn = GetOrCreate(
- bitcast_funcs_, BinaryOperandType{{src_type, dst_type}}, [&]() -> std::string {
+ std::string fn =
+ GetOrAdd(bitcast_funcs_, BinaryOperandType{{src_type, dst_type}}, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -613,37 +613,37 @@
auto* ret_ty = TypeOf(expr)->UnwrapRef();
auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef();
auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef();
- fn = tint::GetOrCreate(float_modulo_funcs_, BinaryOperandType{{lhs_ty, rhs_ty}},
- [&]() -> std::string {
- TextBuffer b;
- TINT_DEFER(helpers_.Append(b));
+ fn = tint::GetOrAdd(float_modulo_funcs_, BinaryOperandType{{lhs_ty, rhs_ty}},
+ [&]() -> std::string {
+ TextBuffer b;
+ TINT_DEFER(helpers_.Append(b));
- auto fn_name = UniqueIdentifier("tint_float_modulo");
- std::vector<std::string> parameter_names;
- {
- auto decl = Line(&b);
- EmitTypeAndName(decl, ret_ty, core::AddressSpace::kUndefined,
- core::Access::kUndefined, fn_name);
- {
- ScopedParen sp(decl);
- const auto* ty = TypeOf(expr->lhs)->UnwrapRef();
- EmitTypeAndName(decl, ty, core::AddressSpace::kUndefined,
- core::Access::kUndefined, "lhs");
- decl << ", ";
- ty = TypeOf(expr->rhs)->UnwrapRef();
- EmitTypeAndName(decl, ty, core::AddressSpace::kUndefined,
- core::Access::kUndefined, "rhs");
- }
- decl << " {";
- }
- {
- ScopedIndent si(&b);
- Line(&b) << "return (lhs - rhs * trunc(lhs / rhs));";
- }
- Line(&b) << "}";
- Line(&b);
- return fn_name;
- });
+ auto fn_name = UniqueIdentifier("tint_float_modulo");
+ std::vector<std::string> parameter_names;
+ {
+ auto decl = Line(&b);
+ EmitTypeAndName(decl, ret_ty, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, fn_name);
+ {
+ ScopedParen sp(decl);
+ const auto* ty = TypeOf(expr->lhs)->UnwrapRef();
+ EmitTypeAndName(decl, ty, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, "lhs");
+ decl << ", ";
+ ty = TypeOf(expr->rhs)->UnwrapRef();
+ EmitTypeAndName(decl, ty, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, "rhs");
+ }
+ decl << " {";
+ }
+ {
+ ScopedIndent si(&b);
+ Line(&b) << "return (lhs - rhs * trunc(lhs / rhs));";
+ }
+ Line(&b) << "}";
+ Line(&b);
+ return fn_name;
+ });
// Call the helper
out << fn;
@@ -1145,7 +1145,7 @@
if (vec_ty->type()->is_integer_scalar()) {
// GLSL does not have a builtin for dot() with integer vector types.
// Generate the helper function if it hasn't been created already
- fn = tint::GetOrCreate(int_dot_funcs_, vec_ty, [&]() -> std::string {
+ fn = tint::GetOrAdd(int_dot_funcs_, vec_ty, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -2942,7 +2942,7 @@
const sem::BuiltinFn* builtin,
F&& build) {
// Generate the helper function if it hasn't been created already
- auto fn = tint::GetOrCreate(builtins_, builtin, [&]() -> std::string {
+ auto fn = tint::GetOrAdd(builtins_, builtin, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -3009,8 +3009,8 @@
std::string ASTPrinter::StructName(const core::type::Struct* s) {
auto name = s->Name().Name();
if (HasPrefix(name, "__")) {
- name = tint::GetOrCreate(builtin_struct_names_, s,
- [&] { return UniqueIdentifier(name.substr(2)); });
+ name = tint::GetOrAdd(builtin_struct_names_, s,
+ [&] { return UniqueIdentifier(name.substr(2)); });
}
return name;
}
diff --git a/src/tint/lang/glsl/writer/ast_raise/combine_samplers.cc b/src/tint/lang/glsl/writer/ast_raise/combine_samplers.cc
index 819241a..6762ee6 100644
--- a/src/tint/lang/glsl/writer/ast_raise/combine_samplers.cc
+++ b/src/tint/lang/glsl/writer/ast_raise/combine_samplers.cc
@@ -191,8 +191,8 @@
if (IsGlobal(pair)) {
// Both texture and sampler are global; add a new global variable
// to represent the combined sampler (if not already created).
- GetOrCreate(global_combined_texture_samplers_, pair,
- [&] { return CreateCombinedGlobal(texture_var, sampler_var, name); });
+ GetOrAdd(global_combined_texture_samplers_, pair,
+ [&] { return CreateCombinedGlobal(texture_var, sampler_var, name); });
} else {
// Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler.
diff --git a/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc b/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc
index 5ac9fec..2b31e07 100644
--- a/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc
+++ b/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc
@@ -139,7 +139,7 @@
TINT_ICE_ON_NO_MATCH);
},
[&](const sem::Function* user_fn) {
- auto user_param_to_info = fn_to_data.Find(user_fn);
+ auto user_param_to_info = fn_to_data.Get(user_fn);
if (!user_param_to_info) {
// Uninterested function not calling texture builtins with function
// texture param.
@@ -149,11 +149,10 @@
user_fn->Declaration()->params.Length());
for (size_t i = 0; i < call->Arguments().Length(); i++) {
auto param = user_fn->Declaration()->params[i];
- auto info = user_param_to_info->Get(param);
- if (info.has_value()) {
+ if (auto info = user_param_to_info->Get(param)) {
auto* arg = call->Arguments()[i];
auto* texture_sem = arg->RootIdentifier();
- auto& args = call_to_data.GetOrCreate(call_expr, [&] {
+ auto& args = call_to_data.GetOrAdd(call_expr, [&] {
return Vector<
std::variant<BindingPoint, const ast::Parameter*>, 4>();
});
@@ -183,8 +182,8 @@
// If any functions need extra params, add them now.
if (!fn_to_data.IsEmpty()) {
- for (auto pair : fn_to_data) {
- auto* fn = pair.key;
+ for (auto& pair : fn_to_data) {
+ auto* fn = pair.key.Value();
// Reorder the param to a vector to make sure params are in the correct order.
Vector<const ast::Parameter*, 4> extra_params_in_order;
@@ -201,8 +200,8 @@
}
// Replace all interested texture builtin calls.
- for (auto pair : builtin_to_replace) {
- auto call = pair.key;
+ for (auto& pair : builtin_to_replace) {
+ auto call = pair.key.Value();
if (std::holds_alternative<BindingPoint>(pair.value)) {
// This texture is a global variable with binding point.
// Read builtin value from uniform buffer.
@@ -216,8 +215,8 @@
}
// Insert all extra args to interested function calls.
- for (auto pair : call_to_data) {
- auto call = pair.key;
+ for (auto& pair : call_to_data) {
+ auto call = pair.key.Value();
for (auto new_arg_info : pair.value) {
if (std::holds_alternative<BindingPoint>(new_arg_info)) {
// This texture is a global variable with binding point.
@@ -363,7 +362,7 @@
// Load the builtin value from the UBO.
auto member_sym = bindpoint_to_syms.Get(binding);
- TINT_ASSERT(member_sym.has_value());
+ TINT_ASSERT(member_sym);
return b.MemberAccessor(ubo_sym, *member_sym);
}
@@ -383,7 +382,7 @@
/// @returns the new u32 function parameter.
const ast::Parameter* GetAndRecordFunctionParameter(const sem::Function* fn,
const sem::Variable* var) {
- auto& param_to_info = fn_to_data.GetOrCreate(
+ auto& param_to_info = fn_to_data.GetOrAdd(
fn, [&] { return Hashmap<const ast::Parameter*, FunctionExtraParamInfo, 4>(); });
const ast::Parameter* param = nullptr;
@@ -397,7 +396,7 @@
// Get or record a new u32 param to this function if first visited.
auto entry = param_to_info.Get(param);
- if (entry.has_value()) {
+ if (entry) {
return entry->param;
}
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
index 9070b70..f2a8f82 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
@@ -460,7 +460,7 @@
bool ASTPrinter::EmitDynamicVectorAssignment(const ast::AssignmentStatement* stmt,
const core::type::Vector* vec) {
- auto name = tint::GetOrCreate(dynamic_vector_write_, vec, [&]() -> std::string {
+ auto name = tint::GetOrAdd(dynamic_vector_write_, vec, [&]() -> std::string {
std::string fn = UniqueIdentifier("set_vector_element");
{
auto out = Line(&helpers_);
@@ -525,7 +525,7 @@
bool ASTPrinter::EmitDynamicMatrixVectorAssignment(const ast::AssignmentStatement* stmt,
const core::type::Matrix* mat) {
- auto name = tint::GetOrCreate(dynamic_matrix_vector_write_, mat, [&]() -> std::string {
+ auto name = tint::GetOrAdd(dynamic_matrix_vector_write_, mat, [&]() -> std::string {
std::string fn = UniqueIdentifier("set_matrix_column");
{
auto out = Line(&helpers_);
@@ -586,7 +586,7 @@
auto* lhs_row_access = stmt->lhs->As<ast::IndexAccessorExpression>();
auto* lhs_col_access = lhs_row_access->object->As<ast::IndexAccessorExpression>();
- auto name = tint::GetOrCreate(dynamic_matrix_scalar_write_, mat, [&]() -> std::string {
+ auto name = tint::GetOrAdd(dynamic_matrix_scalar_write_, mat, [&]() -> std::string {
std::string fn = UniqueIdentifier("set_matrix_scalar");
{
auto out = Line(&helpers_);
@@ -722,7 +722,7 @@
// f32tof16 to get the bits. This should be safe, because the convertion is precise
// for finite and infinite f16 value as they are exactly representable by f32, and
// WGSL spec allow any result if f16 value is NaN.
- return tint::GetOrCreate(
+ return tint::GetOrAdd(
bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -794,7 +794,7 @@
// convertion is precise for finite and infinite f16 result value as they are
// exactly representable by f32, and WGSL spec allow any result if f16 result value
// would be NaN.
- return tint::GetOrCreate(
+ return tint::GetOrAdd(
bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -4689,7 +4689,7 @@
const sem::BuiltinFn* builtin,
F&& build) {
// Generate the helper function if it hasn't been created already
- auto fn = tint::GetOrCreate(builtins_, builtin, [&]() -> std::string {
+ auto fn = tint::GetOrAdd(builtins_, builtin, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -4758,8 +4758,8 @@
std::string ASTPrinter::StructName(const core::type::Struct* s) {
auto name = s->Name().Name();
if (HasPrefix(name, "__")) {
- name = tint::GetOrCreate(builtin_struct_names_, s,
- [&] { return UniqueIdentifier(name.substr(2)); });
+ name = tint::GetOrAdd(builtin_struct_names_, s,
+ [&] { return UniqueIdentifier(name.substr(2)); });
}
return name;
}
diff --git a/src/tint/lang/hlsl/writer/ast_raise/calculate_array_length.cc b/src/tint/lang/hlsl/writer/ast_raise/calculate_array_length.cc
index 3dd8c8a..25a23ea 100644
--- a/src/tint/lang/hlsl/writer/ast_raise/calculate_array_length.cc
+++ b/src/tint/lang/hlsl/writer/ast_raise/calculate_array_length.cc
@@ -116,7 +116,7 @@
// [RW]ByteAddressBuffer.GetDimensions().
std::unordered_map<const core::type::Reference*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const core::type::Reference* buffer_type) {
- return tint::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
+ return tint::GetOrAdd(buffer_size_intrinsics, buffer_type, [&] {
auto name = b.Sym();
auto type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
@@ -195,7 +195,7 @@
auto* block = call->Stmt()->Block()->Declaration();
auto array_length =
- tint::GetOrCreate(array_length_by_usage, {block, storage_buffer_var}, [&] {
+ tint::GetOrAdd(array_length_by_usage, {block, storage_buffer_var}, [&] {
// First time this array length is used for this block.
// Let's calculate it.
diff --git a/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc
index a6e5955..91e67b5 100644
--- a/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc
@@ -482,7 +482,7 @@
Symbol LoadFunc(const core::type::Type* el_ty,
core::AddressSpace address_space,
const Symbol& buffer) {
- return tint::GetOrCreate(load_funcs, LoadStoreKey{el_ty, buffer}, [&] {
+ return tint::GetOrAdd(load_funcs, LoadStoreKey{el_ty, buffer}, [&] {
Vector params{b.Param("offset", b.ty.u32())};
auto name = b.Symbols().New(buffer.Name() + "_load");
@@ -562,7 +562,7 @@
/// @param buffer the symbol of the storage buffer variable, owned by the target ProgramBuilder.
/// @return the name of the function that performs the store
Symbol StoreFunc(const core::type::Type* el_ty, const Symbol& buffer) {
- return tint::GetOrCreate(store_funcs, LoadStoreKey{el_ty, buffer}, [&] {
+ return tint::GetOrAdd(store_funcs, LoadStoreKey{el_ty, buffer}, [&] {
Vector params{
b.Param("offset", b.ty.u32()),
b.Param("value", CreateASTTypeFor(ctx, el_ty)),
@@ -653,7 +653,7 @@
const sem::BuiltinFn* builtin,
const Symbol& buffer) {
auto fn = builtin->Fn();
- return tint::GetOrCreate(atomic_funcs, AtomicKey{el_ty, fn, buffer}, [&] {
+ return tint::GetOrAdd(atomic_funcs, AtomicKey{el_ty, fn, buffer}, [&] {
// The first parameter to all WGSL atomics is the expression to the
// atomic. This is replaced with two parameters: the buffer and offset.
Vector params{b.Param("offset", b.ty.u32())};
diff --git a/src/tint/lang/hlsl/writer/ast_raise/remove_continue_in_switch.cc b/src/tint/lang/hlsl/writer/ast_raise/remove_continue_in_switch.cc
index db94a0f..899ea10 100644
--- a/src/tint/lang/hlsl/writer/ast_raise/remove_continue_in_switch.cc
+++ b/src/tint/lang/hlsl/writer/ast_raise/remove_continue_in_switch.cc
@@ -68,7 +68,7 @@
continue;
}
- auto& info = switch_infos.GetOrCreate(switch_stmt, [&] {
+ auto& info = switch_infos.GetOrAdd(switch_stmt, [&] {
switch_stmts.Push(switch_stmt);
auto* block = sem.Get(switch_stmt)->FindFirstParent<sem::LoopBlockStatement>();
return SwitchInfo{/* loop_block */ block, /* continues */ Empty};
@@ -89,7 +89,7 @@
for (auto* switch_stmt : switch_stmts) {
const auto& info = switch_infos.Get(switch_stmt);
- auto var_name = loop_to_var.GetOrCreate(info->loop_block, [&] {
+ auto var_name = loop_to_var.GetOrAdd(info->loop_block, [&] {
// Create and insert 'var tint_continue : bool;' before loop
auto var = b.Symbols().New("tint_continue");
auto* decl = b.Decl(b.Var(var, b.ty.bool_()));
diff --git a/src/tint/lang/hlsl/writer/ast_raise/truncate_interstage_variables.cc b/src/tint/lang/hlsl/writer/ast_raise/truncate_interstage_variables.cc
index 187834e..5d38ff8 100644
--- a/src/tint/lang/hlsl/writer/ast_raise/truncate_interstage_variables.cc
+++ b/src/tint/lang/hlsl/writer/ast_raise/truncate_interstage_variables.cc
@@ -129,34 +129,33 @@
// Get or create a new truncated struct/truncate function for the interstage inputs &
// outputs.
- auto entry =
- old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] {
- auto new_struct_sym = b.Symbols().New();
+ auto entry = old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrAdd(str, [&] {
+ auto new_struct_sym = b.Symbols().New();
- Vector<const ast::StructMember*, 20> truncated_members;
- Vector<const ast::Expression*, 20> initializer_exprs;
+ Vector<const ast::StructMember*, 20> truncated_members;
+ Vector<const ast::Expression*, 20> initializer_exprs;
- for (auto* member : str->Members()) {
- if (omit_members.Contains(member)) {
- continue;
- }
-
- truncated_members.Push(ctx.Clone(member->Declaration()));
- initializer_exprs.Push(b.MemberAccessor("io", ctx.Clone(member->Name())));
+ for (auto* member : str->Members()) {
+ if (omit_members.Contains(member)) {
+ continue;
}
- // Create the new shader io struct.
- b.Structure(new_struct_sym, std::move(truncated_members));
+ truncated_members.Push(ctx.Clone(member->Declaration()));
+ initializer_exprs.Push(b.MemberAccessor("io", ctx.Clone(member->Name())));
+ }
- // Create the mapping function to truncate the shader io.
- auto mapping_fn_sym = b.Symbols().New("truncate_shader_output");
- b.Func(mapping_fn_sym, Vector{b.Param("io", ctx.Clone(func_ast->return_type))},
- b.ty(new_struct_sym),
- Vector{
- b.Return(b.Call(new_struct_sym, std::move(initializer_exprs))),
- });
- return TruncatedStructAndConverter{new_struct_sym, mapping_fn_sym};
- });
+ // Create the new shader io struct.
+ b.Structure(new_struct_sym, std::move(truncated_members));
+
+ // Create the mapping function to truncate the shader io.
+ auto mapping_fn_sym = b.Symbols().New("truncate_shader_output");
+ b.Func(mapping_fn_sym, Vector{b.Param("io", ctx.Clone(func_ast->return_type))},
+ b.ty(new_struct_sym),
+ Vector{
+ b.Return(b.Call(new_struct_sym, std::move(initializer_exprs))),
+ });
+ return TruncatedStructAndConverter{new_struct_sym, mapping_fn_sym};
+ });
ctx.Replace(func_ast->return_type.expr, b.Expr(entry.truncated_struct));
@@ -172,7 +171,7 @@
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
auto* return_sem = sem.Get(return_statement);
if (auto mapping_fn_sym =
- entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
+ entry_point_functions_to_truncate_functions.Get(return_sem->Function())) {
return b.Return(return_statement->source,
b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value)));
}
@@ -181,7 +180,7 @@
// Remove IO attributes from old shader IO struct which is not used as entry point output
// anymore.
- for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
+ for (auto& it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
const ast::Struct* struct_ty = it.key->Declaration();
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index d7d744c..4c987c9 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -947,7 +947,7 @@
auto sc = ptr_ty->AddressSpace();
auto* str = builtin->ReturnType()->As<core::type::Struct>();
- auto func = tint::GetOrCreate(
+ auto func = tint::GetOrAdd(
atomicCompareExchangeWeak_, ACEWKeyType{{sc, str}}, [&]() -> std::string {
if (!EmitStructType(&helpers_,
builtin->ReturnType()->As<core::type::Struct>())) {
@@ -1342,7 +1342,7 @@
if (vec_ty->type()->is_integer_scalar()) {
// MSL does not have a builtin for dot() with integer vector types.
// Generate the helper function if it hasn't been created already
- fn = tint::GetOrCreate(int_dot_funcs_, vec_ty->Width(), [&]() -> std::string {
+ fn = tint::GetOrAdd(int_dot_funcs_, vec_ty->Width(), [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -2872,7 +2872,7 @@
// largest negative value, it returns `e`.
auto* expr_type = TypeOf(expr->expr)->UnwrapRef();
if (expr->op == core::UnaryOp::kNegation && expr_type->is_signed_integer_scalar_or_vector()) {
- auto fn = tint::GetOrCreate(unary_minus_funcs_, expr_type, [&]() -> std::string {
+ auto fn = tint::GetOrAdd(unary_minus_funcs_, expr_type, [&]() -> std::string {
// e.g.:
// int tint_unary_minus(const int v) {
// return (v == -2147483648) ? v : -v;
@@ -3039,7 +3039,7 @@
const sem::BuiltinFn* builtin,
F&& build) {
// Generate the helper function if it hasn't been created already
- auto fn = tint::GetOrCreate(builtins_, builtin, [&]() -> std::string {
+ auto fn = tint::GetOrAdd(builtins_, builtin, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
@@ -3122,8 +3122,8 @@
std::string ASTPrinter::StructName(const core::type::Struct* s) {
auto name = s->Name().Name();
if (HasPrefix(name, "__")) {
- name = tint::GetOrCreate(builtin_struct_names_, s,
- [&] { return UniqueIdentifier(name.substr(2)); });
+ name = tint::GetOrAdd(builtin_struct_names_, s,
+ [&] { return UniqueIdentifier(name.substr(2)); });
}
return name;
}
diff --git a/src/tint/lang/msl/writer/ast_raise/packed_vec3.cc b/src/tint/lang/msl/writer/ast_raise/packed_vec3.cc
index a61e3f8..a1f2321 100644
--- a/src/tint/lang/msl/writer/ast_raise/packed_vec3.cc
+++ b/src/tint/lang/msl/writer/ast_raise/packed_vec3.cc
@@ -139,7 +139,7 @@
// Create a struct with a single `__packed_vec3` member.
// Give the struct member the same alignment as the original unpacked vec3
// type, to avoid changing the array element stride.
- return b.ty(packed_vec3_wrapper_struct_names.GetOrCreate(vec, [&] {
+ return b.ty(packed_vec3_wrapper_struct_names.GetOrAdd(vec, [&] {
auto name = b.Symbols().New(
"tint_packed_vec3_" + vec->type()->FriendlyName() +
(array_element ? "_array_element" : "_struct_member"));
@@ -181,7 +181,7 @@
},
[&](const core::type::Struct* str) -> ast::Type {
if (ContainsVec3(str)) {
- auto name = rewritten_structs.GetOrCreate(str, [&] {
+ auto name = rewritten_structs.GetOrAdd(str, [&] {
tint::Vector<const ast::StructMember*, 4> members;
for (auto* member : str->Members()) {
// If the member type contains a vec3, rewrite it.
@@ -317,7 +317,7 @@
/// @returns an expression that holds the unpacked value
const ast::Expression* UnpackComposite(const ast::Expression* expr,
const core::type::Type* ty) {
- auto helper = unpack_helpers.GetOrCreate(ty, [&] {
+ auto helper = unpack_helpers.GetOrAdd(ty, [&] {
return MakePackUnpackHelper(
"tint_unpack_vec3_in_composite", ty,
[&](const ast::Expression* element,
@@ -345,7 +345,7 @@
/// @param ty the unpacked type
/// @returns an expression that holds the packed value
const ast::Expression* PackComposite(const ast::Expression* expr, const core::type::Type* ty) {
- auto helper = pack_helpers.GetOrCreate(ty, [&] {
+ auto helper = pack_helpers.GetOrAdd(ty, [&] {
return MakePackUnpackHelper(
"tint_pack_vec3_in_composite", ty,
[&](const ast::Expression* element,
diff --git a/src/tint/lang/msl/writer/common/option_helpers.cc b/src/tint/lang/msl/writer/common/option_helpers.cc
index 2561e9f..b655301 100644
--- a/src/tint/lang/msl/writer/common/option_helpers.cc
+++ b/src/tint/lang/msl/writer/common/option_helpers.cc
@@ -52,7 +52,7 @@
auto wgsl_seen = [&diagnostics, &seen_wgsl_bindings](const tint::BindingPoint& src,
const binding::BindingInfo& dst) -> bool {
- if (auto binding = seen_wgsl_bindings.Find(src)) {
+ if (auto binding = seen_wgsl_bindings.Get(src)) {
if (*binding != dst) {
std::stringstream str;
str << "found duplicate WGSL binding point: " << src;
@@ -67,7 +67,7 @@
auto msl_seen = [&diagnostics](InfoToPointMap& map, const binding::BindingInfo& src,
const tint::BindingPoint& dst) -> bool {
- if (auto binding = map.Find(src)) {
+ if (auto binding = map.Get(src)) {
if (*binding != dst) {
std::stringstream str;
str << "found duplicate MSL binding point: [binding: " << src.binding << "]";
diff --git a/src/tint/lang/msl/writer/helpers/generate_bindings.cc b/src/tint/lang/msl/writer/helpers/generate_bindings.cc
index a0eaf7e..71a5803 100644
--- a/src/tint/lang/msl/writer/helpers/generate_bindings.cc
+++ b/src/tint/lang/msl/writer/helpers/generate_bindings.cc
@@ -59,7 +59,7 @@
for (auto* var : program.AST().GlobalVariables()) {
if (auto* sem_var = program.Sem().Get(var)->As<sem::GlobalVariable>()) {
if (auto bp = sem_var->Attributes().binding_point) {
- if (auto val = group_to_next_binding_number.Find(bp->group)) {
+ if (auto val = group_to_next_binding_number.Get(bp->group)) {
*val = std::max(*val, bp->binding + 1);
} else {
group_to_next_binding_number.Add(bp->group, bp->binding + 1);
@@ -109,14 +109,12 @@
for (auto bp : ext_tex_bps) {
uint32_t g = bp.group;
- uint32_t next_num = *(group_to_next_binding_number.GetOrZero(g));
+ uint32_t& next_num = group_to_next_binding_number.GetOrAddZero(g);
binding::BindingInfo plane0{bp.binding};
binding::BindingInfo plane1{next_num++};
binding::BindingInfo metadata{next_num++};
- group_to_next_binding_number.Replace(g, next_num);
-
bindings.external_texture.emplace(bp, binding::ExternalTexture{metadata, plane0, plane1});
}
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 0e8988e..3c34b32 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -1432,8 +1432,8 @@
std::string StructName(const core::type::Struct* s) {
auto name = s->Name().Name();
if (HasPrefix(name, "__")) {
- name = tint::GetOrCreate(builtin_struct_names_, s,
- [&] { return UniqueIdentifier(name.substr(2)); });
+ name = tint::GetOrAdd(builtin_struct_names_, s,
+ [&] { return UniqueIdentifier(name.substr(2)); });
}
return name;
}
@@ -1442,7 +1442,7 @@
/// @returns the name of the given value, creating a new unique name if the value is unnamed in
/// the module.
std::string NameOf(const core::ir::Value* value) {
- return names_.GetOrCreate(value, [&] {
+ return names_.GetOrAdd(value, [&] {
if (auto sym = ir_.NameOf(value); sym.IsValid()) {
return sym.Name();
}
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_array.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_array.cc
index 47e242b..bba5bb1 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_array.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_array.cc
@@ -107,7 +107,7 @@
return nullptr;
}
if (!arr->IsStrideImplicit()) {
- auto el_ty = tint::GetOrCreate(decomposed, arr, [&] {
+ auto el_ty = tint::GetOrAdd(decomposed, arr, [&] {
auto name = b.Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
auto* member = b.Member(kMemberName, ast::Type{member_ty},
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
index 9521329..e28ff41 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
@@ -150,8 +150,8 @@
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
if (auto* access = src.Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
- if (auto info = decomposed.Find(access->Member())) {
- auto fn = tint::GetOrCreate(mat_to_arr, *info, [&] {
+ if (auto info = decomposed.Get(access->Member())) {
+ auto fn = tint::GetOrAdd(mat_to_arr, *info, [&] {
auto name =
b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
std::to_string(info->matrix->rows()) + "_stride_" +
@@ -189,8 +189,8 @@
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
if (auto* access = src.Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
- if (auto info = decomposed.Find(access->Member())) {
- auto fn = tint::GetOrCreate(arr_to_mat, *info, [&] {
+ if (auto info = decomposed.Get(access->Member())) {
+ auto fn = tint::GetOrAdd(arr_to_mat, *info, [&] {
auto name =
b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
"x" + std::to_string(info->matrix->rows()) + "_stride_" +
diff --git a/src/tint/lang/spirv/reader/ast_lower/fold_trivial_lets.cc b/src/tint/lang/spirv/reader/ast_lower/fold_trivial_lets.cc
index 56fa298..01084c4 100644
--- a/src/tint/lang/spirv/reader/ast_lower/fold_trivial_lets.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/fold_trivial_lets.cc
@@ -73,8 +73,7 @@
auto fold_lets = [&](const ast::Expression* expr) {
ast::TraverseExpressions(expr, [&](const ast::IdentifierExpression* ident) {
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
- auto itr = pending_lets.Find(user->Variable());
- if (itr) {
+ if (auto itr = pending_lets.Get(user->Variable())) {
TINT_ASSERT(itr->remaining_uses > 0);
// We found a reference to a pending let, so replace it with the inlined
@@ -83,7 +82,7 @@
// Decrement the remaining uses count and remove the let declaration if this
// was the last remaining use.
- if (--itr->remaining_uses == 0) {
+ if (--(itr->remaining_uses) == 0) {
ctx.Remove(block->statements, itr->decl);
}
}
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
index cfaf300..6bfb652 100644
--- a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
@@ -133,7 +133,7 @@
/// @param type the type of the parameter
/// @returns the name of the parameter
Symbol GetParameter(const ast::Function* func, const ast::Type& type) {
- return func_to_param.GetOrCreate(func, [&] {
+ return func_to_param.GetOrAdd(func, [&] {
// Append a new parameter to the function.
auto name = b.Symbols().New("tint_wgid");
ctx.InsertBack(func->params, b.Param(name, ctx.Clone(type)));
diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
index e8b5aeb..3b835c3 100644
--- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
@@ -1528,7 +1528,7 @@
// TODO(dneto): initializers (a.k.a. initializer expression)
if (ast_var) {
builder_.AST().AddGlobalVariable(ast_var);
- module_variable_.GetOrCreate(var.result_id(), [&] {
+ module_variable_.GetOrAdd(var.result_id(), [&] {
return ModuleVariable{ast_var, ast_address_space, ast_access};
});
}
@@ -1564,7 +1564,7 @@
storage_type, ast_initializer, {});
builder_.AST().AddGlobalVariable(ast_var);
- module_variable_.GetOrCreate(builtin_position_.per_vertex_var_id, [&] {
+ module_variable_.GetOrAdd(builtin_position_.per_vertex_var_id, [&] {
return ModuleVariable{ast_var, ast_address_space};
});
}
diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.h b/src/tint/lang/spirv/reader/ast_parser/ast_parser.h
index ddd8ded..86f60e7 100644
--- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.h
+++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.h
@@ -695,8 +695,7 @@
/// @param id a SPIR-V ID
/// @returns the AST variable or null.
ModuleVariable GetModuleVariable(uint32_t id) {
- auto entry = module_variable_.Find(id);
- return entry ? *entry : ModuleVariable{};
+ return module_variable_.GetOr(id, ModuleVariable{});
}
/// Returns the channel component type corresponding to the given image
diff --git a/src/tint/lang/spirv/reader/ast_parser/function.cc b/src/tint/lang/spirv/reader/ast_parser/function.cc
index 2ba5673..73d5578 100644
--- a/src/tint/lang/spirv/reader/ast_parser/function.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/function.cc
@@ -3428,7 +3428,7 @@
auto copy_name = namer_.MakeDerivedName(namer_.Name(phi_id) + "_c" +
std::to_string(block_info.id));
auto copy_sym = builder_.Symbols().Register(copy_name);
- copied_phis.GetOrCreate(phi_id, [copy_sym] { return copy_sym; });
+ copied_phis.GetOrAdd(phi_id, [copy_sym] { return copy_sym; });
AddStatement(builder_.WrapInStatement(
builder_.Let(copy_sym, builder_.Expr(namer_.Name(phi_id)))));
}
@@ -3439,7 +3439,7 @@
const auto phi_id = assignment.phi_id;
auto* const lhs_expr = builder_.Expr(namer_.Name(phi_id));
// If RHS value is actually a phi we just cpatured, then use it.
- auto copy_sym = copied_phis.Find(assignment.value_id);
+ auto copy_sym = copied_phis.Get(assignment.value_id);
auto* const rhs_expr =
copy_sym ? builder_.Expr(*copy_sym) : MakeExpression(assignment.value_id).expr;
AddStatement(builder_.Assign(lhs_expr, rhs_expr));
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index cfbc8cc..1976c38 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -165,7 +165,7 @@
/// @returns a Tint type object
const core::type::Type* Type(const spvtools::opt::analysis::Type* type,
core::Access access_mode = core::Access::kUndefined) {
- return types_.GetOrCreate(TypeKey{type, access_mode}, [&]() -> const core::type::Type* {
+ return types_.GetOrAdd(TypeKey{type, access_mode}, [&]() -> const core::type::Type* {
switch (type->kind()) {
case spvtools::opt::analysis::Type::kVoid:
return ty_.void_();
@@ -324,7 +324,7 @@
/// @param id a SPIR-V result ID for a function declaration instruction
/// @returns a Tint function object
core::ir::Function* Function(uint32_t id) {
- return functions_.GetOrCreate(id, [&] {
+ return functions_.GetOrAdd(id, [&] {
return b_.Function(ty_.void_(), core::ir::Function::PipelineStage::kUndefined,
std::nullopt);
});
@@ -333,7 +333,7 @@
/// @param id a SPIR-V result ID
/// @returns a Tint value object
core::ir::Value* Value(uint32_t id) {
- return values_.GetOrCreate(id, [&]() -> core::ir::Value* {
+ return values_.GetOrAdd(id, [&]() -> core::ir::Value* {
if (auto* c = spirv_context_->get_constant_mgr()->FindDeclaredConstant(id)) {
return b_.Constant(Constant(c));
}
@@ -477,7 +477,7 @@
// Handle OpExecutionMode declarations.
for (auto& execution_mode : spirv_context_->module()->execution_modes()) {
- auto* func = functions_.Get(execution_mode.GetSingleWordInOperand(0)).value_or(nullptr);
+ auto* func = functions_.GetOr(execution_mode.GetSingleWordInOperand(0), nullptr);
auto mode = execution_mode.GetSingleWordInOperand(1);
TINT_ASSERT_OR_RETURN(func);
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.cc b/src/tint/lang/spirv/writer/ast_printer/builder.cc
index 143a7ab..bcd175e 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.cc
@@ -622,7 +622,7 @@
}
uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
- return tint::GetOrCreate(func_sig_to_id_, func->Signature(), [&]() -> uint32_t {
+ return tint::GetOrAdd(func_sig_to_id_, func->Signature(), [&]() -> uint32_t {
auto func_op = result_op();
auto func_type_id = std::get<uint32_t>(func_op);
@@ -1400,7 +1400,7 @@
? scope_stack_[0] // Global scope
: scope_stack_.back(); // Lexical scope
- return tint::GetOrCreate(stack.type_init_to_id_, OperandListKey{ops}, [&]() -> uint32_t {
+ return tint::GetOrAdd(stack.type_init_to_id_, OperandListKey{ops}, [&]() -> uint32_t {
auto result = result_op();
ops[kOpsResultIdx] = result;
@@ -1640,13 +1640,13 @@
}
auto& global_scope = scope_stack_[0];
- return tint::GetOrCreate(global_scope.type_init_to_id_, OperandListKey{ops},
- [&]() -> uint32_t {
- auto result = result_op();
- ops[kOpsResultIdx] = result;
- module_.PushType(spv::Op::OpConstantComposite, std::move(ops));
- return std::get<uint32_t>(result);
- });
+ return tint::GetOrAdd(global_scope.type_init_to_id_, OperandListKey{ops},
+ [&]() -> uint32_t {
+ auto result = result_op();
+ ops[kOpsResultIdx] = result;
+ module_.PushType(spv::Op::OpConstantComposite, std::move(ops));
+ return std::get<uint32_t>(result);
+ });
};
return Switch(
@@ -1765,7 +1765,7 @@
return 0;
}
- return tint::GetOrCreate(const_null_to_id_, type, [&] {
+ return tint::GetOrAdd(const_null_to_id_, type, [&] {
auto result = result_op();
module_.PushType(spv::Op::OpConstantNull, {Operand(type_id), result});
@@ -1782,7 +1782,7 @@
}
uint64_t key = (static_cast<uint64_t>(type->Width()) << 32) + value_id;
- return tint::GetOrCreate(const_splat_to_id_, key, [&] {
+ return tint::GetOrAdd(const_splat_to_id_, key, [&] {
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
@@ -3286,7 +3286,7 @@
}
uint32_t sampled_image_type_id =
- tint::GetOrCreate(texture_type_to_sampled_image_type_id_, texture_type, [&] {
+ tint::GetOrAdd(texture_type_to_sampled_image_type_id_, texture_type, [&] {
// We need to create the sampled image type and cache the result.
auto sampled_image_type = result_op();
auto texture_type_id = GenerateTypeIfNeeded(texture_type);
@@ -3667,7 +3667,7 @@
core::Access::kReadWrite);
}
- return tint::GetOrCreate(type_to_id_, type, [&]() -> uint32_t {
+ return tint::GetOrAdd(type_to_id_, type, [&]() -> uint32_t {
auto result = result_op();
auto id = std::get<uint32_t>(result);
bool ok = Switch(
diff --git a/src/tint/lang/spirv/writer/ast_raise/vectorize_matrix_conversions.cc b/src/tint/lang/spirv/writer/ast_raise/vectorize_matrix_conversions.cc
index d443f3f..31505a0 100644
--- a/src/tint/lang/spirv/writer/ast_raise/vectorize_matrix_conversions.cc
+++ b/src/tint/lang/spirv/writer/ast_raise/vectorize_matrix_conversions.cc
@@ -137,7 +137,7 @@
});
} else {
// If has side effects, use a helper function.
- auto fn = tint::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
+ auto fn = tint::GetOrAdd(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
auto name = b.Symbols().New("convert_mat" + std::to_string(src_type->columns()) +
"x" + std::to_string(src_type->rows()) + "_" +
src_type->type()->FriendlyName() + "_" +
diff --git a/src/tint/lang/spirv/writer/common/option_helper.cc b/src/tint/lang/spirv/writer/common/option_helper.cc
index aec904a..c1a9d10 100644
--- a/src/tint/lang/spirv/writer/common/option_helper.cc
+++ b/src/tint/lang/spirv/writer/common/option_helper.cc
@@ -46,7 +46,7 @@
auto wgsl_seen = [&diagnostics, &seen_wgsl_bindings](const tint::BindingPoint& src,
const binding::BindingInfo& dst) -> bool {
- if (auto binding = seen_wgsl_bindings.Find(src)) {
+ if (auto binding = seen_wgsl_bindings.Get(src)) {
if (*binding != dst) {
std::stringstream str;
str << "found duplicate WGSL binding point: " << src;
@@ -61,7 +61,7 @@
auto spirv_seen = [&diagnostics, &seen_spirv_bindings](const binding::BindingInfo& src,
const tint::BindingPoint& dst) -> bool {
- if (auto binding = seen_spirv_bindings.Find(src)) {
+ if (auto binding = seen_spirv_bindings.Get(src)) {
if (*binding != dst) {
std::stringstream str;
str << "found duplicate SPIR-V binding point: [group: " << src.group
diff --git a/src/tint/lang/spirv/writer/helpers/ast_generate_bindings.cc b/src/tint/lang/spirv/writer/helpers/ast_generate_bindings.cc
index cd40f27..0659c48 100644
--- a/src/tint/lang/spirv/writer/helpers/ast_generate_bindings.cc
+++ b/src/tint/lang/spirv/writer/helpers/ast_generate_bindings.cc
@@ -56,7 +56,7 @@
for (auto* var : program.AST().GlobalVariables()) {
if (auto* sem_var = program.Sem().Get(var)->As<sem::GlobalVariable>()) {
if (auto bp = sem_var->Attributes().binding_point) {
- if (auto val = group_to_next_binding_number.Find(bp->group)) {
+ if (auto val = group_to_next_binding_number.Get(bp->group)) {
*val = std::max(*val, bp->binding + 1);
} else {
group_to_next_binding_number.Add(bp->group, bp->binding + 1);
@@ -106,14 +106,12 @@
for (auto bp : ext_tex_bps) {
uint32_t g = bp.group;
- uint32_t next_num = *(group_to_next_binding_number.GetOrZero(g));
+ uint32_t& next_num = group_to_next_binding_number.GetOrAddZero(g);
binding::BindingInfo plane0{bp.group, bp.binding};
binding::BindingInfo plane1{g, next_num++};
binding::BindingInfo metadata{g, next_num++};
- group_to_next_binding_number.Replace(g, next_num);
-
bindings.external_texture.emplace(bp, binding::ExternalTexture{metadata, plane0, plane1});
}
diff --git a/src/tint/lang/spirv/writer/helpers/generate_bindings.cc b/src/tint/lang/spirv/writer/helpers/generate_bindings.cc
index 1cf3b6a..ad38631 100644
--- a/src/tint/lang/spirv/writer/helpers/generate_bindings.cc
+++ b/src/tint/lang/spirv/writer/helpers/generate_bindings.cc
@@ -56,7 +56,7 @@
}
auto* var = inst->As<core::ir::Var>();
if (auto bp = var->BindingPoint()) {
- if (auto val = group_to_next_binding_number.Find(bp->group)) {
+ if (auto val = group_to_next_binding_number.Get(bp->group)) {
*val = std::max(*val, bp->binding + 1);
} else {
group_to_next_binding_number.Add(bp->group, bp->binding + 1);
@@ -104,14 +104,12 @@
for (auto bp : ext_tex_bps) {
uint32_t g = bp.group;
- uint32_t next_num = *(group_to_next_binding_number.GetOrZero(g));
+ uint32_t& next_num = group_to_next_binding_number.GetOrAddZero(g);
binding::BindingInfo plane0{bp.group, bp.binding};
binding::BindingInfo plane1{g, next_num++};
binding::BindingInfo metadata{g, next_num++};
- group_to_next_binding_number.Replace(g, next_num);
-
bindings.external_texture.emplace(bp, binding::ExternalTexture{metadata, plane0, plane1});
}
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index cc4e35c..9f243f4 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -386,7 +386,7 @@
/// @param constant the constant to get the ID for
/// @returns the result ID of the constant
uint32_t Constant(const core::constant::Value* constant) {
- return constants_.GetOrCreate(constant, [&] {
+ return constants_.GetOrAdd(constant, [&] {
auto* ty = constant->Type();
// Use OpConstantNull for zero-valued composite constants.
@@ -455,7 +455,7 @@
/// @param type the type to get the ID for
/// @returns the result ID of the OpConstantNull instruction
uint32_t ConstantNull(const core::type::Type* type) {
- return constant_nulls_.GetOrCreate(type, [&] {
+ return constant_nulls_.GetOrAdd(type, [&] {
auto id = module_.NextId();
module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
return id;
@@ -466,7 +466,7 @@
/// @param type the type of the undef value
/// @returns the result ID of the instruction
uint32_t Undef(const core::type::Type* type) {
- return undef_values_.GetOrCreate(type, [&] {
+ return undef_values_.GetOrAdd(type, [&] {
auto id = module_.NextId();
module_.PushType(spv::Op::OpUndef, {Type(type), id});
return id;
@@ -478,7 +478,7 @@
/// @returns the result ID of the type
uint32_t Type(const core::type::Type* ty) {
ty = DedupType(ty, ir_.Types());
- return types_.GetOrCreate(ty, [&] {
+ return types_.GetOrAdd(ty, [&] {
auto id = module_.NextId();
Switch(
ty, //
@@ -548,7 +548,7 @@
value, //
[&](core::ir::Constant* constant) { return Constant(constant); },
[&](core::ir::Value*) {
- return values_.GetOrCreate(value, [&] { return module_.NextId(); });
+ return values_.GetOrAdd(value, [&] { return module_.NextId(); });
});
}
@@ -556,7 +556,7 @@
/// @param block the block to get the label ID for
/// @returns the ID of the block's label
uint32_t Label(const core::ir::Block* block) {
- return block_labels_.GetOrCreate(block, [&] { return module_.NextId(); });
+ return block_labels_.GetOrAdd(block, [&] { return module_.NextId(); });
}
/// Emit a struct type.
@@ -717,7 +717,7 @@
}
// Get the ID for the function type (creating it if needed).
- auto function_type_id = function_types_.GetOrCreate(function_type, [&] {
+ auto function_type_id = function_types_.GetOrAdd(function_type, [&] {
auto func_ty_id = module_.NextId();
OperandList operands = {func_ty_id, return_type_id};
operands.insert(operands.end(), function_type.param_type_ids.begin(),
@@ -790,7 +790,7 @@
// Determine if this IO variable is used by the entry point.
bool used = false;
for (const auto& use : var->Result(0)->Usages()) {
- auto* block = use.instruction->Block();
+ auto* block = use->instruction->Block();
while (block->Parent()) {
block = block->Parent()->Block();
}
@@ -1385,7 +1385,7 @@
auto glsl_ext_inst = [&](enum GLSLstd450 inst) {
constexpr const char* kGLSLstd450 = "GLSL.std.450";
op = spv::Op::OpExtInst;
- operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&] {
+ operands.push_back(imports_.GetOrAdd(kGLSLstd450, [&] {
// Import the instruction set the first time it is requested.
auto import = module_.NextId();
module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)});
@@ -2197,7 +2197,7 @@
/// @param ci the control instruction to get the merge label for
/// @returns the label ID
uint32_t GetMergeLabel(core::ir::ControlInstruction* ci) {
- return merge_block_labels_.GetOrCreate(ci, [&] { return module_.NextId(); });
+ return merge_block_labels_.GetOrAdd(ci, [&] { return module_.NextId(); });
}
/// Get the ID of the label of the block that will contain a terminator instruction.
diff --git a/src/tint/lang/spirv/writer/raise/merge_return.cc b/src/tint/lang/spirv/writer/raise/merge_return.cc
index 41bafb9..c954430 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return.cc
@@ -73,7 +73,7 @@
void Process(core::ir::Function* fn) {
// Find all of the nested return instructions in the function.
for (const auto& usage : fn->Usages()) {
- if (auto* ret = usage.instruction->As<core::ir::Return>()) {
+ if (auto* ret = usage->instruction->As<core::ir::Return>()) {
TransitivelyMarkAsReturning(ret->Block()->Parent());
}
}
diff --git a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc
index 32f0d1b..19037f7 100644
--- a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc
+++ b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index.cc
@@ -149,7 +149,7 @@
if (to_replace.first_dynamic_index > 0) {
PartialAccess partial_access = {
access->Object(), access->Indices().Truncate(to_replace.first_dynamic_index)};
- source_object = source_object_to_value.GetOrCreate(partial_access, [&] {
+ source_object = source_object_to_value.GetOrAdd(partial_access, [&] {
auto* intermediate_source = builder.Access(to_replace.dynamic_index_source_type,
source_object, partial_access.indices);
intermediate_source->InsertBefore(access);
@@ -158,7 +158,7 @@
}
// Declare a local variable and copy the source object to it.
- auto* local = object_to_local.GetOrCreate(source_object, [&] {
+ auto* local = object_to_local.GetOrAdd(source_object, [&] {
auto* decl = builder.Var(ir.Types().ptr(
core::AddressSpace::kFunction, source_object->Type(), core::Access::kReadWrite));
decl->SetInitializer(source_object);
diff --git a/src/tint/lang/wgsl/ast/blend_src_attribute.cc b/src/tint/lang/wgsl/ast/blend_src_attribute.cc
index 089ce1e..102c0ac 100644
--- a/src/tint/lang/wgsl/ast/blend_src_attribute.cc
+++ b/src/tint/lang/wgsl/ast/blend_src_attribute.cc
@@ -37,9 +37,9 @@
namespace tint::ast {
BlendSrcAttribute::BlendSrcAttribute(GenerationID pid,
- NodeID nid,
- const Source& src,
- const Expression* exp)
+ NodeID nid,
+ const Source& src,
+ const Expression* exp)
: Base(pid, nid, src), expr(exp) {}
BlendSrcAttribute::~BlendSrcAttribute() = default;
diff --git a/src/tint/lang/wgsl/ast/clone_context.cc b/src/tint/lang/wgsl/ast/clone_context.cc
index 1159c91..3d6c69f 100644
--- a/src/tint/lang/wgsl/ast/clone_context.cc
+++ b/src/tint/lang/wgsl/ast/clone_context.cc
@@ -39,7 +39,7 @@
CloneContext::~CloneContext() = default;
Symbol CloneContext::Clone(Symbol s) {
- return cloned_symbols_.GetOrCreate(s, [&]() -> Symbol {
+ return cloned_symbols_.GetOrAdd(s, [&]() -> Symbol {
if (symbol_transform_) {
return symbol_transform_(s);
}
@@ -67,7 +67,7 @@
}
// Was Replace() called for this node?
- if (auto fn = replacements_.Find(node)) {
+ if (auto fn = replacements_.Get(node)) {
return (*fn)();
}
diff --git a/src/tint/lang/wgsl/ast/clone_context.h b/src/tint/lang/wgsl/ast/clone_context.h
index 5115e17..b99f774 100644
--- a/src/tint/lang/wgsl/ast/clone_context.h
+++ b/src/tint/lang/wgsl/ast/clone_context.h
@@ -186,14 +186,14 @@
void Clone(tint::Vector<T*, N>& to, const tint::Vector<T*, N>& from) {
to.Reserve(from.Length());
- auto transforms = list_transforms_.Find(&from);
+ auto transforms = list_transforms_.Get(&from);
if (transforms) {
for (auto& builder : transforms->insert_front_) {
to.Push(CheckedCast<T>(builder()));
}
for (auto& el : from) {
- if (auto insert_before = transforms->insert_before_.Find(el)) {
+ if (auto insert_before = transforms->insert_before_.Get(el)) {
for (auto& builder : *insert_before) {
to.Push(CheckedCast<T>(builder()));
}
@@ -201,7 +201,7 @@
if (!transforms->remove_.Contains(el)) {
to.Push(Clone(el));
}
- if (auto insert_after = transforms->insert_after_.Find(el)) {
+ if (auto insert_after = transforms->insert_after_.Get(el)) {
for (auto& builder : *insert_after) {
to.Push(CheckedCast<T>(builder()));
}
@@ -214,10 +214,12 @@
for (auto& el : from) {
to.Push(Clone(el));
- // Clone(el) may have updated the transformation list, adding an `insert_after`
- // transform for `from`.
+ if (!transforms) {
+ // Clone(el) may have create a transformation list
+ transforms = list_transforms_.Get(&from);
+ }
if (transforms) {
- if (auto insert_after = transforms->insert_after_.Find(el)) {
+ if (auto insert_after = transforms->insert_after_.Get(el)) {
for (auto& builder : *insert_after) {
to.Push(CheckedCast<T>(builder()));
}
@@ -225,8 +227,10 @@
}
}
- // Clone(el) may have updated the transformation list, adding an `insert_back_`
- // transform for `from`.
+ if (!transforms) {
+ // Clone(el) may have create a transformation list
+ transforms = list_transforms_.Get(&from);
+ }
if (transforms) {
for (auto& builder : transforms->insert_back_) {
to.Push(CheckedCast<T>(builder()));
@@ -374,7 +378,7 @@
return *this;
}
- list_transforms_.GetOrZero(&vector)->remove_.Add(object);
+ list_transforms_.GetOrAddZero(&vector).remove_.Add(object);
return *this;
}
@@ -396,7 +400,7 @@
/// @returns this CloneContext so calls can be chained
template <typename T, size_t N, typename BUILDER>
CloneContext& InsertFront(const tint::Vector<T, N>& vector, BUILDER&& builder) {
- list_transforms_.GetOrZero(&vector)->insert_front_.Push(std::forward<BUILDER>(builder));
+ list_transforms_.GetOrAddZero(&vector).insert_front_.Push(std::forward<BUILDER>(builder));
return *this;
}
@@ -419,7 +423,7 @@
/// @returns this CloneContext so calls can be chained
template <typename T, size_t N, typename BUILDER>
CloneContext& InsertBack(const tint::Vector<T, N>& vector, BUILDER&& builder) {
- list_transforms_.GetOrZero(&vector)->insert_back_.Push(std::forward<BUILDER>(builder));
+ list_transforms_.GetOrAddZero(&vector).insert_back_.Push(std::forward<BUILDER>(builder));
return *this;
}
@@ -440,7 +444,7 @@
return *this;
}
- list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
+ list_transforms_.GetOrAddZero(&vector).insert_before_.GetOrAddZero(before).Push(
[object] { return object; });
return *this;
}
@@ -459,7 +463,7 @@
CloneContext& InsertBefore(const tint::Vector<T, N>& vector,
const BEFORE* before,
BUILDER&& builder) {
- list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
+ list_transforms_.GetOrAddZero(&vector).insert_before_.GetOrAddZero(before).Push(
std::forward<BUILDER>(builder));
return *this;
}
@@ -481,7 +485,7 @@
return *this;
}
- list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
+ list_transforms_.GetOrAddZero(&vector).insert_after_.GetOrAddZero(after).Push(
[object] { return object; });
return *this;
}
@@ -500,7 +504,7 @@
CloneContext& InsertAfter(const tint::Vector<T, N>& vector,
const AFTER* after,
BUILDER&& builder) {
- list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
+ list_transforms_.GetOrAddZero(&vector).insert_after_.GetOrAddZero(after).Push(
std::forward<BUILDER>(builder));
return *this;
}
diff --git a/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc b/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc
index 82e0716..631f1bd 100644
--- a/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc
+++ b/src/tint/lang/wgsl/ast/transform/add_block_attribute.cc
@@ -83,7 +83,7 @@
if (needs_wrapping) {
const char* kMemberName = "inner";
- auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
+ auto* wrapper = wrapper_structs.GetOrAdd(ty, [&] {
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
auto wrapper_name = global->name->symbol.Name() + "_block";
auto* ret =
diff --git a/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc b/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc
index b3a8ece..b1f36fc 100644
--- a/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc
+++ b/src/tint/lang/wgsl/ast/transform/builtin_polyfill.cc
@@ -1117,7 +1117,7 @@
auto* lhs_ty = src.TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src.TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
- auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
+ auto fn = binary_op_polyfills.GetOrAdd(sig, [&] {
const bool is_div = bin_op->op == core::BinaryOp::kDivide;
const auto [lhs_el_ty, lhs_width] = lhs_ty->Elements(lhs_ty, 1);
@@ -1210,7 +1210,7 @@
auto* lhs_ty = src.TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src.TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
- auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
+ auto fn = binary_op_polyfills.GetOrAdd(sig, [&] {
const auto [lhs_el_ty, lhs_width] = lhs_ty->Elements(lhs_ty, 1);
const auto [rhs_el_ty, rhs_width] = rhs_ty->Elements(rhs_ty, 1);
@@ -1295,21 +1295,21 @@
switch (builtin->Fn()) {
case wgsl::BuiltinFn::kAcosh:
if (cfg.builtins.acosh != Level::kNone) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return acosh(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kAsinh:
if (cfg.builtins.asinh) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return asinh(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kAtanh:
if (cfg.builtins.atanh != Level::kNone) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return atanh(builtin->ReturnType()); });
}
return Symbol{};
@@ -1318,7 +1318,7 @@
if (cfg.builtins.clamp_int) {
auto& sig = builtin->Signature();
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return clampInteger(builtin->ReturnType()); });
}
}
@@ -1326,42 +1326,42 @@
case wgsl::BuiltinFn::kCountLeadingZeros:
if (cfg.builtins.count_leading_zeros) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return countLeadingZeros(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kCountTrailingZeros:
if (cfg.builtins.count_trailing_zeros) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return countTrailingZeros(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kExtractBits:
if (cfg.builtins.extract_bits != Level::kNone) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return extractBits(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kFirstLeadingBit:
if (cfg.builtins.first_leading_bit) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return firstLeadingBit(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kFirstTrailingBit:
if (cfg.builtins.first_trailing_bit) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return firstTrailingBit(builtin->ReturnType()); });
}
return Symbol{};
case wgsl::BuiltinFn::kInsertBits:
if (cfg.builtins.insert_bits != Level::kNone) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return insertBits(builtin->ReturnType()); });
}
return Symbol{};
@@ -1373,7 +1373,7 @@
auto& sig = builtin->Signature();
auto* vec = sig.return_type->As<core::type::Vector>();
if (vec && vec->Width() == 2 && vec->type()->Is<core::type::F32>()) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return reflect(builtin->ReturnType()); });
}
}
@@ -1381,7 +1381,7 @@
case wgsl::BuiltinFn::kSaturate:
if (cfg.builtins.saturate) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return saturate(builtin->ReturnType()); });
}
return Symbol{};
@@ -1390,8 +1390,8 @@
if (cfg.builtins.sign_int) {
auto* ty = builtin->ReturnType();
if (ty->is_signed_integer_scalar_or_vector()) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return sign_int(ty); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return sign_int(ty); });
}
}
return Symbol{};
@@ -1418,7 +1418,7 @@
auto* tex = sig.Parameter(core::ParameterUsage::kTexture);
if (auto* stex = tex->Type()->As<core::type::SampledTexture>()) {
if (stex->type()->Is<core::type::F32>()) {
- return builtin_polyfills.GetOrCreate(builtin, [&] {
+ return builtin_polyfills.GetOrAdd(builtin, [&] {
return textureSampleBaseClampToEdge_2d_f32();
});
}
@@ -1456,7 +1456,7 @@
case wgsl::BuiltinFn::kQuantizeToF16:
if (cfg.builtins.quantize_to_vec_f16) {
if (auto* vec = builtin->ReturnType()->As<core::type::Vector>()) {
- return builtin_polyfills.GetOrCreate(
+ return builtin_polyfills.GetOrAdd(
builtin, [&] { return quantizeToF16(vec); });
}
}
@@ -1464,7 +1464,7 @@
case wgsl::BuiltinFn::kWorkgroupUniformLoad:
if (cfg.builtins.workgroup_uniform_load) {
- return builtin_polyfills.GetOrCreate(builtin, [&] {
+ return builtin_polyfills.GetOrAdd(builtin, [&] {
return workgroupUniformLoad(builtin->ReturnType());
});
}
@@ -1472,64 +1472,62 @@
case wgsl::BuiltinFn::kDot4I8Packed: {
if (cfg.builtins.dot_4x8_packed) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Dot4I8Packed(); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return Dot4I8Packed(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kDot4U8Packed: {
if (cfg.builtins.dot_4x8_packed) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Dot4U8Packed(); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return Dot4U8Packed(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kPack4XI8: {
if (cfg.builtins.pack_unpack_4x8) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Pack4xI8(); });
+ return builtin_polyfills.GetOrAdd(builtin, [&] { return Pack4xI8(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kPack4XU8: {
if (cfg.builtins.pack_unpack_4x8) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Pack4xU8(); });
+ return builtin_polyfills.GetOrAdd(builtin, [&] { return Pack4xU8(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kPack4XI8Clamp: {
if (cfg.builtins.pack_unpack_4x8) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Pack4xI8Clamp(); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return Pack4xI8Clamp(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kPack4XU8Clamp: {
if (cfg.builtins.pack_4xu8_clamp) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Pack4xU8Clamp(); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return Pack4xU8Clamp(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kUnpack4XI8: {
if (cfg.builtins.pack_unpack_4x8) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Unpack4xI8(); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return Unpack4xI8(); });
}
return Symbol{};
}
case wgsl::BuiltinFn::kUnpack4XU8: {
if (cfg.builtins.pack_unpack_4x8) {
- return builtin_polyfills.GetOrCreate(builtin,
- [&] { return Unpack4xU8(); });
+ return builtin_polyfills.GetOrAdd(builtin,
+ [&] { return Unpack4xU8(); });
}
return Symbol{};
}
@@ -1545,7 +1543,7 @@
auto* dst_ty = conv->Target();
if (tint::IsAnyOf<core::type::I32, core::type::U32>(
dst_ty->Elements(dst_ty).type)) {
- return f32_conv_polyfills.GetOrCreate(dst_ty, [&] { //
+ return f32_conv_polyfills.GetOrAdd(dst_ty, [&] { //
return ConvF32ToIU32(src_ty, dst_ty);
});
}
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
index b056a6a..dcea861 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
@@ -265,7 +265,7 @@
// Get or create the intrinsic function.
auto builtin = BuiltinOf(attrs);
- auto intrinsic = wave_intrinsics.GetOrCreate(builtin, [&] {
+ auto intrinsic = wave_intrinsics.GetOrAdd(builtin, [&] {
if (builtin == core::BuiltinValue::kSubgroupInvocationId) {
return make_intrinsic("__WaveGetLaneIndex",
HLSLWaveIntrinsic::Op::kWaveGetLaneIndex);
diff --git a/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc
index c87fa15..59e6dd6 100644
--- a/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc
+++ b/src/tint/lang/wgsl/ast/transform/clamp_frag_depth.cc
@@ -139,7 +139,7 @@
// return S(s.first, s.second, clamp_frag_depth(s.frag_depth), s.last);
// }
auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()->Declaration();
- auto helper = io_structs_clamp_helpers.GetOrCreate(struct_ty, [&] {
+ auto helper = io_structs_clamp_helpers.GetOrAdd(struct_ty, [&] {
auto return_ty = fn->return_type;
auto fn_sym =
b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name());
diff --git a/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc b/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc
index 93a539c..28e776b 100644
--- a/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc
+++ b/src/tint/lang/wgsl/ast/transform/demote_to_helper.cc
@@ -202,7 +202,7 @@
auto* result_struct = sem_call->Type()->As<core::type::Struct>();
auto* atomic_ty = result_struct->Members()[0]->Type();
result_ty =
- b.ty(tint::GetOrCreate(atomic_cmpxchg_result_types, atomic_ty, [&] {
+ b.ty(tint::GetOrAdd(atomic_cmpxchg_result_types, atomic_ty, [&] {
auto name = b.Sym();
b.Structure(
name,
diff --git a/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc b/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
index 4ecea23..7cdee65 100644
--- a/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
+++ b/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
@@ -358,8 +358,8 @@
tint::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> SortedVariants() {
tint::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> out;
out.Reserve(variants.Count());
- for (auto it : variants) {
- out.Push({&it.key, &it.value});
+ for (auto& it : variants) {
+ out.Push({&it.key.Value(), &it.value});
}
out.Sort([&](auto& va, auto& vb) { return va.second->order < vb.second->order; });
return out;
@@ -539,7 +539,7 @@
// The dynamic index needs to be hoisted (if it hasn't been already).
auto fn = FnInfoFor(idx->Stmt()->Function());
- fn->hoisted_exprs.GetOrCreate(idx, [=] {
+ fn->hoisted_exprs.GetOrAdd(idx, [=] {
// Create a name for the new 'let'
auto name = b.Symbols().New("ptr_index_save");
// Insert a new 'let' just above the dynamic index statement.
@@ -678,7 +678,7 @@
// Construct a new FnVariant if this is the first caller of the target signature
auto* target_info = FnInfoFor(target);
- auto& target_variant = target_info->variants.GetOrCreate(target_signature, [&] {
+ auto& target_variant = target_info->variants.GetOrAdd(target_signature, [&] {
if (target_signature.IsEmpty()) {
// Call target does not require any argument changes.
FnVariant variant;
@@ -693,15 +693,15 @@
StringStream ss;
ss << target->Declaration()->name->symbol.Name();
for (auto* param : target->Parameters()) {
- if (auto indices = target_signature.Find(param)) {
+ if (auto indices = target_signature.Get(param)) {
ss << "_" << AccessShapeName(*indices);
}
}
// Build the pointer parameter symbols.
Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;
- for (auto param_it : target_signature) {
- auto* param = param_it.key;
+ for (auto& param_it : target_signature) {
+ auto* param = param_it.key.Value();
auto& shape = param_it.value;
// Parameter needs replacing with either zero, one or two parameters:
@@ -770,7 +770,7 @@
/// @returns the AccessChain for the expression @p expr, or nullptr if the expression does
/// not hold an access chain.
AccessChain* AccessChainFor(const sem::ValueExpression* expr) const {
- if (auto chain = access_chains.Find(expr)) {
+ if (auto chain = access_chains.Get(expr)) {
return *chain;
}
return nullptr;
@@ -781,7 +781,7 @@
AccessShape AbsoluteAccessShape(const FnVariant::Signature& signature,
const AccessShape& shape) const {
if (auto* root_param = shape.root.variable->As<sem::Parameter>()) {
- if (auto incoming_chain = signature.Find(root_param)) {
+ if (auto incoming_chain = signature.Get(root_param)) {
// Access chain originates from a parameter, which will be transformed into an array
// of dynamic indices. Concatenate the signature's AccessShape for the parameter
// to the chain's indices, skipping over the chain's initial parameter index.
@@ -831,8 +831,8 @@
// dynamic indices).
tint::Vector<const Parameter*, 8> params;
for (auto* param : fn->Parameters()) {
- if (auto incoming_shape = variant_sig.Find(param)) {
- auto& symbols = *variant.ptr_param_symbols.Find(param);
+ if (auto incoming_shape = variant_sig.Get(param)) {
+ auto& symbols = *variant.ptr_param_symbols.Get(param);
if (symbols.base_ptr.IsValid()) {
auto base_ptr_ty =
b.ty.ptr(incoming_shape->root.address_space,
@@ -872,7 +872,7 @@
void TransformCall(const sem::Call* call) {
// Register a custom handler for the specific call expression
ctx.Replace(call->Declaration(), [this, call] {
- auto target_variant = clone_state->current_variant->calls.Find(call);
+ auto target_variant = clone_state->current_variant->calls.Get(call);
if (!target_variant) {
// The current variant does not need to transform this call.
return ctx.CloneWithoutTransform(call->Declaration());
@@ -921,9 +921,9 @@
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
// Access chain originates from a pointer parameter.
if (auto incoming_chain =
- clone_state->current_variant_sig->Find(root_param)) {
+ clone_state->current_variant_sig->Get(root_param)) {
auto indices =
- clone_state->current_variant->ptr_param_symbols.Find(root_param)
+ clone_state->current_variant->ptr_param_symbols.Get(root_param)
->indices;
// This pointer parameter will have been replaced with a array<u32, N>
@@ -1001,7 +1001,7 @@
// If the expression has been hoisted to a 'let', then replace the expression with an
// identifier to the hoisted let.
- if (auto hoisted = clone_state->current_function->hoisted_exprs.Find(expr)) {
+ if (auto hoisted = clone_state->current_function->hoisted_exprs.Get(expr)) {
return b.Expr(*hoisted);
}
@@ -1018,7 +1018,7 @@
return nullptr; // Just clone the expression.
}
- auto incoming_shape = clone_state->current_variant_sig->Find(root_param);
+ auto incoming_shape = clone_state->current_variant_sig->Get(root_param);
if (!incoming_shape) {
// The root parameter of the access chain is not part of the variant's signature.
return nullptr; // Just clone the expression.
@@ -1033,8 +1033,8 @@
// Replace this with the variant's incoming shape. This will bring the expression up to
// the incoming pointer.
size_t next_dyn_idx_from_indices = 0;
- auto indices =
- clone_state->current_variant->ptr_param_symbols.Find(root_param)->indices;
+ auto& indices =
+ clone_state->current_variant->ptr_param_symbols.Get(root_param)->indices;
for (auto param_access : incoming_shape->ops) {
chain_expr = BuildAccessExpr(chain_expr, param_access, [&] {
return b.IndexAccessor(indices, AInt(next_dyn_idx_from_indices++));
@@ -1065,13 +1065,13 @@
/// @returns the FnInfo for the given function, constructing a new FnInfo if @p fn doesn't
/// already have one.
FnInfo* FnInfoFor(const sem::Function* fn) {
- return fns.GetOrCreate(fn, [this] { return fn_info_allocator.Create(); });
+ return fns.GetOrAdd(fn, [this] { return fn_info_allocator.Create(); });
}
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
/// if this is the first call for the given shape.
Type DynamicIndexArrayType(const AccessShape& shape) {
- auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
+ auto name = dynamic_index_array_aliases.GetOrAdd(shape, [&] {
// Count the number of dynamic indices
uint32_t num_dyn_indices = shape.NumDynamicIndices();
if (num_dyn_indices == 0) {
@@ -1120,7 +1120,7 @@
/// @param deref if true, the returned expression will always be a reference type.
const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
if (auto* param = root.variable->As<sem::Parameter>()) {
- if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
+ if (auto symbols = clone_state->current_variant->ptr_param_symbols.Get(param)) {
if (deref) {
return b.Deref(b.Expr(symbols->base_ptr));
}
@@ -1165,7 +1165,7 @@
/// underscore and number, if the symbol is already taken.
Symbol UniqueSymbolWithSuffix(Symbol symbol, const std::string& suffix) {
auto str = symbol.Name() + suffix;
- return unique_symbols.GetOrCreate(str, [&] { return b.Symbols().New(str); });
+ return unique_symbols.GetOrAdd(str, [&] { return b.Symbols().New(str); });
}
/// @returns true if the function @p fn has at least one pointer parameter.
diff --git a/src/tint/lang/wgsl/ast/transform/hoist_to_decl_before.cc b/src/tint/lang/wgsl/ast/transform/hoist_to_decl_before.cc
index 3f5fdbb..f1fa3d8 100644
--- a/src/tint/lang/wgsl/ast/transform/hoist_to_decl_before.cc
+++ b/src/tint/lang/wgsl/ast/transform/hoist_to_decl_before.cc
@@ -171,11 +171,11 @@
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #for_loops map is mutated.
- auto ForLoop(const sem::ForLoopStatement* for_loop) {
+ LoopInfo& ForLoop(const sem::ForLoopStatement* for_loop) {
if (for_loops.IsEmpty()) {
RegisterForLoopTransform();
}
- return for_loops.GetOrZero(for_loop);
+ return for_loops.GetOrAddZero(for_loop);
}
/// @returns a new LoopInfo reference for the given @p while_loop.
@@ -183,11 +183,11 @@
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #for_loops map is mutated.
- auto WhileLoop(const sem::WhileStatement* while_loop) {
+ LoopInfo& WhileLoop(const sem::WhileStatement* while_loop) {
if (while_loops.IsEmpty()) {
RegisterWhileLoopTransform();
}
- return while_loops.GetOrZero(while_loop);
+ return while_loops.GetOrAddZero(while_loop);
}
/// @returns a new ElseIfInfo reference for the given @p else_if.
@@ -195,11 +195,11 @@
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #else_ifs map is mutated.
- auto ElseIf(const IfStatement* else_if) {
+ ElseIfInfo& ElseIf(const IfStatement* else_if) {
if (else_ifs.IsEmpty()) {
RegisterElseIfTransform();
}
- return else_ifs.GetOrZero(else_if);
+ return else_ifs.GetOrAddZero(else_if);
}
/// Registers the handler for transforming for-loops based on the content of the #for_loops map.
@@ -208,7 +208,7 @@
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
- if (auto info = for_loops.Find(fl)) {
+ if (auto info = for_loops.Get(fl)) {
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
// Build the loop body's statements.
@@ -267,7 +267,7 @@
auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) {
- if (auto info = while_loops.Find(w)) {
+ if (auto info = while_loops.Get(w)) {
auto* while_loop = w->Declaration();
// While needs to be decomposed to a loop.
// Build the loop body's statements.
@@ -304,7 +304,7 @@
void RegisterElseIfTransform() const {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const IfStatement* stmt) -> const Statement* {
- if (auto info = else_ifs.Find(stmt)) {
+ if (auto info = else_ifs.Get(stmt)) {
// Build the else block's body statements, starting with let decls for the
// conditional expression.
auto body_stmts = Build(info->cond_decls);
@@ -336,10 +336,10 @@
if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
// Insertion point is an 'else if' condition.
// Need to convert 'else if' to 'else { if }'.
- auto else_if_info = ElseIf(else_if->Declaration());
+ auto& else_if_info = ElseIf(else_if->Declaration());
// Index the map to decompose this else if, even if `stmt` is nullptr.
- auto& decls = else_if_info->cond_decls;
+ auto& decls = else_if_info.cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -351,7 +351,7 @@
// For-loop needs to be decomposed to a loop.
// Index the map to decompose this for-loop, even if `stmt` is nullptr.
- auto& decls = ForLoop(fl)->cond_decls;
+ auto& decls = ForLoop(fl).cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -363,7 +363,7 @@
// While needs to be decomposed to a loop.
// Index the map to decompose this while, even if `stmt` is nullptr.
- auto& decls = WhileLoop(w)->cond_decls;
+ auto& decls = WhileLoop(w).cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -390,7 +390,7 @@
// For-loop needs to be decomposed to a loop.
// Index the map to decompose this for-loop, even if `stmt` is nullptr.
- auto& decls = ForLoop(fl)->init_decls;
+ auto& decls = ForLoop(fl).init_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -403,7 +403,7 @@
// For-loop needs to be decomposed to a loop.
// Index the map to decompose this for-loop, even if `stmt` is nullptr.
- auto& decls = ForLoop(fl)->cont_decls;
+ auto& decls = ForLoop(fl).cont_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
diff --git a/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc
index bdcc60b..d32b3b7 100644
--- a/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc
+++ b/src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.cc
@@ -497,7 +497,7 @@
createGammaCorrectionFn();
}
- auto texture_load_external_sym = texture_load_external_fns.GetOrCreate(call->Target(), [&] {
+ auto texture_load_external_sym = texture_load_external_fns.GetOrAdd(call->Target(), [&] {
auto& sig = call->Target()->Signature();
auto* coord_ty = sig.Parameter(core::ParameterUsage::kCoords)->Type();
diff --git a/src/tint/lang/wgsl/ast/transform/preserve_padding.cc b/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
index d6935a9..5e28fed 100644
--- a/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
+++ b/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
@@ -123,7 +123,7 @@
const char* kDestParamName = "dest";
const char* kValueParamName = "value";
auto call_helper = [&](auto&& body) {
- auto helper = helpers.GetOrCreate(ty, [&] {
+ auto helper = helpers.GetOrAdd(ty, [&] {
auto helper_name = b.Symbols().New("assign_and_preserve_padding");
tint::Vector<const Parameter*, 2> params = {
b.Param(kDestParamName,
diff --git a/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc
index 6400a5e..5094e8d 100644
--- a/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc
+++ b/src/tint/lang/wgsl/ast/transform/promote_initializers_to_let.cc
@@ -135,8 +135,8 @@
// After walking the full AST, const_chains only contains the outer-most constant expressions.
// Check if any of these need hoisting, and append those to to_hoist.
- for (auto* expr : const_chains) {
- if (auto* sem = src.Sem().GetVal(expr); should_hoist(sem)) {
+ for (auto& expr : const_chains) {
+ if (auto* sem = src.Sem().GetVal(expr.Value()); should_hoist(sem)) {
to_hoist.Push(sem);
}
}
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies.cc b/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
index a6b1123..ff33316c 100644
--- a/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
@@ -130,7 +130,7 @@
for (auto* arg : side_effects) {
sig.push_back(sem.GetVal(arg)->Type()->UnwrapRef());
}
- auto sink = sinks.GetOrCreate(sig, [&] {
+ auto sink = sinks.GetOrAdd(sig, [&] {
auto name = b.Symbols().New("phony_sink");
tint::Vector<const Parameter*, 8> params;
for (auto* ty : sig) {
diff --git a/src/tint/lang/wgsl/ast/transform/renamer.cc b/src/tint/lang/wgsl/ast/transform/renamer.cc
index 7233498..487697d 100644
--- a/src/tint/lang/wgsl/ast/transform/renamer.cc
+++ b/src/tint/lang/wgsl/ast/transform/renamer.cc
@@ -1398,7 +1398,7 @@
}
// Create a replacement for this symbol, if we haven't already.
- auto replacement = remappings.GetOrCreate(symbol, [&] {
+ auto replacement = remappings.GetOrAdd(symbol, [&] {
if (requested_names) {
auto iter = requested_names->find(symbol.Name());
if (iter != requested_names->end()) {
@@ -1422,8 +1422,8 @@
ctx.Clone();
Remappings out;
- for (auto it : remappings) {
- out[it.key.Name()] = it.value.Name();
+ for (auto& it : remappings) {
+ out[it.key->Name()] = it.value.Name();
}
outputs.Add<Data>(std::move(out));
diff --git a/src/tint/lang/wgsl/ast/transform/robustness.cc b/src/tint/lang/wgsl/ast/transform/robustness.cc
index 8169d28..676a544 100644
--- a/src/tint/lang/wgsl/ast/transform/robustness.cc
+++ b/src/tint/lang/wgsl/ast/transform/robustness.cc
@@ -335,7 +335,7 @@
}
auto* stmt = expr->Stmt();
- auto obj_pred = *predicates.GetOrZero(obj);
+ auto& obj_pred = predicates.GetOrAddZero(obj);
auto idx_let = b.Symbols().New("index");
auto pred = b.Symbols().New("predicate");
diff --git a/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc b/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc
index 3095252..ad4d713 100644
--- a/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc
+++ b/src/tint/lang/wgsl/ast/transform/simplify_pointers.cc
@@ -258,7 +258,7 @@
// variable identifier.
ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
// Look to see if we need to swap this Expression with a saved variable.
- if (auto saved_var = saved_vars.Find(expr)) {
+ if (auto saved_var = saved_vars.Get(expr)) {
return ctx.dst->Expr(*saved_var);
}
diff --git a/src/tint/lang/wgsl/ast/transform/std140.cc b/src/tint/lang/wgsl/ast/transform/std140.cc
index b28e5a7..037d716 100644
--- a/src/tint/lang/wgsl/ast/transform/std140.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140.cc
@@ -403,14 +403,14 @@
return Switch(
ty, //
[&](const core::type::Struct* str) {
- if (auto std140 = std140_structs.Find(str)) {
+ if (auto std140 = std140_structs.Get(str)) {
return b.ty(*std140);
}
return Type{};
},
[&](const core::type::Matrix* mat) {
if (MatrixNeedsDecomposing(mat)) {
- auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
+ auto std140_mat = std140_mats.GetOrAdd(mat, [&] {
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
std::to_string(mat->rows()) + "_" +
mat->type()->FriendlyName());
@@ -671,7 +671,7 @@
/// @returns the converted value expression.
const Expression* Convert(const core::type::Type* ty, const Expression* expr) {
// Get an existing, or create a new function for converting the std140 type to ty.
- auto fn = conv_fns.GetOrCreate(ty, [&] {
+ auto fn = conv_fns.GetOrAdd(ty, [&] {
auto std140_ty = Std140Type(ty);
if (!std140_ty) {
// ty was not forked for std140.
@@ -690,7 +690,7 @@
// call, or by reassembling a std140 matrix from column vector members.
tint::Vector<const Expression*, 8> args;
for (auto* member : str->Members()) {
- if (auto col_members = std140_mat_members.Find(member)) {
+ if (auto col_members = std140_mat_members.Get(member)) {
// std140 decomposed matrix. Reassemble.
auto mat_ty = CreateASTTypeFor(ctx, member->Type());
auto mat_args =
@@ -772,7 +772,7 @@
const Expression* LoadMatrixWithFn(const AccessChain& access) {
// Get an existing, or create a new function for loading the uniform buffer value.
// This function is keyed off the uniform buffer variable and the access chain.
- auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
+ auto fn = load_fns.GetOrAdd(LoadFnKey{access.var, access.indices}, [&] {
if (access.IsMatrixSubset()) {
// Access chain passes through the matrix, but ends either at a column vector,
// column swizzle, or element.
diff --git a/src/tint/lang/wgsl/ast/transform/unshadow.cc b/src/tint/lang/wgsl/ast/transform/unshadow.cc
index eab34b0..8c61526 100644
--- a/src/tint/lang/wgsl/ast/transform/unshadow.cc
+++ b/src/tint/lang/wgsl/ast/transform/unshadow.cc
@@ -117,7 +117,7 @@
ctx.ReplaceAll([&](const IdentifierExpression* ident) -> const IdentifierExpression* {
if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) {
- if (auto renamed = renamed_to.Find(user->Variable())) {
+ if (auto renamed = renamed_to.Get(user->Variable())) {
return b.Expr(*renamed);
}
}
diff --git a/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc
index 45647df..e75deaa 100644
--- a/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc
+++ b/src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.cc
@@ -125,7 +125,7 @@
// Generate a helper function for constructing the matrix.
// This is done to ensure that the single argument value is only evaluated once, and
// with the correct expression evaluation order.
- auto fn = tint::GetOrCreate(scalar_inits, mat_type, [&] {
+ auto fn = tint::GetOrAdd(scalar_inits, mat_type, [&] {
auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) +
"x" + std::to_string(mat_type->rows()));
b.Func(name,
diff --git a/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc b/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc
index b3e4aa6..81ec237 100644
--- a/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc
+++ b/src/tint/lang/wgsl/ast/transform/vertex_pulling.cc
@@ -318,7 +318,7 @@
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
Symbol GetVertexBufferName(uint32_t index) {
- return tint::GetOrCreate(vertex_buffer_names, index, [&] {
+ return tint::GetOrAdd(vertex_buffer_names, index, [&] {
static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
});
diff --git a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
index 3dba65f..235693e 100644
--- a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
@@ -369,8 +369,8 @@
}
auto array_indices = a.array_indices;
array_indices.Add(ArrayIndex{modulo, division});
- auto index = tint::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
- [&] { return b.Symbols().New("i"); });
+ auto index = tint::GetOrAdd(array_index_names, ArrayIndex{modulo, division},
+ [&] { return b.Symbols().New("i"); });
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
};
return BuildZeroingStatements(arr->ElemType(), get_el);
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index 6c5d935..e6c2b44 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -920,7 +920,7 @@
// turn specified as an upper bound for Vulkan layout sizing. Since D3D
// and Metal are even less specific, we assume Vulkan behavior as a
// good-enough approximation everywhere.
- total_size += tint::RoundUp(align, size);
+ total_size += tint::RoundUp(16u, tint::RoundUp(align, size));
}
}
@@ -1034,14 +1034,7 @@
auto record_function_param = [&fn_to_data](const sem::Function* func,
const ast::Parameter* param, TextureQueryType type) {
- auto& param_to_type = *fn_to_data.GetOrZero(func);
-
- auto entry = param_to_type.Get(param);
- if (entry.has_value()) {
- return;
- }
-
- param_to_type.Add(param, type);
+ fn_to_data.GetOrAddZero(func).Add(param, type);
};
auto save_if_needed = [&res, &seen](const sem::GlobalVariable* global, TextureQueryType type) {
@@ -1110,7 +1103,7 @@
// A function call, check to see if any params needed to be tracked back to a
// global texture.
- auto param_to_type = fn_to_data.Find(func);
+ auto param_to_type = fn_to_data.Get(func);
if (!param_to_type) {
return;
}
@@ -1121,7 +1114,7 @@
// Determine if this had a texture we cared about
auto type = param_to_type->Get(param);
- if (!type.has_value()) {
+ if (!type) {
continue;
}
@@ -1131,10 +1124,10 @@
tint::Switch(
texture_sem,
[&](const sem::GlobalVariable* global) {
- save_if_needed(global, type.value());
+ save_if_needed(global, *type);
},
[&](const sem::Parameter* p) {
- record_function_param(fn, p->Declaration(), type.value());
+ record_function_param(fn, p->Declaration(), *type);
},
TINT_ICE_ON_NO_MATCH);
}
diff --git a/src/tint/lang/wgsl/inspector/inspector_test.cc b/src/tint/lang/wgsl/inspector/inspector_test.cc
index 76d514e..e3d97bb 100644
--- a/src/tint/lang/wgsl/inspector/inspector_test.cc
+++ b/src/tint/lang/wgsl/inspector/inspector_test.cc
@@ -305,7 +305,7 @@
ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size());
- EXPECT_EQ(4u, result[0].workgroup_storage_size);
+ EXPECT_EQ(16u, result[0].workgroup_storage_size);
}
TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeCompoundTypes) {
@@ -338,7 +338,7 @@
ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size());
- EXPECT_EQ(72u, result[0].workgroup_storage_size);
+ EXPECT_EQ(96u, result[0].workgroup_storage_size);
}
TEST_F(InspectorGetEntryPointTest, WorkgroupStorageSizeAlignmentPadding) {
diff --git a/src/tint/lang/wgsl/intrinsic/data.cc b/src/tint/lang/wgsl/intrinsic/data.cc
index 73b6fa4..8266687 100644
--- a/src/tint/lang/wgsl/intrinsic/data.cc
+++ b/src/tint/lang/wgsl/intrinsic/data.cc
@@ -4698,41 +4698,46 @@
{
/* [29] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(65),
+ /* matcher_index */ TypeMatcherIndex(67),
},
{
/* [30] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(57),
+ /* matcher_index */ TypeMatcherIndex(65),
},
{
/* [31] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(58),
+ /* matcher_index */ TypeMatcherIndex(57),
},
{
/* [32] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(55),
+ /* matcher_index */ TypeMatcherIndex(58),
},
{
/* [33] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(56),
+ /* matcher_index */ TypeMatcherIndex(55),
},
{
/* [34] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(59),
+ /* matcher_index */ TypeMatcherIndex(56),
},
{
/* [35] */
/* name */ "T",
- /* matcher_index */ TypeMatcherIndex(54),
+ /* matcher_index */ TypeMatcherIndex(59),
},
{
/* [36] */
/* name */ "T",
+ /* matcher_index */ TypeMatcherIndex(54),
+ },
+ {
+ /* [37] */
+ /* name */ "T",
/* matcher_index */ TypeMatcherIndex(71),
},
};
@@ -5554,7 +5559,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(54),
@@ -6451,7 +6456,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(10),
@@ -6737,7 +6742,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(74),
@@ -7413,7 +7418,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(104),
@@ -7426,7 +7431,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(386),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(104),
@@ -7491,7 +7496,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(110),
@@ -7504,7 +7509,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(389),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(110),
@@ -7569,7 +7574,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(116),
@@ -7582,7 +7587,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(392),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(116),
@@ -7647,7 +7652,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(122),
@@ -7660,7 +7665,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(395),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(122),
@@ -7725,7 +7730,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(128),
@@ -7738,7 +7743,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(398),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(128),
@@ -7803,7 +7808,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(134),
@@ -7816,7 +7821,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(401),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(134),
@@ -7881,7 +7886,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(140),
@@ -7894,7 +7899,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(404),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(140),
@@ -7959,7 +7964,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(146),
@@ -7972,7 +7977,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(407),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(146),
@@ -8037,7 +8042,7 @@
/* num_parameters */ 0,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(/* invalid */),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(152),
@@ -8050,7 +8055,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(36),
+ /* template_types */ TemplateTypeIndex(37),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(410),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(152),
@@ -8440,7 +8445,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -8453,7 +8458,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -8492,7 +8497,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -8505,7 +8510,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -8622,7 +8627,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(30),
+ /* template_types */ TemplateTypeIndex(31),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(31),
@@ -8661,7 +8666,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(31),
+ /* template_types */ TemplateTypeIndex(32),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(33),
@@ -8700,7 +8705,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(32),
+ /* template_types */ TemplateTypeIndex(33),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(15),
@@ -8739,7 +8744,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(33),
+ /* template_types */ TemplateTypeIndex(34),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(87),
@@ -8778,7 +8783,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(34),
+ /* template_types */ TemplateTypeIndex(35),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(9),
@@ -10247,7 +10252,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10260,7 +10265,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10299,7 +10304,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(1),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10312,7 +10317,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(149),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10481,7 +10486,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(16),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10494,7 +10499,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(351),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10507,7 +10512,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(16),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -10520,7 +10525,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 1,
- /* template_types */ TemplateTypeIndex(29),
+ /* template_types */ TemplateTypeIndex(30),
/* template_numbers */ TemplateNumberIndex(4),
/* parameters */ ParameterIndex(351),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(6),
@@ -10975,7 +10980,7 @@
/* num_parameters */ 2,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(0),
+ /* template_types */ TemplateTypeIndex(29),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(348),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
@@ -11027,7 +11032,7 @@
/* num_parameters */ 1,
/* num_template_types */ 1,
/* num_template_numbers */ 0,
- /* template_types */ TemplateTypeIndex(35),
+ /* template_types */ TemplateTypeIndex(36),
/* template_numbers */ TemplateNumberIndex(/* invalid */),
/* parameters */ ParameterIndex(213),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(158),
@@ -11972,7 +11977,7 @@
},
{
/* [121] */
- /* fn subgroupBroadcast<T : fiu32>(value: T, @const sourceLaneIndex: u32) -> T */
+ /* fn subgroupBroadcast<T : fiu32_f16>(value: T, @const sourceLaneIndex: u32) -> T */
/* num overloads */ 1,
/* overloads */ OverloadIndex(466),
},
diff --git a/src/tint/lang/wgsl/resolver/assignment_validation_test.cc b/src/tint/lang/wgsl/resolver/assignment_validation_test.cc
index aeb4d61..826fbd0 100644
--- a/src/tint/lang/wgsl/resolver/assignment_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/assignment_validation_test.cc
@@ -222,7 +222,7 @@
Assign(Expr(Source{{12, 34}}, 1_i), "my_var"));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: cannot assign to value expression of type 'i32')");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot assign to value of type 'i32')");
}
TEST_F(ResolverAssignmentValidationTest, AssignToOverride_Fail) {
@@ -292,7 +292,7 @@
Assign(MemberAccessor(Source{{12, 34}}, Expr(Source{{56, 78}}, "a"), "i"), 2_i));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: cannot assign to value expression of type 'i32'
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot assign to value of type 'i32'
56:78 note: 'let' variables are immutable
98:76 note: let 'a' declared here)");
}
diff --git a/src/tint/lang/wgsl/resolver/call_validation_test.cc b/src/tint/lang/wgsl/resolver/call_validation_test.cc
index 9a0d7c1..d7bdc4e 100644
--- a/src/tint/lang/wgsl/resolver/call_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/call_validation_test.cc
@@ -143,7 +143,7 @@
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of let 'z'");
}
TEST_F(ResolverCallValidationTest,
@@ -214,7 +214,7 @@
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+ EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of value of type 'i32'");
}
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
diff --git a/src/tint/lang/wgsl/resolver/compound_assignment_validation_test.cc b/src/tint/lang/wgsl/resolver/compound_assignment_validation_test.cc
index ada0fb4..578cd22 100644
--- a/src/tint/lang/wgsl/resolver/compound_assignment_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/compound_assignment_validation_test.cc
@@ -276,7 +276,7 @@
WrapInFunction(CompoundAssign(Expr(Source{{56, 78}}, 1_i), 1_i, core::BinaryOp::kAdd));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(56:78 error: cannot assign to value expression of type 'i32')");
+ EXPECT_EQ(r()->error(), R"(56:78 error: cannot assign to value of type 'i32')");
}
TEST_F(ResolverCompoundAssignmentValidationTest, LhsAtomic) {
diff --git a/src/tint/lang/wgsl/resolver/dependency_graph.cc b/src/tint/lang/wgsl/resolver/dependency_graph.cc
index 890d024..f8b5425 100644
--- a/src/tint/lang/wgsl/resolver/dependency_graph.cc
+++ b/src/tint/lang/wgsl/resolver/dependency_graph.cc
@@ -166,7 +166,7 @@
graph_(graph),
dependency_edges_(edges) {
// Register all the globals at global-scope
- for (auto it : globals_by_name) {
+ for (auto& it : globals_by_name) {
scope_stack_.Set(it.key, it.value->node);
}
}
@@ -242,7 +242,7 @@
TINT_DEFER(scope_stack_.Pop());
for (auto* param : func->params) {
- if (auto* shadows = scope_stack_.Get(param->name->symbol)) {
+ if (auto shadows = scope_stack_.Get(param->name->symbol)) {
graph_.shadows.Add(param, shadows);
}
Declare(param->name->symbol, param);
@@ -465,7 +465,7 @@
/// @param symbol the symbol
/// @returns the builtin info
DependencyScanner::BuiltinInfo GetBuiltinInfo(Symbol symbol) {
- return builtin_info_map.GetOrCreate(symbol, [&] {
+ return builtin_info_map.GetOrAdd(symbol, [&] {
if (auto builtin_fn = wgsl::ParseBuiltinFn(symbol.NameView());
builtin_fn != wgsl::BuiltinFn::kNone) {
return BuiltinInfo{BuiltinType::kFunction, builtin_fn};
@@ -549,7 +549,7 @@
return;
}
- if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
+ if (auto global = globals_.Get(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
DependencyInfo{from->source})) {
current_global_->deps.Push(*global);
@@ -765,7 +765,7 @@
/// of global `from` depending on `to`.
/// @note will raise an ICE if the edge is not found.
DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
- auto info = dependency_edges_.Find(DependencyEdge{from, to});
+ auto info = dependency_edges_.Get(DependencyEdge{from, to});
if (TINT_LIKELY(info)) {
return *info;
}
@@ -819,7 +819,7 @@
printf("------ dependencies ------ \n");
for (auto* node : sorted_) {
auto symbol = SymbolOf(node);
- auto* global = *globals_.Find(symbol);
+ auto* global = *globals_.Get(symbol);
printf("%s depends on:\n", symbol.Name().c_str());
for (auto* dep : global->deps) {
printf(" %s\n", NameOf(dep->node).c_str());
diff --git a/src/tint/lang/wgsl/resolver/dependency_graph_test.cc b/src/tint/lang/wgsl/resolver/dependency_graph_test.cc
index 360ffda..5e79a3e 100644
--- a/src/tint/lang/wgsl/resolver/dependency_graph_test.cc
+++ b/src/tint/lang/wgsl/resolver/dependency_graph_test.cc
@@ -757,8 +757,8 @@
auto graph = Build();
- auto resolved_identifier = graph.resolved_identifiers.Find(ident);
- ASSERT_NE(resolved_identifier, nullptr);
+ auto resolved_identifier = graph.resolved_identifiers.Get(ident);
+ ASSERT_TRUE(resolved_identifier);
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
EXPECT_EQ(unresolved->name, "SYMBOL");
@@ -805,7 +805,7 @@
WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(ident), 123_i))));
auto graph = Build();
- auto resolved_identifier = graph.resolved_identifiers.Find(ident);
+ auto resolved_identifier = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved_identifier);
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
@@ -818,7 +818,7 @@
WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(ident), 123_i))));
auto graph = Build();
- auto resolved_identifier = graph.resolved_identifiers.Find(ident);
+ auto resolved_identifier = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved_identifier);
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
@@ -1155,7 +1155,7 @@
bool expect_resolved = ScopeDepth(decl_kind) <= ScopeDepth(use_kind);
auto graph = Build();
- auto resolved_identifier = graph.resolved_identifiers.Find(use);
+ auto resolved_identifier = graph.resolved_identifiers.Get(use);
ASSERT_TRUE(resolved_identifier);
if (expect_resolved) {
@@ -1205,7 +1205,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->BuiltinFn(), builtin) << resolved->String();
}
@@ -1244,7 +1245,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->BuiltinType(), core::ParseBuiltinType(name)) << resolved->String();
}
@@ -1283,7 +1285,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Access(), core::ParseAccess(name)) << resolved->String();
}
@@ -1322,7 +1325,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->AddressSpace(), core::ParseAddressSpace(name)) << resolved->String();
}
@@ -1361,7 +1365,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->BuiltinValue(), core::ParseBuiltinValue(name)) << resolved->String();
}
@@ -1400,7 +1405,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->InterpolationSampling(), core::ParseInterpolationSampling(name))
<< resolved->String();
@@ -1440,7 +1446,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->InterpolationType(), core::ParseInterpolationType(name))
<< resolved->String();
@@ -1480,7 +1487,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->TexelFormat(), core::ParseTexelFormat(name)) << resolved->String();
}
@@ -1527,7 +1535,7 @@
helper.Build();
auto shadows = Build().shadows;
- auto shadow = shadows.Find(inner_var);
+ auto shadow = shadows.Get(inner_var);
ASSERT_TRUE(shadow);
EXPECT_EQ(*shadow, outer);
}
@@ -1561,7 +1569,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << resolved->String();
}
@@ -1578,7 +1587,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << resolved->String();
}
@@ -1593,7 +1603,8 @@
auto* ident = helper.Add(use, symbol);
helper.Build();
- auto resolved = Build().resolved_identifiers.Get(ident);
+ auto graph = Build();
+ auto resolved = graph.resolved_identifiers.Get(ident);
ASSERT_TRUE(resolved);
EXPECT_EQ(resolved->Node(), decl) << resolved->String();
}
@@ -1760,7 +1771,7 @@
auto graph = Build();
for (auto use : symbol_uses) {
- auto resolved_identifier = graph.resolved_identifiers.Find(use.use);
+ auto resolved_identifier = graph.resolved_identifiers.Get(use.use);
ASSERT_TRUE(resolved_identifier) << use.where;
EXPECT_EQ(*resolved_identifier, use.decl) << use.where;
}
diff --git a/src/tint/lang/wgsl/resolver/ptr_ref_validation_test.cc b/src/tint/lang/wgsl/resolver/ptr_ref_validation_test.cc
index cc61eb3..488b5c6 100644
--- a/src/tint/lang/wgsl/resolver/ptr_ref_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/ptr_ref_validation_test.cc
@@ -49,7 +49,7 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of value of type 'i32')");
}
TEST_F(ResolverPtrRefValidationTest, AddressOfLet) {
@@ -62,7 +62,45 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of let 'l')");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfConst) {
+ // const c : i32 = 1;
+ // &c
+ auto* l = Const("c", ty.i32(), Expr(1_i));
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "c"));
+
+ WrapInFunction(l, expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of const 'c')");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfOverride) {
+ // override c : i32;
+ // &o
+ Override("o", ty.i32());
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "o"));
+
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of override 'o')");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfParameter) {
+ // fn F(p : i32) { _ = &p }
+ // &F
+ Func("F", Vector{Param("p", ty.i32())}, ty.void_(),
+ Vector{
+ Assign(Phony(), AddressOf(Expr(Source{{12, 34}}, "p"))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of parameter 'p')");
}
TEST_F(ResolverPtrRefValidationTest, AddressOfHandle) {
@@ -75,8 +113,111 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: cannot take the address of expression in handle "
- "address space");
+ R"(12:34 error: cannot take the address of var 't' in handle address space)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfFunction) {
+ // fn F() {}
+ // &F
+ Func("F", Empty, ty.void_(), Empty);
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "F"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cannot use function 'F' as value
+note: function 'F' declared here
+12:34 note: are you missing '()'?)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfBuiltinFunction) {
+ // &max
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "max"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cannot use builtin function 'max' as value
+12:34 note: are you missing '()'?)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfType) {
+ // &i32
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "i32"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cannot use type 'i32' as value
+12:34 note: are you missing '()'?)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfTypeAlias) {
+ // alias T = i32
+ // &T
+ Alias("T", ty.i32());
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "T"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cannot use type 'i32' as value
+12:34 note: are you missing '()'?)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfAccess) {
+ // &read_write
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "read_write"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot use access 'read_write' as value)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfAddressSpace) {
+ // &handle
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "uniform"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot use address space 'uniform' as value)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfBuiltinValue) {
+ // &position
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "position"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot use builtin value 'position' as value)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfInterpolationSampling) {
+ // ¢roid
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "centroid"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: cannot use interpolation sampling 'centroid' as value)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfInterpolationType) {
+ // &perspective
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "perspective"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot use interpolation type 'perspective' as value)");
+}
+
+TEST_F(ResolverPtrRefValidationTest, AddressOfTexelFormat) {
+ // &rgba8snorm
+ auto* expr = AddressOf(Expr(Source{{12, 34}}, "rgba8snorm"));
+ WrapInFunction(expr);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot use texel format 'rgba8snorm' as value)");
}
TEST_F(ResolverPtrRefValidationTest, AddressOfVectorComponent_MemberAccessor) {
@@ -89,7 +230,7 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of a vector component");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of a vector component)");
}
TEST_F(ResolverPtrRefValidationTest, AddressOfVectorComponent_IndexAccessor) {
@@ -102,7 +243,7 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of a vector component");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot take the address of a vector component)");
}
TEST_F(ResolverPtrRefValidationTest, IndirectOfAddressOfHandle) {
@@ -115,8 +256,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: cannot take the address of expression in handle "
- "address space");
+ R"(12:34 error: cannot take the address of var 't' in handle address space)");
}
TEST_F(ResolverPtrRefValidationTest, DerefOfLiteral) {
@@ -128,7 +268,7 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot dereference expression of type 'i32'");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot dereference expression of type 'i32')");
}
TEST_F(ResolverPtrRefValidationTest, DerefOfVar) {
@@ -141,7 +281,7 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot dereference expression of type 'i32'");
+ EXPECT_EQ(r()->error(), R"(12:34 error: cannot dereference expression of type 'i32')");
}
TEST_F(ResolverPtrRefValidationTest, InferredPtrAccessMismatch) {
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index e8b9bcc..3e717d4 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -911,7 +911,7 @@
}
void Resolver::SetShadows() {
- for (auto it : dependencies_.shadows) {
+ for (auto& it : dependencies_.shadows) {
CastableBase* shadowed = sem_.Get(it.value);
if (TINT_UNLIKELY(!shadowed)) {
StringStream err;
@@ -921,7 +921,7 @@
}
Switch(
- sem_.Get(it.key), //
+ sem_.Get(it.key.Value()), //
[&](sem::LocalVariable* local) { local->SetShadows(shadowed); },
[&](sem::Parameter* param) { param->SetShadows(shadowed); });
}
@@ -1019,7 +1019,7 @@
if (auto added = parameter_names.Add(param->name->symbol, param->source); !added) {
auto name = param->name->symbol.Name();
AddError("redefinition of parameter '" + name + "'", param->source);
- AddNote("previous definition is here", *added.value);
+ AddNote("previous definition is here", added.value);
return nullptr;
}
}
@@ -1560,7 +1560,7 @@
// If we just processed the lhs of a constexpr logical binary expression, mark the rhs for
// short-circuiting.
if (val && val->ConstantValue()) {
- if (auto binary = logical_binary_lhs_to_parent_.Find(expr)) {
+ if (auto binary = logical_binary_lhs_to_parent_.Get(expr)) {
const bool lhs_is_true = val->ConstantValue()->ValueAs<bool>();
if (((*binary)->IsLogicalAnd() && !lhs_is_true) ||
((*binary)->IsLogicalOr() && lhs_is_true)) {
@@ -2131,7 +2131,7 @@
// Is this overload a constructor or conversion?
if (match->info->flags.Contains(OverloadFlag::kIsConstructor)) {
// Type constructor
- target_sem = constructors_.GetOrCreate(match.Get(), [&] {
+ target_sem = constructors_.GetOrAdd(match.Get(), [&] {
auto params = Transform(match->parameters, [&](auto& p, size_t i) {
return b.create<sem::Parameter>(nullptr, static_cast<uint32_t>(i), p.type,
p.usage);
@@ -2141,7 +2141,7 @@
});
} else {
// Type conversion
- target_sem = converters_.GetOrCreate(match.Get(), [&] {
+ target_sem = converters_.GetOrAdd(match.Get(), [&] {
auto* param = b.create<sem::Parameter>(nullptr, 0u, match->parameters[0].type,
match->parameters[0].usage);
return b.create<sem::ValueConversion>(match->return_type, param, overload_stage);
@@ -2233,7 +2233,7 @@
m->type());
},
[&](const sem::Array* arr) -> sem::Call* {
- auto* call_target = array_ctors_.GetOrCreate(
+ auto* call_target = array_ctors_.GetOrAdd(
ArrayConstructorSig{{arr, args.Length(), args_stage}},
[&]() -> sem::ValueConstructor* {
auto params = tint::Transform(args, [&](auto, size_t i) {
@@ -2255,7 +2255,7 @@
return arr_or_str_init(arr, call_target);
},
[&](const core::type::Struct* str) -> sem::Call* {
- auto* call_target = struct_ctors_.GetOrCreate(
+ auto* call_target = struct_ctors_.GetOrAdd(
StructConstructorSig{{str, args.Length(), args_stage}},
[&]() -> sem::ValueConstructor* {
Vector<sem::Parameter*, 8> params;
@@ -2398,7 +2398,7 @@
}
// De-duplicate builtins that are identical.
- auto* target = builtins_.GetOrCreate(std::make_pair(overload.Get(), fn), [&] {
+ auto* target = builtins_.GetOrAdd(std::make_pair(overload.Get(), fn), [&] {
auto params = Transform(overload->parameters, [&](auto& p, size_t i) {
return b.create<sem::Parameter>(nullptr, static_cast<uint32_t>(i), p.type, p.usage);
});
@@ -3312,7 +3312,7 @@
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
auto symbol = ident->symbol;
- if (auto decl = loop_block->Decls().Find(symbol)) {
+ if (auto decl = loop_block->Decls().Get(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
symbol.Name() + "'",
@@ -3694,7 +3694,8 @@
case core::UnaryOp::kAddressOf:
if (auto* ref = expr_ty->As<core::type::Reference>()) {
if (ref->StoreType()->UnwrapRef()->is_handle()) {
- AddError("cannot take the address of expression in handle address space",
+ AddError("cannot take the address of " + sem_.Describe(expr) +
+ " in handle address space",
unary->expr->source);
return nullptr;
}
@@ -3713,7 +3714,7 @@
root_ident = expr->RootIdentifier();
} else {
- AddError("cannot take the address of expression", unary->expr->source);
+ AddError("cannot take the address of " + sem_.Describe(expr), unary->expr->source);
return nullptr;
}
break;
@@ -4314,7 +4315,7 @@
Mark(member->name);
if (auto added = member_map.Add(member->name->symbol, member); !added) {
AddError("redefinition of '" + member->name->symbol.Name() + "'", member->source);
- AddNote("previous definition is here", (*added.value)->source);
+ AddNote("previous definition is here", added.value->source);
return nullptr;
}
diff --git a/src/tint/lang/wgsl/resolver/resolver_test.cc b/src/tint/lang/wgsl/resolver/resolver_test.cc
index 8eb0457..fe4027e 100644
--- a/src/tint/lang/wgsl/resolver/resolver_test.cc
+++ b/src/tint/lang/wgsl/resolver/resolver_test.cc
@@ -138,7 +138,7 @@
WrapInFunction(cond_var, Switch("i", Case(CaseSelector(AddressOf(1_a)), Block())));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "error: cannot take the address of expression");
+ EXPECT_EQ(r()->error(), "error: cannot take the address of value of type 'abstract-int'");
}
TEST_F(ResolverTest, Stmt_Block) {
@@ -2411,7 +2411,7 @@
});
ASSERT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "error: cannot take the address of expression in handle address space");
+ EXPECT_EQ(r()->error(), "error: cannot take the address of var 's' in handle address space");
}
TEST_F(ResolverTest, ModuleDependencyOrderedDeclarations) {
diff --git a/src/tint/lang/wgsl/resolver/sem_helper.cc b/src/tint/lang/wgsl/resolver/sem_helper.cc
index a68734a..d6e8b89 100644
--- a/src/tint/lang/wgsl/resolver/sem_helper.cc
+++ b/src/tint/lang/wgsl/resolver/sem_helper.cc
@@ -94,7 +94,7 @@
},
[&](const sem::ValueExpression* val_expr) {
auto type = val_expr->Type()->FriendlyName();
- return "value expression of type '" + type + "'";
+ return "value of type '" + type + "'";
},
[&](const sem::TypeExpression* ty_expr) {
auto name = ty_expr->Type()->FriendlyName();
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index 709da02..cf03bc7 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -294,8 +294,8 @@
/// @returns a LoopSwitchInfo for the given statement, allocating the LoopSwitchInfo if this is
/// the first call with the given statement.
LoopSwitchInfo& LoopSwitchInfoFor(const sem::Statement* stmt) {
- return *loop_switch_infos.GetOrCreate(stmt,
- [&] { return loop_switch_info_allocator.Create(); });
+ return *loop_switch_infos.GetOrAdd(stmt,
+ [&] { return loop_switch_info_allocator.Create(); });
}
/// Disassociates the LoopSwitchInfo for the given statement.
@@ -431,7 +431,7 @@
/// @param func the function to process
/// @returns true if there are no uniformity issues, false otherwise
bool ProcessFunction(const ast::Function* func) {
- current_function_ = functions_.Add(func, FunctionInfo(func, b)).value;
+ current_function_ = &functions_.Add(func, FunctionInfo(func, b)).value;
// Process function body.
if (func->body) {
@@ -627,7 +627,7 @@
// 'Next'.
auto& behaviors = sem->Behaviors();
if (behaviors.Contains(sem::Behavior::kNext)) {
- for (auto var : scoped_assignments) {
+ for (auto& var : scoped_assignments) {
current_function_->variables.Set(var.key, var.value);
}
}
@@ -649,7 +649,7 @@
auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop/switch exit nodes.
- for (auto* var : current_function_->local_var_decls) {
+ for (auto& var : current_function_->local_var_decls) {
// Skip variables that were declared inside this loop/switch.
if (auto* lv = var->As<sem::LocalVariable>();
lv &&
@@ -658,7 +658,7 @@
}
// Add an edge from the variable exit node to its value at this point.
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
+ auto* exit_node = info.var_exit_nodes.GetOrAdd(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -686,7 +686,7 @@
auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop exit nodes.
- for (auto* var : current_function_->local_var_decls) {
+ for (auto& var : current_function_->local_var_decls) {
// Skip variables that were declared inside this loop.
if (auto* lv = var->As<sem::LocalVariable>();
lv && lv->Statement()->FindFirstParent(
@@ -695,7 +695,7 @@
}
// Add an edge from the variable exit node to its value at this point.
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
+ auto* exit_node = info.var_exit_nodes.GetOrAdd(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -780,7 +780,7 @@
info.type = "forloop";
// Create input nodes for any variables declared before this loop.
- for (auto* v : current_function_->local_var_decls) {
+ for (auto& v : current_function_->local_var_decls) {
auto* in_node = CreateNode({NameFor(v), "_value_forloop_in"});
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
@@ -796,8 +796,8 @@
cf_start = cf_condition_end;
// Propagate assignments to the loop exit nodes.
- for (auto* var : current_function_->local_var_decls) {
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
+ for (auto& var : current_function_->local_var_decls) {
+ auto* exit_node = info.var_exit_nodes.GetOrAdd(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -816,7 +816,7 @@
cfx->AddEdge(cf);
// Add edges from variable loop input nodes to their values at the end of the loop.
- for (auto v : info.var_in_nodes) {
+ for (auto& v : info.var_in_nodes) {
auto* in_node = v.value;
auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
@@ -825,7 +825,7 @@
}
// Set each variable's exit node as its value in the outer scope.
- for (auto v : info.var_exit_nodes) {
+ for (auto& v : info.var_exit_nodes) {
current_function_->variables.Set(v.key, v.value);
}
@@ -855,7 +855,7 @@
info.type = "whileloop";
// Create input nodes for any variables declared before this loop.
- for (auto* v : current_function_->local_var_decls) {
+ for (auto& v : current_function_->local_var_decls) {
auto* in_node = CreateNode({NameFor(v), "_value_forloop_in"});
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
@@ -872,8 +872,8 @@
}
// Propagate assignments to the loop exit nodes.
- for (auto* var : current_function_->local_var_decls) {
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
+ for (auto& var : current_function_->local_var_decls) {
+ auto* exit_node = info.var_exit_nodes.GetOrAdd(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -951,7 +951,7 @@
}
// Update values for any variables assigned in the if or else blocks.
- for (auto* var : current_function_->local_var_decls) {
+ for (auto& var : current_function_->local_var_decls) {
// Skip variables not assigned in either block.
if (!true_vars.Contains(var) && !false_vars.Contains(var)) {
continue;
@@ -964,14 +964,14 @@
// Only add edges if the behavior for that block contains 'Next'.
if (true_has_next) {
if (true_vars.Contains(var)) {
- out_node->AddEdge(*true_vars.Find(var));
+ out_node->AddEdge(*true_vars.Get(var));
} else {
out_node->AddEdge(current_function_->variables.Get(var));
}
}
if (false_has_next) {
if (false_vars.Contains(var)) {
- out_node->AddEdge(*false_vars.Find(var));
+ out_node->AddEdge(*false_vars.Get(var));
} else {
out_node->AddEdge(current_function_->variables.Get(var));
}
@@ -1022,7 +1022,7 @@
info.type = "loop";
// Create input nodes for any variables declared before this loop.
- for (auto* v : current_function_->local_var_decls) {
+ for (auto& v : current_function_->local_var_decls) {
auto name = NameFor(v);
auto* in_node = CreateNode({name, "_value_loop_in"}, v->Declaration());
in_node->AddEdge(current_function_->variables.Get(v));
@@ -1112,7 +1112,7 @@
if (sem_case->Behaviors().Contains(sem::Behavior::kNext)) {
// Propagate variable values to the switch exit nodes.
- for (auto* var : current_function_->local_var_decls) {
+ for (auto& var : current_function_->local_var_decls) {
// Skip variables that were declared inside the switch.
if (auto* lv = var->As<sem::LocalVariable>();
lv && lv->Statement()->FindFirstParent(
@@ -1121,7 +1121,7 @@
}
// Add an edge from the variable exit node to its new value.
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
+ auto* exit_node = info.var_exit_nodes.GetOrAdd(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -1637,11 +1637,11 @@
[&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process
// functions in dependency order.
- auto info = functions_.Find(func->Declaration());
- TINT_ASSERT(info != nullptr);
+ auto info = functions_.Get(func->Declaration());
+ TINT_ASSERT(info);
callsite_tag = info->callsite_tag;
function_tag = info->function_tag;
- func_info = info;
+ func_info = info.value;
},
[&](const sem::ValueConstructor*) {
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
@@ -1798,7 +1798,7 @@
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
- auto target_info = functions_.Find(user->Declaration());
+ auto target_info = functions_.Get(user->Declaration());
for (auto* call_node : target_info->RequiredToBeUniform(severity)->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
@@ -1983,7 +1983,7 @@
auto* user_func = target->As<sem::Function>();
if (user_func) {
// Recurse into the called function to show the reason for the requirement.
- auto next_function = functions_.Find(user_func->Declaration());
+ auto next_function = functions_.Get(user_func->Declaration());
auto& param_info = next_function->parameters[cause->arg_index];
MakeError(*next_function,
is_value ? param_info.value : param_info.ptr_input_contents, severity);
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 0201bb4..6e0413c 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -802,7 +802,7 @@
}
if (auto id = v->Attributes().override_id) {
- if (auto var = override_ids.Find(*id); var && *var != v) {
+ if (auto var = override_ids.Get(*id); var && *var != v) {
auto* attr = ast::GetAttribute<ast::IdAttribute>(v->Declaration()->attributes);
AddError("@id values must be unique", attr->source);
AddNote("a override with an ID of " + std::to_string(id->value) +
@@ -1485,7 +1485,7 @@
!added &&
IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kBindingPointCollision) &&
- IsValidationEnabled((*added.value)->attributes,
+ IsValidationEnabled(added.value->attributes,
ast::DisabledValidation::kBindingPointCollision)) {
// https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
// Bindings must not alias within a shader stage: two different variables in the
@@ -1497,7 +1497,7 @@
"' references multiple variables that use the same resource binding @group(" +
std::to_string(bp->group) + "), @binding(" + std::to_string(bp->binding) + ")",
var_decl->source);
- AddNote("first resource binding usage declared here", (*added.value)->source);
+ AddNote("first resource binding usage declared here", added.value->source);
return false;
}
}
@@ -2524,7 +2524,7 @@
: std::to_string(value)) +
"'",
selector->Declaration()->source);
- AddNote("previous case declared here", *added.value);
+ AddNote("previous case declared here", added.value);
return false;
}
}
@@ -2686,7 +2686,7 @@
auto added = seen.Add(&d->TypeInfo(), d->source);
if (!added && !d->Is<ast::InternalAttribute>()) {
AddError("duplicate " + d->Name() + " attribute", d->source);
- AddNote("first attribute declared here", *added.value);
+ AddNote("first attribute declared here", added.value);
return false;
}
}
@@ -2704,7 +2704,7 @@
auto name = dc->rule_name->name->symbol;
auto diag_added = diagnostics.Add(std::make_pair(category, name), dc);
- if (!diag_added && (*diag_added.value)->severity != dc->severity) {
+ if (!diag_added && diag_added.value->severity != dc->severity) {
{
StringStream ss;
ss << "conflicting diagnostic " << use;
@@ -2714,7 +2714,7 @@
StringStream ss;
ss << "severity of '" << dc->rule_name->String() << "' set to '" << dc->severity
<< "' here";
- AddNote(ss.str(), (*diag_added.value)->rule_name->source);
+ AddNote(ss.str(), diag_added.value->rule_name->source);
}
return false;
}
diff --git a/src/tint/lang/wgsl/wgsl.def b/src/tint/lang/wgsl/wgsl.def
index e23de8a..3e3282c 100644
--- a/src/tint/lang/wgsl/wgsl.def
+++ b/src/tint/lang/wgsl/wgsl.def
@@ -736,7 +736,7 @@
@stage("fragment", "compute") fn atomicCompareExchangeWeak<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T, T) -> __atomic_compare_exchange_result<T>
@must_use @stage("compute") fn subgroupBallot() -> vec4<u32>
-@must_use @stage("compute") fn subgroupBroadcast<T: fiu32>(value: T, @const sourceLaneIndex: u32) -> T
+@must_use @stage("compute") fn subgroupBroadcast<T: fiu32_f16>(value: T, @const sourceLaneIndex: u32) -> T
////////////////////////////////////////////////////////////////////////////////
// Value constructors //
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index c6498a1..028fee4 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -886,7 +886,7 @@
return {Constant(c), PtrKind::kRef};
},
[&](Default) -> ExprAndPtrKind {
- auto lookup = bindings_.Find(value);
+ auto lookup = bindings_.Get(value);
if (TINT_UNLIKELY(!lookup)) {
TINT_ICE() << "Expr(" << (value ? value->TypeInfo().name : "null")
<< ") value has no expression";
@@ -1064,7 +1064,7 @@
}
ast::Type Struct(const core::type::Struct* s) {
- auto n = structs_.GetOrCreate(s, [&] {
+ auto n = structs_.GetOrAdd(s, [&] {
auto members = tint::Transform<8>(s->Members(), [&](const core::type::StructMember* m) {
auto ty = Type(m->Type());
const auto& ir_attrs = m->Attributes();
@@ -1125,7 +1125,7 @@
/// @returns the AST name for the given value, creating and returning a new name on the first
/// call.
Symbol NameFor(const core::ir::Value* value, std::string_view suggested = {}) {
- return names_.GetOrCreate(value, [&] {
+ return names_.GetOrAdd(value, [&] {
if (!suggested.empty()) {
return b.Symbols().Register(suggested);
}
diff --git a/src/tint/lang/wgsl/writer/raise/rename_conflicts.cc b/src/tint/lang/wgsl/writer/raise/rename_conflicts.cc
index f54efd1..a6072ed 100644
--- a/src/tint/lang/wgsl/writer/raise/rename_conflicts.cc
+++ b/src/tint/lang/wgsl/writer/raise/rename_conflicts.cc
@@ -241,13 +241,13 @@
void EnsureResolvesTo(std::string_view identifier, const CastableBase* thing) {
for (auto& scope : tint::Reverse(scopes)) {
if (auto decl = scope.Get(identifier)) {
- if (decl.value() == thing) {
+ if (*decl == thing) {
return; // Resolved to the right thing.
}
// Operand is shadowed
scope.Remove(identifier);
- Rename(decl.value(), identifier);
+ Rename(*decl, identifier);
}
}
}
@@ -257,7 +257,7 @@
/// renamed.
void Declare(Scope& scope, CastableBase* thing, std::string_view name) {
auto add = scope.Add(name, thing);
- if (!add && *add.value != thing) {
+ if (!add && add.value != thing) {
// Multiple declarations with the same name in the same scope.
// Rename the later declaration.
Rename(thing, name);
diff --git a/src/tint/utils/cli/cli.cc b/src/tint/utils/cli/cli.cc
index 557d8bc..b6dc911 100644
--- a/src/tint/utils/cli/cli.cc
+++ b/src/tint/utils/cli/cli.cc
@@ -169,7 +169,7 @@
unconsumed.Push(arg);
continue;
}
- if (auto opt = options_by_name.Find(name)) {
+ if (auto opt = options_by_name.Get(name)) {
if (auto err = (*opt)->Parse(arguments); !err.empty()) {
return Failure{err};
}
diff --git a/src/tint/utils/containers/hashmap.h b/src/tint/utils/containers/hashmap.h
index 1b1a465..58f653d 100644
--- a/src/tint/utils/containers/hashmap.h
+++ b/src/tint/utils/containers/hashmap.h
@@ -39,128 +39,215 @@
namespace tint {
-/// An unordered map that uses a robin-hood hashing algorithm.
+/// HashmapEntry is a key-value pair used by Hashmap as the Entry datatype.
+template <typename KEY, typename VALUE>
+struct HashmapEntry {
+ /// The key type
+ using Key = KEY;
+ /// The value type
+ using Value = VALUE;
+
+ /// @param entry a HashmapEntry
+ /// @return `entry.key`
+ static const Key& KeyOf(const HashmapEntry& entry) { return entry.key; }
+
+ /// The key
+ Key key;
+
+ /// The value
+ Value value;
+};
+
+/// Equality operator for HashmapEntry
+/// @param lhs the LHS HashmapEntry
+/// @param rhs the RHS HashmapEntry
+/// @return true if both entries have equal keys and values.
+template <typename K1, typename V1, typename K2, typename V2>
+inline static bool operator==(const HashmapEntry<K1, V1>& lhs, const HashmapEntry<K2, V2>& rhs) {
+ return lhs.key == rhs.key && lhs.value == rhs.value;
+}
+
+/// Writes the HashmapEntry to the stream.
+/// @param out the stream to write to
+/// @param key_value the HashmapEntry to write
+/// @returns out so calls can be chained
+template <typename STREAM,
+ typename KEY,
+ typename VALUE,
+ typename = traits::EnableIfIsOStream<STREAM>>
+auto& operator<<(STREAM& out, const HashmapEntry<KEY, VALUE>& key_value) {
+ return out << "[" << key_value.key << ": " << key_value.value << "]";
+}
+
+/// The return value of Hashmap::Get(Key).
+/// GetResult supports equality operators and acts similarly to std::optional, but does not make a
+/// copy.
+template <typename T>
+struct GetResult {
+ /// The value found in the map, or null if the entry was not found.
+ /// This pointer is guaranteed to be valid until the owning entry is removed, the map is
+ /// cleared, or the map is destructed.
+ T* value = nullptr;
+
+ /// @returns `true` if #value is not null.
+ operator bool() const { return value; }
+
+ /// @returns the dereferenced value, which must not be null.
+ T& operator*() const { return *value; }
+
+ /// @returns the pointer to the value, which must not be null.
+ T* operator->() const { return value; }
+
+ /// @param other the value to compare against the object that #value points to.
+ /// @returns true if #value is not null and the object that #value points to is equal to @p
+ /// other.
+ template <typename O>
+ bool operator==(const O& other) const {
+ return value && *value == other;
+ }
+
+ /// @param other the value to compare against the object that #value points to.
+ /// @returns true if #value is null or the object that #value points to is not equal to @p
+ /// other.
+ template <typename O>
+ bool operator!=(const O& other) const {
+ return !value || *value != other;
+ }
+};
+
+/// The return value of Hashmap::Add(Key, Value)
+template <typename T>
+struct AddResult {
+ /// A reference to the value of the entry with the given key.
+ /// If an existing entry was found with the key, then this is the value of the existing entry,
+ /// otherwise the value of the newly inserted entry.
+ /// This reference is guaranteed to be valid until the owning entry is removed, the map is
+ /// cleared, or the map is destructed.
+ T& value;
+
+ /// True if an entry did not already exist in the map with the given key.
+ bool added = false;
+
+ /// @returns #added
+ operator bool() const { return added; }
+};
+
+/// An unordered hashmap, with a fixed-size capacity that avoids heap allocations.
template <typename KEY,
typename VALUE,
size_t N,
typename HASH = Hasher<KEY>,
typename EQUAL = EqualTo<KEY>>
-class Hashmap : public HashmapBase<KEY, VALUE, N, HASH, EQUAL> {
- using Base = HashmapBase<KEY, VALUE, N, HASH, EQUAL>;
- using PutMode = typename Base::PutMode;
-
- template <typename T>
- using ReferenceKeyType = traits::CharArrayToCharPtr<std::remove_reference_t<T>>;
+class Hashmap : public HashmapBase<HashmapEntry<HashmapKey<KEY, HASH, EQUAL>, VALUE>, N> {
+ using Base = HashmapBase<HashmapEntry<HashmapKey<KEY, HASH, EQUAL>, VALUE>, N>;
public:
/// The key type
- using Key = KEY;
+ using Key = typename Base::Key;
/// The value type
using Value = VALUE;
/// The key-value type for a map entry
- using Entry = KeyValue<Key, Value>;
+ using Entry = HashmapEntry<Key, Value>;
- /// Result of Add()
- using AddResult = typename Base::PutResult;
-
- /// Reference is returned by Hashmap::Find(), and performs dynamic Hashmap lookups.
- /// The value returned by the Reference reflects the current state of the Hashmap, and so the
- /// referenced value may change, or transition between valid or invalid based on the current
- /// state of the Hashmap.
- template <bool IS_CONST, typename K>
- class ReferenceT {
- /// `const Value` if IS_CONST, or `Value` if !IS_CONST
- using T = std::conditional_t<IS_CONST, const Value, Value>;
-
- /// `const Hashmap` if IS_CONST, or `Hashmap` if !IS_CONST
- using Map = std::conditional_t<IS_CONST, const Hashmap, Hashmap>;
-
- public:
- /// @returns true if the reference is valid.
- operator bool() const { return Get() != nullptr; }
-
- /// @returns the pointer to the Value, or nullptr if the reference is invalid.
- operator T*() const { return Get(); }
-
- /// @returns the pointer to the Value
- /// @warning if the Hashmap does not contain a value for the reference, then this will
- /// trigger a TINT_ASSERT, or invalid pointer dereference.
- T* operator->() const {
- auto* hashmap_reference_lookup = Get();
- TINT_ASSERT(hashmap_reference_lookup != nullptr);
- return hashmap_reference_lookup;
- }
-
- /// @returns the pointer to the Value, or nullptr if the reference is invalid.
- T* Get() const {
- auto generation = map_.Generation();
- if (generation_ != generation) {
- cached_ = map_.Lookup(key_);
- generation_ = generation;
- }
- return cached_;
- }
-
- private:
- friend Hashmap;
-
- /// Constructor
- template <typename K_ARG>
- ReferenceT(Map& map, K_ARG&& key)
- : map_(map),
- key_(std::forward<K_ARG>(key)),
- cached_(nullptr),
- generation_(map.Generation() - 1) {}
-
- /// Constructor
- template <typename K_ARG>
- ReferenceT(Map& map, K_ARG&& key, T* value)
- : map_(map),
- key_(std::forward<K_ARG>(key)),
- cached_(value),
- generation_(map.Generation()) {}
-
- Map& map_;
- const K key_;
- mutable T* cached_ = nullptr;
- mutable size_t generation_ = 0;
- };
-
- /// A mutable reference returned by Find()
- template <typename K>
- using Reference = ReferenceT</*IS_CONST*/ false, K>;
-
- /// An immutable reference returned by Find()
- template <typename K>
- using ConstReference = ReferenceT</*IS_CONST*/ true, K>;
-
- /// Adds a value to the map, if the map does not already contain an entry with the key @p key.
- /// @param key the entry key.
- /// @param value the value of the entry to add to the map.
- /// @returns A AddResult describing the result of the add
+ /// Add attempts to insert a new entry into the map.
+ /// If an existing entry exists with the given key, then the entry is not replaced.
+ /// @param key the new entry's key
+ /// @param value the new entry's value
+ /// @return an AddResult.
template <typename K, typename V>
- AddResult Add(K&& key, V&& value) {
- return this->template Put<PutMode::kAdd>(std::forward<K>(key), std::forward<V>(value));
+ AddResult<Value> Add(K&& key, V&& value) {
+ if (auto idx = this->EditAt(key); idx.entry) {
+ return {idx.entry->value, /* added */ false};
+ } else {
+ idx.Insert(std::forward<K>(key), std::forward<V>(value));
+ return {idx.entry->value, /* added */ true};
+ }
}
- /// Adds a new entry to the map, replacing any entry that has a key equal to @p key.
- /// @param key the entry key.
- /// @param value the value of the entry to add to the map.
- /// @returns A AddResult describing the result of the replace
+ /// Inserts a new entry into the map or updates an existing entry.
+ /// @param key the new entry's key
+ /// @param value the new entry's value
template <typename K, typename V>
- AddResult Replace(K&& key, V&& value) {
- return this->template Put<PutMode::kReplace>(std::forward<K>(key), std::forward<V>(value));
+ void Replace(K&& key, V&& value) {
+ if (auto idx = this->EditAt(key); idx.entry) {
+ idx.Replace(std::forward<K>(key), std::forward<V>(value));
+ } else {
+ idx.Insert(std::forward<K>(key), std::forward<V>(value));
+ }
}
- /// @param key the key to search for.
- /// @returns the value of the entry that is equal to `value`, or no value if the entry was not
- /// found.
+ /// Looks up an entry with the given key.
+ /// @param key the entry's key to search for.
+ /// @returns a GetResult holding the found entry's value, or null if the entry was not found.
template <typename K>
- std::optional<Value> Get(K&& key) const {
- if (auto [found, index] = this->IndexOf(key); found) {
- return this->slots_[index].entry->value;
+ GetResult<Value> Get(K&& key) {
+ if (auto* entry = this->GetEntry(key)) {
+ return {&entry->value};
}
- return std::nullopt;
+ return {nullptr};
+ }
+
+ /// Looks up an entry with the given key.
+ /// @param key the entry's key to search for.
+ /// @returns a GetResult holding the found entry's value, or null if the entry was not found.
+ template <typename K>
+ GetResult<const Value> Get(K&& key) const {
+ if (auto* entry = this->GetEntry(key)) {
+ return {&entry->value};
+ }
+ return {nullptr};
+ }
+
+ /// Searches for an entry with the given key value returning that value if found, otherwise
+ /// returns @p not_found.
+ /// @param key the entry's key value to search for.
+ /// @param not_found the value to return if a node is not found.
+ /// @returns the a reference to the value of the entry, if found otherwise @p not_found.
+ /// @note The returned reference is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
+ template <typename K>
+ const Value& GetOr(K&& key, const Value& not_found) const {
+ if (auto* entry = this->GetEntry(key)) {
+ return entry->value;
+ }
+ return not_found;
+ }
+
+ /// Searches for an entry with the given key, returning a reference to that entry if found,
+ /// otherwise a reference to a newly added entry with the key @p key and the value from calling
+ /// @p create.
+ /// @note: Before calling `create`, the map will insert a zero-initialized value for the given
+ /// key, which will be replaced with the value returned by @p create. If @p create adds an entry
+ /// with @p key to this map, it will be replaced.
+ /// @param key the entry's key value to search for.
+ /// @param create the create function to call if the map does not contain the key. Must have the
+ /// signature `Key()`.
+ /// @returns a reference to the existing entry, or the newly added entry.
+ /// @note The returned reference is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
+ template <typename K, typename CREATE>
+ Entry& GetOrAddEntry(K&& key, CREATE&& create) {
+ auto idx = this->EditAt(key);
+ if (!idx.entry) {
+ idx.Insert(std::forward<K>(key), Value{});
+ idx.entry->value = create();
+ }
+ return *idx.entry;
+ }
+
+ /// Searches for an entry with the given key, returning a reference to that entry if found,
+ /// otherwise a reference to a newly added entry with the key @p key and a zero value.
+ /// @param key the entry's key value to search for.
+ /// @returns a reference to the existing entry, or the newly added entry.
+ /// @note The returned reference is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
+ template <typename K>
+ Entry& GetOrAddZeroEntry(K&& key) {
+ auto idx = this->EditAt(key);
+ if (!idx.entry) {
+ idx.Insert(std::forward<K>(key), Value{});
+ }
+ return *idx.entry;
}
/// Searches for an entry with the given key, adding and returning the result of calling
@@ -171,58 +258,32 @@
/// @param key the entry's key value to search for.
/// @param create the create function to call if the map does not contain the key.
/// @returns the value of the entry.
+ /// @note The returned reference is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
template <typename K, typename CREATE>
- Value& GetOrCreate(K&& key, CREATE&& create) {
- auto res = Add(std::forward<K>(key), Value{});
- if (res.action == MapAction::kAdded) {
- // Store the map generation before calling create()
- auto generation = this->Generation();
- // Call create(), which might modify this map.
- auto value = create();
- // Was this map mutated?
- if (this->Generation() == generation) {
- // Calling create() did not touch the map. No need to lookup again.
- *res.value = std::move(value);
- } else {
- // Calling create() modified the map. Need to insert again.
- res = Replace(key, std::move(value));
- }
- }
- return *res.value;
+ Value& GetOrAdd(K&& key, CREATE&& create) {
+ return GetOrAddEntry(std::forward<K>(key), std::forward<CREATE>(create)).value;
}
/// Searches for an entry with the given key value, adding and returning a newly created
/// zero-initialized value if the entry was not found.
/// @param key the entry's key value to search for.
/// @returns the value of the entry.
+ /// @note The returned reference is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
template <typename K>
- auto GetOrZero(K&& key) {
- auto res = Add(std::forward<K>(key), Value{});
- return Reference<ReferenceKeyType<K>>(*this, key, res.value);
- }
-
- /// @param key the key to search for.
- /// @returns a reference to the entry that is equal to the given value.
- template <typename K>
- auto Find(K&& key) {
- return Reference<ReferenceKeyType<K>>(*this, std::forward<K>(key));
- }
-
- /// @param key the key to search for.
- /// @returns a reference to the entry that is equal to the given value.
- template <typename K>
- auto Find(K&& key) const {
- return ConstReference<ReferenceKeyType<K>>(*this, std::forward<K>(key));
+ Value& GetOrAddZero(K&& key) {
+ return GetOrAddZeroEntry(std::forward<K>(key)).value;
}
/// @returns the keys of the map as a vector.
/// @note the order of the returned vector is non-deterministic between compilers.
template <size_t N2 = N>
- Vector<Key, N2> Keys() const {
- Vector<Key, N2> out;
+ Vector<KEY, N2> Keys() const {
+ Vector<KEY, N2> out;
out.Reserve(this->Count());
- for (auto it : *this) {
- out.Push(it.key);
+ for (auto& it : *this) {
+ out.Push(it.key.Value());
}
return out;
}
@@ -233,7 +294,7 @@
Vector<Value, N2> Values() const {
Vector<Value, N2> out;
out.Reserve(this->Count());
- for (auto it : *this) {
+ for (auto& it : *this) {
out.Push(it.value);
}
return out;
@@ -247,8 +308,8 @@
if (this->Count() != other.Count()) {
return false;
}
- for (auto it : *this) {
- auto other_val = other.Find(it.key);
+ for (auto& it : *this) {
+ auto other_val = other.Get(it.key.Value());
if (!other_val || it.value != *other_val) {
return false;
}
@@ -263,23 +324,6 @@
bool operator!=(const Hashmap<K, V, N2>& other) const {
return !(*this == other);
}
-
- private:
- template <typename K>
- Value* Lookup(K&& key) {
- if (auto [found, index] = this->IndexOf(key); found) {
- return &this->slots_[index].entry->value;
- }
- return nullptr;
- }
-
- template <typename K>
- const Value* Lookup(K&& key) const {
- if (auto [found, index] = this->IndexOf(key); found) {
- return &this->slots_[index].entry->value;
- }
- return nullptr;
- }
};
/// Hasher specialization for Hashmap
@@ -292,12 +336,37 @@
for (auto it : map) {
// Use an XOR to ensure that the non-deterministic ordering of the map still produces
// the same hash value for the same entries.
- hash ^= Hash(it.key, it.value);
+ hash ^= Hash(it.key.Value(), it.value);
}
return hash;
}
};
+/// Writes the Hashmap to the stream.
+/// @param out the stream to write to
+/// @param map the Hashmap to write
+/// @returns out so calls can be chained
+template <typename STREAM,
+ typename KEY,
+ typename VALUE,
+ size_t N,
+ typename HASH,
+ typename EQUAL,
+ typename = traits::EnableIfIsOStream<STREAM>>
+auto& operator<<(STREAM& out, const Hashmap<KEY, VALUE, N, HASH, EQUAL>& map) {
+ out << "Hashmap{";
+ bool first = true;
+ for (auto it : map) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+ out << it;
+ }
+ out << "}";
+ return out;
+}
+
} // namespace tint
#endif // SRC_TINT_UTILS_CONTAINERS_HASHMAP_H_
diff --git a/src/tint/utils/containers/hashmap_base.h b/src/tint/utils/containers/hashmap_base.h
index 81141ba..1697959 100644
--- a/src/tint/utils/containers/hashmap_base.h
+++ b/src/tint/utils/containers/hashmap_base.h
@@ -37,271 +37,361 @@
#include "src/tint/utils/containers/vector.h"
#include "src/tint/utils/ice/ice.h"
#include "src/tint/utils/math/hash.h"
+#include "src/tint/utils/math/math.h"
#include "src/tint/utils/traits/traits.h"
-#define TINT_ASSERT_ITERATORS_NOT_INVALIDATED
-
namespace tint {
-/// Action taken by a map mutation
-enum class MapAction {
- /// A new entry was added to the map
- kAdded,
- /// A existing entry in the map was replaced
- kReplaced,
- /// No action was taken as the map already contained an entry with the given key
- kKeptExisting,
-};
+/// HashmapKey wraps the comparator type for a Hashmap and Hashset.
+/// HashmapKey acts like a read-only `T`, but can be reassigned so long as the value is equivalent.
+/// @tparam T the key comparator type.
+/// @tparam HASH the hash function for the key type.
+/// @tparam EQUAL the equality function for the key type.
+template <typename T, typename HASH = Hasher<T>, typename EQUAL = std::equal_to<T>>
+class HashmapKey {
+ T value_;
-/// KeyValue is a key-value pair.
-template <typename KEY, typename VALUE>
-struct KeyValue {
- /// The key type
- using Key = KEY;
- /// The value type
- using Value = VALUE;
+ public:
+ /// Key is an alias to this templated class.
+ using Key = HashmapKey<T, HASH, EQUAL>;
+ /// Hash is an alias to the hash function for the key type.
+ using Hash = HASH;
+ /// Equal is an alias to the equality function for the key type.
+ using Equal = EQUAL;
- /// The key
- Key key;
+ /// KeyOf() returns @p key, so a HashmapKey can be used as the entry type for a Hashset.
+ /// @param key the HashmapKey
+ /// @return @p key
+ static const Key& KeyOf(const Key& key) { return key; }
- /// The value
- Value value;
+ /// Constructor using copied value.
+ /// @param value the key value.
+ HashmapKey(const T& value) : value_(value), hash(HASH{}(value_)) {} // NOLINT
+
+ /// Constructor using moved value.
+ /// @param value the key value.
+ HashmapKey(T&& value) : value_(std::forward<T>(value)), hash(HASH{}(value_)) {} // NOLINT
+
+ /// Constructor using pre-computed hash and copied value.
+ /// @param hash_ the precomputed hash of @p value
+ /// @param value the key value
+ HashmapKey(size_t hash_, const T& value) : value_(value), hash(hash_) {}
+
+ /// Constructor using pre-computed hash and moved value.
+ /// @param hash_ the precomputed hash of @p value
+ /// @param value the key value
+ HashmapKey(size_t hash_, T&& value) : value_(std::forward<T>(value)), hash(hash_) {}
+
+ /// Copy constructor
+ HashmapKey(const HashmapKey&) = default;
+
+ /// Move constructor
+ HashmapKey(HashmapKey&&) = default;
+
+ /// Destructor
+ ~HashmapKey() = default;
+
+ /// Copy-assignment operator.
+ /// @note As a hashmap uses the HashmapKey for indexing, the new value *must* have the same hash
+ /// value and be equal to this key.
+ /// @param other the key to copy to this key.
+ /// @return this HashmapKey.
+ HashmapKey& operator=(const HashmapKey& other) {
+ TINT_ASSERT(*this == other);
+ value_ = other.Value();
+ return *this;
+ }
+
+ /// Move-assignment operator.
+ /// @note As a hashmap uses the HashmapKey for indexing, the new value *must* have the same hash
+ /// value and be equal to this key.
+ /// @param other the key to move to this key.
+ /// @return this HashmapKey.
+ HashmapKey& operator=(HashmapKey&& other) {
+ TINT_ASSERT(*this == other);
+ value_ = std::move(other.Value());
+ return *this;
+ }
/// Equality operator
- /// @param other the RHS of the operator
- /// @returns true if both the key and value of this KeyValue are equal to the key and value
- /// of @p other
- template <typename K, typename V>
- bool operator==(const KeyValue<K, V>& other) const {
- return key == other.key && value == other.value;
+ /// @param other the other key.
+ /// @return true if the hash and value of @p other are equal to this key.
+ bool operator==(const HashmapKey& other) const {
+ return hash == other.hash && EQUAL{}(value_, other.Value());
}
- /// Inequality operator
- /// @param other the RHS of the operator
- /// @returns true if either the key and value of this KeyValue are not equal to the key and
- /// value of @p other
- template <typename K, typename V>
- bool operator!=(const KeyValue<K, V>& other) const {
- return *this != other;
+ /// Equality operator
+ /// @param other the other key.
+ /// @return true if the hash of other and value of @p other are equal to this key.
+ template <typename RHS>
+ bool operator==(const RHS& other) const {
+ return hash == HASH{}(other) && EQUAL{}(value_, other);
}
+
+ /// @returns the value of the key
+ const T& Value() const { return value_; }
+
+ /// @returns the value of the key
+ operator const T&() const { return value_; }
+
+ /// @returns the pointer to the value, or the value itself if T is a pointer.
+ auto operator->() const {
+ if constexpr (std::is_pointer_v<T>) {
+ // operator-> is useless if the T is a pointer, so automatically unwrap a pointer.
+ return value_;
+ } else {
+ return &value_;
+ }
+ }
+
+ /// The hash of value
+ const size_t hash;
};
-/// KeyValueRef is a pair of references to a key and value.
-/// #key is always a const reference.
-/// #value is always a const reference if @tparam VALUE_IS_CONST is true, otherwise a non-const
-/// reference.
-template <typename KEY, typename VALUE, bool VALUE_IS_CONST>
-struct KeyValueRef {
- /// The reference to key type
- using KeyRef = const KEY&;
- /// The reference to value type
- using ValueRef = std::conditional_t<VALUE_IS_CONST, const VALUE&, VALUE&>;
-
- /// The reference to the key
- KeyRef key;
-
- /// The reference to the value
- ValueRef value;
-
- /// @returns a KeyValue<KEY, VALUE> with the referenced key and value
- operator KeyValue<KEY, VALUE>() const { return {key, value}; }
-};
-
-/// Writes the KeyValue to the stream.
+/// Writes the HashmapKey to the stream.
/// @param out the stream to write to
-/// @param key_value the KeyValue to write
+/// @param key the HashmapKey to write
/// @returns out so calls can be chained
-template <typename STREAM,
- typename KEY,
- typename VALUE,
- typename = traits::EnableIfIsOStream<STREAM>>
-auto& operator<<(STREAM& out, const KeyValue<KEY, VALUE>& key_value) {
- return out << "[" << key_value.key << ": " << key_value.value << "]";
+template <typename STREAM, typename T, typename = traits::EnableIfIsOStream<STREAM>>
+auto& operator<<(STREAM& out, const HashmapKey<T>& key) {
+ if constexpr (traits::HasOperatorShiftLeft<STREAM, T>) {
+ return out << key.Value();
+ } else {
+ return out << "<hashmap-key>";
+ }
}
-/// A base class for Hashmap and Hashset that uses a robin-hood hashing algorithm.
-/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html
-template <typename KEY,
- typename VALUE,
- size_t N,
- typename HASH = Hasher<KEY>,
- typename EQUAL = EqualTo<KEY>>
+/// HashmapBase is the base class for Hashmap and Hashset.
+/// @tparam ENTRY is the single record in the map. The entry type must alias 'Key' to the HashmapKey
+/// type, and implement the method `static HashmapKey<...> KeyOf(ENTRY)` to return the key for the
+/// entry.
+template <typename ENTRY, size_t N>
class HashmapBase {
- static constexpr bool ValueIsVoid = std::is_same_v<VALUE, void>;
+ protected:
+ struct Node;
+ struct Slot;
public:
- /// The key type
- using Key = KEY;
- /// The value type
- using Value = VALUE;
- /// The entry type for the map.
- /// This is:
- /// - Key when Value is void (used by Hashset)
- /// - KeyValue<Key, Value> when Value is not void (used by Hashmap)
- using Entry = std::conditional_t<ValueIsVoid, Key, KeyValue<Key, Value>>;
+ /// Entry is the type of a single record in the hashmap.
+ using Entry = ENTRY;
+ /// Key is the HashmapKey type used to find entries.
+ using Key = typename Entry::Key;
+ /// Hash is the
+ using Hash = typename Key::Hash;
+ /// Equal is the
+ using Equal = typename Key::Equal;
- /// A reference to an entry in the map.
- /// This is:
- /// - const Key& when Value is void (used by Hashset)
- /// - KeyValueRef<Key, Value> when Value is not void (used by Hashmap)
- template <bool IS_CONST>
- using EntryRef = std::conditional_t<
- ValueIsVoid,
- const Key&,
- KeyValueRef<Key, std::conditional_t<ValueIsVoid, bool, Value>, IS_CONST>>;
+ /// The minimum capacity of the map.
+ static constexpr size_t kMinCapacity = std::max<size_t>(N, 8);
- /// STL-friendly alias to Entry. Used by gmock.
- using value_type = Entry;
+ /// The target number of slots, expressed as a fractional percentage of the map capacity.
+ /// e.g. a kLoadFactor of 75, would mean a target slots count of (0.75 * capacity).
+ static constexpr size_t kLoadFactor = 75;
- private:
- /// @returns the key from an entry
- static const Key& KeyOf(const Entry& entry) {
- if constexpr (ValueIsVoid) {
- return entry;
- } else {
- return entry.key;
+ /// @param capacity the capacity of the map, as total number of entries.
+ /// @returns the target slot vector size to hold @p capacity map entries.
+ static constexpr size_t NumSlots(size_t capacity) {
+ return (std::max<size_t>(capacity, kMinCapacity) * kLoadFactor) / 100;
+ }
+
+ /// Constructor.
+ /// Constructs an empty map.
+ HashmapBase() {
+ slots_.Resize(slots_.Capacity());
+ for (auto& node : fixed_) {
+ free_.Add(&node);
}
}
- /// @returns a pointer to the value from an entry.
- static Value* ValueOf(Entry& entry) {
- if constexpr (ValueIsVoid) {
- return nullptr; // Hashset only has keys
- } else {
- return &entry.value;
+ /// Copy constructor.
+ /// Constructs a map with a copy of @p other.
+ /// @param other the map to copy.
+ HashmapBase(const HashmapBase& other) : HashmapBase() {
+ if (&other != this) {
+ Copy(other);
}
}
- /// A slot is a single entry in the underlying vector.
- /// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance
- /// will be zero.
- struct Slot {
- template <typename K>
- bool Equals(size_t key_hash, K&& key) const {
- return key_hash == hash && EQUAL()(std::forward<K>(key), KeyOf(*entry));
+ /// Move constructor.
+ /// Constructs a map with the moved entries of @p other.
+ /// @param other the map to move.
+ HashmapBase(HashmapBase&& other) : HashmapBase() {
+ if (&other != this) {
+ Move(std::move(other));
}
+ }
- /// The slot value. If this does not contain a value, then the slot is vacant.
- std::optional<Entry> entry;
- /// The precomputed hash of value.
- size_t hash = 0;
- size_t distance = 0;
- };
-
- /// The target length of the underlying vector length in relation to the number of entries in
- /// the map, expressed as a percentage. For example a value of `150` would mean there would be
- /// at least 50% more slots than the number of map entries.
- static constexpr size_t kRehashFactor = 150;
-
- /// @returns the target slot vector size to hold `n` map entries.
- static constexpr size_t NumSlots(size_t count) { return (count * kRehashFactor) / 100; }
-
- /// The fixed-size slot vector length, based on N and kRehashFactor.
- static constexpr size_t kNumFixedSlots = NumSlots(N);
-
- /// The minimum number of slots for the map.
- static constexpr size_t kMinSlots = std::max<size_t>(kNumFixedSlots, 4);
-
- public:
- /// Iterator for entries in the map.
- /// Iterators are invalidated if the map is modified.
- template <bool IS_CONST>
- class IteratorT {
- public:
- /// @returns the value pointed to by this iterator
- EntryRef<IS_CONST> operator->() const {
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- TINT_ASSERT(map.Generation() == initial_generation &&
- "iterator invalidated by container modification");
-#endif
- return *this;
- }
-
- /// @returns a reference to the value at the iterator
- EntryRef<IS_CONST> operator*() const {
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- TINT_ASSERT(map.Generation() == initial_generation &&
- "iterator invalidated by container modification");
-#endif
- auto& ref = current->entry.value();
- if constexpr (ValueIsVoid) {
- return ref;
- } else {
- return {ref.key, ref.value};
+ /// Destructor.
+ ~HashmapBase() {
+ // Call the destructor on all entries in the map.
+ for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
+ auto* node = slots_[slot_idx].nodes;
+ while (node) {
+ auto next = node->next;
+ node->Destroy();
+ node = next;
}
}
+ }
+
+ /// Assignment operator.
+ /// Clears this map, and populates this map with a copy of @p other.
+ /// @param other the map to copy.
+ /// @returns this HashmapBase
+ HashmapBase& operator=(const HashmapBase& other) {
+ if (&other != this) {
+ Clear();
+ Copy(other);
+ }
+ return *this;
+ }
+
+ /// Move-assignment operator.
+ /// Clears this map, and populates this map with the moved entries of @p other.
+ /// @param other the map to move.
+ /// @returns this HashmapBase
+ HashmapBase& operator=(HashmapBase&& other) {
+ if (&other != this) {
+ Clear();
+ Move(std::move(other));
+ }
+ return *this;
+ }
+
+ /// @returns the number of entries in the map.
+ size_t Count() const { return count_; }
+
+ /// @returns true if the map holds no entries.
+ bool IsEmpty() const { return count_ == 0; }
+
+ /// Removes all the entries from the map.
+ /// @note the map's capacity is not reduced, as it is assumed that a reused map will likely fill
+ /// to a similar size as before.
+ void Clear() {
+ for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
+ auto* node = slots_[slot_idx].nodes;
+ while (node) {
+ auto next = node->next;
+ node->Destroy();
+ free_.Add(node);
+ node = next;
+ }
+ slots_[slot_idx].nodes = nullptr;
+ }
+ }
+
+ /// Ensures that the map can hold @p n entries without heap reallocation or rehashing.
+ /// @param n the number of entries to ensure can fit in the map without reallocation or
+ /// rehashing.
+ void Reserve(size_t n) {
+ if (n > capacity_) {
+ size_t count = n - capacity_;
+ free_.Allocate(count);
+ capacity_ += count;
+ }
+ }
+
+ /// Looks up an entry with the given key.
+ /// @param key the entry's key to search for.
+ /// @returns a pointer to the matching entry, or null if no entry was found.
+ /// @note The returned pointer is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
+ template <typename K>
+ Entry* GetEntry(K&& key) {
+ size_t hash = Hash{}(key);
+ auto& slot = slots_[hash % slots_.Length()];
+ return slot.Find(hash, key);
+ }
+
+ /// Looks up an entry with the given key.
+ /// @param key the entry's key to search for.
+ /// @returns a pointer to the matching entry, or null if no entry was found.
+ /// @note The returned pointer is guaranteed to be valid until the owning entry is removed,
+ /// the map is cleared, or the map is destructed.
+ template <typename K>
+ const Entry* GetEntry(K&& key) const {
+ size_t hash = Hash{}(key);
+ auto& slot = slots_[hash % slots_.Length()];
+ return slot.Find(hash, key);
+ }
+
+ /// @returns true if the map contains an entry with a key that matches @p key.
+ /// @param key the key to look for.
+ template <typename K = Key>
+ bool Contains(K&& key) const {
+ return GetEntry(key) != nullptr;
+ }
+
+ /// Removes an entry from the map that has a key which matches @p key.
+ /// @returns true if the entry was found and removed, otherwise false.
+ /// @param key the key to look for.
+ template <typename K = Key>
+ bool Remove(K&& key) {
+ size_t hash = Hash{}(key);
+ auto& slot = slots_[hash % slots_.Length()];
+ Node** edge = &slot.nodes;
+ for (auto* node = *edge; node; node = node->next) {
+ if (node->Equals(hash, key)) {
+ *edge = node->next;
+ node->Destroy();
+ free_.Add(node);
+ count_--;
+ return true;
+ }
+ edge = &node->next;
+ }
+ return false;
+ }
+
+ /// Iterator for entries in the map.
+ template <bool IS_CONST>
+ class IteratorT {
+ private:
+ using MAP = std::conditional_t<IS_CONST, const HashmapBase, HashmapBase>;
+ using NODE = std::conditional_t<IS_CONST, const Node, Node>;
+
+ public:
+ /// @returns the entry pointed to by this iterator
+ auto& operator->() { return node_->Entry(); }
+
+ /// @returns a reference to the entry at the iterator
+ auto& operator*() { return node_->Entry(); }
/// Increments the iterator
/// @returns this iterator
IteratorT& operator++() {
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- TINT_ASSERT(map.Generation() == initial_generation &&
- "iterator invalidated by container modification");
-#endif
- if (current == end) {
- return *this;
- }
- ++current;
- SkipToNextValue();
+ node_ = node_->next;
+ SkipEmptySlots();
return *this;
}
/// Equality operator
/// @param other the other iterator to compare this iterator to
/// @returns true if this iterator is equal to other
- bool operator==(const IteratorT& other) const {
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- TINT_ASSERT(map.Generation() == initial_generation &&
- "iterator invalidated by container modification");
-#endif
- return current == other.current;
- }
+ bool operator==(const IteratorT& other) const { return node_ == other.node_; }
/// Inequality operator
/// @param other the other iterator to compare this iterator to
/// @returns true if this iterator is not equal to other
- bool operator!=(const IteratorT& other) const {
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- TINT_ASSERT(map.Generation() == initial_generation &&
- "iterator invalidated by container modification");
-#endif
- return current != other.current;
- }
+ bool operator!=(const IteratorT& other) const { return node_ != other.node_; }
private:
/// Friend class
friend class HashmapBase;
- using SLOT = std::conditional_t<IS_CONST, const Slot, Slot>;
-
- IteratorT(VectorIterator<SLOT> c,
- VectorIterator<SLOT> e,
- [[maybe_unused]] const HashmapBase& m)
- : current(std::move(c)),
- end(std::move(e))
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- ,
- map(m),
- initial_generation(m.Generation())
-#endif
- {
- SkipToNextValue();
+ IteratorT(MAP& map, size_t slot, NODE* node) : map_(map), slot_(slot), node_(node) {
+ SkipEmptySlots();
}
- /// Moves the iterator forward, stopping at the next slot that is not empty.
- void SkipToNextValue() {
- while (current != end && !current->entry.has_value()) {
- ++current;
+ void SkipEmptySlots() {
+ while (!node_ && slot_ + 1 < map_.slots_.Length()) {
+ node_ = map_.slots_[++slot_].nodes;
}
}
- VectorIterator<SLOT> current; /// The slot the iterator is pointing to
- VectorIterator<SLOT> end; /// One past the last slot in the map
-
-#ifdef TINT_ASSERT_ITERATORS_NOT_INVALIDATED
- const HashmapBase& map; /// The hashmap that is being iterated over.
- size_t initial_generation; /// The generation ID when the iterator was created.
-#endif
+ MAP& map_;
+ size_t slot_ = 0;
+ NODE* node_ = nullptr;
};
/// An immutable key and mutable value iterator
@@ -310,342 +400,289 @@
/// An immutable key and value iterator
using ConstIterator = IteratorT</*IS_CONST*/ true>;
- /// Constructor
- HashmapBase() { slots_.Resize(kMinSlots); }
-
- /// Copy constructor
- /// @param other the other HashmapBase to copy
- HashmapBase(const HashmapBase& other) = default;
-
- /// Move constructor
- /// @param other the other HashmapBase to move
- HashmapBase(HashmapBase&& other) = default;
-
- /// Destructor
- ~HashmapBase() { Clear(); }
-
- /// Copy-assignment operator
- /// @param other the other HashmapBase to copy
- /// @returns this so calls can be chained
- HashmapBase& operator=(const HashmapBase& other) = default;
-
- /// Move-assignment operator
- /// @param other the other HashmapBase to move
- /// @returns this so calls can be chained
- HashmapBase& operator=(HashmapBase&& other) = default;
-
- /// Removes all entries from the map.
- void Clear() {
- slots_.Clear(); // Destructs all entries
- slots_.Resize(kMinSlots);
- count_ = 0;
- generation_++;
- }
-
- /// Removes an entry from the map.
- /// @param key the entry key.
- /// @returns true if an entry was removed.
- bool Remove(const Key& key) {
- const auto [found, start] = IndexOf(key);
- if (!found) {
- return false;
- }
-
- // Shuffle the entries backwards until we either find a free slot, or a slot that has zero
- // distance.
- Slot* prev = nullptr;
-
- const auto count = slots_.Length();
- for (size_t distance = 0, index = start; distance < count; distance++) {
- auto& slot = slots_[index];
- if (prev) {
- // note: `distance == 0` also includes empty slots.
- if (slot.distance == 0) {
- // Clear the previous slot, and stop shuffling.
- *prev = {};
- break;
- }
- // Shuffle the slot backwards.
- prev->entry = std::move(slot.entry);
- prev->hash = slot.hash;
- prev->distance = slot.distance - 1;
- }
- prev = &slot;
-
- index = (index == count - 1) ? 0 : index + 1;
- }
-
- // Entry was removed.
- count_--;
- generation_++;
-
- return true;
- }
-
- /// Checks whether an entry exists in the map
- /// @param key the key to search for.
- /// @returns true if the map contains an entry with the given value.
- bool Contains(const Key& key) const {
- const auto [found, _] = IndexOf(key);
- return found;
- }
-
- /// Pre-allocates memory so that the map can hold at least `capacity` entries.
- /// @param capacity the new capacity of the map.
- void Reserve(size_t capacity) {
- // Calculate the number of slots required to hold `capacity` entries.
- const size_t num_slots = std::max(NumSlots(capacity), kMinSlots);
- if (slots_.Length() >= num_slots) {
- // Already have enough slots.
- return;
- }
-
- // Move all the values out of the map and into a vector.
- Vector<Entry, N> entries;
- entries.Reserve(count_);
- for (auto& slot : slots_) {
- if (slot.entry.has_value()) {
- entries.Push(std::move(slot.entry.value()));
- }
- }
-
- // Clear the map, grow the number of slots.
- Clear();
- slots_.Resize(num_slots);
-
- // As the number of slots has grown, the slot indices will have changed from before, so
- // re-add all the entries back into the map.
- for (auto& entry : entries) {
- if constexpr (ValueIsVoid) {
- struct NoValue {};
- Put<PutMode::kAdd>(std::move(entry), NoValue{});
- } else {
- Put<PutMode::kAdd>(std::move(entry.key), std::move(entry.value));
- }
- }
- }
-
- /// @returns the number of entries in the map.
- size_t Count() const { return count_; }
-
- /// @returns true if the map contains no entries.
- bool IsEmpty() const { return count_ == 0; }
-
- /// @returns a monotonic counter which is incremented whenever the map is mutated.
- size_t Generation() const { return generation_; }
-
/// @returns an immutable iterator to the start of the map.
- ConstIterator begin() const { return ConstIterator{slots_.begin(), slots_.end(), *this}; }
+ ConstIterator begin() const { return ConstIterator{*this, 0, slots_.Front().nodes}; }
/// @returns an immutable iterator to the end of the map.
- ConstIterator end() const { return ConstIterator{slots_.end(), slots_.end(), *this}; }
+ ConstIterator end() const { return ConstIterator{*this, slots_.Length(), nullptr}; }
/// @returns an iterator to the start of the map.
- Iterator begin() { return Iterator{slots_.begin(), slots_.end(), *this}; }
+ Iterator begin() { return Iterator{*this, 0, slots_.Front().nodes}; }
/// @returns an iterator to the end of the map.
- Iterator end() { return Iterator{slots_.end(), slots_.end(), *this}; }
+ Iterator end() { return Iterator{*this, slots_.Length(), nullptr}; }
- /// A debug function for checking that the map is in good health.
- /// Asserts if the map is corrupted.
- void ValidateIntegrity() const {
- size_t num_alive = 0;
- for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
- const auto& slot = slots_[slot_idx];
- if (slot.entry.has_value()) {
- num_alive++;
- auto const [index, hash] = Hash(KeyOf(*slot.entry));
- TINT_ASSERT(hash == slot.hash);
- TINT_ASSERT(slot_idx == Wrap(index + slot.distance));
- }
- }
- TINT_ASSERT(num_alive == count_);
- }
+ /// STL-friendly alias to Entry. Used by gmock.
+ using value_type = const Entry&;
protected:
- /// The behaviour of Put() when an entry already exists with the given key.
- enum class PutMode {
- /// Do not replace existing entries with the new value.
- kAdd,
- /// Replace existing entries with the new value.
- kReplace,
- };
-
- /// Result of Put()
- struct PutResult {
- /// Whether the insert replaced or added a new entry to the map.
- MapAction action = MapAction::kAdded;
- /// A pointer to the inserted entry value.
- Value* value = nullptr;
-
- /// @returns true if the entry was added to the map, or an existing entry was replaced.
- operator bool() const { return action != MapAction::kKeptExisting; }
- };
-
- /// The common implementation for Add() and Replace()
- /// @param key the key of the entry to add to the map.
- /// @param value the value of the entry to add to the map.
- /// @returns A PutResult describing the result of the insertion
- template <PutMode MODE, typename K, typename V>
- PutResult Put(K&& key, V&& value) {
- // Ensure the map can fit a new entry
- if (ShouldRehash(count_ + 1)) {
- Reserve((count_ + 1) * 2);
- }
-
- const auto hash = Hash(key);
-
- auto make_entry = [&] {
- if constexpr (ValueIsVoid) {
- return std::forward<K>(key);
- } else {
- return Entry{std::forward<K>(key), std::forward<V>(value)};
- }
+ /// Node holds an Entry in a linked list.
+ struct Node {
+ /// A structure that has the same size and alignment as Entry.
+ /// Replacement for std::aligned_storage as this is broken on earlier versions of MSVC.
+ struct alignas(alignof(ENTRY)) Storage {
+ /// Byte array of length sizeof(ENTRY)
+ uint8_t data[sizeof(ENTRY)];
};
- const auto count = slots_.Length();
- for (size_t distance = 0, index = hash.scan_start; distance < count; distance++) {
- auto& slot = slots_[index];
- if (!slot.entry.has_value()) {
- // Found an empty slot.
- // Place value directly into the slot, and we're done.
- slot.entry.emplace(make_entry());
- slot.hash = hash.code;
- slot.distance = distance;
- count_++;
- generation_++;
- return PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
- }
+ /// Destructs the entry.
+ void Destroy() { Entry().~ENTRY(); }
- // Slot has an entry
+ /// @returns the storage reinterpreted as an `Entry&`
+ ENTRY& Entry() { return *Bitcast<ENTRY*>(&storage.data[0]); }
- if (slot.Equals(hash.code, key)) {
- // Slot is equal to value. Replace or preserve?
- if constexpr (MODE == PutMode::kReplace) {
- slot.entry = make_entry();
- generation_++;
- return PutResult{MapAction::kReplaced, ValueOf(*slot.entry)};
- } else {
- return PutResult{MapAction::kKeptExisting, ValueOf(*slot.entry)};
- }
- }
+ /// @returns the storage reinterpreted as a `const Entry&`
+ const ENTRY& Entry() const { return *Bitcast<const ENTRY*>(&storage.data[0]); }
- if (slot.distance < distance) {
- // Existing slot has a closer distance than the value we're attempting to insert.
- // Steal from the rich!
- // Move the current slot to a temporary (evicted), and put the value into the slot.
- Slot evicted{make_entry(), hash.code, distance};
- std::swap(evicted, slot);
+ /// @returns a reference to the Entry's HashmapKey
+ const HashmapBase::Key& Key() const { return HashmapBase::Entry::KeyOf(Entry()); }
- // Find a new home for the evicted slot.
- evicted.distance++; // We've already swapped at index.
- InsertShuffle(Wrap(index + 1), std::move(evicted));
-
- count_++;
- generation_++;
- return PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
- }
-
- index = (index == count - 1) ? 0 : index + 1;
+ /// @param hash the hash value to compare against the Entry's key hash value
+ /// @param value the value to compare against the Entry's key
+ /// @returns true if the Entry's hash is equal to @p hash, and the Entry's key is equal to
+ /// @p value.
+ template <typename T>
+ bool Equals(size_t hash, T&& value) const {
+ auto& key = Key();
+ return key.hash == hash && HashmapBase::Equal{}(key.Value(), value);
}
- TINT_ICE() << "HashmapBase::Put() looped entire map without finding a slot";
- return PutResult{};
- }
+ /// storage is a buffer that has the same size and alignment as Entry.
+ /// The storage holds a constructed Entry when linked in the slots, and is destructed when
+ /// removed from slots.
+ Storage storage;
- /// HashResult is the return value of Hash()
- struct HashResult {
- /// The target (zero-distance) slot index for the key.
- size_t scan_start;
- /// The calculated hash code of the key.
- size_t code;
+ /// next is the next Node in the slot, or in the free list.
+ Node* next;
};
- /// @param key the key to hash
- /// @returns a tuple holding the target slot index for the given value, and the hash of the
- /// value, respectively.
- template <typename K>
- HashResult Hash(K&& key) const {
- size_t hash = HASH()(std::forward<K>(key));
- size_t index = Wrap(hash);
- return {index, hash};
+ /// Copies the hashmap @p other into this empty hashmap.
+ /// @note This hashmap must be empty before calling
+ /// @param other the hashmap to copy
+ void Copy(const HashmapBase& other) {
+ Reserve(other.capacity_);
+ slots_.Resize(other.slots_.Length());
+ for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
+ for (auto* o = other.slots_[slot_idx].nodes; o; o = o->next) {
+ auto* node = free_.Take();
+ new (&node->Entry()) Entry{o->Entry()};
+ slots_[slot_idx].Add(node);
+ }
+ }
+ count_ = other.count_;
}
- /// Looks for the key in the map.
- /// @param key the key to search for.
- /// @returns a tuple holding a boolean representing whether the key was found in the map, and
- /// if found, the index of the slot that holds the key.
- template <typename K>
- std::tuple<bool, size_t> IndexOf(K&& key) const {
- const auto hash = Hash(key);
- const auto count = slots_.Length();
- for (size_t distance = 0, index = hash.scan_start; distance < count; distance++) {
- auto& slot = slots_[index];
- if (!slot.entry.has_value()) {
- return {/* found */ false, /* index */ 0};
+ /// Moves the the hashmap @p other into this empty hashmap.
+ /// @note This hashmap must be empty before calling
+ /// @param other the hashmap to move
+ void Move(HashmapBase&& other) {
+ Reserve(other.capacity_);
+ slots_.Resize(other.slots_.Length());
+ for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
+ for (auto* o = other.slots_[slot_idx].nodes; o; o = o->next) {
+ auto* node = free_.Take();
+ new (&node->Entry()) Entry{std::move(o->Entry())};
+ slots_[slot_idx].Add(node);
}
- if (slot.Equals(hash.code, key)) {
- return {/* found */ true, index};
- }
- if (slot.distance < distance) {
- // If the slot distance is less than the current probe distance, then the slot
- // must be for entry that has an index that comes after key. In this situation,
- // we know that the map does not contain the key, as it would have been found
- // before this slot. The "Lookup" section of
- // https://programming.guide/robin-hood-hashing.html suggests that the condition
- // should inverted, but this is wrong.
- return {/* found */ false, /* index */ 0};
- }
- index = (index == count - 1) ? 0 : index + 1;
+ }
+ count_ = other.count_;
+ other.Clear();
+ }
+
+ /// EditIndex is the structure returned by EditAt(), used to simplify entry replacement and
+ /// insertion.
+ struct EditIndex {
+ /// The HashmapBase that created this EditIndex
+ HashmapBase& map;
+ /// The slot that will hold the edit.
+ Slot& slot;
+ /// The hash of the key, passed to EditAt().
+ size_t hash;
+ /// The resolved node entry, or nullptr if EditAt() did not resolve to an existing entry.
+ Entry* entry = nullptr;
+
+ /// Replace will replace the entry with a new Entry built from @p key and @p values.
+ /// @note #entry must not be null before calling.
+ /// @note the new key must have equality to the old key.
+ /// @param key the key value (inner value of a HashmapKey).
+ /// @param values optional additional values to pass to the Entry constructor.
+ template <typename K, typename... V>
+ void Replace(K&& key, V&&... values) {
+ *entry = Entry{Key{hash, std::forward<K>(key)}, std::forward<V>(values)...};
}
- TINT_ICE() << "HashmapBase::IndexOf() looped entire map without finding a slot";
- return {/* found */ false, /* index */ 0};
+ /// Insert will create a new entry using @p key and @p values and insert it into the slot.
+ /// The created entry will be assigned to #entry before returning.
+ /// @note #entry must be null before calling.
+ /// @note the key must not already exist in the map.
+ /// @param key the key value (inner value of a HashmapKey).
+ /// @param values optional additional values to pass to the Entry constructor.
+ template <typename K, typename... V>
+ void Insert(K&& key, V&&... values) {
+ auto* node = map.free_.Take();
+ slot.Add(node);
+ map.count_++;
+ entry = &node->Entry();
+ new (entry) Entry{Key{hash, std::forward<K>(key)}, std::forward<V>(values)...};
+ }
+ };
+
+ /// EditAt is a helper for map entry replacement and entry insertion.
+ /// Before indexing, EditAt will ensure there's at least one free node available, potentially
+ /// allocating and rehashing if there's no free nodes available.
+ /// @param key the key used to compute the hash, look up the slot and search for the existing
+ /// node.
+ /// @returns a EditIndex used to modify or insert a new entry into the map with the given key.
+ template <typename K>
+ EditIndex EditAt(K&& key) {
+ if (!free_.nodes_) {
+ free_.Allocate(capacity_);
+ capacity_ += capacity_;
+ Rehash();
+ }
+ size_t hash = Hash{}(key);
+ auto& slot = slots_[hash % slots_.Length()];
+ auto* entry = slot.Find(hash, key);
+ return {*this, slot, hash, entry};
}
- /// Shuffles slots for an insertion that has been placed one slot before `start`.
- /// @param start the index of the first slot to start shuffling.
- /// @param evicted the slot content that was evicted for the insertion.
- void InsertShuffle(size_t start, Slot&& evicted) {
- const auto count = slots_.Length();
- for (size_t distance = 0, index = start; distance < count; distance++) {
- auto& slot = slots_[index];
-
- if (!slot.entry.has_value()) {
- // Empty slot found for evicted.
- slot = std::move(evicted);
- return; // We're done.
+ /// Rehash resizes the slots vector proportionally to the map capacity, and then reinserts the
+ /// nodes so they're linked in the correct slots linked lists.
+ void Rehash() {
+ size_t num_slots = NumSlots(capacity_);
+ decltype(slots_) old_slots;
+ std::swap(slots_, old_slots);
+ slots_.Resize(num_slots);
+ for (size_t old_slot_idx = 0; old_slot_idx < old_slots.Length(); old_slot_idx++) {
+ auto* node = old_slots[old_slot_idx].nodes;
+ while (node) {
+ auto next = node->next;
+ size_t new_slot_idx = node->Key().hash % num_slots;
+ slots_[new_slot_idx].Add(node);
+ node = next;
}
-
- if (slot.distance < evicted.distance) {
- // Occupied slot has shorter distance to evicted.
- // Swap slot and evicted.
- std::swap(slot, evicted);
- }
-
- // evicted moves further from the target slot...
- evicted.distance++;
-
- index = (index == count - 1) ? 0 : index + 1;
}
}
- /// @param count the number of new entries in the map
- /// @returns true if the map should grow the slot vector, and rehash the items.
- bool ShouldRehash(size_t count) const { return NumSlots(count) > slots_.Length(); }
+ /// Slot holds a linked list of nodes. Nodes are assigned to the slot list by calculating the
+ /// modulo of the entry's hash with the slot_ vector length.
+ struct Slot {
+ /// The linked list of nodes in this slot.
+ Node* nodes = nullptr;
- /// @param index an input value
- /// @returns the input value modulo the number of slots.
- size_t Wrap(size_t index) const { return index % slots_.Length(); }
+ /// Add adds the node @p node to this slot.
+ /// @note The node must be unlinked from any existing list before calling.
+ /// @param node the node to add.
+ void Add(Node* node) {
+ node->next = nodes;
+ nodes = node;
+ }
- /// The vector of slots. The vector length is equal to its capacity.
- Vector<Slot, kMinSlots> slots_;
+ /// @returns the node in the slot with the given hash and key.
+ /// @param hash the key hash to search for.
+ /// @param key the key value to search for.
+ template <typename K>
+ const Entry* Find(size_t hash, K&& key) const {
+ for (auto* node = nodes; node; node = node->next) {
+ if (node->Equals(hash, key)) {
+ return &node->Entry();
+ }
+ }
+ return nullptr;
+ }
- /// The number of entries in the map.
+ /// @returns the node in the slot with the given hash and key.
+ /// @param hash the key hash to search for.
+ /// @param key the key value to search for.
+ template <typename K>
+ Entry* Find(size_t hash, K&& key) {
+ for (auto* node = nodes; node; node = node->next) {
+ if (node->Equals(hash, key)) {
+ return &node->Entry();
+ }
+ }
+ return nullptr;
+ }
+ };
+
+ /// Free holds a linked list of nodes which are currently not used by entries in the map, and a
+ /// linked list of node allocations.
+ struct FreeNodes {
+ /// Allocation is the header of a block of memory that holds Nodes.
+ struct Allocation {
+ /// The linked list of allocations.
+ Allocation* next = nullptr;
+ // Node[] array follows this structure.
+ };
+
+ /// The linked list of free nodes.
+ Node* nodes_ = nullptr;
+
+ /// The linked list of allocations.
+ Allocation* allocations_ = nullptr;
+
+ /// Destructor.
+ /// Frees all the allocations made.
+ ~FreeNodes() {
+ auto* allocation = allocations_;
+ while (allocation) {
+ auto* next = allocation->next;
+ free(allocation);
+ allocation = next;
+ }
+ }
+
+ /// @returns the next free node in the list
+ Node* Take() {
+ auto* node = nodes_;
+ nodes_ = node->next;
+ node->next = nullptr;
+ return node;
+ }
+
+ /// Add adds the node @p node to the list of free nodes.
+ /// @note The node must be unlinked from any existing list before calling.
+ /// @param node the node to add.
+ void Add(Node* node) {
+ node->next = nodes_;
+ nodes_ = node;
+ }
+
+ /// Allocate allocates an additional @p count nodes and adds them to the free node list.
+ /// @param count the number of new nodes to allocate.
+ /// @note callers must remember to increment HashmapBase::capacity_ by the same amount.
+ void Allocate(size_t count) {
+ static_assert(std::is_trivial_v<Node>,
+ "Node is not trivial, and will require construction / destruction");
+ constexpr size_t kAllocationSize = RoundUp(alignof(Node), sizeof(Allocation));
+ auto* memory =
+ reinterpret_cast<std::byte*>(malloc(kAllocationSize + sizeof(Node) * count));
+ if (TINT_UNLIKELY(!memory)) {
+ TINT_ICE() << "out of memory";
+ return;
+ }
+ auto* nodes_allocation = Bitcast<Allocation*>(memory);
+ nodes_allocation->next = allocations_;
+ allocations_ = nodes_allocation;
+
+ auto* nodes = Bitcast<Node*>(memory + kAllocationSize);
+ for (size_t i = 0; i < count; i++) {
+ Add(&nodes[i]);
+ }
+ }
+ };
+
+ /// The fixed-size array of nodes, used for the first kMinCapacity entries of the map, before
+ /// allocating from the heap.
+ std::array<Node, kMinCapacity> fixed_;
+ /// The vector of slots. Each slot holds a linked list of nodes which hold entries in the map.
+ Vector<Slot, NumSlots(N)> slots_;
+ /// The linked list of free nodes, and node allocations from the heap.
+ FreeNodes free_;
+ /// The total number of nodes, including free nodes (kMinCapacity + heap-allocated)
+ size_t capacity_ = kMinCapacity;
+ /// The total number of nodes that currently hold map entries.
size_t count_ = 0;
-
- /// Counter that's incremented with each modification to the map.
- size_t generation_ = 0;
};
} // namespace tint
diff --git a/src/tint/utils/containers/hashmap_test.cc b/src/tint/utils/containers/hashmap_test.cc
index af3efee..af926a2 100644
--- a/src/tint/utils/containers/hashmap_test.cc
+++ b/src/tint/utils/containers/hashmap_test.cc
@@ -82,40 +82,19 @@
EXPECT_FALSE(map.Contains("world"));
}
-TEST(Hashmap, Generation) {
- Hashmap<int, std::string, 8> map;
- EXPECT_EQ(map.Generation(), 0u);
- map.Add(1, "one");
- EXPECT_EQ(map.Generation(), 1u);
- map.Add(1, "uno");
- EXPECT_EQ(map.Generation(), 1u);
- map.Replace(1, "une");
- EXPECT_EQ(map.Generation(), 2u);
- map.Add(2, "dos");
- EXPECT_EQ(map.Generation(), 3u);
- map.Remove(1);
- EXPECT_EQ(map.Generation(), 4u);
- map.Clear();
- EXPECT_EQ(map.Generation(), 5u);
- map.Find(2);
- EXPECT_EQ(map.Generation(), 5u);
- map.Get(2);
- EXPECT_EQ(map.Generation(), 5u);
-}
-
TEST(Hashmap, Index) {
Hashmap<int, std::string, 4> map;
- auto zero = map.Find(0);
+ auto zero = map.Get(0);
EXPECT_FALSE(zero);
map.Add(3, "three");
- auto three = map.Find(3);
+ auto three = map.Get(3);
map.Add(2, "two");
- auto two = map.Find(2);
+ auto two = map.Get(2);
map.Add(4, "four");
- auto four = map.Find(4);
+ auto four = map.Get(4);
map.Add(8, "eight");
- auto eight = map.Find(8);
+ auto eight = map.Get(8);
EXPECT_FALSE(zero);
ASSERT_TRUE(three);
@@ -123,23 +102,23 @@
ASSERT_TRUE(four);
ASSERT_TRUE(eight);
- EXPECT_EQ(*three, "three");
- EXPECT_EQ(*two, "two");
- EXPECT_EQ(*four, "four");
- EXPECT_EQ(*eight, "eight");
+ EXPECT_EQ(three, "three");
+ EXPECT_EQ(two, "two");
+ EXPECT_EQ(four, "four");
+ EXPECT_EQ(eight, "eight");
- map.Add(0, "zero"); // Note: Find called before Add() is okay!
+ map.Add(0, "zero");
+ EXPECT_FALSE(zero);
map.Add(5, "five");
- auto five = map.Find(5);
+ auto five = map.Get(5);
map.Add(6, "six");
- auto six = map.Find(6);
+ auto six = map.Get(6);
map.Add(1, "one");
- auto one = map.Find(1);
+ auto one = map.Get(1);
map.Add(7, "seven");
- auto seven = map.Find(7);
+ auto seven = map.Get(7);
- ASSERT_TRUE(zero);
ASSERT_TRUE(three);
ASSERT_TRUE(two);
ASSERT_TRUE(four);
@@ -149,47 +128,38 @@
ASSERT_TRUE(one);
ASSERT_TRUE(seven);
- EXPECT_EQ(*zero, "zero");
- EXPECT_EQ(*three, "three");
- EXPECT_EQ(*two, "two");
- EXPECT_EQ(*four, "four");
- EXPECT_EQ(*eight, "eight");
- EXPECT_EQ(*five, "five");
- EXPECT_EQ(*six, "six");
- EXPECT_EQ(*one, "one");
- EXPECT_EQ(*seven, "seven");
-
- map.Remove(2);
- map.Remove(8);
- map.Remove(1);
-
- EXPECT_FALSE(two);
- EXPECT_FALSE(eight);
- EXPECT_FALSE(one);
+ EXPECT_EQ(three, "three");
+ EXPECT_EQ(two, "two");
+ EXPECT_EQ(four, "four");
+ EXPECT_EQ(eight, "eight");
+ EXPECT_EQ(five, "five");
+ EXPECT_EQ(six, "six");
+ EXPECT_EQ(one, "one");
+ EXPECT_EQ(seven, "seven");
}
TEST(Hashmap, StringKeys) {
Hashmap<std::string, int, 4> map;
- EXPECT_FALSE(map.Find("zero"));
- EXPECT_FALSE(map.Find(std::string("zero")));
- EXPECT_FALSE(map.Find(std::string_view("zero")));
+ EXPECT_FALSE(map.Get("zero"));
+ EXPECT_FALSE(map.Get(std::string("zero")));
+ EXPECT_FALSE(map.Get(std::string_view("zero")));
map.Add("three", 3);
- auto three_cstr = map.Find("three");
- auto three_str = map.Find(std::string("three"));
- auto three_sv = map.Find(std::string_view("three"));
+ auto three_cstr = map.Get("three");
+ auto three_str = map.Get(std::string("three"));
+ auto three_sv = map.Get(std::string_view("three"));
map.Add(std::string("two"), 2);
- auto two_cstr = map.Find("two");
- auto two_str = map.Find(std::string("two"));
- auto two_sv = map.Find(std::string_view("two"));
+ auto two_cstr = map.Get("two");
+ auto two_str = map.Get(std::string("two"));
+ auto two_sv = map.Get(std::string_view("two"));
map.Add("four", 4);
- auto four_cstr = map.Find("four");
- auto four_str = map.Find(std::string("four"));
- auto four_sv = map.Find(std::string_view("four"));
+ auto four_cstr = map.Get("four");
+ auto four_str = map.Get(std::string("four"));
+ auto four_sv = map.Get(std::string_view("four"));
map.Add(std::string("eight"), 8);
- auto eight_cstr = map.Find("eight");
- auto eight_str = map.Find(std::string("eight"));
- auto eight_sv = map.Find(std::string_view("eight"));
+ auto eight_cstr = map.Get("eight");
+ auto eight_str = map.Get(std::string("eight"));
+ auto eight_sv = map.Get(std::string_view("eight"));
ASSERT_TRUE(three_cstr);
ASSERT_TRUE(three_str);
@@ -204,40 +174,40 @@
ASSERT_TRUE(eight_str);
ASSERT_TRUE(eight_sv);
- EXPECT_EQ(*three_cstr, 3);
- EXPECT_EQ(*three_str, 3);
- EXPECT_EQ(*three_sv, 3);
- EXPECT_EQ(*two_cstr, 2);
- EXPECT_EQ(*two_str, 2);
- EXPECT_EQ(*two_sv, 2);
- EXPECT_EQ(*four_cstr, 4);
- EXPECT_EQ(*four_str, 4);
- EXPECT_EQ(*four_sv, 4);
- EXPECT_EQ(*eight_cstr, 8);
- EXPECT_EQ(*eight_str, 8);
- EXPECT_EQ(*eight_sv, 8);
+ EXPECT_EQ(three_cstr, 3);
+ EXPECT_EQ(three_str, 3);
+ EXPECT_EQ(three_sv, 3);
+ EXPECT_EQ(two_cstr, 2);
+ EXPECT_EQ(two_str, 2);
+ EXPECT_EQ(two_sv, 2);
+ EXPECT_EQ(four_cstr, 4);
+ EXPECT_EQ(four_str, 4);
+ EXPECT_EQ(four_sv, 4);
+ EXPECT_EQ(eight_cstr, 8);
+ EXPECT_EQ(eight_str, 8);
+ EXPECT_EQ(eight_sv, 8);
- map.Add("zero", 0); // Note: Find called before Add() is okay!
- auto zero_cstr = map.Find("zero");
- auto zero_str = map.Find(std::string("zero"));
- auto zero_sv = map.Find(std::string_view("zero"));
+ map.Add("zero", 0);
+ auto zero_cstr = map.Get("zero");
+ auto zero_str = map.Get(std::string("zero"));
+ auto zero_sv = map.Get(std::string_view("zero"));
map.Add(std::string("five"), 5);
- auto five_cstr = map.Find("five");
- auto five_str = map.Find(std::string("five"));
- auto five_sv = map.Find(std::string_view("five"));
+ auto five_cstr = map.Get("five");
+ auto five_str = map.Get(std::string("five"));
+ auto five_sv = map.Get(std::string_view("five"));
map.Add("six", 6);
- auto six_cstr = map.Find("six");
- auto six_str = map.Find(std::string("six"));
- auto six_sv = map.Find(std::string_view("six"));
+ auto six_cstr = map.Get("six");
+ auto six_str = map.Get(std::string("six"));
+ auto six_sv = map.Get(std::string_view("six"));
map.Add("one", 1);
- auto one_cstr = map.Find("one");
- auto one_str = map.Find(std::string("one"));
- auto one_sv = map.Find(std::string_view("one"));
+ auto one_cstr = map.Get("one");
+ auto one_str = map.Get(std::string("one"));
+ auto one_sv = map.Get(std::string_view("one"));
map.Add(std::string("seven"), 7);
- auto seven_cstr = map.Find("seven");
- auto seven_str = map.Find(std::string("seven"));
- auto seven_sv = map.Find(std::string_view("seven"));
+ auto seven_cstr = map.Get("seven");
+ auto seven_str = map.Get(std::string("seven"));
+ auto seven_sv = map.Get(std::string_view("seven"));
ASSERT_TRUE(zero_cstr);
ASSERT_TRUE(zero_str);
@@ -267,33 +237,33 @@
ASSERT_TRUE(seven_str);
ASSERT_TRUE(seven_sv);
- EXPECT_EQ(*zero_cstr, 0);
- EXPECT_EQ(*zero_str, 0);
- EXPECT_EQ(*zero_sv, 0);
- EXPECT_EQ(*three_cstr, 3);
- EXPECT_EQ(*three_str, 3);
- EXPECT_EQ(*three_sv, 3);
- EXPECT_EQ(*two_cstr, 2);
- EXPECT_EQ(*two_str, 2);
- EXPECT_EQ(*two_sv, 2);
- EXPECT_EQ(*four_cstr, 4);
- EXPECT_EQ(*four_str, 4);
- EXPECT_EQ(*four_sv, 4);
- EXPECT_EQ(*eight_cstr, 8);
- EXPECT_EQ(*eight_str, 8);
- EXPECT_EQ(*eight_sv, 8);
- EXPECT_EQ(*five_cstr, 5);
- EXPECT_EQ(*five_str, 5);
- EXPECT_EQ(*five_sv, 5);
- EXPECT_EQ(*six_cstr, 6);
- EXPECT_EQ(*six_str, 6);
- EXPECT_EQ(*six_sv, 6);
- EXPECT_EQ(*one_cstr, 1);
- EXPECT_EQ(*one_str, 1);
- EXPECT_EQ(*one_sv, 1);
- EXPECT_EQ(*seven_cstr, 7);
- EXPECT_EQ(*seven_str, 7);
- EXPECT_EQ(*seven_sv, 7);
+ EXPECT_EQ(zero_cstr, 0);
+ EXPECT_EQ(zero_str, 0);
+ EXPECT_EQ(zero_sv, 0);
+ EXPECT_EQ(three_cstr, 3);
+ EXPECT_EQ(three_str, 3);
+ EXPECT_EQ(three_sv, 3);
+ EXPECT_EQ(two_cstr, 2);
+ EXPECT_EQ(two_str, 2);
+ EXPECT_EQ(two_sv, 2);
+ EXPECT_EQ(four_cstr, 4);
+ EXPECT_EQ(four_str, 4);
+ EXPECT_EQ(four_sv, 4);
+ EXPECT_EQ(eight_cstr, 8);
+ EXPECT_EQ(eight_str, 8);
+ EXPECT_EQ(eight_sv, 8);
+ EXPECT_EQ(five_cstr, 5);
+ EXPECT_EQ(five_str, 5);
+ EXPECT_EQ(five_sv, 5);
+ EXPECT_EQ(six_cstr, 6);
+ EXPECT_EQ(six_str, 6);
+ EXPECT_EQ(six_sv, 6);
+ EXPECT_EQ(one_cstr, 1);
+ EXPECT_EQ(one_str, 1);
+ EXPECT_EQ(one_sv, 1);
+ EXPECT_EQ(seven_cstr, 7);
+ EXPECT_EQ(seven_str, 7);
+ EXPECT_EQ(seven_sv, 7);
}
TEST(Hashmap, Iterator) {
@@ -316,7 +286,7 @@
map.Add(4, "four");
map.Add(3, "three");
map.Add(2, "two");
- for (auto pair : map) {
+ for (auto& pair : map) {
pair.value += "!";
}
EXPECT_THAT(map, testing::UnorderedElementsAre(Entry{1, "one!"}, Entry{2, "two!"},
@@ -349,44 +319,46 @@
}
}
-TEST(Hashmap, GetOrCreate) {
+TEST(Hashmap, GetOrAdd) {
Hashmap<int, std::string, 8> map;
std::optional<std::string> value_of_key_0_at_create;
- EXPECT_EQ(map.GetOrCreate(0,
- [&] {
- value_of_key_0_at_create = map.Get(0);
- return "zero";
- }),
+ EXPECT_EQ(map.GetOrAdd(0,
+ [&] {
+ if (auto existing = map.Get(0)) {
+ value_of_key_0_at_create = *existing;
+ }
+ return "zero";
+ }),
"zero");
EXPECT_EQ(map.Count(), 1u);
EXPECT_EQ(map.Get(0), "zero");
EXPECT_EQ(value_of_key_0_at_create, "");
bool create_called = false;
- EXPECT_EQ(map.GetOrCreate(0,
- [&] {
- create_called = true;
- return "oh noes";
- }),
+ EXPECT_EQ(map.GetOrAdd(0,
+ [&] {
+ create_called = true;
+ return "oh noes";
+ }),
"zero");
EXPECT_FALSE(create_called);
EXPECT_EQ(map.Count(), 1u);
EXPECT_EQ(map.Get(0), "zero");
- EXPECT_EQ(map.GetOrCreate(1, [&] { return "one"; }), "one");
+ EXPECT_EQ(map.GetOrAdd(1, [&] { return "one"; }), "one");
EXPECT_EQ(map.Count(), 2u);
EXPECT_EQ(map.Get(1), "one");
}
-TEST(Hashmap, GetOrCreate_CreateModifiesMap) {
+TEST(Hashmap, GetOrAdd_CreateModifiesMap) {
Hashmap<int, std::string, 8> map;
- EXPECT_EQ(map.GetOrCreate(0,
- [&] {
- map.Add(3, "three");
- map.Add(1, "one");
- map.Add(2, "two");
- return "zero";
- }),
+ EXPECT_EQ(map.GetOrAdd(0,
+ [&] {
+ map.Add(3, "three");
+ map.Add(1, "one");
+ map.Add(2, "two");
+ return "zero";
+ }),
"zero");
EXPECT_EQ(map.Count(), 4u);
EXPECT_EQ(map.Get(0), "zero");
@@ -395,11 +367,11 @@
EXPECT_EQ(map.Get(3), "three");
bool create_called = false;
- EXPECT_EQ(map.GetOrCreate(0,
- [&] {
- create_called = true;
- return "oh noes";
- }),
+ EXPECT_EQ(map.GetOrAdd(0,
+ [&] {
+ create_called = true;
+ return "oh noes";
+ }),
"zero");
EXPECT_FALSE(create_called);
EXPECT_EQ(map.Count(), 4u);
@@ -408,13 +380,13 @@
EXPECT_EQ(map.Get(2), "two");
EXPECT_EQ(map.Get(3), "three");
- EXPECT_EQ(map.GetOrCreate(4,
- [&] {
- map.Add(6, "six");
- map.Add(5, "five");
- map.Add(7, "seven");
- return "four";
- }),
+ EXPECT_EQ(map.GetOrAdd(4,
+ [&] {
+ map.Add(6, "six");
+ map.Add(5, "five");
+ map.Add(7, "seven");
+ return "four";
+ }),
"four");
EXPECT_EQ(map.Count(), 8u);
EXPECT_EQ(map.Get(0), "zero");
@@ -427,13 +399,13 @@
EXPECT_EQ(map.Get(7), "seven");
}
-TEST(Hashmap, GetOrCreate_CreateAddsSameKeyedValue) {
+TEST(Hashmap, GetOrAdd_CreateAddsSameKeyedValue) {
Hashmap<int, std::string, 8> map;
- EXPECT_EQ(map.GetOrCreate(42,
- [&] {
- map.Add(42, "should-be-replaced");
- return "expected-value";
- }),
+ EXPECT_EQ(map.GetOrAdd(42,
+ [&] {
+ map.Add(42, "should-be-replaced");
+ return "expected-value";
+ }),
"expected-value");
EXPECT_EQ(map.Count(), 1u);
EXPECT_EQ(map.Get(42), "expected-value");
@@ -464,7 +436,7 @@
case 2: { // Remove
auto expected = reference.erase(key) != 0;
EXPECT_EQ(map.Remove(key), expected) << "i:" << i;
- EXPECT_FALSE(map.Get(key).has_value()) << "i:" << i;
+ EXPECT_FALSE(map.Get(key)) << "i:" << i;
EXPECT_FALSE(map.Contains(key)) << "i:" << i;
break;
}
@@ -478,7 +450,7 @@
auto expected = reference[key];
EXPECT_EQ(map.Get(key), expected) << "i:" << i;
} else {
- EXPECT_FALSE(map.Get(key).has_value()) << "i:" << i;
+ EXPECT_FALSE(map.Get(key)) << "i:" << i;
}
break;
}
diff --git a/src/tint/utils/containers/hashset.h b/src/tint/utils/containers/hashset.h
index da07e43..9bdc7ef 100644
--- a/src/tint/utils/containers/hashset.h
+++ b/src/tint/utils/containers/hashset.h
@@ -40,11 +40,10 @@
namespace tint {
-/// An unordered set that uses a robin-hood hashing algorithm.
+/// An unordered hashset, with a fixed-size capacity that avoids heap allocations.
template <typename KEY, size_t N, typename HASH = Hasher<KEY>, typename EQUAL = std::equal_to<KEY>>
-class Hashset : public HashmapBase<KEY, void, N, HASH, EQUAL> {
- using Base = HashmapBase<KEY, void, N, HASH, EQUAL>;
- using PutMode = typename Base::PutMode;
+class Hashset : public HashmapBase<HashmapKey<KEY, HASH, EQUAL>, N> {
+ using Base = HashmapBase<HashmapKey<KEY, HASH, EQUAL>, N>;
public:
using Base::Base;
@@ -63,8 +62,12 @@
/// @returns true if the value was added, false if there was an existing value in the set.
template <typename V>
bool Add(V&& value) {
- struct NoValue {};
- return this->template Put<PutMode::kAdd>(std::forward<V>(value), NoValue{});
+ auto idx = this->EditAt(value);
+ if (idx.entry) {
+ return false; // Entry already exists
+ }
+ idx.Insert(std::forward<V>(value));
+ return true;
}
/// @returns the set entries of the map as a vector
@@ -73,8 +76,8 @@
tint::Vector<KEY, N2> Vector() const {
tint::Vector<KEY, N2> out;
out.Reserve(this->Count());
- for (auto& value : *this) {
- out.Push(value);
+ for (auto& key : *this) {
+ out.Push(key.Value());
}
return out;
}
@@ -83,8 +86,8 @@
/// @param pred a function-like with the signature `bool(T)`
template <typename PREDICATE>
bool Any(PREDICATE&& pred) const {
- for (const auto& it : *this) {
- if (pred(it)) {
+ for (const auto& key : *this) {
+ if (pred(key.Value())) {
return true;
}
}
@@ -95,13 +98,23 @@
/// @param pred a function-like with the signature `bool(T)`
template <typename PREDICATE>
bool All(PREDICATE&& pred) const {
- for (const auto& it : *this) {
- if (!pred(it)) {
+ for (const auto& key : *this) {
+ if (!pred(key.Value())) {
return false;
}
}
return true;
}
+
+ /// Looks up an entry in the set that is equal to @p key
+ /// @param key the key to search for.
+ /// @returns the entry that is equal to @p key
+ std::optional<KEY> Get(const KEY& key) const {
+ if (auto [found, index] = this->IndexOf(key); found) {
+ return this->slots_[index].entry;
+ }
+ return std::nullopt;
+ }
};
} // namespace tint
diff --git a/src/tint/utils/containers/hashset_test.cc b/src/tint/utils/containers/hashset_test.cc
index 96e22bc..e9ae39e 100644
--- a/src/tint/utils/containers/hashset_test.cc
+++ b/src/tint/utils/containers/hashset_test.cc
@@ -55,11 +55,11 @@
TEST(Hashset, InitializerConstructor) {
Hashset<int, 8> set{1, 5, 7};
EXPECT_EQ(set.Count(), 3u);
- EXPECT_TRUE(set.Contains(1u));
- EXPECT_FALSE(set.Contains(3u));
- EXPECT_TRUE(set.Contains(5u));
- EXPECT_TRUE(set.Contains(7u));
- EXPECT_FALSE(set.Contains(9u));
+ EXPECT_TRUE(set.Contains(1));
+ EXPECT_FALSE(set.Contains(3));
+ EXPECT_TRUE(set.Contains(5));
+ EXPECT_TRUE(set.Contains(7));
+ EXPECT_FALSE(set.Contains(9));
}
TEST(Hashset, AddRemove) {
@@ -83,7 +83,6 @@
ASSERT_TRUE(set.Add(prime)) << "i: " << i;
ASSERT_FALSE(set.Add(prime)) << "i: " << i;
ASSERT_EQ(set.Count(), i + 1);
- set.ValidateIntegrity();
}
ASSERT_EQ(set.Count(), kPrimes.size());
for (int prime : kPrimes) {
@@ -91,21 +90,6 @@
}
}
-TEST(Hashset, Generation) {
- Hashset<int, 8> set;
- EXPECT_EQ(set.Generation(), 0u);
- set.Add(1);
- EXPECT_EQ(set.Generation(), 1u);
- set.Add(1);
- EXPECT_EQ(set.Generation(), 1u);
- set.Add(2);
- EXPECT_EQ(set.Generation(), 2u);
- set.Remove(1);
- EXPECT_EQ(set.Generation(), 3u);
- set.Clear();
- EXPECT_EQ(set.Generation(), 4u);
-}
-
TEST(Hashset, Iterator) {
Hashset<std::string, 8> set;
set.Add("one");
@@ -160,7 +144,6 @@
break;
}
}
- set.ValidateIntegrity();
}
}
diff --git a/src/tint/utils/containers/map.h b/src/tint/utils/containers/map.h
index 3a2830c..5de5538 100644
--- a/src/tint/utils/containers/map.h
+++ b/src/tint/utils/containers/map.h
@@ -46,7 +46,7 @@
return it != map.end() ? it->second : if_missing;
}
-/// GetOrCreate is a utility function for lazily adding to an unordered map.
+/// GetOrAdd is a utility function for lazily adding to an unordered map.
/// If the map already contains the key `key` then this is returned, otherwise
/// `create()` is called and the result is added to the map and is returned.
/// @param map the unordered_map
@@ -54,7 +54,7 @@
/// @param create a callable function-like object with the signature `V()`
/// @return the value of the item with the given key, or the newly created item
template <typename K, typename V, typename H, typename C, typename CREATE>
-V GetOrCreate(std::unordered_map<K, V, H, C>& map, const K& key, CREATE&& create) {
+V GetOrAdd(std::unordered_map<K, V, H, C>& map, const K& key, CREATE&& create) {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
diff --git a/src/tint/utils/containers/map_test.cc b/src/tint/utils/containers/map_test.cc
index 8b66b6d..34e2d1f 100644
--- a/src/tint/utils/containers/map_test.cc
+++ b/src/tint/utils/containers/map_test.cc
@@ -43,22 +43,22 @@
EXPECT_EQ(Lookup(map, 20), 0); // missing, without if_missing
}
-TEST(GetOrCreateTest, NewKey) {
+TEST(GetOrAddTest, NewKey) {
std::unordered_map<int, int> map;
- EXPECT_EQ(GetOrCreate(map, 1, [&] { return 2; }), 2);
+ EXPECT_EQ(GetOrAdd(map, 1, [&] { return 2; }), 2);
EXPECT_EQ(map.size(), 1u);
EXPECT_EQ(map[1], 2);
}
-TEST(GetOrCreateTest, ExistingKey) {
+TEST(GetOrAddTest, ExistingKey) {
std::unordered_map<int, int> map;
map[1] = 2;
bool called = false;
- EXPECT_EQ(GetOrCreate(map, 1,
- [&] {
- called = true;
- return -2;
- }),
+ EXPECT_EQ(GetOrAdd(map, 1,
+ [&] {
+ called = true;
+ return -2;
+ }),
2);
EXPECT_EQ(called, false);
EXPECT_EQ(map.size(), 1u);
diff --git a/src/tint/utils/containers/scope_stack.h b/src/tint/utils/containers/scope_stack.h
index e87bff5..066d367 100644
--- a/src/tint/utils/containers/scope_stack.h
+++ b/src/tint/utils/containers/scope_stack.h
@@ -65,7 +65,7 @@
/// @returns the old value if there was an existing key at the top of the
/// stack, otherwise the zero initializer for type T.
V Set(const K& key, V val) {
- if (auto el = Top().Find(key)) {
+ if (auto el = Top().Get(key)) {
std::swap(val, *el);
return val;
}
@@ -78,7 +78,7 @@
/// @returns the value, or the zero initializer if the value was not found
V Get(const K& key) const {
for (size_t i = depth_; i > 0; i--) {
- if (auto val = stack_[i - 1].Find(key)) {
+ if (auto val = stack_[i - 1].Get(key)) {
return *val;
}
}
diff --git a/src/tint/utils/containers/unique_allocator.h b/src/tint/utils/containers/unique_allocator.h
index 124adc4..8d88623 100644
--- a/src/tint/utils/containers/unique_allocator.h
+++ b/src/tint/utils/containers/unique_allocator.h
@@ -29,15 +29,14 @@
#define SRC_TINT_UTILS_CONTAINERS_UNIQUE_ALLOCATOR_H_
#include <functional>
-#include <unordered_set>
#include <utility>
+#include "src/tint/utils/containers/hashmap_base.h"
#include "src/tint/utils/memory/block_allocator.h"
namespace tint {
-/// UniqueAllocator is used to allocate unique instances of the template type
-/// `T`.
+/// UniqueAllocator is used to allocate unique instances of the template type `T`.
template <typename T, typename HASH = std::hash<T>, typename EQUAL = std::equal_to<T>>
class UniqueAllocator {
public:
@@ -50,19 +49,15 @@
/// pointer is returned.
template <typename TYPE = T, typename... ARGS>
TYPE* Get(ARGS&&... args) {
- // Create a temporary T instance on the stack so that we can hash it, and
- // use it for equality lookup for the std::unordered_set. If the item is not
- // found in the set, then we create the persisted instance with the
- // allocator.
- TYPE key{args...};
- auto hash = Hasher{}(key);
- auto it = items.find(Entry{hash, &key});
- if (it != items.end()) {
- return static_cast<TYPE*>(it->ptr);
+ // Create a temporary T instance on the stack so that we can hash it, and use it for
+ // equality lookup for the Set. If the item is not found in the set, then we create the
+ // persisted instance with the allocator.
+ TYPE prototype{args...};
+ Key& key = items.Add(&prototype);
+ if (key.Value() == &prototype) {
+ key = allocator.template Create<TYPE>(std::forward<ARGS>(args)...);
}
- auto* ptr = allocator.template Create<TYPE>(std::forward<ARGS>(args)...);
- items.emplace_hint(it, Entry{hash, ptr});
- return ptr;
+ return static_cast<TYPE*>(key.Value());
}
/// @param args the arguments used to create the temporary used for the search.
@@ -70,15 +65,10 @@
/// was not found.
template <typename TYPE = T, typename... ARGS>
TYPE* Find(ARGS&&... args) const {
- // Create a temporary T instance on the stack so that we can hash it, and
- // use it for equality lookup for the std::unordered_set.
- TYPE key{args...};
- auto hash = Hasher{}(key);
- auto it = items.find(Entry{hash, &key});
- if (it != items.end()) {
- return static_cast<TYPE*>(it->ptr);
- }
- return nullptr;
+ // Create a temporary T instance on the stack so that we can hash it, and use it for
+ // equality lookup for the Set.
+ TYPE prototype{std::forward<ARGS>(args)...};
+ return static_cast<TYPE*>(items.Get(&prototype));
}
/// Wrap sets this allocator to the objects created with the content of `inner`.
@@ -95,36 +85,52 @@
Iterator end() const { return allocator.Objects().end(); }
private:
- /// The hash function
- using Hasher = HASH;
- /// The equality function
- using Equality = EQUAL;
-
- /// Entry is used as the entry to the unordered_set
- struct Entry {
- /// The pre-calculated hash of the entry
- size_t hash;
- /// The pointer to the unique object
- T* ptr;
- };
- /// Comparator is the hashing and equality function used by the unordered_set
- struct Comparator {
+ /// Comparator is the hashing function used by the Hashset
+ struct Hasher {
/// Hashing function
/// @param e the entry
/// @returns the hash of the entry
- size_t operator()(Entry e) const { return e.hash; }
+ size_t operator()(T* e) const { return HASH{}(*e); }
+ };
+ /// Equality is the equality function used by the Hashset
+ struct Equality {
/// Equality function
/// @param a the first entry to compare
/// @param b the second entry to compare
/// @returns true if the two entries are equal
- bool operator()(Entry a, Entry b) const { return EQUAL{}(*a.ptr, *b.ptr); }
+ bool operator()(T* a, T* b) const { return EQUAL{}(*a, *b); }
+ };
+
+ /// The maximum capacity of Set before it performs heap allocations.
+ static constexpr size_t kFixedSize = 32;
+
+ /// The key type of Set.
+ using Key = HashmapKey<T*, Hasher, Equality>;
+
+ /// A custom Hashset implementation that allows keys to be modified.
+ class Set : public HashmapBase<Key, kFixedSize> {
+ public:
+ Key& Add(T* key) {
+ auto idx = this->EditAt(key);
+ if (!idx.entry) {
+ idx.Insert(std::forward<T*>(key));
+ }
+ return *idx.entry;
+ }
+
+ T* Get(T* key) const {
+ if (auto* entry = this->GetEntry(key)) {
+ return *entry;
+ }
+ return nullptr;
+ }
};
/// The block allocator used to allocate the unique objects
BlockAllocator<T> allocator;
- /// The unordered_set of unique item entries
- std::unordered_set<Entry, Comparator, Comparator> items;
+ /// The set of unique item entries
+ Set items;
};
} // namespace tint
diff --git a/src/tint/utils/math/hash.h b/src/tint/utils/math/hash.h
index e143d07..43c97f4 100644
--- a/src/tint/utils/math/hash.h
+++ b/src/tint/utils/math/hash.h
@@ -122,11 +122,11 @@
/// @param ptr the pointer to hash
/// @returns a hash of the pointer
size_t operator()(T* ptr) const {
- auto hash = std::hash<T*>()(ptr);
+ auto hash = static_cast<size_t>(reinterpret_cast<uintptr_t>(ptr));
#ifdef TINT_HASH_SEED
hash ^= static_cast<uint32_t>(TINT_HASH_SEED);
#endif
- return hash ^ (hash >> 4);
+ return hash >> 4;
}
};
diff --git a/src/tint/utils/symbol/symbol_table.cc b/src/tint/utils/symbol/symbol_table.cc
index e2ffae0..6685517 100644
--- a/src/tint/utils/symbol/symbol_table.cc
+++ b/src/tint/utils/symbol/symbol_table.cc
@@ -42,33 +42,22 @@
Symbol SymbolTable::Register(std::string_view name) {
TINT_ASSERT(!name.empty());
- auto it = name_to_symbol_.Find(name);
- if (it) {
- return *it;
- }
- return RegisterInternal(name);
-}
-
-Symbol SymbolTable::RegisterInternal(std::string_view name) {
- char* name_mem = Bitcast<char*>(name_allocator_.Allocate(name.length() + 1));
- if (name_mem == nullptr) {
- TINT_ICE() << "failed to allocate memory for symbol's string";
- return Symbol();
+ auto& it = name_to_symbol_.GetOrAddZeroEntry(name);
+ if (it.value) {
+ return Symbol{it.value, generation_id_, it.key};
}
- memcpy(name_mem, name.data(), name.length() + 1);
- std::string_view nv(name_mem, name.length());
-
- Symbol sym(next_symbol_, generation_id_, nv);
- ++next_symbol_;
- name_to_symbol_.Add(sym.NameView(), sym);
-
- return sym;
+ auto view = Allocate(name);
+ it.key = view;
+ it.value = next_symbol_++;
+ return Symbol{it.value, generation_id_, view};
}
Symbol SymbolTable::Get(std::string_view name) const {
- auto it = name_to_symbol_.Find(name);
- return it ? *it : Symbol();
+ if (auto* entry = name_to_symbol_.GetEntry(name)) {
+ return Symbol{entry->value, generation_id_, entry->key};
+ }
+ return Symbol{};
}
Symbol SymbolTable::New(std::string_view prefix_view /* = "" */) {
@@ -79,30 +68,37 @@
prefix = std::string(prefix_view);
}
- auto it = name_to_symbol_.Find(prefix);
- if (!it) {
- return RegisterInternal(prefix);
+ auto& it = name_to_symbol_.GetOrAddZeroEntry(prefix);
+ if (it.value == 0) {
+ // prefix is a unique name
+ auto view = Allocate(prefix);
+ it.key = view;
+ it.value = next_symbol_++;
+ return Symbol{it.value, generation_id_, view};
}
- size_t i = 0;
- auto last_prefix = last_prefix_to_index_.Find(prefix);
- if (last_prefix) {
- i = *last_prefix;
- }
-
+ size_t& i = last_prefix_to_index_.GetOrAddZero(prefix);
std::string name;
do {
++i;
name = prefix + "_" + std::to_string(i);
} while (name_to_symbol_.Contains(name));
- auto sym = RegisterInternal(name);
- if (last_prefix) {
- *last_prefix = i;
- } else {
- last_prefix_to_index_.Add(prefix, i);
+ auto view = Allocate(name);
+ auto id = name_to_symbol_.Add(view, next_symbol_++);
+ return Symbol{id.value, generation_id_, view};
+}
+
+std::string_view SymbolTable::Allocate(std::string_view name) {
+ static_assert(sizeof(char) == 1);
+ char* name_mem = Bitcast<char*>(name_allocator_.Allocate(name.length() + 1));
+ if (name_mem == nullptr) {
+ TINT_ICE() << "failed to allocate memory for symbol's string";
+ return {};
}
- return sym;
+
+ memcpy(name_mem, name.data(), name.length() + 1);
+ return {name_mem, name.length()};
}
} // namespace tint
diff --git a/src/tint/utils/symbol/symbol_table.h b/src/tint/utils/symbol/symbol_table.h
index 7c39512..0fefa67 100644
--- a/src/tint/utils/symbol/symbol_table.h
+++ b/src/tint/utils/symbol/symbol_table.h
@@ -91,8 +91,8 @@
/// signature: `void(Symbol)`
template <typename F>
void Foreach(F&& callback) const {
- for (auto it : name_to_symbol_) {
- callback(it.value);
+ for (auto& it : name_to_symbol_) {
+ callback(Symbol{it.value, generation_id_, it.key});
}
}
@@ -103,12 +103,12 @@
SymbolTable(const SymbolTable&) = delete;
SymbolTable& operator=(const SymbolTable& other) = delete;
- Symbol RegisterInternal(std::string_view name);
+ std::string_view Allocate(std::string_view name);
// The value to be associated to the next registered symbol table entry.
uint32_t next_symbol_ = 1;
- Hashmap<std::string_view, Symbol, 0> name_to_symbol_;
+ Hashmap<std::string_view, uint32_t, 0> name_to_symbol_;
Hashmap<std::string, size_t, 0> last_prefix_to_index_;
tint::GenerationID generation_id_;