Resolver: Validate usage of break

Also remove unused fields of Resolver (block_to_info_, block_infos_). We can put them back when they're actually needed.

Fixed: tint:190
Change-Id: I1a02a24eca7fba32b8e1120abb88040138a39c6a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44051
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index cfb3e19..da144c7 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -66,9 +66,8 @@
 Resolver::~Resolver() = default;
 
 Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty,
-                               Resolver::BlockInfo* p,
-                               const ast::BlockStatement* b)
-    : type(ty), parent(p), block(b) {}
+                               Resolver::BlockInfo* p)
+    : type(ty), parent(p) {}
 
 Resolver::BlockInfo::~BlockInfo() = default;
 
@@ -150,12 +149,8 @@
 }
 
 bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
-  auto* block =
-      block_infos_.Create(BlockInfo::Type::Generic, current_block_, stmt);
-  block_to_info_[stmt] = block;
-  ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
-
-  return Statements(stmt->list());
+  return BlockScope(BlockInfo::Type::kGeneric,
+                    [&] { return Statements(stmt->list()); });
 }
 
 bool Resolver::Statements(const ast::StatementList& stmts) {
@@ -219,18 +214,24 @@
     return BlockStatement(b);
   }
   if (stmt->Is<ast::BreakStatement>()) {
+    if (!current_block_->FindFirstParent(BlockInfo::Type::kLoop) &&
+        !current_block_->FindFirstParent(BlockInfo::Type::kSwitchCase)) {
+      diagnostics_.add_error("break statement must be in a loop or switch case",
+                             stmt->source());
+      return false;
+    }
     return true;
   }
   if (auto* c = stmt->As<ast::CallStatement>()) {
     return Expression(c->expr());
   }
   if (auto* c = stmt->As<ast::CaseStatement>()) {
-    return BlockStatement(c->body());
+    return CaseStatement(c);
   }
   if (stmt->Is<ast::ContinueStatement>()) {
     // Set if we've hit the first continue statement in our parent loop
     if (auto* loop_block =
-            current_block_->FindFirstParent(BlockInfo::Type::Loop)) {
+            current_block_->FindFirstParent(BlockInfo::Type::kLoop)) {
       if (loop_block->first_continue == size_t(~0)) {
         loop_block->first_continue = loop_block->decls.size();
       }
@@ -268,26 +269,20 @@
     // these would make their BlockInfo siblings as in the AST, but we want the
     // body BlockInfo to parent the continuing BlockInfo for semantics and
     // validation. Also, we need to set their types differently.
-    auto* block =
-        block_infos_.Create(BlockInfo::Type::Loop, current_block_, l->body());
-    block_to_info_[l->body()] = block;
-    ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
-
-    if (!Statements(l->body()->list())) {
-      return false;
-    }
-
-    if (l->has_continuing()) {
-      auto* cont_block = block_infos_.Create(BlockInfo::Type::LoopContinuing,
-                                             current_block_, l->continuing());
-      block_to_info_[l->continuing()] = cont_block;
-      ScopedAssignment<BlockInfo*> scope_sa2(current_block_, cont_block);
-
-      if (!Statements(l->continuing()->list())) {
+    return BlockScope(BlockInfo::Type::kLoop, [&] {
+      if (!Statements(l->body()->list())) {
         return false;
       }
-    }
-    return true;
+
+      if (l->has_continuing()) {
+        if (!BlockScope(BlockInfo::Type::kLoopContinuing,
+                        [&] { return Statements(l->continuing()->list()); })) {
+          return false;
+        }
+      }
+
+      return true;
+    });
   }
   if (auto* r = stmt->As<ast::ReturnStatement>()) {
     return Expression(r->value());
@@ -297,7 +292,7 @@
       return false;
     }
     for (auto* case_stmt : s->body()) {
-      if (!Statement(case_stmt)) {
+      if (!CaseStatement(case_stmt)) {
         return false;
       }
     }
@@ -316,6 +311,11 @@
   return false;
 }
 
+bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
+  return BlockScope(BlockInfo::Type::kSwitchCase,
+                    [&] { return Statements(stmt->body()->list()); });
+}
+
 bool Resolver::Expressions(const ast::ExpressionList& list) {
   for (auto* expr : list) {
     if (!Expression(expr)) {
@@ -395,8 +395,7 @@
   } else if (auto* arr = parent_type->As<type::Array>()) {
     if (!arr->type()->is_scalar()) {
       // If we extract a non-scalar from an array then we also get a pointer. We
-      // will generate a Function storage class variable to store this
-      // into.
+      // will generate a Function storage class variable to store this into.
       ret = builder_->create<type::Pointer>(ret, ast::StorageClass::kFunction);
     }
   }
@@ -573,9 +572,9 @@
     // refer to a variable that is bypassed by a continue statement in the
     // loop's body block.
     if (auto* continuing_block =
-            current_block_->FindFirstParent(BlockInfo::Type::LoopContinuing)) {
+            current_block_->FindFirstParent(BlockInfo::Type::kLoopContinuing)) {
       auto* loop_block =
-          continuing_block->FindFirstParent(BlockInfo::Type::Loop);
+          continuing_block->FindFirstParent(BlockInfo::Type::kLoop);
       if (loop_block->first_continue != size_t(~0)) {
         auto& decls = loop_block->decls;
         // If our identifier is in loop_block->decls, make sure its index is
@@ -946,6 +945,13 @@
   }
 }
 
+template <typename F>
+bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
+  BlockInfo block_info(type, current_block_);
+  ScopedAssignment<BlockInfo*> sa(current_block_, &block_info);
+  return callback();
+}
+
 Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
     : declaration(decl), storage_class(decl->declared_storage_class()) {}
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index f07fab5..4a095be 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -33,6 +33,7 @@
 class BinaryExpression;
 class BitcastExpression;
 class CallExpression;
+class CaseStatement;
 class ConstructorExpression;
 class Function;
 class IdentifierExpression;
@@ -105,9 +106,9 @@
   /// parent block and variables declared in the block.
   /// Used to validate variable scoping rules.
   struct BlockInfo {
-    enum class Type { Generic, Loop, LoopContinuing };
+    enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
 
-    BlockInfo(Type type, BlockInfo* parent, const ast::BlockStatement* block);
+    BlockInfo(Type type, BlockInfo* parent);
     ~BlockInfo();
 
     template <typename Pred>
@@ -124,9 +125,8 @@
           [ty](auto* block_info) { return block_info->type == ty; });
     }
 
-    const Type type;
-    BlockInfo* parent;
-    const ast::BlockStatement* block;
+    Type const type;
+    BlockInfo* const parent;
     std::vector<const ast::Variable*> decls;
 
     // first_continue is set to the index of the first variable in decls
@@ -134,9 +134,6 @@
     constexpr static size_t kNoContinue = size_t(~0);
     size_t first_continue = kNoContinue;
   };
-  std::unordered_map<const ast::BlockStatement*, BlockInfo*> block_to_info_;
-  BlockAllocator<BlockInfo> block_infos_;
-  BlockInfo* current_block_ = nullptr;
 
   /// Resolves the program, without creating final the semantic nodes.
   /// @returns true on success, false on error
@@ -200,6 +197,7 @@
   bool Binary(ast::BinaryExpression* expr);
   bool Bitcast(ast::BitcastExpression* expr);
   bool Call(ast::CallExpression* expr);
+  bool CaseStatement(ast::CaseStatement* stmt);
   bool Constructor(ast::ConstructorExpression* expr);
   bool Identifier(ast::IdentifierExpression* expr);
   bool IntrinsicCall(ast::CallExpression* call,
@@ -221,9 +219,16 @@
   /// @param type the resolved type
   void SetType(ast::Expression* expr, type::Type* type);
 
+  /// Constructs a new BlockInfo with the given type and with #current_block_ as
+  /// its parent, assigns this to #current_block_, and then calls `callback`.
+  /// The original #current_block_ is restored on exit.
+  template <typename F>
+  bool BlockScope(BlockInfo::Type type, F&& callback);
+
   ProgramBuilder* const builder_;
   std::unique_ptr<IntrinsicTable> const intrinsic_table_;
   diag::List diagnostics_;
+  BlockInfo* current_block_ = nullptr;
   ScopeStack<VariableInfo*> variable_stack_;
   std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
   std::unordered_map<ast::Function*, FunctionInfo*> function_to_info_;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index d7c971d..d3613cd 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -17,6 +17,7 @@
 #include "gmock/gmock.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/bitcast_expression.h"
+#include "src/ast/break_statement.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/continue_statement.h"
 #include "src/ast/if_statement.h"
@@ -476,10 +477,37 @@
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
+TEST_F(ResolverTest, Stmt_ContinueInLoop) {
+  WrapInFunction(Loop(Block(create<ast::ContinueStatement>(Source{{12, 34}}))));
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
 TEST_F(ResolverTest, Stmt_ContinueNotInLoop) {
-  WrapInFunction(create<ast::ContinueStatement>());
+  WrapInFunction(create<ast::ContinueStatement>(Source{{12, 34}}));
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(r()->error(), "error: continue statement must be in a loop");
+  EXPECT_EQ(r()->error(), "12:34 error: continue statement must be in a loop");
+}
+
+TEST_F(ResolverTest, Stmt_BreakInLoop) {
+  WrapInFunction(Loop(Block(create<ast::BreakStatement>(Source{{12, 34}}))));
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_BreakInSwitch) {
+  WrapInFunction(Loop(Block(create<ast::SwitchStatement>(
+      Expr(1), ast::CaseStatementList{
+                   create<ast::CaseStatement>(
+                       ast::CaseSelectorList{Literal(1)},
+                       Block(create<ast::BreakStatement>(Source{{12, 34}}))),
+               }))));
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_BreakNotInLoopOrSwitch) {
+  WrapInFunction(create<ast::BreakStatement>(Source{{12, 34}}));
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: break statement must be in a loop or switch case");
 }
 
 TEST_F(ResolverTest, Stmt_Return) {