Resolver: Validate usage of break
Also remove unused fields of Resolver (block_to_info_, block_infos_). We can put them back when they're actually needed.
Fixed: tint:190
Change-Id: I1a02a24eca7fba32b8e1120abb88040138a39c6a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44051
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index cfb3e19..da144c7 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -66,9 +66,8 @@
Resolver::~Resolver() = default;
Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty,
- Resolver::BlockInfo* p,
- const ast::BlockStatement* b)
- : type(ty), parent(p), block(b) {}
+ Resolver::BlockInfo* p)
+ : type(ty), parent(p) {}
Resolver::BlockInfo::~BlockInfo() = default;
@@ -150,12 +149,8 @@
}
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
- auto* block =
- block_infos_.Create(BlockInfo::Type::Generic, current_block_, stmt);
- block_to_info_[stmt] = block;
- ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
-
- return Statements(stmt->list());
+ return BlockScope(BlockInfo::Type::kGeneric,
+ [&] { return Statements(stmt->list()); });
}
bool Resolver::Statements(const ast::StatementList& stmts) {
@@ -219,18 +214,24 @@
return BlockStatement(b);
}
if (stmt->Is<ast::BreakStatement>()) {
+ if (!current_block_->FindFirstParent(BlockInfo::Type::kLoop) &&
+ !current_block_->FindFirstParent(BlockInfo::Type::kSwitchCase)) {
+ diagnostics_.add_error("break statement must be in a loop or switch case",
+ stmt->source());
+ return false;
+ }
return true;
}
if (auto* c = stmt->As<ast::CallStatement>()) {
return Expression(c->expr());
}
if (auto* c = stmt->As<ast::CaseStatement>()) {
- return BlockStatement(c->body());
+ return CaseStatement(c);
}
if (stmt->Is<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop
if (auto* loop_block =
- current_block_->FindFirstParent(BlockInfo::Type::Loop)) {
+ current_block_->FindFirstParent(BlockInfo::Type::kLoop)) {
if (loop_block->first_continue == size_t(~0)) {
loop_block->first_continue = loop_block->decls.size();
}
@@ -268,26 +269,20 @@
// these would make their BlockInfo siblings as in the AST, but we want the
// body BlockInfo to parent the continuing BlockInfo for semantics and
// validation. Also, we need to set their types differently.
- auto* block =
- block_infos_.Create(BlockInfo::Type::Loop, current_block_, l->body());
- block_to_info_[l->body()] = block;
- ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
-
- if (!Statements(l->body()->list())) {
- return false;
- }
-
- if (l->has_continuing()) {
- auto* cont_block = block_infos_.Create(BlockInfo::Type::LoopContinuing,
- current_block_, l->continuing());
- block_to_info_[l->continuing()] = cont_block;
- ScopedAssignment<BlockInfo*> scope_sa2(current_block_, cont_block);
-
- if (!Statements(l->continuing()->list())) {
+ return BlockScope(BlockInfo::Type::kLoop, [&] {
+ if (!Statements(l->body()->list())) {
return false;
}
- }
- return true;
+
+ if (l->has_continuing()) {
+ if (!BlockScope(BlockInfo::Type::kLoopContinuing,
+ [&] { return Statements(l->continuing()->list()); })) {
+ return false;
+ }
+ }
+
+ return true;
+ });
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return Expression(r->value());
@@ -297,7 +292,7 @@
return false;
}
for (auto* case_stmt : s->body()) {
- if (!Statement(case_stmt)) {
+ if (!CaseStatement(case_stmt)) {
return false;
}
}
@@ -316,6 +311,11 @@
return false;
}
+bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
+ return BlockScope(BlockInfo::Type::kSwitchCase,
+ [&] { return Statements(stmt->body()->list()); });
+}
+
bool Resolver::Expressions(const ast::ExpressionList& list) {
for (auto* expr : list) {
if (!Expression(expr)) {
@@ -395,8 +395,7 @@
} else if (auto* arr = parent_type->As<type::Array>()) {
if (!arr->type()->is_scalar()) {
// If we extract a non-scalar from an array then we also get a pointer. We
- // will generate a Function storage class variable to store this
- // into.
+ // will generate a Function storage class variable to store this into.
ret = builder_->create<type::Pointer>(ret, ast::StorageClass::kFunction);
}
}
@@ -573,9 +572,9 @@
// refer to a variable that is bypassed by a continue statement in the
// loop's body block.
if (auto* continuing_block =
- current_block_->FindFirstParent(BlockInfo::Type::LoopContinuing)) {
+ current_block_->FindFirstParent(BlockInfo::Type::kLoopContinuing)) {
auto* loop_block =
- continuing_block->FindFirstParent(BlockInfo::Type::Loop);
+ continuing_block->FindFirstParent(BlockInfo::Type::kLoop);
if (loop_block->first_continue != size_t(~0)) {
auto& decls = loop_block->decls;
// If our identifier is in loop_block->decls, make sure its index is
@@ -946,6 +945,13 @@
}
}
+template <typename F>
+bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
+ BlockInfo block_info(type, current_block_);
+ ScopedAssignment<BlockInfo*> sa(current_block_, &block_info);
+ return callback();
+}
+
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
: declaration(decl), storage_class(decl->declared_storage_class()) {}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index f07fab5..4a095be 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -33,6 +33,7 @@
class BinaryExpression;
class BitcastExpression;
class CallExpression;
+class CaseStatement;
class ConstructorExpression;
class Function;
class IdentifierExpression;
@@ -105,9 +106,9 @@
/// parent block and variables declared in the block.
/// Used to validate variable scoping rules.
struct BlockInfo {
- enum class Type { Generic, Loop, LoopContinuing };
+ enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
- BlockInfo(Type type, BlockInfo* parent, const ast::BlockStatement* block);
+ BlockInfo(Type type, BlockInfo* parent);
~BlockInfo();
template <typename Pred>
@@ -124,9 +125,8 @@
[ty](auto* block_info) { return block_info->type == ty; });
}
- const Type type;
- BlockInfo* parent;
- const ast::BlockStatement* block;
+ Type const type;
+ BlockInfo* const parent;
std::vector<const ast::Variable*> decls;
// first_continue is set to the index of the first variable in decls
@@ -134,9 +134,6 @@
constexpr static size_t kNoContinue = size_t(~0);
size_t first_continue = kNoContinue;
};
- std::unordered_map<const ast::BlockStatement*, BlockInfo*> block_to_info_;
- BlockAllocator<BlockInfo> block_infos_;
- BlockInfo* current_block_ = nullptr;
/// Resolves the program, without creating final the semantic nodes.
/// @returns true on success, false on error
@@ -200,6 +197,7 @@
bool Binary(ast::BinaryExpression* expr);
bool Bitcast(ast::BitcastExpression* expr);
bool Call(ast::CallExpression* expr);
+ bool CaseStatement(ast::CaseStatement* stmt);
bool Constructor(ast::ConstructorExpression* expr);
bool Identifier(ast::IdentifierExpression* expr);
bool IntrinsicCall(ast::CallExpression* call,
@@ -221,9 +219,16 @@
/// @param type the resolved type
void SetType(ast::Expression* expr, type::Type* type);
+ /// Constructs a new BlockInfo with the given type and with #current_block_ as
+ /// its parent, assigns this to #current_block_, and then calls `callback`.
+ /// The original #current_block_ is restored on exit.
+ template <typename F>
+ bool BlockScope(BlockInfo::Type type, F&& callback);
+
ProgramBuilder* const builder_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
diag::List diagnostics_;
+ BlockInfo* current_block_ = nullptr;
ScopeStack<VariableInfo*> variable_stack_;
std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
std::unordered_map<ast::Function*, FunctionInfo*> function_to_info_;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index d7c971d..d3613cd 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -17,6 +17,7 @@
#include "gmock/gmock.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
+#include "src/ast/break_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/if_statement.h"
@@ -476,10 +477,37 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+TEST_F(ResolverTest, Stmt_ContinueInLoop) {
+ WrapInFunction(Loop(Block(create<ast::ContinueStatement>(Source{{12, 34}}))));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
TEST_F(ResolverTest, Stmt_ContinueNotInLoop) {
- WrapInFunction(create<ast::ContinueStatement>());
+ WrapInFunction(create<ast::ContinueStatement>(Source{{12, 34}}));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "error: continue statement must be in a loop");
+ EXPECT_EQ(r()->error(), "12:34 error: continue statement must be in a loop");
+}
+
+TEST_F(ResolverTest, Stmt_BreakInLoop) {
+ WrapInFunction(Loop(Block(create<ast::BreakStatement>(Source{{12, 34}}))));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_BreakInSwitch) {
+ WrapInFunction(Loop(Block(create<ast::SwitchStatement>(
+ Expr(1), ast::CaseStatementList{
+ create<ast::CaseStatement>(
+ ast::CaseSelectorList{Literal(1)},
+ Block(create<ast::BreakStatement>(Source{{12, 34}}))),
+ }))));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_BreakNotInLoopOrSwitch) {
+ WrapInFunction(create<ast::BreakStatement>(Source{{12, 34}}));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: break statement must be in a loop or switch case");
}
TEST_F(ResolverTest, Stmt_Return) {