sem: Split BlockStatement up into subclasses

Allows us to put block-type-specific data on the specific subtype instead of littering a common base class

Change-Id: If4a327a8ee52d5911308f38b518ec07c3ceebcb7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51367
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 2450f14..d05d058 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1251,8 +1251,8 @@
           << "Resolver::Function() called with a current statement";
       return false;
     }
-    sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>(
-        func->body(), nullptr, sem::BlockStatement::Type::kGeneric);
+    auto* sem_block =
+        builder_->create<sem::BlockStatement>(func->body(), nullptr);
     builder_->Sem().Add(func->body(), sem_block);
     TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
     if (!BlockScope(func->body(),
@@ -1379,8 +1379,7 @@
   sem::Statement* sem_statement;
   if (stmt->As<ast::BlockStatement>()) {
     sem_statement = builder_->create<sem::BlockStatement>(
-        stmt->As<ast::BlockStatement>(), current_statement_,
-        sem::BlockStatement::Type::kGeneric);
+        stmt->As<ast::BlockStatement>(), current_statement_);
   } else {
     sem_statement = builder_->create<sem::Statement>(stmt, current_statement_);
   }
@@ -1403,9 +1402,8 @@
     return BlockScope(b, [&] { return Statements(b->list()); });
   }
   if (stmt->Is<ast::BreakStatement>()) {
-    if (!current_block_->FindFirstParent(sem::BlockStatement::Type::kLoop) &&
-        !current_block_->FindFirstParent(
-            sem::BlockStatement::Type::kSwitchCase)) {
+    if (!current_block_->FindFirstParent<sem::LoopBlockStatement>() &&
+        !current_block_->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
       diagnostics_.add_error("break statement must be in a loop or switch case",
                              stmt->source());
       return false;
@@ -1422,9 +1420,9 @@
   if (stmt->Is<ast::ContinueStatement>()) {
     // Set if we've hit the first continue statement in our parent loop
     if (auto* loop_block =
-            current_block_->FindFirstParent(sem::BlockStatement::Type::kLoop)) {
+            current_block_->FindFirstParent<sem::LoopBlockStatement>()) {
       if (loop_block->FirstContinue() == size_t(~0)) {
-        const_cast<sem::BlockStatement*>(loop_block)
+        const_cast<sem::LoopBlockStatement*>(loop_block)
             ->SetFirstContinue(loop_block->Decls().size());
       }
     } else {
@@ -1468,8 +1466,8 @@
   for (auto* sel : stmt->selectors()) {
     Mark(sel);
   }
-  sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>(
-      stmt->body(), current_statement_, sem::BlockStatement::Type::kSwitchCase);
+  auto* sem_block = builder_->create<sem::SwitchCaseBlockStatement>(
+      stmt->body(), current_statement_);
   builder_->Sem().Add(stmt->body(), sem_block);
   TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
   return BlockScope(stmt->body(),
@@ -1492,8 +1490,8 @@
 
   Mark(stmt->body());
   {
-    sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>(
-        stmt->body(), current_statement_, sem::BlockStatement::Type::kGeneric);
+    auto* sem_block =
+        builder_->create<sem::BlockStatement>(stmt->body(), current_statement_);
     builder_->Sem().Add(stmt->body(), sem_block);
     TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
     if (!BlockScope(stmt->body(),
@@ -1525,9 +1523,8 @@
     }
     Mark(else_stmt->body());
     {
-      sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>(
-          else_stmt->body(), current_statement_,
-          sem::BlockStatement::Type::kGeneric);
+      auto* sem_block = builder_->create<sem::BlockStatement>(
+          else_stmt->body(), current_statement_);
       builder_->Sem().Add(else_stmt->body(), sem_block);
       TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
       if (!BlockScope(else_stmt->body(),
@@ -1546,8 +1543,8 @@
   // validation. Also, we need to set their types differently.
   Mark(stmt->body());
 
-  auto* sem_block_body = builder_->create<sem::BlockStatement>(
-      stmt->body(), current_statement_, sem::BlockStatement::Type::kLoop);
+  auto* sem_block_body = builder_->create<sem::LoopBlockStatement>(
+      stmt->body(), current_statement_);
   builder_->Sem().Add(stmt->body(), sem_block_body);
   TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_body);
   return BlockScope(stmt->body(), [&] {
@@ -1558,9 +1555,9 @@
       Mark(stmt->continuing());
     }
     if (stmt->has_continuing()) {
-      auto* sem_block_continuing = builder_->create<sem::BlockStatement>(
-          stmt->continuing(), current_statement_,
-          sem::BlockStatement::Type::kLoopContinuing);
+      auto* sem_block_continuing =
+          builder_->create<sem::LoopContinuingBlockStatement>(
+              stmt->continuing(), current_statement_);
       builder_->Sem().Add(stmt->continuing(), sem_block_continuing);
       TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_continuing);
       if (!BlockScope(stmt->continuing(),
@@ -1909,10 +1906,11 @@
       // If identifier is part of a loop continuing block, make sure it
       // doesn't refer to a variable that is bypassed by a continue statement
       // in the loop's body block.
-      if (auto* continuing_block = current_block_->FindFirstParent(
-              sem::BlockStatement::Type::kLoopContinuing)) {
+      if (auto* continuing_block =
+              current_block_
+                  ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
         auto* loop_block =
-            continuing_block->FindFirstParent(sem::BlockStatement::Type::kLoop);
+            continuing_block->FindFirstParent<sem::LoopBlockStatement>();
         if (loop_block->FirstContinue() != size_t(~0)) {
           auto& decls = loop_block->Decls();
           // If our identifier is in loop_block->decls, make sure its index is
diff --git a/src/sem/block_statement.cc b/src/sem/block_statement.cc
index 8d63f0d..bc4c845 100644
--- a/src/sem/block_statement.cc
+++ b/src/sem/block_statement.cc
@@ -17,35 +17,47 @@
 #include "src/ast/block_statement.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::sem::BlockStatement);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::LoopBlockStatement);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::LoopContinuingBlockStatement);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement);
 
 namespace tint {
 namespace sem {
 
 BlockStatement::BlockStatement(const ast::BlockStatement* declaration,
-                               const Statement* parent,
-                               Type type)
-    : Base(declaration, parent), type_(type) {}
+                               const Statement* parent)
+    : Base(declaration, parent) {}
 
 BlockStatement::~BlockStatement() = default;
 
-const BlockStatement* BlockStatement::FindFirstParent(
-    BlockStatement::Type ty) const {
-  return FindFirstParent(
-      [ty](auto* block_info) { return block_info->type_ == ty; });
-}
-
 const ast::BlockStatement* BlockStatement::Declaration() const {
   return Base::Declaration()->As<ast::BlockStatement>();
 }
 
-void BlockStatement::SetFirstContinue(size_t first_continue) {
-  TINT_ASSERT(type_ == Type::kLoop);
-  first_continue_ = first_continue;
-}
-
 void BlockStatement::AddDecl(ast::Variable* var) {
   decls_.push_back(var);
 }
 
+LoopBlockStatement::LoopBlockStatement(const ast::BlockStatement* declaration,
+                                       const Statement* parent)
+    : Base(declaration, parent) {}
+LoopBlockStatement::~LoopBlockStatement() = default;
+
+void LoopBlockStatement::SetFirstContinue(size_t first_continue) {
+  first_continue_ = first_continue;
+}
+
+LoopContinuingBlockStatement::LoopContinuingBlockStatement(
+    const ast::BlockStatement* declaration,
+    const Statement* parent)
+    : Base(declaration, parent) {}
+LoopContinuingBlockStatement::~LoopContinuingBlockStatement() = default;
+
+SwitchCaseBlockStatement::SwitchCaseBlockStatement(
+    const ast::BlockStatement* declaration,
+    const Statement* parent)
+    : Base(declaration, parent) {}
+SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;
+
 }  // namespace sem
 }  // namespace tint
diff --git a/src/sem/block_statement.h b/src/sem/block_statement.h
index 5e0f481..7989df5 100644
--- a/src/sem/block_statement.h
+++ b/src/sem/block_statement.h
@@ -17,7 +17,6 @@
 
 #include <vector>
 
-#include "src/debug.h"
 #include "src/sem/statement.h"
 
 namespace tint {
@@ -34,15 +33,11 @@
 /// declared in the block.
 class BlockStatement : public Castable<BlockStatement, Statement> {
  public:
-  enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
-
   /// Constructor
   /// @param declaration the AST node for this block statement
   /// @param parent the owning statement
-  /// @param type the type of block this is
   BlockStatement(const ast::BlockStatement* declaration,
-                 const Statement* parent,
-                 Type type);
+                 const Statement* parent);
 
   /// Destructor
   ~BlockStatement() override;
@@ -52,7 +47,7 @@
   const ast::BlockStatement* Declaration() const;
 
   /// @returns the closest enclosing block that satisfies the given predicate,
-  ///          which may be the block itself, or nullptr if no match is found
+  /// which may be the block itself, or nullptr if no match is found
   /// @param pred a predicate that the resulting block must satisfy
   template <typename Pred>
   const BlockStatement* FindFirstParent(Pred&& pred) const {
@@ -63,21 +58,47 @@
     return curr;
   }
 
-  /// @returns the closest enclosing block that matches the given type, which
-  ///          may be the block itself, or nullptr if no match is found
-  /// @param ty the type of block to be searched for
-  const BlockStatement* FindFirstParent(BlockStatement::Type ty) const;
+  /// @returns the statement itself if it matches the template type `T`,
+  /// otherwise the nearest enclosing block that matches `T`, or nullptr if
+  /// there is none.
+  template <typename T>
+  const T* FindFirstParent() const {
+    const BlockStatement* curr = this;
+    while (curr) {
+      if (auto* block = curr->As<T>()) {
+        return block;
+      }
+      curr = curr->Block();
+    }
+    return nullptr;
+  }
 
   /// @returns the declarations associated with this block
   const std::vector<const ast::Variable*>& Decls() const { return decls_; }
 
-  /// Requires that this is a loop block.
+  /// Associates a declaration with this block.
+  /// @param var a variable declaration to be added to the block
+  void AddDecl(ast::Variable* var);
+
+ private:
+  std::vector<const ast::Variable*> decls_;
+};
+
+/// Holds semantic information about a loop block
+class LoopBlockStatement : public Castable<LoopBlockStatement, BlockStatement> {
+ public:
+  /// Constructor
+  /// @param declaration the AST node for this block statement
+  /// @param parent the owning statement
+  LoopBlockStatement(const ast::BlockStatement* declaration,
+                     const Statement* parent);
+
+  /// Destructor
+  ~LoopBlockStatement() override;
+
   /// @returns the index of the first variable declared after the first continue
-  ///          statement
-  size_t FirstContinue() const {
-    TINT_ASSERT(type_ == Type::kLoop);
-    return first_continue_;
-  }
+  /// statement
+  size_t FirstContinue() const { return first_continue_; }
 
   /// Requires that this is a loop block.
   /// Allows the resolver to set the index of the first variable declared after
@@ -85,20 +106,41 @@
   /// @param first_continue index of the relevant variable
   void SetFirstContinue(size_t first_continue);
 
-  /// Allows the resolver to associate a declaration with this block.
-  /// @param var a variable declaration to be added to the block
-  void AddDecl(ast::Variable* var);
-
  private:
-  Type const type_;
-  std::vector<const ast::Variable*> decls_;
-
   // first_continue is set to the index of the first variable in decls
   // declared after the first continue statement in a loop block, if any.
   constexpr static size_t kNoContinue = size_t(~0);
   size_t first_continue_ = kNoContinue;
 };
 
+/// Holds semantic information about a loop continuing block
+class LoopContinuingBlockStatement
+    : public Castable<LoopContinuingBlockStatement, BlockStatement> {
+ public:
+  /// Constructor
+  /// @param declaration the AST node for this block statement
+  /// @param parent the owning statement
+  LoopContinuingBlockStatement(const ast::BlockStatement* declaration,
+                               const Statement* parent);
+
+  /// Destructor
+  ~LoopContinuingBlockStatement() override;
+};
+
+/// Holds semantic information about a switch case block
+class SwitchCaseBlockStatement
+    : public Castable<SwitchCaseBlockStatement, BlockStatement> {
+ public:
+  /// Constructor
+  /// @param declaration the AST node for this block statement
+  /// @param parent the owning statement
+  SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
+                           const Statement* parent);
+
+  /// Destructor
+  ~SwitchCaseBlockStatement() override;
+};
+
 }  // namespace sem
 }  // namespace tint