| // Copyright 2023 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/lang/core/ir/from_program.h" |
| |
| #include <iostream> |
| #include <unordered_map> |
| #include <utility> |
| #include <variant> |
| #include <vector> |
| |
| #include "src/tint/lang/core/ir/block_param.h" |
| #include "src/tint/lang/core/ir/builder.h" |
| #include "src/tint/lang/core/ir/exit_if.h" |
| #include "src/tint/lang/core/ir/exit_loop.h" |
| #include "src/tint/lang/core/ir/exit_switch.h" |
| #include "src/tint/lang/core/ir/function.h" |
| #include "src/tint/lang/core/ir/if.h" |
| #include "src/tint/lang/core/ir/loop.h" |
| #include "src/tint/lang/core/ir/module.h" |
| #include "src/tint/lang/core/ir/store.h" |
| #include "src/tint/lang/core/ir/switch.h" |
| #include "src/tint/lang/core/ir/value.h" |
| #include "src/tint/lang/core/type/pointer.h" |
| #include "src/tint/lang/core/type/reference.h" |
| #include "src/tint/lang/core/type/struct.h" |
| #include "src/tint/lang/core/type/void.h" |
| #include "src/tint/lang/wgsl/ast/accessor_expression.h" |
| #include "src/tint/lang/wgsl/ast/alias.h" |
| #include "src/tint/lang/wgsl/ast/assignment_statement.h" |
| #include "src/tint/lang/wgsl/ast/binary_expression.h" |
| #include "src/tint/lang/wgsl/ast/bitcast_expression.h" |
| #include "src/tint/lang/wgsl/ast/block_statement.h" |
| #include "src/tint/lang/wgsl/ast/bool_literal_expression.h" |
| #include "src/tint/lang/wgsl/ast/break_if_statement.h" |
| #include "src/tint/lang/wgsl/ast/break_statement.h" |
| #include "src/tint/lang/wgsl/ast/call_expression.h" |
| #include "src/tint/lang/wgsl/ast/call_statement.h" |
| #include "src/tint/lang/wgsl/ast/compound_assignment_statement.h" |
| #include "src/tint/lang/wgsl/ast/const.h" |
| #include "src/tint/lang/wgsl/ast/const_assert.h" |
| #include "src/tint/lang/wgsl/ast/continue_statement.h" |
| #include "src/tint/lang/wgsl/ast/diagnostic_directive.h" |
| #include "src/tint/lang/wgsl/ast/discard_statement.h" |
| #include "src/tint/lang/wgsl/ast/enable.h" |
| #include "src/tint/lang/wgsl/ast/float_literal_expression.h" |
| #include "src/tint/lang/wgsl/ast/for_loop_statement.h" |
| #include "src/tint/lang/wgsl/ast/function.h" |
| #include "src/tint/lang/wgsl/ast/id_attribute.h" |
| #include "src/tint/lang/wgsl/ast/identifier.h" |
| #include "src/tint/lang/wgsl/ast/identifier_expression.h" |
| #include "src/tint/lang/wgsl/ast/if_statement.h" |
| #include "src/tint/lang/wgsl/ast/increment_decrement_statement.h" |
| #include "src/tint/lang/wgsl/ast/index_accessor_expression.h" |
| #include "src/tint/lang/wgsl/ast/int_literal_expression.h" |
| #include "src/tint/lang/wgsl/ast/interpolate_attribute.h" |
| #include "src/tint/lang/wgsl/ast/invariant_attribute.h" |
| #include "src/tint/lang/wgsl/ast/let.h" |
| #include "src/tint/lang/wgsl/ast/literal_expression.h" |
| #include "src/tint/lang/wgsl/ast/loop_statement.h" |
| #include "src/tint/lang/wgsl/ast/member_accessor_expression.h" |
| #include "src/tint/lang/wgsl/ast/override.h" |
| #include "src/tint/lang/wgsl/ast/phony_expression.h" |
| #include "src/tint/lang/wgsl/ast/return_statement.h" |
| #include "src/tint/lang/wgsl/ast/statement.h" |
| #include "src/tint/lang/wgsl/ast/struct.h" |
| #include "src/tint/lang/wgsl/ast/struct_member_align_attribute.h" |
| #include "src/tint/lang/wgsl/ast/struct_member_size_attribute.h" |
| #include "src/tint/lang/wgsl/ast/switch_statement.h" |
| #include "src/tint/lang/wgsl/ast/templated_identifier.h" |
| #include "src/tint/lang/wgsl/ast/unary_op_expression.h" |
| #include "src/tint/lang/wgsl/ast/var.h" |
| #include "src/tint/lang/wgsl/ast/variable_decl_statement.h" |
| #include "src/tint/lang/wgsl/ast/while_statement.h" |
| #include "src/tint/lang/wgsl/program/program.h" |
| #include "src/tint/lang/wgsl/sem/builtin.h" |
| #include "src/tint/lang/wgsl/sem/call.h" |
| #include "src/tint/lang/wgsl/sem/function.h" |
| #include "src/tint/lang/wgsl/sem/index_accessor_expression.h" |
| #include "src/tint/lang/wgsl/sem/load.h" |
| #include "src/tint/lang/wgsl/sem/materialize.h" |
| #include "src/tint/lang/wgsl/sem/member_accessor_expression.h" |
| #include "src/tint/lang/wgsl/sem/module.h" |
| #include "src/tint/lang/wgsl/sem/switch_statement.h" |
| #include "src/tint/lang/wgsl/sem/value_constructor.h" |
| #include "src/tint/lang/wgsl/sem/value_conversion.h" |
| #include "src/tint/lang/wgsl/sem/value_expression.h" |
| #include "src/tint/lang/wgsl/sem/variable.h" |
| #include "src/tint/utils/containers/reverse.h" |
| #include "src/tint/utils/containers/scope_stack.h" |
| #include "src/tint/utils/macros/defer.h" |
| #include "src/tint/utils/macros/scoped_assignment.h" |
| #include "src/tint/utils/result/result.h" |
| #include "src/tint/utils/rtti/switch.h" |
| |
| using namespace tint::number_suffixes; // NOLINT |
| |
| namespace tint::ir { |
| |
| namespace { |
| |
| using ResultType = utils::Result<Module, diag::List>; |
| |
| /// Impl is the private-implementation of FromProgram(). |
| class Impl { |
| public: |
| /// Constructor |
| /// @param program the program to convert to IR |
| explicit Impl(const Program* program) : program_(program) {} |
| |
| /// Builds an IR module from the program passed to the constructor. |
| /// @return the IR module or an error. |
| ResultType Build() { return EmitModule(); } |
| |
| private: |
| enum class ControlFlags { kNone, kExcludeSwitch }; |
| |
| // The input Program |
| const Program* program_ = nullptr; |
| |
| /// The IR module being built |
| Module mod; |
| |
| /// The IR builder being used by the impl. |
| Builder builder_{mod}; |
| |
| // The clone context used to clone data from #program_ |
| constant::CloneContext clone_ctx_{ |
| /* type_ctx */ type::CloneContext{ |
| /* src */ {&program_->Symbols()}, |
| /* dst */ {&builder_.ir.symbols, &builder_.ir.Types()}, |
| }, |
| /* dst */ {builder_.ir.constant_values}, |
| }; |
| |
| /// The stack of flow control instructions. |
| utils::Vector<ControlInstruction*, 8> control_stack_; |
| |
| struct VectorRefElementAccess { |
| ir::Value* vector = nullptr; |
| ir::Value* index = nullptr; |
| }; |
| |
| using ValueOrVecElAccess = std::variant<ir::Value*, VectorRefElementAccess>; |
| |
| /// The current block for expressions. |
| Block* current_block_ = nullptr; |
| |
| /// The current function being processed. |
| Function* current_function_ = nullptr; |
| |
| /// The current stack of scopes being processed. |
| ScopeStack<Symbol, Value*> scopes_; |
| |
| /// The diagnostic that have been raised. |
| diag::List diagnostics_; |
| |
| class StackScope { |
| public: |
| explicit StackScope(Impl* impl) : impl_(impl) { impl->scopes_.Push(); } |
| |
| ~StackScope() { impl_->scopes_.Pop(); } |
| |
| protected: |
| Impl* impl_; |
| }; |
| |
| class ControlStackScope : public StackScope { |
| public: |
| ControlStackScope(Impl* impl, ControlInstruction* b) : StackScope(impl) { |
| impl_->control_stack_.Push(b); |
| } |
| |
| ~ControlStackScope() { impl_->control_stack_.Pop(); } |
| }; |
| |
| void add_error(const Source& s, const std::string& err) { |
| diagnostics_.add_error(tint::diag::System::IR, err, s); |
| } |
| |
| bool NeedTerminator() { return current_block_ && !current_block_->HasTerminator(); } |
| |
| void SetTerminator(Terminator* terminator) { |
| TINT_ASSERT(IR, current_block_); |
| TINT_ASSERT(IR, !current_block_->HasTerminator()); |
| |
| current_block_->Append(terminator); |
| current_block_ = nullptr; |
| } |
| |
| Instruction* FindEnclosingControl(ControlFlags flags) { |
| for (auto it = control_stack_.rbegin(); it != control_stack_.rend(); ++it) { |
| if ((*it)->Is<Loop>()) { |
| return *it; |
| } |
| if (flags == ControlFlags::kExcludeSwitch) { |
| continue; |
| } |
| if ((*it)->Is<Switch>()) { |
| return *it; |
| } |
| } |
| return nullptr; |
| } |
| |
| ResultType EmitModule() { |
| auto* sem = program_->Sem().Module(); |
| |
| for (auto* decl : sem->DependencyOrderedDeclarations()) { |
| tint::Switch( |
| decl, // |
| [&](const ast::Struct*) { |
| // Will be encoded into the `type::Struct` when used. We will then hoist all |
| // used structs up to module scope when converting IR. |
| }, |
| [&](const ast::Alias*) { |
| // Folded away and doesn't appear in the IR. |
| }, |
| [&](const ast::Variable* var) { |
| // Setup the current block to be the root block for the module. The builder |
| // will handle creating it if it doesn't exist already. |
| TINT_SCOPED_ASSIGNMENT(current_block_, builder_.RootBlock()); |
| EmitVariable(var); |
| }, |
| [&](const ast::Function* func) { EmitFunction(func); }, |
| [&](const ast::Enable*) { |
| // TODO(dsinclair): Implement? I think these need to be passed along so further |
| // stages know what is enabled. |
| }, |
| [&](const ast::ConstAssert*) { |
| // Evaluated by the resolver, drop from the IR. |
| }, |
| [&](const ast::DiagnosticDirective*) { |
| // Ignored for now. |
| }, |
| [&](Default) { |
| add_error(decl->source, "unknown type: " + std::string(decl->TypeInfo().name)); |
| }); |
| } |
| |
| if (diagnostics_.contains_errors()) { |
| return ResultType(std::move(diagnostics_)); |
| } |
| |
| return ResultType{std::move(mod)}; |
| } |
| |
| builtin::Interpolation ExtractInterpolation(const ast::InterpolateAttribute* interp) { |
| auto type = program_->Sem() |
| .Get(interp->type) |
| ->As<sem::BuiltinEnumExpression<builtin::InterpolationType>>(); |
| builtin::InterpolationType interpolation_type = type->Value(); |
| |
| builtin::InterpolationSampling interpolation_sampling = |
| builtin::InterpolationSampling::kUndefined; |
| if (interp->sampling) { |
| auto sampling = program_->Sem() |
| .Get(interp->sampling) |
| ->As<sem::BuiltinEnumExpression<builtin::InterpolationSampling>>(); |
| interpolation_sampling = sampling->Value(); |
| } |
| |
| return builtin::Interpolation{interpolation_type, interpolation_sampling}; |
| } |
| |
| void EmitFunction(const ast::Function* ast_func) { |
| // The flow stack should have been emptied when the previous function finished building. |
| TINT_ASSERT(IR, control_stack_.IsEmpty()); |
| |
| const auto* sem = program_->Sem().Get(ast_func); |
| |
| auto* ir_func = builder_.Function(ast_func->name->symbol.NameView(), |
| sem->ReturnType()->Clone(clone_ctx_.type_ctx)); |
| current_function_ = ir_func; |
| |
| scopes_.Set(ast_func->name->symbol, ir_func); |
| |
| if (ast_func->IsEntryPoint()) { |
| switch (ast_func->PipelineStage()) { |
| case ast::PipelineStage::kVertex: |
| ir_func->SetStage(Function::PipelineStage::kVertex); |
| break; |
| case ast::PipelineStage::kFragment: |
| ir_func->SetStage(Function::PipelineStage::kFragment); |
| break; |
| case ast::PipelineStage::kCompute: { |
| ir_func->SetStage(Function::PipelineStage::kCompute); |
| |
| auto wg_size = sem->WorkgroupSize(); |
| ir_func->SetWorkgroupSize(wg_size[0].value(), wg_size[1].value_or(1), |
| wg_size[2].value_or(1)); |
| break; |
| } |
| default: { |
| TINT_ICE(IR, diagnostics_) << "Invalid pipeline stage"; |
| return; |
| } |
| } |
| |
| // Note, interpolated is only valid when paired with Location, so it will only be set |
| // when the location is set. |
| std::optional<builtin::Interpolation> interpolation; |
| for (auto* attr : ast_func->return_type_attributes) { |
| tint::Switch( |
| attr, // |
| [&](const ast::InterpolateAttribute* interp) { |
| interpolation = ExtractInterpolation(interp); |
| }, |
| [&](const ast::InvariantAttribute*) { ir_func->SetReturnInvariant(true); }, |
| [&](const ast::BuiltinAttribute* b) { |
| if (auto* ident_sem = |
| program_->Sem() |
| .Get(b) |
| ->As<sem::BuiltinEnumExpression<builtin::BuiltinValue>>()) { |
| switch (ident_sem->Value()) { |
| case builtin::BuiltinValue::kPosition: |
| ir_func->SetReturnBuiltin(Function::ReturnBuiltin::kPosition); |
| break; |
| case builtin::BuiltinValue::kFragDepth: |
| ir_func->SetReturnBuiltin(Function::ReturnBuiltin::kFragDepth); |
| break; |
| case builtin::BuiltinValue::kSampleMask: |
| ir_func->SetReturnBuiltin(Function::ReturnBuiltin::kSampleMask); |
| break; |
| default: |
| TINT_ICE(IR, diagnostics_) |
| << "Unknown builtin value in return attributes " |
| << ident_sem->Value(); |
| return; |
| } |
| } else { |
| TINT_ICE(IR, diagnostics_) << "Builtin attribute sem invalid"; |
| return; |
| } |
| }); |
| } |
| if (sem->ReturnLocation().has_value()) { |
| ir_func->SetReturnLocation(sem->ReturnLocation().value(), interpolation); |
| } |
| } |
| |
| scopes_.Push(); |
| TINT_DEFER(scopes_.Pop()); |
| |
| utils::Vector<FunctionParam*, 1> params; |
| for (auto* p : ast_func->params) { |
| const auto* param_sem = program_->Sem().Get(p)->As<sem::Parameter>(); |
| auto* ty = param_sem->Type()->Clone(clone_ctx_.type_ctx); |
| auto* param = builder_.FunctionParam(p->name->symbol.NameView(), ty); |
| |
| // Note, interpolated is only valid when paired with Location, so it will only be set |
| // when the location is set. |
| std::optional<builtin::Interpolation> interpolation; |
| for (auto* attr : p->attributes) { |
| tint::Switch( |
| attr, // |
| [&](const ast::InterpolateAttribute* interp) { |
| interpolation = ExtractInterpolation(interp); |
| }, |
| [&](const ast::InvariantAttribute*) { param->SetInvariant(true); }, |
| [&](const ast::BuiltinAttribute* b) { |
| if (auto* ident_sem = |
| program_->Sem() |
| .Get(b) |
| ->As<sem::BuiltinEnumExpression<builtin::BuiltinValue>>()) { |
| switch (ident_sem->Value()) { |
| case builtin::BuiltinValue::kVertexIndex: |
| param->SetBuiltin(FunctionParam::Builtin::kVertexIndex); |
| break; |
| case builtin::BuiltinValue::kInstanceIndex: |
| param->SetBuiltin(FunctionParam::Builtin::kInstanceIndex); |
| break; |
| case builtin::BuiltinValue::kPosition: |
| param->SetBuiltin(FunctionParam::Builtin::kPosition); |
| break; |
| case builtin::BuiltinValue::kFrontFacing: |
| param->SetBuiltin(FunctionParam::Builtin::kFrontFacing); |
| break; |
| case builtin::BuiltinValue::kLocalInvocationId: |
| param->SetBuiltin(FunctionParam::Builtin::kLocalInvocationId); |
| break; |
| case builtin::BuiltinValue::kLocalInvocationIndex: |
| param->SetBuiltin( |
| FunctionParam::Builtin::kLocalInvocationIndex); |
| break; |
| case builtin::BuiltinValue::kGlobalInvocationId: |
| param->SetBuiltin(FunctionParam::Builtin::kGlobalInvocationId); |
| break; |
| case builtin::BuiltinValue::kWorkgroupId: |
| param->SetBuiltin(FunctionParam::Builtin::kWorkgroupId); |
| break; |
| case builtin::BuiltinValue::kNumWorkgroups: |
| param->SetBuiltin(FunctionParam::Builtin::kNumWorkgroups); |
| break; |
| case builtin::BuiltinValue::kSampleIndex: |
| param->SetBuiltin(FunctionParam::Builtin::kSampleIndex); |
| break; |
| case builtin::BuiltinValue::kSampleMask: |
| param->SetBuiltin(FunctionParam::Builtin::kSampleMask); |
| break; |
| default: |
| TINT_ICE(IR, diagnostics_) |
| << "Unknown builtin value in parameter attributes " |
| << ident_sem->Value(); |
| return; |
| } |
| } else { |
| TINT_ICE(IR, diagnostics_) << "Builtin attribute sem invalid"; |
| return; |
| } |
| }); |
| |
| if (param_sem->Location().has_value()) { |
| param->SetLocation(param_sem->Location().value(), interpolation); |
| } |
| if (param_sem->BindingPoint().has_value()) { |
| param->SetBindingPoint(param_sem->BindingPoint()->group, |
| param_sem->BindingPoint()->binding); |
| } |
| } |
| |
| scopes_.Set(p->name->symbol, param); |
| params.Push(param); |
| } |
| ir_func->SetParams(params); |
| |
| TINT_SCOPED_ASSIGNMENT(current_block_, ir_func->Block()); |
| EmitBlock(ast_func->body); |
| |
| // Add a terminator if one was not already created. |
| if (NeedTerminator()) { |
| if (!program_->Sem().Get(ast_func->body)->Behaviors().Contains(sem::Behavior::kNext)) { |
| SetTerminator(builder_.Unreachable()); |
| } else { |
| SetTerminator(builder_.Return(current_function_)); |
| } |
| } |
| |
| TINT_ASSERT(IR, control_stack_.IsEmpty()); |
| current_block_ = nullptr; |
| current_function_ = nullptr; |
| } |
| |
| void EmitStatements(utils::VectorRef<const ast::Statement*> stmts) { |
| for (auto* s : stmts) { |
| EmitStatement(s); |
| |
| if (auto* sem = program_->Sem().Get(s); |
| sem && !sem->Behaviors().Contains(sem::Behavior::kNext)) { |
| break; // Unreachable statement. |
| } |
| } |
| } |
| |
| void EmitStatement(const ast::Statement* stmt) { |
| tint::Switch( |
| stmt, // |
| [&](const ast::AssignmentStatement* a) { EmitAssignment(a); }, |
| [&](const ast::BlockStatement* b) { EmitBlock(b); }, |
| [&](const ast::BreakStatement* b) { EmitBreak(b); }, |
| [&](const ast::BreakIfStatement* b) { EmitBreakIf(b); }, |
| [&](const ast::CallStatement* c) { EmitCall(c); }, |
| [&](const ast::CompoundAssignmentStatement* c) { EmitCompoundAssignment(c); }, |
| [&](const ast::ContinueStatement* c) { EmitContinue(c); }, |
| [&](const ast::DiscardStatement* d) { EmitDiscard(d); }, |
| [&](const ast::IfStatement* i) { EmitIf(i); }, |
| [&](const ast::LoopStatement* l) { EmitLoop(l); }, |
| [&](const ast::ForLoopStatement* l) { EmitForLoop(l); }, |
| [&](const ast::WhileStatement* l) { EmitWhile(l); }, |
| [&](const ast::ReturnStatement* r) { EmitReturn(r); }, |
| [&](const ast::SwitchStatement* s) { EmitSwitch(s); }, |
| [&](const ast::VariableDeclStatement* v) { EmitVariable(v->variable); }, |
| [&](const ast::IncrementDecrementStatement* i) { EmitIncrementDecrement(i); }, |
| [&](const ast::ConstAssert*) { |
| // Not emitted |
| }, |
| [&](Default) { |
| add_error(stmt->source, |
| "unknown statement type: " + std::string(stmt->TypeInfo().name)); |
| }); |
| } |
| |
| void EmitAssignment(const ast::AssignmentStatement* stmt) { |
| // If assigning to a phony, just generate the RHS and we're done. Note that, because |
| // this isn't used, a subsequent transform could remove it due to it being dead code. |
| // This could then change the interface for the program (i.e. a global var no longer |
| // used). If that happens we have to either fix this to store to a phony value, or make |
| // sure we pull the interface before doing the dead code elimination. |
| if (stmt->lhs->Is<ast::PhonyExpression>()) { |
| (void)EmitValueExpression(stmt->rhs); |
| return; |
| } |
| |
| auto lhs = EmitExpression(stmt->lhs); |
| |
| auto rhs = EmitValueExpression(stmt->rhs); |
| if (!rhs) { |
| return; |
| } |
| |
| auto b = builder_.With(current_block_); |
| if (auto* v = std::get_if<ir::Value*>(&lhs)) { |
| b.Store(*v, rhs); |
| } else if (auto ref = std::get_if<VectorRefElementAccess>(&lhs)) { |
| b.StoreVectorElement(ref->vector, ref->index, rhs); |
| } |
| } |
| |
| void EmitIncrementDecrement(const ast::IncrementDecrementStatement* stmt) { |
| auto lhs = EmitExpression(stmt->lhs); |
| |
| auto* one = program_->TypeOf(stmt->lhs)->UnwrapRef()->is_signed_integer_scalar() |
| ? builder_.Constant(1_i) |
| : builder_.Constant(1_u); |
| |
| EmitCompoundAssignment(lhs, one, |
| stmt->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract); |
| } |
| |
| void EmitCompoundAssignment(const ast::CompoundAssignmentStatement* stmt) { |
| auto lhs = EmitExpression(stmt->lhs); |
| |
| auto rhs = EmitValueExpression(stmt->rhs); |
| if (!rhs) { |
| return; |
| } |
| |
| EmitCompoundAssignment(lhs, rhs, stmt->op); |
| } |
| |
| void EmitCompoundAssignment(ValueOrVecElAccess lhs, ir::Value* rhs, ast::BinaryOp op) { |
| auto b = builder_.With(current_block_); |
| if (auto* v = std::get_if<ir::Value*>(&lhs)) { |
| auto* load = b.Load(*v); |
| auto* ty = load->Result()->Type(); |
| auto* inst = current_block_->Append(BinaryOp(ty, load->Result(), rhs, op)); |
| b.Store(*v, inst); |
| } else if (auto ref = std::get_if<VectorRefElementAccess>(&lhs)) { |
| auto* load = b.LoadVectorElement(ref->vector, ref->index); |
| auto* ty = load->Result()->Type(); |
| auto* inst = b.Append(BinaryOp(ty, load->Result(), rhs, op)); |
| b.StoreVectorElement(ref->vector, ref->index, inst); |
| } |
| } |
| |
| void EmitBlock(const ast::BlockStatement* block) { |
| scopes_.Push(); |
| TINT_DEFER(scopes_.Pop()); |
| |
| // Note, this doesn't need to emit a Block as the current block should be sufficient as the |
| // blocks all get flattened. Each flow control node will inject the basic blocks it |
| // requires. |
| EmitStatements(block->statements); |
| } |
| |
| void EmitIf(const ast::IfStatement* stmt) { |
| // Emit the if condition into the end of the preceding block |
| auto reg = EmitValueExpression(stmt->condition); |
| if (!reg) { |
| return; |
| } |
| auto* if_inst = builder_.If(reg); |
| current_block_->Append(if_inst); |
| |
| { |
| ControlStackScope scope(this, if_inst); |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, if_inst->True()); |
| EmitBlock(stmt->body); |
| |
| // If the true block did not terminate, then emit an exit_if |
| if (NeedTerminator()) { |
| SetTerminator(builder_.ExitIf(if_inst)); |
| } |
| } |
| |
| if (stmt->else_statement) { |
| TINT_SCOPED_ASSIGNMENT(current_block_, if_inst->False()); |
| EmitStatement(stmt->else_statement); |
| |
| // If the false block did not terminate, then emit an exit_if |
| if (NeedTerminator()) { |
| SetTerminator(builder_.ExitIf(if_inst)); |
| } |
| } |
| } |
| } |
| |
| void EmitLoop(const ast::LoopStatement* stmt) { |
| auto* loop_inst = builder_.Loop(); |
| current_block_->Append(loop_inst); |
| |
| // Note: The loop doesn't use EmitBlock because it needs the scope stack to not get popped |
| // until after the continuing block. |
| |
| ControlStackScope scope(this, loop_inst); |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Body()); |
| |
| EmitStatements(stmt->body->statements); |
| |
| // The current block didn't `break`, `return` or `continue`, go to the continuing block. |
| if (NeedTerminator()) { |
| SetTerminator(builder_.Continue(loop_inst)); |
| } |
| } |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Continuing()); |
| if (stmt->continuing) { |
| EmitBlock(stmt->continuing); |
| } |
| // Branch back to the start block if the continue target didn't terminate already |
| if (NeedTerminator()) { |
| SetTerminator(builder_.NextIteration(loop_inst)); |
| } |
| } |
| } |
| |
| void EmitWhile(const ast::WhileStatement* stmt) { |
| auto* loop_inst = builder_.Loop(); |
| current_block_->Append(loop_inst); |
| |
| ControlStackScope scope(this, loop_inst); |
| |
| // Continue is always empty, just go back to the start |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Continuing()); |
| SetTerminator(builder_.NextIteration(loop_inst)); |
| } |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Body()); |
| |
| // Emit the while condition into the Start().target of the loop |
| auto reg = EmitValueExpression(stmt->condition); |
| if (!reg) { |
| return; |
| } |
| |
| // Create an `if (cond) {} else {break;}` control flow |
| auto* if_inst = builder_.If(reg); |
| current_block_->Append(if_inst); |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, if_inst->True()); |
| SetTerminator(builder_.ExitIf(if_inst)); |
| } |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, if_inst->False()); |
| SetTerminator(builder_.ExitLoop(loop_inst)); |
| } |
| |
| EmitStatements(stmt->body->statements); |
| |
| // The current block didn't `break`, `return` or `continue`, go to the continuing block. |
| if (NeedTerminator()) { |
| SetTerminator(builder_.Continue(loop_inst)); |
| } |
| } |
| } |
| |
| void EmitForLoop(const ast::ForLoopStatement* stmt) { |
| auto* loop_inst = builder_.Loop(); |
| current_block_->Append(loop_inst); |
| |
| ControlStackScope scope(this, loop_inst); |
| |
| if (stmt->initializer) { |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Initializer()); |
| |
| // Emit the for initializer before branching to the loop body |
| EmitStatement(stmt->initializer); |
| |
| if (NeedTerminator()) { |
| SetTerminator(builder_.NextIteration(loop_inst)); |
| } |
| } |
| |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Body()); |
| |
| if (stmt->condition) { |
| // Emit the condition into the target target of the loop body |
| auto reg = EmitValueExpression(stmt->condition); |
| if (!reg) { |
| return; |
| } |
| |
| // Create an `if (cond) {} else {break;}` control flow |
| auto* if_inst = builder_.If(reg); |
| current_block_->Append(if_inst); |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, if_inst->True()); |
| SetTerminator(builder_.ExitIf(if_inst)); |
| } |
| |
| { |
| TINT_SCOPED_ASSIGNMENT(current_block_, if_inst->False()); |
| SetTerminator(builder_.ExitLoop(loop_inst)); |
| } |
| } |
| |
| EmitBlock(stmt->body); |
| if (NeedTerminator()) { |
| SetTerminator(builder_.Continue(loop_inst)); |
| } |
| |
| if (stmt->continuing) { |
| TINT_SCOPED_ASSIGNMENT(current_block_, loop_inst->Continuing()); |
| EmitStatement(stmt->continuing); |
| SetTerminator(builder_.NextIteration(loop_inst)); |
| } |
| } |
| |
| void EmitSwitch(const ast::SwitchStatement* stmt) { |
| // Emit the condition into the preceding block |
| auto reg = EmitValueExpression(stmt->condition); |
| if (!reg) { |
| return; |
| } |
| auto* switch_inst = builder_.Switch(reg); |
| current_block_->Append(switch_inst); |
| |
| ControlStackScope scope(this, switch_inst); |
| |
| const auto* sem = program_->Sem().Get(stmt); |
| for (const auto* c : sem->Cases()) { |
| utils::Vector<Switch::CaseSelector, 4> selectors; |
| for (const auto* selector : c->Selectors()) { |
| if (selector->IsDefault()) { |
| selectors.Push({nullptr}); |
| } else { |
| selectors.Push({builder_.Constant(selector->Value()->Clone(clone_ctx_))}); |
| } |
| } |
| |
| TINT_SCOPED_ASSIGNMENT(current_block_, builder_.Case(switch_inst, selectors)); |
| EmitBlock(c->Body()->Declaration()); |
| |
| if (NeedTerminator()) { |
| SetTerminator(builder_.ExitSwitch(switch_inst)); |
| } |
| } |
| } |
| |
| void EmitReturn(const ast::ReturnStatement* stmt) { |
| Value* ret_value = nullptr; |
| if (stmt->value) { |
| auto ret = EmitValueExpression(stmt->value); |
| if (!ret) { |
| return; |
| } |
| ret_value = ret; |
| } |
| if (ret_value) { |
| SetTerminator(builder_.Return(current_function_, ret_value)); |
| } else { |
| SetTerminator(builder_.Return(current_function_)); |
| } |
| } |
| |
| void EmitBreak(const ast::BreakStatement*) { |
| auto* current_control = FindEnclosingControl(ControlFlags::kNone); |
| TINT_ASSERT(IR, current_control); |
| |
| if (auto* c = current_control->As<Loop>()) { |
| SetTerminator(builder_.ExitLoop(c)); |
| } else if (auto* s = current_control->As<Switch>()) { |
| SetTerminator(builder_.ExitSwitch(s)); |
| } else { |
| TINT_UNREACHABLE(IR, diagnostics_); |
| } |
| } |
| |
| void EmitContinue(const ast::ContinueStatement*) { |
| auto* current_control = FindEnclosingControl(ControlFlags::kExcludeSwitch); |
| TINT_ASSERT(IR, current_control); |
| |
| if (auto* c = current_control->As<Loop>()) { |
| SetTerminator(builder_.Continue(c)); |
| } else { |
| TINT_UNREACHABLE(IR, diagnostics_); |
| } |
| } |
| |
| // Discard is being treated as an instruction. The semantics in WGSL is demote_to_helper, so |
| // the code has to continue as before it just predicates writes. If WGSL grows some kind of |
| // terminating discard that would probably make sense as a Block but would then require |
| // figuring out the multi-level exit that is triggered. |
| void EmitDiscard(const ast::DiscardStatement*) { |
| auto* inst = builder_.Discard(); |
| current_block_->Append(inst); |
| } |
| |
| void EmitBreakIf(const ast::BreakIfStatement* stmt) { |
| auto* current_control = FindEnclosingControl(ControlFlags::kExcludeSwitch); |
| |
| // Emit the break-if condition into the end of the preceding block |
| auto cond = EmitValueExpression(stmt->condition); |
| if (!cond) { |
| return; |
| } |
| SetTerminator(builder_.BreakIf(current_control->As<ir::Loop>(), cond)); |
| } |
| |
| ValueOrVecElAccess EmitExpression(const ast::Expression* root) { |
| struct Emitter { |
| explicit Emitter(Impl& i) : impl(i) {} |
| |
| ValueOrVecElAccess Emit(const ast::Expression* root) { |
| // Process the root expression. This will likely add tasks. |
| Process(root); |
| |
| // Execute all the tasks until all expressions have been resolved. |
| while (!tasks.IsEmpty()) { |
| auto task = tasks.Pop(); |
| task(); |
| } |
| |
| // Get the resolved root expression. |
| return Get(root); |
| } |
| |
| private: |
| Impl& impl; |
| utils::Vector<ir::Block*, 8> blocks; |
| utils::Vector<std::function<void()>, 64> tasks; |
| utils::Hashmap<const ast::Expression*, ValueOrVecElAccess, 64> bindings_; |
| |
| void Bind(const ast::Expression* expr, ir::Value* value) { |
| // If this expression maps to sem::Load, insert a load instruction to get the result |
| if (impl.program_->Sem().Get<sem::Load>(expr)) { |
| auto* load = impl.builder_.Load(value); |
| impl.current_block_->Append(load); |
| value = load->Result(); |
| } |
| bindings_.Add(expr, value); |
| } |
| |
| void Bind(const ast::Expression* expr, const VectorRefElementAccess& access) { |
| // If this expression maps to sem::Load, insert a load instruction to get the result |
| if (impl.program_->Sem().Get<sem::Load>(expr)) { |
| auto* load = impl.builder_.LoadVectorElement(access.vector, access.index); |
| impl.current_block_->Append(load); |
| bindings_.Add(expr, load->Result()); |
| } else { |
| bindings_.Add(expr, access); |
| } |
| } |
| |
| ValueOrVecElAccess Get(const ast::Expression* expr) { |
| auto val = bindings_.Get(expr); |
| if (!val) { |
| return nullptr; |
| } |
| return *val; |
| } |
| |
| Value* GetValue(const ast::Expression* expr) { |
| auto res = Get(expr); |
| if (auto** val = std::get_if<Value*>(&res)) { |
| return *val; |
| } |
| TINT_ICE(IR, impl.diagnostics_) << "expression did not resolve to a value"; |
| return nullptr; |
| } |
| |
| void PushBlock(ir::Block* block) { |
| blocks.Push(impl.current_block_); |
| impl.current_block_ = block; |
| } |
| |
| void PopBlock() { impl.current_block_ = blocks.Pop(); } |
| |
| Value* EmitConstant(const ast::Expression* expr) { |
| if (auto* sem = impl.program_->Sem().GetVal(expr)) { |
| if (auto* v = sem->ConstantValue()) { |
| if (auto* cv = v->Clone(impl.clone_ctx_)) { |
| auto* val = impl.builder_.Constant(cv); |
| bindings_.Add(expr, val); |
| return val; |
| } |
| } |
| } |
| return nullptr; |
| } |
| |
| void EmitAccess(const ast::AccessorExpression* expr) { |
| if (auto vec_access = AsVectorRefElementAccess(expr)) { |
| Bind(expr, vec_access.value()); |
| return; |
| } |
| |
| auto* obj = GetValue(expr->object); |
| if (!obj) { |
| TINT_ASSERT(IR, false && "no object result"); |
| return; |
| } |
| |
| auto* sem = impl.program_->Sem().Get(expr)->Unwrap(); |
| |
| // The access result type should match the source result type. If the source is a |
| // pointer, we generate a pointer. |
| const type::Type* ty = sem->Type()->UnwrapRef()->Clone(impl.clone_ctx_.type_ctx); |
| if (auto* ptr = obj->Type()->As<type::Pointer>(); ptr && !ty->Is<type::Pointer>()) { |
| ty = impl.builder_.ir.Types().ptr(ptr->AddressSpace(), ty, ptr->Access()); |
| } |
| |
| auto index = tint::Switch( |
| sem, |
| [&](const sem::IndexAccessorExpression* idx) -> ir::Value* { |
| if (auto* v = idx->Index()->ConstantValue()) { |
| if (auto* cv = v->Clone(impl.clone_ctx_)) { |
| return impl.builder_.Constant(cv); |
| } |
| TINT_ASSERT(IR, false && "constant clone failed"); |
| return nullptr; |
| } |
| return GetValue(idx->Index()->Declaration()); |
| }, |
| [&](const sem::StructMemberAccess* access) -> ir::Value* { |
| return impl.builder_.Constant(u32((access->Member()->Index()))); |
| }, |
| [&](const sem::Swizzle* swizzle) -> ir::Value* { |
| auto& indices = swizzle->Indices(); |
| |
| // A single element swizzle is just treated as an accessor. |
| if (indices.Length() == 1) { |
| return impl.builder_.Constant(u32(indices[0])); |
| } |
| auto* val = impl.builder_.Swizzle(ty, obj, std::move(indices)); |
| impl.current_block_->Append(val); |
| Bind(expr, val->Result()); |
| return nullptr; |
| }, |
| [&](Default) { |
| TINT_ICE(Writer, impl.diagnostics_) |
| << "invalid accessor: " + std::string(sem->TypeInfo().name); |
| return nullptr; |
| }); |
| |
| if (!index) { |
| return; |
| } |
| |
| // If the object is an unnamed value (a subexpression, not a let) and is the result |
| // of another access, then we can just append the index to that access. |
| if (!impl.mod.NameOf(obj).IsValid()) { |
| if (auto* inst_res = obj->As<InstructionResult>()) { |
| if (auto* access = inst_res->Source()->As<Access>()) { |
| access->AddIndex(index); |
| access->Result()->SetType(ty); |
| bindings_.Remove(expr->object); |
| // Move the access after the index expression. |
| if (impl.current_block_->Back() != access) { |
| impl.current_block_->Remove(access); |
| impl.current_block_->Append(access); |
| } |
| Bind(expr, access->Result()); |
| return; |
| } |
| } |
| } |
| |
| // Create a new access |
| auto* access = impl.builder_.Access(ty, obj, index); |
| impl.current_block_->Append(access); |
| Bind(expr, access->Result()); |
| } |
| |
| void EmitBinary(const ast::BinaryExpression* b) { |
| auto* b_sem = impl.program_->Sem().Get(b); |
| auto* ty = b_sem->Type()->Clone(impl.clone_ctx_.type_ctx); |
| auto lhs = GetValue(b->lhs); |
| if (!lhs) { |
| return; |
| } |
| auto rhs = GetValue(b->rhs); |
| if (!rhs) { |
| return; |
| } |
| Binary* inst = impl.BinaryOp(ty, lhs, rhs, b->op); |
| if (!inst) { |
| return; |
| } |
| impl.current_block_->Append(inst); |
| Bind(b, inst->Result()); |
| } |
| |
| void EmitUnary(const ast::UnaryOpExpression* expr) { |
| auto val = GetValue(expr->expr); |
| if (!val) { |
| return; |
| } |
| auto* sem = impl.program_->Sem().Get(expr); |
| auto* ty = sem->Type()->Clone(impl.clone_ctx_.type_ctx); |
| Instruction* inst = nullptr; |
| switch (expr->op) { |
| case ast::UnaryOp::kAddressOf: |
| case ast::UnaryOp::kIndirection: |
| // 'address-of' and 'indirection' just fold away and we propagate the |
| // pointer. |
| Bind(expr, val); |
| return; |
| case ast::UnaryOp::kComplement: |
| inst = impl.builder_.Complement(ty, val); |
| break; |
| case ast::UnaryOp::kNegation: |
| inst = impl.builder_.Negation(ty, val); |
| break; |
| case ast::UnaryOp::kNot: |
| inst = impl.builder_.Not(ty, val); |
| break; |
| } |
| impl.current_block_->Append(inst); |
| Bind(expr, inst->Result()); |
| } |
| |
| void EmitBitcast(const ast::BitcastExpression* b) { |
| auto val = GetValue(b->expr); |
| if (!val) { |
| return; |
| } |
| auto* sem = impl.program_->Sem().Get(b); |
| auto* ty = sem->Type()->Clone(impl.clone_ctx_.type_ctx); |
| auto* inst = impl.builder_.Bitcast(ty, val); |
| impl.current_block_->Append(inst); |
| Bind(b, inst->Result()); |
| } |
| |
| void EmitCall(const ast::CallExpression* expr) { |
| // If this is a materialized semantic node, just use the constant value. |
| if (auto* mat = impl.program_->Sem().Get(expr)) { |
| if (mat->ConstantValue()) { |
| auto* cv = mat->ConstantValue()->Clone(impl.clone_ctx_); |
| if (!cv) { |
| impl.add_error(expr->source, "failed to get constant value for call " + |
| std::string(expr->TypeInfo().name)); |
| return; |
| } |
| Bind(expr, impl.builder_.Constant(cv)); |
| return; |
| } |
| } |
| utils::Vector<Value*, 8> args; |
| args.Reserve(expr->args.Length()); |
| // Emit the arguments |
| for (const auto* arg : expr->args) { |
| auto value = GetValue(arg); |
| if (!value) { |
| impl.add_error(arg->source, "failed to convert arguments"); |
| return; |
| } |
| args.Push(value); |
| } |
| auto* sem = impl.program_->Sem().Get<sem::Call>(expr); |
| if (!sem) { |
| impl.add_error(expr->source, "failed to get semantic information for call " + |
| std::string(expr->TypeInfo().name)); |
| return; |
| } |
| auto* ty = sem->Target()->ReturnType()->Clone(impl.clone_ctx_.type_ctx); |
| Instruction* inst = nullptr; |
| // If this is a builtin function, emit the specific builtin value |
| if (auto* b = sem->Target()->As<sem::Builtin>()) { |
| inst = impl.builder_.Call(ty, b->Type(), args); |
| } else if (sem->Target()->As<sem::ValueConstructor>()) { |
| inst = impl.builder_.Construct(ty, std::move(args)); |
| } else if (sem->Target()->Is<sem::ValueConversion>()) { |
| inst = impl.builder_.Convert(ty, args[0]); |
| } else if (expr->target->identifier->Is<ast::TemplatedIdentifier>()) { |
| TINT_UNIMPLEMENTED(IR, impl.diagnostics_) << "missing templated ident support"; |
| return; |
| } else { |
| // Not a builtin and not a templated call, so this is a user function. |
| inst = impl.builder_.Call( |
| ty, impl.scopes_.Get(expr->target->identifier->symbol)->As<ir::Function>(), |
| std::move(args)); |
| } |
| if (inst == nullptr) { |
| return; |
| } |
| impl.current_block_->Append(inst); |
| Bind(expr, inst->Result()); |
| } |
| |
| void EmitIdentifier(const ast::IdentifierExpression* i) { |
| auto* v = impl.scopes_.Get(i->identifier->symbol); |
| if (TINT_UNLIKELY(!v)) { |
| impl.add_error(i->source, |
| "unable to find identifier " + i->identifier->symbol.Name()); |
| return; |
| } |
| Bind(i, v); |
| } |
| |
| void EmitLiteral(const ast::LiteralExpression* lit) { |
| auto* sem = impl.program_->Sem().Get(lit); |
| if (!sem) { |
| impl.add_error(lit->source, "failed to get semantic information for node " + |
| std::string(lit->TypeInfo().name)); |
| return; |
| } |
| auto* cv = sem->ConstantValue()->Clone(impl.clone_ctx_); |
| if (!cv) { |
| impl.add_error(lit->source, "failed to get constant value for node " + |
| std::string(lit->TypeInfo().name)); |
| return; |
| } |
| auto* val = impl.builder_.Constant(cv); |
| Bind(lit, val); |
| } |
| |
| std::optional<VectorRefElementAccess> AsVectorRefElementAccess( |
| const ast::Expression* expr) { |
| return AsVectorRefElementAccess( |
| impl.program_->Sem().Get<sem::ValueExpression>(expr)->UnwrapLoad()); |
| } |
| |
| std::optional<VectorRefElementAccess> AsVectorRefElementAccess( |
| const sem::ValueExpression* expr) { |
| auto* access = As<sem::AccessorExpression>(expr); |
| if (!access) { |
| return std::nullopt; |
| } |
| |
| auto* ref = access->Object()->Type()->As<type::Reference>(); |
| if (!ref) { |
| return std::nullopt; |
| } |
| |
| if (!ref->StoreType()->Is<type::Vector>()) { |
| return std::nullopt; |
| } |
| return tint::Switch( |
| access, |
| [&](const sem::Swizzle* s) -> std::optional<VectorRefElementAccess> { |
| if (auto vec = GetValue(access->Object()->Declaration())) { |
| return VectorRefElementAccess{ |
| vec, impl.builder_.Constant(u32(s->Indices()[0]))}; |
| } |
| return std::nullopt; |
| }, |
| [&](const sem::IndexAccessorExpression* i) |
| -> std::optional<VectorRefElementAccess> { |
| if (auto vec = GetValue(access->Object()->Declaration())) { |
| if (auto idx = GetValue(i->Index()->Declaration())) { |
| return VectorRefElementAccess{vec, idx}; |
| } |
| } |
| return std::nullopt; |
| }); |
| } |
| |
| void BeginShortCircuit(const ast::BinaryExpression* expr) { |
| auto lhs = GetValue(expr->lhs); |
| if (!lhs) { |
| return; |
| } |
| |
| auto& b = impl.builder_; |
| auto* if_inst = b.If(lhs); |
| impl.current_block_->Append(if_inst); |
| |
| auto* result = b.InstructionResult(b.ir.Types().bool_()); |
| if_inst->SetResults(result); |
| |
| if (expr->op == ast::BinaryOp::kLogicalAnd) { |
| if_inst->False()->Append(b.ExitIf(if_inst, b.Constant(false))); |
| PushBlock(if_inst->True()); |
| } else { |
| if_inst->True()->Append(b.ExitIf(if_inst, b.Constant(true))); |
| PushBlock(if_inst->False()); |
| } |
| |
| Bind(expr, result); |
| } |
| |
| void EndShortCircuit(const ast::BinaryExpression* b) { |
| auto res = GetValue(b); |
| auto* src = res->As<InstructionResult>()->Source(); |
| auto* if_ = src->As<ir::If>(); |
| TINT_ASSERT_OR_RETURN(IR, if_); |
| auto rhs = GetValue(b->rhs); |
| if (!rhs) { |
| return; |
| } |
| impl.current_block_->Append(impl.builder_.ExitIf(if_, rhs)); |
| PopBlock(); |
| } |
| |
| void Process(const ast::Expression* expr) { |
| if (EmitConstant(expr)) { |
| // If this is a value that has been const-eval'd, then no need to traverse |
| // deeper. |
| return; |
| } |
| |
| tint::Switch( |
| expr, // |
| [&](const ast::BinaryExpression* e) { |
| if (e->op == ast::BinaryOp::kLogicalAnd || |
| e->op == ast::BinaryOp::kLogicalOr) { |
| tasks.Push([=] { EndShortCircuit(e); }); |
| tasks.Push([=] { Process(e->rhs); }); |
| tasks.Push([=] { BeginShortCircuit(e); }); |
| tasks.Push([=] { Process(e->lhs); }); |
| } else { |
| tasks.Push([=] { EmitBinary(e); }); |
| tasks.Push([=] { Process(e->rhs); }); |
| tasks.Push([=] { Process(e->lhs); }); |
| } |
| }, |
| [&](const ast::IndexAccessorExpression* e) { |
| tasks.Push([=] { EmitAccess(e); }); |
| tasks.Push([=] { Process(e->index); }); |
| tasks.Push([=] { Process(e->object); }); |
| }, |
| [&](const ast::MemberAccessorExpression* e) { |
| tasks.Push([=] { EmitAccess(e); }); |
| tasks.Push([=] { Process(e->object); }); |
| }, |
| [&](const ast::UnaryOpExpression* e) { |
| tasks.Push([=] { EmitUnary(e); }); |
| tasks.Push([=] { Process(e->expr); }); |
| }, |
| [&](const ast::CallExpression* e) { |
| tasks.Push([=] { EmitCall(e); }); |
| for (auto* arg : utils::Reverse(e->args)) { |
| tasks.Push([=] { Process(arg); }); |
| } |
| }, |
| [&](const ast::BitcastExpression* e) { |
| tasks.Push([=] { EmitBitcast(e); }); |
| tasks.Push([=] { Process(e->expr); }); |
| }, |
| [&](const ast::LiteralExpression* e) { EmitLiteral(e); }, |
| [&](const ast::IdentifierExpression* e) { EmitIdentifier(e); }, |
| [&](Default) { |
| impl.add_error(expr->source, |
| "Unhandled: " + std::string(expr->TypeInfo().name)); |
| }); |
| } |
| }; |
| |
| return Emitter(*this).Emit(root); |
| } |
| |
| Value* EmitValueExpression(const ast::Expression* root) { |
| auto res = EmitExpression(root); |
| if (auto** val = std::get_if<Value*>(&res)) { |
| return *val; |
| } |
| TINT_ICE(IR, diagnostics_) << "expression did not resolve to a value"; |
| return nullptr; |
| } |
| |
| void EmitCall(const ast::CallStatement* stmt) { (void)EmitValueExpression(stmt->expr); } |
| |
| void EmitVariable(const ast::Variable* var) { |
| auto* sem = program_->Sem().Get(var); |
| |
| return tint::Switch( // |
| var, |
| [&](const ast::Var* v) { |
| auto* ref = sem->Type()->As<type::Reference>(); |
| auto* ty = builder_.ir.Types().Get<type::Pointer>( |
| ref->AddressSpace(), ref->StoreType()->Clone(clone_ctx_.type_ctx), |
| ref->Access()); |
| |
| auto* val = builder_.Var(ty); |
| if (v->initializer) { |
| auto init = EmitValueExpression(v->initializer); |
| if (!init) { |
| return; |
| } |
| val->SetInitializer(init); |
| } |
| current_block_->Append(val); |
| |
| if (auto* gv = sem->As<sem::GlobalVariable>(); gv && var->HasBindingPoint()) { |
| val->SetBindingPoint(gv->BindingPoint().value().group, |
| gv->BindingPoint().value().binding); |
| } |
| |
| // Store the declaration so we can get the instruction to store too |
| scopes_.Set(v->name->symbol, val->Result()); |
| |
| // Record the original name of the var |
| builder_.ir.SetName(val, v->name->symbol.Name()); |
| }, |
| [&](const ast::Let* l) { |
| auto* last_stmt = current_block_->Back(); |
| auto init = EmitValueExpression(l->initializer); |
| if (!init) { |
| return; |
| } |
| |
| auto* value = init; |
| if (current_block_->Back() == last_stmt) { |
| // Emitting the let's initializer didn't create an instruction. |
| // Create an ir::Let to give the let an instruction. This gives the let a |
| // place of declaration and name, which preserves runtime semantics of the |
| // let, and can be used by consumers of the IR to produce a variable or |
| // debug info. |
| auto* let = current_block_->Append(builder_.Let(l->name->symbol.Name(), value)); |
| value = let->Result(); |
| } else { |
| // Record the original name of the let |
| builder_.ir.SetName(value, l->name->symbol.Name()); |
| } |
| |
| // Store the results of the initialization |
| scopes_.Set(l->name->symbol, value); |
| }, |
| [&](const ast::Override*) { |
| add_error(var->source, |
| "found an `Override` variable. The SubstituteOverrides " |
| "transform must be run before converting to IR"); |
| }, |
| [&](const ast::Const*) { |
| // Skip. This should be handled by const-eval already, so the const will be a |
| // `constant::` value at the usage sites. Can just ignore the `const` variable |
| // as it should never be used. |
| // |
| // TODO(dsinclair): Probably want to store the const variable somewhere and then |
| // in identifier expression log an error if we ever see a const identifier. Add |
| // this when identifiers and variables are supported. |
| }, |
| [&](Default) { |
| add_error(var->source, "unknown variable: " + std::string(var->TypeInfo().name)); |
| }); |
| } |
| |
| ir::Binary* BinaryOp(const type::Type* ty, ir::Value* lhs, ir::Value* rhs, ast::BinaryOp op) { |
| switch (op) { |
| case ast::BinaryOp::kAnd: |
| return builder_.And(ty, lhs, rhs); |
| case ast::BinaryOp::kOr: |
| return builder_.Or(ty, lhs, rhs); |
| case ast::BinaryOp::kXor: |
| return builder_.Xor(ty, lhs, rhs); |
| case ast::BinaryOp::kEqual: |
| return builder_.Equal(ty, lhs, rhs); |
| case ast::BinaryOp::kNotEqual: |
| return builder_.NotEqual(ty, lhs, rhs); |
| case ast::BinaryOp::kLessThan: |
| return builder_.LessThan(ty, lhs, rhs); |
| case ast::BinaryOp::kGreaterThan: |
| return builder_.GreaterThan(ty, lhs, rhs); |
| case ast::BinaryOp::kLessThanEqual: |
| return builder_.LessThanEqual(ty, lhs, rhs); |
| case ast::BinaryOp::kGreaterThanEqual: |
| return builder_.GreaterThanEqual(ty, lhs, rhs); |
| case ast::BinaryOp::kShiftLeft: |
| return builder_.ShiftLeft(ty, lhs, rhs); |
| case ast::BinaryOp::kShiftRight: |
| return builder_.ShiftRight(ty, lhs, rhs); |
| case ast::BinaryOp::kAdd: |
| return builder_.Add(ty, lhs, rhs); |
| case ast::BinaryOp::kSubtract: |
| return builder_.Subtract(ty, lhs, rhs); |
| case ast::BinaryOp::kMultiply: |
| return builder_.Multiply(ty, lhs, rhs); |
| case ast::BinaryOp::kDivide: |
| return builder_.Divide(ty, lhs, rhs); |
| case ast::BinaryOp::kModulo: |
| return builder_.Modulo(ty, lhs, rhs); |
| case ast::BinaryOp::kLogicalAnd: |
| case ast::BinaryOp::kLogicalOr: |
| TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled"; |
| return nullptr; |
| case ast::BinaryOp::kNone: |
| TINT_ICE(IR, diagnostics_) << "missing binary operand type"; |
| return nullptr; |
| } |
| TINT_UNREACHABLE(IR, diagnostics_); |
| return nullptr; |
| } |
| }; |
| |
| } // namespace |
| |
| utils::Result<Module, std::string> FromProgram(const Program* program) { |
| if (!program->IsValid()) { |
| return std::string("input program is not valid"); |
| } |
| |
| Impl b(program); |
| auto r = b.Build(); |
| if (!r) { |
| return r.Failure().str(); |
| } |
| |
| return r.Move(); |
| } |
| |
| } // namespace tint::ir |