tint/utils: Make Hashmap::Find() safer to use
Don't return a raw pointer to the map entry's value, instead return a new Reference which re-looks up the entry if the map is mutated.
Change-Id: I031749785faeac98e2a129a776493cb0371a5cb9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110540
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/clone_context.cc b/src/tint/clone_context.cc
index 457522b..fe94a14 100644
--- a/src/tint/clone_context.cc
+++ b/src/tint/clone_context.cc
@@ -73,7 +73,7 @@
}
// Was Replace() called for this object?
- if (auto* fn = replacements_.Find(object)) {
+ if (auto fn = replacements_.Find(object)) {
return (*fn)();
}
diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h
index 05f4868..f1db617 100644
--- a/src/tint/clone_context.h
+++ b/src/tint/clone_context.h
@@ -208,7 +208,7 @@
to.Push(CheckedCast<T>(builder()));
}
for (auto& el : from) {
- if (auto* insert_before = transforms->insert_before_.Find(el)) {
+ if (auto insert_before = transforms->insert_before_.Find(el)) {
for (auto& builder : *insert_before) {
to.Push(CheckedCast<T>(builder()));
}
@@ -216,7 +216,7 @@
if (!transforms->remove_.Contains(el)) {
to.Push(Clone(el));
}
- if (auto* insert_after = transforms->insert_after_.Find(el)) {
+ if (auto insert_after = transforms->insert_after_.Find(el)) {
for (auto& builder : *insert_after) {
to.Push(CheckedCast<T>(builder()));
}
@@ -232,7 +232,7 @@
// Clone(el) may have updated the transformation list, adding an `insert_after`
// transform for `from`.
if (transforms) {
- if (auto* insert_after = transforms->insert_after_.Find(el)) {
+ if (auto insert_after = transforms->insert_after_.Find(el)) {
for (auto& builder : *insert_after) {
to.Push(CheckedCast<T>(builder()));
}
@@ -389,7 +389,7 @@
return *this;
}
- list_transforms_.Edit(&vector).remove_.Add(object);
+ list_transforms_.GetOrZero(&vector)->remove_.Add(object);
return *this;
}
@@ -411,7 +411,7 @@
/// @returns this CloneContext so calls can be chained
template <typename T, size_t N, typename BUILDER>
CloneContext& InsertFront(const utils::Vector<T, N>& vector, BUILDER&& builder) {
- list_transforms_.Edit(&vector).insert_front_.Push(std::forward<BUILDER>(builder));
+ list_transforms_.GetOrZero(&vector)->insert_front_.Push(std::forward<BUILDER>(builder));
return *this;
}
@@ -434,7 +434,7 @@
/// @returns this CloneContext so calls can be chained
template <typename T, size_t N, typename BUILDER>
CloneContext& InsertBack(const utils::Vector<T, N>& vector, BUILDER&& builder) {
- list_transforms_.Edit(&vector).insert_back_.Push(std::forward<BUILDER>(builder));
+ list_transforms_.GetOrZero(&vector)->insert_back_.Push(std::forward<BUILDER>(builder));
return *this;
}
@@ -456,7 +456,7 @@
return *this;
}
- list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
+ list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
[object] { return object; });
return *this;
}
@@ -475,7 +475,7 @@
CloneContext& InsertBefore(const utils::Vector<T, N>& vector,
const BEFORE* before,
BUILDER&& builder) {
- list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
+ list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
std::forward<BUILDER>(builder));
return *this;
}
@@ -498,7 +498,7 @@
return *this;
}
- list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
+ list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
[object] { return object; });
return *this;
}
@@ -517,7 +517,7 @@
CloneContext& InsertAfter(const utils::Vector<T, N>& vector,
const AFTER* after,
BUILDER&& builder) {
- list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
+ list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
std::forward<BUILDER>(builder));
return *this;
}
@@ -601,61 +601,6 @@
/// @returns the diagnostic list of #dst
diag::List& Diagnostics() const;
- /// VectorListTransforms is a map of utils::Vector pointer to transforms for that list
- struct VectorListTransforms {
- using Map = utils::Hashmap<const void*, ListTransforms, 4>;
-
- /// An accessor to the VectorListTransforms map.
- /// Index caches the last map lookup, and will only re-search the map if the transform map
- /// was modified since the last lookup.
- struct Index {
- /// @returns true if the map now holds a value for the index
- operator bool() {
- Update();
- return cached_;
- }
-
- /// @returns a pointer to the indexed map entry
- const ListTransforms* operator->() {
- Update();
- return cached_;
- }
-
- private:
- friend VectorListTransforms;
-
- Index(const void* list, Map* map)
- : list_(list),
- map_(map),
- generation_(map->Generation()),
- cached_(map_->Find(list)) {}
-
- void Update() {
- if (map_->Generation() != generation_) {
- cached_ = map_->Find(list_);
- generation_ = map_->Generation();
- }
- }
-
- const void* list_;
- Map* map_;
- uint64_t generation_;
- const ListTransforms* cached_;
- };
-
- /// Edit returns a reference to the ListTransforms for the given vector pointer and
- /// increments #list_transform_generation_ signalling that the list transforms have been
- /// modified.
- inline ListTransforms& Edit(const void* list) { return map_.GetOrZero(list); }
-
- /// @returns an Index to the transforms for the given list.
- inline Index Find(const void* list) { return Index{list, &map_}; }
-
- private:
- /// The map of vector pointer to ListTransforms
- Map map_;
- };
-
/// A map of object in #src to functions that create their replacement in #dst
utils::Hashmap<const Cloneable*, std::function<const Cloneable*()>, 8> replacements_;
@@ -666,7 +611,7 @@
utils::Vector<CloneableTransform, 8> transforms_;
/// Transformations to apply to vectors
- VectorListTransforms list_transforms_;
+ utils::Hashmap<const void*, ListTransforms, 4> list_transforms_;
/// Symbol transform registered with ReplaceAll()
SymbolTransform symbol_transform_;
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index f97d169..accd692 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -3522,7 +3522,7 @@
const auto phi_id = assignment.phi_id;
auto* const lhs_expr = builder_.Expr(namer_.Name(phi_id));
// If RHS value is actually a phi we just cpatured, then use it.
- auto* const copy_sym = copied_phis.Find(assignment.value_id);
+ auto copy_sym = copied_phis.Find(assignment.value_id);
auto* const rhs_expr =
copy_sym ? builder_.Expr(*copy_sym) : MakeExpression(assignment.value_id).expr;
AddStatement(builder_.Assign(lhs_expr, rhs_expr));
diff --git a/src/tint/reader/spirv/parser_impl.h b/src/tint/reader/spirv/parser_impl.h
index 1780ca3..9ff9033 100644
--- a/src/tint/reader/spirv/parser_impl.h
+++ b/src/tint/reader/spirv/parser_impl.h
@@ -666,7 +666,7 @@
/// @param id a SPIR-V ID
/// @returns the AST variable or null.
const ast::Var* GetModuleVariable(uint32_t id) {
- auto* entry = module_variable_.Find(id);
+ auto entry = module_variable_.Find(id);
return entry ? *entry : nullptr;
}
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 1fda54f..796412b 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -471,7 +471,7 @@
}
}
- if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
+ if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
DependencyInfo{from->source, action})) {
current_global_->deps.Push(*global);
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 2cc4a3a..81e79ff 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1128,8 +1128,8 @@
if (expect_pass) {
// Check that the use resolves to the declaration
- auto* resolved_symbol = graph.resolved_symbols.Find(use);
- ASSERT_NE(resolved_symbol, nullptr);
+ auto resolved_symbol = graph.resolved_symbols.Find(use);
+ ASSERT_TRUE(resolved_symbol);
EXPECT_EQ(*resolved_symbol, decl)
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n"
@@ -1179,8 +1179,8 @@
helper.Build();
auto shadows = Build().shadows;
- auto* shadow = shadows.Find(inner_var);
- ASSERT_NE(shadow, nullptr);
+ auto shadow = shadows.Find(inner_var);
+ ASSERT_TRUE(shadow);
EXPECT_EQ(*shadow, outer);
}
@@ -1310,8 +1310,8 @@
auto graph = Build();
for (auto use : symbol_uses) {
- auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
- ASSERT_NE(resolved_symbol, nullptr) << use.where;
+ auto resolved_symbol = graph.resolved_symbols.Find(use.use);
+ ASSERT_TRUE(resolved_symbol) << use.where;
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
}
}
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 05f7675..0fdd78c 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2481,7 +2481,7 @@
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
- if (auto* decl = loop_block->Decls().Find(symbol)) {
+ if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
builder_->Symbols().NameFor(symbol) + "'",
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index 12ef4a2..2e557d9 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -54,7 +54,7 @@
/// @param node the node to retrieve
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
- auto* resolved = dependencies_.resolved_symbols.Find(node);
+ auto resolved = dependencies_.resolved_symbols.Find(node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
}
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index afdbdfb..ffedaf6 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -600,7 +600,7 @@
}
// Add an edge from the variable's loop input node to its value at this point.
- auto** in_node = info.var_in_nodes.Find(var);
+ auto in_node = info.var_in_nodes.Find(var);
TINT_ASSERT(Resolver, in_node != nullptr);
auto* out_node = current_function_->variables.Get(var);
if (out_node != *in_node) {
@@ -1334,7 +1334,7 @@
[&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process
// functions in dependency order.
- auto* info = functions_.Find(func->Declaration());
+ auto info = functions_.Find(func->Declaration());
TINT_ASSERT(Resolver, info != nullptr);
callsite_tag = info->callsite_tag;
function_tag = info->function_tag;
@@ -1466,7 +1466,7 @@
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
- auto* target_info = functions_.Find(user->Declaration());
+ auto target_info = functions_.Find(user->Declaration());
for (auto* call_node : target_info->required_to_be_uniform->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
@@ -1636,7 +1636,7 @@
// If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
- auto* next_function = functions_.Find(user->Declaration());
+ auto next_function = functions_.Find(user->Declaration());
Node* next_cause = next_function->parameters[cause->arg_index].init_value;
MakeError(*next_function, next_cause, true);
}
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index db93bbe..2fc8f21 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -719,7 +719,7 @@
return false;
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
- if (auto* found = atomic_composite_info.Find(type)) {
+ if (auto found = atomic_composite_info.Find(type)) {
if (address_space != ast::AddressSpace::kStorage &&
address_space != ast::AddressSpace::kWorkgroup) {
AddError("atomic variables must have <storage> or <workgroup> address space",
@@ -798,7 +798,7 @@
for (auto* attr : decl->attributes) {
if (attr->Is<ast::IdAttribute>()) {
auto id = v->OverrideId();
- if (auto* var = override_ids.Find(id); var && *var != v) {
+ if (auto var = override_ids.Find(id); var && *var != v) {
AddError("@id values must be unique", attr->source);
AddNote(
"a override with an ID of " + std::to_string(id.value) +
diff --git a/src/tint/scope_stack.h b/src/tint/scope_stack.h
index a2da4dd..75c50b4 100644
--- a/src/tint/scope_stack.h
+++ b/src/tint/scope_stack.h
@@ -44,7 +44,7 @@
/// stack, otherwise the zero initializer for type T.
V Set(const K& key, V val) {
auto& back = stack_.Back();
- if (auto* el = back.Find(key)) {
+ if (auto el = back.Find(key)) {
std::swap(val, *el);
return val;
}
@@ -57,7 +57,7 @@
/// @returns the value, or the zero initializer if the value was not found
V Get(const K& key) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
- if (auto* val = iter->Find(key)) {
+ if (auto val = iter->Find(key)) {
return *val;
}
}
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index 5494ca2..b7fd7c2 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -129,7 +129,7 @@
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
- if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+ if (auto info = decomposed.Find(access->Member()->Declaration())) {
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
auto name =
b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
@@ -168,7 +168,7 @@
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
- if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+ if (auto info = decomposed.Find(access->Member()->Declaration())) {
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
auto name =
b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc
index b2b99ed..a0855b7 100644
--- a/src/tint/transform/simplify_pointers.cc
+++ b/src/tint/transform/simplify_pointers.cc
@@ -140,7 +140,7 @@
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
- if (auto* saved_var = saved_vars.Find(expr)) {
+ if (auto saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var);
}
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 2116b0f..8b566fe 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -401,7 +401,7 @@
return Switch(
ty, //
[&](const sem::Struct* str) -> const ast::Type* {
- if (auto* std140 = std140_structs.Find(str)) {
+ if (auto std140 = std140_structs.Find(str)) {
return b.create<ast::TypeName>(*std140);
}
return nullptr;
@@ -695,7 +695,7 @@
// call, or by reassembling a std140 matrix from column vector members.
utils::Vector<const ast::Expression*, 8> args;
for (auto* member : str->Members()) {
- if (auto* col_members = std140_mat_members.Find(member)) {
+ if (auto col_members = std140_mat_members.Find(member)) {
// std140 decomposed matrix. Reassemble.
auto* mat_ty = CreateASTTypeFor(ctx, member->Type());
auto mat_args =
diff --git a/src/tint/transform/truncate_interstage_variables.cc b/src/tint/transform/truncate_interstage_variables.cc
index a5e7256..30237bc 100644
--- a/src/tint/transform/truncate_interstage_variables.cc
+++ b/src/tint/transform/truncate_interstage_variables.cc
@@ -161,7 +161,7 @@
ctx.ReplaceAll(
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
auto* return_sem = sem.Get(return_statement);
- if (auto* mapping_fn_sym =
+ if (auto mapping_fn_sym =
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
return b.Return(return_statement->source,
b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value)));
diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc
index 93ce595..8d2b876 100644
--- a/src/tint/transform/unshadow.cc
+++ b/src/tint/transform/unshadow.cc
@@ -97,7 +97,7 @@
ctx.ReplaceAll(
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
- if (auto* renamed = renamed_to.Find(user->Variable())) {
+ if (auto renamed = renamed_to.Find(user->Variable())) {
return b.Expr(*renamed);
}
}
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index d4db655..ede1986 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -135,7 +135,7 @@
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #for_loops map is mutated.
- LoopInfo& ForLoop(const sem::ForLoopStatement* for_loop) {
+ auto ForLoop(const sem::ForLoopStatement* for_loop) {
if (for_loops.IsEmpty()) {
RegisterForLoopTransform();
}
@@ -147,7 +147,7 @@
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #for_loops map is mutated.
- LoopInfo& WhileLoop(const sem::WhileStatement* while_loop) {
+ auto WhileLoop(const sem::WhileStatement* while_loop) {
if (while_loops.IsEmpty()) {
RegisterWhileLoopTransform();
}
@@ -159,7 +159,7 @@
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #else_ifs map is mutated.
- ElseIfInfo& ElseIf(const ast::IfStatement* else_if) {
+ auto ElseIf(const ast::IfStatement* else_if) {
if (else_ifs.IsEmpty()) {
RegisterElseIfTransform();
}
@@ -172,7 +172,7 @@
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
- if (auto* info = for_loops.Find(fl)) {
+ if (auto info = for_loops.Find(fl)) {
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
// Build the loop body's statements.
@@ -222,7 +222,7 @@
auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) {
- if (auto* info = while_loops.Find(w)) {
+ if (auto info = while_loops.Find(w)) {
auto* while_loop = w->Declaration();
// While needs to be decomposed to a loop.
// Build the loop body's statements.
@@ -259,7 +259,7 @@
void RegisterElseIfTransform() const {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* {
- if (auto* info = else_ifs.Find(stmt)) {
+ if (auto info = else_ifs.Find(stmt)) {
// Build the else block's body statements, starting with let decls for the
// conditional expression.
auto body_stmts = Build(info->cond_decls);
@@ -291,10 +291,10 @@
if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
// Insertion point is an 'else if' condition.
// Need to convert 'else if' to 'else { if }'.
- auto& else_if_info = ElseIf(else_if->Declaration());
+ auto else_if_info = ElseIf(else_if->Declaration());
// Index the map to convert this else if, even if `stmt` is nullptr.
- auto& decls = else_if_info.cond_decls;
+ auto& decls = else_if_info->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -306,7 +306,7 @@
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
- auto& decls = ForLoop(fl).cond_decls;
+ auto& decls = ForLoop(fl)->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -318,7 +318,7 @@
// While needs to be decomposed to a loop.
// Index the map to convert this while, even if `stmt` is nullptr.
- auto& decls = WhileLoop(w).cond_decls;
+ auto& decls = WhileLoop(w)->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@@ -354,7 +354,7 @@
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
- auto& decls = ForLoop(fl).cont_decls;
+ auto& decls = ForLoop(fl)->cont_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h
index 1040e0c..19a5abe 100644
--- a/src/tint/utils/hashmap.h
+++ b/src/tint/utils/hashmap.h
@@ -47,6 +47,67 @@
/// Result of Add()
using AddResult = typename Base::PutResult;
+ /// Reference is returned by Hashmap::Find(), and performs dynamic Hashmap lookups.
+ /// The value returned by the Reference reflects the current state of the Hashmap, and so the
+ /// referenced value may change, or transition between valid or invalid based on the current
+ /// state of the Hashmap.
+ template <bool IS_CONST>
+ class ReferenceT {
+ /// `const Value` if IS_CONST, or `Value` if !IS_CONST
+ using T = std::conditional_t<IS_CONST, const Value, Value>;
+
+ /// `const Hashmap` if IS_CONST, or `Hashmap` if !IS_CONST
+ using Map = std::conditional_t<IS_CONST, const Hashmap, Hashmap>;
+
+ public:
+ /// @returns true if the reference is valid.
+ operator bool() const { return Get() != nullptr; }
+
+ /// @returns the pointer to the Value, or nullptr if the reference is invalid.
+ operator T*() const { return Get(); }
+
+ /// @returns the pointer to the Value
+ /// @warning if the Hashmap does not contain a value for the reference, then this will
+ /// trigger a TINT_ASSERT, or invalid pointer dereference.
+ T* operator->() const {
+ auto* hashmap_reference_lookup = Get();
+ TINT_ASSERT(Utils, hashmap_reference_lookup != nullptr);
+ return hashmap_reference_lookup;
+ }
+
+ /// @returns the pointer to the Value, or nullptr if the reference is invalid.
+ T* Get() const {
+ auto generation = map_.Generation();
+ if (generation_ != generation) {
+ cached_ = map_.Lookup(key_);
+ generation_ = generation;
+ }
+ return cached_;
+ }
+
+ private:
+ friend Hashmap;
+
+ /// Constructor
+ ReferenceT(Map& map, const Key& key)
+ : map_(map), key_(key), cached_(nullptr), generation_(map.Generation() - 1) {}
+
+ /// Constructor
+ ReferenceT(Map& map, const Key& key, T* value)
+ : map_(map), key_(key), cached_(value), generation_(map.Generation()) {}
+
+ Map& map_;
+ const Key key_;
+ mutable T* cached_ = nullptr;
+ mutable size_t generation_ = 0;
+ };
+
+ /// A mutable reference returned by Find()
+ using Reference = ReferenceT</*IS_CONST*/ false>;
+
+ /// An immutable reference returned by Find()
+ using ConstReference = ReferenceT</*IS_CONST*/ true>;
+
/// Adds a value to the map, if the map does not already contain an entry with the key @p key.
/// @param key the entry key.
/// @param value the value of the entry to add to the map.
@@ -108,25 +169,28 @@
/// @param key the entry's key value to search for.
/// @returns the value of the entry.
template <typename K>
- Value& GetOrZero(K&& key) {
+ Reference GetOrZero(K&& key) {
auto res = Add(std::forward<K>(key), Value{});
- return *res.value;
+ return Reference(*this, key, res.value);
}
/// @param key the key to search for.
- /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
- /// not contain the given value.
- const Value* Find(const Key& key) const {
+ /// @returns a reference to the entry that is equal to the given value.
+ Reference Find(const Key& key) { return Reference(*this, key); }
+
+ /// @param key the key to search for.
+ /// @returns a reference to the entry that is equal to the given value.
+ ConstReference Find(const Key& key) const { return ConstReference(*this, key); }
+
+ private:
+ Value* Lookup(const Key& key) {
if (auto [found, index] = this->IndexOf(key); found) {
return &this->slots_[index].entry->value;
}
return nullptr;
}
- /// @param key the key to search for.
- /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
- /// not contain the given value.
- Value* Find(const Key& key) {
+ const Value* Lookup(const Key& key) const {
if (auto [found, index] = this->IndexOf(key); found) {
return &this->slots_[index].entry->value;
}
diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc
index 77421cf..34a93e6 100644
--- a/src/tint/utils/hashmap_test.cc
+++ b/src/tint/utils/hashmap_test.cc
@@ -90,6 +90,71 @@
EXPECT_EQ(map.Generation(), 5u);
}
+TEST(Hashmap, Index) {
+ Hashmap<int, std::string, 4> map;
+ auto zero = map.Find(0);
+ EXPECT_FALSE(zero);
+
+ map.Add(3, "three");
+ auto three = map.Find(3);
+ map.Add(2, "two");
+ auto two = map.Find(2);
+ map.Add(4, "four");
+ auto four = map.Find(4);
+ map.Add(8, "eight");
+ auto eight = map.Find(8);
+
+ EXPECT_FALSE(zero);
+ ASSERT_TRUE(three);
+ ASSERT_TRUE(two);
+ ASSERT_TRUE(four);
+ ASSERT_TRUE(eight);
+
+ EXPECT_EQ(*three, "three");
+ EXPECT_EQ(*two, "two");
+ EXPECT_EQ(*four, "four");
+ EXPECT_EQ(*eight, "eight");
+
+ map.Add(0, "zero"); // Note: Find called before Add() is okay!
+
+ map.Add(5, "five");
+ auto five = map.Find(5);
+ map.Add(6, "six");
+ auto six = map.Find(6);
+ map.Add(1, "one");
+ auto one = map.Find(1);
+ map.Add(7, "seven");
+ auto seven = map.Find(7);
+
+ ASSERT_TRUE(zero);
+ ASSERT_TRUE(three);
+ ASSERT_TRUE(two);
+ ASSERT_TRUE(four);
+ ASSERT_TRUE(eight);
+ ASSERT_TRUE(five);
+ ASSERT_TRUE(six);
+ ASSERT_TRUE(one);
+ ASSERT_TRUE(seven);
+
+ EXPECT_EQ(*zero, "zero");
+ EXPECT_EQ(*three, "three");
+ EXPECT_EQ(*two, "two");
+ EXPECT_EQ(*four, "four");
+ EXPECT_EQ(*eight, "eight");
+ EXPECT_EQ(*five, "five");
+ EXPECT_EQ(*six, "six");
+ EXPECT_EQ(*one, "one");
+ EXPECT_EQ(*seven, "seven");
+
+ map.Remove(2);
+ map.Remove(8);
+ map.Remove(1);
+
+ EXPECT_FALSE(two);
+ EXPECT_FALSE(eight);
+ EXPECT_FALSE(one);
+}
+
TEST(Hashmap, Iterator) {
using Map = Hashmap<int, std::string, 8>;
using Entry = typename Map::Entry;