Add std::hash<tint::Symbol> specialization
Allows symbols to be used as keys for std::unordered_map and std::unordered_set.
Replace all map / set use of uint32_t for Symbol, where applicable.
Change-Id: If142b4ad1f0ee65bc62209ae2f277e7746be19bb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37262
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/scope_stack.h b/src/scope_stack.h
index e266886..a1619c4 100644
--- a/src/scope_stack.h
+++ b/src/scope_stack.h
@@ -48,14 +48,12 @@
/// Set a global variable in the stack
/// @param symbol the symbol of the variable
/// @param val the value
- void set_global(const Symbol& symbol, T val) {
- stack_[0][symbol.value()] = val;
- }
+ void set_global(const Symbol& symbol, T val) { stack_[0][symbol] = val; }
/// Sets variable into the top most scope of the stack
/// @param symbol the symbol of the variable
/// @param val the value
- void set(const Symbol& symbol, T val) { stack_.back()[symbol.value()] = val; }
+ void set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
/// Checks for the given `symbol` in the stack
/// @param symbol the symbol to look for
@@ -79,7 +77,7 @@
bool get(const Symbol& symbol, T* ret, bool* is_global) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter;
- auto val = map.find(symbol.value());
+ auto val = map.find(symbol);
if (val != map.end()) {
if (ret) {
@@ -95,7 +93,7 @@
}
private:
- std::vector<std::unordered_map<uint32_t, T>> stack_;
+ std::vector<std::unordered_map<Symbol, T>> stack_;
};
} // namespace tint
diff --git a/src/symbol.h b/src/symbol.h
index c5ab39b..3a92d6f 100644
--- a/src/symbol.h
+++ b/src/symbol.h
@@ -67,4 +67,20 @@
} // namespace tint
+namespace std {
+
+/// Custom std::hash specialization for tint::Symbol so symbols can be used as
+/// keys for std::unordered_map and std::unordered_set.
+template <>
+class hash<tint::Symbol> {
+ public:
+ /// @param sym the symbol to return
+ /// @return the Symbol internal value
+ inline std::size_t operator()(const tint::Symbol& sym) const {
+ return static_cast<std::size_t>(sym.value());
+ }
+};
+
+} // namespace std
+
#endif // SRC_SYMBOL_H_
diff --git a/src/symbol_table.cc b/src/symbol_table.cc
index 8998ee2..11b903e 100644
--- a/src/symbol_table.cc
+++ b/src/symbol_table.cc
@@ -40,7 +40,7 @@
++next_symbol_;
name_to_symbol_[name] = sym;
- symbol_to_name_[sym.value()] = name;
+ symbol_to_name_[sym] = name;
return sym;
}
@@ -51,7 +51,7 @@
}
std::string SymbolTable::NameFor(const Symbol symbol) const {
- auto it = symbol_to_name_.find(symbol.value());
+ auto it = symbol_to_name_.find(symbol);
if (it == symbol_to_name_.end())
return "";
diff --git a/src/symbol_table.h b/src/symbol_table.h
index 1c08598..b3bf3fe 100644
--- a/src/symbol_table.h
+++ b/src/symbol_table.h
@@ -62,7 +62,7 @@
// The value to be associated to the next registered symbol table entry.
uint32_t next_symbol_ = 1;
- std::unordered_map<uint32_t, std::string> symbol_to_name_;
+ std::unordered_map<Symbol, std::string> symbol_to_name_;
std::unordered_map<std::string, Symbol> name_to_symbol_;
};
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 0dc61d4..f16104d 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -124,7 +124,7 @@
if (!func->IsEntryPoint()) {
continue;
}
- for (const auto& callee : caller_to_callee_[func->symbol().value()]) {
+ for (const auto& callee : caller_to_callee_[func->symbol()]) {
set_entry_points(callee, func->symbol());
}
}
@@ -133,9 +133,9 @@
}
void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
- symbol_to_function_[fn_sym.value()]->add_ancestor_entry_point(ep_sym);
+ symbol_to_function_[fn_sym]->add_ancestor_entry_point(ep_sym);
- for (const auto& callee : caller_to_callee_[fn_sym.value()]) {
+ for (const auto& callee : caller_to_callee_[fn_sym]) {
set_entry_points(callee, ep_sym);
}
}
@@ -150,7 +150,7 @@
}
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
- symbol_to_function_[func->symbol().value()] = func;
+ symbol_to_function_[func->symbol()] = func;
current_function_ = func;
@@ -389,7 +389,7 @@
}
} else {
if (current_function_) {
- caller_to_callee_[current_function_->symbol().value()].push_back(
+ caller_to_callee_[current_function_->symbol()].push_back(
ident->symbol());
auto* callee_func = mod_->FindFunctionBySymbol(ident->symbol());
@@ -906,7 +906,7 @@
return true;
}
- auto iter = symbol_to_function_.find(symbol.value());
+ auto iter = symbol_to_function_.find(symbol);
if (iter != symbol_to_function_.end()) {
expr->set_result_type(iter->second->return_type());
return true;
diff --git a/src/type_determiner.h b/src/type_determiner.h
index d13b97c..479bbca 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -129,11 +129,11 @@
ast::Module* mod_;
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
- std::unordered_map<uint32_t, ast::Function*> symbol_to_function_;
+ std::unordered_map<Symbol, ast::Function*> symbol_to_function_;
ast::Function* current_function_ = nullptr;
// Map from caller functions to callee functions.
- std::unordered_map<uint32_t, std::vector<Symbol>> caller_to_callee_;
+ std::unordered_map<Symbol, std::vector<Symbol>> caller_to_callee_;
};
} // namespace tint
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index af842ae..145a8cb 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -160,7 +160,7 @@
}
}
- std::unordered_set<uint32_t> emitted_globals;
+ std::unordered_set<Symbol> emitted_globals;
// Make sure all entry point data is emitted before the entry point functions
for (auto* func : module_->functions()) {
if (!func->IsEntryPoint()) {
@@ -198,14 +198,14 @@
Symbol sym;
switch (type) {
case VarType::kIn: {
- auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_it != ep_sym_to_in_data_.end()) {
sym = in_it->second.var_symbol;
}
break;
}
case VarType::kOut: {
- auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ auto outit = ep_sym_to_out_data_.find(current_ep_sym_);
if (outit != ep_sym_to_out_data_.end()) {
sym = outit->second.var_symbol;
}
@@ -1279,14 +1279,14 @@
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
- auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
+ auto in_it = ep_sym_to_in_data_.find(ep_sym);
if (in_it != ep_sym_to_in_data_.end()) {
out << "in " << namer_->NameFor(in_it->second.struct_symbol) << " "
<< namer_->NameFor(in_it->second.var_symbol);
first = false;
}
- auto outit = ep_sym_to_out_data_.find(ep_sym.value());
+ auto outit = ep_sym_to_out_data_.find(ep_sym);
if (outit != ep_sym_to_out_data_.end()) {
if (!first) {
out << ", ";
@@ -1328,7 +1328,7 @@
bool GeneratorImpl::EmitEntryPointData(
std::ostream& out,
ast::Function* func,
- std::unordered_set<uint32_t>& emitted_globals) {
+ std::unordered_set<Symbol>& emitted_globals) {
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> in_variables;
std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>> outvariables;
for (auto data : func->referenced_location_variables()) {
@@ -1369,10 +1369,10 @@
// If the global has already been emitted we skip it, it's been emitted by
// a previous entry point.
- if (emitted_globals.count(var->symbol().value()) != 0) {
+ if (emitted_globals.count(var->symbol()) != 0) {
continue;
}
- emitted_globals.insert(var->symbol().value());
+ emitted_globals.insert(var->symbol());
auto* type = var->type()->UnwrapIfNeeded();
if (auto* strct = type->As<ast::type::Struct>()) {
@@ -1413,10 +1413,10 @@
// If the global has already been emitted we skip it, it's been emitted by
// a previous entry point.
- if (emitted_globals.count(var->symbol().value()) != 0) {
+ if (emitted_globals.count(var->symbol()) != 0) {
continue;
}
- emitted_globals.insert(var->symbol().value());
+ emitted_globals.insert(var->symbol());
auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac == nullptr) {
@@ -1439,8 +1439,8 @@
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
- ep_sym_to_in_data_[func->symbol().value()] = {
- in_struct_sym, module_->RegisterSymbol(in_var_name)};
+ ep_sym_to_in_data_[func->symbol()] = {in_struct_sym,
+ module_->RegisterSymbol(in_var_name)};
make_indent(out);
out << "struct " << namer_->NameFor(in_struct_sym) << " {" << std::endl;
@@ -1486,7 +1486,7 @@
auto outstruct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
auto outvar_name = namer_->GenerateName(kTintStructOutVarPrefix);
- ep_sym_to_out_data_[func->symbol().value()] = {
+ ep_sym_to_out_data_[func->symbol()] = {
outstruct_sym, module_->RegisterSymbol(outvar_name)};
make_indent(out);
@@ -1577,7 +1577,7 @@
make_indent(out);
}
- auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) {
out << namer_->NameFor(outdata->second.struct_symbol);
@@ -1586,7 +1586,7 @@
}
out << " " << namer_->NameFor(current_ep_sym_) << "(";
- auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_data != ep_sym_to_in_data_.end()) {
out << namer_->NameFor(in_data->second.struct_symbol) << " "
<< namer_->NameFor(in_data->second.var_symbol);
@@ -2023,7 +2023,7 @@
if (generating_entry_point_) {
out << "return";
- auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ auto outdata = ep_sym_to_out_data_.find(current_ep_sym_);
if (outdata != ep_sym_to_out_data_.end()) {
out << " " << namer_->NameFor(outdata->second.var_symbol);
}
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 8f8a37b..8a96cc5 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -225,7 +225,7 @@
/// @returns true if the entry point data was emitted
bool EmitEntryPointData(std::ostream& out,
ast::Function* func,
- std::unordered_set<uint32_t>& emitted_globals);
+ std::unordered_set<Symbol>& emitted_globals);
/// Handles emitting the entry point function
/// @param out the output stream
/// @param func the entry point
@@ -395,8 +395,8 @@
bool generating_entry_point_ = false;
uint32_t loop_emission_counter_ = 0;
ScopeStack<ast::Variable*> global_variables_;
- std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
- std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
+ std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
+ std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did
diff --git a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc
index 2f58c8e..8dad257 100644
--- a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc
+++ b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc
@@ -72,7 +72,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@@ -122,7 +122,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@@ -172,7 +172,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@@ -222,7 +222,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@@ -269,7 +269,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@@ -311,7 +311,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_FALSE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
@@ -361,7 +361,7 @@
mod->AddFunction(func);
- std::unordered_set<uint32_t> globals;
+ std::unordered_set<Symbol> globals;
ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(gen.EmitEntryPointData(out, func, globals)) << gen.error();
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index e21221e..63f760e 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -400,14 +400,14 @@
Symbol sym;
switch (type) {
case VarType::kIn: {
- auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_it != ep_sym_to_in_data_.end()) {
sym = in_it->second.var_symbol;
}
break;
}
case VarType::kOut: {
- auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ auto out_it = ep_sym_to_out_data_.find(current_ep_sym_);
if (out_it != ep_sym_to_out_data_.end()) {
sym = out_it->second.var_symbol;
}
@@ -1061,7 +1061,7 @@
auto in_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kInStructNameSuffix));
auto in_var_name = namer_->GenerateName(kTintStructInVarPrefix);
- ep_sym_to_in_data_[func->symbol().value()] = {
+ ep_sym_to_in_data_[func->symbol()] = {
in_struct_sym, module_->RegisterSymbol(in_var_name)};
make_indent();
@@ -1099,7 +1099,7 @@
auto out_struct_sym = module_->RegisterSymbol(namer_->GenerateName(
module_->SymbolToName(func->symbol()) + "_" + kOutStructNameSuffix));
auto out_var_name = namer_->GenerateName(kTintStructOutVarPrefix);
- ep_sym_to_out_data_[func->symbol().value()] = {
+ ep_sym_to_out_data_[func->symbol()] = {
out_struct_sym, module_->RegisterSymbol(out_var_name)};
make_indent();
@@ -1284,14 +1284,14 @@
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
- auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
+ auto in_it = ep_sym_to_in_data_.find(ep_sym);
if (in_it != ep_sym_to_in_data_.end()) {
out_ << "thread " << namer_->NameFor(in_it->second.struct_symbol) << "& "
<< namer_->NameFor(in_it->second.var_symbol);
first = false;
}
- auto out_it = ep_sym_to_out_data_.find(ep_sym.value());
+ auto out_it = ep_sym_to_out_data_.find(ep_sym);
if (out_it != ep_sym_to_out_data_.end()) {
if (!first) {
out_ << ", ";
@@ -1421,7 +1421,7 @@
// This is an entry point, the return type is the entry point output structure
// if one exists, or void otherwise.
- auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) {
out_ << namer_->NameFor(out_data->second.struct_symbol);
@@ -1431,7 +1431,7 @@
out_ << " " << namer_->NameFor(func->symbol()) << "(";
bool first = true;
- auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_data != ep_sym_to_in_data_.end()) {
out_ << namer_->NameFor(in_data->second.struct_symbol) << " "
<< namer_->NameFor(in_data->second.var_symbol) << " [[stage_in]]";
@@ -1734,7 +1734,7 @@
out_ << "return";
if (generating_entry_point_) {
- auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
if (out_data != ep_sym_to_out_data_.end()) {
out_ << " " << namer_->NameFor(out_data->second.var_symbol);
}
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index ece2afe..75ea8d4 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -281,8 +281,8 @@
uint32_t loop_emission_counter_ = 0;
Namer* namer_;
- std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
- std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
+ std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
+ std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 9cdf602..edc1fa3 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -587,7 +587,7 @@
scope_stack_.pop_scope();
- func_symbol_to_id_[func->symbol().value()] = func_id;
+ func_symbol_to_id_[func->symbol()] = func_id;
return true;
}
@@ -1814,7 +1814,7 @@
OperandList ops = {Operand::Int(type_id), result};
- auto func_id = func_symbol_to_id_[ident->symbol().value()];
+ auto func_id = func_symbol_to_id_[ident->symbol()];
if (func_id == 0) {
error_ = "unable to find called function: " +
mod_->SymbolToName(ident->symbol());
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 2c2c38d..e985e93 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -508,7 +508,7 @@
std::vector<Function> functions_;
std::unordered_map<std::string, uint32_t> import_name_to_id_;
- std::unordered_map<uint32_t, uint32_t> func_symbol_to_id_;
+ std::unordered_map<Symbol, uint32_t> func_symbol_to_id_;
std::unordered_map<std::string, uint32_t> type_name_to_id_;
std::unordered_map<std::string, uint32_t> const_to_id_;
std::unordered_map<std::string, uint32_t>