tint/utils: Make Hashmap::Find() safer to use

Don't return a raw pointer to the map entry's value, instead return a new Reference which re-looks up the entry if the map is mutated.

Change-Id: I031749785faeac98e2a129a776493cb0371a5cb9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110540
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/clone_context.cc b/src/tint/clone_context.cc
index 457522b..fe94a14 100644
--- a/src/tint/clone_context.cc
+++ b/src/tint/clone_context.cc
@@ -73,7 +73,7 @@
     }
 
     // Was Replace() called for this object?
-    if (auto* fn = replacements_.Find(object)) {
+    if (auto fn = replacements_.Find(object)) {
         return (*fn)();
     }
 
diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h
index 05f4868..f1db617 100644
--- a/src/tint/clone_context.h
+++ b/src/tint/clone_context.h
@@ -208,7 +208,7 @@
                 to.Push(CheckedCast<T>(builder()));
             }
             for (auto& el : from) {
-                if (auto* insert_before = transforms->insert_before_.Find(el)) {
+                if (auto insert_before = transforms->insert_before_.Find(el)) {
                     for (auto& builder : *insert_before) {
                         to.Push(CheckedCast<T>(builder()));
                     }
@@ -216,7 +216,7 @@
                 if (!transforms->remove_.Contains(el)) {
                     to.Push(Clone(el));
                 }
-                if (auto* insert_after = transforms->insert_after_.Find(el)) {
+                if (auto insert_after = transforms->insert_after_.Find(el)) {
                     for (auto& builder : *insert_after) {
                         to.Push(CheckedCast<T>(builder()));
                     }
@@ -232,7 +232,7 @@
                 // Clone(el) may have updated the transformation list, adding an `insert_after`
                 // transform for `from`.
                 if (transforms) {
-                    if (auto* insert_after = transforms->insert_after_.Find(el)) {
+                    if (auto insert_after = transforms->insert_after_.Find(el)) {
                         for (auto& builder : *insert_after) {
                             to.Push(CheckedCast<T>(builder()));
                         }
@@ -389,7 +389,7 @@
             return *this;
         }
 
-        list_transforms_.Edit(&vector).remove_.Add(object);
+        list_transforms_.GetOrZero(&vector)->remove_.Add(object);
         return *this;
     }
 
@@ -411,7 +411,7 @@
     /// @returns this CloneContext so calls can be chained
     template <typename T, size_t N, typename BUILDER>
     CloneContext& InsertFront(const utils::Vector<T, N>& vector, BUILDER&& builder) {
-        list_transforms_.Edit(&vector).insert_front_.Push(std::forward<BUILDER>(builder));
+        list_transforms_.GetOrZero(&vector)->insert_front_.Push(std::forward<BUILDER>(builder));
         return *this;
     }
 
@@ -434,7 +434,7 @@
     /// @returns this CloneContext so calls can be chained
     template <typename T, size_t N, typename BUILDER>
     CloneContext& InsertBack(const utils::Vector<T, N>& vector, BUILDER&& builder) {
-        list_transforms_.Edit(&vector).insert_back_.Push(std::forward<BUILDER>(builder));
+        list_transforms_.GetOrZero(&vector)->insert_back_.Push(std::forward<BUILDER>(builder));
         return *this;
     }
 
@@ -456,7 +456,7 @@
             return *this;
         }
 
-        list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
+        list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
             [object] { return object; });
         return *this;
     }
@@ -475,7 +475,7 @@
     CloneContext& InsertBefore(const utils::Vector<T, N>& vector,
                                const BEFORE* before,
                                BUILDER&& builder) {
-        list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
+        list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
             std::forward<BUILDER>(builder));
         return *this;
     }
@@ -498,7 +498,7 @@
             return *this;
         }
 
-        list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
+        list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
             [object] { return object; });
         return *this;
     }
@@ -517,7 +517,7 @@
     CloneContext& InsertAfter(const utils::Vector<T, N>& vector,
                               const AFTER* after,
                               BUILDER&& builder) {
-        list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
+        list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
             std::forward<BUILDER>(builder));
         return *this;
     }
@@ -601,61 +601,6 @@
     /// @returns the diagnostic list of #dst
     diag::List& Diagnostics() const;
 
-    /// VectorListTransforms is a map of utils::Vector pointer to transforms for that list
-    struct VectorListTransforms {
-        using Map = utils::Hashmap<const void*, ListTransforms, 4>;
-
-        /// An accessor to the VectorListTransforms map.
-        /// Index caches the last map lookup, and will only re-search the map if the transform map
-        /// was modified since the last lookup.
-        struct Index {
-            /// @returns true if the map now holds a value for the index
-            operator bool() {
-                Update();
-                return cached_;
-            }
-
-            /// @returns a pointer to the indexed map entry
-            const ListTransforms* operator->() {
-                Update();
-                return cached_;
-            }
-
-          private:
-            friend VectorListTransforms;
-
-            Index(const void* list, Map* map)
-                : list_(list),
-                  map_(map),
-                  generation_(map->Generation()),
-                  cached_(map_->Find(list)) {}
-
-            void Update() {
-                if (map_->Generation() != generation_) {
-                    cached_ = map_->Find(list_);
-                    generation_ = map_->Generation();
-                }
-            }
-
-            const void* list_;
-            Map* map_;
-            uint64_t generation_;
-            const ListTransforms* cached_;
-        };
-
-        /// Edit returns a reference to the ListTransforms for the given vector pointer and
-        /// increments #list_transform_generation_ signalling that the list transforms have been
-        /// modified.
-        inline ListTransforms& Edit(const void* list) { return map_.GetOrZero(list); }
-
-        /// @returns an Index to the transforms for the given list.
-        inline Index Find(const void* list) { return Index{list, &map_}; }
-
-      private:
-        /// The map of vector pointer to ListTransforms
-        Map map_;
-    };
-
     /// A map of object in #src to functions that create their replacement in #dst
     utils::Hashmap<const Cloneable*, std::function<const Cloneable*()>, 8> replacements_;
 
@@ -666,7 +611,7 @@
     utils::Vector<CloneableTransform, 8> transforms_;
 
     /// Transformations to apply to vectors
-    VectorListTransforms list_transforms_;
+    utils::Hashmap<const void*, ListTransforms, 4> list_transforms_;
 
     /// Symbol transform registered with ReplaceAll()
     SymbolTransform symbol_transform_;
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index f97d169..accd692 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -3522,7 +3522,7 @@
             const auto phi_id = assignment.phi_id;
             auto* const lhs_expr = builder_.Expr(namer_.Name(phi_id));
             // If RHS value is actually a phi we just cpatured, then use it.
-            auto* const copy_sym = copied_phis.Find(assignment.value_id);
+            auto copy_sym = copied_phis.Find(assignment.value_id);
             auto* const rhs_expr =
                 copy_sym ? builder_.Expr(*copy_sym) : MakeExpression(assignment.value_id).expr;
             AddStatement(builder_.Assign(lhs_expr, rhs_expr));
diff --git a/src/tint/reader/spirv/parser_impl.h b/src/tint/reader/spirv/parser_impl.h
index 1780ca3..9ff9033 100644
--- a/src/tint/reader/spirv/parser_impl.h
+++ b/src/tint/reader/spirv/parser_impl.h
@@ -666,7 +666,7 @@
     /// @param id a SPIR-V ID
     /// @returns the AST variable or null.
     const ast::Var* GetModuleVariable(uint32_t id) {
-        auto* entry = module_variable_.Find(id);
+        auto entry = module_variable_.Find(id);
         return entry ? *entry : nullptr;
     }
 
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 1fda54f..796412b 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -471,7 +471,7 @@
             }
         }
 
-        if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
+        if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
             if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
                                       DependencyInfo{from->source, action})) {
                 current_global_->deps.Push(*global);
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 2cc4a3a..81e79ff 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1128,8 +1128,8 @@
 
     if (expect_pass) {
         // Check that the use resolves to the declaration
-        auto* resolved_symbol = graph.resolved_symbols.Find(use);
-        ASSERT_NE(resolved_symbol, nullptr);
+        auto resolved_symbol = graph.resolved_symbols.Find(use);
+        ASSERT_TRUE(resolved_symbol);
         EXPECT_EQ(*resolved_symbol, decl)
             << "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
             << "\n"
@@ -1179,8 +1179,8 @@
     helper.Build();
 
     auto shadows = Build().shadows;
-    auto* shadow = shadows.Find(inner_var);
-    ASSERT_NE(shadow, nullptr);
+    auto shadow = shadows.Find(inner_var);
+    ASSERT_TRUE(shadow);
     EXPECT_EQ(*shadow, outer);
 }
 
@@ -1310,8 +1310,8 @@
 
     auto graph = Build();
     for (auto use : symbol_uses) {
-        auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
-        ASSERT_NE(resolved_symbol, nullptr) << use.where;
+        auto resolved_symbol = graph.resolved_symbols.Find(use.use);
+        ASSERT_TRUE(resolved_symbol) << use.where;
         EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
     }
 }
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 05f7675..0fdd78c 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2481,7 +2481,7 @@
                 if (loop_block->FirstContinue()) {
                     // If our identifier is in loop_block->decls, make sure its index is
                     // less than first_continue
-                    if (auto* decl = loop_block->Decls().Find(symbol)) {
+                    if (auto decl = loop_block->Decls().Find(symbol)) {
                         if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
                             AddError("continue statement bypasses declaration of '" +
                                          builder_->Symbols().NameFor(symbol) + "'",
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index 12ef4a2..2e557d9 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -54,7 +54,7 @@
     /// @param node the node to retrieve
     template <typename SEM = sem::Node>
     SEM* ResolvedSymbol(const ast::Node* node) const {
-        auto* resolved = dependencies_.resolved_symbols.Find(node);
+        auto resolved = dependencies_.resolved_symbols.Find(node);
         return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
     }
 
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index afdbdfb..ffedaf6 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -600,7 +600,7 @@
                     }
 
                     // Add an edge from the variable's loop input node to its value at this point.
-                    auto** in_node = info.var_in_nodes.Find(var);
+                    auto in_node = info.var_in_nodes.Find(var);
                     TINT_ASSERT(Resolver, in_node != nullptr);
                     auto* out_node = current_function_->variables.Get(var);
                     if (out_node != *in_node) {
@@ -1334,7 +1334,7 @@
             [&](const sem::Function* func) {
                 // We must have already analyzed the user-defined function since we process
                 // functions in dependency order.
-                auto* info = functions_.Find(func->Declaration());
+                auto info = functions_.Find(func->Declaration());
                 TINT_ASSERT(Resolver, info != nullptr);
                 callsite_tag = info->callsite_tag;
                 function_tag = info->function_tag;
@@ -1466,7 +1466,7 @@
         } else if (auto* user = target->As<sem::Function>()) {
             // This is a call to a user-defined function, so inspect the functions called by that
             // function and look for one whose node has an edge from the RequiredToBeUniform node.
-            auto* target_info = functions_.Find(user->Declaration());
+            auto target_info = functions_.Find(user->Declaration());
             for (auto* call_node : target_info->required_to_be_uniform->edges) {
                 if (call_node->type == Node::kRegular) {
                     auto* child_call = call_node->ast->As<ast::CallExpression>();
@@ -1636,7 +1636,7 @@
             // If this is a call to a user-defined function, add a note to show the reason that the
             // parameter is required to be uniform.
             if (auto* user = target->As<sem::Function>()) {
-                auto* next_function = functions_.Find(user->Declaration());
+                auto next_function = functions_.Find(user->Declaration());
                 Node* next_cause = next_function->parameters[cause->arg_index].init_value;
                 MakeError(*next_function, next_cause, true);
             }
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index db93bbe..2fc8f21 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -719,7 +719,7 @@
             return false;
         }
     } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
-        if (auto* found = atomic_composite_info.Find(type)) {
+        if (auto found = atomic_composite_info.Find(type)) {
             if (address_space != ast::AddressSpace::kStorage &&
                 address_space != ast::AddressSpace::kWorkgroup) {
                 AddError("atomic variables must have <storage> or <workgroup> address space",
@@ -798,7 +798,7 @@
     for (auto* attr : decl->attributes) {
         if (attr->Is<ast::IdAttribute>()) {
             auto id = v->OverrideId();
-            if (auto* var = override_ids.Find(id); var && *var != v) {
+            if (auto var = override_ids.Find(id); var && *var != v) {
                 AddError("@id values must be unique", attr->source);
                 AddNote(
                     "a override with an ID of " + std::to_string(id.value) +
diff --git a/src/tint/scope_stack.h b/src/tint/scope_stack.h
index a2da4dd..75c50b4 100644
--- a/src/tint/scope_stack.h
+++ b/src/tint/scope_stack.h
@@ -44,7 +44,7 @@
     /// stack, otherwise the zero initializer for type T.
     V Set(const K& key, V val) {
         auto& back = stack_.Back();
-        if (auto* el = back.Find(key)) {
+        if (auto el = back.Find(key)) {
             std::swap(val, *el);
             return val;
         }
@@ -57,7 +57,7 @@
     /// @returns the value, or the zero initializer if the value was not found
     V Get(const K& key) const {
         for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
-            if (auto* val = iter->Find(key)) {
+            if (auto val = iter->Find(key)) {
                 return *val;
             }
         }
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index 5494ca2..b7fd7c2 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -129,7 +129,7 @@
     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
     ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
         if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
-            if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+            if (auto info = decomposed.Find(access->Member()->Declaration())) {
                 auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
                     auto name =
                         b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
@@ -168,7 +168,7 @@
     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
     ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
         if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
-            if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+            if (auto info = decomposed.Find(access->Member()->Declaration())) {
                 auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
                     auto name =
                         b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc
index b2b99ed..a0855b7 100644
--- a/src/tint/transform/simplify_pointers.cc
+++ b/src/tint/transform/simplify_pointers.cc
@@ -140,7 +140,7 @@
         // variable identifier.
         ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
             // Look to see if we need to swap this Expression with a saved variable.
-            if (auto* saved_var = saved_vars.Find(expr)) {
+            if (auto saved_var = saved_vars.Find(expr)) {
                 return ctx.dst->Expr(*saved_var);
             }
 
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 2116b0f..8b566fe 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -401,7 +401,7 @@
         return Switch(
             ty,  //
             [&](const sem::Struct* str) -> const ast::Type* {
-                if (auto* std140 = std140_structs.Find(str)) {
+                if (auto std140 = std140_structs.Find(str)) {
                     return b.create<ast::TypeName>(*std140);
                 }
                 return nullptr;
@@ -695,7 +695,7 @@
                     // call, or by reassembling a std140 matrix from column vector members.
                     utils::Vector<const ast::Expression*, 8> args;
                     for (auto* member : str->Members()) {
-                        if (auto* col_members = std140_mat_members.Find(member)) {
+                        if (auto col_members = std140_mat_members.Find(member)) {
                             // std140 decomposed matrix. Reassemble.
                             auto* mat_ty = CreateASTTypeFor(ctx, member->Type());
                             auto mat_args =
diff --git a/src/tint/transform/truncate_interstage_variables.cc b/src/tint/transform/truncate_interstage_variables.cc
index a5e7256..30237bc 100644
--- a/src/tint/transform/truncate_interstage_variables.cc
+++ b/src/tint/transform/truncate_interstage_variables.cc
@@ -161,7 +161,7 @@
     ctx.ReplaceAll(
         [&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
             auto* return_sem = sem.Get(return_statement);
-            if (auto* mapping_fn_sym =
+            if (auto mapping_fn_sym =
                     entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
                 return b.Return(return_statement->source,
                                 b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value)));
diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc
index 93ce595..8d2b876 100644
--- a/src/tint/transform/unshadow.cc
+++ b/src/tint/transform/unshadow.cc
@@ -97,7 +97,7 @@
         ctx.ReplaceAll(
             [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
                 if (auto* user = sem.Get<sem::VariableUser>(ident)) {
-                    if (auto* renamed = renamed_to.Find(user->Variable())) {
+                    if (auto renamed = renamed_to.Find(user->Variable())) {
                         return b.Expr(*renamed);
                     }
                 }
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index d4db655..ede1986 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -135,7 +135,7 @@
     /// automatically called.
     /// @warning the returned reference is invalid if this is called a second time, or the
     /// #for_loops map is mutated.
-    LoopInfo& ForLoop(const sem::ForLoopStatement* for_loop) {
+    auto ForLoop(const sem::ForLoopStatement* for_loop) {
         if (for_loops.IsEmpty()) {
             RegisterForLoopTransform();
         }
@@ -147,7 +147,7 @@
     /// automatically called.
     /// @warning the returned reference is invalid if this is called a second time, or the
     /// #for_loops map is mutated.
-    LoopInfo& WhileLoop(const sem::WhileStatement* while_loop) {
+    auto WhileLoop(const sem::WhileStatement* while_loop) {
         if (while_loops.IsEmpty()) {
             RegisterWhileLoopTransform();
         }
@@ -159,7 +159,7 @@
     /// automatically called.
     /// @warning the returned reference is invalid if this is called a second time, or the
     /// #else_ifs map is mutated.
-    ElseIfInfo& ElseIf(const ast::IfStatement* else_if) {
+    auto ElseIf(const ast::IfStatement* else_if) {
         if (else_ifs.IsEmpty()) {
             RegisterElseIfTransform();
         }
@@ -172,7 +172,7 @@
             auto& sem = ctx.src->Sem();
 
             if (auto* fl = sem.Get(stmt)) {
-                if (auto* info = for_loops.Find(fl)) {
+                if (auto info = for_loops.Find(fl)) {
                     auto* for_loop = fl->Declaration();
                     // For-loop needs to be decomposed to a loop.
                     // Build the loop body's statements.
@@ -222,7 +222,7 @@
             auto& sem = ctx.src->Sem();
 
             if (auto* w = sem.Get(stmt)) {
-                if (auto* info = while_loops.Find(w)) {
+                if (auto info = while_loops.Find(w)) {
                     auto* while_loop = w->Declaration();
                     // While needs to be decomposed to a loop.
                     // Build the loop body's statements.
@@ -259,7 +259,7 @@
     void RegisterElseIfTransform() const {
         // Decompose 'else-if' statements into 'else { if }' blocks.
         ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* {
-            if (auto* info = else_ifs.Find(stmt)) {
+            if (auto info = else_ifs.Find(stmt)) {
                 // Build the else block's body statements, starting with let decls for the
                 // conditional expression.
                 auto body_stmts = Build(info->cond_decls);
@@ -291,10 +291,10 @@
         if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
             // Insertion point is an 'else if' condition.
             // Need to convert 'else if' to 'else { if }'.
-            auto& else_if_info = ElseIf(else_if->Declaration());
+            auto else_if_info = ElseIf(else_if->Declaration());
 
             // Index the map to convert this else if, even if `stmt` is nullptr.
-            auto& decls = else_if_info.cond_decls;
+            auto& decls = else_if_info->cond_decls;
             if constexpr (!std::is_same_v<BUILDER, Decompose>) {
                 decls.Push(std::forward<BUILDER>(builder));
             }
@@ -306,7 +306,7 @@
             // For-loop needs to be decomposed to a loop.
 
             // Index the map to convert this for-loop, even if `stmt` is nullptr.
-            auto& decls = ForLoop(fl).cond_decls;
+            auto& decls = ForLoop(fl)->cond_decls;
             if constexpr (!std::is_same_v<BUILDER, Decompose>) {
                 decls.Push(std::forward<BUILDER>(builder));
             }
@@ -318,7 +318,7 @@
             // While needs to be decomposed to a loop.
 
             // Index the map to convert this while, even if `stmt` is nullptr.
-            auto& decls = WhileLoop(w).cond_decls;
+            auto& decls = WhileLoop(w)->cond_decls;
             if constexpr (!std::is_same_v<BUILDER, Decompose>) {
                 decls.Push(std::forward<BUILDER>(builder));
             }
@@ -354,7 +354,7 @@
                 // For-loop needs to be decomposed to a loop.
 
                 // Index the map to convert this for-loop, even if `stmt` is nullptr.
-                auto& decls = ForLoop(fl).cont_decls;
+                auto& decls = ForLoop(fl)->cont_decls;
                 if constexpr (!std::is_same_v<BUILDER, Decompose>) {
                     decls.Push(std::forward<BUILDER>(builder));
                 }
diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h
index 1040e0c..19a5abe 100644
--- a/src/tint/utils/hashmap.h
+++ b/src/tint/utils/hashmap.h
@@ -47,6 +47,67 @@
     /// Result of Add()
     using AddResult = typename Base::PutResult;
 
+    /// Reference is returned by Hashmap::Find(), and performs dynamic Hashmap lookups.
+    /// The value returned by the Reference reflects the current state of the Hashmap, and so the
+    /// referenced value may change, or transition between valid or invalid based on the current
+    /// state of the Hashmap.
+    template <bool IS_CONST>
+    class ReferenceT {
+        /// `const Value` if IS_CONST, or `Value` if !IS_CONST
+        using T = std::conditional_t<IS_CONST, const Value, Value>;
+
+        /// `const Hashmap` if IS_CONST, or `Hashmap` if !IS_CONST
+        using Map = std::conditional_t<IS_CONST, const Hashmap, Hashmap>;
+
+      public:
+        /// @returns true if the reference is valid.
+        operator bool() const { return Get() != nullptr; }
+
+        /// @returns the pointer to the Value, or nullptr if the reference is invalid.
+        operator T*() const { return Get(); }
+
+        /// @returns the pointer to the Value
+        /// @warning if the Hashmap does not contain a value for the reference, then this will
+        /// trigger a TINT_ASSERT, or invalid pointer dereference.
+        T* operator->() const {
+            auto* hashmap_reference_lookup = Get();
+            TINT_ASSERT(Utils, hashmap_reference_lookup != nullptr);
+            return hashmap_reference_lookup;
+        }
+
+        /// @returns the pointer to the Value, or nullptr if the reference is invalid.
+        T* Get() const {
+            auto generation = map_.Generation();
+            if (generation_ != generation) {
+                cached_ = map_.Lookup(key_);
+                generation_ = generation;
+            }
+            return cached_;
+        }
+
+      private:
+        friend Hashmap;
+
+        /// Constructor
+        ReferenceT(Map& map, const Key& key)
+            : map_(map), key_(key), cached_(nullptr), generation_(map.Generation() - 1) {}
+
+        /// Constructor
+        ReferenceT(Map& map, const Key& key, T* value)
+            : map_(map), key_(key), cached_(value), generation_(map.Generation()) {}
+
+        Map& map_;
+        const Key key_;
+        mutable T* cached_ = nullptr;
+        mutable size_t generation_ = 0;
+    };
+
+    /// A mutable reference returned by Find()
+    using Reference = ReferenceT</*IS_CONST*/ false>;
+
+    /// An immutable reference returned by Find()
+    using ConstReference = ReferenceT</*IS_CONST*/ true>;
+
     /// Adds a value to the map, if the map does not already contain an entry with the key @p key.
     /// @param key the entry key.
     /// @param value the value of the entry to add to the map.
@@ -108,25 +169,28 @@
     /// @param key the entry's key value to search for.
     /// @returns the value of the entry.
     template <typename K>
-    Value& GetOrZero(K&& key) {
+    Reference GetOrZero(K&& key) {
         auto res = Add(std::forward<K>(key), Value{});
-        return *res.value;
+        return Reference(*this, key, res.value);
     }
 
     /// @param key the key to search for.
-    /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
-    ///          not contain the given value.
-    const Value* Find(const Key& key) const {
+    /// @returns a reference to the entry that is equal to the given value.
+    Reference Find(const Key& key) { return Reference(*this, key); }
+
+    /// @param key the key to search for.
+    /// @returns a reference to the entry that is equal to the given value.
+    ConstReference Find(const Key& key) const { return ConstReference(*this, key); }
+
+  private:
+    Value* Lookup(const Key& key) {
         if (auto [found, index] = this->IndexOf(key); found) {
             return &this->slots_[index].entry->value;
         }
         return nullptr;
     }
 
-    /// @param key the key to search for.
-    /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
-    ///          not contain the given value.
-    Value* Find(const Key& key) {
+    const Value* Lookup(const Key& key) const {
         if (auto [found, index] = this->IndexOf(key); found) {
             return &this->slots_[index].entry->value;
         }
diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc
index 77421cf..34a93e6 100644
--- a/src/tint/utils/hashmap_test.cc
+++ b/src/tint/utils/hashmap_test.cc
@@ -90,6 +90,71 @@
     EXPECT_EQ(map.Generation(), 5u);
 }
 
+TEST(Hashmap, Index) {
+    Hashmap<int, std::string, 4> map;
+    auto zero = map.Find(0);
+    EXPECT_FALSE(zero);
+
+    map.Add(3, "three");
+    auto three = map.Find(3);
+    map.Add(2, "two");
+    auto two = map.Find(2);
+    map.Add(4, "four");
+    auto four = map.Find(4);
+    map.Add(8, "eight");
+    auto eight = map.Find(8);
+
+    EXPECT_FALSE(zero);
+    ASSERT_TRUE(three);
+    ASSERT_TRUE(two);
+    ASSERT_TRUE(four);
+    ASSERT_TRUE(eight);
+
+    EXPECT_EQ(*three, "three");
+    EXPECT_EQ(*two, "two");
+    EXPECT_EQ(*four, "four");
+    EXPECT_EQ(*eight, "eight");
+
+    map.Add(0, "zero");  // Note: Find called before Add() is okay!
+
+    map.Add(5, "five");
+    auto five = map.Find(5);
+    map.Add(6, "six");
+    auto six = map.Find(6);
+    map.Add(1, "one");
+    auto one = map.Find(1);
+    map.Add(7, "seven");
+    auto seven = map.Find(7);
+
+    ASSERT_TRUE(zero);
+    ASSERT_TRUE(three);
+    ASSERT_TRUE(two);
+    ASSERT_TRUE(four);
+    ASSERT_TRUE(eight);
+    ASSERT_TRUE(five);
+    ASSERT_TRUE(six);
+    ASSERT_TRUE(one);
+    ASSERT_TRUE(seven);
+
+    EXPECT_EQ(*zero, "zero");
+    EXPECT_EQ(*three, "three");
+    EXPECT_EQ(*two, "two");
+    EXPECT_EQ(*four, "four");
+    EXPECT_EQ(*eight, "eight");
+    EXPECT_EQ(*five, "five");
+    EXPECT_EQ(*six, "six");
+    EXPECT_EQ(*one, "one");
+    EXPECT_EQ(*seven, "seven");
+
+    map.Remove(2);
+    map.Remove(8);
+    map.Remove(1);
+
+    EXPECT_FALSE(two);
+    EXPECT_FALSE(eight);
+    EXPECT_FALSE(one);
+}
+
 TEST(Hashmap, Iterator) {
     using Map = Hashmap<int, std::string, 8>;
     using Entry = typename Map::Entry;