Add a symbol to the Function AST node.
This Cl adds a Symbol representing the function name to the function
AST. The symbol is added alongside the name for now. When all usages of
the function name are removed then the string version will be removed
from the constructor.
Change-Id: Ib2450e5fe531e988b25bb7d2937acc6af2187871
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35220
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 26522e8..ae69287 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -31,12 +31,14 @@
namespace ast {
Function::Function(const Source& source,
+ Symbol symbol,
const std::string& name,
VariableList params,
type::Type* return_type,
BlockStatement* body,
FunctionDecorationList decorations)
: Base(source),
+ symbol_(symbol),
name_(name),
params_(std::move(params)),
return_type_(return_type),
@@ -202,7 +204,7 @@
return ret;
}
-void Function::add_ancestor_entry_point(const std::string& ep) {
+void Function::add_ancestor_entry_point(Symbol ep) {
for (const auto& point : ancestor_entry_points_) {
if (point == ep) {
return;
@@ -211,9 +213,9 @@
ancestor_entry_points_.push_back(ep);
}
-bool Function::HasAncestorEntryPoint(const std::string& name) const {
+bool Function::HasAncestorEntryPoint(Symbol symbol) const {
for (const auto& point : ancestor_entry_points_) {
- if (point == name) {
+ if (point == symbol) {
return true;
}
}
@@ -226,7 +228,7 @@
Function* Function::Clone(CloneContext* ctx) const {
return ctx->mod->create<Function>(
- ctx->Clone(source()), name_, ctx->Clone(params_),
+ ctx->Clone(source()), symbol_, name_, ctx->Clone(params_),
ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_));
}
@@ -238,7 +240,7 @@
if (body_ == nullptr || !body_->IsValid()) {
return false;
}
- if (name_.length() == 0) {
+ if (name_.length() == 0 || !symbol_.IsValid()) {
return false;
}
if (return_type_ == nullptr) {
@@ -249,7 +251,7 @@
void Function::to_str(std::ostream& out, size_t indent) const {
make_indent(out, indent);
- out << "Function " << name_ << " -> " << return_type_->type_name()
+ out << "Function " << symbol_.to_str() << " -> " << return_type_->type_name()
<< std::endl;
for (auto* deco : decorations()) {
diff --git a/src/ast/function.h b/src/ast/function.h
index 61eb529..14b6789 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -35,6 +35,7 @@
#include "src/ast/type/sampler_type.h"
#include "src/ast/type/type.h"
#include "src/ast/variable.h"
+#include "src/symbol.h"
namespace tint {
namespace ast {
@@ -52,12 +53,14 @@
/// Create a function
/// @param source the variable source
+ /// @param symbol the function symbol
/// @param name the function name
/// @param params the function parameters
/// @param return_type the return type
/// @param body the function body
/// @param decorations the function decorations
Function(const Source& source,
+ Symbol symbol,
const std::string& name,
VariableList params,
type::Type* return_type,
@@ -68,6 +71,8 @@
~Function() override;
+ /// @returns the function symbol
+ Symbol symbol() const { return symbol_; }
/// @returns the function name
const std::string& name() { return name_; }
/// @returns the function params
@@ -150,15 +155,15 @@
/// Adds an ancestor entry point
/// @param ep the entry point ancestor
- void add_ancestor_entry_point(const std::string& ep);
+ void add_ancestor_entry_point(Symbol ep);
/// @returns the ancestor entry points
- const std::vector<std::string>& ancestor_entry_points() const {
+ const std::vector<Symbol>& ancestor_entry_points() const {
return ancestor_entry_points_;
}
/// Checks if the given entry point is an ancestor
- /// @param name the entry point name
- /// @returns true if `name` is an ancestor entry point of this function
- bool HasAncestorEntryPoint(const std::string& name) const;
+ /// @param sym the entry point symbol
+ /// @returns true if `sym` is an ancestor entry point of this function
+ bool HasAncestorEntryPoint(Symbol sym) const;
/// @returns the function return type.
type::Type* return_type() const { return return_type_; }
@@ -197,13 +202,14 @@
const std::vector<std::pair<Variable*, Function::BindingInfo>>
ReferencedSampledTextureVariablesImpl(bool multisampled) const;
+ Symbol symbol_;
std::string name_;
VariableList params_;
type::Type* return_type_ = nullptr;
BlockStatement* body_ = nullptr;
std::vector<Variable*> referenced_module_vars_;
std::vector<Variable*> local_referenced_module_vars_;
- std::vector<std::string> ancestor_entry_points_;
+ std::vector<Symbol> ancestor_entry_points_;
FunctionDecorationList decorations_;
};
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 68f9921..e341fc8 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -35,14 +35,18 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
auto* var = params[0];
- Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
- FunctionDecorationList{});
+ Function f(Source{}, func_sym, "func", params, &void_type,
+ create<BlockStatement>(), FunctionDecorationList{});
+ EXPECT_EQ(f.symbol(), func_sym);
EXPECT_EQ(f.name(), "func");
ASSERT_EQ(f.params().size(), 1u);
EXPECT_EQ(f.return_type(), &void_type);
@@ -53,13 +57,16 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
- Function f(Source{Source::Location{20, 2}}, "func", params, &void_type,
- create<BlockStatement>(), FunctionDecorationList{});
+ Function f(Source{Source::Location{20, 2}}, func_sym, "func", params,
+ &void_type, create<BlockStatement>(), FunctionDecorationList{});
auto src = f.source();
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
@@ -69,9 +76,12 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
Variable v(Source{}, "var", StorageClass::kInput, &i32, false, nullptr,
ast::VariableDecorationList{});
- Function f(Source{}, "func", VariableList{}, &void_type,
+ Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(&v);
@@ -92,6 +102,9 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32,
false, nullptr,
ast::VariableDecorationList{
@@ -116,7 +129,7 @@
create<BuiltinDecoration>(Builtin::kFragDepth, Source{}),
});
- Function f(Source{}, "func", VariableList{}, &void_type,
+ Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(loc1);
@@ -137,6 +150,9 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32,
false, nullptr,
ast::VariableDecorationList{
@@ -161,7 +177,7 @@
create<BuiltinDecoration>(Builtin::kFragDepth, Source{}),
});
- Function f(Source{}, "func", VariableList{}, &void_type,
+ Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(loc1);
@@ -180,22 +196,30 @@
TEST_F(FunctionTest, AddDuplicateEntryPoints) {
type::Void void_type;
- Function f(Source{}, "func", VariableList{}, &void_type,
+
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+ auto main_sym = m.RegisterSymbol("main");
+
+ Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
- f.add_ancestor_entry_point("main");
+ f.add_ancestor_entry_point(main_sym);
ASSERT_EQ(1u, f.ancestor_entry_points().size());
- EXPECT_EQ("main", f.ancestor_entry_points()[0]);
+ EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]);
- f.add_ancestor_entry_point("main");
+ f.add_ancestor_entry_point(main_sym);
ASSERT_EQ(1u, f.ancestor_entry_points().size());
- EXPECT_EQ("main", f.ancestor_entry_points()[0]);
+ EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]);
}
TEST_F(FunctionTest, IsValid) {
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@@ -204,21 +228,27 @@
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
- Function f(Source{}, "func", params, &void_type, body,
+ Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_TRUE(f.IsValid());
}
-TEST_F(FunctionTest, IsValid_EmptyName) {
+TEST_F(FunctionTest, IsValid_InvalidName) {
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
- Function f(Source{}, "", params, &void_type, create<BlockStatement>(),
+ auto* body = create<BlockStatement>();
+ body->append(create<DiscardStatement>());
+
+ Function f(Source{}, func_sym, "", params, &void_type, body,
FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@@ -226,13 +256,16 @@
TEST_F(FunctionTest, IsValid_MissingReturnType) {
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
- Function f(Source{}, "func", params, nullptr, create<BlockStatement>(),
- FunctionDecorationList{});
+ Function f(Source{}, func_sym, "func", params, nullptr,
+ create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@@ -240,27 +273,33 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
params.push_back(nullptr);
- Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
- FunctionDecorationList{});
+ Function f(Source{}, func_sym, "func", params, &void_type,
+ create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
TEST_F(FunctionTest, IsValid_InvalidParam) {
type::Void void_type;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone,
nullptr, false, nullptr,
ast::VariableDecorationList{}));
- Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
- FunctionDecorationList{});
+ Function f(Source{}, func_sym, "func", params, &void_type,
+ create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@@ -268,6 +307,9 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@@ -277,7 +319,7 @@
body->append(create<DiscardStatement>());
body->append(nullptr);
- Function f(Source{}, "func", params, &void_type, body,
+ Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
@@ -287,6 +329,9 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@@ -296,7 +341,7 @@
body->append(create<DiscardStatement>());
body->append(nullptr);
- Function f(Source{}, "func", params, &void_type, body,
+ Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@@ -305,14 +350,18 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
- Function f(Source{}, "func", {}, &void_type, body, FunctionDecorationList{});
+ Function f(Source{}, func_sym, "func", {}, &void_type, body,
+ FunctionDecorationList{});
std::ostringstream out;
f.to_str(out, 2);
- EXPECT_EQ(out.str(), R"( Function func -> __void
+ EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
()
{
Discard{}
@@ -324,16 +373,19 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(
- Source{}, "func", {}, &void_type, body,
+ Source{}, func_sym, "func", {}, &void_type, body,
FunctionDecorationList{create<WorkgroupDecoration>(2, 4, 6, Source{})});
std::ostringstream out;
f.to_str(out, 2);
- EXPECT_EQ(out.str(), R"( Function func -> __void
+ EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
WorkgroupDecoration{2 4 6}
()
{
@@ -346,6 +398,9 @@
type::Void void_type;
type::I32 i32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@@ -354,12 +409,12 @@
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
- Function f(Source{}, "func", params, &void_type, body,
+ Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
std::ostringstream out;
f.to_str(out, 2);
- EXPECT_EQ(out.str(), R"( Function func -> __void
+ EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
(
Variable{
var
@@ -376,8 +431,11 @@
TEST_F(FunctionTest, TypeName) {
type::Void void_type;
- Function f(Source{}, "func", {}, &void_type, create<BlockStatement>(),
- FunctionDecorationList{});
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
+ Function f(Source{}, func_sym, "func", {}, &void_type,
+ create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.type_name(), "__func__void");
}
@@ -386,6 +444,9 @@
type::I32 i32;
type::F32 f32;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
params.push_back(create<Variable>(Source{}, "var1", StorageClass::kNone, &i32,
false, nullptr,
@@ -394,19 +455,22 @@
false, nullptr,
ast::VariableDecorationList{}));
- Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
- FunctionDecorationList{});
+ Function f(Source{}, func_sym, "func", params, &void_type,
+ create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.type_name(), "__func__void__i32__f32");
}
TEST_F(FunctionTest, GetLastStatement) {
type::Void void_type;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
auto* body = create<BlockStatement>();
auto* stmt = create<DiscardStatement>();
body->append(stmt);
- Function f(Source{}, "func", params, &void_type, body,
+ Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_EQ(f.get_last_statement(), stmt);
@@ -415,9 +479,12 @@
TEST_F(FunctionTest, GetLastStatement_nullptr) {
type::Void void_type;
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
VariableList params;
auto* body = create<BlockStatement>();
- Function f(Source{}, "func", params, &void_type, body,
+ Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_EQ(f.get_last_statement(), nullptr);
@@ -425,8 +492,12 @@
TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
type::Void void_type;
- Function f(Source{}, "f", {}, &void_type, create<BlockStatement>(),
- FunctionDecorationList{});
+
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
+ Function f(Source{}, func_sym, "func", {}, &void_type,
+ create<BlockStatement>(), FunctionDecorationList{});
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
@@ -438,7 +509,12 @@
TEST_F(FunctionTest, WorkgroupSize) {
type::Void void_type;
- Function f(Source{}, "f", {}, &void_type, create<BlockStatement>(),
+
+ Module m;
+ auto func_sym = m.RegisterSymbol("func");
+
+ Function f(Source{}, func_sym, "func", {}, &void_type,
+ create<BlockStatement>(),
{create<WorkgroupDecoration>(2u, 4u, 6u, Source{})});
uint32_t x = 0;
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 5dcc12b..2132e1f 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -47,21 +47,23 @@
for (auto* func : functions_) {
ctx->mod->functions_.emplace_back(ctx->Clone(func));
}
+
+ ctx->mod->symbol_table_ = symbol_table_;
}
-Function* Module::FindFunctionByName(const std::string& name) const {
+Function* Module::FindFunctionBySymbol(Symbol sym) const {
for (auto* func : functions_) {
- if (func->name() == name) {
+ if (func->symbol() == sym) {
return func;
}
}
return nullptr;
}
-Function* Module::FindFunctionByNameAndStage(const std::string& name,
- PipelineStage stage) const {
+Function* Module::FindFunctionBySymbolAndStage(Symbol sym,
+ PipelineStage stage) const {
for (auto* func : functions_) {
- if (func->name() == name && func->pipeline_stage() == stage) {
+ if (func->symbol() == sym && func->pipeline_stage() == stage) {
return func;
}
}
@@ -81,6 +83,10 @@
return symbol_table_.Register(name);
}
+Symbol Module::GetSymbol(const std::string& name) const {
+ return symbol_table_.GetSymbol(name);
+}
+
std::string Module::SymbolToName(const Symbol sym) const {
return symbol_table_.NameFor(sym);
}
diff --git a/src/ast/module.h b/src/ast/module.h
index b274be1..1facd79 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -89,15 +89,14 @@
/// @returns the modules functions
const FunctionList& functions() const { return functions_; }
/// Returns the function with the given name
- /// @param name the name to search for
+ /// @param sym the function symbol to search for
/// @returns the associated function or nullptr if none exists
- Function* FindFunctionByName(const std::string& name) const;
+ Function* FindFunctionBySymbol(Symbol sym) const;
/// Returns the function with the given name
- /// @param name the name to search for
+ /// @param sym the function symbol to search for
/// @param stage the pipeline stage
/// @returns the associated function or nullptr if none exists
- Function* FindFunctionByNameAndStage(const std::string& name,
- PipelineStage stage) const;
+ Function* FindFunctionBySymbolAndStage(Symbol sym, PipelineStage stage) const;
/// @param stage the pipeline stage
/// @returns true if the module contains an entrypoint function with the given
/// stage
@@ -169,6 +168,11 @@
/// previously generated symbol will be returned.
Symbol RegisterSymbol(const std::string& name);
+ /// Returns the symbol for `name`
+ /// @param name the name to lookup
+ /// @returns the symbol for name or symbol::kInvalid
+ Symbol GetSymbol(const std::string& name) const;
+
/// Returns the `name` for `sym`
/// @param sym the symbol to retrieve the name for
/// @returns the use provided `name` for the symbol or "" if not found
diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc
index 3303235..6914ecc 100644
--- a/src/ast/module_test.cc
+++ b/src/ast/module_test.cc
@@ -48,16 +48,17 @@
type::F32 f32;
Module m;
+ auto func_sym = m.RegisterSymbol("main");
auto* func =
- create<Function>(Source{}, "main", VariableList{}, &f32,
+ create<Function>(Source{}, func_sym, "main", VariableList{}, &f32,
create<BlockStatement>(), ast::FunctionDecorationList{});
m.AddFunction(func);
- EXPECT_EQ(func, m.FindFunctionByName("main"));
+ EXPECT_EQ(func, m.FindFunctionBySymbol(func_sym));
}
TEST_F(ModuleTest, LookupFunctionMissing) {
Module m;
- EXPECT_EQ(nullptr, m.FindFunctionByName("Missing"));
+ EXPECT_EQ(nullptr, m.FindFunctionBySymbol(m.RegisterSymbol("Missing")));
}
TEST_F(ModuleTest, IsValid_Empty) {
@@ -127,11 +128,12 @@
TEST_F(ModuleTest, IsValid_Function) {
type::F32 f32;
- auto* func =
- create<Function>(Source{}, "main", VariableList(), &f32,
- create<BlockStatement>(), ast::FunctionDecorationList{});
Module m;
+
+ auto* func = create<Function>(Source{}, m.RegisterSymbol("main"), "main",
+ VariableList(), &f32, create<BlockStatement>(),
+ ast::FunctionDecorationList{});
m.AddFunction(func);
EXPECT_TRUE(m.IsValid());
}
@@ -144,10 +146,13 @@
TEST_F(ModuleTest, IsValid_Invalid_Function) {
VariableList p;
- auto* func = create<Function>(Source{}, "", p, nullptr, nullptr,
- ast::FunctionDecorationList{});
Module m;
+
+ auto* func =
+ create<Function>(Source{}, m.RegisterSymbol("main"), "main", p, nullptr,
+ nullptr, ast::FunctionDecorationList{});
+
m.AddFunction(func);
EXPECT_FALSE(m.IsValid());
}
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index 1bbbe8b..7ef2311 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -267,7 +267,7 @@
}
ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
- auto* func = module_.FindFunctionByName(name);
+ auto* func = module_.FindFunctionBySymbol(module_.GetSymbol(name));
if (!func) {
error_ += name + " was not found!";
return nullptr;
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 46ac6e4..fbf24e5 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -83,8 +83,9 @@
ast::FunctionDecorationList decorations = {}) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, name, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(name), name,
+ ast::VariableList(), void_type(), body,
+ decorations);
}
/// Generates a function that calls another
@@ -102,8 +103,9 @@
create<ast::CallExpression>(ident_expr, ast::ExpressionList());
body->append(create<ast::CallStatement>(call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, caller, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(caller),
+ caller, ast::VariableList(), void_type(), body,
+ decorations);
}
/// Add In/Out variables to the global variables
@@ -154,8 +156,9 @@
create<ast::IdentifierExpression>(in)));
}
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, name, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(name), name,
+ ast::VariableList(), void_type(), body,
+ decorations);
}
/// Generates a function that references in/out variables and calls another
@@ -184,8 +187,9 @@
create<ast::CallExpression>(ident_expr, ast::ExpressionList());
body->append(create<ast::CallStatement>(call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, caller, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(caller),
+ caller, ast::VariableList(), void_type(), body,
+ decorations);
}
/// Add a Constant ID to the global variables.
@@ -445,9 +449,9 @@
}
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, func_name, ast::VariableList(),
- void_type(), body,
- ast::FunctionDecorationList{});
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
+ func_name, ast::VariableList(), void_type(),
+ body, ast::FunctionDecorationList{});
}
/// Adds a regular sampler variable to the module
@@ -587,8 +591,9 @@
create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, func_name, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
+ func_name, ast::VariableList(), void_type(),
+ body, decorations);
}
/// Generates a function that references a specific sampler variable
@@ -634,8 +639,9 @@
create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, func_name, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
+ func_name, ast::VariableList(), void_type(),
+ body, decorations);
}
/// Generates a function that references a specific comparison sampler
@@ -682,8 +688,9 @@
create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
- return create<ast::Function>(Source{}, func_name, ast::VariableList(),
- void_type(), body, decorations);
+ return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
+ func_name, ast::VariableList(), void_type(),
+ body, decorations);
}
/// Gets an appropriate type for the data in a given texture type.
@@ -1513,7 +1520,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>(
- Source{}, "ep_func", ast::VariableList(), void_type(), body,
+ Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
+ ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -1659,7 +1667,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>(
- Source{}, "ep_func", ast::VariableList(), void_type(), body,
+ Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
+ ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -1832,7 +1841,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>(
- Source{}, "ep_func", ast::VariableList(), void_type(), body,
+ Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
+ ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index ec024bb..c8affa7 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -761,9 +761,10 @@
}
auto* body = statements_stack_[0].statements_;
- ast_module_.AddFunction(create<ast::Function>(
- decl.source, decl.name, std::move(decl.params), decl.return_type, body,
- std::move(decl.decorations)));
+ ast_module_.AddFunction(
+ create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name),
+ decl.name, std::move(decl.params), decl.return_type,
+ body, std::move(decl.decorations)));
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();
diff --git a/src/reader/spirv/function_call_test.cc b/src/reader/spirv/function_call_test.cc
index a30ec37..fabf819 100644
--- a/src/reader/spirv/function_call_test.cc
+++ b/src/reader/spirv/function_call_test.cc
@@ -46,14 +46,16 @@
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error();
- const auto module_ast_str = p->module().to_str();
+ const auto module_ast_str = p->get_module().to_str();
EXPECT_THAT(module_ast_str, Eq(R"(Module{
- Function x_50 -> __void
+ Function )" + p->get_module().GetSymbol("x_50").to_str() +
+ R"( -> __void
()
{
Return{}
}
- Function x_100 -> __void
+ Function )" + p->get_module().GetSymbol("x_100").to_str() +
+ R"( -> __void
()
{
Call[not set]{
@@ -214,9 +216,10 @@
)"));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error();
EXPECT_TRUE(p->error().empty());
- const auto module_ast_str = p->module().to_str();
+ const auto module_ast_str = p->get_module().to_str();
EXPECT_THAT(module_ast_str, HasSubstr(R"(Module{
- Function x_50 -> __u32
+ Function )" + p->get_module().GetSymbol("x_50").to_str() +
+ R"( -> __u32
(
VariableConst{
x_51
@@ -240,7 +243,8 @@
}
}
}
- Function x_100 -> __void
+ Function )" + p->get_module().GetSymbol("x_100").to_str() +
+ R"( -> __void
()
{
VariableDeclStatement{
diff --git a/src/reader/spirv/function_decl_test.cc b/src/reader/spirv/function_decl_test.cc
index 2223893..398f8fa 100644
--- a/src/reader/spirv/function_decl_test.cc
+++ b/src/reader/spirv/function_decl_test.cc
@@ -59,9 +59,10 @@
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
- auto got = p->module().to_str();
- auto* expect = R"(Module{
- Function x_100 -> __void
+ auto got = p->get_module().to_str();
+ auto expect = R"(Module{
+ Function )" + p->get_module().GetSymbol("x_100").to_str() +
+ R"( -> __void
()
{
Return{}
@@ -83,9 +84,10 @@
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
- auto got = p->module().to_str();
- auto* expect = R"(Module{
- Function x_100 -> __f32
+ auto got = p->get_module().to_str();
+ auto expect = R"(Module{
+ Function )" + p->get_module().GetSymbol("x_100").to_str() +
+ R"( -> __f32
()
{
Return{
@@ -115,9 +117,10 @@
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
- auto got = p->module().to_str();
- auto* expect = R"(Module{
- Function x_100 -> __void
+ auto got = p->get_module().to_str();
+ auto expect = R"(Module{
+ Function )" + p->get_module().GetSymbol("x_100").to_str() +
+ R"( -> __void
(
VariableConst{
a
@@ -159,9 +162,10 @@
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
- auto got = p->module().to_str();
- auto* expect = R"(Module{
- Function x_100 -> __void
+ auto got = p->get_module().to_str();
+ auto expect = R"(Module{
+ Function )" + p->get_module().GetSymbol("x_100").to_str() +
+ R"( -> __void
(
VariableConst{
x_14
diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc
index f322de2..dab5f3a 100644
--- a/src/reader/spirv/parser_impl_function_decl_test.cc
+++ b/src/reader/spirv/parser_impl_function_decl_test.cc
@@ -53,7 +53,7 @@
auto p = parser(test::Assemble(CommonTypes()));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
}
@@ -64,7 +64,7 @@
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
}
@@ -79,9 +79,10 @@
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function main -> __void
+ Function )" + p->get_module().GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{)"));
@@ -98,9 +99,10 @@
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function main -> __void
+ Function )" + p->get_module().GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{fragment}
()
{)"));
@@ -117,9 +119,10 @@
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function main -> __void
+ Function )" + p->get_module().GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{compute}
()
{)"));
@@ -138,14 +141,16 @@
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function frag_main -> __void
+ Function )" + p->get_module().GetSymbol("frag_main").to_str() +
+ R"( -> __void
StageDecoration{fragment}
()
{)"));
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function comp_main -> __void
+ Function )" + p->get_module().GetSymbol("comp_main").to_str() +
+ R"( -> __void
StageDecoration{compute}
()
{)"));
@@ -160,9 +165,10 @@
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function main -> __void
+ Function )" + p->get_module().GetSymbol("main").to_str() +
+ R"( -> __void
()
{)"));
}
@@ -193,9 +199,10 @@
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function leaf -> __u32
+ Function )" + p->get_module().GetSymbol("leaf").to_str() +
+ R"( -> __u32
()
{
Return{
@@ -204,7 +211,8 @@
}
}
}
- Function branch -> __u32
+ Function )" + p->get_module().GetSymbol("branch").to_str() +
+ R"( -> __u32
()
{
VariableDeclStatement{
@@ -227,7 +235,8 @@
}
}
}
- Function root -> __void
+ Function )" + p->get_module().GetSymbol("root").to_str() +
+ R"( -> __void
()
{
VariableDeclStatement{
@@ -260,9 +269,10 @@
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function ret_float -> __f32
+ Function )" + p->get_module().GetSymbol("ret_float").to_str() +
+ R"( -> __f32
()
{
Return{
@@ -289,9 +299,10 @@
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function mixed_params -> __void
+ Function )" + p->get_module().GetSymbol("mixed_params").to_str() +
+ R"( -> __void
(
VariableConst{
a
@@ -328,9 +339,10 @@
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
- const auto module_ast = p->module().to_str();
+ const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
- Function mixed_params -> __void
+ Function )" + p->get_module().GetSymbol("mixed_params").to_str() +
+ R"( -> __void
(
VariableConst{
x_14
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 67a4651..04ff79b 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1280,9 +1280,9 @@
if (errored)
return Failure::kErrored;
- return create<ast::Function>(header->source, header->name, header->params,
- header->return_type, body.value,
- func_decos.value);
+ return create<ast::Function>(
+ header->source, module_.RegisterSymbol(header->name), header->name,
+ header->params, header->return_type, body.value, func_decos.value);
}
// function_type_decl
diff --git a/src/symbol_table.cc b/src/symbol_table.cc
index 13fe13f..8998ee2 100644
--- a/src/symbol_table.cc
+++ b/src/symbol_table.cc
@@ -18,10 +18,14 @@
SymbolTable::SymbolTable() = default;
+SymbolTable::SymbolTable(const SymbolTable&) = default;
+
SymbolTable::SymbolTable(SymbolTable&&) = default;
SymbolTable::~SymbolTable() = default;
+SymbolTable& SymbolTable::operator=(const SymbolTable& other) = default;
+
SymbolTable& SymbolTable::operator=(SymbolTable&&) = default;
Symbol SymbolTable::Register(const std::string& name) {
@@ -41,6 +45,11 @@
return sym;
}
+Symbol SymbolTable::GetSymbol(const std::string& name) const {
+ auto it = name_to_symbol_.find(name);
+ return it != name_to_symbol_.end() ? it->second : Symbol();
+}
+
std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol.value());
if (it == symbol_to_name_.end())
diff --git a/src/symbol_table.h b/src/symbol_table.h
index e3351e1..1c08598 100644
--- a/src/symbol_table.h
+++ b/src/symbol_table.h
@@ -27,11 +27,17 @@
public:
/// Constructor
SymbolTable();
+ /// Copy constructor
+ SymbolTable(const SymbolTable&);
/// Move Constructor
SymbolTable(SymbolTable&&);
/// Destructor
~SymbolTable();
+ /// Copy assignment
+ /// @param other the symbol table to copy
+ /// @returns the new symbol table
+ SymbolTable& operator=(const SymbolTable& other);
/// Move assignment
/// @param other the symbol table to move
/// @returns the symbol table
@@ -42,6 +48,11 @@
/// @returns the symbol representing the given name
Symbol Register(const std::string& name);
+ /// Returns the symbol for the given `name`
+ /// @param name the name to lookup
+ /// @returns the symbol for the name or symbol::kInvalid if not found.
+ Symbol GetSymbol(const std::string& name) const;
+
/// Returns the name for the given symbol
/// @param symbol the symbol to retrieve the name for
/// @returns the symbol name or "" if not found
diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc
index 6e1f171..24a7242 100644
--- a/src/transform/bound_array_accessors_test.cc
+++ b/src/transform/bound_array_accessors_test.cc
@@ -51,7 +51,7 @@
template <typename T = ast::Expression>
T* FindVariable(ast::Module* mod, std::string name) {
- if (auto* func = mod->FindFunctionByName("func")) {
+ if (auto* func = mod->FindFunctionBySymbol(mod->RegisterSymbol("func"))) {
for (auto* stmt : *func->body()) {
if (auto* decl = stmt->As<ast::VariableDeclStatement>()) {
if (auto* var = decl->variable()) {
@@ -92,9 +92,9 @@
struct ModuleBuilder : public ast::BuilderWithModule {
ModuleBuilder() : body_(create<ast::BlockStatement>()) {
- mod->AddFunction(create<ast::Function>(Source{}, "func",
- ast::VariableList{}, ty.void_, body_,
- ast::FunctionDecorationList{}));
+ mod->AddFunction(create<ast::Function>(
+ Source{}, mod->RegisterSymbol("func"), "func", ast::VariableList{},
+ ty.void_, body_, ast::FunctionDecorationList{}));
}
ast::Module Module() {
diff --git a/src/transform/emit_vertex_point_size_test.cc b/src/transform/emit_vertex_point_size_test.cc
index 7a08e5f..57a9047 100644
--- a/src/transform/emit_vertex_point_size_test.cc
+++ b/src/transform/emit_vertex_point_size_test.cc
@@ -58,23 +58,26 @@
Var("builtin_assignments_should_happen_before_this",
tint::ast::StorageClass::kFunction, ty.f32)));
- mod->AddFunction(
- create<ast::Function>(Source{}, "non_entry_a", ast::VariableList{},
- ty.void_, create<ast::BlockStatement>(Source{}),
- ast::FunctionDecorationList{}));
+ auto a_sym = mod->RegisterSymbol("non_entry_a");
+ mod->AddFunction(create<ast::Function>(
+ Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_,
+ create<ast::BlockStatement>(Source{}),
+ ast::FunctionDecorationList{}));
+ auto entry_sym = mod->RegisterSymbol("entry");
auto* entry = create<ast::Function>(
- Source{}, "entry", ast::VariableList{}, ty.void_, block,
+ Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, block,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex,
Source{}),
});
mod->AddFunction(entry);
- mod->AddFunction(
- create<ast::Function>(Source{}, "non_entry_b", ast::VariableList{},
- ty.void_, create<ast::BlockStatement>(Source{}),
- ast::FunctionDecorationList{}));
+ auto b_sym = mod->RegisterSymbol("non_entry_b");
+ mod->AddFunction(create<ast::Function>(
+ Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_,
+ create<ast::BlockStatement>(Source{}),
+ ast::FunctionDecorationList{}));
}
};
@@ -82,7 +85,7 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- auto* expected = R"(Module{
+ auto expected = R"(Module{
Variable{
Decorations{
BuiltinDecoration{pointsize}
@@ -91,11 +94,13 @@
out
__f32
}
- Function non_entry_a -> __void
+ Function )" + result.module.RegisterSymbol("non_entry_a").to_str() +
+ R"( -> __void
()
{
}
- Function entry -> __void
+ Function )" + result.module.RegisterSymbol("entry").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -111,7 +116,8 @@
}
}
}
- Function non_entry_b -> __void
+ Function )" + result.module.RegisterSymbol("non_entry_b").to_str() +
+ R"( -> __void
()
{
}
@@ -123,23 +129,26 @@
TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
struct Builder : ModuleBuilder {
void Build() override {
- mod->AddFunction(
- create<ast::Function>(Source{}, "non_entry_a", ast::VariableList{},
- ty.void_, create<ast::BlockStatement>(Source{}),
- ast::FunctionDecorationList{}));
+ auto a_sym = mod->RegisterSymbol("non_entry_a");
+ mod->AddFunction(create<ast::Function>(
+ Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_,
+ create<ast::BlockStatement>(Source{}),
+ ast::FunctionDecorationList{}));
- mod->AddFunction(
- create<ast::Function>(Source{}, "entry", ast::VariableList{},
- ty.void_, create<ast::BlockStatement>(Source{}),
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kVertex, Source{}),
- }));
+ auto entry_sym = mod->RegisterSymbol("entry");
+ mod->AddFunction(create<ast::Function>(
+ Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_,
+ create<ast::BlockStatement>(Source{}),
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kVertex,
+ Source{}),
+ }));
- mod->AddFunction(
- create<ast::Function>(Source{}, "non_entry_b", ast::VariableList{},
- ty.void_, create<ast::BlockStatement>(Source{}),
- ast::FunctionDecorationList{}));
+ auto b_sym = mod->RegisterSymbol("non_entry_b");
+ mod->AddFunction(create<ast::Function>(
+ Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_,
+ create<ast::BlockStatement>(Source{}),
+ ast::FunctionDecorationList{}));
}
};
@@ -147,7 +156,7 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- auto* expected = R"(Module{
+ auto expected = R"(Module{
Variable{
Decorations{
BuiltinDecoration{pointsize}
@@ -156,11 +165,13 @@
out
__f32
}
- Function non_entry_a -> __void
+ Function )" + result.module.RegisterSymbol("non_entry_a").to_str() +
+ R"( -> __void
()
{
}
- Function entry -> __void
+ Function )" + result.module.RegisterSymbol("entry").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -169,7 +180,8 @@
ScalarConstructor[__f32]{1.000000}
}
}
- Function non_entry_b -> __void
+ Function )" + result.module.RegisterSymbol("non_entry_b").to_str() +
+ R"( -> __void
()
{
}
@@ -181,8 +193,9 @@
TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
struct Builder : ModuleBuilder {
void Build() override {
+ auto frag_sym = mod->RegisterSymbol("fragment_entry");
auto* fragment_entry = create<ast::Function>(
- Source{}, "fragment_entry", ast::VariableList{}, ty.void_,
+ Source{}, frag_sym, "fragment_entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment,
@@ -190,13 +203,14 @@
});
mod->AddFunction(fragment_entry);
- auto* compute_entry =
- create<ast::Function>(Source{}, "compute_entry", ast::VariableList{},
- ty.void_, create<ast::BlockStatement>(Source{}),
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto comp_sym = mod->RegisterSymbol("compute_entry");
+ auto* compute_entry = create<ast::Function>(
+ Source{}, comp_sym, "compute_entry", ast::VariableList{}, ty.void_,
+ create<ast::BlockStatement>(Source{}),
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod->AddFunction(compute_entry);
}
};
@@ -205,13 +219,15 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- auto* expected = R"(Module{
- Function fragment_entry -> __void
+ auto expected = R"(Module{
+ Function )" + result.module.RegisterSymbol("fragment_entry").to_str() +
+ R"( -> __void
StageDecoration{fragment}
()
{
}
- Function compute_entry -> __void
+ Function )" + result.module.RegisterSymbol("compute_entry").to_str() +
+ R"( -> __void
StageDecoration{compute}
()
{
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index 8738b7e..b398935 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -169,9 +169,9 @@
body->append(ctx.Clone(s));
}
return ctx.mod->create<ast::Function>(
- ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()),
- ctx.Clone(func->return_type()), ctx.Clone(body),
- ctx.Clone(func->decorations()));
+ ctx.Clone(func->source()), func->symbol(), func->name(),
+ ctx.Clone(func->params()), ctx.Clone(func->return_type()),
+ ctx.Clone(body), ctx.Clone(func->decorations()));
});
in->Clone(&ctx);
diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc
index c3b857c..7a27557 100644
--- a/src/transform/first_index_offset_test.cc
+++ b/src/transform/first_index_offset_test.cc
@@ -58,9 +58,9 @@
ast::Function* AddFunction(const std::string& name,
ast::VariableList params = {}) {
- auto* func = create<ast::Function>(Source{}, name, std::move(params),
- ty.u32, create<ast::BlockStatement>(),
- ast::FunctionDecorationList());
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol(name), name, std::move(params), ty.u32,
+ create<ast::BlockStatement>(), ast::FunctionDecorationList());
mod->AddFunction(func);
return func;
}
@@ -154,7 +154,7 @@
uniform
__struct_TintFirstIndexOffsetData
}
- Function test -> __u32
+ Function tint_symbol_1 -> __u32
()
{
VariableDeclStatement{
@@ -229,7 +229,7 @@
uniform
__struct_TintFirstIndexOffsetData
}
- Function test -> __u32
+ Function tint_symbol_1 -> __u32
()
{
VariableDeclStatement{
@@ -317,7 +317,7 @@
uniform
__struct_TintFirstIndexOffsetData
}
- Function test -> __u32
+ Function tint_symbol_1 -> __u32
()
{
Return{
@@ -389,7 +389,7 @@
uniform
__struct_TintFirstIndexOffsetData
}
- Function func1 -> __u32
+ Function tint_symbol_1 -> __u32
()
{
VariableDeclStatement{
@@ -415,7 +415,7 @@
}
}
}
- Function func2 -> __u32
+ Function tint_symbol_2 -> __u32
()
{
Return{
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index d35f3d8..6143d15 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -84,8 +84,8 @@
}
// Find entry point
- auto* func = mod->FindFunctionByNameAndStage(cfg.entry_point_name,
- ast::PipelineStage::kVertex);
+ auto* func = mod->FindFunctionBySymbolAndStage(
+ mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
if (func == nullptr) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
@@ -94,9 +94,6 @@
return out;
}
- // Save the vertex function
- auto* vertex_func = mod->FindFunctionByName(func->name());
-
// TODO(idanr): Need to check shader locations in descriptor cover all
// attributes
@@ -108,7 +105,7 @@
state.FindOrInsertInstanceIndexIfUsed();
state.ConvertVertexInputVariablesToPrivate();
state.AddVertexStorageBuffers();
- state.AddVertexPullingPreamble(vertex_func);
+ state.AddVertexPullingPreamble(func);
return out;
}
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index 8f2a177..c027920 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -47,8 +47,8 @@
// Create basic module with an entry point and vertex function
void InitBasicModule() {
auto* func = create<ast::Function>(
- Source{}, "main", ast::VariableList{}, mod_->create<ast::type::Void>(),
- create<ast::BlockStatement>(),
+ Source{}, mod_->RegisterSymbol("main"), "main", ast::VariableList{},
+ mod_->create<ast::type::Void>(), create<ast::BlockStatement>(),
ast::FunctionDecorationList{create<ast::StageDecoration>(
ast::PipelineStage::kVertex, Source{})});
mod()->AddFunction(func);
@@ -134,8 +134,8 @@
TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
auto* func = create<ast::Function>(
- Source{}, "main", ast::VariableList{}, mod()->create<ast::type::Void>(),
- create<ast::BlockStatement>(),
+ Source{}, mod()->RegisterSymbol("main"), "main", ast::VariableList{},
+ mod()->create<ast::type::Void>(), create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -152,7 +152,8 @@
InitBasicModule();
InitTransform({});
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
}
TEST_F(VertexPullingTest, OneAttribute) {
@@ -164,7 +165,8 @@
InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@@ -193,7 +195,8 @@
storage_buffer
__struct_TintVertexData
}
- Function main -> __void
+ Function )" + result.module.GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -250,7 +253,8 @@
{{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}});
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@@ -279,7 +283,8 @@
storage_buffer
__struct_TintVertexData
}
- Function main -> __void
+ Function )" + result.module.GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -336,7 +341,8 @@
transform()->SetPullingBufferBindingSet(5);
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@@ -365,7 +371,8 @@
storage_buffer
__struct_TintVertexData
}
- Function main -> __void
+ Function )" + result.module.GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -451,7 +458,8 @@
{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}});
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@@ -502,7 +510,8 @@
storage_buffer
__struct_TintVertexData
}
- Function main -> __void
+ Function )" + result.module.GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -592,7 +601,8 @@
{{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}});
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@@ -626,7 +636,8 @@
storage_buffer
__struct_TintVertexData
}
- Function main -> __void
+ Function )" + result.module.GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
@@ -778,7 +789,8 @@
{16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}});
auto result = manager()->Run(mod());
- ASSERT_FALSE(result.diagnostics.contains_errors());
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@@ -835,7 +847,8 @@
storage_buffer
__struct_TintVertexData
}
- Function main -> __void
+ Function )" + result.module.GetSymbol("main").to_str() +
+ R"( -> __void
StageDecoration{vertex}
()
{
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 91fdba5..86e8d2e 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -122,7 +122,7 @@
continue;
}
for (const auto& callee : caller_to_callee_[func->name()]) {
- set_entry_points(callee, func->name());
+ set_entry_points(callee, func->symbol());
}
}
@@ -130,11 +130,11 @@
}
void TypeDeterminer::set_entry_points(const std::string& fn_name,
- const std::string& ep_name) {
- name_to_function_[fn_name]->add_ancestor_entry_point(ep_name);
+ Symbol ep_sym) {
+ name_to_function_[fn_name]->add_ancestor_entry_point(ep_sym);
for (const auto& callee : caller_to_callee_[fn_name]) {
- set_entry_points(callee, ep_name);
+ set_entry_points(callee, ep_sym);
}
}
@@ -389,7 +389,8 @@
if (current_function_) {
caller_to_callee_[current_function_->name()].push_back(ident->name());
- auto* callee_func = mod_->FindFunctionByName(ident->name());
+ auto* callee_func =
+ mod_->FindFunctionBySymbol(mod_->GetSymbol(ident->name()));
if (callee_func == nullptr) {
set_error(expr->source(),
"unable to find called function: " + ident->name());
diff --git a/src/type_determiner.h b/src/type_determiner.h
index f7cb587..0b1157a 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -113,7 +113,7 @@
private:
void set_error(const Source& src, const std::string& msg);
void set_referenced_from_function_if_needed(ast::Variable* var, bool local);
- void set_entry_points(const std::string& fn_name, const std::string& ep_name);
+ void set_entry_points(const std::string& fn_name, Symbol ep_sym);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineBinary(ast::BinaryExpression* expr);
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index c798986..a45e896 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -341,9 +341,9 @@
ast::type::F32 f32;
ast::VariableList params;
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
- create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
+ create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@@ -372,15 +372,16 @@
auto* main_body = create<ast::BlockStatement>();
main_body->append(create<ast::CallStatement>(call_expr));
main_body->append(create<ast::ReturnStatement>(Source{}));
- auto* func_main =
- create<ast::Function>(Source{}, "main", params0, &f32, main_body,
- ast::FunctionDecorationList{});
+ auto* func_main = create<ast::Function>(Source{}, mod->RegisterSymbol("main"),
+ "main", params0, &f32, main_body,
+ ast::FunctionDecorationList{});
mod->AddFunction(func_main);
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func = create<ast::Function>(Source{}, "func", params0, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("func"), "func",
+ params0, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
EXPECT_FALSE(td()->Determine()) << td()->error();
@@ -639,9 +640,9 @@
ast::type::F32 f32;
ast::VariableList params;
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
- create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
+ create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@@ -659,9 +660,9 @@
ast::type::F32 f32;
ast::VariableList params;
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
- create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
+ create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@@ -809,8 +810,8 @@
body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var")));
- ast::Function f(Source{}, "my_func", {}, &f32, body,
- ast::FunctionDecorationList{});
+ ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32,
+ body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f));
@@ -836,8 +837,8 @@
body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var")));
- ast::Function f(Source{}, "my_func", {}, &f32, body,
- ast::FunctionDecorationList{});
+ ast::Function f(Source{}, mod->RegisterSymbol("myfunc"), "my_func", {}, &f32,
+ body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f));
@@ -868,8 +869,8 @@
body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var")));
- ast::Function f(Source{}, "my_func", {}, &f32, body,
- ast::FunctionDecorationList{});
+ ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32,
+ body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f));
@@ -885,9 +886,9 @@
ast::type::F32 f32;
ast::VariableList params;
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
- create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
+ create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@@ -968,8 +969,9 @@
body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("priv_var"),
create<ast::IdentifierExpression>("priv_var")));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
+ params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
@@ -1049,8 +1051,9 @@
create<ast::IdentifierExpression>("priv_var"),
create<ast::IdentifierExpression>("priv_var")));
ast::VariableList params;
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
+ params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
@@ -1059,8 +1062,9 @@
create<ast::IdentifierExpression>("out_var"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("my_func"),
ast::ExpressionList{})));
- auto* func2 = create<ast::Function>(Source{}, "func", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func2 =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("func"), "func",
+ params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func2);
@@ -1096,8 +1100,9 @@
create<ast::FloatLiteral>(&f32, 1.f))));
ast::VariableList params;
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
+ params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
@@ -2636,8 +2641,9 @@
auto* body = create<ast::BlockStatement>();
body->append(stmt);
- auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{},
- &i32, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
+ "func", ast::VariableList{}, &i32, body,
+ ast::FunctionDecorationList{});
mod->AddFunction(func);
@@ -2660,8 +2666,9 @@
auto* body = create<ast::BlockStatement>();
body->append(stmt);
- auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{},
- &i32, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
+ "func", ast::VariableList{}, &i32, body,
+ ast::FunctionDecorationList{});
mod->AddFunction(func);
@@ -2684,8 +2691,9 @@
auto* body = create<ast::BlockStatement>();
body->append(stmt);
- auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{},
- &i32, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
+ "func", ast::VariableList{}, &i32, body,
+ ast::FunctionDecorationList{});
mod->AddFunction(func);
@@ -4857,24 +4865,27 @@
ast::VariableList params;
auto* body = create<ast::BlockStatement>();
- auto* func_b = create<ast::Function>(Source{}, "b", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func_b =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("b"), "b", params,
+ &f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("second"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("b"),
ast::ExpressionList{})));
- auto* func_c = create<ast::Function>(Source{}, "c", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func_c =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("c"), "c", params,
+ &f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("first"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("c"),
ast::ExpressionList{})));
- auto* func_a = create<ast::Function>(Source{}, "a", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func_a =
+ create<ast::Function>(Source{}, mod->RegisterSymbol("a"), "a", params,
+ &f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>(
@@ -4886,7 +4897,7 @@
create<ast::CallExpression>(create<ast::IdentifierExpression>("b"),
ast::ExpressionList{})));
auto* ep_1 = create<ast::Function>(
- Source{}, "ep_1", params, &f32, body,
+ Source{}, mod->RegisterSymbol("ep_1"), "ep_1", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -4897,7 +4908,7 @@
create<ast::CallExpression>(create<ast::IdentifierExpression>("c"),
ast::ExpressionList{})));
auto* ep_2 = create<ast::Function>(
- Source{}, "ep_2", params, &f32, body,
+ Source{}, mod->RegisterSymbol("ep_2"), "ep_2", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -4954,17 +4965,17 @@
const auto& b_eps = func_b->ancestor_entry_points();
ASSERT_EQ(2u, b_eps.size());
- EXPECT_EQ("ep_1", b_eps[0]);
- EXPECT_EQ("ep_2", b_eps[1]);
+ EXPECT_EQ(mod->RegisterSymbol("ep_1"), b_eps[0]);
+ EXPECT_EQ(mod->RegisterSymbol("ep_2"), b_eps[1]);
const auto& a_eps = func_a->ancestor_entry_points();
ASSERT_EQ(1u, a_eps.size());
- EXPECT_EQ("ep_1", a_eps[0]);
+ EXPECT_EQ(mod->RegisterSymbol("ep_1"), a_eps[0]);
const auto& c_eps = func_c->ancestor_entry_points();
ASSERT_EQ(2u, c_eps.size());
- EXPECT_EQ("ep_1", c_eps[0]);
- EXPECT_EQ("ep_2", c_eps[1]);
+ EXPECT_EQ(mod->RegisterSymbol("ep_1"), c_eps[0]);
+ EXPECT_EQ(mod->RegisterSymbol("ep_2"), c_eps[1]);
EXPECT_TRUE(ep_1->ancestor_entry_points().empty());
EXPECT_TRUE(ep_2->ancestor_entry_points().empty());
diff --git a/src/validator/validator_function_test.cc b/src/validator/validator_function_test.cc
index c335c22..780b542 100644
--- a/src/validator/validator_function_test.cc
+++ b/src/validator/validator_function_test.cc
@@ -54,7 +54,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
auto* func = create<ast::Function>(
- Source{Source::Location{12, 34}}, "func", params, &void_type, body,
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
+ params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -71,8 +72,8 @@
ast::type::Void void_type;
ast::VariableList params;
auto* func = create<ast::Function>(
- Source{Source::Location{12, 34}}, "func", params, &void_type,
- create<ast::BlockStatement>(),
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
+ params, &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -100,9 +101,9 @@
ast::type::Void void_type;
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
- auto* func =
- create<ast::Function>(Source{Source::Location{12, 34}}, "func", params,
- &i32, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
+ params, &i32, body, ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
@@ -117,8 +118,9 @@
ast::type::I32 i32;
ast::VariableList params;
auto* func = create<ast::Function>(
- Source{Source::Location{12, 34}}, "func", params, &i32,
- create<ast::BlockStatement>(), ast::FunctionDecorationList{});
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
+ params, &i32, create<ast::BlockStatement>(),
+ ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
@@ -136,7 +138,7 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "func", params, &void_type, body,
+ Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -157,7 +159,8 @@
body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
return_expr));
- auto* func = create<ast::Function>(Source{}, "func", params, &void_type, body,
+ auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
+ "func", params, &void_type, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func);
@@ -180,8 +183,9 @@
body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
return_expr));
- auto* func = create<ast::Function>(Source{}, "func", params, &f32, body,
- ast::FunctionDecorationList{});
+ auto* func =
+ create<ast::Function>(Source{}, mod()->RegisterSymbol("func"), "func",
+ params, &f32, body, ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
@@ -204,8 +208,9 @@
create<ast::SintLiteral>(&i32, 2));
body->append(create<ast::ReturnStatement>(Source{}, return_expr));
- auto* func = create<ast::Function>(Source{}, "func", params, &i32, body,
- ast::FunctionDecorationList{});
+ auto* func =
+ create<ast::Function>(Source{}, mod()->RegisterSymbol("func"), "func",
+ params, &i32, body, ast::FunctionDecorationList{});
ast::VariableList params_copy;
auto* body_copy = create<ast::BlockStatement>();
@@ -213,9 +218,9 @@
create<ast::SintLiteral>(&i32, 2));
body_copy->append(create<ast::ReturnStatement>(Source{}, return_expr_copy));
- auto* func_copy = create<ast::Function>(Source{Source::Location{12, 34}},
- "func", params_copy, &i32, body_copy,
- ast::FunctionDecorationList{});
+ auto* func_copy = create<ast::Function>(
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
+ params_copy, &i32, body_copy, ast::FunctionDecorationList{});
mod()->AddFunction(func);
mod()->AddFunction(func_copy);
@@ -237,7 +242,8 @@
auto* body0 = create<ast::BlockStatement>();
body0->append(create<ast::CallStatement>(call_expr));
body0->append(create<ast::ReturnStatement>(Source{}));
- auto* func0 = create<ast::Function>(Source{}, "func", params0, &f32, body0,
+ auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
+ "func", params0, &f32, body0,
ast::FunctionDecorationList{});
mod()->AddFunction(func0);
@@ -268,7 +274,8 @@
create<ast::SintLiteral>(&i32, 2));
body0->append(create<ast::ReturnStatement>(Source{}, return_expr));
- auto* func0 = create<ast::Function>(Source{}, "func", params0, &i32, body0,
+ auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
+ "func", params0, &i32, body0,
ast::FunctionDecorationList{});
mod()->AddFunction(func0);
@@ -288,7 +295,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func = create<ast::Function>(
- Source{Source::Location{12, 34}}, "vtx_main", params, &i32, body,
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_main"),
+ "vtx_main", params, &i32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -317,7 +325,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{Source::Location{12, 34}}, "vtx_func", params, &void_type, body,
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_func"),
+ "vtx_func", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -339,7 +348,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{Source::Location{12, 34}}, "main", params, &void_type, body,
+ Source{Source::Location{12, 34}}, mod()->RegisterSymbol("main"), "main",
+ params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
@@ -361,7 +371,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "vtx_func", params, &void_type, body,
+ Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -377,8 +388,9 @@
ast::VariableList params;
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func = create<ast::Function>(Source{}, "vtx_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params,
+ &void_type, body, ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
diff --git a/src/validator/validator_test.cc b/src/validator/validator_test.cc
index d993b1e..5931fc0 100644
--- a/src/validator/validator_test.cc
+++ b/src/validator/validator_test.cc
@@ -332,7 +332,8 @@
body->append(create<ast::AssignmentStatement>(
Source{Source::Location{12, 34}}, lhs, rhs));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
+ auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
+ "my_func", params, &f32, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func);
@@ -370,7 +371,8 @@
Source{Source::Location{12, 34}}, lhs, rhs));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "my_func", params, &void_type, body,
+ Source{}, mod()->RegisterSymbol("my_func"), "my_func", params, &void_type,
+ body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -587,8 +589,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
+ "my_func", params, &void_type, body,
+ ast::FunctionDecorationList{});
mod()->AddFunction(func);
@@ -631,8 +634,9 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var_a_float));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
+ "my_func", params, &void_type, body,
+ ast::FunctionDecorationList{});
mod()->AddFunction(func);
@@ -759,8 +763,9 @@
body0->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var0));
body0->append(create<ast::ReturnStatement>(Source{}));
- auto* func0 = create<ast::Function>(Source{}, "func0", params0, &void_type,
- body0, ast::FunctionDecorationList{});
+ auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func0"),
+ "func0", params0, &void_type, body0,
+ ast::FunctionDecorationList{});
ast::VariableList params1;
auto* body1 = create<ast::BlockStatement>();
@@ -768,7 +773,8 @@
Source{Source::Location{13, 34}}, var1));
body1->append(create<ast::ReturnStatement>(Source{}));
auto* func1 = create<ast::Function>(
- Source{}, "func1", params1, &void_type, body1,
+ Source{}, mod()->RegisterSymbol("func1"), "func1", params1, &void_type,
+ body1,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
diff --git a/src/validator/validator_type_test.cc b/src/validator/validator_type_test.cc
index 2fbaeff..ac8cc05 100644
--- a/src/validator/validator_type_test.cc
+++ b/src/validator/validator_type_test.cc
@@ -206,8 +206,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var));
+
auto* func = create<ast::Function>(
- Source{}, "func", params, &void_type, body,
+ Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index c52c8df..fa49a4b 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -196,15 +196,15 @@
std::string name = "";
switch (type) {
case VarType::kIn: {
- auto in_it = ep_name_to_in_data_.find(current_ep_name_);
- if (in_it != ep_name_to_in_data_.end()) {
+ auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name;
}
break;
}
case VarType::kOut: {
- auto outit = ep_name_to_out_data_.find(current_ep_name_);
- if (outit != ep_name_to_out_data_.end()) {
+ auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ if (outit != ep_sym_to_out_data_.end()) {
name = outit->second.var_name;
}
break;
@@ -668,12 +668,14 @@
}
auto name = ident->name();
- auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
+ auto caller_sym = module_->GetSymbol(name);
+ auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
+ caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) {
name = it->second;
}
- auto* func = module_->FindFunctionByName(ident->name());
+ auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name()));
if (func == nullptr) {
error_ = "Unable to find function: " + name;
return false;
@@ -1189,15 +1191,15 @@
has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) {
- for (const auto& ep_name : func->ancestor_entry_points()) {
- if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_name)) {
+ for (const auto& ep_sym : func->ancestor_entry_points()) {
+ if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_sym)) {
return false;
}
out << std::endl;
}
} else {
// Emit as non-duplicated
- if (!EmitFunctionInternal(out, func, false, "")) {
+ if (!EmitFunctionInternal(out, func, false, Symbol())) {
return false;
}
out << std::endl;
@@ -1209,8 +1211,8 @@
bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
ast::Function* func,
bool emit_duplicate_functions,
- const std::string& ep_name) {
- auto name = func->name();
+ Symbol ep_sym) {
+ auto name = func->symbol().to_str();
if (!EmitType(out, func->return_type(), "")) {
return false;
@@ -1219,10 +1221,15 @@
out << " ";
if (emit_duplicate_functions) {
- name = generate_name(name + "_" + ep_name);
- ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
+ auto func_name = name;
+ auto ep_name = ep_sym.to_str();
+ // TODO(dsinclair): The SymbolToName should go away and just use
+ // to_str() here when the conversion is complete.
+ name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym));
+ ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else {
- name = namer_.NameFor(name);
+ // TODO(dsinclair): this should be updated to a remapped name
+ name = namer_.NameFor(func->name());
}
out << name << "(";
@@ -1234,15 +1241,15 @@
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
- auto in_it = ep_name_to_in_data_.find(ep_name);
- if (in_it != ep_name_to_in_data_.end()) {
+ auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
+ if (in_it != ep_sym_to_in_data_.end()) {
out << "in " << in_it->second.struct_name << " "
<< in_it->second.var_name;
first = false;
}
- auto outit = ep_name_to_out_data_.find(ep_name);
- if (outit != ep_name_to_out_data_.end()) {
+ auto outit = ep_sym_to_out_data_.find(ep_sym.value());
+ if (outit != ep_sym_to_out_data_.end()) {
if (!first) {
out << ", ";
}
@@ -1269,13 +1276,13 @@
out << ") ";
- current_ep_name_ = ep_name;
+ current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(out, func->body())) {
return false;
}
- current_ep_name_ = "";
+ current_ep_sym_ = Symbol();
return true;
}
@@ -1392,7 +1399,7 @@
auto in_struct_name =
generate_name(func->name() + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix);
- ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name};
+ ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name};
make_indent(out);
out << "struct " << in_struct_name << " {" << std::endl;
@@ -1438,7 +1445,7 @@
auto outstruct_name =
generate_name(func->name() + "_" + kOutStructNameSuffix);
auto outvar_name = generate_name(kTintStructOutVarPrefix);
- ep_name_to_out_data_[func->name()] = {outstruct_name, outvar_name};
+ ep_sym_to_out_data_[func->symbol().value()] = {outstruct_name, outvar_name};
make_indent(out);
out << "struct " << outstruct_name << " {" << std::endl;
@@ -1516,7 +1523,7 @@
ast::Function* func) {
make_indent(out);
- current_ep_name_ = func->name();
+ current_ep_sym_ = func->symbol();
if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
uint32_t x = 0;
@@ -1528,17 +1535,18 @@
make_indent(out);
}
- auto outdata = ep_name_to_out_data_.find(current_ep_name_);
- bool has_outdata = outdata != ep_name_to_out_data_.end();
+ auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) {
out << outdata->second.struct_name;
} else {
out << "void";
}
- out << " " << namer_.NameFor(current_ep_name_) << "(";
+ // TODO(dsinclair): This should output the remapped name
+ out << " " << namer_.NameFor(module_->SymbolToName(current_ep_sym_)) << "(";
- auto in_data = ep_name_to_in_data_.find(current_ep_name_);
- if (in_data != ep_name_to_in_data_.end()) {
+ auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ if (in_data != ep_sym_to_in_data_.end()) {
out << in_data->second.struct_name << " " << in_data->second.var_name;
}
out << ") {" << std::endl;
@@ -1563,7 +1571,7 @@
make_indent(out);
out << "}" << std::endl;
- current_ep_name_ = "";
+ current_ep_sym_ = Symbol();
return true;
}
@@ -1966,8 +1974,8 @@
if (generating_entry_point_) {
out << "return";
- auto outdata = ep_name_to_out_data_.find(current_ep_name_);
- if (outdata != ep_name_to_out_data_.end()) {
+ auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ if (outdata != ep_sym_to_out_data_.end()) {
out << " " << outdata->second.var_name;
}
} else if (stmt->has_value()) {
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 6694e26..ebd98b9 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -210,12 +210,12 @@
/// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point
- /// @param ep_name the current entry point or blank if none set
+ /// @param ep_sym the current entry point or symbol::kInvalid if none set
/// @returns true if the function was emitted.
bool EmitFunctionInternal(std::ostream& out,
ast::Function* func,
bool emit_duplicate_functions,
- const std::string& ep_name);
+ Symbol ep_sym);
/// Handles emitting information for an entry point
/// @param out the output stream
/// @param func the entry point
@@ -397,12 +397,12 @@
Namer namer_;
ast::Module* module_ = nullptr;
- std::string current_ep_name_;
+ Symbol current_ep_sym_;
bool generating_entry_point_ = false;
uint32_t loop_emission_counter_ = 0;
ScopeStack<ast::Variable*> global_variables_;
- std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_;
- std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_;
+ std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
+ std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 20654eb..318c68b 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -613,9 +613,9 @@
ast::type::Void void_type;
- auto* func = create<ast::Function>(Source{}, "foo", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("foo"), "foo", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ast::ExpressionList params;
diff --git a/src/writer/hlsl/generator_impl_call_test.cc b/src/writer/hlsl/generator_impl_call_test.cc
index e185dcc..3311837 100644
--- a/src/writer/hlsl/generator_impl_call_test.cc
+++ b/src/writer/hlsl/generator_impl_call_test.cc
@@ -35,9 +35,9 @@
auto* id = create<ast::IdentifierExpression>("my_func");
ast::CallExpression call(id, {});
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error();
@@ -53,9 +53,9 @@
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallExpression call(id, params);
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error();
@@ -71,9 +71,9 @@
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallStatement call(create<ast::CallExpression>(id, params));
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(out, &call)) << gen.error();
diff --git a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc
index 57c8553..320b478 100644
--- a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc
+++ b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc
@@ -91,7 +91,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "vtx_main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -164,7 +164,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "vtx_main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -237,7 +237,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -309,7 +309,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -378,7 +378,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -442,7 +442,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -512,7 +512,7 @@
create<ast::IdentifierExpression>("x"))));
auto* func = create<ast::Function>(
- Source{}, "main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 95543a7..cab53fc 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -57,9 +57,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
+ "my_func", ast::VariableList{}, &void_type,
+ body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -77,9 +77,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "GeometryShader", ast::VariableList{},
- &void_type, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader",
+ ast::VariableList{}, &void_type, body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -118,8 +118,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
+ "my_func", params, &void_type, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -174,7 +175,8 @@
create<ast::IdentifierExpression>("foo")));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -245,7 +247,8 @@
create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -309,7 +312,8 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -380,7 +384,8 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -455,7 +460,8 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -526,7 +532,8 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -594,7 +601,8 @@
body->append(assign);
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -682,8 +690,9 @@
create<ast::IdentifierExpression>("param")));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("foo")));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -698,7 +707,7 @@
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -766,8 +775,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -782,7 +792,7 @@
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -863,8 +873,9 @@
create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -879,7 +890,7 @@
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -948,8 +959,9 @@
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x"))));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -971,7 +983,8 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1034,8 +1047,9 @@
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x"))));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -1057,7 +1071,8 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1122,7 +1137,7 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1152,8 +1167,8 @@
ast::type::Void void_type;
auto* func = create<ast::Function>(
- Source{}, "GeometryShader", ast::VariableList{}, &void_type,
- create<ast::BlockStatement>(),
+ Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader",
+ ast::VariableList{}, &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1175,7 +1190,7 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -1200,7 +1215,7 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
@@ -1236,8 +1251,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
+ "my_func", params, &void_type, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -1317,12 +1333,12 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "a", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod.AddFunction(func);
}
@@ -1343,12 +1359,12 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "b", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod.AddFunction(func);
}
diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc
index dfdaf95..1494fcf 100644
--- a/src/writer/hlsl/generator_impl_test.cc
+++ b/src/writer/hlsl/generator_impl_test.cc
@@ -29,9 +29,9 @@
TEST_F(HlslGeneratorImplTest, Generate) {
ast::type::Void void_type;
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.Generate(out)) << gen.error();
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 000c3d7..4af9556 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -411,15 +411,15 @@
std::string name = "";
switch (type) {
case VarType::kIn: {
- auto in_it = ep_name_to_in_data_.find(current_ep_name_);
- if (in_it != ep_name_to_in_data_.end()) {
+ auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name;
}
break;
}
case VarType::kOut: {
- auto out_it = ep_name_to_out_data_.find(current_ep_name_);
- if (out_it != ep_name_to_out_data_.end()) {
+ auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ if (out_it != ep_sym_to_out_data_.end()) {
name = out_it->second.var_name;
}
break;
@@ -573,12 +573,14 @@
}
auto name = ident->name();
- auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
+ auto caller_sym = module_->GetSymbol(name);
+ auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
+ caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) {
name = it->second;
}
- auto* func = module_->FindFunctionByName(ident->name());
+ auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name()));
if (func == nullptr) {
error_ = "Unable to find function: " + name;
return false;
@@ -1026,7 +1028,7 @@
auto in_struct_name =
generate_name(func->name() + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix);
- ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name};
+ ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name};
make_indent();
out_ << "struct " << in_struct_name << " {" << std::endl;
@@ -1063,7 +1065,8 @@
auto out_struct_name =
generate_name(func->name() + "_" + kOutStructNameSuffix);
auto out_var_name = generate_name(kTintStructOutVarPrefix);
- ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name};
+ ep_sym_to_out_data_[func->symbol().value()] = {out_struct_name,
+ out_var_name};
make_indent();
out_ << "struct " << out_struct_name << " {" << std::endl;
@@ -1205,15 +1208,15 @@
has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) {
- for (const auto& ep_name : func->ancestor_entry_points()) {
- if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
+ for (const auto& ep_sym : func->ancestor_entry_points()) {
+ if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) {
return false;
}
out_ << std::endl;
}
} else {
// Emit as non-duplicated
- if (!EmitFunctionInternal(func, false, "")) {
+ if (!EmitFunctionInternal(func, false, Symbol())) {
return false;
}
out_ << std::endl;
@@ -1224,19 +1227,23 @@
bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
- const std::string& ep_name) {
- auto name = func->name();
-
+ Symbol ep_sym) {
+ auto name = func->symbol().to_str();
if (!EmitType(func->return_type(), "")) {
return false;
}
out_ << " ";
if (emit_duplicate_functions) {
- name = generate_name(name + "_" + ep_name);
- ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
+ auto func_name = name;
+ auto ep_name = ep_sym.to_str();
+ // TODO(dsinclair): The SymbolToName should go away and just use
+ // to_str() here when the conversion is complete.
+ name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym));
+ ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else {
- name = namer_.NameFor(name);
+ // TODO(dsinclair): this should be updated to a remapped name
+ name = namer_.NameFor(func->name());
}
out_ << name << "(";
@@ -1247,15 +1254,15 @@
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
- auto in_it = ep_name_to_in_data_.find(ep_name);
- if (in_it != ep_name_to_in_data_.end()) {
+ auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
+ if (in_it != ep_sym_to_in_data_.end()) {
out_ << "thread " << in_it->second.struct_name << "& "
<< in_it->second.var_name;
first = false;
}
- auto out_it = ep_name_to_out_data_.find(ep_name);
- if (out_it != ep_name_to_out_data_.end()) {
+ auto out_it = ep_sym_to_out_data_.find(ep_sym.value());
+ if (out_it != ep_sym_to_out_data_.end()) {
if (!first) {
out_ << ", ";
}
@@ -1337,13 +1344,13 @@
out_ << ") ";
- current_ep_name_ = ep_name;
+ current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(func->body())) {
return false;
}
- current_ep_name_ = "";
+ current_ep_sym_ = Symbol();
return true;
}
@@ -1377,25 +1384,25 @@
bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
make_indent();
- current_ep_name_ = func->name();
+ current_ep_sym_ = func->symbol();
EmitStage(func->pipeline_stage());
out_ << " ";
// This is an entry point, the return type is the entry point output structure
// if one exists, or void otherwise.
- auto out_data = ep_name_to_out_data_.find(current_ep_name_);
- bool has_out_data = out_data != ep_name_to_out_data_.end();
+ auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) {
out_ << out_data->second.struct_name;
} else {
out_ << "void";
}
- out_ << " " << namer_.NameFor(current_ep_name_) << "(";
+ out_ << " " << namer_.NameFor(func->name()) << "(";
bool first = true;
- auto in_data = ep_name_to_in_data_.find(current_ep_name_);
- if (in_data != ep_name_to_in_data_.end()) {
+ auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
+ if (in_data != ep_sym_to_in_data_.end()) {
out_ << in_data->second.struct_name << " " << in_data->second.var_name
<< " [[stage_in]]";
first = false;
@@ -1503,7 +1510,7 @@
make_indent();
out_ << "}" << std::endl;
- current_ep_name_ = "";
+ current_ep_sym_ = Symbol();
return true;
}
@@ -1687,8 +1694,8 @@
out_ << "return";
if (generating_entry_point_) {
- auto out_data = ep_name_to_out_data_.find(current_ep_name_);
- if (out_data != ep_name_to_out_data_.end()) {
+ auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
+ if (out_data != ep_sym_to_out_data_.end()) {
out_ << " " << out_data->second.var_name;
}
} else if (stmt->has_value()) {
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 0e4ef7d..d087c7c 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -156,11 +156,11 @@
/// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point
- /// @param ep_name the current entry point or blank if none set
+ /// @param ep_sym the current entry point or symbol::kInvalid if not set
/// @returns true if the function was emitted.
bool EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
- const std::string& ep_name);
+ Symbol ep_sym);
/// Handles generating an identifier expression
/// @param expr the identifier expression
/// @returns true if the identifier was emitted
@@ -282,13 +282,13 @@
Namer namer_;
ScopeStack<ast::Variable*> global_variables_;
- std::string current_ep_name_;
+ Symbol current_ep_sym_;
bool generating_entry_point_ = false;
const ast::Module* module_ = nullptr;
uint32_t loop_emission_counter_ = 0;
- std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_;
- std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_;
+ std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
+ std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did
diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc
index c27e556..7102424 100644
--- a/src/writer/msl/generator_impl_call_test.cc
+++ b/src/writer/msl/generator_impl_call_test.cc
@@ -37,9 +37,9 @@
auto* id = create<ast::IdentifierExpression>("my_func");
ast::CallExpression call(id, {});
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error();
@@ -55,9 +55,9 @@
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallExpression call(id, params);
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error();
@@ -73,9 +73,9 @@
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallStatement call(create<ast::CallExpression>(id, params));
- auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, create<ast::BlockStatement>(),
- ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
diff --git a/src/writer/msl/generator_impl_function_entry_point_data_test.cc b/src/writer/msl/generator_impl_function_entry_point_data_test.cc
index ac6b4c8..f43ea32 100644
--- a/src/writer/msl/generator_impl_function_entry_point_data_test.cc
+++ b/src/writer/msl/generator_impl_function_entry_point_data_test.cc
@@ -90,7 +90,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "vtx_main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -160,7 +160,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "vtx_main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -229,7 +229,7 @@
create<ast::IdentifierExpression>("bar"),
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -299,7 +299,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -366,7 +366,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -428,7 +428,7 @@
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
- Source{}, "main", params, &f32, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -496,7 +496,7 @@
create<ast::IdentifierExpression>("x"))));
auto* func = create<ast::Function>(
- Source{}, "main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index 03934cb..950e8fe 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -60,9 +60,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "my_func", ast::VariableList{},
- &void_type, body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
+ "my_func", ast::VariableList{}, &void_type,
+ body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -82,9 +82,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "main", ast::VariableList{}, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("main"),
+ "main", ast::VariableList{}, &void_type,
+ body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -125,8 +125,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
+ "my_func", params, &void_type, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@@ -183,7 +184,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{create<ast::StageDecoration>(
ast::PipelineStage::kFragment, Source{})});
@@ -257,7 +259,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -321,7 +324,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -397,7 +401,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -478,7 +483,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -572,8 +578,9 @@
create<ast::IdentifierExpression>("param")));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("foo")));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -588,7 +595,7 @@
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -659,8 +666,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -676,7 +684,7 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -760,8 +768,9 @@
create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -776,7 +785,7 @@
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -843,8 +852,9 @@
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x"))));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -867,7 +877,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -943,8 +954,9 @@
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("b"))));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -967,7 +979,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1049,8 +1062,9 @@
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("b"))));
- auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
- body, ast::FunctionDecorationList{});
+ auto* sub_func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@@ -1073,7 +1087,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
- Source{}, "frag_main", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1145,7 +1160,7 @@
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
- Source{}, "ep_1", params, &void_type, body,
+ Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -1177,8 +1192,8 @@
ast::type::Void void_type;
auto* func = create<ast::Function>(
- Source{}, "main", ast::VariableList{}, &void_type,
- create<ast::BlockStatement>(),
+ Source{}, mod.RegisterSymbol("main"), "main", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -1212,8 +1227,9 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
- body, ast::FunctionDecorationList{});
+ auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
+ "my_func", params, &void_type, body,
+ ast::FunctionDecorationList{});
mod.AddFunction(func);
@@ -1298,12 +1314,12 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "a", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod.AddFunction(func);
}
@@ -1325,12 +1341,12 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "b", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod.AddFunction(func);
}
diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc
index 06fdbc7..0ba3247 100644
--- a/src/writer/msl/generator_impl_test.cc
+++ b/src/writer/msl/generator_impl_test.cc
@@ -51,8 +51,8 @@
ast::type::Void void_type;
auto* func = create<ast::Function>(
- Source{}, "my_func", ast::VariableList{}, &void_type,
- create<ast::BlockStatement>(),
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc
index 1f2d50b..a4585a5 100644
--- a/src/writer/spirv/builder_call_test.cc
+++ b/src/writer/spirv/builder_call_test.cc
@@ -65,11 +65,11 @@
Source{}, create<ast::BinaryExpression>(
ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"),
create<ast::IdentifierExpression>("b"))));
- ast::Function a_func(Source{}, "a_func", func_params, &f32, body,
- ast::FunctionDecorationList{});
+ ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func",
+ func_params, &f32, body, ast::FunctionDecorationList{});
- ast::Function func(Source{}, "main", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ast::ExpressionList call_params;
@@ -143,11 +143,12 @@
ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"),
create<ast::IdentifierExpression>("b"))));
- ast::Function a_func(Source{}, "a_func", func_params, &void_type, body,
+ ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func",
+ func_params, &void_type, body,
ast::FunctionDecorationList{});
- ast::Function func(Source{}, "main", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ast::ExpressionList call_params;
diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc
index 6f835e1..dfab021 100644
--- a/src/writer/spirv/builder_function_decoration_test.cc
+++ b/src/writer/spirv/builder_function_decoration_test.cc
@@ -42,7 +42,8 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -67,8 +68,8 @@
ast::type::Void void_type;
- ast::Function func(Source{}, "main", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(params.stage, Source{}),
});
@@ -97,7 +98,8 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -174,7 +176,7 @@
create<ast::IdentifierExpression>("my_in")));
ast::Function func(
- Source{}, "main", {}, &void_type, body,
+ Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@@ -244,7 +246,8 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -259,7 +262,8 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@@ -274,7 +278,8 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
@@ -290,13 +295,15 @@
ast::type::Void void_type;
ast::Function func1(
- Source{}, "main1", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main1"), "main1", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
ast::Function func2(
- Source{}, "main2", {}, &void_type, create<ast::BlockStatement>(),
+ Source{}, mod->RegisterSymbol("main2"), "main2", {}, &void_type,
+ create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc
index 51da306..983a3b2 100644
--- a/src/writer/spirv/builder_function_test.cc
+++ b/src/writer/spirv/builder_function_test.cc
@@ -47,8 +47,8 @@
TEST_F(BuilderTest, Function_Empty) {
ast::type::Void void_type;
- ast::Function func(Source{}, "a_func", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
@@ -68,8 +68,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- ast::Function func(Source{}, "a_func", {}, &void_type, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@@ -101,8 +101,8 @@
Source{}, create<ast::IdentifierExpression>("a")));
ASSERT_TRUE(td.DetermineResultType(body)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &void_type, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -128,8 +128,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::DiscardStatement>());
- ast::Function func(Source{}, "a_func", {}, &void_type, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@@ -168,8 +168,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("a")));
- ast::Function func(Source{}, "a_func", params, &f32, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", params,
+ &f32, body, ast::FunctionDecorationList{});
td.RegisterVariableForTesting(func.params()[0]);
td.RegisterVariableForTesting(func.params()[1]);
@@ -197,8 +197,8 @@
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
- ast::Function func(Source{}, "a_func", {}, &void_type, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@@ -213,8 +213,8 @@
TEST_F(BuilderTest, FunctionType) {
ast::type::Void void_type;
- ast::Function func(Source{}, "a_func", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
@@ -225,11 +225,11 @@
TEST_F(BuilderTest, FunctionType_DeDuplicate) {
ast::type::Void void_type;
- ast::Function func1(Source{}, "a_func", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func1(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
- ast::Function func2(Source{}, "b_func", {}, &void_type,
- create<ast::BlockStatement>(),
+ ast::Function func2(Source{}, mod->RegisterSymbol("b_func"), "b_func", {},
+ &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func1));
@@ -307,12 +307,12 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "a", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol("a"), "a", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod->AddFunction(func);
}
@@ -334,12 +334,12 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "b", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod->RegisterSymbol("b"), "b", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod->AddFunction(func);
}
diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc
index 9839711..d839c9c 100644
--- a/src/writer/spirv/builder_intrinsic_test.cc
+++ b/src/writer/spirv/builder_intrinsic_test.cc
@@ -471,8 +471,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
@@ -505,8 +505,8 @@
auto expr = Call(param.name, 1.0f);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -533,8 +533,8 @@
auto expr = Call(param.name, vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -587,8 +587,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -612,8 +612,8 @@
auto expr = Call("length", vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -639,8 +639,8 @@
auto expr = Call("normalize", vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -671,8 +671,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -700,8 +700,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -737,8 +737,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -763,8 +763,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -792,8 +792,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -823,8 +823,8 @@
auto expr = Call(param.name, 1.0f, 1.0f, 1.0f);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -853,8 +853,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -894,8 +894,8 @@
auto expr = Call(param.name, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -922,8 +922,8 @@
auto expr = Call(param.name, vec2<i32>(1, 1));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -957,8 +957,8 @@
auto expr = Call(param.name, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -985,8 +985,8 @@
auto expr = Call(param.name, vec2<u32>(1u, 1u));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1020,8 +1020,8 @@
auto expr = Call(param.name, 1, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1048,8 +1048,8 @@
auto expr = Call(param.name, vec2<i32>(1, 1), vec2<i32>(1, 1));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1084,8 +1084,8 @@
auto expr = Call(param.name, 1u, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1112,8 +1112,8 @@
auto expr = Call(param.name, vec2<u32>(1u, 1u), vec2<u32>(1u, 1u));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1148,8 +1148,8 @@
auto expr = Call(param.name, 1, 1, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1178,8 +1178,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1213,8 +1213,8 @@
auto expr = Call(param.name, 1u, 1u, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1243,8 +1243,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1276,8 +1276,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1320,8 +1320,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1360,8 +1360,8 @@
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@@ -1405,8 +1405,8 @@
auto expr = Call("arrayLength", "ptr_var");
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, ty.void_,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc
index 8ae74b0..b57cee9 100644
--- a/src/writer/spirv/builder_switch_test.cc
+++ b/src/writer/spirv/builder_switch_test.cc
@@ -121,8 +121,8 @@
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &i32,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@@ -201,8 +201,8 @@
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &i32,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@@ -300,8 +300,8 @@
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &i32,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@@ -408,8 +408,8 @@
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &i32,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@@ -495,8 +495,8 @@
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &i32,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@@ -563,8 +563,8 @@
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
- ast::Function func(Source{}, "a_func", {}, &i32,
- create<ast::BlockStatement>(),
+ ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
+ &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 5abd8db..58982dd 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -113,7 +113,8 @@
bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module,
ast::PipelineStage stage,
const std::string& name) {
- auto* func = module.FindFunctionByNameAndStage(name, stage);
+ auto* func =
+ module.FindFunctionBySymbolAndStage(module.GetSymbol(name), stage);
if (func == nullptr) {
error_ = "Unable to find requested entry point: " + name;
return false;
@@ -153,7 +154,7 @@
}
for (auto* f : module.functions()) {
- if (!f->HasAncestorEntryPoint(name)) {
+ if (!f->HasAncestorEntryPoint(module.GetSymbol(name))) {
continue;
}
diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc
index b9d3a0c..573b391 100644
--- a/src/writer/wgsl/generator_impl_function_test.cc
+++ b/src/writer/wgsl/generator_impl_function_test.cc
@@ -46,8 +46,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
ast::type::Void void_type;
- ast::Function func(Source{}, "my_func", {}, &void_type, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {},
+ &void_type, body, ast::FunctionDecorationList{});
gen.increment_indent();
@@ -85,8 +85,8 @@
ast::VariableDecorationList{})); // decorations
ast::type::Void void_type;
- ast::Function func(Source{}, "my_func", params, &void_type, body,
- ast::FunctionDecorationList{});
+ ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", params,
+ &void_type, body, ast::FunctionDecorationList{});
gen.increment_indent();
@@ -104,7 +104,8 @@
body->append(create<ast::ReturnStatement>(Source{}));
ast::type::Void void_type;
- ast::Function func(Source{}, "my_func", {}, &void_type, body,
+ ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {},
+ &void_type, body,
ast::FunctionDecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
});
@@ -127,7 +128,7 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "my_func", {}, &void_type, body,
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@@ -150,7 +151,7 @@
ast::type::Void void_type;
ast::Function func(
- Source{}, "my_func", {}, &void_type, body,
+ Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
@@ -237,12 +238,12 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "a", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod.AddFunction(func);
}
@@ -264,12 +265,12 @@
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
- auto* func =
- create<ast::Function>(Source{}, "b", params, &void_type, body,
- ast::FunctionDecorationList{
- create<ast::StageDecoration>(
- ast::PipelineStage::kCompute, Source{}),
- });
+ auto* func = create<ast::Function>(
+ Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute,
+ Source{}),
+ });
mod.AddFunction(func);
}
diff --git a/src/writer/wgsl/generator_impl_test.cc b/src/writer/wgsl/generator_impl_test.cc
index 5343153..1b8da48 100644
--- a/src/writer/wgsl/generator_impl_test.cc
+++ b/src/writer/wgsl/generator_impl_test.cc
@@ -33,8 +33,9 @@
ast::type::Void void_type;
mod.AddFunction(create<ast::Function>(
- Source{}, "my_func", ast::VariableList{}, &void_type,
- create<ast::BlockStatement>(), ast::FunctionDecorationList{}));
+ Source{}, mod.RegisterSymbol("a_func"), "my_func", ast::VariableList{},
+ &void_type, create<ast::BlockStatement>(),
+ ast::FunctionDecorationList{}));
ASSERT_TRUE(gen.Generate(mod)) << gen.error();
EXPECT_EQ(gen.result(), R"(fn my_func() -> void {