| // Copyright 2022 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/transform/unwind_discard_functions.h" |
| |
| #include <memory> |
| #include <string> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| #include "src/tint/ast/discard_statement.h" |
| #include "src/tint/ast/return_statement.h" |
| #include "src/tint/ast/traverse_expressions.h" |
| #include "src/tint/sem/block_statement.h" |
| #include "src/tint/sem/call.h" |
| #include "src/tint/sem/for_loop_statement.h" |
| #include "src/tint/sem/function.h" |
| #include "src/tint/sem/if_statement.h" |
| #include "src/tint/transform/utils/get_insertion_point.h" |
| |
| TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions); |
| |
| namespace tint::transform { |
| namespace { |
| |
| class State { |
| private: |
| CloneContext& ctx; |
| ProgramBuilder& b; |
| const sem::Info& sem; |
| Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read |
| Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read |
| |
| // Returns true if `sem_expr` contains a call expression that may |
| // (transitively) execute a discard statement. |
| bool MayDiscard(const sem::Expression* sem_expr) { |
| return sem_expr && sem_expr->Behaviors().Contains(sem::Behavior::kDiscard); |
| } |
| |
| // Lazily creates and returns the name of the module bool variable for whether |
| // to discard: "tint_discard". |
| Symbol ModuleDiscardVarName() { |
| if (!module_discard_var_name.IsValid()) { |
| module_discard_var_name = b.Symbols().New("tint_discard"); |
| ctx.dst->GlobalVar(module_discard_var_name, b.ty.bool_(), b.Expr(false), |
| ast::StorageClass::kPrivate); |
| } |
| return module_discard_var_name; |
| } |
| |
| // Lazily creates and returns the name of the function that contains a single |
| // discard statement: "tint_discard_func". |
| // We do this to avoid having multiple discard statements in a single program, |
| // which causes problems in certain backends (see crbug.com/1118). |
| Symbol ModuleDiscardFuncName() { |
| if (!module_discard_func_name.IsValid()) { |
| module_discard_func_name = b.Symbols().New("tint_discard_func"); |
| b.Func(module_discard_func_name, {}, b.ty.void_(), {b.Discard()}); |
| } |
| return module_discard_func_name; |
| } |
| |
| // Creates "return <default return value>;" based on the return type of |
| // `stmt`'s owning function. |
| const ast::ReturnStatement* Return(const ast::Statement* stmt) { |
| const ast::Expression* ret_val = nullptr; |
| auto* ret_type = sem.Get(stmt)->Function()->Declaration()->return_type; |
| if (!ret_type->Is<ast::Void>()) { |
| ret_val = b.Construct(ctx.Clone(ret_type)); |
| } |
| return b.Return(ret_val); |
| } |
| |
| // Returns true if the function `stmt` is in is an entry point |
| bool IsInEntryPointFunc(const ast::Statement* stmt) { |
| return sem.Get(stmt)->Function()->Declaration()->IsEntryPoint(); |
| } |
| |
| // Creates "tint_discard_func();" |
| const ast::CallStatement* CallDiscardFunc() { |
| auto func_name = ModuleDiscardFuncName(); |
| return b.CallStmt(b.Call(func_name)); |
| } |
| |
| // Creates and returns a new if-statement of the form: |
| // |
| // if (tint_discard) { |
| // return <default value>; |
| // } |
| // |
| // or if `stmt` is in a entry point function: |
| // |
| // if (tint_discard) { |
| // tint_discard_func(); |
| // return <default value>; |
| // } |
| // |
| const ast::IfStatement* IfDiscardReturn(const ast::Statement* stmt) { |
| ast::StatementList stmts; |
| |
| // For entry point functions, also emit the discard statement |
| if (IsInEntryPointFunc(stmt)) { |
| stmts.emplace_back(CallDiscardFunc()); |
| } |
| |
| stmts.emplace_back(Return(stmt)); |
| |
| auto var_name = ModuleDiscardVarName(); |
| return b.If(var_name, b.Block(stmts)); |
| } |
| |
| // Hoists `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`. |
| // For example, if `stmt` is: |
| // |
| // return f(); |
| // |
| // This function will transform this to: |
| // |
| // let t1 = f(); |
| // if (tint_discard) { |
| // return; |
| // } |
| // return t1; |
| // |
| const ast::Statement* HoistAndInsertBefore(const ast::Statement* stmt, |
| const sem::Expression* sem_expr) { |
| auto* expr = sem_expr->Declaration(); |
| |
| auto ip = utils::GetInsertionPoint(ctx, stmt); |
| auto var_name = b.Sym(); |
| auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr))); |
| ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl); |
| |
| ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt)); |
| |
| auto* var_expr = b.Expr(var_name); |
| |
| // Special handling for CallStatement as we can only replace its expression |
| // with a CallExpression. |
| if (stmt->Is<ast::CallStatement>()) { |
| // We could replace the call statement with no statement, but we can't do |
| // that with transforms (yet), so just return a phony assignment. |
| return b.Assign(b.Phony(), var_expr); |
| } |
| |
| ctx.Replace(expr, var_expr); |
| return ctx.CloneWithoutTransform(stmt); |
| } |
| |
| // Returns true if `stmt` is a for-loop initializer statement. |
| bool IsForLoopInitStatement(const ast::Statement* stmt) { |
| if (auto* sem_stmt = sem.Get(stmt)) { |
| if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) { |
| return sem_fl->Declaration()->initializer == stmt; |
| } |
| } |
| return false; |
| } |
| |
| // Inserts an `IfDiscardReturn` after `stmt` if possible (i.e. `stmt` is not |
| // in a for-loop init), otherwise falls back to HoistAndInsertBefore, hoisting |
| // `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`. |
| // |
| // For example, if `stmt` is: |
| // |
| // let r = f(); |
| // |
| // This function will transform this to: |
| // |
| // let r = f(); |
| // if (tint_discard) { |
| // return; |
| // } |
| const ast::Statement* TryInsertAfter(const ast::Statement* stmt, |
| const sem::Expression* sem_expr) { |
| // If `stmt` is the init of a for-loop, hoist and insert before instead. |
| if (IsForLoopInitStatement(stmt)) { |
| return HoistAndInsertBefore(stmt, sem_expr); |
| } |
| |
| auto ip = utils::GetInsertionPoint(ctx, stmt); |
| ctx.InsertAfter(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt)); |
| return nullptr; // Don't replace current statement |
| } |
| |
| // Replaces the input discard statement with either setting the module level |
| // discard bool ("tint_discard = true"), or calling the discard function |
| // ("tint_discard_func()"), followed by a default return statement. |
| // |
| // Replaces "discard;" with: |
| // |
| // tint_discard = true; |
| // return; |
| // |
| // Or if `stmt` is a entry point function, replaces with: |
| // |
| // tint_discard_func(); |
| // return; |
| // |
| const ast::Statement* ReplaceDiscardStatement(const ast::DiscardStatement* stmt) { |
| const ast::Statement* to_insert = nullptr; |
| if (IsInEntryPointFunc(stmt)) { |
| to_insert = CallDiscardFunc(); |
| } else { |
| auto var_name = ModuleDiscardVarName(); |
| to_insert = b.Assign(var_name, true); |
| } |
| |
| auto ip = utils::GetInsertionPoint(ctx, stmt); |
| ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert); |
| return Return(stmt); |
| } |
| |
| // Handle statement |
| const ast::Statement* Statement(const ast::Statement* stmt) { |
| return Switch( |
| stmt, |
| [&](const ast::DiscardStatement* s) -> const ast::Statement* { |
| return ReplaceDiscardStatement(s); |
| }, |
| [&](const ast::AssignmentStatement* s) -> const ast::Statement* { |
| auto* sem_lhs = sem.Get(s->lhs); |
| auto* sem_rhs = sem.Get(s->rhs); |
| if (MayDiscard(sem_lhs)) { |
| if (MayDiscard(sem_rhs)) { |
| TINT_ICE(Transform, b.Diagnostics()) |
| << "Unexpected: both sides of assignment statement may " |
| "discard. Make sure transform::PromoteSideEffectsToDecl " |
| "was run first."; |
| } |
| return TryInsertAfter(s, sem_lhs); |
| } else if (MayDiscard(sem_rhs)) { |
| return TryInsertAfter(s, sem_rhs); |
| } |
| return nullptr; |
| }, |
| [&](const ast::CallStatement* s) -> const ast::Statement* { |
| auto* sem_expr = sem.Get(s->expr); |
| if (!MayDiscard(sem_expr)) { |
| return nullptr; |
| } |
| return TryInsertAfter(s, sem_expr); |
| }, |
| [&](const ast::ForLoopStatement* s) -> const ast::Statement* { |
| if (MayDiscard(sem.Get(s->condition))) { |
| TINT_ICE(Transform, b.Diagnostics()) |
| << "Unexpected ForLoopStatement condition that may discard. " |
| "Make sure transform::PromoteSideEffectsToDecl was run " |
| "first."; |
| } |
| return nullptr; |
| }, |
| [&](const ast::WhileStatement* s) -> const ast::Statement* { |
| if (MayDiscard(sem.Get(s->condition))) { |
| TINT_ICE(Transform, b.Diagnostics()) |
| << "Unexpected WhileStatement condition that may discard. " |
| "Make sure transform::PromoteSideEffectsToDecl was run " |
| "first."; |
| } |
| return nullptr; |
| }, |
| [&](const ast::IfStatement* s) -> const ast::Statement* { |
| auto* sem_expr = sem.Get(s->condition); |
| if (!MayDiscard(sem_expr)) { |
| return nullptr; |
| } |
| return HoistAndInsertBefore(s, sem_expr); |
| }, |
| [&](const ast::ReturnStatement* s) -> const ast::Statement* { |
| auto* sem_expr = sem.Get(s->value); |
| if (!MayDiscard(sem_expr)) { |
| return nullptr; |
| } |
| return HoistAndInsertBefore(s, sem_expr); |
| }, |
| [&](const ast::SwitchStatement* s) -> const ast::Statement* { |
| auto* sem_expr = sem.Get(s->condition); |
| if (!MayDiscard(sem_expr)) { |
| return nullptr; |
| } |
| return HoistAndInsertBefore(s, sem_expr); |
| }, |
| [&](const ast::VariableDeclStatement* s) -> const ast::Statement* { |
| auto* var = s->variable; |
| if (!var->constructor) { |
| return nullptr; |
| } |
| auto* sem_expr = sem.Get(var->constructor); |
| if (!MayDiscard(sem_expr)) { |
| return nullptr; |
| } |
| return TryInsertAfter(s, sem_expr); |
| }); |
| } |
| |
| public: |
| /// Constructor |
| /// @param ctx_in the context |
| explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} |
| |
| /// Runs the transform |
| void Run() { |
| ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { |
| // Iterate block statements and replace them as needed. |
| for (auto* stmt : block->statements) { |
| if (auto* new_stmt = Statement(stmt)) { |
| ctx.Replace(stmt, new_stmt); |
| } |
| |
| // Handle for loops, as they are the only other AST node that |
| // contains statements outside of BlockStatements. |
| if (auto* fl = stmt->As<ast::ForLoopStatement>()) { |
| if (auto* new_stmt = Statement(fl->initializer)) { |
| ctx.Replace(fl->initializer, new_stmt); |
| } |
| if (auto* new_stmt = Statement(fl->continuing)) { |
| // NOTE: Should never reach here as we cannot discard in a |
| // continuing block. |
| ctx.Replace(fl->continuing, new_stmt); |
| } |
| } |
| } |
| |
| return nullptr; |
| }); |
| |
| ctx.Clone(); |
| } |
| }; |
| |
| } // namespace |
| |
| UnwindDiscardFunctions::UnwindDiscardFunctions() = default; |
| UnwindDiscardFunctions::~UnwindDiscardFunctions() = default; |
| |
| void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |
| State state(ctx); |
| state.Run(); |
| } |
| |
| bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const { |
| auto& sem = program->Sem(); |
| for (auto* f : program->AST().Functions()) { |
| if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| } // namespace tint::transform |