Replace Statement::(Is|As)* with Castable
Change-Id: I5520752a4b5844be0ecac7921616893d123b246a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34315
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/assignment_statement.cc b/src/ast/assignment_statement.cc
index cfad841..3fe9a8b 100644
--- a/src/ast/assignment_statement.cc
+++ b/src/ast/assignment_statement.cc
@@ -31,10 +31,6 @@
AssignmentStatement::~AssignmentStatement() = default;
-bool AssignmentStatement::IsAssign() const {
- return true;
-}
-
bool AssignmentStatement::IsValid() const {
if (lhs_ == nullptr || !lhs_->IsValid())
return false;
diff --git a/src/ast/assignment_statement.h b/src/ast/assignment_statement.h
index dd68ca0..08c972e 100644
--- a/src/ast/assignment_statement.h
+++ b/src/ast/assignment_statement.h
@@ -55,9 +55,6 @@
/// @returns the right side expression
Expression* rhs() const { return rhs_; }
- /// @returns true if this is an assignment statement
- bool IsAssign() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/assignment_statement_test.cc b/src/ast/assignment_statement_test.cc
index fa01c7e..ad430b7 100644
--- a/src/ast/assignment_statement_test.cc
+++ b/src/ast/assignment_statement_test.cc
@@ -47,7 +47,7 @@
auto* rhs = create<ast::IdentifierExpression>("rhs");
AssignmentStatement stmt(lhs, rhs);
- EXPECT_TRUE(stmt.IsAssign());
+ EXPECT_TRUE(stmt.Is<AssignmentStatement>());
}
TEST_F(AssignmentStatementTest, IsValid) {
diff --git a/src/ast/block_statement.cc b/src/ast/block_statement.cc
index a5aa5a9..7ab8d9b 100644
--- a/src/ast/block_statement.cc
+++ b/src/ast/block_statement.cc
@@ -25,10 +25,6 @@
BlockStatement::~BlockStatement() = default;
-bool BlockStatement::IsBlock() const {
- return true;
-}
-
bool BlockStatement::IsValid() const {
for (auto* stmt : *this) {
if (stmt == nullptr || !stmt->IsValid()) {
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h
index 07189c7..0b8dac5 100644
--- a/src/ast/block_statement.h
+++ b/src/ast/block_statement.h
@@ -87,9 +87,6 @@
return statements_.end();
}
- /// @returns true if this is a block statement
- bool IsBlock() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/block_statement_test.cc b/src/ast/block_statement_test.cc
index 1c25885..f0941b7 100644
--- a/src/ast/block_statement_test.cc
+++ b/src/ast/block_statement_test.cc
@@ -65,7 +65,7 @@
TEST_F(BlockStatementTest, IsBlock) {
BlockStatement b;
- EXPECT_TRUE(b.IsBlock());
+ EXPECT_TRUE(b.Is<BlockStatement>());
}
TEST_F(BlockStatementTest, IsValid) {
diff --git a/src/ast/break_statement.cc b/src/ast/break_statement.cc
index ed70840..1894ba4 100644
--- a/src/ast/break_statement.cc
+++ b/src/ast/break_statement.cc
@@ -25,10 +25,6 @@
BreakStatement::~BreakStatement() = default;
-bool BreakStatement::IsBreak() const {
- return true;
-}
-
bool BreakStatement::IsValid() const {
return true;
}
diff --git a/src/ast/break_statement.h b/src/ast/break_statement.h
index 13a7d57..2a72c03 100644
--- a/src/ast/break_statement.h
+++ b/src/ast/break_statement.h
@@ -32,9 +32,6 @@
BreakStatement(BreakStatement&&);
~BreakStatement() override;
- /// @returns true if this is an break statement
- bool IsBreak() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/break_statement_test.cc b/src/ast/break_statement_test.cc
index e9e44cb..8021348 100644
--- a/src/ast/break_statement_test.cc
+++ b/src/ast/break_statement_test.cc
@@ -31,7 +31,7 @@
TEST_F(BreakStatementTest, IsBreak) {
BreakStatement stmt;
- EXPECT_TRUE(stmt.IsBreak());
+ EXPECT_TRUE(stmt.Is<BreakStatement>());
}
TEST_F(BreakStatementTest, IsValid) {
diff --git a/src/ast/call_statement.cc b/src/ast/call_statement.cc
index 0ff5188..c37c99a 100644
--- a/src/ast/call_statement.cc
+++ b/src/ast/call_statement.cc
@@ -27,10 +27,6 @@
CallStatement::~CallStatement() = default;
-bool CallStatement::IsCall() const {
- return true;
-}
-
bool CallStatement::IsValid() const {
return call_ != nullptr && call_->IsValid();
}
diff --git a/src/ast/call_statement.h b/src/ast/call_statement.h
index c74f156..1657388 100644
--- a/src/ast/call_statement.h
+++ b/src/ast/call_statement.h
@@ -42,9 +42,6 @@
/// @returns the call expression
CallExpression* expr() const { return call_; }
- /// @returns true if this is a call statement
- bool IsCall() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/call_statement_test.cc b/src/ast/call_statement_test.cc
index 363918d..528ecc1 100644
--- a/src/ast/call_statement_test.cc
+++ b/src/ast/call_statement_test.cc
@@ -34,7 +34,7 @@
TEST_F(CallStatementTest, IsCall) {
CallStatement c;
- EXPECT_TRUE(c.IsCall());
+ EXPECT_TRUE(c.Is<CallStatement>());
}
TEST_F(CallStatementTest, IsValid) {
diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc
index 84cda94..2d50185 100644
--- a/src/ast/case_statement.cc
+++ b/src/ast/case_statement.cc
@@ -31,10 +31,6 @@
CaseStatement::~CaseStatement() = default;
-bool CaseStatement::IsCase() const {
- return true;
-}
-
bool CaseStatement::IsValid() const {
return body_ != nullptr && body_->IsValid();
}
diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h
index aacd2f0..ef39af0 100644
--- a/src/ast/case_statement.h
+++ b/src/ast/case_statement.h
@@ -70,9 +70,6 @@
/// @returns the case body
BlockStatement* body() { return body_; }
- /// @returns true if this is a case statement
- bool IsCase() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc
index 30c298e..d5c77b6 100644
--- a/src/ast/case_statement_test.cc
+++ b/src/ast/case_statement_test.cc
@@ -99,7 +99,7 @@
TEST_F(CaseStatementTest, IsCase) {
CaseStatement c(create<ast::BlockStatement>());
- EXPECT_TRUE(c.IsCase());
+ EXPECT_TRUE(c.Is<ast::CaseStatement>());
}
TEST_F(CaseStatementTest, IsValid) {
diff --git a/src/ast/continue_statement.cc b/src/ast/continue_statement.cc
index f66a958..1f9da2a 100644
--- a/src/ast/continue_statement.cc
+++ b/src/ast/continue_statement.cc
@@ -25,10 +25,6 @@
ContinueStatement::~ContinueStatement() = default;
-bool ContinueStatement::IsContinue() const {
- return true;
-}
-
bool ContinueStatement::IsValid() const {
return true;
}
diff --git a/src/ast/continue_statement.h b/src/ast/continue_statement.h
index 1eb93f9..b2a01f9 100644
--- a/src/ast/continue_statement.h
+++ b/src/ast/continue_statement.h
@@ -35,9 +35,6 @@
ContinueStatement(ContinueStatement&&);
~ContinueStatement() override;
- /// @returns true if this is an continue statement
- bool IsContinue() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/continue_statement_test.cc b/src/ast/continue_statement_test.cc
index 5b1c58e..577d50b 100644
--- a/src/ast/continue_statement_test.cc
+++ b/src/ast/continue_statement_test.cc
@@ -31,7 +31,7 @@
TEST_F(ContinueStatementTest, IsContinue) {
ContinueStatement stmt;
- EXPECT_TRUE(stmt.IsContinue());
+ EXPECT_TRUE(stmt.Is<ContinueStatement>());
}
TEST_F(ContinueStatementTest, IsValid) {
diff --git a/src/ast/discard_statement.cc b/src/ast/discard_statement.cc
index e2c17fa..b70856b 100644
--- a/src/ast/discard_statement.cc
+++ b/src/ast/discard_statement.cc
@@ -25,10 +25,6 @@
DiscardStatement::~DiscardStatement() = default;
-bool DiscardStatement::IsDiscard() const {
- return true;
-}
-
bool DiscardStatement::IsValid() const {
return true;
}
diff --git a/src/ast/discard_statement.h b/src/ast/discard_statement.h
index 6f2e42a..ba7b398 100644
--- a/src/ast/discard_statement.h
+++ b/src/ast/discard_statement.h
@@ -32,9 +32,6 @@
DiscardStatement(DiscardStatement&&);
~DiscardStatement() override;
- /// @returns true if this is a discard statement
- bool IsDiscard() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/discard_statement_test.cc b/src/ast/discard_statement_test.cc
index a60a9a4..fdf444f 100644
--- a/src/ast/discard_statement_test.cc
+++ b/src/ast/discard_statement_test.cc
@@ -43,7 +43,7 @@
TEST_F(DiscardStatementTest, IsDiscard) {
DiscardStatement stmt;
- EXPECT_TRUE(stmt.IsDiscard());
+ EXPECT_TRUE(stmt.Is<DiscardStatement>());
}
TEST_F(DiscardStatementTest, IsValid) {
diff --git a/src/ast/else_statement.cc b/src/ast/else_statement.cc
index 00279c3..937755b 100644
--- a/src/ast/else_statement.cc
+++ b/src/ast/else_statement.cc
@@ -34,10 +34,6 @@
ElseStatement::~ElseStatement() = default;
-bool ElseStatement::IsElse() const {
- return true;
-}
-
bool ElseStatement::IsValid() const {
if (body_ == nullptr || !body_->IsValid()) {
return false;
diff --git a/src/ast/else_statement.h b/src/ast/else_statement.h
index b75904d..60a675a 100644
--- a/src/ast/else_statement.h
+++ b/src/ast/else_statement.h
@@ -67,9 +67,6 @@
/// @returns the else body
BlockStatement* body() { return body_; }
- /// @returns true if this is a else statement
- bool IsElse() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/else_statement_test.cc b/src/ast/else_statement_test.cc
index 948a94b..f90d7b5 100644
--- a/src/ast/else_statement_test.cc
+++ b/src/ast/else_statement_test.cc
@@ -51,7 +51,7 @@
TEST_F(ElseStatementTest, IsElse) {
ElseStatement e(create<BlockStatement>());
- EXPECT_TRUE(e.IsElse());
+ EXPECT_TRUE(e.Is<ElseStatement>());
}
TEST_F(ElseStatementTest, HasCondition) {
diff --git a/src/ast/fallthrough_statement.cc b/src/ast/fallthrough_statement.cc
index 6c84ab3..cd5f00d 100644
--- a/src/ast/fallthrough_statement.cc
+++ b/src/ast/fallthrough_statement.cc
@@ -26,10 +26,6 @@
FallthroughStatement::~FallthroughStatement() = default;
-bool FallthroughStatement::IsFallthrough() const {
- return true;
-}
-
bool FallthroughStatement::IsValid() const {
return true;
}
diff --git a/src/ast/fallthrough_statement.h b/src/ast/fallthrough_statement.h
index f586666..5b0bc81 100644
--- a/src/ast/fallthrough_statement.h
+++ b/src/ast/fallthrough_statement.h
@@ -32,9 +32,6 @@
FallthroughStatement(FallthroughStatement&&);
~FallthroughStatement() override;
- /// @returns true if this is an fallthrough statement
- bool IsFallthrough() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/fallthrough_statement_test.cc b/src/ast/fallthrough_statement_test.cc
index c5a936d..a41f5be 100644
--- a/src/ast/fallthrough_statement_test.cc
+++ b/src/ast/fallthrough_statement_test.cc
@@ -39,7 +39,7 @@
TEST_F(FallthroughStatementTest, IsFallthrough) {
FallthroughStatement stmt;
- EXPECT_TRUE(stmt.IsFallthrough());
+ EXPECT_TRUE(stmt.Is<FallthroughStatement>());
}
TEST_F(FallthroughStatementTest, IsValid) {
diff --git a/src/ast/if_statement.cc b/src/ast/if_statement.cc
index bcf156c..c8406b7 100644
--- a/src/ast/if_statement.cc
+++ b/src/ast/if_statement.cc
@@ -31,10 +31,6 @@
IfStatement::~IfStatement() = default;
-bool IfStatement::IsIf() const {
- return true;
-}
-
bool IfStatement::IsValid() const {
if (condition_ == nullptr || !condition_->IsValid()) {
return false;
diff --git a/src/ast/if_statement.h b/src/ast/if_statement.h
index 4b8a0fb..c6bb79f 100644
--- a/src/ast/if_statement.h
+++ b/src/ast/if_statement.h
@@ -71,9 +71,6 @@
/// @returns true if there are else statements
bool has_else_statements() const { return !else_statements_.empty(); }
- /// @returns true if this is a if statement
- bool IsIf() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/if_statement_test.cc b/src/ast/if_statement_test.cc
index 5c6d04d..45e99ba 100644
--- a/src/ast/if_statement_test.cc
+++ b/src/ast/if_statement_test.cc
@@ -49,7 +49,7 @@
TEST_F(IfStatementTest, IsIf) {
IfStatement stmt(nullptr, create<BlockStatement>());
- EXPECT_TRUE(stmt.IsIf());
+ EXPECT_TRUE(stmt.Is<IfStatement>());
}
TEST_F(IfStatementTest, IsValid) {
diff --git a/src/ast/loop_statement.cc b/src/ast/loop_statement.cc
index fbffebe..220201f 100644
--- a/src/ast/loop_statement.cc
+++ b/src/ast/loop_statement.cc
@@ -29,10 +29,6 @@
LoopStatement::~LoopStatement() = default;
-bool LoopStatement::IsLoop() const {
- return true;
-}
-
bool LoopStatement::IsValid() const {
if (body_ == nullptr || !body_->IsValid()) {
return false;
diff --git a/src/ast/loop_statement.h b/src/ast/loop_statement.h
index f1ed192..a107e08 100644
--- a/src/ast/loop_statement.h
+++ b/src/ast/loop_statement.h
@@ -62,9 +62,6 @@
return continuing_ != nullptr && !continuing_->empty();
}
- /// @returns true if this is a loop statement
- bool IsLoop() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/loop_statement_test.cc b/src/ast/loop_statement_test.cc
index 6e242e3..12001de 100644
--- a/src/ast/loop_statement_test.cc
+++ b/src/ast/loop_statement_test.cc
@@ -57,7 +57,7 @@
TEST_F(LoopStatementTest, IsLoop) {
LoopStatement l(create<BlockStatement>(), create<BlockStatement>());
- EXPECT_TRUE(l.IsLoop());
+ EXPECT_TRUE(l.Is<LoopStatement>());
}
TEST_F(LoopStatementTest, HasContinuing_WithoutContinuing) {
diff --git a/src/ast/return_statement.cc b/src/ast/return_statement.cc
index e395dfc..138618f 100644
--- a/src/ast/return_statement.cc
+++ b/src/ast/return_statement.cc
@@ -30,10 +30,6 @@
ReturnStatement::~ReturnStatement() = default;
-bool ReturnStatement::IsReturn() const {
- return true;
-}
-
bool ReturnStatement::IsValid() const {
if (value_ != nullptr) {
return value_->IsValid();
diff --git a/src/ast/return_statement.h b/src/ast/return_statement.h
index a1177f7..f2426fa 100644
--- a/src/ast/return_statement.h
+++ b/src/ast/return_statement.h
@@ -51,9 +51,6 @@
/// @returns true if the return has a value
bool has_value() const { return value_ != nullptr; }
- /// @returns true if this is a return statement
- bool IsReturn() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/return_statement_test.cc b/src/ast/return_statement_test.cc
index 42261c4..384b49e 100644
--- a/src/ast/return_statement_test.cc
+++ b/src/ast/return_statement_test.cc
@@ -41,7 +41,7 @@
TEST_F(ReturnStatementTest, IsReturn) {
ReturnStatement r;
- EXPECT_TRUE(r.IsReturn());
+ EXPECT_TRUE(r.Is<ReturnStatement>());
}
TEST_F(ReturnStatementTest, HasValue_WithoutValue) {
diff --git a/src/ast/statement.cc b/src/ast/statement.cc
index f44286f..aa45876 100644
--- a/src/ast/statement.cc
+++ b/src/ast/statement.cc
@@ -42,247 +42,51 @@
Statement::~Statement() = default;
-bool Statement::IsAssign() const {
- return false;
-}
-
-bool Statement::IsBlock() const {
- return false;
-}
-
-bool Statement::IsBreak() const {
- return false;
-}
-
-bool Statement::IsCase() const {
- return false;
-}
-
-bool Statement::IsCall() const {
- return false;
-}
-
-bool Statement::IsContinue() const {
- return false;
-}
-
-bool Statement::IsDiscard() const {
- return false;
-}
-
-bool Statement::IsElse() const {
- return false;
-}
-
-bool Statement::IsFallthrough() const {
- return false;
-}
-
-bool Statement::IsIf() const {
- return false;
-}
-
-bool Statement::IsLoop() const {
- return false;
-}
-
-bool Statement::IsReturn() const {
- return false;
-}
-
-bool Statement::IsSwitch() const {
- return false;
-}
-
-bool Statement::IsVariableDecl() const {
- return false;
-}
-
const char* Statement::Name() const {
- if (IsAssign()) {
+ if (Is<AssignmentStatement>()) {
return "assignment statement";
}
- if (IsBlock()) {
+ if (Is<BlockStatement>()) {
return "block statement";
}
- if (IsBreak()) {
+ if (Is<BreakStatement>()) {
return "break statement";
}
- if (IsCase()) {
+ if (Is<CaseStatement>()) {
return "case statement";
}
- if (IsCall()) {
+ if (Is<CallStatement>()) {
return "function call";
}
- if (IsContinue()) {
+ if (Is<ContinueStatement>()) {
return "continue statement";
}
- if (IsDiscard()) {
+ if (Is<DiscardStatement>()) {
return "discard statement";
}
- if (IsElse()) {
+ if (Is<ElseStatement>()) {
return "else statement";
}
- if (IsFallthrough()) {
+ if (Is<FallthroughStatement>()) {
return "fallthrough statement";
}
- if (IsIf()) {
+ if (Is<IfStatement>()) {
return "if statement";
}
- if (IsLoop()) {
+ if (Is<LoopStatement>()) {
return "loop statement";
}
- if (IsReturn()) {
+ if (Is<ReturnStatement>()) {
return "return statement";
}
- if (IsSwitch()) {
+ if (Is<SwitchStatement>()) {
return "switch statement";
}
- if (IsVariableDecl()) {
+ if (Is<VariableDeclStatement>()) {
return "variable declaration";
}
return "statement";
}
-const AssignmentStatement* Statement::AsAssign() const {
- assert(IsAssign());
- return static_cast<const AssignmentStatement*>(this);
-}
-
-const BlockStatement* Statement::AsBlock() const {
- assert(IsBlock());
- return static_cast<const BlockStatement*>(this);
-}
-
-const BreakStatement* Statement::AsBreak() const {
- assert(IsBreak());
- return static_cast<const BreakStatement*>(this);
-}
-
-const CallStatement* Statement::AsCall() const {
- assert(IsCall());
- return static_cast<const CallStatement*>(this);
-}
-
-const CaseStatement* Statement::AsCase() const {
- assert(IsCase());
- return static_cast<const CaseStatement*>(this);
-}
-
-const ContinueStatement* Statement::AsContinue() const {
- assert(IsContinue());
- return static_cast<const ContinueStatement*>(this);
-}
-
-const DiscardStatement* Statement::AsDiscard() const {
- assert(IsDiscard());
- return static_cast<const DiscardStatement*>(this);
-}
-
-const ElseStatement* Statement::AsElse() const {
- assert(IsElse());
- return static_cast<const ElseStatement*>(this);
-}
-
-const FallthroughStatement* Statement::AsFallthrough() const {
- assert(IsFallthrough());
- return static_cast<const FallthroughStatement*>(this);
-}
-
-const IfStatement* Statement::AsIf() const {
- assert(IsIf());
- return static_cast<const IfStatement*>(this);
-}
-
-const LoopStatement* Statement::AsLoop() const {
- assert(IsLoop());
- return static_cast<const LoopStatement*>(this);
-}
-
-const ReturnStatement* Statement::AsReturn() const {
- assert(IsReturn());
- return static_cast<const ReturnStatement*>(this);
-}
-
-const SwitchStatement* Statement::AsSwitch() const {
- assert(IsSwitch());
- return static_cast<const SwitchStatement*>(this);
-}
-
-const VariableDeclStatement* Statement::AsVariableDecl() const {
- assert(IsVariableDecl());
- return static_cast<const VariableDeclStatement*>(this);
-}
-
-AssignmentStatement* Statement::AsAssign() {
- assert(IsAssign());
- return static_cast<AssignmentStatement*>(this);
-}
-
-BlockStatement* Statement::AsBlock() {
- assert(IsBlock());
- return static_cast<BlockStatement*>(this);
-}
-
-BreakStatement* Statement::AsBreak() {
- assert(IsBreak());
- return static_cast<BreakStatement*>(this);
-}
-
-CallStatement* Statement::AsCall() {
- assert(IsCall());
- return static_cast<CallStatement*>(this);
-}
-
-CaseStatement* Statement::AsCase() {
- assert(IsCase());
- return static_cast<CaseStatement*>(this);
-}
-
-ContinueStatement* Statement::AsContinue() {
- assert(IsContinue());
- return static_cast<ContinueStatement*>(this);
-}
-
-DiscardStatement* Statement::AsDiscard() {
- assert(IsDiscard());
- return static_cast<DiscardStatement*>(this);
-}
-
-ElseStatement* Statement::AsElse() {
- assert(IsElse());
- return static_cast<ElseStatement*>(this);
-}
-
-FallthroughStatement* Statement::AsFallthrough() {
- assert(IsFallthrough());
- return static_cast<FallthroughStatement*>(this);
-}
-
-IfStatement* Statement::AsIf() {
- assert(IsIf());
- return static_cast<IfStatement*>(this);
-}
-
-LoopStatement* Statement::AsLoop() {
- assert(IsLoop());
- return static_cast<LoopStatement*>(this);
-}
-
-ReturnStatement* Statement::AsReturn() {
- assert(IsReturn());
- return static_cast<ReturnStatement*>(this);
-}
-
-SwitchStatement* Statement::AsSwitch() {
- assert(IsSwitch());
- return static_cast<SwitchStatement*>(this);
-}
-
-VariableDeclStatement* Statement::AsVariableDecl() {
- assert(IsVariableDecl());
- return static_cast<VariableDeclStatement*>(this);
-}
-
} // namespace ast
} // namespace tint
diff --git a/src/ast/statement.h b/src/ast/statement.h
index 7e791ba..b38c74c 100644
--- a/src/ast/statement.h
+++ b/src/ast/statement.h
@@ -23,116 +23,14 @@
namespace tint {
namespace ast {
-class AssignmentStatement;
-class BlockStatement;
-class BreakStatement;
-class CallStatement;
-class CaseStatement;
-class ContinueStatement;
-class DiscardStatement;
-class ElseStatement;
-class FallthroughStatement;
-class IfStatement;
-class LoopStatement;
-class ReturnStatement;
-class SwitchStatement;
-class VariableDeclStatement;
-
/// Base statement class
class Statement : public Castable<Statement, Node> {
public:
~Statement() override;
- /// @returns true if this is an assign statement
- virtual bool IsAssign() const;
- /// @returns true if this is a block statement
- virtual bool IsBlock() const;
- /// @returns true if this is a break statement
- virtual bool IsBreak() const;
- /// @returns true if this is a call statement
- virtual bool IsCall() const;
- /// @returns true if this is a case statement
- virtual bool IsCase() const;
- /// @returns true if this is a continue statement
- virtual bool IsContinue() const;
- /// @returns true if this is a discard statement
- virtual bool IsDiscard() const;
- /// @returns true if this is an else statement
- virtual bool IsElse() const;
- /// @returns true if this is a fallthrough statement
- virtual bool IsFallthrough() const;
- /// @returns true if this is an if statement
- virtual bool IsIf() const;
- /// @returns true if this is a loop statement
- virtual bool IsLoop() const;
- /// @returns true if this is a return statement
- virtual bool IsReturn() const;
- /// @returns true if this is a switch statement
- virtual bool IsSwitch() const;
- /// @returns true if this is an variable statement
- virtual bool IsVariableDecl() const;
-
/// @returns the human readable name for the statement type.
const char* Name() const;
- /// @returns the statement as a const assign statement
- const AssignmentStatement* AsAssign() const;
- /// @returns the statement as a const block statement
- const BlockStatement* AsBlock() const;
- /// @returns the statement as a const break statement
- const BreakStatement* AsBreak() const;
- /// @returns the statement as a const call statement
- const CallStatement* AsCall() const;
- /// @returns the statement as a const case statement
- const CaseStatement* AsCase() const;
- /// @returns the statement as a const continue statement
- const ContinueStatement* AsContinue() const;
- /// @returns the statement as a const discard statement
- const DiscardStatement* AsDiscard() const;
- /// @returns the statement as a const else statement
- const ElseStatement* AsElse() const;
- /// @returns the statement as a const fallthrough statement
- const FallthroughStatement* AsFallthrough() const;
- /// @returns the statement as a const if statement
- const IfStatement* AsIf() const;
- /// @returns the statement as a const loop statement
- const LoopStatement* AsLoop() const;
- /// @returns the statement as a const return statement
- const ReturnStatement* AsReturn() const;
- /// @returns the statement as a const switch statement
- const SwitchStatement* AsSwitch() const;
- /// @returns the statement as a const variable statement
- const VariableDeclStatement* AsVariableDecl() const;
-
- /// @returns the statement as an assign statement
- AssignmentStatement* AsAssign();
- /// @returns the statement as a block statement
- BlockStatement* AsBlock();
- /// @returns the statement as a break statement
- BreakStatement* AsBreak();
- /// @returns the statement as a call statement
- CallStatement* AsCall();
- /// @returns the statement as a case statement
- CaseStatement* AsCase();
- /// @returns the statement as a continue statement
- ContinueStatement* AsContinue();
- /// @returns the statement as a discard statement
- DiscardStatement* AsDiscard();
- /// @returns the statement as a else statement
- ElseStatement* AsElse();
- /// @returns the statement as a fallthrough statement
- FallthroughStatement* AsFallthrough();
- /// @returns the statement as a if statement
- IfStatement* AsIf();
- /// @returns the statement as a loop statement
- LoopStatement* AsLoop();
- /// @returns the statement as a return statement
- ReturnStatement* AsReturn();
- /// @returns the statement as a switch statement
- SwitchStatement* AsSwitch();
- /// @returns the statement as an variable statement
- VariableDeclStatement* AsVariableDecl();
-
protected:
/// Constructor
Statement();
diff --git a/src/ast/switch_statement.cc b/src/ast/switch_statement.cc
index 65bb68f..ed0d730 100644
--- a/src/ast/switch_statement.cc
+++ b/src/ast/switch_statement.cc
@@ -29,10 +29,6 @@
CaseStatementList body)
: Base(source), condition_(condition), body_(body) {}
-bool SwitchStatement::IsSwitch() const {
- return true;
-}
-
SwitchStatement::SwitchStatement(SwitchStatement&&) = default;
SwitchStatement::~SwitchStatement() = default;
diff --git a/src/ast/switch_statement.h b/src/ast/switch_statement.h
index 656918d..df7d570 100644
--- a/src/ast/switch_statement.h
+++ b/src/ast/switch_statement.h
@@ -60,9 +60,6 @@
/// @returns the Switch body
const CaseStatementList& body() const { return body_; }
- /// @returns true if this is a switch statement
- bool IsSwitch() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc
index 49427cf..eadc87c 100644
--- a/src/ast/switch_statement_test.cc
+++ b/src/ast/switch_statement_test.cc
@@ -57,7 +57,7 @@
TEST_F(SwitchStatementTest, IsSwitch) {
SwitchStatement stmt;
- EXPECT_TRUE(stmt.IsSwitch());
+ EXPECT_TRUE(stmt.Is<SwitchStatement>());
}
TEST_F(SwitchStatementTest, IsValid) {
diff --git a/src/ast/variable_decl_statement.cc b/src/ast/variable_decl_statement.cc
index 732c166..9b51bf4 100644
--- a/src/ast/variable_decl_statement.cc
+++ b/src/ast/variable_decl_statement.cc
@@ -30,10 +30,6 @@
VariableDeclStatement::~VariableDeclStatement() = default;
-bool VariableDeclStatement::IsVariableDecl() const {
- return true;
-}
-
bool VariableDeclStatement::IsValid() const {
return variable_ != nullptr && variable_->IsValid();
}
diff --git a/src/ast/variable_decl_statement.h b/src/ast/variable_decl_statement.h
index a233029..d03d5ad 100644
--- a/src/ast/variable_decl_statement.h
+++ b/src/ast/variable_decl_statement.h
@@ -48,9 +48,6 @@
/// @returns the variable
Variable* variable() const { return variable_; }
- /// @returns true if this is an variable statement
- bool IsVariableDecl() const override;
-
/// @returns true if the node is valid
bool IsValid() const override;
diff --git a/src/ast/variable_decl_statement_test.cc b/src/ast/variable_decl_statement_test.cc
index 8e96e12..9b5f907 100644
--- a/src/ast/variable_decl_statement_test.cc
+++ b/src/ast/variable_decl_statement_test.cc
@@ -44,7 +44,7 @@
TEST_F(VariableDeclStatementTest, IsVariableDecl) {
VariableDeclStatement s;
- EXPECT_TRUE(s.IsVariableDecl());
+ EXPECT_TRUE(s.Is<VariableDeclStatement>());
}
TEST_F(VariableDeclStatementTest, IsValid) {
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index bc2d5c7..133fac0 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -561,8 +561,8 @@
const auto& top = statements_stack_.back();
auto* cond = create<ast::IdentifierExpression>(guard_name);
auto* body = create<ast::BlockStatement>();
- auto* const guard_stmt =
- AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
+ auto* const guard_stmt = AddStatement(create<ast::IfStatement>(cond, body))
+ ->As<ast::IfStatement>();
PushNewStatementBlock(top.construct_, end_id,
[guard_stmt](StatementBlock* s) {
guard_stmt->set_body(s->statements_);
@@ -574,8 +574,8 @@
const auto& top = statements_stack_.back();
auto* cond = MakeTrue();
auto* body = create<ast::BlockStatement>();
- auto* const guard_stmt =
- AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
+ auto* const guard_stmt = AddStatement(create<ast::IfStatement>(cond, body))
+ ->As<ast::IfStatement>();
guard_stmt->set_condition(MakeTrue());
PushNewStatementBlock(top.construct_, end_id,
[guard_stmt](StatementBlock* s) {
@@ -2023,8 +2023,8 @@
block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr;
auto* body = create<ast::BlockStatement>();
- auto* const if_stmt =
- AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
+ auto* const if_stmt = AddStatement(create<ast::IfStatement>(cond, body))
+ ->As<ast::IfStatement>();
// Generate the code for the condition.
@@ -2137,7 +2137,7 @@
const auto* branch = block_info.basic_block->terminator();
auto* const switch_stmt =
- AddStatement(create<ast::SwitchStatement>())->AsSwitch();
+ AddStatement(create<ast::SwitchStatement>())->As<ast::SwitchStatement>();
const auto selector_id = branch->GetSingleWordInOperand(0);
// Generate the code for the selector.
auto selector = MakeExpression(selector_id);
@@ -2255,7 +2255,7 @@
auto* loop =
AddStatement(create<ast::LoopStatement>(create<ast::BlockStatement>(),
create<ast::BlockStatement>()))
- ->AsLoop();
+ ->As<ast::LoopStatement>();
PushNewStatementBlock(
construct, construct->end_id,
[loop](StatementBlock* s) { loop->set_body(s->statements_); });
@@ -2266,11 +2266,11 @@
// A continue construct has the same depth as its associated loop
// construct. Start a continue construct.
auto* loop_candidate = LastStatement();
- if (!loop_candidate->IsLoop()) {
+ if (!loop_candidate->Is<ast::LoopStatement>()) {
return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack";
}
- auto* loop = loop_candidate->AsLoop();
+ auto* loop = loop_candidate->As<ast::LoopStatement>();
PushNewStatementBlock(
construct, construct->end_id,
[loop](StatementBlock* s) { loop->set_continuing(s->statements_); });
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 1e32028..883a713 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -27,16 +27,21 @@
#include "src/ast/access_control.h"
#include "src/ast/array_decoration.h"
#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
#include "src/ast/builtin.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/constructor_expression.h"
+#include "src/ast/continue_statement.h"
#include "src/ast/else_statement.h"
+#include "src/ast/switch_statement.h"
#include "src/ast/function.h"
+#include "src/ast/if_statement.h"
#include "src/ast/literal.h"
#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
#include "src/ast/pipeline_stage.h"
+#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
#include "src/ast/storage_class.h"
#include "src/ast/struct.h"
@@ -44,10 +49,11 @@
#include "src/ast/struct_member.h"
#include "src/ast/struct_member_decoration.h"
#include "src/ast/type/storage_texture_type.h"
+#include "src/ast/type/struct_type.h"
#include "src/ast/type/texture_type.h"
#include "src/ast/type/type.h"
-#include "src/ast/type/struct_type.h"
#include "src/ast/variable.h"
+#include "src/ast/variable_decl_statement.h"
#include "src/ast/variable_decoration.h"
#include "src/context.h"
#include "src/diagnostic/diagnostic.h"
diff --git a/src/reader/wgsl/parser_impl_assignment_stmt_test.cc b/src/reader/wgsl/parser_impl_assignment_stmt_test.cc
index b641ced..dfcadfe 100644
--- a/src/reader/wgsl/parser_impl_assignment_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_assignment_stmt_test.cc
@@ -36,7 +36,7 @@
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsAssign());
+ ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
ASSERT_NE(e->lhs(), nullptr);
ASSERT_NE(e->rhs(), nullptr);
@@ -61,7 +61,7 @@
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsAssign());
+ ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
ASSERT_NE(e->lhs(), nullptr);
ASSERT_NE(e->rhs(), nullptr);
diff --git a/src/reader/wgsl/parser_impl_body_stmt_test.cc b/src/reader/wgsl/parser_impl_body_stmt_test.cc
index e4e21ee..b5413ef 100644
--- a/src/reader/wgsl/parser_impl_body_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_body_stmt_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -30,8 +31,8 @@
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
ASSERT_EQ(e->size(), 2u);
- EXPECT_TRUE(e->get(0)->IsDiscard());
- EXPECT_TRUE(e->get(1)->IsReturn());
+ EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
+ EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, BodyStmt_Empty) {
diff --git a/src/reader/wgsl/parser_impl_break_stmt_test.cc b/src/reader/wgsl/parser_impl_break_stmt_test.cc
index fdd7438..1c7b0d7 100644
--- a/src/reader/wgsl/parser_impl_break_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_break_stmt_test.cc
@@ -28,7 +28,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsBreak());
+ ASSERT_TRUE(e->Is<ast::BreakStatement>());
}
} // namespace
diff --git a/src/reader/wgsl/parser_impl_call_stmt_test.cc b/src/reader/wgsl/parser_impl_call_stmt_test.cc
index cfacc7b..a0ce2ed 100644
--- a/src/reader/wgsl/parser_impl_call_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_call_stmt_test.cc
@@ -32,8 +32,8 @@
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsCall());
- auto* c = e->AsCall()->expr();
+ ASSERT_TRUE(e->Is<ast::CallStatement>());
+ auto* c = e->As<ast::CallStatement>()->expr();
ASSERT_TRUE(c->func()->IsIdentifier());
auto* func = c->func()->AsIdentifier();
@@ -50,8 +50,8 @@
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsCall());
- auto* c = e->AsCall()->expr();
+ ASSERT_TRUE(e->Is<ast::CallStatement>());
+ auto* c = e->As<ast::CallStatement>()->expr();
ASSERT_TRUE(c->func()->IsIdentifier());
auto* func = c->func()->AsIdentifier();
diff --git a/src/reader/wgsl/parser_impl_case_body_test.cc b/src/reader/wgsl/parser_impl_case_body_test.cc
index 4abb8a9..f3bc05c 100644
--- a/src/reader/wgsl/parser_impl_case_body_test.cc
+++ b/src/reader/wgsl/parser_impl_case_body_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -40,8 +41,8 @@
EXPECT_FALSE(e.errored);
EXPECT_TRUE(e.matched);
ASSERT_EQ(e->size(), 2u);
- EXPECT_TRUE(e->get(0)->IsVariableDecl());
- EXPECT_TRUE(e->get(1)->IsAssign());
+ EXPECT_TRUE(e->get(0)->Is<ast::VariableDeclStatement>());
+ EXPECT_TRUE(e->get(1)->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, CaseBody_InvalidStatement) {
@@ -60,7 +61,7 @@
EXPECT_FALSE(e.errored);
EXPECT_TRUE(e.matched);
ASSERT_EQ(e->size(), 1u);
- EXPECT_TRUE(e->get(0)->IsFallthrough());
+ EXPECT_TRUE(e->get(0)->Is<ast::FallthroughStatement>());
}
TEST_F(ParserImplTest, CaseBody_Fallthrough_MissingSemicolon) {
diff --git a/src/reader/wgsl/parser_impl_continue_stmt_test.cc b/src/reader/wgsl/parser_impl_continue_stmt_test.cc
index e2e6e01..e9b72a1 100644
--- a/src/reader/wgsl/parser_impl_continue_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_continue_stmt_test.cc
@@ -28,7 +28,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsContinue());
+ ASSERT_TRUE(e->Is<ast::ContinueStatement>());
}
} // namespace
diff --git a/src/reader/wgsl/parser_impl_continuing_stmt_test.cc b/src/reader/wgsl/parser_impl_continuing_stmt_test.cc
index 4e4a7d1..983d98f 100644
--- a/src/reader/wgsl/parser_impl_continuing_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_continuing_stmt_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -28,7 +29,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e->size(), 1u);
- ASSERT_TRUE(e->get(0)->IsDiscard());
+ ASSERT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, ContinuingStmt_InvalidBody) {
diff --git a/src/reader/wgsl/parser_impl_else_stmt_test.cc b/src/reader/wgsl/parser_impl_else_stmt_test.cc
index 98ad2b0..a0e86a0 100644
--- a/src/reader/wgsl/parser_impl_else_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_else_stmt_test.cc
@@ -29,7 +29,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsElse());
+ ASSERT_TRUE(e->Is<ast::ElseStatement>());
ASSERT_EQ(e->condition(), nullptr);
EXPECT_EQ(e->body()->size(), 2u);
}
diff --git a/src/reader/wgsl/parser_impl_elseif_stmt_test.cc b/src/reader/wgsl/parser_impl_elseif_stmt_test.cc
index 22ebd97..1c27a9b 100644
--- a/src/reader/wgsl/parser_impl_elseif_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_elseif_stmt_test.cc
@@ -30,7 +30,7 @@
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e.value.size(), 1u);
- ASSERT_TRUE(e.value[0]->IsElse());
+ ASSERT_TRUE(e.value[0]->Is<ast::ElseStatement>());
ASSERT_NE(e.value[0]->condition(), nullptr);
ASSERT_TRUE(e.value[0]->condition()->IsBinary());
EXPECT_EQ(e.value[0]->body()->size(), 2u);
@@ -44,12 +44,12 @@
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e.value.size(), 2u);
- ASSERT_TRUE(e.value[0]->IsElse());
+ ASSERT_TRUE(e.value[0]->Is<ast::ElseStatement>());
ASSERT_NE(e.value[0]->condition(), nullptr);
ASSERT_TRUE(e.value[0]->condition()->IsBinary());
EXPECT_EQ(e.value[0]->body()->size(), 2u);
- ASSERT_TRUE(e.value[1]->IsElse());
+ ASSERT_TRUE(e.value[1]->Is<ast::ElseStatement>());
ASSERT_NE(e.value[1]->condition(), nullptr);
ASSERT_TRUE(e.value[1]->condition()->IsIdentifier());
EXPECT_EQ(e.value[1]->body()->size(), 1u);
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index ca155f2..7ba09b3 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -50,7 +50,7 @@
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
- EXPECT_TRUE(body->get(0)->IsReturn());
+ EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
@@ -86,7 +86,7 @@
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
- EXPECT_TRUE(body->get(0)->IsReturn());
+ EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
@@ -130,7 +130,7 @@
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
- EXPECT_TRUE(body->get(0)->IsReturn());
+ EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
@@ -175,7 +175,7 @@
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
- EXPECT_TRUE(body->get(0)->IsReturn());
+ EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
diff --git a/src/reader/wgsl/parser_impl_if_stmt_test.cc b/src/reader/wgsl/parser_impl_if_stmt_test.cc
index 4fd989e..fc1ac37 100644
--- a/src/reader/wgsl/parser_impl_if_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_if_stmt_test.cc
@@ -31,7 +31,7 @@
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsIf());
+ ASSERT_TRUE(e->Is<ast::IfStatement>());
ASSERT_NE(e->condition(), nullptr);
ASSERT_TRUE(e->condition()->IsBinary());
EXPECT_EQ(e->body()->size(), 2u);
@@ -46,7 +46,7 @@
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsIf());
+ ASSERT_TRUE(e->Is<ast::IfStatement>());
ASSERT_NE(e->condition(), nullptr);
ASSERT_TRUE(e->condition()->IsBinary());
EXPECT_EQ(e->body()->size(), 2u);
diff --git a/src/reader/wgsl/parser_impl_loop_stmt_test.cc b/src/reader/wgsl/parser_impl_loop_stmt_test.cc
index 043afca..2a0ba20 100644
--- a/src/reader/wgsl/parser_impl_loop_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_loop_stmt_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -30,7 +31,7 @@
ASSERT_NE(e.value, nullptr);
ASSERT_EQ(e->body()->size(), 1u);
- EXPECT_TRUE(e->body()->get(0)->IsDiscard());
+ EXPECT_TRUE(e->body()->get(0)->Is<ast::DiscardStatement>());
EXPECT_EQ(e->continuing()->size(), 0u);
}
@@ -44,10 +45,10 @@
ASSERT_NE(e.value, nullptr);
ASSERT_EQ(e->body()->size(), 1u);
- EXPECT_TRUE(e->body()->get(0)->IsDiscard());
+ EXPECT_TRUE(e->body()->get(0)->Is<ast::DiscardStatement>());
EXPECT_EQ(e->continuing()->size(), 1u);
- EXPECT_TRUE(e->continuing()->get(0)->IsDiscard());
+ EXPECT_TRUE(e->continuing()->get(0)->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, LoopStmt_NoBodyNoContinuing) {
@@ -70,7 +71,7 @@
ASSERT_NE(e.value, nullptr);
ASSERT_EQ(e->body()->size(), 0u);
ASSERT_EQ(e->continuing()->size(), 1u);
- EXPECT_TRUE(e->continuing()->get(0)->IsDiscard());
+ EXPECT_TRUE(e->continuing()->get(0)->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, LoopStmt_MissingBracketLeft) {
diff --git a/src/reader/wgsl/parser_impl_statement_test.cc b/src/reader/wgsl/parser_impl_statement_test.cc
index 85489de..ec96357 100644
--- a/src/reader/wgsl/parser_impl_statement_test.cc
+++ b/src/reader/wgsl/parser_impl_statement_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
#include "src/reader/wgsl/parser_impl.h"
@@ -29,7 +30,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsReturn());
+ ASSERT_TRUE(e->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, Statement_Semicolon) {
@@ -44,8 +45,8 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsReturn());
- auto* ret = e->AsReturn();
+ ASSERT_TRUE(e->Is<ast::ReturnStatement>());
+ auto* ret = e->As<ast::ReturnStatement>();
ASSERT_EQ(ret->value(), nullptr);
}
@@ -56,8 +57,8 @@
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsReturn());
- auto* ret = e->AsReturn();
+ ASSERT_TRUE(e->Is<ast::ReturnStatement>());
+ auto* ret = e->As<ast::ReturnStatement>();
ASSERT_NE(ret->value(), nullptr);
EXPECT_TRUE(ret->value()->IsBinary());
}
@@ -88,7 +89,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsIf());
+ ASSERT_TRUE(e->Is<ast::IfStatement>());
}
TEST_F(ParserImplTest, Statement_If_Invalid) {
@@ -107,7 +108,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsVariableDecl());
+ ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
}
TEST_F(ParserImplTest, Statement_Variable_Invalid) {
@@ -136,7 +137,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsSwitch());
+ ASSERT_TRUE(e->Is<ast::SwitchStatement>());
}
TEST_F(ParserImplTest, Statement_Switch_Invalid) {
@@ -155,7 +156,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsLoop());
+ ASSERT_TRUE(e->Is<ast::LoopStatement>());
}
TEST_F(ParserImplTest, Statement_Loop_Invalid) {
@@ -174,7 +175,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsAssign());
+ ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, Statement_Assignment_Invalid) {
@@ -203,7 +204,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsBreak());
+ ASSERT_TRUE(e->Is<ast::BreakStatement>());
}
TEST_F(ParserImplTest, Statement_Break_MissingSemicolon) {
@@ -222,7 +223,7 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsContinue());
+ ASSERT_TRUE(e->Is<ast::ContinueStatement>());
}
TEST_F(ParserImplTest, Statement_Continue_MissingSemicolon) {
@@ -242,7 +243,7 @@
ASSERT_NE(e.value, nullptr);
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsDiscard());
+ ASSERT_TRUE(e->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, Statement_Discard_MissingSemicolon) {
@@ -261,8 +262,9 @@
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
- ASSERT_TRUE(e->IsBlock());
- EXPECT_TRUE(e->AsBlock()->get(0)->IsVariableDecl());
+ ASSERT_TRUE(e->Is<ast::BlockStatement>());
+ EXPECT_TRUE(
+ e->As<ast::BlockStatement>()->get(0)->Is<ast::VariableDeclStatement>());
}
TEST_F(ParserImplTest, Statement_Body_Invalid) {
diff --git a/src/reader/wgsl/parser_impl_statements_test.cc b/src/reader/wgsl/parser_impl_statements_test.cc
index 55bd919..88b6012 100644
--- a/src/reader/wgsl/parser_impl_statements_test.cc
+++ b/src/reader/wgsl/parser_impl_statements_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -28,8 +29,8 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e->size(), 2u);
- EXPECT_TRUE(e->get(0)->IsDiscard());
- EXPECT_TRUE(e->get(1)->IsReturn());
+ EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
+ EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, Statements_Empty) {
diff --git a/src/reader/wgsl/parser_impl_switch_body_test.cc b/src/reader/wgsl/parser_impl_switch_body_test.cc
index 74ac8a3..73ab068 100644
--- a/src/reader/wgsl/parser_impl_switch_body_test.cc
+++ b/src/reader/wgsl/parser_impl_switch_body_test.cc
@@ -29,10 +29,10 @@
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsCase());
+ ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
ASSERT_EQ(e->body()->size(), 1u);
- EXPECT_TRUE(e->body()->get(0)->IsAssign());
+ EXPECT_TRUE(e->body()->get(0)->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
@@ -112,10 +112,10 @@
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsCase());
+ ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_TRUE(e->IsDefault());
ASSERT_EQ(e->body()->size(), 1u);
- EXPECT_TRUE(e->body()->get(0)->IsAssign());
+ EXPECT_TRUE(e->body()->get(0)->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, SwitchBody_Default_MissingColon) {
diff --git a/src/reader/wgsl/parser_impl_switch_stmt_test.cc b/src/reader/wgsl/parser_impl_switch_stmt_test.cc
index b4d3394..9d2bfbb 100644
--- a/src/reader/wgsl/parser_impl_switch_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_switch_stmt_test.cc
@@ -33,7 +33,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsSwitch());
+ ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body().size(), 2u);
EXPECT_FALSE(e->body()[0]->IsDefault());
EXPECT_FALSE(e->body()[1]->IsDefault());
@@ -46,7 +46,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsSwitch());
+ ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body().size(), 0u);
}
@@ -61,7 +61,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsSwitch());
+ ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body().size(), 3u);
ASSERT_FALSE(e->body()[0]->IsDefault());
diff --git a/src/reader/wgsl/parser_impl_variable_stmt_test.cc b/src/reader/wgsl/parser_impl_variable_stmt_test.cc
index fe6cc5a..8f83c31 100644
--- a/src/reader/wgsl/parser_impl_variable_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_variable_stmt_test.cc
@@ -30,7 +30,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsVariableDecl());
+ ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
ASSERT_NE(e->variable(), nullptr);
EXPECT_EQ(e->variable()->name(), "a");
@@ -49,7 +49,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsVariableDecl());
+ ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
ASSERT_NE(e->variable(), nullptr);
EXPECT_EQ(e->variable()->name(), "a");
@@ -89,7 +89,7 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->IsVariableDecl());
+ ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
ASSERT_EQ(e->source().range.begin.line, 1u);
ASSERT_EQ(e->source().range.begin.column, 7u);
diff --git a/src/transform/bound_array_accessors_transform.cc b/src/transform/bound_array_accessors_transform.cc
index 61e4a4a..1715c8c 100644
--- a/src/transform/bound_array_accessors_transform.cc
+++ b/src/transform/bound_array_accessors_transform.cc
@@ -21,10 +21,14 @@
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/block_statement.h"
+#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
@@ -67,52 +71,47 @@
}
bool BoundArrayAccessorsTransform::ProcessStatement(ast::Statement* stmt) {
- if (stmt->IsAssign()) {
- auto* as = stmt->AsAssign();
+ if (auto* as = stmt->As<ast::AssignmentStatement>()) {
return ProcessExpression(as->lhs()) && ProcessExpression(as->rhs());
- } else if (stmt->IsBlock()) {
- for (auto* s : *(stmt->AsBlock())) {
+ } else if (auto* block = stmt->As<ast::BlockStatement>()) {
+ for (auto* s : *block) {
if (!ProcessStatement(s)) {
return false;
}
}
- } else if (stmt->IsBreak()) {
+ } else if (stmt->Is<ast::BreakStatement>()) {
/* nop */
- } else if (stmt->IsCall()) {
- return ProcessExpression(stmt->AsCall()->expr());
- } else if (stmt->IsCase()) {
- return ProcessStatement(stmt->AsCase()->body());
- } else if (stmt->IsContinue()) {
+ } else if (auto* call = stmt->As<ast::CallStatement>()) {
+ return ProcessExpression(call->expr());
+ } else if (auto* kase = stmt->As<ast::CaseStatement>()) {
+ return ProcessStatement(kase->body());
+ } else if (stmt->Is<ast::ContinueStatement>()) {
/* nop */
- } else if (stmt->IsDiscard()) {
+ } else if (stmt->Is<ast::DiscardStatement>()) {
/* nop */
- } else if (stmt->IsElse()) {
- auto* e = stmt->AsElse();
+ } else if (auto* e = stmt->As<ast::ElseStatement>()) {
return ProcessExpression(e->condition()) && ProcessStatement(e->body());
- } else if (stmt->IsFallthrough()) {
+ } else if (stmt->Is<ast::FallthroughStatement>()) {
/* nop */
- } else if (stmt->IsIf()) {
- auto* e = stmt->AsIf();
- if (!ProcessExpression(e->condition()) || !ProcessStatement(e->body())) {
+ } else if (auto* i = stmt->As<ast::IfStatement>()) {
+ if (!ProcessExpression(i->condition()) || !ProcessStatement(i->body())) {
return false;
}
- for (auto* s : e->else_statements()) {
+ for (auto* s : i->else_statements()) {
if (!ProcessStatement(s)) {
return false;
}
}
- } else if (stmt->IsLoop()) {
- auto* l = stmt->AsLoop();
+ } else if (auto* l = stmt->As<ast::LoopStatement>()) {
if (l->has_continuing() && !ProcessStatement(l->continuing())) {
return false;
}
return ProcessStatement(l->body());
- } else if (stmt->IsReturn()) {
- if (stmt->AsReturn()->has_value()) {
- return ProcessExpression(stmt->AsReturn()->value());
+ } else if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ if (r->has_value()) {
+ return ProcessExpression(r->value());
}
- } else if (stmt->IsSwitch()) {
- auto* s = stmt->AsSwitch();
+ } else if (auto* s = stmt->As<ast::SwitchStatement>()) {
if (!ProcessExpression(s->condition())) {
return false;
}
@@ -122,8 +121,8 @@
return false;
}
}
- } else if (stmt->IsVariableDecl()) {
- auto* v = stmt->AsVariableDecl()->variable();
+ } else if (auto* vd = stmt->As<ast::VariableDeclStatement>()) {
+ auto* v = vd->variable();
if (v->has_constructor() && !ProcessExpression(v->constructor())) {
return false;
}
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 40f55a7..6a83e74 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -27,7 +27,9 @@
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
@@ -178,11 +180,11 @@
}
bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
- if (!stmt->IsVariableDecl()) {
+ if (!stmt->Is<ast::VariableDeclStatement>()) {
return true;
}
- auto* var = stmt->AsVariableDecl()->variable();
+ auto* var = stmt->As<ast::VariableDeclStatement>()->variable();
// Nothing to do for const
if (var->is_const()) {
return true;
@@ -203,39 +205,35 @@
}
bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
- if (stmt->IsAssign()) {
- auto* a = stmt->AsAssign();
+ if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs());
}
- if (stmt->IsBlock()) {
- return DetermineStatements(stmt->AsBlock());
+ if (auto* b = stmt->As<ast::BlockStatement>()) {
+ return DetermineStatements(b);
}
- if (stmt->IsBreak()) {
+ if (stmt->Is<ast::BreakStatement>()) {
return true;
}
- if (stmt->IsCall()) {
- return DetermineResultType(stmt->AsCall()->expr());
+ if (auto* c = stmt->As<ast::CallStatement>()) {
+ return DetermineResultType(c->expr());
}
- if (stmt->IsCase()) {
- auto* c = stmt->AsCase();
+ if (auto* c = stmt->As<ast::CaseStatement>()) {
return DetermineStatements(c->body());
}
- if (stmt->IsContinue()) {
+ if (stmt->Is<ast::ContinueStatement>()) {
return true;
}
- if (stmt->IsDiscard()) {
+ if (stmt->Is<ast::DiscardStatement>()) {
return true;
}
- if (stmt->IsElse()) {
- auto* e = stmt->AsElse();
+ if (auto* e = stmt->As<ast::ElseStatement>()) {
return DetermineResultType(e->condition()) &&
DetermineStatements(e->body());
}
- if (stmt->IsFallthrough()) {
+ if (stmt->Is<ast::FallthroughStatement>()) {
return true;
}
- if (stmt->IsIf()) {
- auto* i = stmt->AsIf();
+ if (auto* i = stmt->As<ast::IfStatement>()) {
if (!DetermineResultType(i->condition()) ||
!DetermineStatements(i->body())) {
return false;
@@ -248,17 +246,14 @@
}
return true;
}
- if (stmt->IsLoop()) {
- auto* l = stmt->AsLoop();
+ if (auto* l = stmt->As<ast::LoopStatement>()) {
return DetermineStatements(l->body()) &&
DetermineStatements(l->continuing());
}
- if (stmt->IsReturn()) {
- auto* r = stmt->AsReturn();
+ if (auto* r = stmt->As<ast::ReturnStatement>()) {
return DetermineResultType(r->value());
}
- if (stmt->IsSwitch()) {
- auto* s = stmt->AsSwitch();
+ if (auto* s = stmt->As<ast::SwitchStatement>()) {
if (!DetermineResultType(s->condition())) {
return false;
}
@@ -269,8 +264,7 @@
}
return true;
}
- if (stmt->IsVariableDecl()) {
- auto* v = stmt->AsVariableDecl();
+ if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
variable_stack_.set(v->variable()->name(), v->variable());
return DetermineResultType(v->variable()->constructor());
}
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 6567f37..eea936e 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -19,6 +19,7 @@
#include <utility>
#include "src/ast/call_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/function.h"
#include "src/ast/int_literal.h"
#include "src/ast/intrinsic.h"
@@ -208,7 +209,7 @@
if (!current_function_->return_type()->Is<ast::type::VoidType>()) {
if (!func->get_last_statement() ||
- !func->get_last_statement()->IsReturn()) {
+ !func->get_last_statement()->Is<ast::ReturnStatement>()) {
add_error(func->source(), "v-0002",
"non-void function must end with a return statement");
return false;
@@ -284,29 +285,28 @@
if (!stmt) {
return false;
}
- if (stmt->IsVariableDecl()) {
- auto* v = stmt->AsVariableDecl();
+ if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
bool constructor_valid =
v->variable()->has_constructor()
? ValidateExpression(v->variable()->constructor())
: true;
- return constructor_valid && ValidateDeclStatement(stmt->AsVariableDecl());
+ return constructor_valid && ValidateDeclStatement(v);
}
- if (stmt->IsAssign()) {
- return ValidateAssign(stmt->AsAssign());
+ if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+ return ValidateAssign(a);
}
- if (stmt->IsReturn()) {
- return ValidateReturnStatement(stmt->AsReturn());
+ if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ return ValidateReturnStatement(r);
}
- if (stmt->IsCall()) {
- return ValidateCallExpr(stmt->AsCall()->expr());
+ if (auto* c = stmt->As<ast::CallStatement>()) {
+ return ValidateCallExpr(c->expr());
}
- if (stmt->IsSwitch()) {
- return ValidateSwitch(stmt->AsSwitch());
+ if (auto* s = stmt->As<ast::SwitchStatement>()) {
+ return ValidateSwitch(s);
}
- if (stmt->IsCase()) {
- return ValidateCase(stmt->AsCase());
+ if (auto* c = stmt->As<ast::CaseStatement>()) {
+ return ValidateCase(c);
}
return true;
}
@@ -368,8 +368,10 @@
}
auto* last_clause = s->body().back();
- auto* last_stmt_of_last_clause = last_clause->AsCase()->body()->last();
- if (last_stmt_of_last_clause && last_stmt_of_last_clause->IsFallthrough()) {
+ auto* last_stmt_of_last_clause =
+ last_clause->As<ast::CaseStatement>()->body()->last();
+ if (last_stmt_of_last_clause &&
+ last_stmt_of_last_clause->Is<ast::FallthroughStatement>()) {
add_error(last_stmt_of_last_clause->source(), "v-0028",
"a fallthrough statement must not appear as "
"the last statement in last clause of a switch");
diff --git a/src/validator/validator_impl.h b/src/validator/validator_impl.h
index 9cb4854..3d8e229 100644
--- a/src/validator/validator_impl.h
+++ b/src/validator/validator_impl.h
@@ -26,7 +26,9 @@
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
+#include "src/ast/switch_statement.h"
#include "src/ast/variable.h"
+#include "src/ast/variable_decl_statement.h"
#include "src/diagnostic/diagnostic.h"
#include "src/diagnostic/formatter.h"
#include "src/scope_stack.h"
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 38a974f..0ffbc06 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -29,6 +29,7 @@
#include "src/ast/case_statement.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@@ -75,7 +76,8 @@
return false;
}
- return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
+ return stmts->last()->Is<ast::BreakStatement>() ||
+ stmts->last()->Is<ast::FallthroughStatement>();
}
std::string get_buffer_name(ast::Expression* expr) {
@@ -1601,11 +1603,10 @@
// the for loop into the continuing scope. Then, the variable declarations
// will be turned into assignments.
for (auto* s : *stmt->body()) {
- if (!s->IsVariableDecl()) {
- continue;
- }
- if (!EmitVariable(out, s->AsVariableDecl()->variable(), true)) {
- return false;
+ if (auto* v = s->As<ast::VariableDeclStatement>()) {
+ if (!EmitVariable(out, v->variable(), true)) {
+ return false;
+ }
}
}
}
@@ -1630,10 +1631,11 @@
for (auto* s : *(stmt->body())) {
// If we have a continuing block we've already emitted the variable
// declaration before the loop, so treat it as an assignment.
- if (s->IsVariableDecl() && stmt->has_continuing()) {
+ auto* decl = s->As<ast::VariableDeclStatement>();
+ if (decl != nullptr && stmt->has_continuing()) {
make_indent(out);
- auto* var = s->AsVariableDecl()->variable();
+ auto* var = decl->variable();
std::ostringstream pre;
std::ostringstream constructor_out;
@@ -1963,51 +1965,51 @@
}
bool GeneratorImpl::EmitStatement(std::ostream& out, ast::Statement* stmt) {
- if (stmt->IsAssign()) {
- return EmitAssign(out, stmt->AsAssign());
+ if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+ return EmitAssign(out, a);
}
- if (stmt->IsBlock()) {
- return EmitIndentedBlockAndNewline(out, stmt->AsBlock());
+ if (auto* b = stmt->As<ast::BlockStatement>()) {
+ return EmitIndentedBlockAndNewline(out, b);
}
- if (stmt->IsBreak()) {
- return EmitBreak(out, stmt->AsBreak());
+ if (auto* b = stmt->As<ast::BreakStatement>()) {
+ return EmitBreak(out, b);
}
- if (stmt->IsCall()) {
+ if (auto* c = stmt->As<ast::CallStatement>()) {
make_indent(out);
std::ostringstream pre;
std::ostringstream call_out;
- if (!EmitCall(pre, call_out, stmt->AsCall()->expr())) {
+ if (!EmitCall(pre, call_out, c->expr())) {
return false;
}
out << pre.str();
out << call_out.str() << ";" << std::endl;
return true;
}
- if (stmt->IsContinue()) {
- return EmitContinue(out, stmt->AsContinue());
+ if (auto* c = stmt->As<ast::ContinueStatement>()) {
+ return EmitContinue(out, c);
}
- if (stmt->IsDiscard()) {
- return EmitDiscard(out, stmt->AsDiscard());
+ if (auto* d = stmt->As<ast::DiscardStatement>()) {
+ return EmitDiscard(out, d);
}
- if (stmt->IsFallthrough()) {
+ if (auto* f = stmt->As<ast::FallthroughStatement>()) {
make_indent(out);
out << "/* fallthrough */" << std::endl;
return true;
}
- if (stmt->IsIf()) {
- return EmitIf(out, stmt->AsIf());
+ if (auto* i = stmt->As<ast::IfStatement>()) {
+ return EmitIf(out, i);
}
- if (stmt->IsLoop()) {
- return EmitLoop(out, stmt->AsLoop());
+ if (auto* l = stmt->As<ast::LoopStatement>()) {
+ return EmitLoop(out, l);
}
- if (stmt->IsReturn()) {
- return EmitReturn(out, stmt->AsReturn());
+ if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ return EmitReturn(out, r);
}
- if (stmt->IsSwitch()) {
- return EmitSwitch(out, stmt->AsSwitch());
+ if (auto* s = stmt->As<ast::SwitchStatement>()) {
+ return EmitSwitch(out, s);
}
- if (stmt->IsVariableDecl()) {
- return EmitVariable(out, stmt->AsVariableDecl()->variable(), false);
+ if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+ return EmitVariable(out, v->variable(), false);
}
error_ = "unknown statement type: " + stmt->str();
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index a793f77..be2f31b 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -19,10 +19,19 @@
#include <unordered_map>
#include <unordered_set>
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
+#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
#include "src/ast/literal.h"
+#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
+#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/switch_statement.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/context.h"
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 5fb958c..61aec3d 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -32,6 +32,7 @@
#include "src/ast/continue_statement.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
@@ -80,7 +81,8 @@
return false;
}
- return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
+ return stmts->last()->Is<ast::BreakStatement>() ||
+ stmts->last()->Is<ast::FallthroughStatement>();
}
uint32_t adjust_for_alignment(uint32_t count, uint32_t alignment) {
@@ -1540,11 +1542,10 @@
// the for loop into the continuing scope. Then, the variable declarations
// will be turned into assignments.
for (auto* s : *(stmt->body())) {
- if (!s->IsVariableDecl()) {
- continue;
- }
- if (!EmitVariable(s->AsVariableDecl()->variable(), true)) {
- return false;
+ if (auto* decl = s->As<ast::VariableDeclStatement>()) {
+ if (!EmitVariable(decl->variable(), true)) {
+ return false;
+ }
}
}
}
@@ -1569,10 +1570,11 @@
for (auto* s : *(stmt->body())) {
// If we have a continuing block we've already emitted the variable
// declaration before the loop, so treat it as an assignment.
- if (s->IsVariableDecl() && stmt->has_continuing()) {
+ auto* decl = s->As<ast::VariableDeclStatement>();
+ if (decl != nullptr && stmt->has_continuing()) {
make_indent();
- auto* var = s->AsVariableDecl()->variable();
+ auto* var = decl->variable();
out_ << var->name() << " = ";
if (var->constructor() != nullptr) {
if (!EmitExpression(var->constructor())) {
@@ -1716,48 +1718,48 @@
}
bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
- if (stmt->IsAssign()) {
- return EmitAssign(stmt->AsAssign());
+ if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+ return EmitAssign(a);
}
- if (stmt->IsBlock()) {
- return EmitIndentedBlockAndNewline(stmt->AsBlock());
+ if (auto* b = stmt->As<ast::BlockStatement>()) {
+ return EmitIndentedBlockAndNewline(b);
}
- if (stmt->IsBreak()) {
- return EmitBreak(stmt->AsBreak());
+ if (auto* b = stmt->As<ast::BreakStatement>()) {
+ return EmitBreak(b);
}
- if (stmt->IsCall()) {
+ if (auto* c = stmt->As<ast::CallStatement>()) {
make_indent();
- if (!EmitCall(stmt->AsCall()->expr())) {
+ if (!EmitCall(c->expr())) {
return false;
}
out_ << ";" << std::endl;
return true;
}
- if (stmt->IsContinue()) {
- return EmitContinue(stmt->AsContinue());
+ if (auto* c = stmt->As<ast::ContinueStatement>()) {
+ return EmitContinue(c);
}
- if (stmt->IsDiscard()) {
- return EmitDiscard(stmt->AsDiscard());
+ if (auto* d = stmt->As<ast::DiscardStatement>()) {
+ return EmitDiscard(d);
}
- if (stmt->IsFallthrough()) {
+ if (auto* f = stmt->As<ast::FallthroughStatement>()) {
make_indent();
out_ << "/* fallthrough */" << std::endl;
return true;
}
- if (stmt->IsIf()) {
- return EmitIf(stmt->AsIf());
+ if (auto* i = stmt->As<ast::IfStatement>()) {
+ return EmitIf(i);
}
- if (stmt->IsLoop()) {
- return EmitLoop(stmt->AsLoop());
+ if (auto* l = stmt->As<ast::LoopStatement>()) {
+ return EmitLoop(l);
}
- if (stmt->IsReturn()) {
- return EmitReturn(stmt->AsReturn());
+ if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ return EmitReturn(r);
}
- if (stmt->IsSwitch()) {
- return EmitSwitch(stmt->AsSwitch());
+ if (auto* s = stmt->As<ast::SwitchStatement>()) {
+ return EmitSwitch(s);
}
- if (stmt->IsVariableDecl()) {
- return EmitVariable(stmt->AsVariableDecl()->variable(), false);
+ if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+ return EmitVariable(v->variable(), false);
}
error_ = "unknown statement type: " + stmt->str();
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 7cbd5d6..ca95151 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -19,10 +19,20 @@
#include <string>
#include <unordered_map>
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
+#include "src/ast/else_statement.h"
+#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
#include "src/ast/literal.h"
+#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
+#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/switch_statement.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/scope_stack.h"
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 5487120..0f577c1 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -35,6 +35,7 @@
#include "src/ast/constructor_expression.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@@ -109,7 +110,7 @@
}
bool LastIsFallthrough(const ast::BlockStatement* stmts) {
- return !stmts->empty() && stmts->last()->IsFallthrough();
+ return !stmts->empty() && stmts->last()->Is<ast::FallthroughStatement>();
}
// A terminator is anything which will case a SPIR-V terminator to be emitted.
@@ -121,8 +122,11 @@
}
auto* last = stmts->last();
- return last->IsBreak() || last->IsContinue() || last->IsDiscard() ||
- last->IsReturn() || last->IsFallthrough();
+ return last->Is<ast::BreakStatement>() ||
+ last->Is<ast::ContinueStatement>() ||
+ last->Is<ast::DiscardStatement>() ||
+ last->Is<ast::ReturnStatement>() ||
+ last->Is<ast::FallthroughStatement>();
}
uint32_t IndexFromName(char name) {
@@ -2359,42 +2363,42 @@
}
bool Builder::GenerateStatement(ast::Statement* stmt) {
- if (stmt->IsAssign()) {
- return GenerateAssignStatement(stmt->AsAssign());
+ if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+ return GenerateAssignStatement(a);
}
- if (stmt->IsBlock()) {
- return GenerateBlockStatement(stmt->AsBlock());
+ if (auto* b = stmt->As<ast::BlockStatement>()) {
+ return GenerateBlockStatement(b);
}
- if (stmt->IsBreak()) {
- return GenerateBreakStatement(stmt->AsBreak());
+ if (auto* b = stmt->As<ast::BreakStatement>()) {
+ return GenerateBreakStatement(b);
}
- if (stmt->IsCall()) {
- return GenerateCallExpression(stmt->AsCall()->expr()) != 0;
+ if (auto* c = stmt->As<ast::CallStatement>()) {
+ return GenerateCallExpression(c->expr()) != 0;
}
- if (stmt->IsContinue()) {
- return GenerateContinueStatement(stmt->AsContinue());
+ if (auto* c = stmt->As<ast::ContinueStatement>()) {
+ return GenerateContinueStatement(c);
}
- if (stmt->IsDiscard()) {
- return GenerateDiscardStatement(stmt->AsDiscard());
+ if (auto* d = stmt->As<ast::DiscardStatement>()) {
+ return GenerateDiscardStatement(d);
}
- if (stmt->IsFallthrough()) {
+ if (stmt->Is<ast::FallthroughStatement>()) {
// Do nothing here, the fallthrough gets handled by the switch code.
return true;
}
- if (stmt->IsIf()) {
- return GenerateIfStatement(stmt->AsIf());
+ if (auto* i = stmt->As<ast::IfStatement>()) {
+ return GenerateIfStatement(i);
}
- if (stmt->IsLoop()) {
- return GenerateLoopStatement(stmt->AsLoop());
+ if (auto* l = stmt->As<ast::LoopStatement>()) {
+ return GenerateLoopStatement(l);
}
- if (stmt->IsReturn()) {
- return GenerateReturnStatement(stmt->AsReturn());
+ if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ return GenerateReturnStatement(r);
}
- if (stmt->IsSwitch()) {
- return GenerateSwitchStatement(stmt->AsSwitch());
+ if (auto* s = stmt->As<ast::SwitchStatement>()) {
+ return GenerateSwitchStatement(s);
}
- if (stmt->IsVariableDecl()) {
- return GenerateVariableDeclStatement(stmt->AsVariableDecl());
+ if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+ return GenerateVariableDeclStatement(v);
}
error_ = "Unknown statement: " + stmt->str();
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index f523b71..a8ed01c 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -22,11 +22,19 @@
#include <vector>
#include "spirv/unified1/spirv.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
#include "src/ast/builtin.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
+#include "src/ast/if_statement.h"
#include "src/ast/literal.h"
+#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
+#include "src/ast/return_statement.h"
#include "src/ast/struct_member.h"
+#include "src/ast/switch_statement.h"
#include "src/ast/type/access_control_type.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
@@ -35,6 +43,7 @@
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
+#include "src/ast/variable_decl_statement.h"
#include "src/context.h"
#include "src/scope_stack.h"
#include "src/writer/spirv/function.h"
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index bcff24f..5fa37cb 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -796,46 +796,46 @@
}
bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
- if (stmt->IsAssign()) {
- return EmitAssign(stmt->AsAssign());
+ if (auto* a = stmt->As<ast::AssignmentStatement>()) {
+ return EmitAssign(a);
}
- if (stmt->IsBlock()) {
- return EmitIndentedBlockAndNewline(stmt->AsBlock());
+ if (auto* b = stmt->As<ast::BlockStatement>()) {
+ return EmitIndentedBlockAndNewline(b);
}
- if (stmt->IsBreak()) {
- return EmitBreak(stmt->AsBreak());
+ if (auto* b = stmt->As<ast::BreakStatement>()) {
+ return EmitBreak(b);
}
- if (stmt->IsCall()) {
+ if (auto* c = stmt->As<ast::CallStatement>()) {
make_indent();
- if (!EmitCall(stmt->AsCall()->expr())) {
+ if (!EmitCall(c->expr())) {
return false;
}
out_ << ";" << std::endl;
return true;
}
- if (stmt->IsContinue()) {
- return EmitContinue(stmt->AsContinue());
+ if (auto* c = stmt->As<ast::ContinueStatement>()) {
+ return EmitContinue(c);
}
- if (stmt->IsDiscard()) {
- return EmitDiscard(stmt->AsDiscard());
+ if (auto* d = stmt->As<ast::DiscardStatement>()) {
+ return EmitDiscard(d);
}
- if (stmt->IsFallthrough()) {
- return EmitFallthrough(stmt->AsFallthrough());
+ if (auto* f = stmt->As<ast::FallthroughStatement>()) {
+ return EmitFallthrough(f);
}
- if (stmt->IsIf()) {
- return EmitIf(stmt->AsIf());
+ if (auto* i = stmt->As<ast::IfStatement>()) {
+ return EmitIf(i);
}
- if (stmt->IsLoop()) {
- return EmitLoop(stmt->AsLoop());
+ if (auto* l = stmt->As<ast::LoopStatement>()) {
+ return EmitLoop(l);
}
- if (stmt->IsReturn()) {
- return EmitReturn(stmt->AsReturn());
+ if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ return EmitReturn(r);
}
- if (stmt->IsSwitch()) {
- return EmitSwitch(stmt->AsSwitch());
+ if (auto* s = stmt->As<ast::SwitchStatement>()) {
+ return EmitSwitch(s);
}
- if (stmt->IsVariableDecl()) {
- return EmitVariable(stmt->AsVariableDecl()->variable());
+ if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
+ return EmitVariable(v->variable());
}
error_ = "unknown statement type: " + stmt->str();
diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h
index bd2451b..77fd16d 100644
--- a/src/writer/wgsl/generator_impl.h
+++ b/src/writer/wgsl/generator_impl.h
@@ -19,10 +19,20 @@
#include <string>
#include "src/ast/array_accessor_expression.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/break_statement.h"
+#include "src/ast/case_statement.h"
#include "src/ast/constructor_expression.h"
+#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h"
+#include "src/ast/if_statement.h"
+#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
+#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/switch_statement.h"
#include "src/ast/type/storage_texture_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/type.h"