Replace Statement::(Is|As)* with Castable

Change-Id: I5520752a4b5844be0ecac7921616893d123b246a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34315
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/assignment_statement.cc b/src/ast/assignment_statement.cc
index cfad841..3fe9a8b 100644
--- a/src/ast/assignment_statement.cc
+++ b/src/ast/assignment_statement.cc
@@ -31,10 +31,6 @@
 
 AssignmentStatement::~AssignmentStatement() = default;
 
-bool AssignmentStatement::IsAssign() const {
-  return true;
-}
-
 bool AssignmentStatement::IsValid() const {
   if (lhs_ == nullptr || !lhs_->IsValid())
     return false;
diff --git a/src/ast/assignment_statement.h b/src/ast/assignment_statement.h
index dd68ca0..08c972e 100644
--- a/src/ast/assignment_statement.h
+++ b/src/ast/assignment_statement.h
@@ -55,9 +55,6 @@
   /// @returns the right side expression
   Expression* rhs() const { return rhs_; }
 
-  /// @returns true if this is an assignment statement
-  bool IsAssign() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/assignment_statement_test.cc b/src/ast/assignment_statement_test.cc
index fa01c7e..ad430b7 100644
--- a/src/ast/assignment_statement_test.cc
+++ b/src/ast/assignment_statement_test.cc
@@ -47,7 +47,7 @@
   auto* rhs = create<ast::IdentifierExpression>("rhs");
 
   AssignmentStatement stmt(lhs, rhs);
-  EXPECT_TRUE(stmt.IsAssign());
+  EXPECT_TRUE(stmt.Is<AssignmentStatement>());
 }
 
 TEST_F(AssignmentStatementTest, IsValid) {
diff --git a/src/ast/block_statement.cc b/src/ast/block_statement.cc
index a5aa5a9..7ab8d9b 100644
--- a/src/ast/block_statement.cc
+++ b/src/ast/block_statement.cc
@@ -25,10 +25,6 @@
 
 BlockStatement::~BlockStatement() = default;
 
-bool BlockStatement::IsBlock() const {
-  return true;
-}
-
 bool BlockStatement::IsValid() const {
   for (auto* stmt : *this) {
     if (stmt == nullptr || !stmt->IsValid()) {
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h
index 07189c7..0b8dac5 100644
--- a/src/ast/block_statement.h
+++ b/src/ast/block_statement.h
@@ -87,9 +87,6 @@
     return statements_.end();
   }
 
-  /// @returns true if this is a block statement
-  bool IsBlock() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/block_statement_test.cc b/src/ast/block_statement_test.cc
index 1c25885..f0941b7 100644
--- a/src/ast/block_statement_test.cc
+++ b/src/ast/block_statement_test.cc
@@ -65,7 +65,7 @@
 
 TEST_F(BlockStatementTest, IsBlock) {
   BlockStatement b;
-  EXPECT_TRUE(b.IsBlock());
+  EXPECT_TRUE(b.Is<BlockStatement>());
 }
 
 TEST_F(BlockStatementTest, IsValid) {
diff --git a/src/ast/break_statement.cc b/src/ast/break_statement.cc
index ed70840..1894ba4 100644
--- a/src/ast/break_statement.cc
+++ b/src/ast/break_statement.cc
@@ -25,10 +25,6 @@
 
 BreakStatement::~BreakStatement() = default;
 
-bool BreakStatement::IsBreak() const {
-  return true;
-}
-
 bool BreakStatement::IsValid() const {
   return true;
 }
diff --git a/src/ast/break_statement.h b/src/ast/break_statement.h
index 13a7d57..2a72c03 100644
--- a/src/ast/break_statement.h
+++ b/src/ast/break_statement.h
@@ -32,9 +32,6 @@
   BreakStatement(BreakStatement&&);
   ~BreakStatement() override;
 
-  /// @returns true if this is an break statement
-  bool IsBreak() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/break_statement_test.cc b/src/ast/break_statement_test.cc
index e9e44cb..8021348 100644
--- a/src/ast/break_statement_test.cc
+++ b/src/ast/break_statement_test.cc
@@ -31,7 +31,7 @@
 
 TEST_F(BreakStatementTest, IsBreak) {
   BreakStatement stmt;
-  EXPECT_TRUE(stmt.IsBreak());
+  EXPECT_TRUE(stmt.Is<BreakStatement>());
 }
 
 TEST_F(BreakStatementTest, IsValid) {
diff --git a/src/ast/call_statement.cc b/src/ast/call_statement.cc
index 0ff5188..c37c99a 100644
--- a/src/ast/call_statement.cc
+++ b/src/ast/call_statement.cc
@@ -27,10 +27,6 @@
 
 CallStatement::~CallStatement() = default;
 
-bool CallStatement::IsCall() const {
-  return true;
-}
-
 bool CallStatement::IsValid() const {
   return call_ != nullptr && call_->IsValid();
 }
diff --git a/src/ast/call_statement.h b/src/ast/call_statement.h
index c74f156..1657388 100644
--- a/src/ast/call_statement.h
+++ b/src/ast/call_statement.h
@@ -42,9 +42,6 @@
   /// @returns the call expression
   CallExpression* expr() const { return call_; }
 
-  /// @returns true if this is a call statement
-  bool IsCall() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/call_statement_test.cc b/src/ast/call_statement_test.cc
index 363918d..528ecc1 100644
--- a/src/ast/call_statement_test.cc
+++ b/src/ast/call_statement_test.cc
@@ -34,7 +34,7 @@
 
 TEST_F(CallStatementTest, IsCall) {
   CallStatement c;
-  EXPECT_TRUE(c.IsCall());
+  EXPECT_TRUE(c.Is<CallStatement>());
 }
 
 TEST_F(CallStatementTest, IsValid) {
diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc
index 84cda94..2d50185 100644
--- a/src/ast/case_statement.cc
+++ b/src/ast/case_statement.cc
@@ -31,10 +31,6 @@
 
 CaseStatement::~CaseStatement() = default;
 
-bool CaseStatement::IsCase() const {
-  return true;
-}
-
 bool CaseStatement::IsValid() const {
   return body_ != nullptr && body_->IsValid();
 }
diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h
index aacd2f0..ef39af0 100644
--- a/src/ast/case_statement.h
+++ b/src/ast/case_statement.h
@@ -70,9 +70,6 @@
   /// @returns the case body
   BlockStatement* body() { return body_; }
 
-  /// @returns true if this is a case statement
-  bool IsCase() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc
index 30c298e..d5c77b6 100644
--- a/src/ast/case_statement_test.cc
+++ b/src/ast/case_statement_test.cc
@@ -99,7 +99,7 @@
 
 TEST_F(CaseStatementTest, IsCase) {
   CaseStatement c(create<ast::BlockStatement>());
-  EXPECT_TRUE(c.IsCase());
+  EXPECT_TRUE(c.Is<ast::CaseStatement>());
 }
 
 TEST_F(CaseStatementTest, IsValid) {
diff --git a/src/ast/continue_statement.cc b/src/ast/continue_statement.cc
index f66a958..1f9da2a 100644
--- a/src/ast/continue_statement.cc
+++ b/src/ast/continue_statement.cc
@@ -25,10 +25,6 @@
 
 ContinueStatement::~ContinueStatement() = default;
 
-bool ContinueStatement::IsContinue() const {
-  return true;
-}
-
 bool ContinueStatement::IsValid() const {
   return true;
 }
diff --git a/src/ast/continue_statement.h b/src/ast/continue_statement.h
index 1eb93f9..b2a01f9 100644
--- a/src/ast/continue_statement.h
+++ b/src/ast/continue_statement.h
@@ -35,9 +35,6 @@
   ContinueStatement(ContinueStatement&&);
   ~ContinueStatement() override;
 
-  /// @returns true if this is an continue statement
-  bool IsContinue() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/continue_statement_test.cc b/src/ast/continue_statement_test.cc
index 5b1c58e..577d50b 100644
--- a/src/ast/continue_statement_test.cc
+++ b/src/ast/continue_statement_test.cc
@@ -31,7 +31,7 @@
 
 TEST_F(ContinueStatementTest, IsContinue) {
   ContinueStatement stmt;
-  EXPECT_TRUE(stmt.IsContinue());
+  EXPECT_TRUE(stmt.Is<ContinueStatement>());
 }
 
 TEST_F(ContinueStatementTest, IsValid) {
diff --git a/src/ast/discard_statement.cc b/src/ast/discard_statement.cc
index e2c17fa..b70856b 100644
--- a/src/ast/discard_statement.cc
+++ b/src/ast/discard_statement.cc
@@ -25,10 +25,6 @@
 
 DiscardStatement::~DiscardStatement() = default;
 
-bool DiscardStatement::IsDiscard() const {
-  return true;
-}
-
 bool DiscardStatement::IsValid() const {
   return true;
 }
diff --git a/src/ast/discard_statement.h b/src/ast/discard_statement.h
index 6f2e42a..ba7b398 100644
--- a/src/ast/discard_statement.h
+++ b/src/ast/discard_statement.h
@@ -32,9 +32,6 @@
   DiscardStatement(DiscardStatement&&);
   ~DiscardStatement() override;
 
-  /// @returns true if this is a discard statement
-  bool IsDiscard() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/discard_statement_test.cc b/src/ast/discard_statement_test.cc
index a60a9a4..fdf444f 100644
--- a/src/ast/discard_statement_test.cc
+++ b/src/ast/discard_statement_test.cc
@@ -43,7 +43,7 @@
 
 TEST_F(DiscardStatementTest, IsDiscard) {
   DiscardStatement stmt;
-  EXPECT_TRUE(stmt.IsDiscard());
+  EXPECT_TRUE(stmt.Is<DiscardStatement>());
 }
 
 TEST_F(DiscardStatementTest, IsValid) {
diff --git a/src/ast/else_statement.cc b/src/ast/else_statement.cc
index 00279c3..937755b 100644
--- a/src/ast/else_statement.cc
+++ b/src/ast/else_statement.cc
@@ -34,10 +34,6 @@
 
 ElseStatement::~ElseStatement() = default;
 
-bool ElseStatement::IsElse() const {
-  return true;
-}
-
 bool ElseStatement::IsValid() const {
   if (body_ == nullptr || !body_->IsValid()) {
     return false;
diff --git a/src/ast/else_statement.h b/src/ast/else_statement.h
index b75904d..60a675a 100644
--- a/src/ast/else_statement.h
+++ b/src/ast/else_statement.h
@@ -67,9 +67,6 @@
   /// @returns the else body
   BlockStatement* body() { return body_; }
 
-  /// @returns true if this is a else statement
-  bool IsElse() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/else_statement_test.cc b/src/ast/else_statement_test.cc
index 948a94b..f90d7b5 100644
--- a/src/ast/else_statement_test.cc
+++ b/src/ast/else_statement_test.cc
@@ -51,7 +51,7 @@
 
 TEST_F(ElseStatementTest, IsElse) {
   ElseStatement e(create<BlockStatement>());
-  EXPECT_TRUE(e.IsElse());
+  EXPECT_TRUE(e.Is<ElseStatement>());
 }
 
 TEST_F(ElseStatementTest, HasCondition) {
diff --git a/src/ast/fallthrough_statement.cc b/src/ast/fallthrough_statement.cc
index 6c84ab3..cd5f00d 100644
--- a/src/ast/fallthrough_statement.cc
+++ b/src/ast/fallthrough_statement.cc
@@ -26,10 +26,6 @@
 
 FallthroughStatement::~FallthroughStatement() = default;
 
-bool FallthroughStatement::IsFallthrough() const {
-  return true;
-}
-
 bool FallthroughStatement::IsValid() const {
   return true;
 }
diff --git a/src/ast/fallthrough_statement.h b/src/ast/fallthrough_statement.h
index f586666..5b0bc81 100644
--- a/src/ast/fallthrough_statement.h
+++ b/src/ast/fallthrough_statement.h
@@ -32,9 +32,6 @@
   FallthroughStatement(FallthroughStatement&&);
   ~FallthroughStatement() override;
 
-  /// @returns true if this is an fallthrough statement
-  bool IsFallthrough() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/fallthrough_statement_test.cc b/src/ast/fallthrough_statement_test.cc
index c5a936d..a41f5be 100644
--- a/src/ast/fallthrough_statement_test.cc
+++ b/src/ast/fallthrough_statement_test.cc
@@ -39,7 +39,7 @@
 
 TEST_F(FallthroughStatementTest, IsFallthrough) {
   FallthroughStatement stmt;
-  EXPECT_TRUE(stmt.IsFallthrough());
+  EXPECT_TRUE(stmt.Is<FallthroughStatement>());
 }
 
 TEST_F(FallthroughStatementTest, IsValid) {
diff --git a/src/ast/if_statement.cc b/src/ast/if_statement.cc
index bcf156c..c8406b7 100644
--- a/src/ast/if_statement.cc
+++ b/src/ast/if_statement.cc
@@ -31,10 +31,6 @@
 
 IfStatement::~IfStatement() = default;
 
-bool IfStatement::IsIf() const {
-  return true;
-}
-
 bool IfStatement::IsValid() const {
   if (condition_ == nullptr || !condition_->IsValid()) {
     return false;
diff --git a/src/ast/if_statement.h b/src/ast/if_statement.h
index 4b8a0fb..c6bb79f 100644
--- a/src/ast/if_statement.h
+++ b/src/ast/if_statement.h
@@ -71,9 +71,6 @@
   /// @returns true if there are else statements
   bool has_else_statements() const { return !else_statements_.empty(); }
 
-  /// @returns true if this is a if statement
-  bool IsIf() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/if_statement_test.cc b/src/ast/if_statement_test.cc
index 5c6d04d..45e99ba 100644
--- a/src/ast/if_statement_test.cc
+++ b/src/ast/if_statement_test.cc
@@ -49,7 +49,7 @@
 
 TEST_F(IfStatementTest, IsIf) {
   IfStatement stmt(nullptr, create<BlockStatement>());
-  EXPECT_TRUE(stmt.IsIf());
+  EXPECT_TRUE(stmt.Is<IfStatement>());
 }
 
 TEST_F(IfStatementTest, IsValid) {
diff --git a/src/ast/loop_statement.cc b/src/ast/loop_statement.cc
index fbffebe..220201f 100644
--- a/src/ast/loop_statement.cc
+++ b/src/ast/loop_statement.cc
@@ -29,10 +29,6 @@
 
 LoopStatement::~LoopStatement() = default;
 
-bool LoopStatement::IsLoop() const {
-  return true;
-}
-
 bool LoopStatement::IsValid() const {
   if (body_ == nullptr || !body_->IsValid()) {
     return false;
diff --git a/src/ast/loop_statement.h b/src/ast/loop_statement.h
index f1ed192..a107e08 100644
--- a/src/ast/loop_statement.h
+++ b/src/ast/loop_statement.h
@@ -62,9 +62,6 @@
     return continuing_ != nullptr && !continuing_->empty();
   }
 
-  /// @returns true if this is a loop statement
-  bool IsLoop() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/loop_statement_test.cc b/src/ast/loop_statement_test.cc
index 6e242e3..12001de 100644
--- a/src/ast/loop_statement_test.cc
+++ b/src/ast/loop_statement_test.cc
@@ -57,7 +57,7 @@
 
 TEST_F(LoopStatementTest, IsLoop) {
   LoopStatement l(create<BlockStatement>(), create<BlockStatement>());
-  EXPECT_TRUE(l.IsLoop());
+  EXPECT_TRUE(l.Is<LoopStatement>());
 }
 
 TEST_F(LoopStatementTest, HasContinuing_WithoutContinuing) {
diff --git a/src/ast/return_statement.cc b/src/ast/return_statement.cc
index e395dfc..138618f 100644
--- a/src/ast/return_statement.cc
+++ b/src/ast/return_statement.cc
@@ -30,10 +30,6 @@
 
 ReturnStatement::~ReturnStatement() = default;
 
-bool ReturnStatement::IsReturn() const {
-  return true;
-}
-
 bool ReturnStatement::IsValid() const {
   if (value_ != nullptr) {
     return value_->IsValid();
diff --git a/src/ast/return_statement.h b/src/ast/return_statement.h
index a1177f7..f2426fa 100644
--- a/src/ast/return_statement.h
+++ b/src/ast/return_statement.h
@@ -51,9 +51,6 @@
   /// @returns true if the return has a value
   bool has_value() const { return value_ != nullptr; }
 
-  /// @returns true if this is a return statement
-  bool IsReturn() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/return_statement_test.cc b/src/ast/return_statement_test.cc
index 42261c4..384b49e 100644
--- a/src/ast/return_statement_test.cc
+++ b/src/ast/return_statement_test.cc
@@ -41,7 +41,7 @@
 
 TEST_F(ReturnStatementTest, IsReturn) {
   ReturnStatement r;
-  EXPECT_TRUE(r.IsReturn());
+  EXPECT_TRUE(r.Is<ReturnStatement>());
 }
 
 TEST_F(ReturnStatementTest, HasValue_WithoutValue) {
diff --git a/src/ast/statement.cc b/src/ast/statement.cc
index f44286f..aa45876 100644
--- a/src/ast/statement.cc
+++ b/src/ast/statement.cc
@@ -42,247 +42,51 @@
 
 Statement::~Statement() = default;
 
-bool Statement::IsAssign() const {
-  return false;
-}
-
-bool Statement::IsBlock() const {
-  return false;
-}
-
-bool Statement::IsBreak() const {
-  return false;
-}
-
-bool Statement::IsCase() const {
-  return false;
-}
-
-bool Statement::IsCall() const {
-  return false;
-}
-
-bool Statement::IsContinue() const {
-  return false;
-}
-
-bool Statement::IsDiscard() const {
-  return false;
-}
-
-bool Statement::IsElse() const {
-  return false;
-}
-
-bool Statement::IsFallthrough() const {
-  return false;
-}
-
-bool Statement::IsIf() const {
-  return false;
-}
-
-bool Statement::IsLoop() const {
-  return false;
-}
-
-bool Statement::IsReturn() const {
-  return false;
-}
-
-bool Statement::IsSwitch() const {
-  return false;
-}
-
-bool Statement::IsVariableDecl() const {
-  return false;
-}
-
 const char* Statement::Name() const {
-  if (IsAssign()) {
+  if (Is<AssignmentStatement>()) {
     return "assignment statement";
   }
-  if (IsBlock()) {
+  if (Is<BlockStatement>()) {
     return "block statement";
   }
-  if (IsBreak()) {
+  if (Is<BreakStatement>()) {
     return "break statement";
   }
-  if (IsCase()) {
+  if (Is<CaseStatement>()) {
     return "case statement";
   }
-  if (IsCall()) {
+  if (Is<CallStatement>()) {
     return "function call";
   }
-  if (IsContinue()) {
+  if (Is<ContinueStatement>()) {
     return "continue statement";
   }
-  if (IsDiscard()) {
+  if (Is<DiscardStatement>()) {
     return "discard statement";
   }
-  if (IsElse()) {
+  if (Is<ElseStatement>()) {
     return "else statement";
   }
-  if (IsFallthrough()) {
+  if (Is<FallthroughStatement>()) {
     return "fallthrough statement";
   }
-  if (IsIf()) {
+  if (Is<IfStatement>()) {
     return "if statement";
   }
-  if (IsLoop()) {
+  if (Is<LoopStatement>()) {
     return "loop statement";
   }
-  if (IsReturn()) {
+  if (Is<ReturnStatement>()) {
     return "return statement";
   }
-  if (IsSwitch()) {
+  if (Is<SwitchStatement>()) {
     return "switch statement";
   }
-  if (IsVariableDecl()) {
+  if (Is<VariableDeclStatement>()) {
     return "variable declaration";
   }
   return "statement";
 }
 
-const AssignmentStatement* Statement::AsAssign() const {
-  assert(IsAssign());
-  return static_cast<const AssignmentStatement*>(this);
-}
-
-const BlockStatement* Statement::AsBlock() const {
-  assert(IsBlock());
-  return static_cast<const BlockStatement*>(this);
-}
-
-const BreakStatement* Statement::AsBreak() const {
-  assert(IsBreak());
-  return static_cast<const BreakStatement*>(this);
-}
-
-const CallStatement* Statement::AsCall() const {
-  assert(IsCall());
-  return static_cast<const CallStatement*>(this);
-}
-
-const CaseStatement* Statement::AsCase() const {
-  assert(IsCase());
-  return static_cast<const CaseStatement*>(this);
-}
-
-const ContinueStatement* Statement::AsContinue() const {
-  assert(IsContinue());
-  return static_cast<const ContinueStatement*>(this);
-}
-
-const DiscardStatement* Statement::AsDiscard() const {
-  assert(IsDiscard());
-  return static_cast<const DiscardStatement*>(this);
-}
-
-const ElseStatement* Statement::AsElse() const {
-  assert(IsElse());
-  return static_cast<const ElseStatement*>(this);
-}
-
-const FallthroughStatement* Statement::AsFallthrough() const {
-  assert(IsFallthrough());
-  return static_cast<const FallthroughStatement*>(this);
-}
-
-const IfStatement* Statement::AsIf() const {
-  assert(IsIf());
-  return static_cast<const IfStatement*>(this);
-}
-
-const LoopStatement* Statement::AsLoop() const {
-  assert(IsLoop());
-  return static_cast<const LoopStatement*>(this);
-}
-
-const ReturnStatement* Statement::AsReturn() const {
-  assert(IsReturn());
-  return static_cast<const ReturnStatement*>(this);
-}
-
-const SwitchStatement* Statement::AsSwitch() const {
-  assert(IsSwitch());
-  return static_cast<const SwitchStatement*>(this);
-}
-
-const VariableDeclStatement* Statement::AsVariableDecl() const {
-  assert(IsVariableDecl());
-  return static_cast<const VariableDeclStatement*>(this);
-}
-
-AssignmentStatement* Statement::AsAssign() {
-  assert(IsAssign());
-  return static_cast<AssignmentStatement*>(this);
-}
-
-BlockStatement* Statement::AsBlock() {
-  assert(IsBlock());
-  return static_cast<BlockStatement*>(this);
-}
-
-BreakStatement* Statement::AsBreak() {
-  assert(IsBreak());
-  return static_cast<BreakStatement*>(this);
-}
-
-CallStatement* Statement::AsCall() {
-  assert(IsCall());
-  return static_cast<CallStatement*>(this);
-}
-
-CaseStatement* Statement::AsCase() {
-  assert(IsCase());
-  return static_cast<CaseStatement*>(this);
-}
-
-ContinueStatement* Statement::AsContinue() {
-  assert(IsContinue());
-  return static_cast<ContinueStatement*>(this);
-}
-
-DiscardStatement* Statement::AsDiscard() {
-  assert(IsDiscard());
-  return static_cast<DiscardStatement*>(this);
-}
-
-ElseStatement* Statement::AsElse() {
-  assert(IsElse());
-  return static_cast<ElseStatement*>(this);
-}
-
-FallthroughStatement* Statement::AsFallthrough() {
-  assert(IsFallthrough());
-  return static_cast<FallthroughStatement*>(this);
-}
-
-IfStatement* Statement::AsIf() {
-  assert(IsIf());
-  return static_cast<IfStatement*>(this);
-}
-
-LoopStatement* Statement::AsLoop() {
-  assert(IsLoop());
-  return static_cast<LoopStatement*>(this);
-}
-
-ReturnStatement* Statement::AsReturn() {
-  assert(IsReturn());
-  return static_cast<ReturnStatement*>(this);
-}
-
-SwitchStatement* Statement::AsSwitch() {
-  assert(IsSwitch());
-  return static_cast<SwitchStatement*>(this);
-}
-
-VariableDeclStatement* Statement::AsVariableDecl() {
-  assert(IsVariableDecl());
-  return static_cast<VariableDeclStatement*>(this);
-}
-
 }  // namespace ast
 }  // namespace tint
diff --git a/src/ast/statement.h b/src/ast/statement.h
index 7e791ba..b38c74c 100644
--- a/src/ast/statement.h
+++ b/src/ast/statement.h
@@ -23,116 +23,14 @@
 namespace tint {
 namespace ast {
 
-class AssignmentStatement;
-class BlockStatement;
-class BreakStatement;
-class CallStatement;
-class CaseStatement;
-class ContinueStatement;
-class DiscardStatement;
-class ElseStatement;
-class FallthroughStatement;
-class IfStatement;
-class LoopStatement;
-class ReturnStatement;
-class SwitchStatement;
-class VariableDeclStatement;
-
 /// Base statement class
 class Statement : public Castable<Statement, Node> {
  public:
   ~Statement() override;
 
-  /// @returns true if this is an assign statement
-  virtual bool IsAssign() const;
-  /// @returns true if this is a block statement
-  virtual bool IsBlock() const;
-  /// @returns true if this is a break statement
-  virtual bool IsBreak() const;
-  /// @returns true if this is a call statement
-  virtual bool IsCall() const;
-  /// @returns true if this is a case statement
-  virtual bool IsCase() const;
-  /// @returns true if this is a continue statement
-  virtual bool IsContinue() const;
-  /// @returns true if this is a discard statement
-  virtual bool IsDiscard() const;
-  /// @returns true if this is an else statement
-  virtual bool IsElse() const;
-  /// @returns true if this is a fallthrough statement
-  virtual bool IsFallthrough() const;
-  /// @returns true if this is an if statement
-  virtual bool IsIf() const;
-  /// @returns true if this is a loop statement
-  virtual bool IsLoop() const;
-  /// @returns true if this is a return statement
-  virtual bool IsReturn() const;
-  /// @returns true if this is a switch statement
-  virtual bool IsSwitch() const;
-  /// @returns true if this is an variable statement
-  virtual bool IsVariableDecl() const;
-
   /// @returns the human readable name for the statement type.
   const char* Name() const;
 
-  /// @returns the statement as a const assign statement
-  const AssignmentStatement* AsAssign() const;
-  /// @returns the statement as a const block statement
-  const BlockStatement* AsBlock() const;
-  /// @returns the statement as a const break statement
-  const BreakStatement* AsBreak() const;
-  /// @returns the statement as a const call statement
-  const CallStatement* AsCall() const;
-  /// @returns the statement as a const case statement
-  const CaseStatement* AsCase() const;
-  /// @returns the statement as a const continue statement
-  const ContinueStatement* AsContinue() const;
-  /// @returns the statement as a const discard statement
-  const DiscardStatement* AsDiscard() const;
-  /// @returns the statement as a const else statement
-  const ElseStatement* AsElse() const;
-  /// @returns the statement as a const fallthrough statement
-  const FallthroughStatement* AsFallthrough() const;
-  /// @returns the statement as a const if statement
-  const IfStatement* AsIf() const;
-  /// @returns the statement as a const loop statement
-  const LoopStatement* AsLoop() const;
-  /// @returns the statement as a const return statement
-  const ReturnStatement* AsReturn() const;
-  /// @returns the statement as a const switch statement
-  const SwitchStatement* AsSwitch() const;
-  /// @returns the statement as a const variable statement
-  const VariableDeclStatement* AsVariableDecl() const;
-
-  /// @returns the statement as an assign statement
-  AssignmentStatement* AsAssign();
-  /// @returns the statement as a block statement
-  BlockStatement* AsBlock();
-  /// @returns the statement as a break statement
-  BreakStatement* AsBreak();
-  /// @returns the statement as a call statement
-  CallStatement* AsCall();
-  /// @returns the statement as a case statement
-  CaseStatement* AsCase();
-  /// @returns the statement as a continue statement
-  ContinueStatement* AsContinue();
-  /// @returns the statement as a discard statement
-  DiscardStatement* AsDiscard();
-  /// @returns the statement as a else statement
-  ElseStatement* AsElse();
-  /// @returns the statement as a fallthrough statement
-  FallthroughStatement* AsFallthrough();
-  /// @returns the statement as a if statement
-  IfStatement* AsIf();
-  /// @returns the statement as a loop statement
-  LoopStatement* AsLoop();
-  /// @returns the statement as a return statement
-  ReturnStatement* AsReturn();
-  /// @returns the statement as a switch statement
-  SwitchStatement* AsSwitch();
-  /// @returns the statement as an variable statement
-  VariableDeclStatement* AsVariableDecl();
-
  protected:
   /// Constructor
   Statement();
diff --git a/src/ast/switch_statement.cc b/src/ast/switch_statement.cc
index 65bb68f..ed0d730 100644
--- a/src/ast/switch_statement.cc
+++ b/src/ast/switch_statement.cc
@@ -29,10 +29,6 @@
                                  CaseStatementList body)
     : Base(source), condition_(condition), body_(body) {}
 
-bool SwitchStatement::IsSwitch() const {
-  return true;
-}
-
 SwitchStatement::SwitchStatement(SwitchStatement&&) = default;
 
 SwitchStatement::~SwitchStatement() = default;
diff --git a/src/ast/switch_statement.h b/src/ast/switch_statement.h
index 656918d..df7d570 100644
--- a/src/ast/switch_statement.h
+++ b/src/ast/switch_statement.h
@@ -60,9 +60,6 @@
   /// @returns the Switch body
   const CaseStatementList& body() const { return body_; }
 
-  /// @returns true if this is a switch statement
-  bool IsSwitch() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc
index 49427cf..eadc87c 100644
--- a/src/ast/switch_statement_test.cc
+++ b/src/ast/switch_statement_test.cc
@@ -57,7 +57,7 @@
 
 TEST_F(SwitchStatementTest, IsSwitch) {
   SwitchStatement stmt;
-  EXPECT_TRUE(stmt.IsSwitch());
+  EXPECT_TRUE(stmt.Is<SwitchStatement>());
 }
 
 TEST_F(SwitchStatementTest, IsValid) {
diff --git a/src/ast/variable_decl_statement.cc b/src/ast/variable_decl_statement.cc
index 732c166..9b51bf4 100644
--- a/src/ast/variable_decl_statement.cc
+++ b/src/ast/variable_decl_statement.cc
@@ -30,10 +30,6 @@
 
 VariableDeclStatement::~VariableDeclStatement() = default;
 
-bool VariableDeclStatement::IsVariableDecl() const {
-  return true;
-}
-
 bool VariableDeclStatement::IsValid() const {
   return variable_ != nullptr && variable_->IsValid();
 }
diff --git a/src/ast/variable_decl_statement.h b/src/ast/variable_decl_statement.h
index a233029..d03d5ad 100644
--- a/src/ast/variable_decl_statement.h
+++ b/src/ast/variable_decl_statement.h
@@ -48,9 +48,6 @@
   /// @returns the variable
   Variable* variable() const { return variable_; }
 
-  /// @returns true if this is an variable statement
-  bool IsVariableDecl() const override;
-
   /// @returns true if the node is valid
   bool IsValid() const override;
 
diff --git a/src/ast/variable_decl_statement_test.cc b/src/ast/variable_decl_statement_test.cc
index 8e96e12..9b5f907 100644
--- a/src/ast/variable_decl_statement_test.cc
+++ b/src/ast/variable_decl_statement_test.cc
@@ -44,7 +44,7 @@
 
 TEST_F(VariableDeclStatementTest, IsVariableDecl) {
   VariableDeclStatement s;
-  EXPECT_TRUE(s.IsVariableDecl());
+  EXPECT_TRUE(s.Is<VariableDeclStatement>());
 }
 
 TEST_F(VariableDeclStatementTest, IsValid) {
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index bc2d5c7..133fac0 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -561,8 +561,8 @@
   const auto& top = statements_stack_.back();
   auto* cond = create<ast::IdentifierExpression>(guard_name);
   auto* body = create<ast::BlockStatement>();
-  auto* const guard_stmt =
-      AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
+  auto* const guard_stmt = AddStatement(create<ast::IfStatement>(cond, body))
+                               ->As<ast::IfStatement>();
   PushNewStatementBlock(top.construct_, end_id,
                         [guard_stmt](StatementBlock* s) {
                           guard_stmt->set_body(s->statements_);
@@ -574,8 +574,8 @@
   const auto& top = statements_stack_.back();
   auto* cond = MakeTrue();
   auto* body = create<ast::BlockStatement>();
-  auto* const guard_stmt =
-      AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
+  auto* const guard_stmt = AddStatement(create<ast::IfStatement>(cond, body))
+                               ->As<ast::IfStatement>();
   guard_stmt->set_condition(MakeTrue());
   PushNewStatementBlock(top.construct_, end_id,
                         [guard_stmt](StatementBlock* s) {
@@ -2023,8 +2023,8 @@
       block_info.basic_block->terminator()->GetSingleWordInOperand(0);
   auto* cond = MakeExpression(condition_id).expr;
   auto* body = create<ast::BlockStatement>();
-  auto* const if_stmt =
-      AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
+  auto* const if_stmt = AddStatement(create<ast::IfStatement>(cond, body))
+                            ->As<ast::IfStatement>();
 
   // Generate the code for the condition.
 
@@ -2137,7 +2137,7 @@
   const auto* branch = block_info.basic_block->terminator();
 
   auto* const switch_stmt =
-      AddStatement(create<ast::SwitchStatement>())->AsSwitch();
+      AddStatement(create<ast::SwitchStatement>())->As<ast::SwitchStatement>();
   const auto selector_id = branch->GetSingleWordInOperand(0);
   // Generate the code for the selector.
   auto selector = MakeExpression(selector_id);
@@ -2255,7 +2255,7 @@
   auto* loop =
       AddStatement(create<ast::LoopStatement>(create<ast::BlockStatement>(),
                                               create<ast::BlockStatement>()))
-          ->AsLoop();
+          ->As<ast::LoopStatement>();
   PushNewStatementBlock(
       construct, construct->end_id,
       [loop](StatementBlock* s) { loop->set_body(s->statements_); });
@@ -2266,11 +2266,11 @@
   // A continue construct has the same depth as its associated loop
   // construct. Start a continue construct.
   auto* loop_candidate = LastStatement();
-  if (!loop_candidate->IsLoop()) {
+  if (!loop_candidate->Is<ast::LoopStatement>()) {
     return Fail() << "internal error: starting continue construct, "
                      "expected loop on top of stack";
   }
-  auto* loop = loop_candidate->AsLoop();
+  auto* loop = loop_candidate->As<ast::LoopStatement>();
   PushNewStatementBlock(
       construct, construct->end_id,
       [loop](StatementBlock* s) { loop->set_continuing(s->statements_); });
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 1e32028..883a713 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -27,16 +27,21 @@
 #include "src/ast/access_control.h"
 #include "src/ast/array_decoration.h"
 #include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
 #include "src/ast/builtin.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/case_statement.h"
 #include "src/ast/constructor_expression.h"
+#include "src/ast/continue_statement.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/switch_statement.h"
 #include "src/ast/function.h"
+#include "src/ast/if_statement.h"
 #include "src/ast/literal.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/module.h"
 #include "src/ast/pipeline_stage.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/statement.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/struct.h"
@@ -44,10 +49,11 @@
 #include "src/ast/struct_member.h"
 #include "src/ast/struct_member_decoration.h"
 #include "src/ast/type/storage_texture_type.h"
+#include "src/ast/type/struct_type.h"
 #include "src/ast/type/texture_type.h"
 #include "src/ast/type/type.h"
-#include "src/ast/type/struct_type.h"
 #include "src/ast/variable.h"
+#include "src/ast/variable_decl_statement.h"
 #include "src/ast/variable_decoration.h"
 #include "src/context.h"
 #include "src/diagnostic/diagnostic.h"
diff --git a/src/reader/wgsl/parser_impl_assignment_stmt_test.cc b/src/reader/wgsl/parser_impl_assignment_stmt_test.cc
index b641ced..dfcadfe 100644
--- a/src/reader/wgsl/parser_impl_assignment_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_assignment_stmt_test.cc
@@ -36,7 +36,7 @@
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
 
-  ASSERT_TRUE(e->IsAssign());
+  ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
   ASSERT_NE(e->lhs(), nullptr);
   ASSERT_NE(e->rhs(), nullptr);
 
@@ -61,7 +61,7 @@
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
 
-  ASSERT_TRUE(e->IsAssign());
+  ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
   ASSERT_NE(e->lhs(), nullptr);
   ASSERT_NE(e->rhs(), nullptr);
 
diff --git a/src/reader/wgsl/parser_impl_body_stmt_test.cc b/src/reader/wgsl/parser_impl_body_stmt_test.cc
index e4e21ee..b5413ef 100644
--- a/src/reader/wgsl/parser_impl_body_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_body_stmt_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
 #include "src/reader/wgsl/parser_impl.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
@@ -30,8 +31,8 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   ASSERT_FALSE(e.errored);
   ASSERT_EQ(e->size(), 2u);
-  EXPECT_TRUE(e->get(0)->IsDiscard());
-  EXPECT_TRUE(e->get(1)->IsReturn());
+  EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
+  EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, BodyStmt_Empty) {
diff --git a/src/reader/wgsl/parser_impl_break_stmt_test.cc b/src/reader/wgsl/parser_impl_break_stmt_test.cc
index fdd7438..1c7b0d7 100644
--- a/src/reader/wgsl/parser_impl_break_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_break_stmt_test.cc
@@ -28,7 +28,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsBreak());
+  ASSERT_TRUE(e->Is<ast::BreakStatement>());
 }
 
 }  // namespace
diff --git a/src/reader/wgsl/parser_impl_call_stmt_test.cc b/src/reader/wgsl/parser_impl_call_stmt_test.cc
index cfacc7b..a0ce2ed 100644
--- a/src/reader/wgsl/parser_impl_call_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_call_stmt_test.cc
@@ -32,8 +32,8 @@
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
 
-  ASSERT_TRUE(e->IsCall());
-  auto* c = e->AsCall()->expr();
+  ASSERT_TRUE(e->Is<ast::CallStatement>());
+  auto* c = e->As<ast::CallStatement>()->expr();
 
   ASSERT_TRUE(c->func()->IsIdentifier());
   auto* func = c->func()->AsIdentifier();
@@ -50,8 +50,8 @@
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
 
-  ASSERT_TRUE(e->IsCall());
-  auto* c = e->AsCall()->expr();
+  ASSERT_TRUE(e->Is<ast::CallStatement>());
+  auto* c = e->As<ast::CallStatement>()->expr();
 
   ASSERT_TRUE(c->func()->IsIdentifier());
   auto* func = c->func()->AsIdentifier();
diff --git a/src/reader/wgsl/parser_impl_case_body_test.cc b/src/reader/wgsl/parser_impl_case_body_test.cc
index 4abb8a9..f3bc05c 100644
--- a/src/reader/wgsl/parser_impl_case_body_test.cc
+++ b/src/reader/wgsl/parser_impl_case_body_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/reader/wgsl/parser_impl.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
@@ -40,8 +41,8 @@
   EXPECT_FALSE(e.errored);
   EXPECT_TRUE(e.matched);
   ASSERT_EQ(e->size(), 2u);
-  EXPECT_TRUE(e->get(0)->IsVariableDecl());
-  EXPECT_TRUE(e->get(1)->IsAssign());
+  EXPECT_TRUE(e->get(0)->Is<ast::VariableDeclStatement>());
+  EXPECT_TRUE(e->get(1)->Is<ast::AssignmentStatement>());
 }
 
 TEST_F(ParserImplTest, CaseBody_InvalidStatement) {
@@ -60,7 +61,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_TRUE(e.matched);
   ASSERT_EQ(e->size(), 1u);
-  EXPECT_TRUE(e->get(0)->IsFallthrough());
+  EXPECT_TRUE(e->get(0)->Is<ast::FallthroughStatement>());
 }
 
 TEST_F(ParserImplTest, CaseBody_Fallthrough_MissingSemicolon) {
diff --git a/src/reader/wgsl/parser_impl_continue_stmt_test.cc b/src/reader/wgsl/parser_impl_continue_stmt_test.cc
index e2e6e01..e9b72a1 100644
--- a/src/reader/wgsl/parser_impl_continue_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_continue_stmt_test.cc
@@ -28,7 +28,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsContinue());
+  ASSERT_TRUE(e->Is<ast::ContinueStatement>());
 }
 
 }  // namespace
diff --git a/src/reader/wgsl/parser_impl_continuing_stmt_test.cc b/src/reader/wgsl/parser_impl_continuing_stmt_test.cc
index 4e4a7d1..983d98f 100644
--- a/src/reader/wgsl/parser_impl_continuing_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_continuing_stmt_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
 #include "src/reader/wgsl/parser_impl.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
@@ -28,7 +29,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_EQ(e->size(), 1u);
-  ASSERT_TRUE(e->get(0)->IsDiscard());
+  ASSERT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
 }
 
 TEST_F(ParserImplTest, ContinuingStmt_InvalidBody) {
diff --git a/src/reader/wgsl/parser_impl_else_stmt_test.cc b/src/reader/wgsl/parser_impl_else_stmt_test.cc
index 98ad2b0..a0e86a0 100644
--- a/src/reader/wgsl/parser_impl_else_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_else_stmt_test.cc
@@ -29,7 +29,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsElse());
+  ASSERT_TRUE(e->Is<ast::ElseStatement>());
   ASSERT_EQ(e->condition(), nullptr);
   EXPECT_EQ(e->body()->size(), 2u);
 }
diff --git a/src/reader/wgsl/parser_impl_elseif_stmt_test.cc b/src/reader/wgsl/parser_impl_elseif_stmt_test.cc
index 22ebd97..1c27a9b 100644
--- a/src/reader/wgsl/parser_impl_elseif_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_elseif_stmt_test.cc
@@ -30,7 +30,7 @@
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_EQ(e.value.size(), 1u);
 
-  ASSERT_TRUE(e.value[0]->IsElse());
+  ASSERT_TRUE(e.value[0]->Is<ast::ElseStatement>());
   ASSERT_NE(e.value[0]->condition(), nullptr);
   ASSERT_TRUE(e.value[0]->condition()->IsBinary());
   EXPECT_EQ(e.value[0]->body()->size(), 2u);
@@ -44,12 +44,12 @@
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_EQ(e.value.size(), 2u);
 
-  ASSERT_TRUE(e.value[0]->IsElse());
+  ASSERT_TRUE(e.value[0]->Is<ast::ElseStatement>());
   ASSERT_NE(e.value[0]->condition(), nullptr);
   ASSERT_TRUE(e.value[0]->condition()->IsBinary());
   EXPECT_EQ(e.value[0]->body()->size(), 2u);
 
-  ASSERT_TRUE(e.value[1]->IsElse());
+  ASSERT_TRUE(e.value[1]->Is<ast::ElseStatement>());
   ASSERT_NE(e.value[1]->condition(), nullptr);
   ASSERT_TRUE(e.value[1]->condition()->IsIdentifier());
   EXPECT_EQ(e.value[1]->body()->size(), 1u);
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index ca155f2..7ba09b3 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -50,7 +50,7 @@
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
-  EXPECT_TRUE(body->get(0)->IsReturn());
+  EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
@@ -86,7 +86,7 @@
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
-  EXPECT_TRUE(body->get(0)->IsReturn());
+  EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
@@ -130,7 +130,7 @@
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
-  EXPECT_TRUE(body->get(0)->IsReturn());
+  EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
@@ -175,7 +175,7 @@
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
-  EXPECT_TRUE(body->get(0)->IsReturn());
+  EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
diff --git a/src/reader/wgsl/parser_impl_if_stmt_test.cc b/src/reader/wgsl/parser_impl_if_stmt_test.cc
index 4fd989e..fc1ac37 100644
--- a/src/reader/wgsl/parser_impl_if_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_if_stmt_test.cc
@@ -31,7 +31,7 @@
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
 
-  ASSERT_TRUE(e->IsIf());
+  ASSERT_TRUE(e->Is<ast::IfStatement>());
   ASSERT_NE(e->condition(), nullptr);
   ASSERT_TRUE(e->condition()->IsBinary());
   EXPECT_EQ(e->body()->size(), 2u);
@@ -46,7 +46,7 @@
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
 
-  ASSERT_TRUE(e->IsIf());
+  ASSERT_TRUE(e->Is<ast::IfStatement>());
   ASSERT_NE(e->condition(), nullptr);
   ASSERT_TRUE(e->condition()->IsBinary());
   EXPECT_EQ(e->body()->size(), 2u);
diff --git a/src/reader/wgsl/parser_impl_loop_stmt_test.cc b/src/reader/wgsl/parser_impl_loop_stmt_test.cc
index 043afca..2a0ba20 100644
--- a/src/reader/wgsl/parser_impl_loop_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_loop_stmt_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
 #include "src/reader/wgsl/parser_impl.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
@@ -30,7 +31,7 @@
   ASSERT_NE(e.value, nullptr);
 
   ASSERT_EQ(e->body()->size(), 1u);
-  EXPECT_TRUE(e->body()->get(0)->IsDiscard());
+  EXPECT_TRUE(e->body()->get(0)->Is<ast::DiscardStatement>());
 
   EXPECT_EQ(e->continuing()->size(), 0u);
 }
@@ -44,10 +45,10 @@
   ASSERT_NE(e.value, nullptr);
 
   ASSERT_EQ(e->body()->size(), 1u);
-  EXPECT_TRUE(e->body()->get(0)->IsDiscard());
+  EXPECT_TRUE(e->body()->get(0)->Is<ast::DiscardStatement>());
 
   EXPECT_EQ(e->continuing()->size(), 1u);
-  EXPECT_TRUE(e->continuing()->get(0)->IsDiscard());
+  EXPECT_TRUE(e->continuing()->get(0)->Is<ast::DiscardStatement>());
 }
 
 TEST_F(ParserImplTest, LoopStmt_NoBodyNoContinuing) {
@@ -70,7 +71,7 @@
   ASSERT_NE(e.value, nullptr);
   ASSERT_EQ(e->body()->size(), 0u);
   ASSERT_EQ(e->continuing()->size(), 1u);
-  EXPECT_TRUE(e->continuing()->get(0)->IsDiscard());
+  EXPECT_TRUE(e->continuing()->get(0)->Is<ast::DiscardStatement>());
 }
 
 TEST_F(ParserImplTest, LoopStmt_MissingBracketLeft) {
diff --git a/src/reader/wgsl/parser_impl_statement_test.cc b/src/reader/wgsl/parser_impl_statement_test.cc
index 85489de..ec96357 100644
--- a/src/reader/wgsl/parser_impl_statement_test.cc
+++ b/src/reader/wgsl/parser_impl_statement_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/statement.h"
 #include "src/reader/wgsl/parser_impl.h"
@@ -29,7 +30,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsReturn());
+  ASSERT_TRUE(e->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Semicolon) {
@@ -44,8 +45,8 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsReturn());
-  auto* ret = e->AsReturn();
+  ASSERT_TRUE(e->Is<ast::ReturnStatement>());
+  auto* ret = e->As<ast::ReturnStatement>();
   ASSERT_EQ(ret->value(), nullptr);
 }
 
@@ -56,8 +57,8 @@
 
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsReturn());
-  auto* ret = e->AsReturn();
+  ASSERT_TRUE(e->Is<ast::ReturnStatement>());
+  auto* ret = e->As<ast::ReturnStatement>();
   ASSERT_NE(ret->value(), nullptr);
   EXPECT_TRUE(ret->value()->IsBinary());
 }
@@ -88,7 +89,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsIf());
+  ASSERT_TRUE(e->Is<ast::IfStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_If_Invalid) {
@@ -107,7 +108,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsVariableDecl());
+  ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Variable_Invalid) {
@@ -136,7 +137,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsSwitch());
+  ASSERT_TRUE(e->Is<ast::SwitchStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Switch_Invalid) {
@@ -155,7 +156,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsLoop());
+  ASSERT_TRUE(e->Is<ast::LoopStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Loop_Invalid) {
@@ -174,7 +175,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsAssign());
+  ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Assignment_Invalid) {
@@ -203,7 +204,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsBreak());
+  ASSERT_TRUE(e->Is<ast::BreakStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Break_MissingSemicolon) {
@@ -222,7 +223,7 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsContinue());
+  ASSERT_TRUE(e->Is<ast::ContinueStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Continue_MissingSemicolon) {
@@ -242,7 +243,7 @@
   ASSERT_NE(e.value, nullptr);
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsDiscard());
+  ASSERT_TRUE(e->Is<ast::DiscardStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Discard_MissingSemicolon) {
@@ -261,8 +262,9 @@
   ASSERT_FALSE(p->has_error()) << p->error();
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
-  ASSERT_TRUE(e->IsBlock());
-  EXPECT_TRUE(e->AsBlock()->get(0)->IsVariableDecl());
+  ASSERT_TRUE(e->Is<ast::BlockStatement>());
+  EXPECT_TRUE(
+      e->As<ast::BlockStatement>()->get(0)->Is<ast::VariableDeclStatement>());
 }
 
 TEST_F(ParserImplTest, Statement_Body_Invalid) {
diff --git a/src/reader/wgsl/parser_impl_statements_test.cc b/src/reader/wgsl/parser_impl_statements_test.cc
index 55bd919..88b6012 100644
--- a/src/reader/wgsl/parser_impl_statements_test.cc
+++ b/src/reader/wgsl/parser_impl_statements_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/statement.h"
 #include "src/reader/wgsl/parser_impl.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -28,8 +29,8 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_EQ(e->size(), 2u);
-  EXPECT_TRUE(e->get(0)->IsDiscard());
-  EXPECT_TRUE(e->get(1)->IsReturn());
+  EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
+  EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, Statements_Empty) {
diff --git a/src/reader/wgsl/parser_impl_switch_body_test.cc b/src/reader/wgsl/parser_impl_switch_body_test.cc
index 74ac8a3..73ab068 100644
--- a/src/reader/wgsl/parser_impl_switch_body_test.cc
+++ b/src/reader/wgsl/parser_impl_switch_body_test.cc
@@ -29,10 +29,10 @@
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsCase());
+  ASSERT_TRUE(e->Is<ast::CaseStatement>());
   EXPECT_FALSE(e->IsDefault());
   ASSERT_EQ(e->body()->size(), 1u);
-  EXPECT_TRUE(e->body()->get(0)->IsAssign());
+  EXPECT_TRUE(e->body()->get(0)->Is<ast::AssignmentStatement>());
 }
 
 TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
@@ -112,10 +112,10 @@
   EXPECT_TRUE(e.matched);
   EXPECT_FALSE(e.errored);
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsCase());
+  ASSERT_TRUE(e->Is<ast::CaseStatement>());
   EXPECT_TRUE(e->IsDefault());
   ASSERT_EQ(e->body()->size(), 1u);
-  EXPECT_TRUE(e->body()->get(0)->IsAssign());
+  EXPECT_TRUE(e->body()->get(0)->Is<ast::AssignmentStatement>());
 }
 
 TEST_F(ParserImplTest, SwitchBody_Default_MissingColon) {
diff --git a/src/reader/wgsl/parser_impl_switch_stmt_test.cc b/src/reader/wgsl/parser_impl_switch_stmt_test.cc
index b4d3394..9d2bfbb 100644
--- a/src/reader/wgsl/parser_impl_switch_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_switch_stmt_test.cc
@@ -33,7 +33,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsSwitch());
+  ASSERT_TRUE(e->Is<ast::SwitchStatement>());
   ASSERT_EQ(e->body().size(), 2u);
   EXPECT_FALSE(e->body()[0]->IsDefault());
   EXPECT_FALSE(e->body()[1]->IsDefault());
@@ -46,7 +46,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsSwitch());
+  ASSERT_TRUE(e->Is<ast::SwitchStatement>());
   ASSERT_EQ(e->body().size(), 0u);
 }
 
@@ -61,7 +61,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsSwitch());
+  ASSERT_TRUE(e->Is<ast::SwitchStatement>());
 
   ASSERT_EQ(e->body().size(), 3u);
   ASSERT_FALSE(e->body()[0]->IsDefault());
diff --git a/src/reader/wgsl/parser_impl_variable_stmt_test.cc b/src/reader/wgsl/parser_impl_variable_stmt_test.cc
index fe6cc5a..8f83c31 100644
--- a/src/reader/wgsl/parser_impl_variable_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_variable_stmt_test.cc
@@ -30,7 +30,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsVariableDecl());
+  ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
   ASSERT_NE(e->variable(), nullptr);
   EXPECT_EQ(e->variable()->name(), "a");
 
@@ -49,7 +49,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsVariableDecl());
+  ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
   ASSERT_NE(e->variable(), nullptr);
   EXPECT_EQ(e->variable()->name(), "a");
 
@@ -89,7 +89,7 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_NE(e.value, nullptr);
-  ASSERT_TRUE(e->IsVariableDecl());
+  ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
 
   ASSERT_EQ(e->source().range.begin.line, 1u);
   ASSERT_EQ(e->source().range.begin.column, 7u);
diff --git a/src/transform/bound_array_accessors_transform.cc b/src/transform/bound_array_accessors_transform.cc
index 61e4a4a..1715c8c 100644
--- a/src/transform/bound_array_accessors_transform.cc
+++ b/src/transform/bound_array_accessors_transform.cc
@@ -21,10 +21,14 @@
 #include "src/ast/binary_expression.h"
 #include "src/ast/bitcast_expression.h"
 #include "src/ast/block_statement.h"
+#include "src/ast/break_statement.h"
 #include "src/ast/call_expression.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/member_accessor_expression.h"
@@ -67,52 +71,47 @@
 }
 
 bool BoundArrayAccessorsTransform::ProcessStatement(ast::Statement* stmt) {
-  if (stmt->IsAssign()) {
-    auto* as = stmt->AsAssign();
+  if (auto* as = stmt->As<ast::AssignmentStatement>()) {
     return ProcessExpression(as->lhs()) && ProcessExpression(as->rhs());
-  } else if (stmt->IsBlock()) {
-    for (auto* s : *(stmt->AsBlock())) {
+  } else if (auto* block = stmt->As<ast::BlockStatement>()) {
+    for (auto* s : *block) {
       if (!ProcessStatement(s)) {
         return false;
       }
     }
-  } else if (stmt->IsBreak()) {
+  } else if (stmt->Is<ast::BreakStatement>()) {
     /* nop */
-  } else if (stmt->IsCall()) {
-    return ProcessExpression(stmt->AsCall()->expr());
-  } else if (stmt->IsCase()) {
-    return ProcessStatement(stmt->AsCase()->body());
-  } else if (stmt->IsContinue()) {
+  } else if (auto* call = stmt->As<ast::CallStatement>()) {
+    return ProcessExpression(call->expr());
+  } else if (auto* kase = stmt->As<ast::CaseStatement>()) {
+    return ProcessStatement(kase->body());
+  } else if (stmt->Is<ast::ContinueStatement>()) {
     /* nop */
-  } else if (stmt->IsDiscard()) {
+  } else if (stmt->Is<ast::DiscardStatement>()) {
     /* nop */
-  } else if (stmt->IsElse()) {
-    auto* e = stmt->AsElse();
+  } else if (auto* e = stmt->As<ast::ElseStatement>()) {
     return ProcessExpression(e->condition()) && ProcessStatement(e->body());
-  } else if (stmt->IsFallthrough()) {
+  } else if (stmt->Is<ast::FallthroughStatement>()) {
     /* nop */
-  } else if (stmt->IsIf()) {
-    auto* e = stmt->AsIf();
-    if (!ProcessExpression(e->condition()) || !ProcessStatement(e->body())) {
+  } else if (auto* i = stmt->As<ast::IfStatement>()) {
+    if (!ProcessExpression(i->condition()) || !ProcessStatement(i->body())) {
       return false;
     }
-    for (auto* s : e->else_statements()) {
+    for (auto* s : i->else_statements()) {
       if (!ProcessStatement(s)) {
         return false;
       }
     }
-  } else if (stmt->IsLoop()) {
-    auto* l = stmt->AsLoop();
+  } else if (auto* l = stmt->As<ast::LoopStatement>()) {
     if (l->has_continuing() && !ProcessStatement(l->continuing())) {
       return false;
     }
     return ProcessStatement(l->body());
-  } else if (stmt->IsReturn()) {
-    if (stmt->AsReturn()->has_value()) {
-      return ProcessExpression(stmt->AsReturn()->value());
+  } else if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    if (r->has_value()) {
+      return ProcessExpression(r->value());
     }
-  } else if (stmt->IsSwitch()) {
-    auto* s = stmt->AsSwitch();
+  } else if (auto* s = stmt->As<ast::SwitchStatement>()) {
     if (!ProcessExpression(s->condition())) {
       return false;
     }
@@ -122,8 +121,8 @@
         return false;
       }
     }
-  } else if (stmt->IsVariableDecl()) {
-    auto* v = stmt->AsVariableDecl()->variable();
+  } else if (auto* vd = stmt->As<ast::VariableDeclStatement>()) {
+    auto* v = vd->variable();
     if (v->has_constructor() && !ProcessExpression(v->constructor())) {
       return false;
     }
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 40f55a7..6a83e74 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -27,7 +27,9 @@
 #include "src/ast/call_statement.h"
 #include "src/ast/case_statement.h"
 #include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/intrinsic.h"
@@ -178,11 +180,11 @@
 }
 
 bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
-  if (!stmt->IsVariableDecl()) {
+  if (!stmt->Is<ast::VariableDeclStatement>()) {
     return true;
   }
 
-  auto* var = stmt->AsVariableDecl()->variable();
+  auto* var = stmt->As<ast::VariableDeclStatement>()->variable();
   // Nothing to do for const
   if (var->is_const()) {
     return true;
@@ -203,39 +205,35 @@
 }
 
 bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
-  if (stmt->IsAssign()) {
-    auto* a = stmt->AsAssign();
+  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
     return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs());
   }
-  if (stmt->IsBlock()) {
-    return DetermineStatements(stmt->AsBlock());
+  if (auto* b = stmt->As<ast::BlockStatement>()) {
+    return DetermineStatements(b);
   }
-  if (stmt->IsBreak()) {
+  if (stmt->Is<ast::BreakStatement>()) {
     return true;
   }
-  if (stmt->IsCall()) {
-    return DetermineResultType(stmt->AsCall()->expr());
+  if (auto* c = stmt->As<ast::CallStatement>()) {
+    return DetermineResultType(c->expr());
   }
-  if (stmt->IsCase()) {
-    auto* c = stmt->AsCase();
+  if (auto* c = stmt->As<ast::CaseStatement>()) {
     return DetermineStatements(c->body());
   }
-  if (stmt->IsContinue()) {
+  if (stmt->Is<ast::ContinueStatement>()) {
     return true;
   }
-  if (stmt->IsDiscard()) {
+  if (stmt->Is<ast::DiscardStatement>()) {
     return true;
   }
-  if (stmt->IsElse()) {
-    auto* e = stmt->AsElse();
+  if (auto* e = stmt->As<ast::ElseStatement>()) {
     return DetermineResultType(e->condition()) &&
            DetermineStatements(e->body());
   }
-  if (stmt->IsFallthrough()) {
+  if (stmt->Is<ast::FallthroughStatement>()) {
     return true;
   }
-  if (stmt->IsIf()) {
-    auto* i = stmt->AsIf();
+  if (auto* i = stmt->As<ast::IfStatement>()) {
     if (!DetermineResultType(i->condition()) ||
         !DetermineStatements(i->body())) {
       return false;
@@ -248,17 +246,14 @@
     }
     return true;
   }
-  if (stmt->IsLoop()) {
-    auto* l = stmt->AsLoop();
+  if (auto* l = stmt->As<ast::LoopStatement>()) {
     return DetermineStatements(l->body()) &&
            DetermineStatements(l->continuing());
   }
-  if (stmt->IsReturn()) {
-    auto* r = stmt->AsReturn();
+  if (auto* r = stmt->As<ast::ReturnStatement>()) {
     return DetermineResultType(r->value());
   }
-  if (stmt->IsSwitch()) {
-    auto* s = stmt->AsSwitch();
+  if (auto* s = stmt->As<ast::SwitchStatement>()) {
     if (!DetermineResultType(s->condition())) {
       return false;
     }
@@ -269,8 +264,7 @@
     }
     return true;
   }
-  if (stmt->IsVariableDecl()) {
-    auto* v = stmt->AsVariableDecl();
+  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
     variable_stack_.set(v->variable()->name(), v->variable());
     return DetermineResultType(v->variable()->constructor());
   }
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 6567f37..eea936e 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -19,6 +19,7 @@
 #include <utility>
 
 #include "src/ast/call_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/function.h"
 #include "src/ast/int_literal.h"
 #include "src/ast/intrinsic.h"
@@ -208,7 +209,7 @@
 
   if (!current_function_->return_type()->Is<ast::type::VoidType>()) {
     if (!func->get_last_statement() ||
-        !func->get_last_statement()->IsReturn()) {
+        !func->get_last_statement()->Is<ast::ReturnStatement>()) {
       add_error(func->source(), "v-0002",
                 "non-void function must end with a return statement");
       return false;
@@ -284,29 +285,28 @@
   if (!stmt) {
     return false;
   }
-  if (stmt->IsVariableDecl()) {
-    auto* v = stmt->AsVariableDecl();
+  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
     bool constructor_valid =
         v->variable()->has_constructor()
             ? ValidateExpression(v->variable()->constructor())
             : true;
 
-    return constructor_valid && ValidateDeclStatement(stmt->AsVariableDecl());
+    return constructor_valid && ValidateDeclStatement(v);
   }
-  if (stmt->IsAssign()) {
-    return ValidateAssign(stmt->AsAssign());
+  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+    return ValidateAssign(a);
   }
-  if (stmt->IsReturn()) {
-    return ValidateReturnStatement(stmt->AsReturn());
+  if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    return ValidateReturnStatement(r);
   }
-  if (stmt->IsCall()) {
-    return ValidateCallExpr(stmt->AsCall()->expr());
+  if (auto* c = stmt->As<ast::CallStatement>()) {
+    return ValidateCallExpr(c->expr());
   }
-  if (stmt->IsSwitch()) {
-    return ValidateSwitch(stmt->AsSwitch());
+  if (auto* s = stmt->As<ast::SwitchStatement>()) {
+    return ValidateSwitch(s);
   }
-  if (stmt->IsCase()) {
-    return ValidateCase(stmt->AsCase());
+  if (auto* c = stmt->As<ast::CaseStatement>()) {
+    return ValidateCase(c);
   }
   return true;
 }
@@ -368,8 +368,10 @@
   }
 
   auto* last_clause = s->body().back();
-  auto* last_stmt_of_last_clause = last_clause->AsCase()->body()->last();
-  if (last_stmt_of_last_clause && last_stmt_of_last_clause->IsFallthrough()) {
+  auto* last_stmt_of_last_clause =
+      last_clause->As<ast::CaseStatement>()->body()->last();
+  if (last_stmt_of_last_clause &&
+      last_stmt_of_last_clause->Is<ast::FallthroughStatement>()) {
     add_error(last_stmt_of_last_clause->source(), "v-0028",
               "a fallthrough statement must not appear as "
               "the last statement in last clause of a switch");
diff --git a/src/validator/validator_impl.h b/src/validator/validator_impl.h
index 9cb4854..3d8e229 100644
--- a/src/validator/validator_impl.h
+++ b/src/validator/validator_impl.h
@@ -26,7 +26,9 @@
 #include "src/ast/module.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/statement.h"
+#include "src/ast/switch_statement.h"
 #include "src/ast/variable.h"
+#include "src/ast/variable_decl_statement.h"
 #include "src/diagnostic/diagnostic.h"
 #include "src/diagnostic/formatter.h"
 #include "src/scope_stack.h"
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 38a974f..0ffbc06 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -29,6 +29,7 @@
 #include "src/ast/case_statement.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
@@ -75,7 +76,8 @@
     return false;
   }
 
-  return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
+  return stmts->last()->Is<ast::BreakStatement>() ||
+         stmts->last()->Is<ast::FallthroughStatement>();
 }
 
 std::string get_buffer_name(ast::Expression* expr) {
@@ -1601,11 +1603,10 @@
     // the for loop into the continuing scope. Then, the variable declarations
     // will be turned into assignments.
     for (auto* s : *stmt->body()) {
-      if (!s->IsVariableDecl()) {
-        continue;
-      }
-      if (!EmitVariable(out, s->AsVariableDecl()->variable(), true)) {
-        return false;
+      if (auto* v = s->As<ast::VariableDeclStatement>()) {
+        if (!EmitVariable(out, v->variable(), true)) {
+          return false;
+        }
       }
     }
   }
@@ -1630,10 +1631,11 @@
   for (auto* s : *(stmt->body())) {
     // If we have a continuing block we've already emitted the variable
     // declaration before the loop, so treat it as an assignment.
-    if (s->IsVariableDecl() && stmt->has_continuing()) {
+    auto* decl = s->As<ast::VariableDeclStatement>();
+    if (decl != nullptr && stmt->has_continuing()) {
       make_indent(out);
 
-      auto* var = s->AsVariableDecl()->variable();
+      auto* var = decl->variable();
 
       std::ostringstream pre;
       std::ostringstream constructor_out;
@@ -1963,51 +1965,51 @@
 }
 
 bool GeneratorImpl::EmitStatement(std::ostream& out, ast::Statement* stmt) {
-  if (stmt->IsAssign()) {
-    return EmitAssign(out, stmt->AsAssign());
+  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+    return EmitAssign(out, a);
   }
-  if (stmt->IsBlock()) {
-    return EmitIndentedBlockAndNewline(out, stmt->AsBlock());
+  if (auto* b = stmt->As<ast::BlockStatement>()) {
+    return EmitIndentedBlockAndNewline(out, b);
   }
-  if (stmt->IsBreak()) {
-    return EmitBreak(out, stmt->AsBreak());
+  if (auto* b = stmt->As<ast::BreakStatement>()) {
+    return EmitBreak(out, b);
   }
-  if (stmt->IsCall()) {
+  if (auto* c = stmt->As<ast::CallStatement>()) {
     make_indent(out);
     std::ostringstream pre;
     std::ostringstream call_out;
-    if (!EmitCall(pre, call_out, stmt->AsCall()->expr())) {
+    if (!EmitCall(pre, call_out, c->expr())) {
       return false;
     }
     out << pre.str();
     out << call_out.str() << ";" << std::endl;
     return true;
   }
-  if (stmt->IsContinue()) {
-    return EmitContinue(out, stmt->AsContinue());
+  if (auto* c = stmt->As<ast::ContinueStatement>()) {
+    return EmitContinue(out, c);
   }
-  if (stmt->IsDiscard()) {
-    return EmitDiscard(out, stmt->AsDiscard());
+  if (auto* d = stmt->As<ast::DiscardStatement>()) {
+    return EmitDiscard(out, d);
   }
-  if (stmt->IsFallthrough()) {
+  if (auto* f = stmt->As<ast::FallthroughStatement>()) {
     make_indent(out);
     out << "/* fallthrough */" << std::endl;
     return true;
   }
-  if (stmt->IsIf()) {
-    return EmitIf(out, stmt->AsIf());
+  if (auto* i = stmt->As<ast::IfStatement>()) {
+    return EmitIf(out, i);
   }
-  if (stmt->IsLoop()) {
-    return EmitLoop(out, stmt->AsLoop());
+  if (auto* l = stmt->As<ast::LoopStatement>()) {
+    return EmitLoop(out, l);
   }
-  if (stmt->IsReturn()) {
-    return EmitReturn(out, stmt->AsReturn());
+  if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    return EmitReturn(out, r);
   }
-  if (stmt->IsSwitch()) {
-    return EmitSwitch(out, stmt->AsSwitch());
+  if (auto* s = stmt->As<ast::SwitchStatement>()) {
+    return EmitSwitch(out, s);
   }
-  if (stmt->IsVariableDecl()) {
-    return EmitVariable(out, stmt->AsVariableDecl()->variable(), false);
+  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+    return EmitVariable(out, v->variable(), false);
   }
 
   error_ = "unknown statement type: " + stmt->str();
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index a793f77..be2f31b 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -19,10 +19,19 @@
 #include <unordered_map>
 #include <unordered_set>
 
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
+#include "src/ast/if_statement.h"
 #include "src/ast/intrinsic.h"
 #include "src/ast/literal.h"
+#include "src/ast/loop_statement.h"
 #include "src/ast/module.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/switch_statement.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/context.h"
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 5fb958c..61aec3d 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -32,6 +32,7 @@
 #include "src/ast/continue_statement.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/function.h"
 #include "src/ast/identifier_expression.h"
@@ -80,7 +81,8 @@
     return false;
   }
 
-  return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
+  return stmts->last()->Is<ast::BreakStatement>() ||
+         stmts->last()->Is<ast::FallthroughStatement>();
 }
 
 uint32_t adjust_for_alignment(uint32_t count, uint32_t alignment) {
@@ -1540,11 +1542,10 @@
     // the for loop into the continuing scope. Then, the variable declarations
     // will be turned into assignments.
     for (auto* s : *(stmt->body())) {
-      if (!s->IsVariableDecl()) {
-        continue;
-      }
-      if (!EmitVariable(s->AsVariableDecl()->variable(), true)) {
-        return false;
+      if (auto* decl = s->As<ast::VariableDeclStatement>()) {
+        if (!EmitVariable(decl->variable(), true)) {
+          return false;
+        }
       }
     }
   }
@@ -1569,10 +1570,11 @@
   for (auto* s : *(stmt->body())) {
     // If we have a continuing block we've already emitted the variable
     // declaration before the loop, so treat it as an assignment.
-    if (s->IsVariableDecl() && stmt->has_continuing()) {
+    auto* decl = s->As<ast::VariableDeclStatement>();
+    if (decl != nullptr && stmt->has_continuing()) {
       make_indent();
 
-      auto* var = s->AsVariableDecl()->variable();
+      auto* var = decl->variable();
       out_ << var->name() << " = ";
       if (var->constructor() != nullptr) {
         if (!EmitExpression(var->constructor())) {
@@ -1716,48 +1718,48 @@
 }
 
 bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
-  if (stmt->IsAssign()) {
-    return EmitAssign(stmt->AsAssign());
+  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+    return EmitAssign(a);
   }
-  if (stmt->IsBlock()) {
-    return EmitIndentedBlockAndNewline(stmt->AsBlock());
+  if (auto* b = stmt->As<ast::BlockStatement>()) {
+    return EmitIndentedBlockAndNewline(b);
   }
-  if (stmt->IsBreak()) {
-    return EmitBreak(stmt->AsBreak());
+  if (auto* b = stmt->As<ast::BreakStatement>()) {
+    return EmitBreak(b);
   }
-  if (stmt->IsCall()) {
+  if (auto* c = stmt->As<ast::CallStatement>()) {
     make_indent();
-    if (!EmitCall(stmt->AsCall()->expr())) {
+    if (!EmitCall(c->expr())) {
       return false;
     }
     out_ << ";" << std::endl;
     return true;
   }
-  if (stmt->IsContinue()) {
-    return EmitContinue(stmt->AsContinue());
+  if (auto* c = stmt->As<ast::ContinueStatement>()) {
+    return EmitContinue(c);
   }
-  if (stmt->IsDiscard()) {
-    return EmitDiscard(stmt->AsDiscard());
+  if (auto* d = stmt->As<ast::DiscardStatement>()) {
+    return EmitDiscard(d);
   }
-  if (stmt->IsFallthrough()) {
+  if (auto* f = stmt->As<ast::FallthroughStatement>()) {
     make_indent();
     out_ << "/* fallthrough */" << std::endl;
     return true;
   }
-  if (stmt->IsIf()) {
-    return EmitIf(stmt->AsIf());
+  if (auto* i = stmt->As<ast::IfStatement>()) {
+    return EmitIf(i);
   }
-  if (stmt->IsLoop()) {
-    return EmitLoop(stmt->AsLoop());
+  if (auto* l = stmt->As<ast::LoopStatement>()) {
+    return EmitLoop(l);
   }
-  if (stmt->IsReturn()) {
-    return EmitReturn(stmt->AsReturn());
+  if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    return EmitReturn(r);
   }
-  if (stmt->IsSwitch()) {
-    return EmitSwitch(stmt->AsSwitch());
+  if (auto* s = stmt->As<ast::SwitchStatement>()) {
+    return EmitSwitch(s);
   }
-  if (stmt->IsVariableDecl()) {
-    return EmitVariable(stmt->AsVariableDecl()->variable(), false);
+  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+    return EmitVariable(v->variable(), false);
   }
 
   error_ = "unknown statement type: " + stmt->str();
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 7cbd5d6..ca95151 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -19,10 +19,20 @@
 #include <string>
 #include <unordered_map>
 
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
+#include "src/ast/else_statement.h"
+#include "src/ast/if_statement.h"
 #include "src/ast/intrinsic.h"
 #include "src/ast/literal.h"
+#include "src/ast/loop_statement.h"
 #include "src/ast/module.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/switch_statement.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/scope_stack.h"
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 5487120..0f577c1 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -35,6 +35,7 @@
 #include "src/ast/constructor_expression.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
@@ -109,7 +110,7 @@
 }
 
 bool LastIsFallthrough(const ast::BlockStatement* stmts) {
-  return !stmts->empty() && stmts->last()->IsFallthrough();
+  return !stmts->empty() && stmts->last()->Is<ast::FallthroughStatement>();
 }
 
 // A terminator is anything which will case a SPIR-V terminator to be emitted.
@@ -121,8 +122,11 @@
   }
 
   auto* last = stmts->last();
-  return last->IsBreak() || last->IsContinue() || last->IsDiscard() ||
-         last->IsReturn() || last->IsFallthrough();
+  return last->Is<ast::BreakStatement>() ||
+         last->Is<ast::ContinueStatement>() ||
+         last->Is<ast::DiscardStatement>() ||
+         last->Is<ast::ReturnStatement>() ||
+         last->Is<ast::FallthroughStatement>();
 }
 
 uint32_t IndexFromName(char name) {
@@ -2359,42 +2363,42 @@
 }
 
 bool Builder::GenerateStatement(ast::Statement* stmt) {
-  if (stmt->IsAssign()) {
-    return GenerateAssignStatement(stmt->AsAssign());
+  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+    return GenerateAssignStatement(a);
   }
-  if (stmt->IsBlock()) {
-    return GenerateBlockStatement(stmt->AsBlock());
+  if (auto* b = stmt->As<ast::BlockStatement>()) {
+    return GenerateBlockStatement(b);
   }
-  if (stmt->IsBreak()) {
-    return GenerateBreakStatement(stmt->AsBreak());
+  if (auto* b = stmt->As<ast::BreakStatement>()) {
+    return GenerateBreakStatement(b);
   }
-  if (stmt->IsCall()) {
-    return GenerateCallExpression(stmt->AsCall()->expr()) != 0;
+  if (auto* c = stmt->As<ast::CallStatement>()) {
+    return GenerateCallExpression(c->expr()) != 0;
   }
-  if (stmt->IsContinue()) {
-    return GenerateContinueStatement(stmt->AsContinue());
+  if (auto* c = stmt->As<ast::ContinueStatement>()) {
+    return GenerateContinueStatement(c);
   }
-  if (stmt->IsDiscard()) {
-    return GenerateDiscardStatement(stmt->AsDiscard());
+  if (auto* d = stmt->As<ast::DiscardStatement>()) {
+    return GenerateDiscardStatement(d);
   }
-  if (stmt->IsFallthrough()) {
+  if (stmt->Is<ast::FallthroughStatement>()) {
     // Do nothing here, the fallthrough gets handled by the switch code.
     return true;
   }
-  if (stmt->IsIf()) {
-    return GenerateIfStatement(stmt->AsIf());
+  if (auto* i = stmt->As<ast::IfStatement>()) {
+    return GenerateIfStatement(i);
   }
-  if (stmt->IsLoop()) {
-    return GenerateLoopStatement(stmt->AsLoop());
+  if (auto* l = stmt->As<ast::LoopStatement>()) {
+    return GenerateLoopStatement(l);
   }
-  if (stmt->IsReturn()) {
-    return GenerateReturnStatement(stmt->AsReturn());
+  if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    return GenerateReturnStatement(r);
   }
-  if (stmt->IsSwitch()) {
-    return GenerateSwitchStatement(stmt->AsSwitch());
+  if (auto* s = stmt->As<ast::SwitchStatement>()) {
+    return GenerateSwitchStatement(s);
   }
-  if (stmt->IsVariableDecl()) {
-    return GenerateVariableDeclStatement(stmt->AsVariableDecl());
+  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+    return GenerateVariableDeclStatement(v);
   }
 
   error_ = "Unknown statement: " + stmt->str();
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index f523b71..a8ed01c 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -22,11 +22,19 @@
 #include <vector>
 
 #include "spirv/unified1/spirv.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
 #include "src/ast/builtin.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/if_statement.h"
 #include "src/ast/literal.h"
+#include "src/ast/loop_statement.h"
 #include "src/ast/module.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/struct_member.h"
+#include "src/ast/switch_statement.h"
 #include "src/ast/type/access_control_type.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/matrix_type.h"
@@ -35,6 +43,7 @@
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
+#include "src/ast/variable_decl_statement.h"
 #include "src/context.h"
 #include "src/scope_stack.h"
 #include "src/writer/spirv/function.h"
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index bcff24f..5fa37cb 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -796,46 +796,46 @@
 }
 
 bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
-  if (stmt->IsAssign()) {
-    return EmitAssign(stmt->AsAssign());
+  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+    return EmitAssign(a);
   }
-  if (stmt->IsBlock()) {
-    return EmitIndentedBlockAndNewline(stmt->AsBlock());
+  if (auto* b = stmt->As<ast::BlockStatement>()) {
+    return EmitIndentedBlockAndNewline(b);
   }
-  if (stmt->IsBreak()) {
-    return EmitBreak(stmt->AsBreak());
+  if (auto* b = stmt->As<ast::BreakStatement>()) {
+    return EmitBreak(b);
   }
-  if (stmt->IsCall()) {
+  if (auto* c = stmt->As<ast::CallStatement>()) {
     make_indent();
-    if (!EmitCall(stmt->AsCall()->expr())) {
+    if (!EmitCall(c->expr())) {
       return false;
     }
     out_ << ";" << std::endl;
     return true;
   }
-  if (stmt->IsContinue()) {
-    return EmitContinue(stmt->AsContinue());
+  if (auto* c = stmt->As<ast::ContinueStatement>()) {
+    return EmitContinue(c);
   }
-  if (stmt->IsDiscard()) {
-    return EmitDiscard(stmt->AsDiscard());
+  if (auto* d = stmt->As<ast::DiscardStatement>()) {
+    return EmitDiscard(d);
   }
-  if (stmt->IsFallthrough()) {
-    return EmitFallthrough(stmt->AsFallthrough());
+  if (auto* f = stmt->As<ast::FallthroughStatement>()) {
+    return EmitFallthrough(f);
   }
-  if (stmt->IsIf()) {
-    return EmitIf(stmt->AsIf());
+  if (auto* i = stmt->As<ast::IfStatement>()) {
+    return EmitIf(i);
   }
-  if (stmt->IsLoop()) {
-    return EmitLoop(stmt->AsLoop());
+  if (auto* l = stmt->As<ast::LoopStatement>()) {
+    return EmitLoop(l);
   }
-  if (stmt->IsReturn()) {
-    return EmitReturn(stmt->AsReturn());
+  if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    return EmitReturn(r);
   }
-  if (stmt->IsSwitch()) {
-    return EmitSwitch(stmt->AsSwitch());
+  if (auto* s = stmt->As<ast::SwitchStatement>()) {
+    return EmitSwitch(s);
   }
-  if (stmt->IsVariableDecl()) {
-    return EmitVariable(stmt->AsVariableDecl()->variable());
+  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+    return EmitVariable(v->variable());
   }
 
   error_ = "unknown statement type: " + stmt->str();
diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h
index bd2451b..77fd16d 100644
--- a/src/writer/wgsl/generator_impl.h
+++ b/src/writer/wgsl/generator_impl.h
@@ -19,10 +19,20 @@
 #include <string>
 
 #include "src/ast/array_accessor_expression.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/case_statement.h"
 #include "src/ast/constructor_expression.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/identifier_expression.h"
+#include "src/ast/if_statement.h"
+#include "src/ast/loop_statement.h"
 #include "src/ast/module.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/switch_statement.h"
 #include "src/ast/type/storage_texture_type.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type/type.h"