sem: Replace SwitchCaseBlockStatement with CaseStatement
The SwitchCaseBlockStatement was bound to the BlockStatement of an ast::CaseStatement, but we had nothing that mapped to the actual ast::CaseStatement.
sem::CaseStatement replaces sem::SwitchCaseBlockStatement, and has a Block() accessor, providing a superset of the old behavior.
With this, we can now easily validate the `fallthrough` rules directly, instead of scanning the switch case. This keeps the validation more tigtly coupled to the ast / sem nodes.
Change-Id: I0f22eba37bb164b9e071a6166c7a41fc1a5ac532
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71460
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/docs/compound_statements.md b/docs/compound_statements.md
index 7d28b1d..a113cce 100644
--- a/docs/compound_statements.md
+++ b/docs/compound_statements.md
@@ -105,11 +105,15 @@
```
sem::SwitchStatement {
sem::Expression condition
- sem::SwitchCaseBlockStatement {
- sem::Statement statement_a
+ sem::CaseStatement {
+ sem::BlockStatement {
+ sem::Statement statement_a
+ }
}
- sem::SwitchCaseBlockStatement {
- sem::Statement statement_b
+ sem::CaseStatement {
+ sem::BlockStatement {
+ sem::Statement statement_b
+ }
}
}
```
diff --git a/src/resolver/compound_statement_test.cc b/src/resolver/compound_statement_test.cc
index db91d7c..7de565c 100644
--- a/src/resolver/compound_statement_test.cc
+++ b/src/resolver/compound_statement_test.cc
@@ -343,31 +343,34 @@
{
auto* s = Sem().Get(stmt_a);
ASSERT_NE(s, nullptr);
- EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
- EXPECT_EQ(s->Parent()->Parent(),
- s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
}
{
auto* s = Sem().Get(stmt_b);
ASSERT_NE(s, nullptr);
- EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
- EXPECT_EQ(s->Parent()->Parent(),
- s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
}
{
auto* s = Sem().Get(stmt_c);
ASSERT_NE(s, nullptr);
- EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
- EXPECT_EQ(s->Parent()->Parent(),
- s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(),
+ s->FindFirstParent<sem::SwitchStatement>());
+ EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
}
}
diff --git a/src/resolver/control_block_validation_test.cc b/src/resolver/control_block_validation_test.cc
index 890892e..41389dd 100644
--- a/src/resolver/control_block_validation_test.cc
+++ b/src/resolver/control_block_validation_test.cc
@@ -294,8 +294,8 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: a fallthrough statement must not appear as the last "
- "statement in last clause of a switch");
+ "12:34 error: a fallthrough statement must not be used in the last "
+ "switch case");
}
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index e490815..36ca107 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -874,17 +874,20 @@
return nullptr;
}
-sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
- const ast::CaseStatement* stmt) {
- auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
- stmt->body, current_compound_statement_, current_function_);
+sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
+ auto* sem = builder_->create<sem::CaseStatement>(
+ stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
- builder_->Sem().Add(stmt->body, sem);
- Mark(stmt->body);
for (auto* sel : stmt->selectors) {
Mark(sel);
}
- return Statements(stmt->body->statements);
+ Mark(stmt->body);
+ auto* body = BlockStatement(stmt->body);
+ if (!body) {
+ return false;
+ }
+ sem->SetBlock(body);
+ return true;
});
}
@@ -2361,7 +2364,9 @@
const ast::FallthroughStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
- return StatementScope(stmt, sem, [&] { return true; });
+ return StatementScope(stmt, sem, [&] {
+ return ValidateFallthroughStatement(sem);
+ });
}
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index f2a2fc8..84565ee 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -60,13 +60,13 @@
class Array;
class Atomic;
class BlockStatement;
+class CaseStatement;
class ElseStatement;
class ForLoopStatement;
class IfStatement;
class Intrinsic;
class LoopStatement;
class Statement;
-class SwitchCaseBlockStatement;
class SwitchStatement;
class TypeConstructor;
} // namespace sem
@@ -209,7 +209,7 @@
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*);
- sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
+ sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
@@ -238,14 +238,15 @@
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBreakStatement(const sem::Statement* stmt);
- bool ValidateContinueStatement(const sem::Statement* stmt);
- bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
+ bool ValidateContinueStatement(const sem::Statement* stmt);
+ bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func);
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
+ bool ValidateFallthroughStatement(const sem::Statement* stmt);
bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index b7cb04e..eb2be86 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1346,7 +1346,7 @@
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
- !stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
+ !stmt->FindFirstParent<sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source);
return false;
@@ -1385,6 +1385,29 @@
return true;
}
+bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
+ if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
+ if (auto* c = As<sem::CaseStatement>(block->Parent())) {
+ if (block->Declaration()->Last() == stmt->Declaration()) {
+ if (auto* s = As<sem::SwitchStatement>(c->Parent())) {
+ if (c->Declaration() != s->Declaration()->body.back()) {
+ return true;
+ }
+ AddError(
+ "a fallthrough statement must not be used in the last switch "
+ "case",
+ stmt->Declaration()->source);
+ return false;
+ }
+ }
+ }
+ }
+ AddError(
+ "fallthrough must only be used as the last statement of a case block",
+ stmt->Declaration()->source);
+ return false;
+}
+
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
@@ -2231,18 +2254,6 @@
return false;
}
- if (!s->body.empty()) {
- auto* last_clause = s->body.back()->As<ast::CaseStatement>();
- auto* last_stmt = last_clause->body->Last();
- if (last_stmt && last_stmt->Is<ast::FallthroughStatement>()) {
- AddError(
- "a fallthrough statement must not appear as "
- "the last statement in last clause of a switch",
- last_stmt->source);
- return false;
- }
- }
-
return true;
}
diff --git a/src/sem/switch_statement.cc b/src/sem/switch_statement.cc
index fe13c3e..9a911a2 100644
--- a/src/sem/switch_statement.cc
+++ b/src/sem/switch_statement.cc
@@ -16,8 +16,8 @@
#include "src/program_builder.h"
+TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement);
-TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement);
namespace tint {
namespace sem {
@@ -32,15 +32,22 @@
SwitchStatement::~SwitchStatement() = default;
-SwitchCaseBlockStatement::SwitchCaseBlockStatement(
- const ast::BlockStatement* declaration,
- const CompoundStatement* parent,
- const sem::Function* function)
+const ast::SwitchStatement* SwitchStatement::Declaration() const {
+ return static_cast<const ast::SwitchStatement*>(Base::Declaration());
+}
+
+CaseStatement::CaseStatement(const ast::CaseStatement* declaration,
+ const CompoundStatement* parent,
+ const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
-SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;
+CaseStatement::~CaseStatement() = default;
+
+const ast::CaseStatement* CaseStatement::Declaration() const {
+ return static_cast<const ast::CaseStatement*>(Base::Declaration());
+}
} // namespace sem
} // namespace tint
diff --git a/src/sem/switch_statement.h b/src/sem/switch_statement.h
index 8e5a2cd..49da6e9 100644
--- a/src/sem/switch_statement.h
+++ b/src/sem/switch_statement.h
@@ -20,6 +20,7 @@
// Forward declarations
namespace tint {
namespace ast {
+class CaseStatement;
class SwitchStatement;
} // namespace ast
} // namespace tint
@@ -40,22 +41,36 @@
/// Destructor
~SwitchStatement() override;
+
+ /// @return the AST node for this statement
+ const ast::SwitchStatement* Declaration() const;
};
-/// Holds semantic information about a switch case block
-class SwitchCaseBlockStatement
- : public Castable<SwitchCaseBlockStatement, BlockStatement> {
+/// Holds semantic information about a switch case statement
+class CaseStatement : public Castable<CaseStatement, CompoundStatement> {
public:
/// Constructor
- /// @param declaration the AST node for this block statement
+ /// @param declaration the AST node for this case statement
/// @param parent the owning statement
/// @param function the owning function
- SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
- const CompoundStatement* parent,
- const sem::Function* function);
+ CaseStatement(const ast::CaseStatement* declaration,
+ const CompoundStatement* parent,
+ const sem::Function* function);
/// Destructor
- ~SwitchCaseBlockStatement() override;
+ ~CaseStatement() override;
+
+ /// @return the AST node for this statement
+ const ast::CaseStatement* Declaration() const;
+
+ /// @param body the case body block statement
+ void SetBlock(const BlockStatement* body) { body_ = body; }
+
+ /// @returns the case body block statement
+ const BlockStatement* Body() const { return body_; }
+
+ private:
+ const BlockStatement* body_ = nullptr;
};
} // namespace sem
diff --git a/src/writer/wgsl/generator_impl_fallthrough_test.cc b/src/writer/wgsl/generator_impl_fallthrough_test.cc
index 0325256..aa1d037 100644
--- a/src/writer/wgsl/generator_impl_fallthrough_test.cc
+++ b/src/writer/wgsl/generator_impl_fallthrough_test.cc
@@ -23,7 +23,9 @@
TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) {
auto* f = create<ast::FallthroughStatement>();
- WrapInFunction(f);
+ WrapInFunction(Switch(1, //
+ Case(Expr(1), Block(f)), //
+ DefaultCase()));
GeneratorImpl& gen = Build();