reader/spirv: Remove use of BlockStatement::append()
Introduce `StatementBuilder`s , which may hold mutable state, before being converted into the immutable AST node on completion of the `BlockStatement`.
Bug: tint:396
Bug: tint:390
Change-Id: I0381c4ae7948be0de02bc13e54e0037a72baaf0c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35506
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/block_statement.cc b/src/ast/block_statement.cc
index b0e5f4b..d952a22 100644
--- a/src/ast/block_statement.cc
+++ b/src/ast/block_statement.cc
@@ -24,6 +24,10 @@
BlockStatement::BlockStatement(const Source& source) : Base(source) {}
+BlockStatement::BlockStatement(const Source& source,
+ const StatementList& statements)
+ : Base(source), statements_(std::move(statements)) {}
+
BlockStatement::BlockStatement(BlockStatement&&) = default;
BlockStatement::~BlockStatement() = default;
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h
index faa7aaa..730f7bf 100644
--- a/src/ast/block_statement.h
+++ b/src/ast/block_statement.h
@@ -30,6 +30,10 @@
/// Constructor
/// @param source the block statement source
explicit BlockStatement(const Source& source);
+ /// Constructor
+ /// @param source the block statement source
+ /// @param statements the block statements
+ BlockStatement(const Source& source, const StatementList& statements);
/// Move constructor
BlockStatement(BlockStatement&&);
~BlockStatement() override;
diff --git a/src/ast/statement.h b/src/ast/statement.h
index baff8d9..14808d5 100644
--- a/src/ast/statement.h
+++ b/src/ast/statement.h
@@ -42,6 +42,9 @@
Statement(const Statement&) = delete;
};
+/// A list of statements
+using StatementList = std::vector<Statement*>;
+
} // namespace ast
} // namespace tint
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index ce8deec..effad4f 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -608,6 +608,70 @@
std::unordered_set<uint32_t> visited_;
};
+/// A StatementBuilder for ast::SwitchStatment
+/// @see StatementBuilder
+struct SwitchStatementBuilder
+ : public Castable<SwitchStatementBuilder, StatementBuilder> {
+ /// Constructor
+ /// @param cond the switch statement condition
+ explicit SwitchStatementBuilder(ast::Expression* cond) : condition(cond) {}
+
+ /// @param mod the ast Module to build into
+ /// @returns the built ast::SwitchStatement
+ ast::SwitchStatement* Build(ast::Module* mod) const override {
+ // We've listed cases in reverse order in the switch statement.
+ // Reorder them to match the presentation order in WGSL.
+ auto reversed_cases = cases;
+ std::reverse(reversed_cases.begin(), reversed_cases.end());
+
+ return mod->create<ast::SwitchStatement>(Source{}, condition,
+ reversed_cases);
+ }
+
+ /// Switch statement condition
+ ast::Expression* const condition;
+ /// Switch statement cases
+ ast::CaseStatementList cases;
+};
+
+/// A StatementBuilder for ast::IfStatement
+/// @see StatementBuilder
+struct IfStatementBuilder
+ : public Castable<IfStatementBuilder, StatementBuilder> {
+ /// Constructor
+ /// @param c the if-statement condition
+ explicit IfStatementBuilder(ast::Expression* c) : cond(c) {}
+
+ /// @param mod the ast Module to build into
+ /// @returns the built ast::IfStatement
+ ast::IfStatement* Build(ast::Module* mod) const override {
+ return mod->create<ast::IfStatement>(Source{}, cond, body, else_stmts);
+ }
+
+ /// If-statement condition
+ ast::Expression* const cond;
+ /// If-statement block body
+ ast::BlockStatement* body = nullptr;
+ /// Optional if-statement else statements
+ ast::ElseStatementList else_stmts;
+};
+
+/// A StatementBuilder for ast::LoopStatement
+/// @see StatementBuilder
+struct LoopStatementBuilder
+ : public Castable<LoopStatementBuilder, StatementBuilder> {
+ /// @param mod the ast Module to build into
+ /// @returns the built ast::LoopStatement
+ ast::LoopStatement* Build(ast::Module* mod) const override {
+ return mod->create<ast::LoopStatement>(Source{}, body, continuing);
+ }
+
+ /// Loop-statement block body
+ ast::BlockStatement* body = nullptr;
+ /// Loop-statement continuing body
+ ast::BlockStatement* continuing = nullptr;
+};
+
} // namespace
BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
@@ -622,6 +686,17 @@
DefInfo::~DefInfo() = default;
+bool StatementBuilder::IsValid() const {
+ return true;
+}
+ast::Node* StatementBuilder::Clone(ast::CloneContext*) const {
+ return nullptr;
+}
+void StatementBuilder::to_str(std::ostream& out, size_t indent) const {
+ make_indent(out, indent);
+ out << "StatementBuilder" << std::endl;
+}
+
FunctionEmitter::FunctionEmitter(ParserImpl* pi,
const spvtools::opt::Function& function,
const EntryPointInfo* ep_info)
@@ -636,7 +711,7 @@
function_(function),
i32_(ast_module_.create<ast::type::I32>()),
ep_info_(ep_info) {
- PushNewStatementBlock(nullptr, 0, nullptr, nullptr, nullptr);
+ PushNewStatementBlock(nullptr, 0, nullptr, nullptr);
}
FunctionEmitter::FunctionEmitter(ParserImpl* pi,
@@ -646,32 +721,62 @@
FunctionEmitter::~FunctionEmitter() = default;
FunctionEmitter::StatementBlock::StatementBlock(
- const Construct* construct,
+ const spirv::Construct* construct,
uint32_t end_id,
- CompletionAction completion_action,
- ast::BlockStatement* statements,
+ FunctionEmitter::CompletionAction completion_action,
ast::CaseStatementList* cases)
: construct_(construct),
end_id_(end_id),
completion_action_(completion_action),
- statements_(statements),
cases_(cases) {}
-FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&&) = default;
+FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other)
+ : construct_(other.construct_),
+ end_id_(other.end_id_),
+ completion_action_(std::move(other.completion_action_)),
+ statements_(std::move(other.statements_)),
+ cases_(std::move(other.cases_)) {
+ other.statements_.clear();
+}
-FunctionEmitter::StatementBlock::~StatementBlock() = default;
+FunctionEmitter::StatementBlock::~StatementBlock() {
+ if (!finalized_) {
+ // Delete builders that have not been built with Finalize()
+ for (auto* statement : statements_) {
+ if (auto* builder = statement->As<StatementBuilder>()) {
+ delete builder;
+ }
+ }
+ }
+}
+
+void FunctionEmitter::StatementBlock::Finalize(ast::Module* mod) {
+ assert(!finalized_ /* Finalize() must only be called once */);
+ for (size_t i = 0; i < statements_.size(); i++) {
+ if (auto* builder = statements_[i]->As<StatementBuilder>()) {
+ statements_[i] = builder->Build(mod);
+ delete builder;
+ }
+ }
+
+ if (completion_action_ != nullptr) {
+ completion_action_(statements_);
+ }
+
+ finalized_ = true;
+}
+
+void FunctionEmitter::StatementBlock::Add(ast::Statement* statement) {
+ assert(!finalized_ /* Add() must not be called after Finalize() */);
+ statements_.emplace_back(statement);
+}
void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
uint32_t end_id,
- ast::BlockStatement* block,
ast::CaseStatementList* cases,
CompletionAction action) {
- if (block == nullptr) {
- block = create<ast::BlockStatement>(Source{});
- }
-
statements_stack_.emplace_back(
- StatementBlock{construct, end_id, action, block, cases});
+ StatementBlock{construct, end_id, action, cases});
}
void FunctionEmitter::PushGuard(const std::string& guard_name,
@@ -685,10 +790,12 @@
auto* cond = create<ast::IdentifierExpression>(
Source{}, ast_module_.RegisterSymbol(guard_name), guard_name);
- auto* body = create<ast::BlockStatement>(Source{});
- AddStatement(
- create<ast::IfStatement>(Source{}, cond, body, ast::ElseStatementList{}));
- PushNewStatementBlock(top.construct_, end_id, body, nullptr, nullptr);
+ auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
+
+ PushNewStatementBlock(
+ top.Construct(), end_id, nullptr, [=](const ast::StatementList& stmts) {
+ builder->body = create<ast::BlockStatement>(Source{}, stmts);
+ });
}
void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
@@ -696,31 +803,36 @@
const auto& top = statements_stack_.back();
auto* cond = MakeTrue(Source{});
- auto* body = create<ast::BlockStatement>(Source{});
- AddStatement(
- create<ast::IfStatement>(Source{}, cond, body, ast::ElseStatementList{}));
- PushNewStatementBlock(top.construct_, end_id, body, nullptr, nullptr);
+ auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
+
+ PushNewStatementBlock(
+ top.Construct(), end_id, nullptr, [=](const ast::StatementList& stmts) {
+ builder->body = create<ast::BlockStatement>(Source{}, stmts);
+ });
}
-const ast::BlockStatement* FunctionEmitter::ast_body() {
+const ast::StatementList FunctionEmitter::ast_body() {
assert(!statements_stack_.empty());
- return statements_stack_[0].statements_;
+ auto& entry = statements_stack_[0];
+ entry.Finalize(&ast_module_);
+ return entry.Statements();
}
ast::Statement* FunctionEmitter::AddStatement(ast::Statement* statement) {
assert(!statements_stack_.empty());
auto* result = statement;
if (result != nullptr) {
- statements_stack_.back().statements_->append(statement);
+ auto& block = statements_stack_.back();
+ block.Add(statement);
}
return result;
}
ast::Statement* FunctionEmitter::LastStatement() {
assert(!statements_stack_.empty());
- auto* statement_list = statements_stack_.back().statements_;
- assert(!statement_list->empty());
- return statement_list->last();
+ auto& statement_list = statements_stack_.back().Statements();
+ assert(!statement_list.empty());
+ return statement_list.back();
}
bool FunctionEmitter::Emit() {
@@ -748,7 +860,10 @@
<< statements_stack_.size();
}
- auto* body = statements_stack_[0].statements_;
+ statements_stack_[0].Finalize(&ast_module_);
+
+ auto& statements = statements_stack_[0].Statements();
+ auto* body = create<ast::BlockStatement>(Source{}, statements);
ast_module_.AddFunction(
create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name),
decl.name, std::move(decl.params), decl.return_type,
@@ -756,7 +871,7 @@
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();
- PushNewStatementBlock(constructs_[0].get(), 0, nullptr, nullptr, nullptr);
+ PushNewStatementBlock(constructs_[0].get(), 0, nullptr, nullptr);
return success();
}
@@ -1935,7 +2050,7 @@
// TODO(dneto): refactor how the first construct is created vs.
// this statements stack entry is populated.
assert(statements_stack_.size() == 1);
- statements_stack_[0].construct_ = function_construct;
+ statements_stack_[0].SetConstruct(function_construct);
for (auto block_id : block_order()) {
if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
@@ -1948,11 +2063,8 @@
bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
// Close off previous constructs.
while (!statements_stack_.empty() &&
- (statements_stack_.back().end_id_ == block_info.id)) {
- StatementBlock& sb = statements_stack_.back();
- if (sb.completion_action_ != nullptr) {
- sb.completion_action_();
- }
+ (statements_stack_.back().EndId() == block_info.id)) {
+ statements_stack_.back().Finalize(&ast_module_);
statements_stack_.pop_back();
}
if (statements_stack_.empty()) {
@@ -1965,7 +2077,7 @@
std::vector<const Construct*> entering_constructs; // inner most comes first
{
auto* here = block_info.construct;
- auto* const top_construct = statements_stack_.back().construct_;
+ auto* const top_construct = statements_stack_.back().Construct();
while (here != top_construct) {
// Only enter a construct at its header block.
if (here->begin_id == block_info.id) {
@@ -2152,42 +2264,9 @@
const auto condition_id =
block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr;
- auto* body = create<ast::BlockStatement>(Source{});
// Generate the code for the condition.
- // Use the IfBuilder to create the if-statement. The IfBuilder is constructed
- // as a std::shared_ptr and is captured by the then and else clause
- // CompletionAction lambdas, and so will only be destructed when the last
- // block is completed. The IfBuilder destructor constructs the IfStatement,
- // inserting it at the current insertion point in the current
- // ast::BlockStatement.
- struct IfBuilder {
- IfBuilder(ast::Module* mod,
- StatementBlock& statement_block,
- tint::ast::Expression* cond,
- ast::BlockStatement* body)
- : mod_(mod),
- dst_block_(statement_block.statements_),
- dst_block_insertion_point_(statement_block.statements_->size()),
- cond_(cond),
- body_(body) {}
-
- ~IfBuilder() {
- dst_block_->insert(
- dst_block_insertion_point_,
- mod_->create<ast::IfStatement>(Source{}, cond_, body_, else_stmts_));
- }
-
- ast::Module* mod_;
- ast::BlockStatement* dst_block_;
- size_t dst_block_insertion_point_;
- tint::ast::Expression* cond_;
- ast::BlockStatement* body_;
- ast::ElseStatementList else_stmts_;
- };
-
- auto if_builder = std::make_shared<IfBuilder>(
- &ast_module_, statements_stack_.back(), cond, body);
+ auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
// Compute the block IDs that should end the then-clause and the else-clause.
@@ -2225,17 +2304,16 @@
// Push statement blocks for the then-clause and the else-clause.
// But make sure we do it in the right order.
- auto push_else = [this, if_builder, else_end, construct]() {
+ auto push_else = [this, builder, else_end, construct]() {
// Push the else clause onto the stack first.
- auto* else_body = create<ast::BlockStatement>(Source{});
PushNewStatementBlock(
- construct, else_end, else_body, nullptr,
- [this, if_builder, else_body]() {
+ construct, else_end, nullptr, [=](const ast::StatementList& stmts) {
// Only set the else-clause if there are statements to fill it.
- if (!else_body->empty()) {
+ if (!stmts.empty()) {
// The "else" consists of the statement list from the top of
// statements stack, without an elseif condition.
- if_builder->else_stmts_.emplace_back(
+ auto* else_body = create<ast::BlockStatement>(Source{}, stmts);
+ builder->else_stmts.emplace_back(
create<ast::ElseStatement>(Source{}, nullptr, else_body));
}
});
@@ -2275,7 +2353,10 @@
}
// Push the then clause onto the stack.
- PushNewStatementBlock(construct, then_end, body, nullptr, [if_builder] {});
+ PushNewStatementBlock(
+ construct, then_end, nullptr, [=](const ast::StatementList& stmts) {
+ builder->body = create<ast::BlockStatement>(Source{}, stmts);
+ });
}
return success();
@@ -2293,14 +2374,11 @@
auto selector = MakeExpression(selector_id);
// First, push the statement block for the entire switch.
- ast::CaseStatementList case_list;
- auto* swch = create<ast::SwitchStatement>(Source{}, selector.expr, case_list);
- AddStatement(swch)->As<ast::SwitchStatement>();
+ auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
// Grab a pointer to the case list. It will get buried in the statement block
// stack.
- auto* cases = &(swch->body());
- PushNewStatementBlock(construct, construct->end_id, nullptr, cases, nullptr);
+ PushNewStatementBlock(construct, construct->end_id, &swch->cases, nullptr);
// We will push statement-blocks onto the stack to gather the statements in
// the default clause and cases clauses. Determine the list of blocks
@@ -2367,21 +2445,27 @@
const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
: construct->end_id;
- // Create the case clause. Temporarily put it in the wrong order
- // on the case statement list.
- auto* body = create<ast::BlockStatement>(Source{});
- cases->emplace_back(create<ast::CaseStatement>(Source{}, selectors, body));
-
- PushNewStatementBlock(construct, end_id, body, nullptr, nullptr);
+ // Reserve the case clause slot in swch->cases, push the new statement block
+ // for the case, and fill the case clause once the block is generated.
+ auto case_idx = swch->cases.size();
+ swch->cases.emplace_back(nullptr);
+ PushNewStatementBlock(
+ construct, end_id, nullptr, [=](const ast::StatementList& stmts) {
+ auto* body = create<ast::BlockStatement>(Source{}, stmts);
+ swch->cases[case_idx] =
+ create<ast::CaseStatement>(Source{}, selectors, body);
+ });
if ((default_info == clause_heads[i]) && has_selectors &&
construct->ContainsPos(default_info->pos)) {
// Generate a default clause with a just fallthrough.
- auto* stmts = create<ast::BlockStatement>(Source{});
- stmts->append(create<ast::FallthroughStatement>(Source{}));
+ auto* stmts = create<ast::BlockStatement>(
+ Source{}, ast::StatementList{
+ create<ast::FallthroughStatement>(Source{}),
+ });
auto* case_stmt =
create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
- cases->emplace_back(case_stmt);
+ swch->cases.emplace_back(case_stmt);
}
if (i == 0) {
@@ -2389,18 +2473,16 @@
}
}
- // We've listed cases in reverse order in the switch statement. Reorder them
- // to match the presentation order in WGSL.
- std::reverse(cases->begin(), cases->end());
-
return success();
}
bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
- auto* body = create<ast::BlockStatement>(Source{});
- AddStatement(create<ast::LoopStatement>(
- Source{}, body, create<ast::BlockStatement>(Source{})));
- PushNewStatementBlock(construct, construct->end_id, body, nullptr, nullptr);
+ auto* builder = AddStatementBuilder<LoopStatementBuilder>();
+ PushNewStatementBlock(construct, construct->end_id, nullptr,
+ [=](const ast::StatementList& stmts) {
+ builder->body =
+ create<ast::BlockStatement>(Source{}, stmts);
+ });
return success();
}
@@ -2408,13 +2490,16 @@
// A continue construct has the same depth as its associated loop
// construct. Start a continue construct.
auto* loop_candidate = LastStatement();
- auto* loop = loop_candidate->As<ast::LoopStatement>();
+ auto* loop = loop_candidate->As<LoopStatementBuilder>();
if (loop == nullptr) {
return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack";
}
- PushNewStatementBlock(construct, construct->end_id, loop->continuing(),
- nullptr, nullptr);
+ PushNewStatementBlock(construct, construct->end_id, nullptr,
+ [=](const ast::StatementList& stmts) {
+ loop->continuing =
+ create<ast::BlockStatement>(Source{}, stmts);
+ });
return success();
}
@@ -2502,7 +2587,7 @@
AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
if (!flow_guard.empty()) {
- PushGuard(flow_guard, statements_stack_.back().end_id_);
+ PushGuard(flow_guard, statements_stack_.back().EndId());
}
return true;
}
@@ -2600,17 +2685,18 @@
}
ast::ElseStatementList else_stmts;
if (else_stmt != nullptr) {
- auto* stmts = create<ast::BlockStatement>(Source{});
- stmts->append(else_stmt);
- else_stmts.emplace_back(
- create<ast::ElseStatement>(Source{}, nullptr, stmts));
+ ast::StatementList stmts{else_stmt};
+ else_stmts.emplace_back(create<ast::ElseStatement>(
+ Source{}, nullptr, create<ast::BlockStatement>(Source{}, stmts)));
}
- auto* if_block = create<ast::BlockStatement>(Source{});
+ ast::StatementList if_stmts;
+ if (then_stmt != nullptr) {
+ if_stmts.emplace_back(then_stmt);
+ }
+ auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
auto* if_stmt =
create<ast::IfStatement>(Source{}, condition, if_block, else_stmts);
- if (then_stmt != nullptr) {
- if_block->append(then_stmt);
- }
+
return if_stmt;
}
@@ -4285,3 +4371,8 @@
} // namespace spirv
} // namespace reader
} // namespace tint
+
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::StatementBuilder);
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::SwitchStatementBuilder);
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::IfStatementBuilder);
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::LoopStatementBuilder);
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 5d24318..e1f0968 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -289,6 +289,29 @@
return o;
}
+/// A placeholder Statement that exists for the duration of building a
+/// StatementBlock. Once the StatementBlock is built, Build() will be called to
+/// construct the final AST node, which will be used in the place of this
+/// StatementBuilder.
+/// StatementBuilders are used to simplify construction of AST nodes that will
+/// become immutable. The builders may hold mutable state while the
+/// StatementBlock is being constructed, which becomes an immutable node on
+/// StatementBlock::Finalize().
+class StatementBuilder : public Castable<StatementBuilder, ast::Statement> {
+ public:
+ /// Constructor
+ StatementBuilder() : Base(Source{}) {}
+
+ /// @param mod the ast Module to build into
+ /// @returns the build AST node
+ virtual ast::Statement* Build(ast::Module* mod) const = 0;
+
+ private:
+ bool IsValid() const override;
+ Node* Clone(ast::CloneContext*) const override;
+ void to_str(std::ostream& out, size_t indent) const override;
+};
+
/// A FunctionEmitter emits a SPIR-V function onto a Tint AST module.
class FunctionEmitter {
public:
@@ -317,10 +340,10 @@
/// @returns true if emission has failed.
bool failed() const { return !success(); }
- /// Returns the body of the function. It is the bottom of the statement
- /// stack.
+ /// Finalizes any StatementBuilders returns the body of the function.
+ /// Must only be called once, and to be used only for testing.
/// @returns the body of the function.
- const ast::BlockStatement* ast_body();
+ const ast::StatementList ast_body();
/// Records failure.
/// @returns a FailStream on which to emit diagnostics.
@@ -811,6 +834,14 @@
/// @returns a pointer to the statement.
ast::Statement* AddStatement(ast::Statement* statement);
+ template <typename T, typename... ARGS>
+ T* AddStatementBuilder(ARGS&&... args) {
+ // The builder is temporary and is not owned by the module.
+ auto builder = new T(std::forward<ARGS>(args)...);
+ AddStatement(builder);
+ return builder;
+ }
+
/// Returns the source record for the given instruction.
/// @param inst the SPIR-V instruction
/// @return the Source record, or a default one
@@ -819,43 +850,79 @@
/// @returns the last statetment in the top of the statement stack.
ast::Statement* LastStatement();
- using CompletionAction = std::function<void()>;
+ using CompletionAction = std::function<void(const ast::StatementList&)>;
// A StatementBlock represents a braced-list of statements while it is being
// constructed.
- struct StatementBlock {
+ class StatementBlock {
+ public:
StatementBlock(const Construct* construct,
uint32_t end_id,
CompletionAction completion_action,
- ast::BlockStatement* statements,
ast::CaseStatementList* cases);
StatementBlock(StatementBlock&&);
~StatementBlock();
- // The construct to which this construct constributes.
- const Construct* construct_;
- // The ID of the block at which the completion action should be triggerd
- // and this statement block discarded. This is often the |end_id| of
- // |construct| itself.
- uint32_t end_id_;
- // The completion action finishes processing this statement block.
- CompletionAction completion_action_;
+ StatementBlock(const StatementBlock&) = delete;
+ StatementBlock& operator=(const StatementBlock&) = delete;
- // Only one of |statements| or |cases| is active.
+ /// Replaces any StatementBuilders with the built result, and calls the
+ /// completion callback (if set). Must only be called once, after all
+ /// statements have been added with Add().
+ /// @param mod the module
+ void Finalize(ast::Module* mod);
- // The list of statements being built, if this construct is not a switch.
- ast::BlockStatement* statements_ = nullptr;
- // The list of switch cases being built, if this construct is a switch.
+ /// Add() adds `statement` to the block.
+ /// Add() must not be called after calling Finalize().
+ void Add(ast::Statement* statement);
+
+ /// @param construct the construct which this construct constributes to
+ void SetConstruct(const Construct* construct) { construct_ = construct; }
+
+ /// @return the construct to which this construct constributes
+ const Construct* Construct() const { return construct_; }
+
+ /// @return the ID of the block at which the completion action should be
+ /// triggerd and this statement block discarded. This is often the `end_id`
+ /// of `construct` itself.
+ uint32_t EndId() const { return end_id_; }
+
+ /// @return the completion action finishes processing this statement block
+ CompletionAction CompletionAction() const { return completion_action_; }
+
+ /// @return the list of statements being built, if this construct is not a
+ /// switch.
+ const ast::StatementList& Statements() const { return statements_; }
+
+ /// @return the list of switch cases being built, if this construct is a
+ /// switch
+ ast::CaseStatementList* Cases() const { return cases_; }
+
+ private:
+ /// The construct to which this construct constributes.
+ const spirv::Construct* construct_;
+ /// The ID of the block at which the completion action should be triggerd
+ /// and this statement block discarded. This is often the `end_id` of
+ /// `construct` itself.
+ uint32_t const end_id_;
+ /// The completion action finishes processing this statement block.
+ FunctionEmitter::CompletionAction const completion_action_;
+
+ // Only one of `statements` or `cases` is active.
+
+ /// The list of statements being built, if this construct is not a switch.
+ ast::StatementList statements_;
+ /// The list of switch cases being built, if this construct is a switch.
ast::CaseStatementList* cases_ = nullptr;
+ /// True if Finalize() has been called.
+ bool finalized_ = false;
};
/// Pushes an empty statement block onto the statements stack.
- /// @param block the block to push into
/// @param cases the case list to push into
/// @param action the completion action for this block
void PushNewStatementBlock(const Construct* construct,
uint32_t end_id,
- ast::BlockStatement* block,
ast::CaseStatementList* cases,
CompletionAction action);
@@ -887,6 +954,8 @@
return ast_module_.create<T>(std::forward<ARGS>(args)...);
}
+ using StatementsStack = std::vector<StatementBlock>;
+
ParserImpl& parser_impl_;
ast::Module& ast_module_;
spvtools::opt::IRContext& ir_context_;
@@ -901,9 +970,9 @@
// A stack of statement lists. Each list is contained in a construct in
// the next deeper element of stack. The 0th entry represents the statements
// for the entire function. This stack is never empty.
- // The |construct| member for the 0th element is only valid during the
+ // The `construct` member for the 0th element is only valid during the
// lifetime of the EmitFunctionBodyStatements method.
- std::vector<StatementBlock> statements_stack_;
+ StatementsStack statements_stack_;
// The set of IDs that have already had an identifier name generated for it.
std::unordered_set<uint32_t> identifier_values_;
diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc
index 2c4d356..b8f004b 100644
--- a/src/reader/spirv/function_composite_test.cc
+++ b/src/reader/spirv/function_composite_test.cc
@@ -479,7 +479,8 @@
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
- EXPECT_THAT(ToString(p->get_module(), fe.ast_body()), HasSubstr(R"(
+ auto got = fe.ast_body();
+ EXPECT_THAT(ToString(p->get_module(), got), HasSubstr(R"(
VariableConst{
x_2
none
@@ -491,8 +492,8 @@
}
}
})"))
- << ToString(p->get_module(), fe.ast_body());
- EXPECT_THAT(ToString(p->get_module(), fe.ast_body()), HasSubstr(R"(
+ << ToString(p->get_module(), got);
+ EXPECT_THAT(ToString(p->get_module(), got), HasSubstr(R"(
VariableConst{
x_4
none
@@ -504,7 +505,7 @@
}
}
})"))
- << ToString(p->get_module(), fe.ast_body());
+ << ToString(p->get_module(), got);
}
TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {
diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h
index 0342e47..e8283c6 100644
--- a/src/reader/spirv/parser_impl_test_helper.h
+++ b/src/reader/spirv/parser_impl_test_helper.h
@@ -57,13 +57,14 @@
// Use this form when you don't need to template any further.
using SpvParserTest = SpvParserTestBase<::testing::Test>;
-/// Returns the string dump of a function body.
-/// @param body the statement in the body
-/// @returnss the string dump of a function body.
+/// Returns the string dump of a statement list.
+/// @param mod the module
+/// @param stmts the statement list
+/// @returns the string dump of a statement list.
inline std::string ToString(const ast::Module& mod,
- const ast::BlockStatement* body) {
+ const ast::StatementList& stmts) {
std::ostringstream outs;
- for (const auto* stmt : *body) {
+ for (const auto* stmt : stmts) {
stmt->to_str(outs, 0);
}
return Demangler().Demangle(mod, outs.str());