sem: Split BlockStatement up into subclasses Allows us to put block-type-specific data on the specific subtype instead of littering a common base class Change-Id: If4a327a8ee52d5911308f38b518ec07c3ceebcb7 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51367 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: David Neto <dneto@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 2450f14..d05d058 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc
@@ -1251,8 +1251,8 @@ << "Resolver::Function() called with a current statement"; return false; } - sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( - func->body(), nullptr, sem::BlockStatement::Type::kGeneric); + auto* sem_block = + builder_->create<sem::BlockStatement>(func->body(), nullptr); builder_->Sem().Add(func->body(), sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); if (!BlockScope(func->body(), @@ -1379,8 +1379,7 @@ sem::Statement* sem_statement; if (stmt->As<ast::BlockStatement>()) { sem_statement = builder_->create<sem::BlockStatement>( - stmt->As<ast::BlockStatement>(), current_statement_, - sem::BlockStatement::Type::kGeneric); + stmt->As<ast::BlockStatement>(), current_statement_); } else { sem_statement = builder_->create<sem::Statement>(stmt, current_statement_); } @@ -1403,9 +1402,8 @@ return BlockScope(b, [&] { return Statements(b->list()); }); } if (stmt->Is<ast::BreakStatement>()) { - if (!current_block_->FindFirstParent(sem::BlockStatement::Type::kLoop) && - !current_block_->FindFirstParent( - sem::BlockStatement::Type::kSwitchCase)) { + if (!current_block_->FindFirstParent<sem::LoopBlockStatement>() && + !current_block_->FindFirstParent<sem::SwitchCaseBlockStatement>()) { diagnostics_.add_error("break statement must be in a loop or switch case", stmt->source()); return false; @@ -1422,9 +1420,9 @@ if (stmt->Is<ast::ContinueStatement>()) { // Set if we've hit the first continue statement in our parent loop if (auto* loop_block = - current_block_->FindFirstParent(sem::BlockStatement::Type::kLoop)) { + current_block_->FindFirstParent<sem::LoopBlockStatement>()) { if (loop_block->FirstContinue() == size_t(~0)) { - const_cast<sem::BlockStatement*>(loop_block) + const_cast<sem::LoopBlockStatement*>(loop_block) ->SetFirstContinue(loop_block->Decls().size()); } } else { @@ -1468,8 +1466,8 @@ for (auto* sel : stmt->selectors()) { Mark(sel); } - sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( - stmt->body(), current_statement_, sem::BlockStatement::Type::kSwitchCase); + auto* sem_block = builder_->create<sem::SwitchCaseBlockStatement>( + stmt->body(), current_statement_); builder_->Sem().Add(stmt->body(), sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); return BlockScope(stmt->body(), @@ -1492,8 +1490,8 @@ Mark(stmt->body()); { - sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( - stmt->body(), current_statement_, sem::BlockStatement::Type::kGeneric); + auto* sem_block = + builder_->create<sem::BlockStatement>(stmt->body(), current_statement_); builder_->Sem().Add(stmt->body(), sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); if (!BlockScope(stmt->body(), @@ -1525,9 +1523,8 @@ } Mark(else_stmt->body()); { - sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( - else_stmt->body(), current_statement_, - sem::BlockStatement::Type::kGeneric); + auto* sem_block = builder_->create<sem::BlockStatement>( + else_stmt->body(), current_statement_); builder_->Sem().Add(else_stmt->body(), sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); if (!BlockScope(else_stmt->body(), @@ -1546,8 +1543,8 @@ // validation. Also, we need to set their types differently. Mark(stmt->body()); - auto* sem_block_body = builder_->create<sem::BlockStatement>( - stmt->body(), current_statement_, sem::BlockStatement::Type::kLoop); + auto* sem_block_body = builder_->create<sem::LoopBlockStatement>( + stmt->body(), current_statement_); builder_->Sem().Add(stmt->body(), sem_block_body); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_body); return BlockScope(stmt->body(), [&] { @@ -1558,9 +1555,9 @@ Mark(stmt->continuing()); } if (stmt->has_continuing()) { - auto* sem_block_continuing = builder_->create<sem::BlockStatement>( - stmt->continuing(), current_statement_, - sem::BlockStatement::Type::kLoopContinuing); + auto* sem_block_continuing = + builder_->create<sem::LoopContinuingBlockStatement>( + stmt->continuing(), current_statement_); builder_->Sem().Add(stmt->continuing(), sem_block_continuing); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_continuing); if (!BlockScope(stmt->continuing(), @@ -1909,10 +1906,11 @@ // If identifier is part of a loop continuing block, make sure it // doesn't refer to a variable that is bypassed by a continue statement // in the loop's body block. - if (auto* continuing_block = current_block_->FindFirstParent( - sem::BlockStatement::Type::kLoopContinuing)) { + if (auto* continuing_block = + current_block_ + ->FindFirstParent<sem::LoopContinuingBlockStatement>()) { auto* loop_block = - continuing_block->FindFirstParent(sem::BlockStatement::Type::kLoop); + continuing_block->FindFirstParent<sem::LoopBlockStatement>(); if (loop_block->FirstContinue() != size_t(~0)) { auto& decls = loop_block->Decls(); // If our identifier is in loop_block->decls, make sure its index is
diff --git a/src/sem/block_statement.cc b/src/sem/block_statement.cc index 8d63f0d..bc4c845 100644 --- a/src/sem/block_statement.cc +++ b/src/sem/block_statement.cc
@@ -17,35 +17,47 @@ #include "src/ast/block_statement.h" TINT_INSTANTIATE_TYPEINFO(tint::sem::BlockStatement); +TINT_INSTANTIATE_TYPEINFO(tint::sem::LoopBlockStatement); +TINT_INSTANTIATE_TYPEINFO(tint::sem::LoopContinuingBlockStatement); +TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement); namespace tint { namespace sem { BlockStatement::BlockStatement(const ast::BlockStatement* declaration, - const Statement* parent, - Type type) - : Base(declaration, parent), type_(type) {} + const Statement* parent) + : Base(declaration, parent) {} BlockStatement::~BlockStatement() = default; -const BlockStatement* BlockStatement::FindFirstParent( - BlockStatement::Type ty) const { - return FindFirstParent( - [ty](auto* block_info) { return block_info->type_ == ty; }); -} - const ast::BlockStatement* BlockStatement::Declaration() const { return Base::Declaration()->As<ast::BlockStatement>(); } -void BlockStatement::SetFirstContinue(size_t first_continue) { - TINT_ASSERT(type_ == Type::kLoop); - first_continue_ = first_continue; -} - void BlockStatement::AddDecl(ast::Variable* var) { decls_.push_back(var); } +LoopBlockStatement::LoopBlockStatement(const ast::BlockStatement* declaration, + const Statement* parent) + : Base(declaration, parent) {} +LoopBlockStatement::~LoopBlockStatement() = default; + +void LoopBlockStatement::SetFirstContinue(size_t first_continue) { + first_continue_ = first_continue; +} + +LoopContinuingBlockStatement::LoopContinuingBlockStatement( + const ast::BlockStatement* declaration, + const Statement* parent) + : Base(declaration, parent) {} +LoopContinuingBlockStatement::~LoopContinuingBlockStatement() = default; + +SwitchCaseBlockStatement::SwitchCaseBlockStatement( + const ast::BlockStatement* declaration, + const Statement* parent) + : Base(declaration, parent) {} +SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default; + } // namespace sem } // namespace tint
diff --git a/src/sem/block_statement.h b/src/sem/block_statement.h index 5e0f481..7989df5 100644 --- a/src/sem/block_statement.h +++ b/src/sem/block_statement.h
@@ -17,7 +17,6 @@ #include <vector> -#include "src/debug.h" #include "src/sem/statement.h" namespace tint { @@ -34,15 +33,11 @@ /// declared in the block. class BlockStatement : public Castable<BlockStatement, Statement> { public: - enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase }; - /// Constructor /// @param declaration the AST node for this block statement /// @param parent the owning statement - /// @param type the type of block this is BlockStatement(const ast::BlockStatement* declaration, - const Statement* parent, - Type type); + const Statement* parent); /// Destructor ~BlockStatement() override; @@ -52,7 +47,7 @@ const ast::BlockStatement* Declaration() const; /// @returns the closest enclosing block that satisfies the given predicate, - /// which may be the block itself, or nullptr if no match is found + /// which may be the block itself, or nullptr if no match is found /// @param pred a predicate that the resulting block must satisfy template <typename Pred> const BlockStatement* FindFirstParent(Pred&& pred) const { @@ -63,21 +58,47 @@ return curr; } - /// @returns the closest enclosing block that matches the given type, which - /// may be the block itself, or nullptr if no match is found - /// @param ty the type of block to be searched for - const BlockStatement* FindFirstParent(BlockStatement::Type ty) const; + /// @returns the statement itself if it matches the template type `T`, + /// otherwise the nearest enclosing block that matches `T`, or nullptr if + /// there is none. + template <typename T> + const T* FindFirstParent() const { + const BlockStatement* curr = this; + while (curr) { + if (auto* block = curr->As<T>()) { + return block; + } + curr = curr->Block(); + } + return nullptr; + } /// @returns the declarations associated with this block const std::vector<const ast::Variable*>& Decls() const { return decls_; } - /// Requires that this is a loop block. + /// Associates a declaration with this block. + /// @param var a variable declaration to be added to the block + void AddDecl(ast::Variable* var); + + private: + std::vector<const ast::Variable*> decls_; +}; + +/// Holds semantic information about a loop block +class LoopBlockStatement : public Castable<LoopBlockStatement, BlockStatement> { + public: + /// Constructor + /// @param declaration the AST node for this block statement + /// @param parent the owning statement + LoopBlockStatement(const ast::BlockStatement* declaration, + const Statement* parent); + + /// Destructor + ~LoopBlockStatement() override; + /// @returns the index of the first variable declared after the first continue - /// statement - size_t FirstContinue() const { - TINT_ASSERT(type_ == Type::kLoop); - return first_continue_; - } + /// statement + size_t FirstContinue() const { return first_continue_; } /// Requires that this is a loop block. /// Allows the resolver to set the index of the first variable declared after @@ -85,20 +106,41 @@ /// @param first_continue index of the relevant variable void SetFirstContinue(size_t first_continue); - /// Allows the resolver to associate a declaration with this block. - /// @param var a variable declaration to be added to the block - void AddDecl(ast::Variable* var); - private: - Type const type_; - std::vector<const ast::Variable*> decls_; - // first_continue is set to the index of the first variable in decls // declared after the first continue statement in a loop block, if any. constexpr static size_t kNoContinue = size_t(~0); size_t first_continue_ = kNoContinue; }; +/// Holds semantic information about a loop continuing block +class LoopContinuingBlockStatement + : public Castable<LoopContinuingBlockStatement, BlockStatement> { + public: + /// Constructor + /// @param declaration the AST node for this block statement + /// @param parent the owning statement + LoopContinuingBlockStatement(const ast::BlockStatement* declaration, + const Statement* parent); + + /// Destructor + ~LoopContinuingBlockStatement() override; +}; + +/// Holds semantic information about a switch case block +class SwitchCaseBlockStatement + : public Castable<SwitchCaseBlockStatement, BlockStatement> { + public: + /// Constructor + /// @param declaration the AST node for this block statement + /// @param parent the owning statement + SwitchCaseBlockStatement(const ast::BlockStatement* declaration, + const Statement* parent); + + /// Destructor + ~SwitchCaseBlockStatement() override; +}; + } // namespace sem } // namespace tint