[tint][ir] Further simplify FromProgram
Use a stack of tasks to walk and evaluate expressions. Uses fewer stacks and only type dispatches once per expression node.
Move all the expressions emission functions into EmitExpression() to prevent them accidentally being called by other emission functions. This reduces the expression related methods to just EmitExpression() and EmitValueExpression().
Move expr_to_result_ into EmitExpression(), substantially reducing the lifetime of this map, and reducing the likelihood it would spill to the heap.
Emit loads at expression emission time instead of when fetching the result. This is more truthful to the semantic info.
Bug: tint:1718
Change-Id: I73685b4d78d910922710f8779cd5cce26826157e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/141962
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index a999fda..7c5ad2a 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -83,7 +83,6 @@
#include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h"
#include "src/tint/lang/wgsl/ast/switch_statement.h"
#include "src/tint/lang/wgsl/ast/templated_identifier.h"
-#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/lang/wgsl/ast/var.h"
#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
@@ -153,9 +152,7 @@
ir::Value* index = nullptr;
};
- /// Maps expressions to their result values
- utils::Hashmap<const ast::Expression*, std::variant<ir::Value*, VectorRefElementAccess>, 64>
- expr_to_result_;
+ using ValueOrVecElAccess = std::variant<ir::Value*, VectorRefElementAccess>;
/// The current block for expressions.
Block* current_block_ = nullptr;
@@ -210,36 +207,6 @@
return nullptr;
}
- Value* ResultForExpression(const ast::Expression* expr) {
- auto val = expr_to_result_.Get(expr);
- if (!val) {
- return nullptr;
- }
-
- if (auto* v = std::get_if<ir::Value*>(&(val.value()))) {
- // If this expression maps to sem::Load, insert a load instruction to get the result.
- auto* sem = program_->Sem().GetVal(expr);
- if ((*v)->Type()->Is<type::Pointer>() && sem->Is<sem::Load>()) {
- auto* load = builder_.Load(*v);
- current_block_->Append(load);
- // There are cases where we get the result for an expression multiple times (binary
- // that does a LHS load for one, so make sure we cache the load away for any
- // subsequent calls.
- expr_to_result_.Replace(expr, load->Result());
- return load->Result();
- }
- return *v;
- } else if (auto ref = std::get_if<VectorRefElementAccess>(&(val.value()))) {
- // Vector reference accesses need to map to LoadVectorElement()
- auto* load = builder_.LoadVectorElement(ref->vector, ref->index);
- current_block_->Append(load);
- return load->Result();
- }
-
- TINT_UNREACHABLE(IR, diagnostics_);
- return nullptr;
- }
-
ResultType EmitModule() {
auto* sem = program_->Sem().Module();
@@ -513,20 +480,6 @@
});
}
- void StoreResult(const ast::Expression* lhs, ir::Value* rhs) {
- auto val = expr_to_result_.Get(lhs);
- if (!val) {
- return;
- }
-
- auto b = builder_.With(current_block_);
- if (auto* v = std::get_if<ir::Value*>(&(val.value()))) {
- b.Store(*v, rhs);
- } else if (auto ref = std::get_if<VectorRefElementAccess>(&(val.value()))) {
- b.StoreVectorElement(ref->vector, ref->index, rhs);
- }
- }
-
void EmitAssignment(const ast::AssignmentStatement* stmt) {
// If assigning to a phony, just generate the RHS and we're done. Note that, because
// this isn't used, a subsequent transform could remove it due to it being dead code.
@@ -534,54 +487,55 @@
// used). If that happens we have to either fix this to store to a phony value, or make
// sure we pull the interface before doing the dead code elimination.
if (stmt->lhs->Is<ast::PhonyExpression>()) {
- (void)EmitExpressionWithResult(stmt->rhs);
+ (void)EmitValueExpression(stmt->rhs);
return;
}
- EmitExpression(stmt->lhs);
+ auto lhs = EmitExpression(stmt->lhs);
- auto rhs = EmitExpressionWithResult(stmt->rhs);
+ auto rhs = EmitValueExpression(stmt->rhs);
if (!rhs) {
return;
}
- StoreResult(stmt->lhs, rhs.Get());
+
+ auto b = builder_.With(current_block_);
+ if (auto* v = std::get_if<ir::Value*>(&lhs)) {
+ b.Store(*v, rhs);
+ } else if (auto ref = std::get_if<VectorRefElementAccess>(&lhs)) {
+ b.StoreVectorElement(ref->vector, ref->index, rhs);
+ }
}
void EmitIncrementDecrement(const ast::IncrementDecrementStatement* stmt) {
- EmitExpression(stmt->lhs);
+ auto lhs = EmitExpression(stmt->lhs);
auto* one = program_->TypeOf(stmt->lhs)->UnwrapRef()->is_signed_integer_scalar()
? builder_.Constant(1_i)
: builder_.Constant(1_u);
- EmitCompoundAssignment(stmt->lhs, one,
+ EmitCompoundAssignment(lhs, one,
stmt->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract);
}
void EmitCompoundAssignment(const ast::CompoundAssignmentStatement* stmt) {
- EmitExpression(stmt->lhs);
+ auto lhs = EmitExpression(stmt->lhs);
- auto rhs = EmitExpressionWithResult(stmt->rhs);
+ auto rhs = EmitValueExpression(stmt->rhs);
if (!rhs) {
return;
}
- EmitCompoundAssignment(stmt->lhs, rhs.Get(), stmt->op);
+ EmitCompoundAssignment(lhs, rhs, stmt->op);
}
- void EmitCompoundAssignment(const ast::Expression* lhs_expr, ir::Value* rhs, ast::BinaryOp op) {
- auto val = expr_to_result_.Get(lhs_expr);
- if (!val) {
- return;
- }
-
+ void EmitCompoundAssignment(ValueOrVecElAccess lhs, ir::Value* rhs, ast::BinaryOp op) {
auto b = builder_.With(current_block_);
- if (auto* v = std::get_if<ir::Value*>(&(val.value()))) {
+ if (auto* v = std::get_if<ir::Value*>(&lhs)) {
auto* load = b.Load(*v);
auto* ty = load->Result()->Type();
auto* inst = current_block_->Append(BinaryOp(ty, load->Result(), rhs, op));
b.Store(*v, inst);
- } else if (auto ref = std::get_if<VectorRefElementAccess>(&(val.value()))) {
+ } else if (auto ref = std::get_if<VectorRefElementAccess>(&lhs)) {
auto* load = b.LoadVectorElement(ref->vector, ref->index);
auto* ty = load->Result()->Type();
auto* inst = b.Append(BinaryOp(ty, load->Result(), rhs, op));
@@ -601,11 +555,11 @@
void EmitIf(const ast::IfStatement* stmt) {
// Emit the if condition into the end of the preceding block
- auto reg = EmitExpressionWithResult(stmt->condition);
+ auto reg = EmitValueExpression(stmt->condition);
if (!reg) {
return;
}
- auto* if_inst = builder_.If(reg.Get());
+ auto* if_inst = builder_.If(reg);
current_block_->Append(if_inst);
{
@@ -683,13 +637,13 @@
TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Body());
// Emit the while condition into the Start().target of the loop
- auto reg = EmitExpressionWithResult(stmt->condition);
+ auto reg = EmitValueExpression(stmt->condition);
if (!reg) {
return;
}
// Create an `if (cond) {} else {break;}` control flow
- auto* if_inst = builder_.If(reg.Get());
+ auto* if_inst = builder_.If(reg);
current_block_->Append(if_inst);
{
@@ -736,13 +690,13 @@
if (stmt->condition) {
// Emit the condition into the target target of the loop body
- auto reg = EmitExpressionWithResult(stmt->condition);
+ auto reg = EmitValueExpression(stmt->condition);
if (!reg) {
return;
}
// Create an `if (cond) {} else {break;}` control flow
- auto* if_inst = builder_.If(reg.Get());
+ auto* if_inst = builder_.If(reg);
current_block_->Append(if_inst);
{
@@ -770,11 +724,11 @@
void EmitSwitch(const ast::SwitchStatement* stmt) {
// Emit the condition into the preceding block
- auto reg = EmitExpressionWithResult(stmt->condition);
+ auto reg = EmitValueExpression(stmt->condition);
if (!reg) {
return;
}
- auto* switch_inst = builder_.Switch(reg.Get());
+ auto* switch_inst = builder_.Switch(reg);
current_block_->Append(switch_inst);
ControlStackScope scope(this, switch_inst);
@@ -802,11 +756,11 @@
void EmitReturn(const ast::ReturnStatement* stmt) {
Value* ret_value = nullptr;
if (stmt->value) {
- auto ret = EmitExpressionWithResult(stmt->value);
+ auto ret = EmitValueExpression(stmt->value);
if (!ret) {
return;
}
- ret_value = ret.Get();
+ ret_value = ret;
}
if (ret_value) {
SetTerminator(builder_.Return(current_function_, ret_value));
@@ -852,269 +806,467 @@
auto* current_control = FindEnclosingControl(ControlFlags::kExcludeSwitch);
// Emit the break-if condition into the end of the preceding block
- auto cond = EmitExpressionWithResult(stmt->condition);
+ auto cond = EmitValueExpression(stmt->condition);
if (!cond) {
return;
}
- SetTerminator(builder_.BreakIf(current_control->As<ir::Loop>(), cond.Get()));
+ SetTerminator(builder_.BreakIf(current_control->As<ir::Loop>(), cond));
}
- void EmitAccess(const ast::AccessorExpression* expr) {
- if (auto vec_access = AsVectorRefElementAccess(expr)) {
- expr_to_result_.Add(expr, vec_access.value());
- return;
- }
+ ValueOrVecElAccess EmitExpression(const ast::Expression* root) {
+ struct Emitter {
+ explicit Emitter(Impl& i) : impl(i) {}
- auto* obj = ResultForExpression(expr->object);
- if (!obj) {
- TINT_ASSERT(IR, false && "no object result");
- return;
- }
+ ValueOrVecElAccess Emit(const ast::Expression* root) {
+ // Process the root expression. This will likely add tasks.
+ Process(root);
- auto* sem = program_->Sem().Get(expr)->Unwrap();
+ // Execute all the tasks until all expressions have been resolved.
+ while (!tasks.IsEmpty()) {
+ auto task = tasks.Pop();
+ task();
+ }
- // The access result type should match the source result type. If the source is a pointer,
- // we generate a pointer.
- const type::Type* ty = sem->Type()->UnwrapRef()->Clone(clone_ctx_.type_ctx);
- if (auto* ptr = obj->Type()->As<type::Pointer>(); ptr && !ty->Is<type::Pointer>()) {
- ty = builder_.ir.Types().ptr(ptr->AddressSpace(), ty, ptr->Access());
- }
+ // Get the resolved root expression.
+ return Get(root);
+ }
- auto index = tint::Switch(
- sem,
- [&](const sem::IndexAccessorExpression* idx) -> ir::Value* {
- if (auto* v = idx->Index()->ConstantValue()) {
- if (auto* cv = v->Clone(clone_ctx_)) {
- return builder_.Constant(cv);
- }
- TINT_ASSERT(IR, false && "constant clone failed");
+ private:
+ Impl& impl;
+ utils::Vector<ir::Block*, 8> blocks;
+ utils::Vector<std::function<void()>, 64> tasks;
+ utils::Hashmap<const ast::Expression*, ValueOrVecElAccess, 64> bindings_;
+
+ void Bind(const ast::Expression* expr, ir::Value* value) {
+ // If this expression maps to sem::Load, insert a load instruction to get the result
+ if (impl.program_->Sem().Get<sem::Load>(expr)) {
+ auto* load = impl.builder_.Load(value);
+ impl.current_block_->Append(load);
+ value = load->Result();
+ }
+ bindings_.Add(expr, value);
+ }
+
+ void Bind(const ast::Expression* expr, const VectorRefElementAccess& access) {
+ // If this expression maps to sem::Load, insert a load instruction to get the result
+ if (impl.program_->Sem().Get<sem::Load>(expr)) {
+ auto* load = impl.builder_.LoadVectorElement(access.vector, access.index);
+ impl.current_block_->Append(load);
+ bindings_.Add(expr, load->Result());
+ } else {
+ bindings_.Add(expr, access);
+ }
+ }
+
+ ValueOrVecElAccess Get(const ast::Expression* expr) {
+ auto val = bindings_.Get(expr);
+ if (!val) {
return nullptr;
}
- return ResultForExpression(idx->Index()->Declaration());
- },
- [&](const sem::StructMemberAccess* access) -> ir::Value* {
- return builder_.Constant(u32((access->Member()->Index())));
- },
- [&](const sem::Swizzle* swizzle) -> ir::Value* {
- auto& indices = swizzle->Indices();
-
- // A single element swizzle is just treated as an accessor.
- if (indices.Length() == 1) {
- return builder_.Constant(u32(indices[0]));
- }
- auto* val = builder_.Swizzle(ty, obj, std::move(indices));
- current_block_->Append(val);
- expr_to_result_.Add(expr, val->Result());
- return nullptr;
- },
- [&](Default) {
- TINT_ICE(Writer, diagnostics_)
- << "invalid accessor: " + std::string(sem->TypeInfo().name);
- return nullptr;
- });
-
- if (!index) {
- return;
- }
-
- // If the object is an unnamed value (a subexpression, not a let) and is the result of
- // another access, then we can just append the index to that access.
- if (!mod.NameOf(obj).IsValid()) {
- if (auto* inst_res = obj->As<InstructionResult>()) {
- if (auto* access = inst_res->Source()->As<Access>()) {
- access->AddIndex(index);
- access->Result()->SetType(ty);
- expr_to_result_.Remove(expr->object);
- expr_to_result_.Add(expr, access->Result());
- // Move the access after the index expression.
- if (current_block_->Back() != access) {
- current_block_->Remove(access);
- current_block_->Append(access);
- }
- return;
- }
+ return *val;
}
- }
- // Create a new access
- auto* access = builder_.Access(ty, obj, index);
- current_block_->Append(access);
- expr_to_result_.Add(expr, access->Result());
- }
-
- struct WorkListData {
- const ast::Expression* inst = nullptr;
- const ast::Expression* emit_to = nullptr;
- ir::Block* parent_block = nullptr;
- };
-
- void EmitExpression(const ast::Expression* root) {
- // If this is a value that has been const-eval'd return the result.
- auto* sem = program_->Sem().GetVal(root);
- if (sem) {
- if (auto* v = sem->ConstantValue()) {
- if (auto* cv = v->Clone(clone_ctx_)) {
- auto* val = builder_.Constant(cv);
- expr_to_result_.Add(root, val);
- return;
+ Value* GetValue(const ast::Expression* expr) {
+ auto res = Get(expr);
+ if (auto** val = std::get_if<Value*>(&res)) {
+ return *val;
}
+ TINT_ICE(IR, impl.diagnostics_) << "expression did not resolve to a value";
+ return nullptr;
}
- }
- utils::Vector<WorkListData, 64> work_list;
+ void PushBlock(ir::Block* block) {
+ blocks.Push(impl.current_block_);
+ impl.current_block_ = block;
+ }
- auto process_work_list = [&](const ast::Expression* expr) {
- TINT_ASSERT(IR, !work_list.IsEmpty());
+ void PopBlock() { impl.current_block_ = blocks.Pop(); }
- auto& cur = work_list.Back();
- return tint::Switch(
- cur.inst, //
- [&](const ast::AccessorExpression* a) {
- EmitAccess(a);
- return true;
- },
- [&](const ast::UnaryOpExpression* u) {
- EmitUnary(u);
- return true;
- },
- [&](const ast::CallExpression* c) {
- EmitCall(c);
- return true;
- },
- [&](const ast::BitcastExpression* b) {
- EmitBitcast(b);
- return true;
- },
- [&](const ast::BinaryExpression* b) {
- switch (b->op) {
- case ast::BinaryOp::kLogicalAnd:
- case ast::BinaryOp::kLogicalOr: {
- if (expr == b->lhs) {
- // Store current_block_ first as the short circuit will set the
- // current_block_ to either the true or false branch.
- cur.parent_block = current_block_;
-
- auto* if_inst = EmitShortCircuit(b);
- if (if_inst) {
- control_stack_.Push(if_inst);
- cur.emit_to = b->rhs;
- return false;
- }
- } else if (expr == b->rhs) {
- control_stack_.Pop();
- EmitShortCircuitResult(b);
- current_block_ = cur.parent_block;
- } else {
- TINT_UNREACHABLE(IR, diagnostics_);
- }
- return true;
+ Value* EmitConstant(const ast::Expression* expr) {
+ if (auto* sem = impl.program_->Sem().GetVal(expr)) {
+ if (auto* v = sem->ConstantValue()) {
+ if (auto* cv = v->Clone(impl.clone_ctx_)) {
+ auto* val = impl.builder_.Constant(cv);
+ bindings_.Add(expr, val);
+ return val;
}
- default:
- if (expr == b->lhs) {
- // Force a load of the LHS if needed so the results come out in the
- // correct order.
- (void)ResultForExpression(expr);
- cur.emit_to = b->rhs;
- return false;
- } else if (expr == b->rhs) {
- EmitBinary(b);
- } else {
- TINT_UNREACHABLE(IR, diagnostics_);
- }
- return true;
}
- },
- [&](Default) {
- add_error(cur.inst->source,
- "unknown expression type: " + std::string(cur.inst->TypeInfo().name));
- return true;
- });
+ }
+ return nullptr;
+ }
+
+ void EmitAccess(const ast::AccessorExpression* expr) {
+ if (auto vec_access = AsVectorRefElementAccess(expr)) {
+ Bind(expr, vec_access.value());
+ return;
+ }
+
+ auto* obj = GetValue(expr->object);
+ if (!obj) {
+ TINT_ASSERT(IR, false && "no object result");
+ return;
+ }
+
+ auto* sem = impl.program_->Sem().Get(expr)->Unwrap();
+
+ // The access result type should match the source result type. If the source is a
+ // pointer, we generate a pointer.
+ const type::Type* ty = sem->Type()->UnwrapRef()->Clone(impl.clone_ctx_.type_ctx);
+ if (auto* ptr = obj->Type()->As<type::Pointer>(); ptr && !ty->Is<type::Pointer>()) {
+ ty = impl.builder_.ir.Types().ptr(ptr->AddressSpace(), ty, ptr->Access());
+ }
+
+ auto index = tint::Switch(
+ sem,
+ [&](const sem::IndexAccessorExpression* idx) -> ir::Value* {
+ if (auto* v = idx->Index()->ConstantValue()) {
+ if (auto* cv = v->Clone(impl.clone_ctx_)) {
+ return impl.builder_.Constant(cv);
+ }
+ TINT_ASSERT(IR, false && "constant clone failed");
+ return nullptr;
+ }
+ return GetValue(idx->Index()->Declaration());
+ },
+ [&](const sem::StructMemberAccess* access) -> ir::Value* {
+ return impl.builder_.Constant(u32((access->Member()->Index())));
+ },
+ [&](const sem::Swizzle* swizzle) -> ir::Value* {
+ auto& indices = swizzle->Indices();
+
+ // A single element swizzle is just treated as an accessor.
+ if (indices.Length() == 1) {
+ return impl.builder_.Constant(u32(indices[0]));
+ }
+ auto* val = impl.builder_.Swizzle(ty, obj, std::move(indices));
+ impl.current_block_->Append(val);
+ Bind(expr, val->Result());
+ return nullptr;
+ },
+ [&](Default) {
+ TINT_ICE(Writer, impl.diagnostics_)
+ << "invalid accessor: " + std::string(sem->TypeInfo().name);
+ return nullptr;
+ });
+
+ if (!index) {
+ return;
+ }
+
+ // If the object is an unnamed value (a subexpression, not a let) and is the result
+ // of another access, then we can just append the index to that access.
+ if (!impl.mod.NameOf(obj).IsValid()) {
+ if (auto* inst_res = obj->As<InstructionResult>()) {
+ if (auto* access = inst_res->Source()->As<Access>()) {
+ access->AddIndex(index);
+ access->Result()->SetType(ty);
+ bindings_.Remove(expr->object);
+ // Move the access after the index expression.
+ if (impl.current_block_->Back() != access) {
+ impl.current_block_->Remove(access);
+ impl.current_block_->Append(access);
+ }
+ Bind(expr, access->Result());
+ return;
+ }
+ }
+ }
+
+ // Create a new access
+ auto* access = impl.builder_.Access(ty, obj, index);
+ impl.current_block_->Append(access);
+ Bind(expr, access->Result());
+ }
+
+ void EmitBinary(const ast::BinaryExpression* b) {
+ auto* b_sem = impl.program_->Sem().Get(b);
+ auto* ty = b_sem->Type()->Clone(impl.clone_ctx_.type_ctx);
+ auto lhs = GetValue(b->lhs);
+ if (!lhs) {
+ return;
+ }
+ auto rhs = GetValue(b->rhs);
+ if (!rhs) {
+ return;
+ }
+ Binary* inst = impl.BinaryOp(ty, lhs, rhs, b->op);
+ if (!inst) {
+ return;
+ }
+ impl.current_block_->Append(inst);
+ Bind(b, inst->Result());
+ }
+
+ void EmitUnary(const ast::UnaryOpExpression* expr) {
+ auto val = GetValue(expr->expr);
+ if (!val) {
+ return;
+ }
+ auto* sem = impl.program_->Sem().Get(expr);
+ auto* ty = sem->Type()->Clone(impl.clone_ctx_.type_ctx);
+ Instruction* inst = nullptr;
+ switch (expr->op) {
+ case ast::UnaryOp::kAddressOf:
+ case ast::UnaryOp::kIndirection:
+ // 'address-of' and 'indirection' just fold away and we propagate the
+ // pointer.
+ Bind(expr, val);
+ return;
+ case ast::UnaryOp::kComplement:
+ inst = impl.builder_.Complement(ty, val);
+ break;
+ case ast::UnaryOp::kNegation:
+ inst = impl.builder_.Negation(ty, val);
+ break;
+ case ast::UnaryOp::kNot:
+ inst = impl.builder_.Not(ty, val);
+ break;
+ }
+ impl.current_block_->Append(inst);
+ Bind(expr, inst->Result());
+ }
+
+ void EmitBitcast(const ast::BitcastExpression* b) {
+ auto val = GetValue(b->expr);
+ if (!val) {
+ return;
+ }
+ auto* sem = impl.program_->Sem().Get(b);
+ auto* ty = sem->Type()->Clone(impl.clone_ctx_.type_ctx);
+ auto* inst = impl.builder_.Bitcast(ty, val);
+ impl.current_block_->Append(inst);
+ Bind(b, inst->Result());
+ }
+
+ void EmitCall(const ast::CallExpression* expr) {
+ // If this is a materialized semantic node, just use the constant value.
+ if (auto* mat = impl.program_->Sem().Get(expr)) {
+ if (mat->ConstantValue()) {
+ auto* cv = mat->ConstantValue()->Clone(impl.clone_ctx_);
+ if (!cv) {
+ impl.add_error(expr->source, "failed to get constant value for call " +
+ std::string(expr->TypeInfo().name));
+ return;
+ }
+ Bind(expr, impl.builder_.Constant(cv));
+ return;
+ }
+ }
+ utils::Vector<Value*, 8> args;
+ args.Reserve(expr->args.Length());
+ // Emit the arguments
+ for (const auto* arg : expr->args) {
+ auto value = GetValue(arg);
+ if (!value) {
+ impl.add_error(arg->source, "failed to convert arguments");
+ return;
+ }
+ args.Push(value);
+ }
+ auto* sem = impl.program_->Sem().Get<sem::Call>(expr);
+ if (!sem) {
+ impl.add_error(expr->source, "failed to get semantic information for call " +
+ std::string(expr->TypeInfo().name));
+ return;
+ }
+ auto* ty = sem->Target()->ReturnType()->Clone(impl.clone_ctx_.type_ctx);
+ Instruction* inst = nullptr;
+ // If this is a builtin function, emit the specific builtin value
+ if (auto* b = sem->Target()->As<sem::Builtin>()) {
+ inst = impl.builder_.Call(ty, b->Type(), args);
+ } else if (sem->Target()->As<sem::ValueConstructor>()) {
+ inst = impl.builder_.Construct(ty, std::move(args));
+ } else if (sem->Target()->Is<sem::ValueConversion>()) {
+ inst = impl.builder_.Convert(ty, args[0]);
+ } else if (expr->target->identifier->Is<ast::TemplatedIdentifier>()) {
+ TINT_UNIMPLEMENTED(IR, impl.diagnostics_) << "missing templated ident support";
+ return;
+ } else {
+ // Not a builtin and not a templated call, so this is a user function.
+ inst = impl.builder_.Call(
+ ty, impl.scopes_.Get(expr->target->identifier->symbol)->As<ir::Function>(),
+ std::move(args));
+ }
+ if (inst == nullptr) {
+ return;
+ }
+ impl.current_block_->Append(inst);
+ Bind(expr, inst->Result());
+ }
+
+ void EmitIdentifier(const ast::IdentifierExpression* i) {
+ auto* v = impl.scopes_.Get(i->identifier->symbol);
+ if (TINT_UNLIKELY(!v)) {
+ impl.add_error(i->source,
+ "unable to find identifier " + i->identifier->symbol.Name());
+ return;
+ }
+ Bind(i, v);
+ }
+
+ void EmitLiteral(const ast::LiteralExpression* lit) {
+ auto* sem = impl.program_->Sem().Get(lit);
+ if (!sem) {
+ impl.add_error(lit->source, "failed to get semantic information for node " +
+ std::string(lit->TypeInfo().name));
+ return;
+ }
+ auto* cv = sem->ConstantValue()->Clone(impl.clone_ctx_);
+ if (!cv) {
+ impl.add_error(lit->source, "failed to get constant value for node " +
+ std::string(lit->TypeInfo().name));
+ return;
+ }
+ auto* val = impl.builder_.Constant(cv);
+ Bind(lit, val);
+ }
+
+ std::optional<VectorRefElementAccess> AsVectorRefElementAccess(
+ const ast::Expression* expr) {
+ return AsVectorRefElementAccess(
+ impl.program_->Sem().Get<sem::ValueExpression>(expr)->UnwrapLoad());
+ }
+
+ std::optional<VectorRefElementAccess> AsVectorRefElementAccess(
+ const sem::ValueExpression* expr) {
+ auto* access = As<sem::AccessorExpression>(expr);
+ if (!access) {
+ return std::nullopt;
+ }
+
+ auto* ref = access->Object()->Type()->As<type::Reference>();
+ if (!ref) {
+ return std::nullopt;
+ }
+
+ if (!ref->StoreType()->Is<type::Vector>()) {
+ return std::nullopt;
+ }
+ return tint::Switch(
+ access,
+ [&](const sem::Swizzle* s) -> std::optional<VectorRefElementAccess> {
+ if (auto vec = GetValue(access->Object()->Declaration())) {
+ return VectorRefElementAccess{
+ vec, impl.builder_.Constant(u32(s->Indices()[0]))};
+ }
+ return std::nullopt;
+ },
+ [&](const sem::IndexAccessorExpression* i)
+ -> std::optional<VectorRefElementAccess> {
+ if (auto vec = GetValue(access->Object()->Declaration())) {
+ if (auto idx = GetValue(i->Index()->Declaration())) {
+ return VectorRefElementAccess{vec, idx};
+ }
+ }
+ return std::nullopt;
+ });
+ }
+
+ void BeginShortCircuit(const ast::BinaryExpression* expr) {
+ auto lhs = GetValue(expr->lhs);
+ if (!lhs) {
+ return;
+ }
+
+ auto& b = impl.builder_;
+ auto* if_inst = b.If(lhs);
+ impl.current_block_->Append(if_inst);
+
+ auto* result = b.InstructionResult(b.ir.Types().bool_());
+ if_inst->SetResults(result);
+
+ if (expr->op == ast::BinaryOp::kLogicalAnd) {
+ if_inst->False()->Append(b.ExitIf(if_inst, b.Constant(false)));
+ PushBlock(if_inst->True());
+ } else {
+ if_inst->True()->Append(b.ExitIf(if_inst, b.Constant(true)));
+ PushBlock(if_inst->False());
+ }
+
+ Bind(expr, result);
+ }
+
+ void EndShortCircuit(const ast::BinaryExpression* b) {
+ auto res = GetValue(b);
+ auto* src = res->As<InstructionResult>()->Source();
+ auto* if_ = src->As<ir::If>();
+ TINT_ASSERT_OR_RETURN(IR, if_);
+ auto rhs = GetValue(b->rhs);
+ if (!rhs) {
+ return;
+ }
+ impl.current_block_->Append(impl.builder_.ExitIf(if_, rhs));
+ PopBlock();
+ }
+
+ void Process(const ast::Expression* expr) {
+ if (EmitConstant(expr)) {
+ // If this is a value that has been const-eval'd, then no need to traverse
+ // deeper.
+ return;
+ }
+
+ tint::Switch(
+ expr, //
+ [&](const ast::BinaryExpression* e) {
+ if (e->op == ast::BinaryOp::kLogicalAnd ||
+ e->op == ast::BinaryOp::kLogicalOr) {
+ tasks.Push([=] { EndShortCircuit(e); });
+ tasks.Push([=] { Process(e->rhs); });
+ tasks.Push([=] { BeginShortCircuit(e); });
+ tasks.Push([=] { Process(e->lhs); });
+ } else {
+ tasks.Push([=] { EmitBinary(e); });
+ tasks.Push([=] { Process(e->rhs); });
+ tasks.Push([=] { Process(e->lhs); });
+ }
+ },
+ [&](const ast::IndexAccessorExpression* e) {
+ tasks.Push([=] { EmitAccess(e); });
+ tasks.Push([=] { Process(e->index); });
+ tasks.Push([=] { Process(e->object); });
+ },
+ [&](const ast::MemberAccessorExpression* e) {
+ tasks.Push([=] { EmitAccess(e); });
+ tasks.Push([=] { Process(e->object); });
+ },
+ [&](const ast::UnaryOpExpression* e) {
+ tasks.Push([=] { EmitUnary(e); });
+ tasks.Push([=] { Process(e->expr); });
+ },
+ [&](const ast::CallExpression* e) {
+ tasks.Push([=] { EmitCall(e); });
+ for (auto* arg : utils::Reverse(e->args)) {
+ tasks.Push([=] { Process(arg); });
+ }
+ },
+ [&](const ast::BitcastExpression* e) {
+ tasks.Push([=] { EmitBitcast(e); });
+ tasks.Push([=] { Process(e->expr); });
+ },
+ [&](const ast::LiteralExpression* e) { EmitLiteral(e); },
+ [&](const ast::IdentifierExpression* e) { EmitIdentifier(e); },
+ [&](Default) {
+ impl.add_error(expr->source,
+ "Unhandled: " + std::string(expr->TypeInfo().name));
+ });
+ }
};
- if (!ast::TraverseExpressions(root, diagnostics_, [&](const ast::Expression* expr) {
- bool handled = false;
- // If this is a value that has been const-eval'd return the result.
- if (auto* expr_sem = program_->Sem().GetVal(expr)) {
- if (auto* v = expr_sem->ConstantValue()) {
- if (auto* cv = v->Clone(clone_ctx_)) {
- auto* val = builder_.Constant(cv);
- expr_to_result_.Add(expr, val);
- handled = true;
- }
- }
- }
-
- if (!handled) {
- tint::Switch(
- expr, //
- [&](const ast::BinaryExpression* b) {
- work_list.Push({b, b->lhs});
- },
- [&](const ast::MemberAccessorExpression* m) {
- work_list.Push({m, m->object});
- },
- [&](const ast::IndexAccessorExpression* i) {
- work_list.Push({i, i->index});
- },
- [&](const ast::UnaryOpExpression* u) {
- work_list.Push({u, u->expr});
- },
- [&](const ast::CallExpression* c) {
- if (c->args.IsEmpty()) {
- EmitCall(c);
- } else {
- work_list.Push({c, c->args.Back()});
- }
- },
- [&](const ast::BitcastExpression* b) {
- work_list.Push({b, b->expr});
- },
- [&](const ast::LiteralExpression* l) { EmitLiteral(l); },
- [&](const ast::IdentifierExpression* i) { EmitIdentifier(i); },
- [&](Default) {
- add_error(expr->source,
- "Unhandled: " + std::string(expr->TypeInfo().name));
- });
- }
-
- auto* cur = expr;
- while (true) {
- if (work_list.IsEmpty() || work_list.Back().emit_to != cur) {
- break;
- }
-
- if (process_work_list(cur)) {
- // The processed work instruction maybe the `emit_to` instruction for the
- // next work item.
- cur = work_list.Pop().inst;
- } else {
- break;
- }
- }
-
- // Don't descend into anything that is const-eval'd as it's already computed
- if (auto* expr_sem = program_->Sem().GetVal(expr)) {
- if (expr_sem->ConstantValue()) {
- return ast::TraverseAction::Skip;
- }
- }
-
- return ast::TraverseAction::Descend;
- })) {
- return;
- }
-
- while (!work_list.IsEmpty()) {
- process_work_list(work_list.Back().inst);
- work_list.Pop();
- }
+ return Emitter(*this).Emit(root);
}
- utils::Result<ir::Value*> EmitExpressionWithResult(const ast::Expression* root) {
- EmitExpression(root);
- return ResultForExpression(root);
+ Value* EmitValueExpression(const ast::Expression* root) {
+ auto res = EmitExpression(root);
+ if (auto** val = std::get_if<Value*>(&res)) {
+ return *val;
+ }
+ TINT_ICE(IR, diagnostics_) << "expression did not resolve to a value";
+ return nullptr;
}
+ void EmitCall(const ast::CallStatement* stmt) { (void)EmitValueExpression(stmt->expr); }
+
void EmitVariable(const ast::Variable* var) {
auto* sem = program_->Sem().Get(var);
@@ -1128,11 +1280,11 @@
auto* val = builder_.Var(ty);
if (v->initializer) {
- auto init = EmitExpressionWithResult(v->initializer);
+ auto init = EmitValueExpression(v->initializer);
if (!init) {
return;
}
- val->SetInitializer(init.Get());
+ val->SetInitializer(init);
}
current_block_->Append(val);
@@ -1149,12 +1301,12 @@
},
[&](const ast::Let* l) {
auto* last_stmt = current_block_->Back();
- auto init = EmitExpressionWithResult(l->initializer);
+ auto init = EmitValueExpression(l->initializer);
if (!init) {
return;
}
- auto* value = init.Get();
+ auto* value = init;
if (current_block_->Back() == last_stmt) {
// Emitting the let's initializer didn't create an instruction.
// Create an ir::Let to give the let an instruction. This gives the let a
@@ -1190,203 +1342,6 @@
});
}
- void EmitShortCircuitResult(const ast::BinaryExpression* b) {
- auto res = ResultForExpression(b);
- auto* src = res->As<InstructionResult>()->Source();
- if (auto* if_ = src->As<ir::If>()) {
- auto rhs = ResultForExpression(b->rhs);
- if (!rhs) {
- return;
- }
- current_block_->Append(builder_.ExitIf(if_, rhs));
- } else {
- TINT_ASSERT(IR, false);
- }
- }
-
- ir::If* EmitShortCircuit(const ast::BinaryExpression* b) {
- auto lhs = ResultForExpression(b->lhs);
- if (!lhs) {
- return nullptr;
- }
-
- auto* if_inst = builder_.If(lhs);
- current_block_->Append(if_inst);
-
- auto* result = builder_.InstructionResult(builder_.ir.Types().bool_());
- if_inst->SetResults(result);
-
- if (b->op == ast::BinaryOp::kLogicalAnd) {
- if_inst->False()->Append(builder_.ExitIf(if_inst, builder_.Constant(false)));
- current_block_ = if_inst->True();
- } else {
- if_inst->True()->Append(builder_.ExitIf(if_inst, builder_.Constant(true)));
- current_block_ = if_inst->False();
- }
-
- expr_to_result_.Add(b, result);
- return if_inst;
- }
-
- void EmitBinary(const ast::BinaryExpression* b) {
- auto* b_sem = program_->Sem().Get(b);
- auto* ty = b_sem->Type()->Clone(clone_ctx_.type_ctx);
-
- auto lhs = ResultForExpression(b->lhs);
- if (!lhs) {
- return;
- }
-
- auto rhs = ResultForExpression(b->rhs);
- if (!rhs) {
- return;
- }
-
- Binary* inst = BinaryOp(ty, lhs, rhs, b->op);
- if (!inst) {
- return;
- }
-
- current_block_->Append(inst);
- expr_to_result_.Add(b, inst->Result());
- }
-
- void EmitUnary(const ast::UnaryOpExpression* expr) {
- auto val = ResultForExpression(expr->expr);
- if (!val) {
- return;
- }
-
- auto* sem = program_->Sem().Get(expr);
- auto* ty = sem->Type()->Clone(clone_ctx_.type_ctx);
-
- Instruction* inst = nullptr;
- switch (expr->op) {
- case ast::UnaryOp::kAddressOf:
- case ast::UnaryOp::kIndirection:
- // 'address-of' and 'indirection' just fold away and we propagate the pointer.
- expr_to_result_.Add(expr, val);
- return;
- case ast::UnaryOp::kComplement:
- inst = builder_.Complement(ty, val);
- break;
- case ast::UnaryOp::kNegation:
- inst = builder_.Negation(ty, val);
- break;
- case ast::UnaryOp::kNot:
- inst = builder_.Not(ty, val);
- break;
- }
-
- current_block_->Append(inst);
- expr_to_result_.Add(expr, inst->Result());
- }
-
- void EmitBitcast(const ast::BitcastExpression* b) {
- auto val = ResultForExpression(b->expr);
- if (!val) {
- return;
- }
-
- auto* sem = program_->Sem().Get(b);
- auto* ty = sem->Type()->Clone(clone_ctx_.type_ctx);
- auto* inst = builder_.Bitcast(ty, val);
-
- current_block_->Append(inst);
- expr_to_result_.Add(b, inst->Result());
- }
-
- void EmitCall(const ast::CallStatement* stmt) { (void)EmitExpressionWithResult(stmt->expr); }
-
- void EmitCall(const ast::CallExpression* expr) {
- // If this is a materialized semantic node, just use the constant value.
- if (auto* mat = program_->Sem().Get(expr)) {
- if (mat->ConstantValue()) {
- auto* cv = mat->ConstantValue()->Clone(clone_ctx_);
- if (!cv) {
- add_error(expr->source, "failed to get constant value for call " +
- std::string(expr->TypeInfo().name));
- return;
- }
- expr_to_result_.Add(expr, builder_.Constant(cv));
- return;
- }
- }
-
- utils::Vector<Value*, 8> args;
- args.Reserve(expr->args.Length());
-
- // Emit the arguments
- for (const auto* arg : expr->args) {
- auto value = ResultForExpression(arg);
- if (!value) {
- add_error(arg->source, "failed to convert arguments");
- return;
- }
- args.Push(value);
- }
-
- auto* sem = program_->Sem().Get<sem::Call>(expr);
- if (!sem) {
- add_error(expr->source, "failed to get semantic information for call " +
- std::string(expr->TypeInfo().name));
- return;
- }
-
- auto* ty = sem->Target()->ReturnType()->Clone(clone_ctx_.type_ctx);
-
- Instruction* inst = nullptr;
-
- // If this is a builtin function, emit the specific builtin value
- if (auto* b = sem->Target()->As<sem::Builtin>()) {
- inst = builder_.Call(ty, b->Type(), args);
- } else if (sem->Target()->As<sem::ValueConstructor>()) {
- inst = builder_.Construct(ty, std::move(args));
- } else if (sem->Target()->Is<sem::ValueConversion>()) {
- inst = builder_.Convert(ty, args[0]);
- } else if (expr->target->identifier->Is<ast::TemplatedIdentifier>()) {
- TINT_UNIMPLEMENTED(IR, diagnostics_) << "missing templated ident support";
- return;
- } else {
- // Not a builtin and not a templated call, so this is a user function.
- inst =
- builder_.Call(ty, scopes_.Get(expr->target->identifier->symbol)->As<ir::Function>(),
- std::move(args));
- }
- if (inst == nullptr) {
- return;
- }
- current_block_->Append(inst);
- expr_to_result_.Add(expr, inst->Result());
- }
-
- void EmitIdentifier(const ast::IdentifierExpression* i) {
- auto* v = scopes_.Get(i->identifier->symbol);
- if (TINT_UNLIKELY(!v)) {
- add_error(i->source, "unable to find identifier " + i->identifier->symbol.Name());
- return;
- }
- expr_to_result_.Add(i, v);
- }
-
- void EmitLiteral(const ast::LiteralExpression* lit) {
- auto* sem = program_->Sem().Get(lit);
- if (!sem) {
- add_error(lit->source, "failed to get semantic information for node " +
- std::string(lit->TypeInfo().name));
- return;
- }
-
- auto* cv = sem->ConstantValue()->Clone(clone_ctx_);
- if (!cv) {
- add_error(lit->source,
- "failed to get constant value for node " + std::string(lit->TypeInfo().name));
- return;
- }
- auto* val = builder_.Constant(cv);
- expr_to_result_.Add(lit, val);
- }
-
ir::Binary* BinaryOp(const type::Type* ty, ir::Value* lhs, ir::Value* rhs, ast::BinaryOp op) {
switch (op) {
case ast::BinaryOp::kAnd:
@@ -1432,44 +1387,6 @@
TINT_UNREACHABLE(IR, diagnostics_);
return nullptr;
}
-
- std::optional<VectorRefElementAccess> AsVectorRefElementAccess(const ast::Expression* expr) {
- return AsVectorRefElementAccess(
- program_->Sem().Get<sem::ValueExpression>(expr)->UnwrapLoad());
- }
-
- std::optional<VectorRefElementAccess> AsVectorRefElementAccess(
- const sem::ValueExpression* expr) {
- auto* access = As<sem::AccessorExpression>(expr);
- if (!access) {
- return std::nullopt;
- }
-
- auto* ref = access->Object()->Type()->As<type::Reference>();
- if (!ref) {
- return std::nullopt;
- }
-
- if (!ref->StoreType()->Is<type::Vector>()) {
- return std::nullopt;
- }
- return tint::Switch(
- access,
- [&](const sem::Swizzle* s) -> std::optional<VectorRefElementAccess> {
- if (auto vec = ResultForExpression(access->Object()->Declaration())) {
- return VectorRefElementAccess{vec, builder_.Constant(u32(s->Indices()[0]))};
- }
- return std::nullopt;
- },
- [&](const sem::IndexAccessorExpression* i) -> std::optional<VectorRefElementAccess> {
- if (auto vec = ResultForExpression(access->Object()->Declaration())) {
- if (auto idx = ResultForExpression(i->Index()->Declaration())) {
- return VectorRefElementAccess{vec, idx};
- }
- }
- return std::nullopt;
- });
- }
};
} // namespace