// Copyright 2021 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
//    contributors may be used to endorse or promote products derived from
//    this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "src/tint/lang/wgsl/resolver/dependency_graph.h"

#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "src/tint/lang/core/builtin_type.h"
#include "src/tint/lang/core/builtin_value.h"
#include "src/tint/lang/wgsl/ast/alias.h"
#include "src/tint/lang/wgsl/ast/assignment_statement.h"
#include "src/tint/lang/wgsl/ast/block_statement.h"
#include "src/tint/lang/wgsl/ast/break_if_statement.h"
#include "src/tint/lang/wgsl/ast/break_statement.h"
#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/lang/wgsl/ast/compound_assignment_statement.h"
#include "src/tint/lang/wgsl/ast/const.h"
#include "src/tint/lang/wgsl/ast/continue_statement.h"
#include "src/tint/lang/wgsl/ast/diagnostic_attribute.h"
#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
#include "src/tint/lang/wgsl/ast/id_attribute.h"
#include "src/tint/lang/wgsl/ast/identifier.h"
#include "src/tint/lang/wgsl/ast/if_statement.h"
#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
#include "src/tint/lang/wgsl/ast/index_attribute.h"
#include "src/tint/lang/wgsl/ast/internal_attribute.h"
#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
#include "src/tint/lang/wgsl/ast/let.h"
#include "src/tint/lang/wgsl/ast/location_attribute.h"
#include "src/tint/lang/wgsl/ast/loop_statement.h"
#include "src/tint/lang/wgsl/ast/must_use_attribute.h"
#include "src/tint/lang/wgsl/ast/override.h"
#include "src/tint/lang/wgsl/ast/return_statement.h"
#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "src/tint/lang/wgsl/ast/stride_attribute.h"
#include "src/tint/lang/wgsl/ast/struct.h"
#include "src/tint/lang/wgsl/ast/struct_member_align_attribute.h"
#include "src/tint/lang/wgsl/ast/struct_member_offset_attribute.h"
#include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h"
#include "src/tint/lang/wgsl/ast/switch_statement.h"
#include "src/tint/lang/wgsl/ast/templated_identifier.h"
#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
#include "src/tint/lang/wgsl/ast/var.h"
#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/lang/wgsl/ast/while_statement.h"
#include "src/tint/lang/wgsl/ast/workgroup_attribute.h"
#include "src/tint/lang/wgsl/sem/builtin_fn.h"
#include "src/tint/utils/containers/map.h"
#include "src/tint/utils/containers/scope_stack.h"
#include "src/tint/utils/containers/unique_vector.h"
#include "src/tint/utils/macros/compiler.h"
#include "src/tint/utils/macros/defer.h"
#include "src/tint/utils/macros/scoped_assignment.h"
#include "src/tint/utils/memory/block_allocator.h"
#include "src/tint/utils/rtti/switch.h"
#include "src/tint/utils/text/string.h"
#include "src/tint/utils/text/string_stream.h"

#define TINT_DUMP_DEPENDENCY_GRAPH 0

namespace tint::resolver {
namespace {

// Forward declaration
struct Global;

/// Dependency describes how one global depends on another global
struct DependencyInfo {
    /// The source of the symbol that forms the dependency
    Source source;
};

/// DependencyEdge describes the two Globals used to define a dependency
/// relationship.
struct DependencyEdge {
    /// The Global that depends on #to
    const Global* from;
    /// The Global that is depended on by #from
    const Global* to;
};

/// DependencyEdgeCmp implements the contracts of std::equal_to<DependencyEdge>
/// and std::hash<DependencyEdge>.
struct DependencyEdgeCmp {
    /// Equality operator
    bool operator()(const DependencyEdge& lhs, const DependencyEdge& rhs) const {
        return lhs.from == rhs.from && lhs.to == rhs.to;
    }
    /// Hashing operator
    inline std::size_t operator()(const DependencyEdge& d) const { return Hash(d.from, d.to); }
};

/// A map of DependencyEdge to DependencyInfo
using DependencyEdges =
    Hashmap<DependencyEdge, DependencyInfo, 64, DependencyEdgeCmp, DependencyEdgeCmp>;

/// Global describes a module-scope variable, type or function.
struct Global {
    explicit Global(const ast::Node* n) : node(n) {}

    /// The declaration ast::Node
    const ast::Node* node;
    /// A list of dependencies that this global depends on
    Vector<Global*, 8> deps;
};

/// A map of global name to Global
using GlobalMap = Hashmap<Symbol, Global*, 16>;

/// Raises an error diagnostic with the given message and source.
void AddError(diag::List& diagnostics, const std::string& msg, const Source& source) {
    diagnostics.add_error(diag::System::Resolver, msg, source);
}

/// Raises a note diagnostic with the given message and source.
void AddNote(diag::List& diagnostics, const std::string& msg, const Source& source) {
    diagnostics.add_note(diag::System::Resolver, msg, source);
}

/// DependencyScanner is used to traverse a module to build the list of
/// global-to-global dependencies.
class DependencyScanner {
  public:
    /// Constructor
    /// @param globals_by_name map of global symbol to Global pointer
    /// @param diagnostics diagnostic messages, appended with any errors found
    /// @param graph the dependency graph to populate with resolved symbols
    /// @param edges the map of globals-to-global dependency edges, which will
    /// be populated by calls to Scan()
    DependencyScanner(const GlobalMap& globals_by_name,
                      diag::List& diagnostics,
                      DependencyGraph& graph,
                      DependencyEdges& edges)
        : globals_(globals_by_name),
          diagnostics_(diagnostics),
          graph_(graph),
          dependency_edges_(edges) {
        // Register all the globals at global-scope
        for (auto it : globals_by_name) {
            scope_stack_.Set(it.key, it.value->node);
        }
    }

    /// Walks the global declarations, resolving symbols, and determining the
    /// dependencies of each global.
    void Scan(Global* global) {
        TINT_SCOPED_ASSIGNMENT(current_global_, global);
        Switch(
            global->node,
            [&](const ast::Struct* str) {
                Declare(str->name->symbol, str);
                for (auto* member : str->members) {
                    TraverseAttributes(member->attributes);
                    TraverseExpression(member->type);
                }
            },
            [&](const ast::Alias* alias) {
                Declare(alias->name->symbol, alias);
                TraverseExpression(alias->type);
            },
            [&](const ast::Function* func) {
                Declare(func->name->symbol, func);
                TraverseFunction(func);
            },
            [&](const ast::Variable* v) {
                Declare(v->name->symbol, v);
                TraverseVariable(v);
            },
            [&](const ast::DiagnosticDirective*) {
                // Diagnostic directives do not affect the dependency graph.
            },
            [&](const ast::Enable*) {
                // Enable directives do not affect the dependency graph.
            },
            [&](const ast::Requires*) {
                // Requires directives do not affect the dependency graph.
            },
            [&](const ast::ConstAssert* assertion) {
                TraverseExpression(assertion->condition);
            },  //
            TINT_ICE_ON_NO_MATCH);
    }

  private:
    /// Traverses the variable, performing symbol resolution.
    void TraverseVariable(const ast::Variable* v) {
        if (auto* var = v->As<ast::Var>()) {
            TraverseExpression(var->declared_address_space);
            TraverseExpression(var->declared_access);
        }
        TraverseExpression(v->type);
        TraverseAttributes(v->attributes);
        TraverseExpression(v->initializer);
    }

    /// Traverses the function, performing symbol resolution and determining global dependencies.
    void TraverseFunction(const ast::Function* func) {
        TraverseAttributes(func->attributes);
        TraverseAttributes(func->return_type_attributes);
        // Perform symbol resolution on all the parameter types before registering
        // the parameters themselves. This allows the case of declaring a parameter
        // with the same identifier as its type.
        for (auto* param : func->params) {
            TraverseAttributes(param->attributes);
            TraverseExpression(param->type);
        }
        // Resolve the return type
        TraverseExpression(func->return_type);

        // Push the scope stack for the parameters and function body.
        scope_stack_.Push();
        TINT_DEFER(scope_stack_.Pop());

        for (auto* param : func->params) {
            if (auto* shadows = scope_stack_.Get(param->name->symbol)) {
                graph_.shadows.Add(param, shadows);
            }
            Declare(param->name->symbol, param);
        }
        if (func->body) {
            TraverseStatements(func->body->statements);
        }
    }

    /// Traverses the statements, performing symbol resolution and determining
    /// global dependencies.
    void TraverseStatements(VectorRef<const ast::Statement*> stmts) {
        for (auto* s : stmts) {
            TraverseStatement(s);
        }
    }

    /// Traverses the statement, performing symbol resolution and determining
    /// global dependencies.
    void TraverseStatement(const ast::Statement* stmt) {
        if (!stmt) {
            return;
        }
        Switch(
            stmt,  //
            [&](const ast::AssignmentStatement* a) {
                TraverseExpression(a->lhs);
                TraverseExpression(a->rhs);
            },
            [&](const ast::BlockStatement* b) {
                scope_stack_.Push();
                TINT_DEFER(scope_stack_.Pop());
                TraverseStatements(b->statements);
            },
            [&](const ast::BreakIfStatement* b) { TraverseExpression(b->condition); },
            [&](const ast::CallStatement* r) { TraverseExpression(r->expr); },
            [&](const ast::CompoundAssignmentStatement* a) {
                TraverseExpression(a->lhs);
                TraverseExpression(a->rhs);
            },
            [&](const ast::ForLoopStatement* l) {
                scope_stack_.Push();
                TINT_DEFER(scope_stack_.Pop());
                TraverseStatement(l->initializer);
                TraverseExpression(l->condition);
                TraverseStatement(l->continuing);
                TraverseStatement(l->body);
            },
            [&](const ast::IncrementDecrementStatement* i) { TraverseExpression(i->lhs); },
            [&](const ast::LoopStatement* l) {
                scope_stack_.Push();
                TINT_DEFER(scope_stack_.Pop());
                TraverseStatements(l->body->statements);
                TraverseStatement(l->continuing);
            },
            [&](const ast::IfStatement* i) {
                TraverseExpression(i->condition);
                TraverseStatement(i->body);
                if (i->else_statement) {
                    TraverseStatement(i->else_statement);
                }
            },
            [&](const ast::ReturnStatement* r) { TraverseExpression(r->value); },
            [&](const ast::SwitchStatement* s) {
                TraverseExpression(s->condition);
                for (auto* c : s->body) {
                    for (auto* sel : c->selectors) {
                        TraverseExpression(sel->expr);
                    }
                    TraverseStatement(c->body);
                }
            },
            [&](const ast::VariableDeclStatement* v) {
                if (auto* shadows = scope_stack_.Get(v->variable->name->symbol)) {
                    graph_.shadows.Add(v->variable, shadows);
                }
                TraverseVariable(v->variable);
                Declare(v->variable->name->symbol, v->variable);
            },
            [&](const ast::WhileStatement* w) {
                scope_stack_.Push();
                TINT_DEFER(scope_stack_.Pop());
                TraverseExpression(w->condition);
                TraverseStatement(w->body);
            },
            [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
            [&](const ast::BreakStatement*) {},     //
            [&](const ast::ContinueStatement*) {},  //
            [&](const ast::DiscardStatement*) {},   //
            TINT_ICE_ON_NO_MATCH);
    }

    /// Adds the symbol definition to the current scope, raising an error if two
    /// symbols collide within the same scope.
    void Declare(Symbol symbol, const ast::Node* node) {
        auto* old = scope_stack_.Set(symbol, node);
        if (old != nullptr && node != old) {
            auto name = symbol.Name();
            AddError(diagnostics_, "redeclaration of '" + name + "'", node->source);
            AddNote(diagnostics_, "'" + name + "' previously declared here", old->source);
        }
    }

    /// Traverses the expression @p root_expr, performing symbol resolution and determining global
    /// dependencies.
    void TraverseExpression(const ast::Expression* root_expr) {
        if (!root_expr) {
            return;
        }

        Vector<const ast::Expression*, 8> pending{root_expr};
        while (!pending.IsEmpty()) {
            auto* next = pending.Pop();
            bool ok = ast::TraverseExpressions(next, [&](const ast::Expression* expr) {
                Switch(
                    expr,
                    [&](const ast::IdentifierExpression* e) {
                        AddDependency(e->identifier, e->identifier->symbol);
                    },
                    [&](const ast::BitcastExpression* cast) { TraverseExpression(cast->type); });
                return ast::TraverseAction::Descend;
            });
            if (!ok) {
                AddError(diagnostics_, "TraverseExpressions failed", next->source);
                return;
            }
        }
    }

    /// Traverses the attribute list, performing symbol resolution and
    /// determining global dependencies.
    void TraverseAttributes(VectorRef<const ast::Attribute*> attrs) {
        for (auto* attr : attrs) {
            TraverseAttribute(attr);
        }
    }

    /// Traverses the attribute, performing symbol resolution and determining
    /// global dependencies.
    void TraverseAttribute(const ast::Attribute* attr) {
        Switch(
            attr,  //
            [&](const ast::BindingAttribute* binding) { TraverseExpression(binding->expr); },
            [&](const ast::BuiltinAttribute* builtin) { TraverseExpression(builtin->builtin); },
            [&](const ast::GroupAttribute* group) { TraverseExpression(group->expr); },
            [&](const ast::IdAttribute* id) { TraverseExpression(id->expr); },
            [&](const ast::IndexAttribute* index) { TraverseExpression(index->expr); },
            [&](const ast::InterpolateAttribute* interpolate) {
                TraverseExpression(interpolate->type);
                TraverseExpression(interpolate->sampling);
            },
            [&](const ast::LocationAttribute* loc) { TraverseExpression(loc->expr); },
            [&](const ast::StructMemberAlignAttribute* align) { TraverseExpression(align->expr); },
            [&](const ast::StructMemberSizeAttribute* size) { TraverseExpression(size->expr); },
            [&](const ast::WorkgroupAttribute* wg) {
                TraverseExpression(wg->x);
                TraverseExpression(wg->y);
                TraverseExpression(wg->z);
            },
            [&](const ast::InternalAttribute* i) {
                for (auto* dep : i->dependencies) {
                    TraverseExpression(dep);
                }
            },
            [&](Default) {
                if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::DiagnosticAttribute,
                                   ast::InterpolateAttribute, ast::InvariantAttribute,
                                   ast::MustUseAttribute, ast::StageAttribute, ast::StrideAttribute,
                                   ast::StructMemberOffsetAttribute>()) {
                    TINT_ICE() << "unhandled attribute type: " << attr->TypeInfo().name;
                }
            });
    }

    /// The type of builtin that a symbol could represent.
    enum class BuiltinType {
        /// No builtin matched
        kNone = 0,
        /// Builtin function
        kFunction,
        /// Builtin
        kBuiltin,
        /// Builtin value
        kBuiltinValue,
        /// Address space
        kAddressSpace,
        /// Texel format
        kTexelFormat,
        /// Access
        kAccess,
        /// Interpolation Type
        kInterpolationType,
        /// Interpolation Sampling
        kInterpolationSampling,
    };

    /// BuiltinInfo stores information about the builtin that a symbol represents.
    struct BuiltinInfo {
        /// @returns the builtin value
        template <typename T>
        T Value() const {
            return std::get<T>(value);
        }

        BuiltinType type = BuiltinType::kNone;
        std::variant<std::monostate,
                     wgsl::BuiltinFn,
                     core::BuiltinType,
                     core::BuiltinValue,
                     core::AddressSpace,
                     core::TexelFormat,
                     core::Access,
                     core::InterpolationType,
                     core::InterpolationSampling>
            value = {};
    };

    /// Get the builtin info for a given symbol.
    /// @param symbol the symbol
    /// @returns the builtin info
    DependencyScanner::BuiltinInfo GetBuiltinInfo(Symbol symbol) {
        return builtin_info_map.GetOrCreate(symbol, [&] {
            if (auto builtin_fn = wgsl::ParseBuiltinFn(symbol.NameView());
                builtin_fn != wgsl::BuiltinFn::kNone) {
                return BuiltinInfo{BuiltinType::kFunction, builtin_fn};
            }
            if (auto builtin_ty = core::ParseBuiltinType(symbol.NameView());
                builtin_ty != core::BuiltinType::kUndefined) {
                return BuiltinInfo{BuiltinType::kBuiltin, builtin_ty};
            }
            if (auto builtin_val = core::ParseBuiltinValue(symbol.NameView());
                builtin_val != core::BuiltinValue::kUndefined) {
                return BuiltinInfo{BuiltinType::kBuiltinValue, builtin_val};
            }
            if (auto addr = core::ParseAddressSpace(symbol.NameView());
                addr != core::AddressSpace::kUndefined) {
                return BuiltinInfo{BuiltinType::kAddressSpace, addr};
            }
            if (auto fmt = core::ParseTexelFormat(symbol.NameView());
                fmt != core::TexelFormat::kUndefined) {
                return BuiltinInfo{BuiltinType::kTexelFormat, fmt};
            }
            if (auto access = core::ParseAccess(symbol.NameView());
                access != core::Access::kUndefined) {
                return BuiltinInfo{BuiltinType::kAccess, access};
            }
            if (auto i_type = core::ParseInterpolationType(symbol.NameView());
                i_type != core::InterpolationType::kUndefined) {
                return BuiltinInfo{BuiltinType::kInterpolationType, i_type};
            }
            if (auto i_smpl = core::ParseInterpolationSampling(symbol.NameView());
                i_smpl != core::InterpolationSampling::kUndefined) {
                return BuiltinInfo{BuiltinType::kInterpolationSampling, i_smpl};
            }
            return BuiltinInfo{};
        });
    }

    /// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved.
    void AddDependency(const ast::Identifier* from, Symbol to) {
        auto* resolved = scope_stack_.Get(to);
        if (!resolved) {
            auto builtin_info = GetBuiltinInfo(to);
            switch (builtin_info.type) {
                case BuiltinType::kNone:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier::UnresolvedIdentifier{to.Name()});
                    break;
                case BuiltinType::kFunction:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<wgsl::BuiltinFn>()));
                    break;
                case BuiltinType::kBuiltin:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<core::BuiltinType>()));
                    break;
                case BuiltinType::kBuiltinValue:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<core::BuiltinValue>()));
                    break;
                case BuiltinType::kAddressSpace:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<core::AddressSpace>()));
                    break;
                case BuiltinType::kTexelFormat:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<core::TexelFormat>()));
                    break;
                case BuiltinType::kAccess:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<core::Access>()));
                    break;
                case BuiltinType::kInterpolationType:
                    graph_.resolved_identifiers.Add(
                        from, ResolvedIdentifier(builtin_info.Value<core::InterpolationType>()));
                    break;
                case BuiltinType::kInterpolationSampling:
                    graph_.resolved_identifiers.Add(
                        from,
                        ResolvedIdentifier(builtin_info.Value<core::InterpolationSampling>()));
                    break;
            }
            return;
        }

        if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
            if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
                                      DependencyInfo{from->source})) {
                current_global_->deps.Push(*global);
            }
        }

        graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved));
    }

    using VariableMap = Hashmap<Symbol, const ast::Variable*, 32>;
    const GlobalMap& globals_;
    diag::List& diagnostics_;
    DependencyGraph& graph_;
    DependencyEdges& dependency_edges_;

    ScopeStack<Symbol, const ast::Node*> scope_stack_;
    Global* current_global_ = nullptr;

    Hashmap<Symbol, BuiltinInfo, 64> builtin_info_map;
};

/// The global dependency analysis system
struct DependencyAnalysis {
  public:
    /// Constructor
    DependencyAnalysis(diag::List& diagnostics, DependencyGraph& graph)
        : diagnostics_(diagnostics), graph_(graph) {}

    /// Performs global dependency analysis on the module, emitting any errors to
    /// #diagnostics.
    /// @returns true if analysis found no errors, otherwise false.
    bool Run(const ast::Module& module) {
        // Reserve container memory
        graph_.resolved_identifiers.Reserve(module.GlobalDeclarations().Length());
        sorted_.Reserve(module.GlobalDeclarations().Length());

        // Collect all the named globals from the AST module
        GatherGlobals(module);

        // Traverse the named globals to build the dependency graph
        DetermineDependencies();

        // Sort the globals into dependency order
        SortGlobals();

        // Dump the dependency graph if TINT_DUMP_DEPENDENCY_GRAPH is non-zero
        DumpDependencyGraph();

        graph_.ordered_globals = sorted_.Release();

        return !diagnostics_.contains_errors();
    }

  private:
    /// @param node the ast::Node of the global declaration
    /// @returns the symbol of the global declaration node
    /// @note will raise an ICE if the node is not a type, function or variable
    /// declaration
    Symbol SymbolOf(const ast::Node* node) const {
        return Switch(
            node,  //
            [&](const ast::TypeDecl* td) { return td->name->symbol; },
            [&](const ast::Function* func) { return func->name->symbol; },
            [&](const ast::Variable* var) { return var->name->symbol; },
            [&](const ast::DiagnosticDirective*) { return Symbol(); },
            [&](const ast::Enable*) { return Symbol(); },
            [&](const ast::Requires*) { return Symbol(); },
            [&](const ast::ConstAssert*) { return Symbol(); },  //
            TINT_ICE_ON_NO_MATCH);
    }

    /// @param node the ast::Node of the global declaration
    /// @returns the name of the global declaration node
    /// @note will raise an ICE if the node is not a type, function or variable
    /// declaration
    std::string NameOf(const ast::Node* node) const { return SymbolOf(node).Name(); }

    /// @param node the ast::Node of the global declaration
    /// @returns a string representation of the global declaration kind
    /// @note will raise an ICE if the node is not a type, function or variable
    /// declaration
    std::string KindOf(const ast::Node* node) {
        return Switch(
            node,                                                     //
            [&](const ast::Struct*) { return "struct"; },             //
            [&](const ast::Alias*) { return "alias"; },               //
            [&](const ast::Function*) { return "function"; },         //
            [&](const ast::Variable* v) { return v->Kind(); },        //
            [&](const ast::ConstAssert*) { return "const_assert"; },  //
            TINT_ICE_ON_NO_MATCH);
    }

    /// Traverses `module`, collecting all the global declarations and populating
    /// the #globals and #declaration_order fields.
    void GatherGlobals(const ast::Module& module) {
        for (auto* node : module.GlobalDeclarations()) {
            auto* global = allocator_.Create(node);
            if (auto symbol = SymbolOf(node); symbol.IsValid()) {
                globals_.Add(symbol, global);
            }
            declaration_order_.Push(global);
        }
    }

    /// Walks the global declarations, determining the dependencies of each global
    /// and adding these to each global's Global::deps field.
    void DetermineDependencies() {
        DependencyScanner scanner(globals_, diagnostics_, graph_, dependency_edges_);
        for (auto* global : declaration_order_) {
            scanner.Scan(global);
        }
    }

    /// Performs a depth-first traversal of `root`'s dependencies, calling `enter`
    /// as the function decends into each dependency and `exit` when bubbling back
    /// up towards the root.
    /// @param enter is a function with the signature: `bool(Global*)`. The
    /// `enter` function returns true if TraverseDependencies() should traverse
    /// the dependency, otherwise it will be skipped.
    /// @param exit is a function with the signature: `void(Global*)`. The `exit`
    /// function is only called if the corresponding `enter` call returned true.
    template <typename ENTER, typename EXIT>
    void TraverseDependencies(const Global* root, ENTER&& enter, EXIT&& exit) {
        // Entry is a single entry in the traversal stack. Entry points to a
        // dep_idx'th dependency of Entry::global.
        struct Entry {
            const Global* global;  // The parent global
            size_t dep_idx;        // The dependency index in `global->deps`
        };

        if (!enter(root)) {
            return;
        }

        Vector<Entry, 16> stack{Entry{root, 0}};
        while (true) {
            auto& entry = stack.Back();
            // Have we exhausted the dependencies of entry.global?
            if (entry.dep_idx < entry.global->deps.Length()) {
                // No, there's more dependencies to traverse.
                auto& dep = entry.global->deps[entry.dep_idx];
                // Does the caller want to enter this dependency?
                if (enter(dep)) {               // Yes.
                    stack.Push(Entry{dep, 0});  // Enter the dependency.
                } else {
                    entry.dep_idx++;  // No. Skip this node.
                }
            } else {
                // Yes. Time to back up.
                // Exit this global, pop the stack, and if there's another parent node,
                // increment its dependency index, and loop again.
                exit(entry.global);
                stack.Pop();
                if (stack.IsEmpty()) {
                    return;  // All done.
                }
                stack.Back().dep_idx++;
            }
        }
    }

    /// SortGlobals sorts the globals into dependency order, erroring if cyclic
    /// dependencies are found. The sorted dependencies are assigned to #sorted.
    void SortGlobals() {
        if (diagnostics_.contains_errors()) {
            return;  // This code assumes there are no undeclared identifiers.
        }

        // Make sure all directives go before any other global declarations.
        for (auto* global : declaration_order_) {
            if (global->node->IsAnyOf<ast::DiagnosticDirective, ast::Enable, ast::Requires>()) {
                sorted_.Add(global->node);
            }
        }

        for (auto* global : declaration_order_) {
            if (global->node->IsAnyOf<ast::DiagnosticDirective, ast::Enable, ast::Requires>()) {
                // Skip directives here, as they are already added.
                continue;
            }
            UniqueVector<const Global*, 8> stack;
            TraverseDependencies(
                global,
                [&](const Global* g) {  // Enter
                    if (!stack.Add(g)) {
                        CyclicDependencyFound(g, stack.Release());
                        return false;
                    }
                    if (sorted_.Contains(g->node)) {
                        // Visited this global already.
                        // stack was pushed, but exit() will not be called when we return
                        // false, so pop here.
                        stack.Pop();
                        return false;
                    }
                    return true;
                },
                [&](const Global* g) {  // Exit. Only called if Enter returned true.
                    sorted_.Add(g->node);
                    stack.Pop();
                });

            sorted_.Add(global->node);

            if (TINT_UNLIKELY(!stack.IsEmpty())) {
                // Each stack.push() must have a corresponding stack.pop_back().
                TINT_ICE() << "stack not empty after returning from TraverseDependencies()";
            }
        }
    }

    /// DepInfoFor() looks up the global dependency information for the dependency
    /// of global `from` depending on `to`.
    /// @note will raise an ICE if the edge is not found.
    DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
        auto info = dependency_edges_.Find(DependencyEdge{from, to});
        if (TINT_LIKELY(info)) {
            return *info;
        }
        TINT_ICE() << "failed to find dependency info for edge: '" << NameOf(from->node) << "' -> '"
                   << NameOf(to->node) << "'";
        return {};
    }

    /// CyclicDependencyFound() emits an error diagnostic for a cyclic dependency.
    /// @param root is the global that starts the cyclic dependency, which must be
    /// found in `stack`.
    /// @param stack is the global dependency stack that contains a loop.
    void CyclicDependencyFound(const Global* root, VectorRef<const Global*> stack) {
        StringStream msg;
        msg << "cyclic dependency found: ";
        constexpr size_t kLoopNotStarted = ~0u;
        size_t loop_start = kLoopNotStarted;
        for (size_t i = 0; i < stack.Length(); i++) {
            auto* e = stack[i];
            if (loop_start == kLoopNotStarted && e == root) {
                loop_start = i;
            }
            if (loop_start != kLoopNotStarted) {
                msg << "'" << NameOf(e->node) << "' -> ";
            }
        }
        msg << "'" << NameOf(root->node) << "'";
        AddError(diagnostics_, msg.str(), root->node->source);
        for (size_t i = loop_start; i < stack.Length(); i++) {
            auto* from = stack[i];
            auto* to = (i + 1 < stack.Length()) ? stack[i + 1] : stack[loop_start];
            auto info = DepInfoFor(from, to);
            AddNote(diagnostics_,
                    KindOf(from->node) + " '" + NameOf(from->node) + "' references " +
                        KindOf(to->node) + " '" + NameOf(to->node) + "' here",
                    info.source);
        }
    }

    void DumpDependencyGraph() {
#if TINT_DUMP_DEPENDENCY_GRAPH == 0
        if ((true)) {
            return;
        }
#endif  // TINT_DUMP_DEPENDENCY_GRAPH
        printf("=========================\n");
        printf("------ declaration ------ \n");
        for (auto* global : declaration_order_) {
            printf("%s\n", NameOf(global->node).c_str());
        }
        printf("------ dependencies ------ \n");
        for (auto* node : sorted_) {
            auto symbol = SymbolOf(node);
            auto* global = *globals_.Find(symbol);
            printf("%s depends on:\n", symbol.Name().c_str());
            for (auto* dep : global->deps) {
                printf("  %s\n", NameOf(dep->node).c_str());
            }
        }
        printf("=========================\n");
    }

    /// Program diagnostics
    diag::List& diagnostics_;

    /// The resulting dependency graph
    DependencyGraph& graph_;

    /// Allocator of Globals
    BlockAllocator<Global> allocator_;

    /// Global map, keyed by name. Populated by GatherGlobals().
    GlobalMap globals_;

    /// Map of DependencyEdge to DependencyInfo. Populated by DetermineDependencies().
    DependencyEdges dependency_edges_;

    /// Globals in declaration order. Populated by GatherGlobals().
    Vector<Global*, 64> declaration_order_;

    /// Globals in sorted dependency order. Populated by SortGlobals().
    UniqueVector<const ast::Node*, 64> sorted_;
};

}  // namespace

DependencyGraph::DependencyGraph() = default;
DependencyGraph::DependencyGraph(DependencyGraph&&) = default;
DependencyGraph::~DependencyGraph() = default;

bool DependencyGraph::Build(const ast::Module& module,
                            diag::List& diagnostics,
                            DependencyGraph& output) {
    DependencyAnalysis da{diagnostics, output};
    return da.Run(module);
}

std::string ResolvedIdentifier::String() const {
    if (auto* node = Node()) {
        return Switch(
            node,
            [&](const ast::TypeDecl* n) {  //
                return "type '" + n->name->symbol.Name() + "'";
            },
            [&](const ast::Var* n) {  //
                return "var '" + n->name->symbol.Name() + "'";
            },
            [&](const ast::Let* n) {  //
                return "let '" + n->name->symbol.Name() + "'";
            },
            [&](const ast::Const* n) {  //
                return "const '" + n->name->symbol.Name() + "'";
            },
            [&](const ast::Override* n) {  //
                return "override '" + n->name->symbol.Name() + "'";
            },
            [&](const ast::Function* n) {  //
                return "function '" + n->name->symbol.Name() + "'";
            },
            [&](const ast::Parameter* n) {  //
                return "parameter '" + n->name->symbol.Name() + "'";
            },  //
            TINT_ICE_ON_NO_MATCH);
    }
    if (auto builtin_fn = BuiltinFn(); builtin_fn != wgsl::BuiltinFn::kNone) {
        return "builtin function '" + tint::ToString(builtin_fn) + "'";
    }
    if (auto builtin_ty = BuiltinType(); builtin_ty != core::BuiltinType::kUndefined) {
        return "builtin type '" + tint::ToString(builtin_ty) + "'";
    }
    if (auto builtin_val = BuiltinValue(); builtin_val != core::BuiltinValue::kUndefined) {
        return "builtin value '" + tint::ToString(builtin_val) + "'";
    }
    if (auto access = Access(); access != core::Access::kUndefined) {
        return "access '" + tint::ToString(access) + "'";
    }
    if (auto addr = AddressSpace(); addr != core::AddressSpace::kUndefined) {
        return "address space '" + tint::ToString(addr) + "'";
    }
    if (auto type = InterpolationType(); type != core::InterpolationType::kUndefined) {
        return "interpolation type '" + tint::ToString(type) + "'";
    }
    if (auto smpl = InterpolationSampling(); smpl != core::InterpolationSampling::kUndefined) {
        return "interpolation sampling '" + tint::ToString(smpl) + "'";
    }
    if (auto fmt = TexelFormat(); fmt != core::TexelFormat::kUndefined) {
        return "texel format '" + tint::ToString(fmt) + "'";
    }
    if (auto* unresolved = Unresolved()) {
        return "unresolved identifier '" + unresolved->name + "'";
    }

    TINT_UNREACHABLE() << "unhandled ResolvedIdentifier";
    return "<unknown>";
}

}  // namespace tint::resolver
