sem: Replace SwitchCaseBlockStatement with CaseStatement

The SwitchCaseBlockStatement was bound to the BlockStatement of an ast::CaseStatement, but we had nothing that mapped to the actual ast::CaseStatement.
sem::CaseStatement replaces sem::SwitchCaseBlockStatement, and has a Block() accessor, providing a superset of the old behavior.

With this, we can now easily validate the `fallthrough` rules directly, instead of scanning the switch case. This keeps the validation more tigtly coupled to the ast / sem nodes.

Change-Id: I0f22eba37bb164b9e071a6166c7a41fc1a5ac532
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71460
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/docs/compound_statements.md b/docs/compound_statements.md
index 7d28b1d..a113cce 100644
--- a/docs/compound_statements.md
+++ b/docs/compound_statements.md
@@ -105,11 +105,15 @@
 ```
 sem::SwitchStatement {
     sem::Expression condition
-    sem::SwitchCaseBlockStatement {
-        sem::Statement statement_a
+    sem::CaseStatement {
+        sem::BlockStatement {
+            sem::Statement statement_a
+        }
     }
-    sem::SwitchCaseBlockStatement {
-        sem::Statement statement_b
+    sem::CaseStatement {
+        sem::BlockStatement {
+            sem::Statement statement_b
+        }
     }
 }
 ```
diff --git a/src/resolver/compound_statement_test.cc b/src/resolver/compound_statement_test.cc
index db91d7c..7de565c 100644
--- a/src/resolver/compound_statement_test.cc
+++ b/src/resolver/compound_statement_test.cc
@@ -343,31 +343,34 @@
   {
     auto* s = Sem().Get(stmt_a);
     ASSERT_NE(s, nullptr);
-    EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
+    EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
     EXPECT_EQ(s->Parent(), s->Block());
-    EXPECT_EQ(s->Parent()->Parent(),
-              s->FindFirstParent<sem::SwitchStatement>());
+    EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
     EXPECT_EQ(s->Parent()->Parent()->Parent(),
+              s->FindFirstParent<sem::SwitchStatement>());
+    EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
               s->FindFirstParent<sem::FunctionBlockStatement>());
   }
   {
     auto* s = Sem().Get(stmt_b);
     ASSERT_NE(s, nullptr);
-    EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
+    EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
     EXPECT_EQ(s->Parent(), s->Block());
-    EXPECT_EQ(s->Parent()->Parent(),
-              s->FindFirstParent<sem::SwitchStatement>());
+    EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
     EXPECT_EQ(s->Parent()->Parent()->Parent(),
+              s->FindFirstParent<sem::SwitchStatement>());
+    EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
               s->FindFirstParent<sem::FunctionBlockStatement>());
   }
   {
     auto* s = Sem().Get(stmt_c);
     ASSERT_NE(s, nullptr);
-    EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
+    EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
     EXPECT_EQ(s->Parent(), s->Block());
-    EXPECT_EQ(s->Parent()->Parent(),
-              s->FindFirstParent<sem::SwitchStatement>());
+    EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
     EXPECT_EQ(s->Parent()->Parent()->Parent(),
+              s->FindFirstParent<sem::SwitchStatement>());
+    EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
               s->FindFirstParent<sem::FunctionBlockStatement>());
   }
 }
diff --git a/src/resolver/control_block_validation_test.cc b/src/resolver/control_block_validation_test.cc
index 890892e..41389dd 100644
--- a/src/resolver/control_block_validation_test.cc
+++ b/src/resolver/control_block_validation_test.cc
@@ -294,8 +294,8 @@
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
-            "12:34 error: a fallthrough statement must not appear as the last "
-            "statement in last clause of a switch");
+            "12:34 error: a fallthrough statement must not be used in the last "
+            "switch case");
 }
 
 TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index e490815..36ca107 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -874,17 +874,20 @@
   return nullptr;
 }
 
-sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
-    const ast::CaseStatement* stmt) {
-  auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
-      stmt->body, current_compound_statement_, current_function_);
+sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
+  auto* sem = builder_->create<sem::CaseStatement>(
+      stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
-    builder_->Sem().Add(stmt->body, sem);
-    Mark(stmt->body);
     for (auto* sel : stmt->selectors) {
       Mark(sel);
     }
-    return Statements(stmt->body->statements);
+    Mark(stmt->body);
+    auto* body = BlockStatement(stmt->body);
+    if (!body) {
+      return false;
+    }
+    sem->SetBlock(body);
+    return true;
   });
 }
 
@@ -2361,7 +2364,9 @@
     const ast::FallthroughStatement* stmt) {
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
-  return StatementScope(stmt, sem, [&] { return true; });
+  return StatementScope(stmt, sem, [&] {
+    return ValidateFallthroughStatement(sem);
+  });
 }
 
 bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index f2a2fc8..84565ee 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -60,13 +60,13 @@
 class Array;
 class Atomic;
 class BlockStatement;
+class CaseStatement;
 class ElseStatement;
 class ForLoopStatement;
 class IfStatement;
 class Intrinsic;
 class LoopStatement;
 class Statement;
-class SwitchCaseBlockStatement;
 class SwitchStatement;
 class TypeConstructor;
 }  // namespace sem
@@ -209,7 +209,7 @@
   sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
   sem::Statement* BreakStatement(const ast::BreakStatement*);
   sem::Statement* CallStatement(const ast::CallStatement*);
-  sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
+  sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
   sem::Statement* ContinueStatement(const ast::ContinueStatement*);
   sem::Statement* DiscardStatement(const ast::DiscardStatement*);
   sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
@@ -238,14 +238,15 @@
   bool ValidateAtomicVariable(const sem::Variable* var);
   bool ValidateAssignment(const ast::AssignmentStatement* a);
   bool ValidateBreakStatement(const sem::Statement* stmt);
-  bool ValidateContinueStatement(const sem::Statement* stmt);
-  bool ValidateDiscardStatement(const sem::Statement* stmt);
   bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
                                  const sem::Type* storage_type,
                                  const bool is_input);
+  bool ValidateContinueStatement(const sem::Statement* stmt);
+  bool ValidateDiscardStatement(const sem::Statement* stmt);
   bool ValidateElseStatement(const sem::ElseStatement* stmt);
   bool ValidateEntryPoint(const sem::Function* func);
   bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
+  bool ValidateFallthroughStatement(const sem::Statement* stmt);
   bool ValidateFunction(const sem::Function* func);
   bool ValidateFunctionCall(const sem::Call* call);
   bool ValidateGlobalVariable(const sem::Variable* var);
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index b7cb04e..eb2be86 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1346,7 +1346,7 @@
 
 bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
   if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
-      !stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
+      !stmt->FindFirstParent<sem::CaseStatement>()) {
     AddError("break statement must be in a loop or switch case",
              stmt->Declaration()->source);
     return false;
@@ -1385,6 +1385,29 @@
   return true;
 }
 
+bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
+  if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
+    if (auto* c = As<sem::CaseStatement>(block->Parent())) {
+      if (block->Declaration()->Last() == stmt->Declaration()) {
+        if (auto* s = As<sem::SwitchStatement>(c->Parent())) {
+          if (c->Declaration() != s->Declaration()->body.back()) {
+            return true;
+          }
+          AddError(
+              "a fallthrough statement must not be used in the last switch "
+              "case",
+              stmt->Declaration()->source);
+          return false;
+        }
+      }
+    }
+  }
+  AddError(
+      "fallthrough must only be used as the last statement of a case block",
+      stmt->Declaration()->source);
+  return false;
+}
+
 bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
   if (auto* cond = stmt->Condition()) {
     auto* cond_ty = cond->Type()->UnwrapRef();
@@ -2231,18 +2254,6 @@
     return false;
   }
 
-  if (!s->body.empty()) {
-    auto* last_clause = s->body.back()->As<ast::CaseStatement>();
-    auto* last_stmt = last_clause->body->Last();
-    if (last_stmt && last_stmt->Is<ast::FallthroughStatement>()) {
-      AddError(
-          "a fallthrough statement must not appear as "
-          "the last statement in last clause of a switch",
-          last_stmt->source);
-      return false;
-    }
-  }
-
   return true;
 }
 
diff --git a/src/sem/switch_statement.cc b/src/sem/switch_statement.cc
index fe13c3e..9a911a2 100644
--- a/src/sem/switch_statement.cc
+++ b/src/sem/switch_statement.cc
@@ -16,8 +16,8 @@
 
 #include "src/program_builder.h"
 
+TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseStatement);
 TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement);
-TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement);
 
 namespace tint {
 namespace sem {
@@ -32,15 +32,22 @@
 
 SwitchStatement::~SwitchStatement() = default;
 
-SwitchCaseBlockStatement::SwitchCaseBlockStatement(
-    const ast::BlockStatement* declaration,
-    const CompoundStatement* parent,
-    const sem::Function* function)
+const ast::SwitchStatement* SwitchStatement::Declaration() const {
+  return static_cast<const ast::SwitchStatement*>(Base::Declaration());
+}
+
+CaseStatement::CaseStatement(const ast::CaseStatement* declaration,
+                             const CompoundStatement* parent,
+                             const sem::Function* function)
     : Base(declaration, parent, function) {
   TINT_ASSERT(Semantic, parent);
   TINT_ASSERT(Semantic, function);
 }
-SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;
+CaseStatement::~CaseStatement() = default;
+
+const ast::CaseStatement* CaseStatement::Declaration() const {
+  return static_cast<const ast::CaseStatement*>(Base::Declaration());
+}
 
 }  // namespace sem
 }  // namespace tint
diff --git a/src/sem/switch_statement.h b/src/sem/switch_statement.h
index 8e5a2cd..49da6e9 100644
--- a/src/sem/switch_statement.h
+++ b/src/sem/switch_statement.h
@@ -20,6 +20,7 @@
 // Forward declarations
 namespace tint {
 namespace ast {
+class CaseStatement;
 class SwitchStatement;
 }  // namespace ast
 }  // namespace tint
@@ -40,22 +41,36 @@
 
   /// Destructor
   ~SwitchStatement() override;
+
+  /// @return the AST node for this statement
+  const ast::SwitchStatement* Declaration() const;
 };
 
-/// Holds semantic information about a switch case block
-class SwitchCaseBlockStatement
-    : public Castable<SwitchCaseBlockStatement, BlockStatement> {
+/// Holds semantic information about a switch case statement
+class CaseStatement : public Castable<CaseStatement, CompoundStatement> {
  public:
   /// Constructor
-  /// @param declaration the AST node for this block statement
+  /// @param declaration the AST node for this case statement
   /// @param parent the owning statement
   /// @param function the owning function
-  SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
-                           const CompoundStatement* parent,
-                           const sem::Function* function);
+  CaseStatement(const ast::CaseStatement* declaration,
+                const CompoundStatement* parent,
+                const sem::Function* function);
 
   /// Destructor
-  ~SwitchCaseBlockStatement() override;
+  ~CaseStatement() override;
+
+  /// @return the AST node for this statement
+  const ast::CaseStatement* Declaration() const;
+
+  /// @param body the case body block statement
+  void SetBlock(const BlockStatement* body) { body_ = body; }
+
+  /// @returns the case body block statement
+  const BlockStatement* Body() const { return body_; }
+
+ private:
+  const BlockStatement* body_ = nullptr;
 };
 
 }  // namespace sem
diff --git a/src/writer/wgsl/generator_impl_fallthrough_test.cc b/src/writer/wgsl/generator_impl_fallthrough_test.cc
index 0325256..aa1d037 100644
--- a/src/writer/wgsl/generator_impl_fallthrough_test.cc
+++ b/src/writer/wgsl/generator_impl_fallthrough_test.cc
@@ -23,7 +23,9 @@
 
 TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) {
   auto* f = create<ast::FallthroughStatement>();
-  WrapInFunction(f);
+  WrapInFunction(Switch(1,                        //
+                        Case(Expr(1), Block(f)),  //
+                        DefaultCase()));
 
   GeneratorImpl& gen = Build();