ProgramBuilder: New helpers,change WrapInStatement
Add AddressOf() and Deref()
Add overloads of Expr() that take a source
Change WrapInStatement() to create a `let`. Unlike `var`, `let` can be
used to hold pointers.
Bug: tint:727
Change-Id: Ib2cd7ab7a7056862e064943dea04387f7e466212
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51183
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/program_builder.cc b/src/program_builder.cc
index a59b03a..6cf5eac 100644
--- a/src/program_builder.cc
+++ b/src/program_builder.cc
@@ -20,6 +20,7 @@
#include "src/debug.h"
#include "src/demangler.h"
#include "src/sem/expression.h"
+#include "src/sem/variable.h"
namespace tint {
@@ -90,6 +91,11 @@
return sem ? sem->Type() : nullptr;
}
+sem::Type* ProgramBuilder::TypeOf(const ast::Variable* var) const {
+ auto* sem = Sem().Get(var);
+ return sem ? sem->Type() : nullptr;
+}
+
const sem::Type* ProgramBuilder::TypeOf(const ast::Type* type) const {
return Sem().Get(type);
}
@@ -162,8 +168,11 @@
}
ast::Statement* ProgramBuilder::WrapInStatement(ast::Expression* expr) {
+ if (auto* ce = expr->As<ast::CallExpression>()) {
+ return create<ast::CallStatement>(ce);
+ }
// Create a temporary variable of inferred type from expr.
- return Decl(Var(symbols_.New(), nullptr, ast::StorageClass::kFunction, expr));
+ return Decl(Const(symbols_.New(), nullptr, expr));
}
ast::VariableDeclStatement* ProgramBuilder::WrapInStatement(ast::Variable* v) {
diff --git a/src/program_builder.h b/src/program_builder.h
index 42375a9..74367bc 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -55,6 +55,7 @@
#include "src/ast/type_name.h"
#include "src/ast/u32.h"
#include "src/ast/uint_literal.h"
+#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/ast/vector.h"
#include "src/ast/void.h"
@@ -920,10 +921,11 @@
/// @return nullptr
ast::IdentifierExpression* Expr(std::nullptr_t) { return nullptr; }
- /// @param name the identifier name
- /// @return an ast::IdentifierExpression with the given name
- ast::IdentifierExpression* Expr(const std::string& name) {
- return create<ast::IdentifierExpression>(Symbols().Register(name));
+ /// @param source the source information
+ /// @param symbol the identifier symbol
+ /// @return an ast::IdentifierExpression with the given symbol
+ ast::IdentifierExpression* Expr(const Source& source, Symbol symbol) {
+ return create<ast::IdentifierExpression>(source, symbol);
}
/// @param symbol the identifier symbol
@@ -932,6 +934,14 @@
return create<ast::IdentifierExpression>(symbol);
}
+ /// @param source the source information
+ /// @param variable the AST variable
+ /// @return an ast::IdentifierExpression with the variable's symbol
+ ast::IdentifierExpression* Expr(const Source& source,
+ ast::Variable* variable) {
+ return create<ast::IdentifierExpression>(source, variable->symbol());
+ }
+
/// @param variable the AST variable
/// @return an ast::IdentifierExpression with the variable's symbol
ast::IdentifierExpression* Expr(ast::Variable* variable) {
@@ -941,6 +951,19 @@
/// @param source the source information
/// @param name the identifier name
/// @return an ast::IdentifierExpression with the given name
+ ast::IdentifierExpression* Expr(const Source& source, const char* name) {
+ return create<ast::IdentifierExpression>(source, Symbols().Register(name));
+ }
+
+ /// @param name the identifier name
+ /// @return an ast::IdentifierExpression with the given name
+ ast::IdentifierExpression* Expr(const char* name) {
+ return create<ast::IdentifierExpression>(Symbols().Register(name));
+ }
+
+ /// @param source the source information
+ /// @param name the identifier name
+ /// @return an ast::IdentifierExpression with the given name
ast::IdentifierExpression* Expr(const Source& source,
const std::string& name) {
return create<ast::IdentifierExpression>(source, Symbols().Register(name));
@@ -948,28 +971,56 @@
/// @param name the identifier name
/// @return an ast::IdentifierExpression with the given name
- ast::IdentifierExpression* Expr(const char* name) {
+ ast::IdentifierExpression* Expr(const std::string& name) {
return create<ast::IdentifierExpression>(Symbols().Register(name));
}
+ /// @param source the source information
+ /// @param value the boolean value
+ /// @return a Scalar constructor for the given value
+ ast::ScalarConstructorExpression* Expr(const Source& source, bool value) {
+ return create<ast::ScalarConstructorExpression>(source, Literal(value));
+ }
+
/// @param value the boolean value
/// @return a Scalar constructor for the given value
ast::ScalarConstructorExpression* Expr(bool value) {
return create<ast::ScalarConstructorExpression>(Literal(value));
}
+ /// @param source the source information
+ /// @param value the float value
+ /// @return a Scalar constructor for the given value
+ ast::ScalarConstructorExpression* Expr(const Source& source, f32 value) {
+ return create<ast::ScalarConstructorExpression>(source, Literal(value));
+ }
+
/// @param value the float value
/// @return a Scalar constructor for the given value
ast::ScalarConstructorExpression* Expr(f32 value) {
return create<ast::ScalarConstructorExpression>(Literal(value));
}
+ /// @param source the source information
+ /// @param value the integer value
+ /// @return a Scalar constructor for the given value
+ ast::ScalarConstructorExpression* Expr(const Source& source, i32 value) {
+ return create<ast::ScalarConstructorExpression>(source, Literal(value));
+ }
+
/// @param value the integer value
/// @return a Scalar constructor for the given value
ast::ScalarConstructorExpression* Expr(i32 value) {
return create<ast::ScalarConstructorExpression>(Literal(value));
}
+ /// @param source the source information
+ /// @param value the unsigned int value
+ /// @return a Scalar constructor for the given value
+ ast::ScalarConstructorExpression* Expr(const Source& source, u32 value) {
+ return create<ast::ScalarConstructorExpression>(source, Literal(value));
+ }
+
/// @param value the unsigned int value
/// @return a Scalar constructor for the given value
ast::ScalarConstructorExpression* Expr(u32 value) {
@@ -1354,6 +1405,40 @@
return var;
}
+ /// @param source the source information
+ /// @param expr the expression to take the address of
+ /// @return an ast::UnaryOpExpression that takes the address of `expr`
+ template <typename EXPR>
+ ast::UnaryOpExpression* AddressOf(const Source& source, EXPR&& expr) {
+ return create<ast::UnaryOpExpression>(source, ast::UnaryOp::kAddressOf,
+ Expr(std::forward<EXPR>(expr)));
+ }
+
+ /// @param expr the expression to take the address of
+ /// @return an ast::UnaryOpExpression that takes the address of `expr`
+ template <typename EXPR>
+ ast::UnaryOpExpression* AddressOf(EXPR&& expr) {
+ return create<ast::UnaryOpExpression>(ast::UnaryOp::kAddressOf,
+ Expr(std::forward<EXPR>(expr)));
+ }
+
+ /// @param source the source information
+ /// @param expr the expression to perform an indirection on
+ /// @return an ast::UnaryOpExpression that dereferences the pointer `expr`
+ template <typename EXPR>
+ ast::UnaryOpExpression* Deref(const Source& source, EXPR&& expr) {
+ return create<ast::UnaryOpExpression>(source, ast::UnaryOp::kIndirection,
+ Expr(std::forward<EXPR>(expr)));
+ }
+
+ /// @param expr the expression to perform an indirection on
+ /// @return an ast::UnaryOpExpression that dereferences the pointer `expr`
+ template <typename EXPR>
+ ast::UnaryOpExpression* Deref(EXPR&& expr) {
+ return create<ast::UnaryOpExpression>(ast::UnaryOp::kIndirection,
+ Expr(std::forward<EXPR>(expr)));
+ }
+
/// @param func the function name
/// @param args the function call arguments
/// @returns a `ast::CallExpression` to the function `func`, with the
@@ -1845,6 +1930,14 @@
/// expression has no resolved type.
sem::Type* TypeOf(const ast::Expression* expr) const;
+ /// Helper for returning the resolved semantic type of the variable `var`.
+ /// @note As the Resolver is run when the Program is built, this will only be
+ /// useful for the Resolver itself and tests that use their own Resolver.
+ /// @param var the AST variable
+ /// @return the resolved semantic type for the variable, or nullptr if the
+ /// variable has no resolved type.
+ sem::Type* TypeOf(const ast::Variable* var) const;
+
/// Helper for returning the resolved semantic type of the AST type `type`.
/// @note As the Resolver is run when the Program is built, this will only be
/// useful for the Resolver itself and tests that use their own Resolver.
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc
index 8f0d658..423a046 100644
--- a/src/writer/hlsl/generator_impl_member_accessor_test.cc
+++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -128,7 +128,7 @@
Global("str", s, ast::StorageClass::kPrivate);
auto* expr = MemberAccessor("str", "mem");
- WrapInFunction(expr);
+ WrapInFunction(Var("expr", ty.f32(), ast::StorageClass::kNone, expr));
GeneratorImpl& gen = SanitizeAndBuild();
@@ -141,7 +141,7 @@
[numthreads(1, 1, 1)]
void test_function() {
- float tint_symbol = str.mem;
+ float expr = str.mem;
return;
}