tint/resolver: Move from STL to tint::utils containers
Change-Id: I883168a1a84457138de85decb921c5c430c32bd8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108702
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 21da5e8..cbf23db 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1357,9 +1357,8 @@
# overflows when resolving deeply nested expression chains or statements.
# Production builds neither use MSVC nor debug, so just bump the stack size
# for this build combination.
- string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type)
- if ((NOT ${build_type} STREQUAL "RELEASE") AND (NOT ${build_type} STREQUAL "RELWITHDEBINFO"))
- target_link_options(tint_unittests PRIVATE "/STACK 2097152") # 2MB, default is 1MB
+ if (IS_DEBUG_BUILD)
+ target_link_options(tint_unittests PRIVATE "/STACK:4194304") # 4MB, default is 1MB
endif()
else()
target_compile_options(tint_unittests PRIVATE
diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc
index b67a9ee..7340387 100644
--- a/src/tint/resolver/builtin_validation_test.cc
+++ b/src/tint/resolver/builtin_validation_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <unordered_set>
+
#include "src/tint/ast/builtin_texture_helper_test.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/type_initializer.h"
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 43ccb8d..10b73bb 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -20,7 +20,6 @@
#include <optional>
#include <string>
#include <type_traits>
-#include <unordered_map>
#include <utility>
#include "src/tint/program_builder.h"
@@ -463,18 +462,18 @@
return nullptr;
},
[&](const sem::Struct* s) -> const ImplConstant* {
- std::unordered_map<const sem::Type*, const ImplConstant*> zero_by_type;
+ utils::Hashmap<const sem::Type*, const ImplConstant*, 8> zero_by_type;
utils::Vector<const sem::Constant*, 4> zeros;
zeros.Reserve(s->Members().size());
for (auto* member : s->Members()) {
- auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
- [&] { return ZeroValue(builder, member->Type()); });
+ auto* zero = zero_by_type.GetOrCreate(
+ member->Type(), [&] { return ZeroValue(builder, member->Type()); });
if (!zero) {
return nullptr;
}
zeros.Push(zero);
}
- if (zero_by_type.size() == 1) {
+ if (zero_by_type.Count() == 1) {
// All members were of the same type, so the zero value is the same for all members.
return builder.create<Splat>(type, zeros[0], s->Members().size());
}
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 3238c8b..4dbf1c0 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -15,7 +15,6 @@
#include "src/tint/resolver/dependency_graph.h"
#include <string>
-#include <unordered_set>
#include <utility>
#include <vector>
@@ -117,7 +116,7 @@
/// A map of DependencyEdge to DependencyInfo
using DependencyEdges =
- std::unordered_map<DependencyEdge, DependencyInfo, DependencyEdgeCmp, DependencyEdgeCmp>;
+ utils::Hashmap<DependencyEdge, DependencyInfo, 64, DependencyEdgeCmp, DependencyEdgeCmp>;
/// Global describes a module-scope variable, type or function.
struct Global {
@@ -126,11 +125,11 @@
/// The declaration ast::Node
const ast::Node* node;
/// A list of dependencies that this global depends on
- std::vector<Global*> deps;
+ utils::Vector<Global*, 8> deps;
};
/// A map of global name to Global
-using GlobalMap = std::unordered_map<Symbol, Global*>;
+using GlobalMap = utils::Hashmap<Symbol, Global*, 16>;
/// Raises an ICE that a global ast::Node type was not handled by this system.
void UnhandledNode(diag::List& diagnostics, const ast::Node* node) {
@@ -170,7 +169,7 @@
dependency_edges_(edges) {
// Register all the globals at global-scope
for (auto it : globals_by_name) {
- scope_stack_.Set(it.first, it.second->node);
+ scope_stack_.Set(it.key, it.value->node);
}
}
@@ -232,7 +231,7 @@
for (auto* param : func->params) {
if (auto* shadows = scope_stack_.Get(param->symbol)) {
- graph_.shadows.emplace(param, shadows);
+ graph_.shadows.Add(param, shadows);
}
Declare(param->symbol, param);
}
@@ -306,7 +305,7 @@
},
[&](const ast::VariableDeclStatement* v) {
if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
- graph_.shadows.emplace(v->variable, shadows);
+ graph_.shadows.Add(v->variable, shadows);
}
TraverseType(v->variable->type);
TraverseExpression(v->variable->initializer);
@@ -473,16 +472,14 @@
}
}
- if (auto* global = utils::Lookup(globals_, to); global && global->node == resolved) {
- if (dependency_edges_
- .emplace(DependencyEdge{current_global_, global},
- DependencyInfo{from->source, action})
- .second) {
- current_global_->deps.emplace_back(global);
+ if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
+ if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
+ DependencyInfo{from->source, action})) {
+ current_global_->deps.Push(*global);
}
}
- graph_.resolved_symbols.emplace(from, resolved);
+ graph_.resolved_symbols.Add(from, resolved);
}
/// @returns true if `name` is the name of a builtin function
@@ -497,7 +494,7 @@
source);
}
- using VariableMap = std::unordered_map<Symbol, const ast::Variable*>;
+ using VariableMap = utils::Hashmap<Symbol, const ast::Variable*, 32>;
const SymbolTable& symbols_;
const GlobalMap& globals_;
diag::List& diagnostics_;
@@ -520,7 +517,7 @@
/// @returns true if analysis found no errors, otherwise false.
bool Run(const ast::Module& module) {
// Reserve container memory
- graph_.resolved_symbols.reserve(module.GlobalDeclarations().Length());
+ graph_.resolved_symbols.Reserve(module.GlobalDeclarations().Length());
sorted_.Reserve(module.GlobalDeclarations().Length());
// Collect all the named globals from the AST module
@@ -589,9 +586,9 @@
for (auto* node : module.GlobalDeclarations()) {
auto* global = allocator_.Create(node);
if (auto symbol = SymbolOf(node); symbol.IsValid()) {
- globals_.emplace(symbol, global);
+ globals_.Add(symbol, global);
}
- declaration_order_.emplace_back(global);
+ declaration_order_.Push(global);
}
}
@@ -625,16 +622,16 @@
return;
}
- std::vector<Entry> stack{Entry{root, 0}};
+ utils::Vector<Entry, 16> stack{Entry{root, 0}};
while (true) {
- auto& entry = stack.back();
+ auto& entry = stack.Back();
// Have we exhausted the dependencies of entry.global?
- if (entry.dep_idx < entry.global->deps.size()) {
+ if (entry.dep_idx < entry.global->deps.Length()) {
// No, there's more dependencies to traverse.
auto& dep = entry.global->deps[entry.dep_idx];
// Does the caller want to enter this dependency?
- if (enter(dep)) { // Yes.
- stack.push_back(Entry{dep, 0}); // Enter the dependency.
+ if (enter(dep)) { // Yes.
+ stack.Push(Entry{dep, 0}); // Enter the dependency.
} else {
entry.dep_idx++; // No. Skip this node.
}
@@ -643,11 +640,11 @@
// Exit this global, pop the stack, and if there's another parent node,
// increment its dependency index, and loop again.
exit(entry.global);
- stack.pop_back();
- if (stack.empty()) {
+ stack.Pop();
+ if (stack.IsEmpty()) {
return; // All done.
}
- stack.back().dep_idx++;
+ stack.Back().dep_idx++;
}
}
}
@@ -707,9 +704,8 @@
/// of global `from` depending on `to`.
/// @note will raise an ICE if the edge is not found.
DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
- auto it = dependency_edges_.find(DependencyEdge{from, to});
- if (it != dependency_edges_.end()) {
- return it->second;
+ if (auto info = dependency_edges_.Find(DependencyEdge{from, to})) {
+ return *info;
}
TINT_ICE(Resolver, diagnostics_)
<< "failed to find dependency info for edge: '" << NameOf(from->node) << "' -> '"
@@ -762,7 +758,7 @@
printf("------ dependencies ------ \n");
for (auto* node : sorted_) {
auto symbol = SymbolOf(node);
- auto* global = globals_.at(symbol);
+ auto* global = *globals_.Find(symbol);
printf("%s depends on:\n", symbols_.NameFor(symbol).c_str());
for (auto* dep : global->deps) {
printf(" %s\n", NameOf(dep->node).c_str());
@@ -791,7 +787,7 @@
DependencyEdges dependency_edges_;
/// Globals in declaration order. Populated by GatherGlobals().
- std::vector<Global*> declaration_order_;
+ utils::Vector<Global*, 64> declaration_order_;
/// Globals in sorted dependency order. Populated by SortGlobals().
utils::UniqueVector<const ast::Node*, 64> sorted_;
diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h
index 9f5ddc5..bc849a0 100644
--- a/src/tint/resolver/dependency_graph.h
+++ b/src/tint/resolver/dependency_graph.h
@@ -15,11 +15,11 @@
#ifndef SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
#define SRC_TINT_RESOLVER_DEPENDENCY_GRAPH_H_
-#include <unordered_map>
#include <vector>
#include "src/tint/ast/module.h"
#include "src/tint/diagnostic/diagnostic.h"
+#include "src/tint/utils/hashmap.h"
namespace tint::resolver {
@@ -50,13 +50,13 @@
/// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
/// variable that declares the symbol.
- std::unordered_map<const ast::Node*, const ast::Node*> resolved_symbols;
+ utils::Hashmap<const ast::Node*, const ast::Node*, 64> resolved_symbols;
/// Map of ast::Variable to a type, function, or variable that is shadowed by
/// the variable key. A declaration (X) shadows another (Y) if X and Y use
/// the same symbol, and X is declared in a sub-scope of the scope that
/// declares Y.
- std::unordered_map<const ast::Variable*, const ast::Node*> shadows;
+ utils::Hashmap<const ast::Variable*, const ast::Node*, 16> shadows;
};
} // namespace tint::resolver
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 272357b..c02a561 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1128,9 +1128,10 @@
if (expect_pass) {
// Check that the use resolves to the declaration
- auto* resolved_symbol = graph.resolved_symbols[use];
- EXPECT_EQ(resolved_symbol, decl)
- << "resolved: " << (resolved_symbol ? resolved_symbol->TypeInfo().name : "<null>")
+ auto* resolved_symbol = graph.resolved_symbols.Find(use);
+ ASSERT_NE(resolved_symbol, nullptr);
+ EXPECT_EQ(*resolved_symbol, decl)
+ << "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n"
<< "decl: " << decl->TypeInfo().name;
}
@@ -1177,7 +1178,10 @@
: helper.parameters[0];
helper.Build();
- EXPECT_EQ(Build().shadows[inner_var], outer);
+ auto shadows = Build().shadows;
+ auto* shadow = shadows.Find(inner_var);
+ ASSERT_NE(shadow, nullptr);
+ EXPECT_EQ(*shadow, outer);
}
INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal,
@@ -1308,8 +1312,9 @@
auto graph = Build();
for (auto use : symbol_uses) {
- auto* resolved_symbol = graph.resolved_symbols[use.use];
- EXPECT_EQ(resolved_symbol, use.decl) << use.where;
+ auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
+ ASSERT_NE(resolved_symbol, nullptr) << use.where;
+ EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
}
}
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 7545f48..970b294 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -16,7 +16,6 @@
#include <algorithm>
#include <limits>
-#include <unordered_map>
#include <utility>
#include "src/tint/ast/binary_expression.h"
@@ -36,7 +35,7 @@
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/type_initializer.h"
#include "src/tint/utils/hash.h"
-#include "src/tint/utils/map.h"
+#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/scoped_assignment.h"
@@ -1114,10 +1113,10 @@
ProgramBuilder& builder;
Matchers matchers;
- std::unordered_map<IntrinsicPrototype, sem::Builtin*, IntrinsicPrototype::Hasher> builtins;
- std::unordered_map<IntrinsicPrototype, sem::TypeInitializer*, IntrinsicPrototype::Hasher>
+ utils::Hashmap<IntrinsicPrototype, sem::Builtin*, 64, IntrinsicPrototype::Hasher> builtins;
+ utils::Hashmap<IntrinsicPrototype, sem::TypeInitializer*, 16, IntrinsicPrototype::Hasher>
initializers;
- std::unordered_map<IntrinsicPrototype, sem::TypeConversion*, IntrinsicPrototype::Hasher>
+ utils::Hashmap<IntrinsicPrototype, sem::TypeConversion*, 16, IntrinsicPrototype::Hasher>
converters;
};
@@ -1185,7 +1184,7 @@
}
// De-duplicate builtins that are identical.
- auto* sem = utils::GetOrCreate(builtins, match, [&] {
+ auto* sem = builtins.GetOrCreate(match, [&] {
utils::Vector<sem::Parameter*, kNumFixedParams> params;
params.Reserve(match.parameters.Length());
for (auto& p : match.parameters) {
@@ -1396,7 +1395,7 @@
}
auto eval_stage = match.overload->const_eval_fn ? sem::EvaluationStage::kConstant
: sem::EvaluationStage::kRuntime;
- auto* target = utils::GetOrCreate(initializers, match, [&]() {
+ auto* target = initializers.GetOrCreate(match, [&]() {
return builder.create<sem::TypeInitializer>(match.return_type, std::move(params),
eval_stage);
});
@@ -1404,7 +1403,7 @@
}
// Conversion.
- auto* target = utils::GetOrCreate(converters, match, [&]() {
+ auto* target = converters.GetOrCreate(match, [&]() {
auto param = builder.create<sem::Parameter>(
nullptr, 0u, match.parameters[0].type, ast::AddressSpace::kNone,
ast::Access::kUndefined, match.parameters[0].usage);
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index d30d671..880d595 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -482,7 +482,7 @@
sem->SetOverrideId(o);
// Track the constant IDs that are specified in the shader.
- override_ids_.emplace(o, sem);
+ override_ids_.Add(o, sem);
}
builder_->Sem().Add(v, sem);
@@ -842,7 +842,7 @@
id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId();
} else {
// No ID was specified, so allocate the next available ID.
- while (!ids_exhausted && override_ids_.count(next_id)) {
+ while (!ids_exhausted && override_ids_.Contains(next_id)) {
increment_next_id();
}
if (ids_exhausted) {
@@ -864,9 +864,9 @@
void Resolver::SetShadows() {
for (auto it : dependencies_.shadows) {
Switch(
- sem_.Get(it.first), //
- [&](sem::LocalVariable* local) { local->SetShadows(sem_.Get(it.second)); },
- [&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.second)); });
+ sem_.Get(it.key), //
+ [&](sem::LocalVariable* local) { local->SetShadows(sem_.Get(it.value)); },
+ [&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.value)); });
}
}
@@ -923,7 +923,7 @@
sem::Function* Resolver::Function(const ast::Function* decl) {
uint32_t parameter_index = 0;
- std::unordered_map<Symbol, Source> parameter_names;
+ utils::Hashmap<Symbol, Source, 8> parameter_names;
utils::Vector<sem::Parameter*, 8> parameters;
// Resolve all the parameters
@@ -931,11 +931,10 @@
Mark(param);
{ // Check the parameter name is unique for the function
- auto emplaced = parameter_names.emplace(param->symbol, param->source);
- if (!emplaced.second) {
+ if (auto added = parameter_names.Add(param->symbol, param->source); !added) {
auto name = builder_->Symbols().NameFor(param->symbol);
AddError("redefinition of parameter '" + name + "'", param->source);
- AddNote("previous definition is here", emplaced.first->second);
+ AddNote("previous definition is here", *added.value);
return nullptr;
}
}
@@ -1031,7 +1030,7 @@
}
if (decl->IsEntryPoint()) {
- entry_points_.emplace_back(func);
+ entry_points_.Push(func);
}
if (decl->body) {
@@ -1850,8 +1849,8 @@
[&](const sem::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); },
[&](const sem::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); },
[&](const sem::Array* arr) -> sem::Call* {
- auto* call_target = utils::GetOrCreate(
- array_inits_, ArrayInitializerSig{{arr, args.Length(), args_stage}},
+ auto* call_target = array_inits_.GetOrCreate(
+ ArrayInitializerSig{{arr, args.Length(), args_stage}},
[&]() -> sem::TypeInitializer* {
auto params = utils::Transform(args, [&](auto, size_t i) {
return builder_->create<sem::Parameter>(
@@ -1877,8 +1876,8 @@
return call;
},
[&](const sem::Struct* str) -> sem::Call* {
- auto* call_target = utils::GetOrCreate(
- struct_inits_, StructInitializerSig{{str, args.Length(), args_stage}},
+ auto* call_target = struct_inits_.GetOrCreate(
+ StructInitializerSig{{str, args.Length(), args_stage}},
[&]() -> sem::TypeInitializer* {
utils::Vector<const sem::Parameter*, 8> params;
params.Resize(std::min(args.Length(), str->Members().size()));
@@ -1981,9 +1980,9 @@
AddError(
"cannot infer common array element type from initializer arguments",
expr->source);
- std::unordered_set<const sem::Type*> types;
+ utils::Hashset<const sem::Type*, 8> types;
for (size_t i = 0; i < args.Length(); i++) {
- if (types.emplace(args[i]->Type()).second) {
+ if (types.Add(args[i]->Type())) {
AddNote("argument " + std::to_string(i) + " is of type '" +
sem_.TypeNameOf(args[i]->Type()) + "'",
args[i]->Declaration()->source);
@@ -2687,11 +2686,10 @@
}
if (el_ty->Is<sem::Atomic>()) {
- atomic_composite_info_.emplace(out, arr->type->source);
+ atomic_composite_info_.Add(out, &arr->type->source);
} else {
- auto found = atomic_composite_info_.find(el_ty);
- if (found != atomic_composite_info_.end()) {
- atomic_composite_info_.emplace(out, found->second);
+ if (auto* found = atomic_composite_info_.Find(el_ty)) {
+ atomic_composite_info_.Add(out, *found);
}
}
@@ -2832,15 +2830,14 @@
// validation.
uint64_t struct_size = 0;
uint64_t struct_align = 1;
- std::unordered_map<Symbol, const ast::StructMember*> member_map;
+ utils::Hashmap<Symbol, const ast::StructMember*, 8> member_map;
for (auto* member : str->members) {
Mark(member);
- auto result = member_map.emplace(member->symbol, member);
- if (!result.second) {
+ if (auto added = member_map.Add(member->symbol, member); !added) {
AddError("redefinition of '" + builder_->Symbols().NameFor(member->symbol) + "'",
member->source);
- AddNote("previous definition is here", result.first->second->source);
+ AddNote("previous definition is here", (*added.value)->source);
return nullptr;
}
@@ -3027,12 +3024,11 @@
for (size_t i = 0; i < sem_members.size(); i++) {
auto* mem_type = sem_members[i]->Type();
if (mem_type->Is<sem::Atomic>()) {
- atomic_composite_info_.emplace(out, sem_members[i]->Declaration()->source);
+ atomic_composite_info_.Add(out, &sem_members[i]->Declaration()->source);
break;
} else {
- auto found = atomic_composite_info_.find(mem_type);
- if (found != atomic_composite_info_.end()) {
- atomic_composite_info_.emplace(out, found->second);
+ if (auto* found = atomic_composite_info_.Find(mem_type)) {
+ atomic_composite_info_.Add(out, *found);
break;
}
}
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 89aab13..bfb95b3 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -18,8 +18,6 @@
#include <memory>
#include <string>
#include <tuple>
-#include <unordered_map>
-#include <unordered_set>
#include <utility>
#include <vector>
@@ -434,13 +432,13 @@
SemHelper sem_;
Validator validator_;
ast::Extensions enabled_extensions_;
- std::vector<sem::Function*> entry_points_;
- std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
+ utils::Vector<sem::Function*, 8> entry_points_;
+ utils::Hashmap<const sem::Type*, const Source*, 8> atomic_composite_info_;
utils::Bitset<0> marked_;
ExprEvalStageConstraint expr_eval_stage_constraint_;
- std::unordered_map<OverrideId, const sem::Variable*> override_ids_;
- std::unordered_map<ArrayInitializerSig, sem::CallTarget*> array_inits_;
- std::unordered_map<StructInitializerSig, sem::CallTarget*> struct_inits_;
+ utils::Hashmap<OverrideId, const sem::Variable*, 8> override_ids_;
+ utils::Hashmap<ArrayInitializerSig, sem::CallTarget*, 8> array_inits_;
+ utils::Hashmap<StructInitializerSig, sem::CallTarget*, 8> struct_inits_;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index 9b0967b..12ef4a2 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -54,8 +54,8 @@
/// @param node the node to retrieve
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
- auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
- return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved)) : nullptr;
+ auto* resolved = dependencies_.resolved_symbols.Find(node);
+ return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
}
/// @returns the resolved type of the ast::Expression `expr`
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index f5389f7..8c4f478 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -16,8 +16,6 @@
#include <limits>
#include <string>
-#include <unordered_map>
-#include <unordered_set>
#include <utility>
#include <vector>
@@ -139,7 +137,7 @@
bool pointer_may_become_non_uniform = false;
/// The parameters that are required to be uniform for the contents of this pointer parameter to
/// be uniform at function exit.
- std::vector<const sem::Parameter*> pointer_param_output_sources;
+ utils::Vector<const sem::Parameter*, 8> pointer_param_output_sources;
/// The node in the graph that corresponds to this parameter's initial value.
Node* init_value;
/// The node in the graph that corresponds to this parameter's output value (or nullptr).
@@ -166,7 +164,7 @@
}
// Create nodes for parameters.
- parameters.resize(func->params.Length());
+ parameters.Resize(func->params.Length());
for (size_t i = 0; i < func->params.Length(); i++) {
auto* param = func->params[i];
auto param_name = builder->Symbols().NameFor(param->symbol);
@@ -177,7 +175,7 @@
if (sem->Type()->Is<sem::Pointer>()) {
node_init = CreateNode("ptrparam_" + name + "_init");
parameters[i].pointer_return_value = CreateNode("ptrparam_" + name + "_return");
- local_var_decls.insert(sem);
+ local_var_decls.Add(sem);
} else {
node_init = CreateNode("param_" + name);
}
@@ -194,7 +192,7 @@
/// The function's uniformity effects.
FunctionTag function_tag;
/// The uniformity requirements of the function's parameters.
- std::vector<ParameterInfo> parameters;
+ utils::Vector<ParameterInfo, 8> parameters;
/// The control flow graph.
utils::BlockAllocator<Node> nodes;
@@ -213,24 +211,31 @@
/// The set of a local read-write vars that are in scope at any given point in the process.
/// Includes pointer parameters.
- std::unordered_set<const sem::Variable*> local_var_decls;
+ utils::Hashset<const sem::Variable*, 8> local_var_decls;
/// The set of partial pointer variables - pointers that point to a subobject (into an array or
/// struct).
- std::unordered_set<const sem::Variable*> partial_ptrs;
+ utils::Hashset<const sem::Variable*, 4> partial_ptrs;
/// LoopSwitchInfo tracks information about the value of variables for a control flow construct.
struct LoopSwitchInfo {
/// The type of this control flow construct.
std::string type;
/// The input values for local variables at the start of this construct.
- std::unordered_map<const sem::Variable*, Node*> var_in_nodes;
+ utils::Hashmap<const sem::Variable*, Node*, 8> var_in_nodes;
/// The exit values for local variables at the end of this construct.
- std::unordered_map<const sem::Variable*, Node*> var_exit_nodes;
+ utils::Hashmap<const sem::Variable*, Node*, 8> var_exit_nodes;
};
- /// Map from control flow statements to the corresponding LoopSwitchInfo structure.
- std::unordered_map<const sem::Statement*, LoopSwitchInfo> loop_switch_infos;
+ /// @returns a LoopSwitchInfo for the given statement, allocating the LoopSwitchInfo if this is
+ /// the first call with the given statement.
+ LoopSwitchInfo& LoopSwitchInfoFor(const sem::Statement* stmt) {
+ return *loop_switch_infos.GetOrCreate(stmt,
+ [&] { return loop_switch_info_allocator.Create(); });
+ }
+
+ /// Disassociates the LoopSwitchInfo for the given statement.
+ void RemoveLoopSwitchInfoFor(const sem::Statement* stmt) { loop_switch_infos.Remove(stmt); }
/// Create a new node.
/// @param tag a tag used to identify the node for debugging purposes
@@ -263,7 +268,13 @@
private:
/// A list of tags that have already been used within the current function.
- std::unordered_set<std::string> tags_;
+ utils::Hashset<std::string, 8> tags_;
+
+ /// Map from control flow statements to the corresponding LoopSwitchInfo structure.
+ utils::Hashmap<const sem::Statement*, LoopSwitchInfo*, 8> loop_switch_infos;
+
+ /// Allocator of LoopSwitchInfos
+ utils::BlockAllocator<LoopSwitchInfo> loop_switch_info_allocator;
};
/// UniformityGraph is used to analyze the uniformity requirements and effects of functions in a
@@ -312,7 +323,7 @@
diag::List& diagnostics_;
/// Map of analyzed function results.
- std::unordered_map<const ast::Function*, FunctionInfo> functions_;
+ utils::Hashmap<const ast::Function*, FunctionInfo, 8> functions_;
/// The function currently being analyzed.
FunctionInfo* current_function_;
@@ -329,8 +340,7 @@
/// @param func the function to process
/// @returns true if there are no uniformity issues, false otherwise
bool ProcessFunction(const ast::Function* func) {
- functions_.emplace(func, FunctionInfo(func, builder_));
- current_function_ = &functions_.at(func);
+ current_function_ = functions_.Add(func, FunctionInfo(func, builder_)).value;
// Process function body.
if (func->body) {
@@ -410,7 +420,7 @@
for (size_t j = 0; j < func->params.Length(); j++) {
auto* param_source = sem_.Get<sem::Parameter>(func->params[j]);
if (reachable.Contains(current_function_->parameters[j].init_value)) {
- current_function_->parameters[i].pointer_param_output_sources.push_back(
+ current_function_->parameters[i].pointer_param_output_sources.Push(
param_source);
}
}
@@ -439,7 +449,7 @@
},
[&](const ast::BlockStatement* b) {
- std::unordered_map<const sem::Variable*, Node*> scoped_assignments;
+ utils::Hashmap<const sem::Variable*, Node*, 8> scoped_assignments;
{
// Push a new scope for variable assignments in the block.
current_function_->variables.Push();
@@ -472,13 +482,13 @@
if (behaviors.Contains(sem::Behavior::kNext) ||
behaviors.Contains(sem::Behavior::kFallthrough)) {
for (auto var : scoped_assignments) {
- current_function_->variables.Set(var.first, var.second);
+ current_function_->variables.Set(var.key, var.value);
}
}
// Remove any variables declared in this scope from the set of in-scope variables.
for (auto decl : sem_.Get<sem::BlockStatement>(b)->Decls()) {
- current_function_->local_var_decls.erase(decl.value.variable);
+ current_function_->local_var_decls.Remove(decl.value.variable);
}
return cf;
@@ -489,8 +499,8 @@
auto* parent = sem_.Get(b)
->FindFirstParent<sem::SwitchStatement, sem::LoopStatement,
sem::ForLoopStatement, sem::WhileStatement>();
- TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
- auto& info = current_function_->loop_switch_infos.at(parent);
+
+ auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop/switch exit nodes.
for (auto* var : current_function_->local_var_decls) {
@@ -502,7 +512,7 @@
}
// Add an edge from the variable exit node to its value at this point.
- auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@@ -526,8 +536,7 @@
{
auto* parent = sem_.Get(b)->FindFirstParent<sem::LoopStatement>();
- TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
- auto& info = current_function_->loop_switch_infos.at(parent);
+ auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate variable values to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
@@ -539,7 +548,7 @@
}
// Add an edge from the variable exit node to its value at this point.
- auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@@ -580,8 +589,7 @@
auto* parent = sem_.Get(c)
->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement,
sem::WhileStatement>();
- TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
- auto& info = current_function_->loop_switch_infos.at(parent);
+ auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate assignments to the loop input nodes.
for (auto* var : current_function_->local_var_decls) {
@@ -593,11 +601,11 @@
}
// Add an edge from the variable's loop input node to its value at this point.
- TINT_ASSERT(Resolver, info.var_in_nodes.count(var));
- auto* in_node = info.var_in_nodes.at(var);
+ auto** in_node = info.var_in_nodes.Find(var);
+ TINT_ASSERT(Resolver, in_node != nullptr);
auto* out_node = current_function_->variables.Get(var);
- if (out_node != in_node) {
- in_node->AddEdge(out_node);
+ if (out_node != *in_node) {
+ (*in_node)->AddEdge(out_node);
}
}
return cf;
@@ -618,7 +626,7 @@
}
auto* cf_start = cf_init;
- auto& info = current_function_->loop_switch_infos[sem_loop];
+ auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "forloop";
// Create input nodes for any variables declared before this loop.
@@ -626,7 +634,7 @@
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v));
- info.var_in_nodes[v] = in_node;
+ info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
@@ -640,7 +648,7 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
- auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@@ -660,19 +668,19 @@
// Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) {
- auto* in_node = v.second;
- auto* out_node = current_function_->variables.Get(v.first);
+ auto* in_node = v.value;
+ auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
}
// Set each variable's exit node as its value in the outer scope.
- for (auto v : info.var_exit_nodes) {
- current_function_->variables.Set(v.first, v.second);
+ for (auto& v : info.var_exit_nodes) {
+ current_function_->variables.Set(v.key, v.value);
}
- current_function_->loop_switch_infos.erase(sem_loop);
+ current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
@@ -687,7 +695,7 @@
auto* cf_start = cf;
- auto& info = current_function_->loop_switch_infos[sem_loop];
+ auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "whileloop";
// Create input nodes for any variables declared before this loop.
@@ -695,7 +703,7 @@
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v));
- info.var_in_nodes[v] = in_node;
+ info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
@@ -710,7 +718,7 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
- auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
@@ -721,9 +729,9 @@
cfx->AddEdge(cf);
// Add edges from variable loop input nodes to their values at the end of the loop.
- for (auto v : info.var_in_nodes) {
- auto* in_node = v.second;
- auto* out_node = current_function_->variables.Get(v.first);
+ for (auto& v : info.var_in_nodes) {
+ auto* in_node = v.value;
+ auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
@@ -731,10 +739,10 @@
// Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) {
- current_function_->variables.Set(v.first, v.second);
+ current_function_->variables.Set(v.key, v.value);
}
- current_function_->loop_switch_infos.erase(sem_loop);
+ current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
@@ -752,15 +760,15 @@
v->affects_control_flow = true;
v->AddEdge(v_cond);
- std::unordered_map<const sem::Variable*, Node*> true_vars;
- std::unordered_map<const sem::Variable*, Node*> false_vars;
+ utils::Hashmap<const sem::Variable*, Node*, 8> true_vars;
+ utils::Hashmap<const sem::Variable*, Node*, 8> false_vars;
// Helper to process a statement with a new scope for variable assignments.
// Populates `assigned_vars` with new nodes for any variables that are assigned in
// this statement.
auto process_in_scope =
[&](Node* cf_in, const ast::Statement* s,
- std::unordered_map<const sem::Variable*, Node*>& assigned_vars) {
+ utils::Hashmap<const sem::Variable*, Node*, 8>& assigned_vars) {
// Push a new scope for variable assignments.
current_function_->variables.Push();
@@ -790,7 +798,7 @@
// Update values for any variables assigned in the if or else blocks.
for (auto* var : current_function_->local_var_decls) {
// Skip variables not assigned in either block.
- if (true_vars.count(var) == 0 && false_vars.count(var) == 0) {
+ if (!true_vars.Contains(var) && !false_vars.Contains(var)) {
continue;
}
@@ -801,15 +809,15 @@
// Add edges to the assigned value or the initial value.
// Only add edges if the behavior for that block contains 'Next'.
if (true_has_next) {
- if (true_vars.count(var)) {
- out_node->AddEdge(true_vars.at(var));
+ if (true_vars.Contains(var)) {
+ out_node->AddEdge(*true_vars.Find(var));
} else {
out_node->AddEdge(current_function_->variables.Get(var));
}
}
if (false_has_next) {
- if (false_vars.count(var)) {
- out_node->AddEdge(false_vars.at(var));
+ if (false_vars.Contains(var)) {
+ out_node->AddEdge(*false_vars.Find(var));
} else {
out_node->AddEdge(current_function_->variables.Get(var));
}
@@ -845,7 +853,7 @@
auto* sem_loop = sem_.Get(l);
auto* cfx = CreateNode("loop_start");
- auto& info = current_function_->loop_switch_infos[sem_loop];
+ auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "loop";
// Create input nodes for any variables declared before this loop.
@@ -853,7 +861,7 @@
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_loop_in", v->Declaration());
in_node->AddEdge(current_function_->variables.Get(v));
- info.var_in_nodes[v] = in_node;
+ info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
@@ -868,8 +876,8 @@
// Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) {
- auto* in_node = v.second;
- auto* out_node = current_function_->variables.Get(v.first);
+ auto* in_node = v.value;
+ auto* out_node = current_function_->variables.Get(v.key);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
@@ -877,10 +885,10 @@
// Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) {
- current_function_->variables.Set(v.first, v.second);
+ current_function_->variables.Set(v.key, v.value);
}
- current_function_->loop_switch_infos.erase(sem_loop);
+ current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
@@ -925,7 +933,7 @@
cf_end = CreateNode("switch_CFend");
}
- auto& info = current_function_->loop_switch_infos[sem_switch];
+ auto& info = current_function_->LoopSwitchInfoFor(sem_switch);
info.type = "switch";
auto* cf_n = v;
@@ -958,12 +966,11 @@
}
// Add an edge from the variable exit node to its new value.
- auto* exit_node =
- utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
- auto name =
- builder_->Symbols().NameFor(var->Declaration()->symbol);
- return CreateNode(name + "_value_" + info.type + "_exit");
- });
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
+ auto name =
+ builder_->Symbols().NameFor(var->Declaration()->symbol);
+ return CreateNode(name + "_value_" + info.type + "_exit");
+ });
exit_node->AddEdge(current_function_->variables.Get(var));
}
}
@@ -974,7 +981,7 @@
// Update nodes for any variables assigned in the switch statement.
for (auto var : info.var_exit_nodes) {
- current_function_->variables.Set(var.first, var.second);
+ current_function_->variables.Set(var.key, var.value);
}
return cf_end ? cf_end : cf;
@@ -995,7 +1002,7 @@
auto* e = UnwrapIndirectAndAddressOfChain(unary_init);
if (e->IsAnyOf<ast::IndexAccessorExpression,
ast::MemberAccessorExpression>()) {
- current_function_->partial_ptrs.insert(sem_var);
+ current_function_->partial_ptrs.Add(sem_var);
}
}
}
@@ -1005,7 +1012,7 @@
current_function_->variables.Set(sem_var, node);
if (decl->variable->Is<ast::Var>()) {
- current_function_->local_var_decls.insert(
+ current_function_->local_var_decls.Add(
sem_.Get<sem::LocalVariable>(decl->variable));
}
@@ -1183,10 +1190,10 @@
// To determine if we're dereferencing a partial pointer, unwrap *&
// chains; if the final expression is an identifier, see if it's a
// partial pointer. If it's not an identifier, then it must be an
- // index/acessor expression, and thus a partial pointer.
+ // index/accessor expression, and thus a partial pointer.
auto* e = UnwrapIndirectAndAddressOfChain(u);
if (auto* var_user = sem_.Get<sem::VariableUser>(e)) {
- if (current_function_->partial_ptrs.count(var_user->Variable())) {
+ if (current_function_->partial_ptrs.Contains(var_user->Variable())) {
return true;
}
} else {
@@ -1290,7 +1297,7 @@
// Process call arguments
Node* cf_last_arg = cf;
- std::vector<Node*> args;
+ utils::Vector<Node*, 8> args;
for (size_t i = 0; i < call->args.Length(); i++) {
auto [cf_i, arg_i] = ProcessExpression(cf_last_arg, call->args[i]);
@@ -1303,7 +1310,7 @@
arg_node->AddEdge(arg_i);
cf_last_arg = cf_i;
- args.push_back(arg_node);
+ args.Push(arg_node);
}
// Note: This is an additional node that isn't described in the specification, for the
@@ -1341,11 +1348,11 @@
[&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process
// functions in dependency order.
- TINT_ASSERT(Resolver, functions_.count(func->Declaration()));
- auto& info = functions_.at(func->Declaration());
- callsite_tag = info.callsite_tag;
- function_tag = info.function_tag;
- func_info = &info;
+ auto* info = functions_.Find(func->Declaration());
+ TINT_ASSERT(Resolver, info != nullptr);
+ callsite_tag = info->callsite_tag;
+ function_tag = info->function_tag;
+ func_info = info;
},
[&](const sem::TypeInitializer*) {
callsite_tag = CallSiteNoRestriction;
@@ -1371,7 +1378,7 @@
result->AddEdge(cf_after);
// For each argument, add edges based on parameter tags.
- for (size_t i = 0; i < args.size(); i++) {
+ for (size_t i = 0; i < args.Length(); i++) {
if (func_info) {
switch (func_info->parameters[i].tag) {
case ParameterRequiredToBeUniform:
@@ -1429,11 +1436,11 @@
/// @param source the starting node
/// @param reachable the set of reachable nodes to populate, if required
void Traverse(Node* source, utils::UniqueVector<Node*, 4>* reachable = nullptr) {
- std::vector<Node*> to_visit{source};
+ utils::Vector<Node*, 8> to_visit{source};
- while (!to_visit.empty()) {
- auto* node = to_visit.back();
- to_visit.pop_back();
+ while (!to_visit.IsEmpty()) {
+ auto* node = to_visit.Back();
+ to_visit.Pop();
if (reachable) {
reachable->Add(node);
@@ -1441,7 +1448,7 @@
for (auto* to : node->edges) {
if (to->visited_from == nullptr) {
to->visited_from = node;
- to_visit.push_back(to);
+ to_visit.Push(to);
}
}
}
@@ -1473,8 +1480,8 @@
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
- auto& target_info = functions_.at(user->Declaration());
- for (auto* call_node : target_info.required_to_be_uniform->edges) {
+ auto* target_info = functions_.Find(user->Declaration());
+ for (auto* call_node : target_info->required_to_be_uniform->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
return FindBuiltinThatRequiresUniformity(child_call);
@@ -1643,9 +1650,9 @@
// If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
- auto& next_function = functions_.at(user->Declaration());
- Node* next_cause = next_function.parameters[cause->arg_index].init_value;
- MakeError(next_function, next_cause, true);
+ auto* next_function = functions_.Find(user->Declaration());
+ Node* next_cause = next_function->parameters[cause->arg_index].init_value;
+ MakeError(*next_function, next_cause, true);
}
} else {
// The requirement was on a function callsite.
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index a6d98e2..42bee0d 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -580,8 +580,8 @@
bool Validator::GlobalVariable(
const sem::GlobalVariable* global,
- const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
- const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
+ const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_ids,
+ const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const {
auto* decl = global->Declaration();
if (global->AddressSpace() != ast::AddressSpace::kWorkgroup &&
IsArrayWithOverrideCount(global->Type())) {
@@ -702,7 +702,7 @@
// buffer variables with a read_write access mode.
bool Validator::AtomicVariable(
const sem::Variable* var,
- std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const {
+ const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const {
auto address_space = var->AddressSpace();
auto* decl = var->Declaration();
auto access = var->Access();
@@ -716,14 +716,13 @@
return false;
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
- auto found = atomic_composite_info.find(type);
- if (found != atomic_composite_info.end()) {
+ if (auto* found = atomic_composite_info.Find(type)) {
if (address_space != ast::AddressSpace::kStorage &&
address_space != ast::AddressSpace::kWorkgroup) {
AddError("atomic variables must have <storage> or <workgroup> address space",
source);
AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here",
- found->second);
+ **found);
return false;
} else if (address_space == ast::AddressSpace::kStorage &&
access != ast::Access::kReadWrite) {
@@ -732,7 +731,7 @@
"access mode",
source);
AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + "' is declared here",
- found->second);
+ **found);
return false;
}
}
@@ -783,7 +782,7 @@
bool Validator::Override(
const sem::GlobalVariable* v,
- const std::unordered_map<OverrideId, const sem::Variable*>& override_ids) const {
+ const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_ids) const {
auto* decl = v->Declaration();
auto* storage_ty = v->Type()->UnwrapRef();
@@ -796,12 +795,12 @@
for (auto* attr : decl->attributes) {
if (attr->Is<ast::IdAttribute>()) {
auto id = v->OverrideId();
- if (auto it = override_ids.find(id); it != override_ids.end() && it->second != v) {
+ if (auto* var = override_ids.Find(id); var && *var != v) {
AddError("@id values must be unique", attr->source);
- AddNote("a override with an ID of " + std::to_string(id.value) +
- " was previously declared here:",
- ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes)
- ->source);
+ AddNote(
+ "a override with an ID of " + std::to_string(id.value) +
+ " was previously declared here:",
+ ast::GetAttribute<ast::IdAttribute>((*var)->Declaration()->attributes)->source);
return false;
}
} else {
@@ -1093,8 +1092,8 @@
// order to catch conflicts.
// TODO(jrprice): This state could be stored in sem::Function instead, and then passed to
// sem::Function since it would be useful there too.
- std::unordered_set<ast::BuiltinValue> builtins;
- std::unordered_set<uint32_t> locations;
+ utils::Hashset<ast::BuiltinValue, 4> builtins;
+ utils::Hashset<uint32_t, 8> locations;
enum class ParamOrRetType {
kParameter,
kReturnType,
@@ -1130,7 +1129,7 @@
}
pipeline_io_attribute = attr;
- if (builtins.count(builtin->builtin)) {
+ if (builtins.Contains(builtin->builtin)) {
AddError(attr_to_str(builtin) +
" attribute appears multiple times as pipeline " +
(param_or_ret == ParamOrRetType::kParameter ? "input" : "output"),
@@ -1142,7 +1141,7 @@
/* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
return false;
}
- builtins.emplace(builtin->builtin);
+ builtins.Add(builtin->builtin);
} else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", attr->source);
@@ -1287,8 +1286,8 @@
// Clear IO sets after parameter validation. Builtin and location attributes in return types
// should be validated independently from those used in parameters.
- builtins.clear();
- locations.clear();
+ builtins.Clear();
+ locations.Clear();
if (!func->ReturnType()->Is<sem::Void>()) {
if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(),
@@ -1299,7 +1298,7 @@
}
if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
- builtins.count(ast::BuiltinValue::kPosition) == 0) {
+ !builtins.Contains(ast::BuiltinValue::kPosition)) {
// Check module-scope variables, as the SPIR-V sanitizer generates these.
bool found = false;
for (auto* global : func->TransitivelyReferencedGlobals()) {
@@ -1327,18 +1326,18 @@
}
// Validate there are no resource variable binding collisions
- std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
+ utils::Hashmap<sem::BindingPoint, const ast::Variable*, 8> binding_points;
for (auto* global : func->TransitivelyReferencedGlobals()) {
auto* var_decl = global->Declaration()->As<ast::Var>();
if (!var_decl || !var_decl->HasBindingPoint()) {
continue;
}
auto bp = global->BindingPoint();
- auto res = binding_points.emplace(bp, var_decl);
- if (!res.second &&
+ if (auto added = binding_points.Add(bp, var_decl);
+ !added &&
IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kBindingPointCollision) &&
- IsValidationEnabled(res.first->second->attributes,
+ IsValidationEnabled((*added.value)->attributes,
ast::DisabledValidation::kBindingPointCollision)) {
// https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
// Bindings must not alias within a shader stage: two different variables in the
@@ -1350,7 +1349,7 @@
"' references multiple variables that use the same resource binding @group(" +
std::to_string(bp.group) + "), @binding(" + std::to_string(bp.binding) + ")",
var_decl->source);
- AddNote("first resource binding usage declared here", res.first->second->source);
+ AddNote("first resource binding usage declared here", (*added.value)->source);
return false;
}
}
@@ -1917,7 +1916,7 @@
return true;
}
-bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points) const {
+bool Validator::PipelineStages(const utils::VectorRef<sem::Function*> entry_points) const {
auto backtrace = [&](const sem::Function* func, const sem::Function* entry_point) {
if (func != entry_point) {
TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) {
@@ -2012,7 +2011,7 @@
return true;
}
-bool Validator::PushConstants(const std::vector<sem::Function*>& entry_points) const {
+bool Validator::PushConstants(const utils::VectorRef<sem::Function*> entry_points) const {
for (auto* entry_point : entry_points) {
// State checked and modified by check_push_constant so that it remembers previously seen
// push_constant variables for an entry-point.
@@ -2130,7 +2129,7 @@
return false;
}
- std::unordered_set<uint32_t> locations;
+ utils::Hashset<uint32_t, 8> locations;
for (auto* member : str->Members()) {
if (auto* r = member->Type()->As<sem::Array>()) {
if (r->IsRuntimeSized()) {
@@ -2248,7 +2247,7 @@
bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location,
const sem::Type* type,
- std::unordered_set<uint32_t>& locations,
+ utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input) const {
@@ -2269,12 +2268,11 @@
return false;
}
- if (locations.count(location)) {
+ if (!locations.Add(location)) {
AddError(attr_to_str(loc_attr, location) + " attribute appears multiple times",
loc_attr->source);
return false;
}
- locations.emplace(location);
return true;
}
@@ -2311,7 +2309,7 @@
}
const sem::CaseSelector* default_selector = nullptr;
- std::unordered_map<int64_t, Source> selectors;
+ utils::Hashmap<int64_t, Source, 4> selectors;
for (auto* case_stmt : s->body) {
auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
@@ -2338,18 +2336,16 @@
}
auto value = selector->Value()->As<uint32_t>();
- auto it = selectors.find(value);
- if (it != selectors.end()) {
+ if (auto added = selectors.Add(value, selector->Declaration()->source); !added) {
AddError("duplicate switch case '" +
(decl_ty->IsAnyOf<sem::I32, sem::AbstractNumeric>()
? std::to_string(i32(value))
: std::to_string(value)) +
"'",
selector->Declaration()->source);
- AddNote("previous case declared here", it->second);
+ AddNote("previous case declared here", *added.value);
return false;
}
- selectors.emplace(value, selector->Declaration()->source);
}
}
@@ -2477,12 +2473,12 @@
}
bool Validator::NoDuplicateAttributes(utils::VectorRef<const ast::Attribute*> attributes) const {
- std::unordered_map<const TypeInfo*, Source> seen;
+ utils::Hashmap<const TypeInfo*, Source, 8> seen;
for (auto* d : attributes) {
- auto res = seen.emplace(&d->TypeInfo(), d->source);
- if (!res.second && !d->Is<ast::InternalAttribute>()) {
+ auto added = seen.Add(&d->TypeInfo(), d->source);
+ if (!added && !d->Is<ast::InternalAttribute>()) {
AddError("duplicate " + d->Name() + " attribute", d->source);
- AddNote("first attribute declared here", res.first->second);
+ AddNote("first attribute declared here", *added.value);
return false;
}
}
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index 082147b..577b688 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -17,16 +17,15 @@
#include <set>
#include <string>
-#include <unordered_map>
-#include <unordered_set>
#include <utility>
-#include <vector>
#include "src/tint/ast/pipeline_stage.h"
#include "src/tint/program_builder.h"
#include "src/tint/resolver/sem_helper.h"
#include "src/tint/sem/evaluation_stage.h"
#include "src/tint/source.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/vector.h"
// Forward declarations
namespace tint::ast {
@@ -116,12 +115,12 @@
/// Validates pipeline stages
/// @param entry_points the entry points to the module
/// @returns true on success, false otherwise.
- bool PipelineStages(const std::vector<sem::Function*>& entry_points) const;
+ bool PipelineStages(const utils::VectorRef<sem::Function*> entry_points) const;
/// Validates push_constant variables
/// @param entry_points the entry points to the module
/// @returns true on success, false otherwise.
- bool PushConstants(const std::vector<sem::Function*>& entry_points) const;
+ bool PushConstants(const utils::VectorRef<sem::Function*> entry_points) const;
/// Validates aliases
/// @param alias the alias to validate
@@ -156,7 +155,7 @@
/// @returns true on success, false otherwise.
bool AtomicVariable(
const sem::Variable* var,
- std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const;
+ const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const;
/// Validates an assignment
/// @param a the assignment statement
@@ -248,8 +247,8 @@
/// @returns true on success, false otherwise
bool GlobalVariable(
const sem::GlobalVariable* var,
- const std::unordered_map<OverrideId, const sem::Variable*>& override_id,
- const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const;
+ const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_id,
+ const utils::Hashmap<const sem::Type*, const Source*, 8>& atomic_composite_info) const;
/// Validates a break-if statement
/// @param stmt the statement to validate
@@ -297,7 +296,7 @@
bool LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location,
const sem::Type* type,
- std::unordered_set<uint32_t>& locations,
+ utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false) const;
@@ -392,7 +391,7 @@
/// @param override_id the set of override ids in the module
/// @returns true on success, false otherwise.
bool Override(const sem::GlobalVariable* v,
- const std::unordered_map<OverrideId, const sem::Variable*>& override_id) const;
+ const utils::Hashmap<OverrideId, const sem::Variable*, 8>& override_id) const;
/// Validates a 'const' variable declaration
/// @param v the variable to validate
diff --git a/src/tint/scope_stack.h b/src/tint/scope_stack.h
index 6838f5b..a2da4dd 100644
--- a/src/tint/scope_stack.h
+++ b/src/tint/scope_stack.h
@@ -14,11 +14,11 @@
#ifndef SRC_TINT_SCOPE_STACK_H_
#define SRC_TINT_SCOPE_STACK_H_
-#include <unordered_map>
#include <utility>
-#include <vector>
#include "src/tint/symbol.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/vector.h"
namespace tint {
@@ -27,22 +27,13 @@
template <class K, class V>
class ScopeStack {
public:
- /// Constructor
- ScopeStack() {
- // Push global bucket
- stack_.push_back({});
- }
- /// Copy Constructor
- ScopeStack(const ScopeStack&) = default;
- ~ScopeStack() = default;
-
/// Push a new scope on to the stack
- void Push() { stack_.push_back({}); }
+ void Push() { stack_.Push({}); }
/// Pop the scope off the top of the stack
void Pop() {
- if (stack_.size() > 1) {
- stack_.pop_back();
+ if (stack_.Length() > 1) {
+ stack_.Pop();
}
}
@@ -52,8 +43,13 @@
/// @returns the old value if there was an existing key at the top of the
/// stack, otherwise the zero initializer for type T.
V Set(const K& key, V val) {
- std::swap(val, stack_.back()[key]);
- return val;
+ auto& back = stack_.Back();
+ if (auto* el = back.Find(key)) {
+ std::swap(val, *el);
+ return val;
+ }
+ back.Add(key, val);
+ return {};
}
/// Retrieves a value from the stack
@@ -61,10 +57,8 @@
/// @returns the value, or the zero initializer if the value was not found
V Get(const K& key) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
- auto& map = *iter;
- auto val = map.find(key);
- if (val != map.end()) {
- return val->second;
+ if (auto* val = iter->Find(key)) {
+ return *val;
}
}
@@ -73,16 +67,16 @@
/// Return the top scope of the stack.
/// @returns the top scope of the stack
- const std::unordered_map<K, V>& Top() const { return stack_.back(); }
+ const utils::Hashmap<K, V, 8>& Top() const { return stack_.Back(); }
/// Clear the scope stack.
void Clear() {
- stack_.clear();
- stack_.push_back({});
+ stack_.Clear();
+ stack_.Push({});
}
private:
- std::vector<std::unordered_map<K, V>> stack_;
+ utils::Vector<utils::Hashmap<K, V, 8>, 8> stack_ = {{}};
};
} // namespace tint
diff --git a/src/tint/utils/hash.h b/src/tint/utils/hash.h
index 717b35f..89cf0f0 100644
--- a/src/tint/utils/hash.h
+++ b/src/tint/utils/hash.h
@@ -157,9 +157,9 @@
template <typename T>
struct UnorderedKeyWrapper {
/// The wrapped value
- const T value;
+ T value;
/// The hash of value
- const size_t hash;
+ size_t hash;
/// Constructor
/// @param v the value to wrap
diff --git a/src/tint/utils/hashmap_base.h b/src/tint/utils/hashmap_base.h
index ca0712a..ee52dad 100644
--- a/src/tint/utils/hashmap_base.h
+++ b/src/tint/utils/hashmap_base.h
@@ -524,7 +524,7 @@
/// Shuffles slots for an insertion that has been placed one slot before `start`.
/// @param start the index of the first slot to start shuffling.
/// @param evicted the slot content that was evicted for the insertion.
- void InsertShuffle(size_t start, Slot evicted) {
+ void InsertShuffle(size_t start, Slot&& evicted) {
Scan(start, [&](size_t, size_t index) {
auto& slot = slots_[index];