tint: Have ast::CallExpression use ast::Identifier
Instead of ast::IdentifierExpression.
The name is not an expression, as it resolves to a function, builtin or
type.
Bug: tint:1257
Change-Id: I13143f2bbc208e9e2934dad20fe5c9aa59520b68
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118341
Kokoro: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/ast/call_expression.cc b/src/tint/ast/call_expression.cc
index 5a8aa96..145092f 100644
--- a/src/tint/ast/call_expression.cc
+++ b/src/tint/ast/call_expression.cc
@@ -23,7 +23,7 @@
namespace tint::ast {
namespace {
-CallExpression::Target ToTarget(const IdentifierExpression* name) {
+CallExpression::Target ToTarget(const Identifier* name) {
CallExpression::Target target;
target.name = name;
return target;
@@ -38,7 +38,7 @@
CallExpression::CallExpression(ProgramID pid,
NodeID nid,
const Source& src,
- const IdentifierExpression* name,
+ const Identifier* name,
utils::VectorRef<const Expression*> a)
: Base(pid, nid, src), target(ToTarget(name)), args(std::move(a)) {
TINT_ASSERT(AST, name);
diff --git a/src/tint/ast/call_expression.h b/src/tint/ast/call_expression.h
index 587a88f..a3b6edc 100644
--- a/src/tint/ast/call_expression.h
+++ b/src/tint/ast/call_expression.h
@@ -20,7 +20,7 @@
// Forward declarations
namespace tint::ast {
class Type;
-class IdentifierExpression;
+class Identifier;
} // namespace tint::ast
namespace tint::ast {
@@ -41,7 +41,7 @@
CallExpression(ProgramID pid,
NodeID nid,
const Source& source,
- const IdentifierExpression* name,
+ const Identifier* name,
utils::VectorRef<const Expression*> args);
/// Constructor
@@ -71,7 +71,7 @@
struct Target {
/// name is a function or builtin to call, or type name to construct or
/// cast-to
- const IdentifierExpression* name = nullptr;
+ const Identifier* name = nullptr;
/// type to construct or cast-to
const Type* type = nullptr;
};
diff --git a/src/tint/ast/call_expression_test.cc b/src/tint/ast/call_expression_test.cc
index 8b5b6a9..774d6a2 100644
--- a/src/tint/ast/call_expression_test.cc
+++ b/src/tint/ast/call_expression_test.cc
@@ -21,13 +21,13 @@
using CallExpressionTest = TestHelper;
TEST_F(CallExpressionTest, CreationIdentifier) {
- auto* func = Expr("func");
+ auto* func = Ident("func");
utils::Vector params{
Expr("param1"),
Expr("param2"),
};
- auto* stmt = create<CallExpression>(func, params);
+ auto* stmt = Call(func, params);
EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr);
@@ -38,8 +38,8 @@
}
TEST_F(CallExpressionTest, CreationIdentifier_WithSource) {
- auto* func = Expr("func");
- auto* stmt = create<CallExpression>(Source{{20, 2}}, func, utils::Empty);
+ auto* func = Ident("func");
+ auto* stmt = Call(Source{{20, 2}}, func);
EXPECT_EQ(stmt->target.name, func);
EXPECT_EQ(stmt->target.type, nullptr);
@@ -55,7 +55,7 @@
Expr("param2"),
};
- auto* stmt = create<CallExpression>(type, params);
+ auto* stmt = Construct(type, params);
EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type);
@@ -67,7 +67,7 @@
TEST_F(CallExpressionTest, CreationType_WithSource) {
auto* type = ty.f32();
- auto* stmt = create<CallExpression>(Source{{20, 2}}, type, utils::Empty);
+ auto* stmt = Construct(Source{{20, 2}}, type);
EXPECT_EQ(stmt->target.name, nullptr);
EXPECT_EQ(stmt->target.type, type);
@@ -77,8 +77,8 @@
}
TEST_F(CallExpressionTest, IsCall) {
- auto* func = Expr("func");
- auto* stmt = create<CallExpression>(func, utils::Empty);
+ auto* func = Ident("func");
+ auto* stmt = Call(func);
EXPECT_TRUE(stmt->Is<CallExpression>());
}
@@ -86,7 +86,7 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CallExpression>(static_cast<IdentifierExpression*>(nullptr), utils::Empty);
+ b.Call(static_cast<Identifier*>(nullptr));
},
"internal compiler error");
}
@@ -95,7 +95,7 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CallExpression>(static_cast<Type*>(nullptr), utils::Empty);
+ b.Construct(static_cast<Type*>(nullptr));
},
"internal compiler error");
}
@@ -104,11 +104,11 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CallExpression>(b.Expr("func"), utils::Vector{
- b.Expr("param1"),
- nullptr,
- b.Expr("param2"),
- });
+ b.Call(b.Ident("func"), utils::Vector{
+ b.Expr("param1"),
+ nullptr,
+ b.Expr("param2"),
+ });
},
"internal compiler error");
}
@@ -118,7 +118,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.create<CallExpression>(b2.Expr("func"), utils::Empty);
+ b1.Call(b2.Ident("func"));
},
"internal compiler error");
}
@@ -128,7 +128,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.create<CallExpression>(b2.ty.f32(), utils::Empty);
+ b1.Construct(b2.ty.f32());
},
"internal compiler error");
}
@@ -138,7 +138,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.create<CallExpression>(b1.Expr("func"), utils::Vector{b2.Expr("param1")});
+ b1.Call(b1.Ident("func"), b2.Expr("param1"));
},
"internal compiler error");
}
diff --git a/src/tint/ast/call_statement_test.cc b/src/tint/ast/call_statement_test.cc
index 84d2b41..a185911 100644
--- a/src/tint/ast/call_statement_test.cc
+++ b/src/tint/ast/call_statement_test.cc
@@ -23,14 +23,14 @@
using CallStatementTest = TestHelper;
TEST_F(CallStatementTest, Creation) {
- auto* expr = create<CallExpression>(Expr("func"), utils::Empty);
+ auto* expr = Call("func");
- auto* c = create<CallStatement>(expr);
+ auto* c = CallStmt(expr);
EXPECT_EQ(c->expr, expr);
}
TEST_F(CallStatementTest, IsCall) {
- auto* c = create<CallStatement>(Call("f"));
+ auto* c = CallStmt(Call("f"));
EXPECT_TRUE(c->Is<CallStatement>());
}
@@ -38,7 +38,7 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CallStatement>(nullptr);
+ b.CallStmt(nullptr);
},
"internal compiler error");
}
@@ -48,7 +48,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.create<CallStatement>(b2.create<CallExpression>(b2.Expr("func"), utils::Empty));
+ b1.CallStmt(b2.Call("func"));
},
"internal compiler error");
}
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index aaee0af..ce3712d 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -1160,7 +1160,13 @@
/// @return an ast::Identifier with the given symbol
template <typename IDENTIFIER>
const ast::Identifier* Ident(IDENTIFIER&& identifier) {
- return create<ast::Identifier>(Sym(std::forward<IDENTIFIER>(identifier)));
+ if constexpr (traits::IsTypeOrDerived<
+ std::decay_t<std::remove_pointer_t<std::decay_t<IDENTIFIER>>>,
+ ast::Identifier>) {
+ return identifier; // Pass-through
+ } else {
+ return create<ast::Identifier>(Sym(std::forward<IDENTIFIER>(identifier)));
+ }
}
/// @param expr the expression
@@ -2054,7 +2060,7 @@
/// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, typename... ARGS>
const ast::CallExpression* Call(const Source& source, NAME&& func, ARGS&&... args) {
- return create<ast::CallExpression>(source, Expr(func),
+ return create<ast::CallExpression>(source, Ident(func),
ExprList(std::forward<ARGS>(args)...));
}
@@ -2064,7 +2070,7 @@
/// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>>
const ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
- return create<ast::CallExpression>(Expr(func), ExprList(std::forward<ARGS>(args)...));
+ return create<ast::CallExpression>(Ident(func), ExprList(std::forward<ARGS>(args)...));
}
/// @param source the source information
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index bf4088a..f068a62 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -1316,10 +1316,10 @@
// Call the inner function. It has no parameters.
stmts.Push(create<ast::CallStatement>(
source,
- create<ast::CallExpression>(source,
- create<ast::IdentifierExpression>(
- source, builder_.Symbols().Register(ep_info_->inner_name)),
- utils::Empty)));
+ create<ast::CallExpression>(
+ source,
+ create<ast::Identifier>(source, builder_.Symbols().Register(ep_info_->inner_name)),
+ utils::Empty)));
// Pipeline outputs are mapped to the return value.
if (ep_info_->outputs.IsEmpty()) {
@@ -3854,7 +3854,7 @@
params.Push(MakeOperand(inst, 0).expr);
return {ast_type, create<ast::CallExpression>(
Source{},
- create<ast::IdentifierExpression>(
+ create<ast::Identifier>(
Source{}, builder_.Symbols().Register(unary_builtin_name)),
std::move(params))};
}
@@ -4106,7 +4106,7 @@
return {};
}
- auto* func = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
+ auto* func = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList operands;
const Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs.
@@ -5212,7 +5212,7 @@
bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
// We ignore function attributes such as Inline, DontInline, Pure, Const.
auto name = namer_.Name(inst.GetSingleWordInOperand(0));
- auto* function = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
+ auto* function = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList args;
for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
@@ -5302,7 +5302,7 @@
TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(opcode(inst));
auto* name = sem::str(builtin);
- auto* ident = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
+ auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
ExpressionList params;
const Type* first_operand_type = nullptr;
@@ -5341,11 +5341,10 @@
params.Push(true_value.expr);
// The condition goes last.
params.Push(condition.expr);
- return {op_ty,
- create<ast::CallExpression>(Source{},
- create<ast::IdentifierExpression>(
- Source{}, builder_.Symbols().Register("select")),
- std::move(params))};
+ return {op_ty, create<ast::CallExpression>(
+ Source{},
+ create<ast::Identifier>(Source{}, builder_.Symbols().Register("select")),
+ std::move(params))};
}
return {};
}
@@ -5650,8 +5649,7 @@
return false;
}
- auto* ident =
- create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(builtin_name));
+ auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(builtin_name));
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(args));
if (inst.type_id() != 0) {
@@ -5741,8 +5739,8 @@
// Invoke textureDimensions.
// If the texture is arrayed, combine with the result from
// textureNumLayers.
- auto* dims_ident = create<ast::IdentifierExpression>(
- Source{}, builder_.Symbols().Register("textureDimensions"));
+ auto* dims_ident =
+ create<ast::Identifier>(Source{}, builder_.Symbols().Register("textureDimensions"));
ExpressionList dims_args{GetImageExpression(inst)};
if (op == spv::Op::OpImageQuerySizeLod) {
dims_args.Push(MakeOperand(inst, 1).expr);
@@ -5758,7 +5756,7 @@
}
exprs.Push(dims_call);
if (ast::IsTextureArray(dims)) {
- auto* layers_ident = create<ast::IdentifierExpression>(
+ auto* layers_ident = create<ast::Identifier>(
Source{}, builder_.Symbols().Register("textureNumLayers"));
auto num_layers = create<ast::CallExpression>(
Source{}, layers_ident, utils::Vector{GetImageExpression(inst)});
@@ -5789,7 +5787,7 @@
const auto* name =
(op == spv::Op::OpImageQueryLevels) ? "textureNumLevels" : "textureNumSamples";
auto* levels_ident =
- create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
+ create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
const ast::Expression* ast_expr = create<ast::CallExpression>(
Source{}, levels_ident, utils::Vector{GetImageExpression(inst)});
auto* result_type = parser_impl_.ConvertType(inst.type_id());
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 524b85f..dd6ef91 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -2436,7 +2436,7 @@
t.source(),
create<ast::CallExpression>(
t.source(),
- create<ast::IdentifierExpression>(t.source(), builder_.Symbols().Register(t.to_str())),
+ create<ast::Identifier>(t.source(), builder_.Symbols().Register(t.to_str())),
std::move(params.value)));
}
@@ -2642,19 +2642,19 @@
"in parentheses");
}
- auto* ident =
- create<ast::IdentifierExpression>(t.source(), builder_.Symbols().Register(t.to_str()));
-
if (peek_is(Token::Type::kParenLeft)) {
auto params = expect_argument_expression_list("function call");
if (params.errored) {
return Failure::kErrored;
}
+ auto* ident =
+ create<ast::Identifier>(t.source(), builder_.Symbols().Register(t.to_str()));
return create<ast::CallExpression>(t.source(), ident, std::move(params.value));
}
- return ident;
+ return create<ast::IdentifierExpression>(t.source(),
+ builder_.Symbols().Register(t.to_str()));
}
if (t.Is(Token::Type::kParenLeft)) {
diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc
index 2d3f2f3..281ca47 100644
--- a/src/tint/resolver/builtin_validation_test.cc
+++ b/src/tint/resolver/builtin_validation_test.cc
@@ -39,10 +39,7 @@
TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageDirect) {
// @compute @workgroup_size(1) fn func { return dpdx(1.0); }
- auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
- utils::Vector{
- Expr(1_f),
- });
+ auto* dpdx = Call(Source{{3, 4}}, "dpdx", 1_f);
Func(Source{{1, 2}}, "func", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(dpdx),
@@ -62,10 +59,7 @@
// fn f2 { f1(); }
// @compute @workgroup_size(1) fn main { return f2(); }
- auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
- utils::Vector{
- Expr(1_f),
- });
+ auto* dpdx = Call(Source{{3, 4}}, "dpdx", 1_f);
Func(Source{{1, 2}}, "f0", utils::Empty, ty.void_(),
utils::Vector{
CallStmt(dpdx),
@@ -138,7 +132,7 @@
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsFunction) {
GlobalConst(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i));
- WrapInFunction(Call(Expr(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
+ WrapInFunction(Call(Ident(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix'
@@ -167,7 +161,7 @@
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalVarUsedAsFunction) {
GlobalVar(Source{{12, 34}}, "mix", ty.i32(), Expr(1_i), type::AddressSpace::kPrivate);
- WrapInFunction(Call(Expr(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
+ WrapInFunction(Call(Ident(Source{{56, 78}}, "mix"), 1_f, 2_f, 3_f));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: cannot call variable 'mix'
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 5ece264..5da6303 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -39,6 +39,7 @@
#include "src/tint/ast/for_loop_statement.h"
#include "src/tint/ast/i32.h"
#include "src/tint/ast/id_attribute.h"
+#include "src/tint/ast/identifier.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/ast/internal_attribute.h"
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index f1f0a0c..84b5469 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -548,7 +548,7 @@
return node;
}
case SymbolUseKind::CallFunction: {
- auto* node = b.Expr(source, symbol);
+ auto* node = b.Ident(source, symbol);
statements.Push(b.CallStmt(b.Call(node)));
return node;
}
@@ -651,7 +651,8 @@
// fn A() { B(); }
// fn B() {}
- Func("A", utils::Empty, ty.void_(), utils::Vector{CallStmt(Call(Expr(Source{{12, 34}}, "B")))});
+ Func("A", utils::Empty, ty.void_(),
+ utils::Vector{CallStmt(Call(Ident(Source{{12, 34}}, "B")))});
Func(Source{{56, 78}}, "B", utils::Empty, ty.void_(), utils::Vector{Return()});
Build();
@@ -812,7 +813,7 @@
// fn main() { main(); }
Func(Source{{12, 34}}, "main", utils::Empty, ty.void_(),
- utils::Vector{CallStmt(Call(Expr(Source{{56, 78}}, "main")))});
+ utils::Vector{CallStmt(Call(Ident(Source{{56, 78}}, "main")))});
Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main'
56:78 note: function 'main' calls function 'main' here)");
@@ -826,17 +827,17 @@
// 5: fn b() { c(); }
Func(Source{{1, 1}}, "a", utils::Empty, ty.void_(),
- utils::Vector{CallStmt(Call(Expr(Source{{1, 10}}, "b")))});
+ utils::Vector{CallStmt(Call(Ident(Source{{1, 10}}, "b")))});
Func(Source{{2, 1}}, "e", utils::Empty, ty.void_(), utils::Empty);
Func(Source{{3, 1}}, "d", utils::Empty, ty.void_(),
utils::Vector{
- CallStmt(Call(Expr(Source{{3, 10}}, "e"))),
- CallStmt(Call(Expr(Source{{3, 10}}, "b"))),
+ CallStmt(Call(Ident(Source{{3, 10}}, "e"))),
+ CallStmt(Call(Ident(Source{{3, 10}}, "b"))),
});
Func(Source{{4, 1}}, "c", utils::Empty, ty.void_(),
- utils::Vector{CallStmt(Call(Expr(Source{{4, 10}}, "d")))});
+ utils::Vector{CallStmt(Call(Ident(Source{{4, 10}}, "d")))});
Func(Source{{5, 1}}, "b", utils::Empty, ty.void_(),
- utils::Vector{CallStmt(Call(Expr(Source{{5, 10}}, "c")))});
+ utils::Vector{CallStmt(Call(Ident(Source{{5, 10}}, "c")))});
Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b'
5:10 note: function 'b' calls function 'c' here
@@ -1232,7 +1233,7 @@
};
#define V add_use(value_decl, Expr(value_sym), __LINE__, "V()")
#define T add_use(type_decl, ty.type_name(type_sym), __LINE__, "T()")
-#define F add_use(func_decl, Expr(func_sym), __LINE__, "F()")
+#define F add_use(func_decl, Ident(func_sym), __LINE__, "F()")
Alias(Sym(), T);
Structure(Sym(), //
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index b8636f9..719db08 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -5330,7 +5330,7 @@
args.Push(b.AddressOf(name));
}
main_body.Push(b.Assign("v0", "non_uniform_global"));
- main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Expr("foo"), args)));
+ main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Ident("foo"), args)));
main_body.Push(b.If(b.Equal("v254", 0_i), b.Block(b.CallStmt(b.Call("workgroupBarrier")))));
b.Func("main", utils::Empty, ty.void_(), main_body);
diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc
index 9ef4443..ebb6d7b 100644
--- a/src/tint/transform/multiplanar_external_texture.cc
+++ b/src/tint/transform/multiplanar_external_texture.cc
@@ -434,14 +434,13 @@
buildTextureBuiltinBody(sem::BuiltinType::kTextureSampleBaseClampToEdge));
}
- const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym);
- return b.Call(exp, utils::Vector{
- plane_0_binding_param,
- b.Expr(syms.plane_1),
- ctx.Clone(expr->args[1]),
- ctx.Clone(expr->args[2]),
- b.Expr(syms.params),
- });
+ return b.Call(texture_sample_external_sym, utils::Vector{
+ plane_0_binding_param,
+ b.Expr(syms.plane_1),
+ ctx.Clone(expr->args[1]),
+ ctx.Clone(expr->args[2]),
+ b.Expr(syms.params),
+ });
}
/// Creates the textureLoadExternal function if needed and returns a call expression to it.
diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc
index d16bf8c..ae6c046 100644
--- a/src/tint/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/transform/promote_side_effects_to_decl.cc
@@ -512,9 +512,6 @@
return clone_maybe_hoisted(bitcast);
},
[&](const ast::CallExpression* call) {
- if (call->target.name) {
- ctx.Replace(call->target.name, decompose(call->target.name));
- }
for (auto* a : call->args) {
ctx.Replace(a, decompose(a));
}
diff --git a/src/tint/transform/renamer.cc b/src/tint/transform/renamer.cc
index 6f485a4..325d3ac 100644
--- a/src/tint/transform/renamer.cc
+++ b/src/tint/transform/renamer.cc
@@ -1262,7 +1262,9 @@
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
// Identifiers that need to keep their symbols preserved.
- utils::Hashset<const ast::IdentifierExpression*, 8> preserved_identifiers;
+ utils::Hashset<const ast::Identifier*, 8> preserved_identifiers;
+ // Identifiers expressions that need to keep their symbols preserved.
+ utils::Hashset<const ast::IdentifierExpression*, 8> preserved_identifiers_expressions;
// Type names that need to keep their symbols preserved.
utils::Hashset<const ast::TypeName*, 8> preserved_type_names;
@@ -1287,11 +1289,11 @@
[&](const ast::MemberAccessorExpression* accessor) {
auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
if (sem->Is<sem::Swizzle>()) {
- preserved_identifiers.Add(accessor->member);
+ preserved_identifiers_expressions.Add(accessor->member);
} else if (auto* str_expr = src->Sem().Get(accessor->structure)) {
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
if (ty->Declaration() == nullptr) { // Builtin structure
- preserved_identifiers.Add(accessor->member);
+ preserved_identifiers_expressions.Add(accessor->member);
}
}
}
@@ -1314,7 +1316,7 @@
}
},
[&](const ast::DiagnosticControl* diagnostic) {
- preserved_identifiers.Add(diagnostic->rule_name);
+ preserved_identifiers_expressions.Add(diagnostic->rule_name);
},
[&](const ast::TypeName* type_name) {
if (is_type_short_name(type_name->name)) {
@@ -1376,11 +1378,21 @@
return sym_out;
});
- ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
+ ctx.ReplaceAll([&](const ast::Identifier* ident) -> const ast::Identifier* {
if (preserved_identifiers.Contains(ident)) {
auto sym_in = ident->symbol;
auto str = src->Symbols().NameFor(sym_in);
auto sym_out = b.Symbols().Register(str);
+ return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), sym_out);
+ }
+ return nullptr; // Clone ident. Uses the symbol remapping above.
+ });
+
+ ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
+ if (preserved_identifiers_expressions.Contains(ident)) {
+ auto sym_in = ident->symbol;
+ auto str = src->Symbols().NameFor(sym_in);
+ auto sym_out = b.Symbols().Register(str);
return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out);
}
return nullptr; // Clone ident. Uses the symbol remapping above.
diff --git a/src/tint/writer/glsl/generator_impl_import_test.cc b/src/tint/writer/glsl/generator_impl_import_test.cc
index 54c2405..adfebe4 100644
--- a/src/tint/writer/glsl/generator_impl_import_test.cc
+++ b/src/tint/writer/glsl/generator_impl_import_test.cc
@@ -34,8 +34,7 @@
TEST_P(GlslImportData_SingleParamTest, FloatScalar) {
auto param = GetParam();
- auto* ident = Expr(param.name);
- auto* expr = Call(ident, 1_f);
+ auto* expr = Call(param.name, 1_f);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
@@ -91,8 +90,7 @@
TEST_P(GlslImportData_SingleVectorParamTest, FloatVector) {
auto param = GetParam();
- auto* ident = Expr(param.name);
- auto* expr = Call(ident, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
+ auto* expr = Call(param.name, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/hlsl/generator_impl_import_test.cc b/src/tint/writer/hlsl/generator_impl_import_test.cc
index 9d00c3f..de864c4 100644
--- a/src/tint/writer/hlsl/generator_impl_import_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_import_test.cc
@@ -34,8 +34,7 @@
TEST_P(HlslImportData_SingleParamTest, FloatScalar) {
auto param = GetParam();
- auto* ident = Expr(param.name);
- auto* expr = Call(ident, 1_f);
+ auto* expr = Call(param.name, 1_f);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
@@ -90,8 +89,7 @@
TEST_P(HlslImportData_SingleVectorParamTest, FloatVector) {
auto param = GetParam();
- auto* ident = Expr(param.name);
- auto* expr = Call(ident, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
+ auto* expr = Call(param.name, vec3<f32>(0.1_f, 0.2_f, 0.3_f));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/msl/generator_impl_builtin_texture_test.cc b/src/tint/writer/msl/generator_impl_builtin_texture_test.cc
index ef5b427..5414416 100644
--- a/src/tint/writer/msl/generator_impl_builtin_texture_test.cc
+++ b/src/tint/writer/msl/generator_impl_builtin_texture_test.cc
@@ -276,7 +276,7 @@
param.BuildTextureVariable(this);
param.BuildSamplerVariable(this);
- auto* call = Call(Expr(param.function), param.args(this));
+ auto* call = Call(Ident(param.function), param.args(this));
auto* stmt = CallStmt(call);
Func("main", utils::Empty, ty.void_(), utils::Vector{stmt},
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 4b9d512..3619c57 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -239,9 +239,7 @@
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
if (expr->target.name) {
- if (!EmitExpression(out, expr->target.name)) {
- return false;
- }
+ out << program_->Symbols().NameFor(expr->target.name->symbol);
} else if (TINT_LIKELY(expr->target.type)) {
if (!EmitType(out, expr->target.type)) {
return false;