reader/spirv: Remove use of BlockStatement::append()

Introduce `StatementBuilder`s , which may hold mutable state, before being converted into the immutable AST node on completion of the `BlockStatement`.

Bug: tint:396
Bug: tint:390
Change-Id: I0381c4ae7948be0de02bc13e54e0037a72baaf0c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35506
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/block_statement.cc b/src/ast/block_statement.cc
index b0e5f4b..d952a22 100644
--- a/src/ast/block_statement.cc
+++ b/src/ast/block_statement.cc
@@ -24,6 +24,10 @@
 
 BlockStatement::BlockStatement(const Source& source) : Base(source) {}
 
+BlockStatement::BlockStatement(const Source& source,
+                               const StatementList& statements)
+    : Base(source), statements_(std::move(statements)) {}
+
 BlockStatement::BlockStatement(BlockStatement&&) = default;
 
 BlockStatement::~BlockStatement() = default;
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h
index faa7aaa..730f7bf 100644
--- a/src/ast/block_statement.h
+++ b/src/ast/block_statement.h
@@ -30,6 +30,10 @@
   /// Constructor
   /// @param source the block statement source
   explicit BlockStatement(const Source& source);
+  /// Constructor
+  /// @param source the block statement source
+  /// @param statements the block statements
+  BlockStatement(const Source& source, const StatementList& statements);
   /// Move constructor
   BlockStatement(BlockStatement&&);
   ~BlockStatement() override;
diff --git a/src/ast/statement.h b/src/ast/statement.h
index baff8d9..14808d5 100644
--- a/src/ast/statement.h
+++ b/src/ast/statement.h
@@ -42,6 +42,9 @@
   Statement(const Statement&) = delete;
 };
 
+/// A list of statements
+using StatementList = std::vector<Statement*>;
+
 }  // namespace ast
 }  // namespace tint
 
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index ce8deec..effad4f 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -608,6 +608,70 @@
   std::unordered_set<uint32_t> visited_;
 };
 
+/// A StatementBuilder for ast::SwitchStatment
+/// @see StatementBuilder
+struct SwitchStatementBuilder
+    : public Castable<SwitchStatementBuilder, StatementBuilder> {
+  /// Constructor
+  /// @param cond the switch statement condition
+  explicit SwitchStatementBuilder(ast::Expression* cond) : condition(cond) {}
+
+  /// @param mod the ast Module to build into
+  /// @returns the built ast::SwitchStatement
+  ast::SwitchStatement* Build(ast::Module* mod) const override {
+    // We've listed cases in reverse order in the switch statement.
+    // Reorder them to match the presentation order in WGSL.
+    auto reversed_cases = cases;
+    std::reverse(reversed_cases.begin(), reversed_cases.end());
+
+    return mod->create<ast::SwitchStatement>(Source{}, condition,
+                                             reversed_cases);
+  }
+
+  /// Switch statement condition
+  ast::Expression* const condition;
+  /// Switch statement cases
+  ast::CaseStatementList cases;
+};
+
+/// A StatementBuilder for ast::IfStatement
+/// @see StatementBuilder
+struct IfStatementBuilder
+    : public Castable<IfStatementBuilder, StatementBuilder> {
+  /// Constructor
+  /// @param c the if-statement condition
+  explicit IfStatementBuilder(ast::Expression* c) : cond(c) {}
+
+  /// @param mod the ast Module to build into
+  /// @returns the built ast::IfStatement
+  ast::IfStatement* Build(ast::Module* mod) const override {
+    return mod->create<ast::IfStatement>(Source{}, cond, body, else_stmts);
+  }
+
+  /// If-statement condition
+  ast::Expression* const cond;
+  /// If-statement block body
+  ast::BlockStatement* body = nullptr;
+  /// Optional if-statement else statements
+  ast::ElseStatementList else_stmts;
+};
+
+/// A StatementBuilder for ast::LoopStatement
+/// @see StatementBuilder
+struct LoopStatementBuilder
+    : public Castable<LoopStatementBuilder, StatementBuilder> {
+  /// @param mod the ast Module to build into
+  /// @returns the built ast::LoopStatement
+  ast::LoopStatement* Build(ast::Module* mod) const override {
+    return mod->create<ast::LoopStatement>(Source{}, body, continuing);
+  }
+
+  /// Loop-statement block body
+  ast::BlockStatement* body = nullptr;
+  /// Loop-statement continuing body
+  ast::BlockStatement* continuing = nullptr;
+};
+
 }  // namespace
 
 BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
@@ -622,6 +686,17 @@
 
 DefInfo::~DefInfo() = default;
 
+bool StatementBuilder::IsValid() const {
+  return true;
+}
+ast::Node* StatementBuilder::Clone(ast::CloneContext*) const {
+  return nullptr;
+}
+void StatementBuilder::to_str(std::ostream& out, size_t indent) const {
+  make_indent(out, indent);
+  out << "StatementBuilder" << std::endl;
+}
+
 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
                                  const spvtools::opt::Function& function,
                                  const EntryPointInfo* ep_info)
@@ -636,7 +711,7 @@
       function_(function),
       i32_(ast_module_.create<ast::type::I32>()),
       ep_info_(ep_info) {
-  PushNewStatementBlock(nullptr, 0, nullptr, nullptr, nullptr);
+  PushNewStatementBlock(nullptr, 0, nullptr, nullptr);
 }
 
 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
@@ -646,32 +721,62 @@
 FunctionEmitter::~FunctionEmitter() = default;
 
 FunctionEmitter::StatementBlock::StatementBlock(
-    const Construct* construct,
+    const spirv::Construct* construct,
     uint32_t end_id,
-    CompletionAction completion_action,
-    ast::BlockStatement* statements,
+    FunctionEmitter::CompletionAction completion_action,
     ast::CaseStatementList* cases)
     : construct_(construct),
       end_id_(end_id),
       completion_action_(completion_action),
-      statements_(statements),
       cases_(cases) {}
 
-FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&&) = default;
+FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other)
+    : construct_(other.construct_),
+      end_id_(other.end_id_),
+      completion_action_(std::move(other.completion_action_)),
+      statements_(std::move(other.statements_)),
+      cases_(std::move(other.cases_)) {
+  other.statements_.clear();
+}
 
-FunctionEmitter::StatementBlock::~StatementBlock() = default;
+FunctionEmitter::StatementBlock::~StatementBlock() {
+  if (!finalized_) {
+    // Delete builders that have not been built with Finalize()
+    for (auto* statement : statements_) {
+      if (auto* builder = statement->As<StatementBuilder>()) {
+        delete builder;
+      }
+    }
+  }
+}
+
+void FunctionEmitter::StatementBlock::Finalize(ast::Module* mod) {
+  assert(!finalized_ /* Finalize() must only be called once */);
+  for (size_t i = 0; i < statements_.size(); i++) {
+    if (auto* builder = statements_[i]->As<StatementBuilder>()) {
+      statements_[i] = builder->Build(mod);
+      delete builder;
+    }
+  }
+
+  if (completion_action_ != nullptr) {
+    completion_action_(statements_);
+  }
+
+  finalized_ = true;
+}
+
+void FunctionEmitter::StatementBlock::Add(ast::Statement* statement) {
+  assert(!finalized_ /* Add() must not be called after Finalize() */);
+  statements_.emplace_back(statement);
+}
 
 void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
                                             uint32_t end_id,
-                                            ast::BlockStatement* block,
                                             ast::CaseStatementList* cases,
                                             CompletionAction action) {
-  if (block == nullptr) {
-    block = create<ast::BlockStatement>(Source{});
-  }
-
   statements_stack_.emplace_back(
-      StatementBlock{construct, end_id, action, block, cases});
+      StatementBlock{construct, end_id, action, cases});
 }
 
 void FunctionEmitter::PushGuard(const std::string& guard_name,
@@ -685,10 +790,12 @@
 
   auto* cond = create<ast::IdentifierExpression>(
       Source{}, ast_module_.RegisterSymbol(guard_name), guard_name);
-  auto* body = create<ast::BlockStatement>(Source{});
-  AddStatement(
-      create<ast::IfStatement>(Source{}, cond, body, ast::ElseStatementList{}));
-  PushNewStatementBlock(top.construct_, end_id, body, nullptr, nullptr);
+  auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
+
+  PushNewStatementBlock(
+      top.Construct(), end_id, nullptr, [=](const ast::StatementList& stmts) {
+        builder->body = create<ast::BlockStatement>(Source{}, stmts);
+      });
 }
 
 void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
@@ -696,31 +803,36 @@
   const auto& top = statements_stack_.back();
 
   auto* cond = MakeTrue(Source{});
-  auto* body = create<ast::BlockStatement>(Source{});
-  AddStatement(
-      create<ast::IfStatement>(Source{}, cond, body, ast::ElseStatementList{}));
-  PushNewStatementBlock(top.construct_, end_id, body, nullptr, nullptr);
+  auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
+
+  PushNewStatementBlock(
+      top.Construct(), end_id, nullptr, [=](const ast::StatementList& stmts) {
+        builder->body = create<ast::BlockStatement>(Source{}, stmts);
+      });
 }
 
-const ast::BlockStatement* FunctionEmitter::ast_body() {
+const ast::StatementList FunctionEmitter::ast_body() {
   assert(!statements_stack_.empty());
-  return statements_stack_[0].statements_;
+  auto& entry = statements_stack_[0];
+  entry.Finalize(&ast_module_);
+  return entry.Statements();
 }
 
 ast::Statement* FunctionEmitter::AddStatement(ast::Statement* statement) {
   assert(!statements_stack_.empty());
   auto* result = statement;
   if (result != nullptr) {
-    statements_stack_.back().statements_->append(statement);
+    auto& block = statements_stack_.back();
+    block.Add(statement);
   }
   return result;
 }
 
 ast::Statement* FunctionEmitter::LastStatement() {
   assert(!statements_stack_.empty());
-  auto* statement_list = statements_stack_.back().statements_;
-  assert(!statement_list->empty());
-  return statement_list->last();
+  auto& statement_list = statements_stack_.back().Statements();
+  assert(!statement_list.empty());
+  return statement_list.back();
 }
 
 bool FunctionEmitter::Emit() {
@@ -748,7 +860,10 @@
                   << statements_stack_.size();
   }
 
-  auto* body = statements_stack_[0].statements_;
+  statements_stack_[0].Finalize(&ast_module_);
+
+  auto& statements = statements_stack_[0].Statements();
+  auto* body = create<ast::BlockStatement>(Source{}, statements);
   ast_module_.AddFunction(
       create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name),
                             decl.name, std::move(decl.params), decl.return_type,
@@ -756,7 +871,7 @@
 
   // Maintain the invariant by repopulating the one and only element.
   statements_stack_.clear();
-  PushNewStatementBlock(constructs_[0].get(), 0, nullptr, nullptr, nullptr);
+  PushNewStatementBlock(constructs_[0].get(), 0, nullptr, nullptr);
 
   return success();
 }
@@ -1935,7 +2050,7 @@
   // TODO(dneto): refactor how the first construct is created vs.
   // this statements stack entry is populated.
   assert(statements_stack_.size() == 1);
-  statements_stack_[0].construct_ = function_construct;
+  statements_stack_[0].SetConstruct(function_construct);
 
   for (auto block_id : block_order()) {
     if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
@@ -1948,11 +2063,8 @@
 bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
   // Close off previous constructs.
   while (!statements_stack_.empty() &&
-         (statements_stack_.back().end_id_ == block_info.id)) {
-    StatementBlock& sb = statements_stack_.back();
-    if (sb.completion_action_ != nullptr) {
-      sb.completion_action_();
-    }
+         (statements_stack_.back().EndId() == block_info.id)) {
+    statements_stack_.back().Finalize(&ast_module_);
     statements_stack_.pop_back();
   }
   if (statements_stack_.empty()) {
@@ -1965,7 +2077,7 @@
   std::vector<const Construct*> entering_constructs;  // inner most comes first
   {
     auto* here = block_info.construct;
-    auto* const top_construct = statements_stack_.back().construct_;
+    auto* const top_construct = statements_stack_.back().Construct();
     while (here != top_construct) {
       // Only enter a construct at its header block.
       if (here->begin_id == block_info.id) {
@@ -2152,42 +2264,9 @@
   const auto condition_id =
       block_info.basic_block->terminator()->GetSingleWordInOperand(0);
   auto* cond = MakeExpression(condition_id).expr;
-  auto* body = create<ast::BlockStatement>(Source{});
 
   // Generate the code for the condition.
-  // Use the IfBuilder to create the if-statement. The IfBuilder is constructed
-  // as a std::shared_ptr and is captured by the then and else clause
-  // CompletionAction lambdas, and so will only be destructed when the last
-  // block is completed. The IfBuilder destructor constructs the IfStatement,
-  // inserting it at the current insertion point in the current
-  // ast::BlockStatement.
-  struct IfBuilder {
-    IfBuilder(ast::Module* mod,
-              StatementBlock& statement_block,
-              tint::ast::Expression* cond,
-              ast::BlockStatement* body)
-        : mod_(mod),
-          dst_block_(statement_block.statements_),
-          dst_block_insertion_point_(statement_block.statements_->size()),
-          cond_(cond),
-          body_(body) {}
-
-    ~IfBuilder() {
-      dst_block_->insert(
-          dst_block_insertion_point_,
-          mod_->create<ast::IfStatement>(Source{}, cond_, body_, else_stmts_));
-    }
-
-    ast::Module* mod_;
-    ast::BlockStatement* dst_block_;
-    size_t dst_block_insertion_point_;
-    tint::ast::Expression* cond_;
-    ast::BlockStatement* body_;
-    ast::ElseStatementList else_stmts_;
-  };
-
-  auto if_builder = std::make_shared<IfBuilder>(
-      &ast_module_, statements_stack_.back(), cond, body);
+  auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
 
   // Compute the block IDs that should end the then-clause and the else-clause.
 
@@ -2225,17 +2304,16 @@
 
   // Push statement blocks for the then-clause and the else-clause.
   // But make sure we do it in the right order.
-  auto push_else = [this, if_builder, else_end, construct]() {
+  auto push_else = [this, builder, else_end, construct]() {
     // Push the else clause onto the stack first.
-    auto* else_body = create<ast::BlockStatement>(Source{});
     PushNewStatementBlock(
-        construct, else_end, else_body, nullptr,
-        [this, if_builder, else_body]() {
+        construct, else_end, nullptr, [=](const ast::StatementList& stmts) {
           // Only set the else-clause if there are statements to fill it.
-          if (!else_body->empty()) {
+          if (!stmts.empty()) {
             // The "else" consists of the statement list from the top of
             // statements stack, without an elseif condition.
-            if_builder->else_stmts_.emplace_back(
+            auto* else_body = create<ast::BlockStatement>(Source{}, stmts);
+            builder->else_stmts.emplace_back(
                 create<ast::ElseStatement>(Source{}, nullptr, else_body));
           }
         });
@@ -2275,7 +2353,10 @@
     }
 
     // Push the then clause onto the stack.
-    PushNewStatementBlock(construct, then_end, body, nullptr, [if_builder] {});
+    PushNewStatementBlock(
+        construct, then_end, nullptr, [=](const ast::StatementList& stmts) {
+          builder->body = create<ast::BlockStatement>(Source{}, stmts);
+        });
   }
 
   return success();
@@ -2293,14 +2374,11 @@
   auto selector = MakeExpression(selector_id);
 
   // First, push the statement block for the entire switch.
-  ast::CaseStatementList case_list;
-  auto* swch = create<ast::SwitchStatement>(Source{}, selector.expr, case_list);
-  AddStatement(swch)->As<ast::SwitchStatement>();
+  auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
 
   // Grab a pointer to the case list.  It will get buried in the statement block
   // stack.
-  auto* cases = &(swch->body());
-  PushNewStatementBlock(construct, construct->end_id, nullptr, cases, nullptr);
+  PushNewStatementBlock(construct, construct->end_id, &swch->cases, nullptr);
 
   // We will push statement-blocks onto the stack to gather the statements in
   // the default clause and cases clauses. Determine the list of blocks
@@ -2367,21 +2445,27 @@
     const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
                                                       : construct->end_id;
 
-    // Create the case clause.  Temporarily put it in the wrong order
-    // on the case statement list.
-    auto* body = create<ast::BlockStatement>(Source{});
-    cases->emplace_back(create<ast::CaseStatement>(Source{}, selectors, body));
-
-    PushNewStatementBlock(construct, end_id, body, nullptr, nullptr);
+    // Reserve the case clause slot in swch->cases, push the new statement block
+    // for the case, and fill the case clause once the block is generated.
+    auto case_idx = swch->cases.size();
+    swch->cases.emplace_back(nullptr);
+    PushNewStatementBlock(
+        construct, end_id, nullptr, [=](const ast::StatementList& stmts) {
+          auto* body = create<ast::BlockStatement>(Source{}, stmts);
+          swch->cases[case_idx] =
+              create<ast::CaseStatement>(Source{}, selectors, body);
+        });
 
     if ((default_info == clause_heads[i]) && has_selectors &&
         construct->ContainsPos(default_info->pos)) {
       // Generate a default clause with a just fallthrough.
-      auto* stmts = create<ast::BlockStatement>(Source{});
-      stmts->append(create<ast::FallthroughStatement>(Source{}));
+      auto* stmts = create<ast::BlockStatement>(
+          Source{}, ast::StatementList{
+                        create<ast::FallthroughStatement>(Source{}),
+                    });
       auto* case_stmt =
           create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
-      cases->emplace_back(case_stmt);
+      swch->cases.emplace_back(case_stmt);
     }
 
     if (i == 0) {
@@ -2389,18 +2473,16 @@
     }
   }
 
-  // We've listed cases in reverse order in the switch statement. Reorder them
-  // to match the presentation order in WGSL.
-  std::reverse(cases->begin(), cases->end());
-
   return success();
 }
 
 bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
-  auto* body = create<ast::BlockStatement>(Source{});
-  AddStatement(create<ast::LoopStatement>(
-      Source{}, body, create<ast::BlockStatement>(Source{})));
-  PushNewStatementBlock(construct, construct->end_id, body, nullptr, nullptr);
+  auto* builder = AddStatementBuilder<LoopStatementBuilder>();
+  PushNewStatementBlock(construct, construct->end_id, nullptr,
+                        [=](const ast::StatementList& stmts) {
+                          builder->body =
+                              create<ast::BlockStatement>(Source{}, stmts);
+                        });
   return success();
 }
 
@@ -2408,13 +2490,16 @@
   // A continue construct has the same depth as its associated loop
   // construct. Start a continue construct.
   auto* loop_candidate = LastStatement();
-  auto* loop = loop_candidate->As<ast::LoopStatement>();
+  auto* loop = loop_candidate->As<LoopStatementBuilder>();
   if (loop == nullptr) {
     return Fail() << "internal error: starting continue construct, "
                      "expected loop on top of stack";
   }
-  PushNewStatementBlock(construct, construct->end_id, loop->continuing(),
-                        nullptr, nullptr);
+  PushNewStatementBlock(construct, construct->end_id, nullptr,
+                        [=](const ast::StatementList& stmts) {
+                          loop->continuing =
+                              create<ast::BlockStatement>(Source{}, stmts);
+                        });
 
   return success();
 }
@@ -2502,7 +2587,7 @@
 
       AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
       if (!flow_guard.empty()) {
-        PushGuard(flow_guard, statements_stack_.back().end_id_);
+        PushGuard(flow_guard, statements_stack_.back().EndId());
       }
       return true;
     }
@@ -2600,17 +2685,18 @@
   }
   ast::ElseStatementList else_stmts;
   if (else_stmt != nullptr) {
-    auto* stmts = create<ast::BlockStatement>(Source{});
-    stmts->append(else_stmt);
-    else_stmts.emplace_back(
-        create<ast::ElseStatement>(Source{}, nullptr, stmts));
+    ast::StatementList stmts{else_stmt};
+    else_stmts.emplace_back(create<ast::ElseStatement>(
+        Source{}, nullptr, create<ast::BlockStatement>(Source{}, stmts)));
   }
-  auto* if_block = create<ast::BlockStatement>(Source{});
+  ast::StatementList if_stmts;
+  if (then_stmt != nullptr) {
+    if_stmts.emplace_back(then_stmt);
+  }
+  auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
   auto* if_stmt =
       create<ast::IfStatement>(Source{}, condition, if_block, else_stmts);
-  if (then_stmt != nullptr) {
-    if_block->append(then_stmt);
-  }
+
   return if_stmt;
 }
 
@@ -4285,3 +4371,8 @@
 }  // namespace spirv
 }  // namespace reader
 }  // namespace tint
+
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::StatementBuilder);
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::SwitchStatementBuilder);
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::IfStatementBuilder);
+TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::LoopStatementBuilder);
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 5d24318..e1f0968 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -289,6 +289,29 @@
   return o;
 }
 
+/// A placeholder Statement that exists for the duration of building a
+/// StatementBlock. Once the StatementBlock is built, Build() will be called to
+/// construct the final AST node, which will be used in the place of this
+/// StatementBuilder.
+/// StatementBuilders are used to simplify construction of AST nodes that will
+/// become immutable. The builders may hold mutable state while the
+/// StatementBlock is being constructed, which becomes an immutable node on
+/// StatementBlock::Finalize().
+class StatementBuilder : public Castable<StatementBuilder, ast::Statement> {
+ public:
+  /// Constructor
+  StatementBuilder() : Base(Source{}) {}
+
+  /// @param mod the ast Module to build into
+  /// @returns the build AST node
+  virtual ast::Statement* Build(ast::Module* mod) const = 0;
+
+ private:
+  bool IsValid() const override;
+  Node* Clone(ast::CloneContext*) const override;
+  void to_str(std::ostream& out, size_t indent) const override;
+};
+
 /// A FunctionEmitter emits a SPIR-V function onto a Tint AST module.
 class FunctionEmitter {
  public:
@@ -317,10 +340,10 @@
   /// @returns true if emission has failed.
   bool failed() const { return !success(); }
 
-  /// Returns the body of the function.  It is the bottom of the statement
-  /// stack.
+  /// Finalizes any StatementBuilders returns the body of the function.
+  /// Must only be called once, and to be used only for testing.
   /// @returns the body of the function.
-  const ast::BlockStatement* ast_body();
+  const ast::StatementList ast_body();
 
   /// Records failure.
   /// @returns a FailStream on which to emit diagnostics.
@@ -811,6 +834,14 @@
   /// @returns a pointer to the statement.
   ast::Statement* AddStatement(ast::Statement* statement);
 
+  template <typename T, typename... ARGS>
+  T* AddStatementBuilder(ARGS&&... args) {
+    // The builder is temporary and is not owned by the module.
+    auto builder = new T(std::forward<ARGS>(args)...);
+    AddStatement(builder);
+    return builder;
+  }
+
   /// Returns the source record for the given instruction.
   /// @param inst the SPIR-V instruction
   /// @return the Source record, or a default one
@@ -819,43 +850,79 @@
   /// @returns the last statetment in the top of the statement stack.
   ast::Statement* LastStatement();
 
-  using CompletionAction = std::function<void()>;
+  using CompletionAction = std::function<void(const ast::StatementList&)>;
 
   // A StatementBlock represents a braced-list of statements while it is being
   // constructed.
-  struct StatementBlock {
+  class StatementBlock {
+   public:
     StatementBlock(const Construct* construct,
                    uint32_t end_id,
                    CompletionAction completion_action,
-                   ast::BlockStatement* statements,
                    ast::CaseStatementList* cases);
     StatementBlock(StatementBlock&&);
     ~StatementBlock();
 
-    // The construct to which this construct constributes.
-    const Construct* construct_;
-    // The ID of the block at which the completion action should be triggerd
-    // and this statement block discarded. This is often the |end_id| of
-    // |construct| itself.
-    uint32_t end_id_;
-    // The completion action finishes processing this statement block.
-    CompletionAction completion_action_;
+    StatementBlock(const StatementBlock&) = delete;
+    StatementBlock& operator=(const StatementBlock&) = delete;
 
-    // Only one of |statements| or |cases| is active.
+    /// Replaces any StatementBuilders with the built result, and calls the
+    /// completion callback (if set). Must only be called once, after all
+    /// statements have been added with Add().
+    /// @param mod the module
+    void Finalize(ast::Module* mod);
 
-    // The list of statements being built, if this construct is not a switch.
-    ast::BlockStatement* statements_ = nullptr;
-    // The list of switch cases being built, if this construct is a switch.
+    /// Add() adds `statement` to the block.
+    /// Add() must not be called after calling Finalize().
+    void Add(ast::Statement* statement);
+
+    /// @param construct the construct which this construct constributes to
+    void SetConstruct(const Construct* construct) { construct_ = construct; }
+
+    /// @return the construct to which this construct constributes
+    const Construct* Construct() const { return construct_; }
+
+    /// @return the ID of the block at which the completion action should be
+    /// triggerd and this statement block discarded. This is often the `end_id`
+    /// of `construct` itself.
+    uint32_t EndId() const { return end_id_; }
+
+    /// @return the completion action finishes processing this statement block
+    CompletionAction CompletionAction() const { return completion_action_; }
+
+    /// @return the list of statements being built, if this construct is not a
+    /// switch.
+    const ast::StatementList& Statements() const { return statements_; }
+
+    /// @return the list of switch cases being built, if this construct is a
+    /// switch
+    ast::CaseStatementList* Cases() const { return cases_; }
+
+   private:
+    /// The construct to which this construct constributes.
+    const spirv::Construct* construct_;
+    /// The ID of the block at which the completion action should be triggerd
+    /// and this statement block discarded. This is often the `end_id` of
+    /// `construct` itself.
+    uint32_t const end_id_;
+    /// The completion action finishes processing this statement block.
+    FunctionEmitter::CompletionAction const completion_action_;
+
+    // Only one of `statements` or `cases` is active.
+
+    /// The list of statements being built, if this construct is not a switch.
+    ast::StatementList statements_;
+    /// The list of switch cases being built, if this construct is a switch.
     ast::CaseStatementList* cases_ = nullptr;
+    /// True if Finalize() has been called.
+    bool finalized_ = false;
   };
 
   /// Pushes an empty statement block onto the statements stack.
-  /// @param block the block to push into
   /// @param cases the case list to push into
   /// @param action the completion action for this block
   void PushNewStatementBlock(const Construct* construct,
                              uint32_t end_id,
-                             ast::BlockStatement* block,
                              ast::CaseStatementList* cases,
                              CompletionAction action);
 
@@ -887,6 +954,8 @@
     return ast_module_.create<T>(std::forward<ARGS>(args)...);
   }
 
+  using StatementsStack = std::vector<StatementBlock>;
+
   ParserImpl& parser_impl_;
   ast::Module& ast_module_;
   spvtools::opt::IRContext& ir_context_;
@@ -901,9 +970,9 @@
   // A stack of statement lists. Each list is contained in a construct in
   // the next deeper element of stack. The 0th entry represents the statements
   // for the entire function.  This stack is never empty.
-  // The |construct| member for the 0th element is only valid during the
+  // The `construct` member for the 0th element is only valid during the
   // lifetime of the EmitFunctionBodyStatements method.
-  std::vector<StatementBlock> statements_stack_;
+  StatementsStack statements_stack_;
 
   // The set of IDs that have already had an identifier name generated for it.
   std::unordered_set<uint32_t> identifier_values_;
diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc
index 2c4d356..b8f004b 100644
--- a/src/reader/spirv/function_composite_test.cc
+++ b/src/reader/spirv/function_composite_test.cc
@@ -479,7 +479,8 @@
   ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
   FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
   EXPECT_TRUE(fe.EmitBody()) << p->error();
-  EXPECT_THAT(ToString(p->get_module(), fe.ast_body()), HasSubstr(R"(
+  auto got = fe.ast_body();
+  EXPECT_THAT(ToString(p->get_module(), got), HasSubstr(R"(
   VariableConst{
     x_2
     none
@@ -491,8 +492,8 @@
       }
     }
   })"))
-      << ToString(p->get_module(), fe.ast_body());
-  EXPECT_THAT(ToString(p->get_module(), fe.ast_body()), HasSubstr(R"(
+      << ToString(p->get_module(), got);
+  EXPECT_THAT(ToString(p->get_module(), got), HasSubstr(R"(
   VariableConst{
     x_4
     none
@@ -504,7 +505,7 @@
       }
     }
   })"))
-      << ToString(p->get_module(), fe.ast_body());
+      << ToString(p->get_module(), got);
 }
 
 TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {
diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h
index 0342e47..e8283c6 100644
--- a/src/reader/spirv/parser_impl_test_helper.h
+++ b/src/reader/spirv/parser_impl_test_helper.h
@@ -57,13 +57,14 @@
 // Use this form when you don't need to template any further.
 using SpvParserTest = SpvParserTestBase<::testing::Test>;
 
-/// Returns the string dump of a function body.
-/// @param body the statement in the body
-/// @returnss the string dump of a function body.
+/// Returns the string dump of a statement list.
+/// @param mod the module
+/// @param stmts the statement list
+/// @returns the string dump of a statement list.
 inline std::string ToString(const ast::Module& mod,
-                            const ast::BlockStatement* body) {
+                            const ast::StatementList& stmts) {
   std::ostringstream outs;
-  for (const auto* stmt : *body) {
+  for (const auto* stmt : stmts) {
     stmt->to_str(outs, 0);
   }
   return Demangler().Demangle(mod, outs.str());