resolver: Refactor Statement handling

Break up Resolver::Statement() into multiple resolver functions.
Move simple statement validation out to resolver_validation.cc

Change-Id: Ifa29433af0a9afa39a66ac3e4f7ca376351adfbf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71102
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 173220c..5d4575d 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -654,9 +654,9 @@
           << "Resolver::Function() called with a current compound statement";
       return nullptr;
     }
-    auto* sem_block = builder_->create<sem::FunctionBlockStatement>(func);
-    builder_->Sem().Add(decl->body, sem_block);
-    if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) {
+    if (!StatementScope(decl->body,
+                        builder_->create<sem::FunctionBlockStatement>(func),
+                        [&] { return Statements(decl->body->statements); })) {
       return nullptr;
     }
   }
@@ -796,7 +796,8 @@
 bool Resolver::Statements(const ast::StatementList& stmts) {
   for (auto* stmt : stmts) {
     Mark(stmt);
-    if (!Statement(stmt)) {
+    auto* sem = Statement(stmt);
+    if (!sem) {
       return false;
     }
   }
@@ -807,18 +808,18 @@
   return true;
 }
 
-bool Resolver::Statement(const ast::Statement* stmt) {
+sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
   if (stmt->Is<ast::CaseStatement>()) {
     AddError("case statement can only be used inside a switch statement",
              stmt->source);
-    return false;
+    return nullptr;
   }
   if (stmt->Is<ast::ElseStatement>()) {
     TINT_ICE(Resolver, diagnostics_)
         << "Resolver::Statement() encountered an Else statement. Else "
            "statements are embedded in If statements, so should never be "
            "encountered as top-level statements";
-    return false;
+    return nullptr;
   }
 
   // Compound statements. These create their own sem::CompoundStatement
@@ -840,69 +841,26 @@
   }
 
   // Non-Compound statements
-  sem::Statement* sem_statement = builder_->create<sem::Statement>(
-      stmt, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem_statement);
-  TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement);
   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
-    return Assignment(a);
+    return AssignmentStatement(a);
   }
-  if (stmt->Is<ast::BreakStatement>()) {
-    if (!sem_statement->FindFirstParent<sem::LoopBlockStatement>() &&
-        !sem_statement->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
-      AddError("break statement must be in a loop or switch case",
-               stmt->source);
-      return false;
-    }
-    return true;
+  if (auto* b = stmt->As<ast::BreakStatement>()) {
+    return BreakStatement(b);
   }
   if (auto* c = stmt->As<ast::CallStatement>()) {
-    if (!Expression(c->expr)) {
-      return false;
-    }
-    return true;
+    return CallStatement(c);
   }
   if (auto* c = stmt->As<ast::ContinueStatement>()) {
-    // Set if we've hit the first continue statement in our parent loop
-    if (auto* block =
-            current_block_->FindFirstParent<
-                sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
-      if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
-        if (!loop_block->FirstContinue()) {
-          const_cast<sem::LoopBlockStatement*>(loop_block)
-              ->SetFirstContinue(c, loop_block->Decls().size());
-        }
-      } else {
-        AddError("continuing blocks must not contain a continue statement",
-                 stmt->source);
-        return false;
-      }
-    } else {
-      AddError("continue statement must be in a loop", stmt->source);
-      return false;
-    }
-
-    return true;
+    return ContinueStatement(c);
   }
-  if (stmt->Is<ast::DiscardStatement>()) {
-    if (auto* continuing =
-            sem_statement
-                ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
-      AddError("continuing blocks must not contain a discard statement",
-               stmt->source);
-      if (continuing != sem_statement->Parent()) {
-        AddNote("see continuing block here", continuing->Declaration()->source);
-      }
-      return false;
-    }
-    current_function_->SetHasDiscard();
-    return true;
+  if (auto* d = stmt->As<ast::DiscardStatement>()) {
+    return DiscardStatement(d);
   }
-  if (stmt->Is<ast::FallthroughStatement>()) {
-    return true;
+  if (auto* f = stmt->As<ast::FallthroughStatement>()) {
+    return FallthroughStatement(f);
   }
   if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return Return(r);
+    return ReturnStatement(r);
   }
   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
     return VariableDeclStatement(v);
@@ -910,43 +868,38 @@
 
   AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
            stmt->source);
-  return false;
+  return nullptr;
 }
 
-bool Resolver::CaseStatement(const ast::CaseStatement* stmt) {
+sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
+    const ast::CaseStatement* stmt) {
   auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
       stmt->body, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem);
-  builder_->Sem().Add(stmt->body, sem);
-  Mark(stmt->body);
-  for (auto* sel : stmt->selectors) {
-    Mark(sel);
-  }
-  return Scope(sem, [&] { return Statements(stmt->body->statements); });
+  return StatementScope(stmt, sem, [&] {
+    builder_->Sem().Add(stmt->body, sem);
+    Mark(stmt->body);
+    for (auto* sel : stmt->selectors) {
+      Mark(sel);
+    }
+    return Statements(stmt->body->statements);
+  });
 }
 
-bool Resolver::IfStatement(const ast::IfStatement* stmt) {
+sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
   auto* sem = builder_->create<sem::IfStatement>(
       stmt, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem);
-  return Scope(sem, [&] {
-    if (!Expression(stmt->condition)) {
+  return StatementScope(stmt, sem, [&] {
+    auto* cond = Expression(stmt->condition);
+    if (!cond) {
       return false;
     }
-
-    auto* cond_type = TypeOf(stmt->condition)->UnwrapRef();
-    if (!cond_type->Is<sem::Bool>()) {
-      AddError(
-          "if statement condition must be bool, got " + TypeNameOf(cond_type),
-          stmt->condition->source);
-      return false;
-    }
+    sem->SetCondition(cond);
 
     Mark(stmt->body);
     auto* body = builder_->create<sem::BlockStatement>(
         stmt->body, current_compound_statement_, current_function_);
-    builder_->Sem().Add(stmt->body, body);
-    if (!Scope(body, [&] { return Statements(stmt->body->statements); })) {
+    if (!StatementScope(stmt->body, body,
+                        [&] { return Statements(stmt->body->statements); })) {
       return false;
     }
 
@@ -956,59 +909,56 @@
         return false;
       }
     }
-    return true;
+
+    return ValidateIfStatement(sem);
   });
 }
 
-bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
+sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
   auto* sem = builder_->create<sem::ElseStatement>(
       stmt, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem);
-  return Scope(sem, [&] {
-    if (auto* cond = stmt->condition) {
-      if (!Expression(cond)) {
+  return StatementScope(stmt, sem, [&] {
+    if (auto* cond_expr = stmt->condition) {
+      auto* cond = Expression(cond_expr);
+      if (!cond) {
         return false;
       }
-
-      auto* else_cond_type = TypeOf(cond)->UnwrapRef();
-      if (!else_cond_type->Is<sem::Bool>()) {
-        AddError("else statement condition must be bool, got " +
-                     TypeNameOf(else_cond_type),
-                 cond->source);
-        return false;
-      }
+      sem->SetCondition(cond);
     }
 
     Mark(stmt->body);
     auto* body = builder_->create<sem::BlockStatement>(
         stmt->body, current_compound_statement_, current_function_);
-    builder_->Sem().Add(stmt->body, body);
-    return Scope(body, [&] { return Statements(stmt->body->statements); });
+    if (!StatementScope(stmt->body, body,
+                        [&] { return Statements(stmt->body->statements); })) {
+      return false;
+    }
+
+    return ValidateElseStatement(sem);
   });
 }
 
-bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
+sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
   auto* sem = builder_->create<sem::BlockStatement>(
       stmt->As<ast::BlockStatement>(), current_compound_statement_,
       current_function_);
-  builder_->Sem().Add(stmt, sem);
-  return Scope(sem, [&] { return Statements(stmt->statements); });
+  return StatementScope(stmt, sem,
+                        [&] { return Statements(stmt->statements); });
 }
 
-bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
+sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
   auto* sem = builder_->create<sem::LoopStatement>(
       stmt, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem);
-  return Scope(sem, [&] {
+  return StatementScope(stmt, sem, [&] {
     Mark(stmt->body);
 
     auto* body = builder_->create<sem::LoopBlockStatement>(
         stmt->body, current_compound_statement_, current_function_);
-    builder_->Sem().Add(stmt->body, body);
-    return Scope(body, [&] {
+    return StatementScope(stmt->body, body, [&] {
       if (!Statements(stmt->body->statements)) {
         return false;
       }
+
       if (stmt->continuing) {
         Mark(stmt->continuing);
         if (!stmt->continuing->Empty()) {
@@ -1016,24 +966,22 @@
               builder_->create<sem::LoopContinuingBlockStatement>(
                   stmt->continuing, current_compound_statement_,
                   current_function_);
-          builder_->Sem().Add(stmt->continuing, continuing);
-          if (!Scope(continuing, [&] {
-                return Statements(stmt->continuing->statements);
-              })) {
-            return false;
-          }
+          return StatementScope(stmt->continuing, continuing, [&] {
+                   return Statements(stmt->continuing->statements);
+                 }) != nullptr;
         }
       }
+
       return true;
     });
   });
 }
 
-bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
+sem::ForLoopStatement* Resolver::ForLoopStatement(
+    const ast::ForLoopStatement* stmt) {
   auto* sem = builder_->create<sem::ForLoopStatement>(
       stmt, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem);
-  return Scope(sem, [&] {
+  return StatementScope(stmt, sem, [&] {
     if (auto* initializer = stmt->initializer) {
       Mark(initializer);
       if (!Statement(initializer)) {
@@ -1041,17 +989,12 @@
       }
     }
 
-    if (auto* condition = stmt->condition) {
-      if (!Expression(condition)) {
+    if (auto* cond_expr = stmt->condition) {
+      auto* cond = Expression(cond_expr);
+      if (!cond) {
         return false;
       }
-
-      auto* cond_ty = TypeOf(condition)->UnwrapRef();
-      if (!cond_ty->Is<sem::Bool>()) {
-        AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
-                 condition->source);
-        return false;
-      }
+      sem->SetCondition(cond);
     }
 
     if (auto* continuing = stmt->continuing) {
@@ -1065,8 +1008,12 @@
 
     auto* body = builder_->create<sem::LoopBlockStatement>(
         stmt->body, current_compound_statement_, current_function_);
-    builder_->Sem().Add(stmt->body, body);
-    return Scope(body, [&] { return Statements(stmt->body->statements); });
+    if (!StatementScope(stmt->body, body,
+                        [&] { return Statements(stmt->body->statements); })) {
+      return false;
+    }
+
+    return ValidateForLoopStatement(sem);
   });
 }
 
@@ -1930,33 +1877,6 @@
   return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
 }
 
-bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
-  Mark(stmt->variable);
-
-  auto* var = Variable(stmt->variable, VariableKind::kLocal);
-  if (!var) {
-    return false;
-  }
-
-  for (auto* deco : stmt->variable->decorations) {
-    Mark(deco);
-    if (!deco->Is<ast::InternalDecoration>()) {
-      AddError("decorations are not valid on local variables", deco->source);
-      return false;
-    }
-  }
-
-  if (current_block_) {  // Not all statements are inside a block
-    current_block_->AddDecl(stmt->variable);
-  }
-
-  if (!ValidateVariable(var)) {
-    return false;
-  }
-
-  return true;
-}
-
 sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
   sem::Type* result = nullptr;
   if (auto* alias = named_type->As<ast::Alias>()) {
@@ -2318,45 +2238,127 @@
   return out;
 }
 
-bool Resolver::Return(const ast::ReturnStatement* ret) {
-  if (auto* value = ret->value) {
-    if (!Expression(value)) {
-      return false;
+sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] {
+    if (auto* value = stmt->value) {
+      if (!Expression(value)) {
+        return false;
+      }
     }
-  }
 
-  // Validate after processing the return value expression so that its type is
-  // available for validation.
-  return ValidateReturn(ret);
+    // Validate after processing the return value expression so that its type is
+    // available for validation.
+    return ValidateReturn(stmt);
+  });
 }
 
-bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
+sem::SwitchStatement* Resolver::SwitchStatement(
+    const ast::SwitchStatement* stmt) {
   auto* sem = builder_->create<sem::SwitchStatement>(
       stmt, current_compound_statement_, current_function_);
-  builder_->Sem().Add(stmt, sem);
-  return Scope(sem, [&] {
+  return StatementScope(stmt, sem, [&] {
     if (!Expression(stmt->condition)) {
       return false;
     }
+
     for (auto* case_stmt : stmt->body) {
       Mark(case_stmt);
       if (!CaseStatement(case_stmt)) {
         return false;
       }
     }
-    if (!ValidateSwitch(stmt)) {
-      return false;
-    }
-    return true;
+
+    return ValidateSwitch(stmt);
   });
 }
 
-bool Resolver::Assignment(const ast::AssignmentStatement* a) {
-  if (!Expression(a->lhs) || !Expression(a->rhs)) {
-    return false;
-  }
+sem::Statement* Resolver::VariableDeclStatement(
+    const ast::VariableDeclStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] {
+    Mark(stmt->variable);
 
-  return ValidateAssignment(a);
+    auto* var = Variable(stmt->variable, VariableKind::kLocal);
+    if (!var) {
+      return false;
+    }
+
+    for (auto* deco : stmt->variable->decorations) {
+      Mark(deco);
+      if (!deco->Is<ast::InternalDecoration>()) {
+        AddError("decorations are not valid on local variables", deco->source);
+        return false;
+      }
+    }
+
+    if (current_block_) {  // Not all statements are inside a block
+      current_block_->AddDecl(stmt->variable);
+    }
+
+    return ValidateVariable(var);
+  });
+}
+
+sem::Statement* Resolver::AssignmentStatement(
+    const ast::AssignmentStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] {
+    if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
+      return false;
+    }
+
+    return ValidateAssignment(stmt);
+  });
+}
+
+sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
+}
+
+sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
+}
+
+sem::Statement* Resolver::ContinueStatement(
+    const ast::ContinueStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] {
+    // Set if we've hit the first continue statement in our parent loop
+    if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
+      if (!block->FirstContinue()) {
+        const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
+            stmt, block->Decls().size());
+      }
+    }
+
+    return ValidateContinueStatement(sem);
+  });
+}
+
+sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] {
+    current_function_->SetHasDiscard();
+
+    return ValidateDiscardStatement(sem);
+  });
+}
+
+sem::Statement* Resolver::FallthroughStatement(
+    const ast::FallthroughStatement* stmt) {
+  auto* sem = builder_->create<sem::Statement>(
+      stmt, current_compound_statement_, current_function_);
+  return StatementScope(stmt, sem, [&] { return true; });
 }
 
 bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
@@ -2399,22 +2401,28 @@
   return true;
 }
 
-template <typename F>
-bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) {
-  auto* prev_current_statement = current_statement_;
-  auto* prev_current_compound_statement = current_compound_statement_;
-  auto* prev_current_block = current_block_;
-  current_statement_ = stmt;
-  current_compound_statement_ = stmt;
-  current_block_ = stmt->As<sem::BlockStatement>();
+template <typename SEM, typename F>
+SEM* Resolver::StatementScope(const ast::Statement* ast,
+                              SEM* sem,
+                              F&& callback) {
+  builder_->Sem().Add(ast, sem);
 
-  TINT_DEFER({
-    current_block_ = prev_current_block;
-    current_compound_statement_ = prev_current_compound_statement;
-    current_statement_ = prev_current_statement;
-  });
+  auto* as_compound =
+      As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
+  auto* as_block =
+      As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
 
-  return callback();
+  TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
+  TINT_SCOPED_ASSIGNMENT(
+      current_compound_statement_,
+      as_compound ? as_compound : current_compound_statement_);
+  TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
+
+  if (!callback()) {
+    return nullptr;
+  }
+
+  return sem;
 }
 
 std::string Resolver::VectorPretty(uint32_t size,
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 08a9c94..f2a2fc8 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -59,8 +59,15 @@
 namespace sem {
 class Array;
 class Atomic;
+class BlockStatement;
+class ElseStatement;
+class ForLoopStatement;
+class IfStatement;
 class Intrinsic;
+class LoopStatement;
 class Statement;
+class SwitchCaseBlockStatement;
+class SwitchStatement;
 class TypeConstructor;
 }  // namespace sem
 
@@ -198,20 +205,26 @@
 
   // Statement resolving methods
   // Each return true on success, false on failure.
-  bool Assignment(const ast::AssignmentStatement* a);
-  bool BlockStatement(const ast::BlockStatement*);
-  bool CaseStatement(const ast::CaseStatement*);
-  bool ElseStatement(const ast::ElseStatement*);
-  bool ForLoopStatement(const ast::ForLoopStatement*);
-  bool Parameter(const ast::Variable* param);
-  bool GlobalVariable(const ast::Variable* var);
-  bool IfStatement(const ast::IfStatement*);
-  bool LoopStatement(const ast::LoopStatement*);
-  bool Return(const ast::ReturnStatement* ret);
-  bool Statement(const ast::Statement*);
+  sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
+  sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
+  sem::Statement* BreakStatement(const ast::BreakStatement*);
+  sem::Statement* CallStatement(const ast::CallStatement*);
+  sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
+  sem::Statement* ContinueStatement(const ast::ContinueStatement*);
+  sem::Statement* DiscardStatement(const ast::DiscardStatement*);
+  sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
+  sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
+  sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
+  sem::Statement* Parameter(const ast::Variable*);
+  sem::IfStatement* IfStatement(const ast::IfStatement*);
+  sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
+  sem::Statement* ReturnStatement(const ast::ReturnStatement*);
+  sem::Statement* Statement(const ast::Statement*);
+  sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
+  sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
   bool Statements(const ast::StatementList&);
-  bool SwitchStatement(const ast::SwitchStatement* s);
-  bool VariableDeclStatement(const ast::VariableDeclStatement*);
+
+  bool GlobalVariable(const ast::Variable*);
 
   // AST and Type validation methods
   // Each return true on success, false on failure.
@@ -224,13 +237,19 @@
   bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
   bool ValidateAtomicVariable(const sem::Variable* var);
   bool ValidateAssignment(const ast::AssignmentStatement* a);
+  bool ValidateBreakStatement(const sem::Statement* stmt);
+  bool ValidateContinueStatement(const sem::Statement* stmt);
+  bool ValidateDiscardStatement(const sem::Statement* stmt);
   bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
                                  const sem::Type* storage_type,
                                  const bool is_input);
+  bool ValidateElseStatement(const sem::ElseStatement* stmt);
   bool ValidateEntryPoint(const sem::Function* func);
+  bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
   bool ValidateFunction(const sem::Function* func);
   bool ValidateFunctionCall(const sem::Call* call);
   bool ValidateGlobalVariable(const sem::Variable* var);
+  bool ValidateIfStatement(const sem::IfStatement* stmt);
   bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
                                      const sem::Type* storage_type);
   bool ValidateIntrinsicCall(const sem::Call* call);
@@ -369,14 +388,19 @@
   /// @param lit the literal
   sem::Type* TypeOf(const ast::LiteralExpression* lit);
 
-  /// Assigns `stmt` to #current_statement_, #current_compound_statement_, and
-  /// possibly #current_block_, pushes the variable scope, then calls
-  /// `callback`. Before returning #current_statement_,
-  /// #current_compound_statement_, and #current_block_ are restored to their
-  /// original values, and the variable scope is popped.
-  /// @returns the value returned by callback
-  template <typename F>
-  bool Scope(sem::CompoundStatement* stmt, F&& callback);
+  /// StatementScope() does the following:
+  /// * Creates the AST -> SEM mapping.
+  /// * Assigns `sem` to #current_statement_
+  /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
+  ///   sem::CompoundStatement.
+  /// * Assigns `sem` to #current_block_ if `sem` derives from
+  ///   sem::BlockStatement.
+  /// * Then calls `callback`.
+  /// * Before returning #current_statement_, #current_compound_statement_, and
+  ///   #current_block_ are restored to their original values.
+  /// @returns `sem` if `callback` returns true, otherwise `nullptr`.
+  template <typename SEM, typename F>
+  SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
 
   /// Returns a human-readable string representation of the vector type name
   /// with the given parameters.
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index 5698138..b7cb04e 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1344,6 +1344,82 @@
   return true;
 }
 
+bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
+  if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
+      !stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
+    AddError("break statement must be in a loop or switch case",
+             stmt->Declaration()->source);
+    return false;
+  }
+  return true;
+}
+
+bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
+  if (auto* block =
+          stmt->FindFirstParent<sem::LoopBlockStatement,
+                                sem::LoopContinuingBlockStatement>()) {
+    if (block->Is<sem::LoopContinuingBlockStatement>()) {
+      AddError("continuing blocks must not contain a continue statement",
+               stmt->Declaration()->source);
+      return false;
+    }
+  } else {
+    AddError("continue statement must be in a loop",
+             stmt->Declaration()->source);
+    return false;
+  }
+
+  return true;
+}
+
+bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
+  if (auto* continuing =
+          stmt->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+    AddError("continuing blocks must not contain a discard statement",
+             stmt->Declaration()->source);
+    if (continuing != stmt->Parent()) {
+      AddNote("see continuing block here", continuing->Declaration()->source);
+    }
+    return false;
+  }
+  return true;
+}
+
+bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
+  if (auto* cond = stmt->Condition()) {
+    auto* cond_ty = cond->Type()->UnwrapRef();
+    if (!cond_ty->Is<sem::Bool>()) {
+      AddError(
+          "else statement condition must be bool, got " + TypeNameOf(cond_ty),
+          stmt->Condition()->Declaration()->source);
+      return false;
+    }
+  }
+  return true;
+}
+
+bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
+  if (auto* cond = stmt->Condition()) {
+    auto* cond_ty = cond->Type()->UnwrapRef();
+    if (!cond_ty->Is<sem::Bool>()) {
+      AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
+               stmt->Condition()->Declaration()->source);
+      return false;
+    }
+  }
+  return true;
+}
+
+bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
+  auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
+  if (!cond_ty->Is<sem::Bool>()) {
+    AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
+             stmt->Condition()->Declaration()->source);
+    return false;
+  }
+  return true;
+}
+
 bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
   if (call->Type()->Is<sem::Void>()) {
     bool is_call_statement = false;
@@ -2103,8 +2179,8 @@
 }
 
 bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
-  auto* cond_type = TypeOf(s->condition)->UnwrapRef();
-  if (!cond_type->is_integer_scalar()) {
+  auto* cond_ty = TypeOf(s->condition)->UnwrapRef();
+  if (!cond_ty->is_integer_scalar()) {
     AddError(
         "switch statement selector expression must be of a "
         "scalar integer type",
@@ -2127,7 +2203,7 @@
     }
 
     for (auto* selector : case_stmt->selectors) {
-      if (cond_type != TypeOf(selector)) {
+      if (cond_ty != TypeOf(selector)) {
         AddError(
             "the case selector values must have the same "
             "type as the selector expression.",
diff --git a/src/sem/for_loop_statement.h b/src/sem/for_loop_statement.h
index ff89241..2f287bb 100644
--- a/src/sem/for_loop_statement.h
+++ b/src/sem/for_loop_statement.h
@@ -21,6 +21,9 @@
 namespace ast {
 class ForLoopStatement;
 }  // namespace ast
+namespace sem {
+class Expression;
+}  // namespace sem
 }  // namespace tint
 
 namespace tint {
@@ -39,6 +42,16 @@
 
   /// Destructor
   ~ForLoopStatement() override;
+
+  /// @returns the for-loop condition expression
+  const Expression* Condition() const { return condition_; }
+
+  /// Sets the for-loop condition expression
+  /// @param condition the for-loop condition expression
+  void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+  const Expression* condition_ = nullptr;
 };
 
 }  // namespace sem
diff --git a/src/sem/if_statement.h b/src/sem/if_statement.h
index 6c25fca..a8c9c2e 100644
--- a/src/sem/if_statement.h
+++ b/src/sem/if_statement.h
@@ -23,6 +23,9 @@
 class IfStatement;
 class ElseStatement;
 }  // namespace ast
+namespace sem {
+class Expression;
+}  // namespace sem
 }  // namespace tint
 
 namespace tint {
@@ -41,6 +44,16 @@
 
   /// Destructor
   ~IfStatement() override;
+
+  /// @returns the if-statement condition expression
+  const Expression* Condition() const { return condition_; }
+
+  /// Sets the if-statement condition expression
+  /// @param condition the if condition expression
+  void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+  const Expression* condition_ = nullptr;
 };
 
 /// Holds semantic information about an else statement
@@ -56,6 +69,16 @@
 
   /// Destructor
   ~ElseStatement() override;
+
+  /// @returns the else-statement condition expression
+  const Expression* Condition() const { return condition_; }
+
+  /// Sets the else-statement condition expression
+  /// @param condition the else condition expression
+  void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+  const Expression* condition_ = nullptr;
 };
 
 }  // namespace sem
diff --git a/src/sem/member_accessor_expression.h b/src/sem/member_accessor_expression.h
index fcc6c6f..6d444f0 100644
--- a/src/sem/member_accessor_expression.h
+++ b/src/sem/member_accessor_expression.h
@@ -82,7 +82,7 @@
   /// Constructor
   /// @param declaration the AST node
   /// @param type the resolved type of the expression
-  /// @param statement the statement that
+  /// @param statement the statement that owns this expression
   /// @param indices the swizzle indices
   Swizzle(const ast::MemberAccessorExpression* declaration,
           const sem::Type* type,
diff --git a/src/sem/variable.h b/src/sem/variable.h
index a389eed..ac7dac1 100644
--- a/src/sem/variable.h
+++ b/src/sem/variable.h
@@ -234,7 +234,7 @@
   const sem::Variable* Variable() const { return variable_; }
 
  private:
-  sem::Variable const* const variable_;
+  const sem::Variable* const variable_;
 };
 
 }  // namespace sem