resolver: Refactor Statement handling
Break up Resolver::Statement() into multiple resolver functions.
Move simple statement validation out to resolver_validation.cc
Change-Id: Ifa29433af0a9afa39a66ac3e4f7ca376351adfbf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71102
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 173220c..5d4575d 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -654,9 +654,9 @@
<< "Resolver::Function() called with a current compound statement";
return nullptr;
}
- auto* sem_block = builder_->create<sem::FunctionBlockStatement>(func);
- builder_->Sem().Add(decl->body, sem_block);
- if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) {
+ if (!StatementScope(decl->body,
+ builder_->create<sem::FunctionBlockStatement>(func),
+ [&] { return Statements(decl->body->statements); })) {
return nullptr;
}
}
@@ -796,7 +796,8 @@
bool Resolver::Statements(const ast::StatementList& stmts) {
for (auto* stmt : stmts) {
Mark(stmt);
- if (!Statement(stmt)) {
+ auto* sem = Statement(stmt);
+ if (!sem) {
return false;
}
}
@@ -807,18 +808,18 @@
return true;
}
-bool Resolver::Statement(const ast::Statement* stmt) {
+sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
if (stmt->Is<ast::CaseStatement>()) {
AddError("case statement can only be used inside a switch statement",
stmt->source);
- return false;
+ return nullptr;
}
if (stmt->Is<ast::ElseStatement>()) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Statement() encountered an Else statement. Else "
"statements are embedded in If statements, so should never be "
"encountered as top-level statements";
- return false;
+ return nullptr;
}
// Compound statements. These create their own sem::CompoundStatement
@@ -840,69 +841,26 @@
}
// Non-Compound statements
- sem::Statement* sem_statement = builder_->create<sem::Statement>(
- stmt, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem_statement);
- TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement);
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
- return Assignment(a);
+ return AssignmentStatement(a);
}
- if (stmt->Is<ast::BreakStatement>()) {
- if (!sem_statement->FindFirstParent<sem::LoopBlockStatement>() &&
- !sem_statement->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
- AddError("break statement must be in a loop or switch case",
- stmt->source);
- return false;
- }
- return true;
+ if (auto* b = stmt->As<ast::BreakStatement>()) {
+ return BreakStatement(b);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
- if (!Expression(c->expr)) {
- return false;
- }
- return true;
+ return CallStatement(c);
}
if (auto* c = stmt->As<ast::ContinueStatement>()) {
- // Set if we've hit the first continue statement in our parent loop
- if (auto* block =
- current_block_->FindFirstParent<
- sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
- if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
- if (!loop_block->FirstContinue()) {
- const_cast<sem::LoopBlockStatement*>(loop_block)
- ->SetFirstContinue(c, loop_block->Decls().size());
- }
- } else {
- AddError("continuing blocks must not contain a continue statement",
- stmt->source);
- return false;
- }
- } else {
- AddError("continue statement must be in a loop", stmt->source);
- return false;
- }
-
- return true;
+ return ContinueStatement(c);
}
- if (stmt->Is<ast::DiscardStatement>()) {
- if (auto* continuing =
- sem_statement
- ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
- AddError("continuing blocks must not contain a discard statement",
- stmt->source);
- if (continuing != sem_statement->Parent()) {
- AddNote("see continuing block here", continuing->Declaration()->source);
- }
- return false;
- }
- current_function_->SetHasDiscard();
- return true;
+ if (auto* d = stmt->As<ast::DiscardStatement>()) {
+ return DiscardStatement(d);
}
- if (stmt->Is<ast::FallthroughStatement>()) {
- return true;
+ if (auto* f = stmt->As<ast::FallthroughStatement>()) {
+ return FallthroughStatement(f);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
- return Return(r);
+ return ReturnStatement(r);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return VariableDeclStatement(v);
@@ -910,43 +868,38 @@
AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
- return false;
+ return nullptr;
}
-bool Resolver::CaseStatement(const ast::CaseStatement* stmt) {
+sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
+ const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem);
- builder_->Sem().Add(stmt->body, sem);
- Mark(stmt->body);
- for (auto* sel : stmt->selectors) {
- Mark(sel);
- }
- return Scope(sem, [&] { return Statements(stmt->body->statements); });
+ return StatementScope(stmt, sem, [&] {
+ builder_->Sem().Add(stmt->body, sem);
+ Mark(stmt->body);
+ for (auto* sel : stmt->selectors) {
+ Mark(sel);
+ }
+ return Statements(stmt->body->statements);
+ });
}
-bool Resolver::IfStatement(const ast::IfStatement* stmt) {
+sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
auto* sem = builder_->create<sem::IfStatement>(
stmt, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem);
- return Scope(sem, [&] {
- if (!Expression(stmt->condition)) {
+ return StatementScope(stmt, sem, [&] {
+ auto* cond = Expression(stmt->condition);
+ if (!cond) {
return false;
}
-
- auto* cond_type = TypeOf(stmt->condition)->UnwrapRef();
- if (!cond_type->Is<sem::Bool>()) {
- AddError(
- "if statement condition must be bool, got " + TypeNameOf(cond_type),
- stmt->condition->source);
- return false;
- }
+ sem->SetCondition(cond);
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
stmt->body, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt->body, body);
- if (!Scope(body, [&] { return Statements(stmt->body->statements); })) {
+ if (!StatementScope(stmt->body, body,
+ [&] { return Statements(stmt->body->statements); })) {
return false;
}
@@ -956,59 +909,56 @@
return false;
}
}
- return true;
+
+ return ValidateIfStatement(sem);
});
}
-bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
+sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
auto* sem = builder_->create<sem::ElseStatement>(
stmt, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem);
- return Scope(sem, [&] {
- if (auto* cond = stmt->condition) {
- if (!Expression(cond)) {
+ return StatementScope(stmt, sem, [&] {
+ if (auto* cond_expr = stmt->condition) {
+ auto* cond = Expression(cond_expr);
+ if (!cond) {
return false;
}
-
- auto* else_cond_type = TypeOf(cond)->UnwrapRef();
- if (!else_cond_type->Is<sem::Bool>()) {
- AddError("else statement condition must be bool, got " +
- TypeNameOf(else_cond_type),
- cond->source);
- return false;
- }
+ sem->SetCondition(cond);
}
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
stmt->body, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt->body, body);
- return Scope(body, [&] { return Statements(stmt->body->statements); });
+ if (!StatementScope(stmt->body, body,
+ [&] { return Statements(stmt->body->statements); })) {
+ return false;
+ }
+
+ return ValidateElseStatement(sem);
});
}
-bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
+sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
auto* sem = builder_->create<sem::BlockStatement>(
stmt->As<ast::BlockStatement>(), current_compound_statement_,
current_function_);
- builder_->Sem().Add(stmt, sem);
- return Scope(sem, [&] { return Statements(stmt->statements); });
+ return StatementScope(stmt, sem,
+ [&] { return Statements(stmt->statements); });
}
-bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
+sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
auto* sem = builder_->create<sem::LoopStatement>(
stmt, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem);
- return Scope(sem, [&] {
+ return StatementScope(stmt, sem, [&] {
Mark(stmt->body);
auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt->body, body);
- return Scope(body, [&] {
+ return StatementScope(stmt->body, body, [&] {
if (!Statements(stmt->body->statements)) {
return false;
}
+
if (stmt->continuing) {
Mark(stmt->continuing);
if (!stmt->continuing->Empty()) {
@@ -1016,24 +966,22 @@
builder_->create<sem::LoopContinuingBlockStatement>(
stmt->continuing, current_compound_statement_,
current_function_);
- builder_->Sem().Add(stmt->continuing, continuing);
- if (!Scope(continuing, [&] {
- return Statements(stmt->continuing->statements);
- })) {
- return false;
- }
+ return StatementScope(stmt->continuing, continuing, [&] {
+ return Statements(stmt->continuing->statements);
+ }) != nullptr;
}
}
+
return true;
});
});
}
-bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
+sem::ForLoopStatement* Resolver::ForLoopStatement(
+ const ast::ForLoopStatement* stmt) {
auto* sem = builder_->create<sem::ForLoopStatement>(
stmt, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem);
- return Scope(sem, [&] {
+ return StatementScope(stmt, sem, [&] {
if (auto* initializer = stmt->initializer) {
Mark(initializer);
if (!Statement(initializer)) {
@@ -1041,17 +989,12 @@
}
}
- if (auto* condition = stmt->condition) {
- if (!Expression(condition)) {
+ if (auto* cond_expr = stmt->condition) {
+ auto* cond = Expression(cond_expr);
+ if (!cond) {
return false;
}
-
- auto* cond_ty = TypeOf(condition)->UnwrapRef();
- if (!cond_ty->Is<sem::Bool>()) {
- AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
- condition->source);
- return false;
- }
+ sem->SetCondition(cond);
}
if (auto* continuing = stmt->continuing) {
@@ -1065,8 +1008,12 @@
auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt->body, body);
- return Scope(body, [&] { return Statements(stmt->body->statements); });
+ if (!StatementScope(stmt->body, body,
+ [&] { return Statements(stmt->body->statements); })) {
+ return false;
+ }
+
+ return ValidateForLoopStatement(sem);
});
}
@@ -1930,33 +1877,6 @@
return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
}
-bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
- Mark(stmt->variable);
-
- auto* var = Variable(stmt->variable, VariableKind::kLocal);
- if (!var) {
- return false;
- }
-
- for (auto* deco : stmt->variable->decorations) {
- Mark(deco);
- if (!deco->Is<ast::InternalDecoration>()) {
- AddError("decorations are not valid on local variables", deco->source);
- return false;
- }
- }
-
- if (current_block_) { // Not all statements are inside a block
- current_block_->AddDecl(stmt->variable);
- }
-
- if (!ValidateVariable(var)) {
- return false;
- }
-
- return true;
-}
-
sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
sem::Type* result = nullptr;
if (auto* alias = named_type->As<ast::Alias>()) {
@@ -2318,45 +2238,127 @@
return out;
}
-bool Resolver::Return(const ast::ReturnStatement* ret) {
- if (auto* value = ret->value) {
- if (!Expression(value)) {
- return false;
+sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ if (auto* value = stmt->value) {
+ if (!Expression(value)) {
+ return false;
+ }
}
- }
- // Validate after processing the return value expression so that its type is
- // available for validation.
- return ValidateReturn(ret);
+ // Validate after processing the return value expression so that its type is
+ // available for validation.
+ return ValidateReturn(stmt);
+ });
}
-bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
+sem::SwitchStatement* Resolver::SwitchStatement(
+ const ast::SwitchStatement* stmt) {
auto* sem = builder_->create<sem::SwitchStatement>(
stmt, current_compound_statement_, current_function_);
- builder_->Sem().Add(stmt, sem);
- return Scope(sem, [&] {
+ return StatementScope(stmt, sem, [&] {
if (!Expression(stmt->condition)) {
return false;
}
+
for (auto* case_stmt : stmt->body) {
Mark(case_stmt);
if (!CaseStatement(case_stmt)) {
return false;
}
}
- if (!ValidateSwitch(stmt)) {
- return false;
- }
- return true;
+
+ return ValidateSwitch(stmt);
});
}
-bool Resolver::Assignment(const ast::AssignmentStatement* a) {
- if (!Expression(a->lhs) || !Expression(a->rhs)) {
- return false;
- }
+sem::Statement* Resolver::VariableDeclStatement(
+ const ast::VariableDeclStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ Mark(stmt->variable);
- return ValidateAssignment(a);
+ auto* var = Variable(stmt->variable, VariableKind::kLocal);
+ if (!var) {
+ return false;
+ }
+
+ for (auto* deco : stmt->variable->decorations) {
+ Mark(deco);
+ if (!deco->Is<ast::InternalDecoration>()) {
+ AddError("decorations are not valid on local variables", deco->source);
+ return false;
+ }
+ }
+
+ if (current_block_) { // Not all statements are inside a block
+ current_block_->AddDecl(stmt->variable);
+ }
+
+ return ValidateVariable(var);
+ });
+}
+
+sem::Statement* Resolver::AssignmentStatement(
+ const ast::AssignmentStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
+ return false;
+ }
+
+ return ValidateAssignment(stmt);
+ });
+}
+
+sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
+}
+
+sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
+}
+
+sem::Statement* Resolver::ContinueStatement(
+ const ast::ContinueStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ // Set if we've hit the first continue statement in our parent loop
+ if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
+ if (!block->FirstContinue()) {
+ const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
+ stmt, block->Decls().size());
+ }
+ }
+
+ return ValidateContinueStatement(sem);
+ });
+}
+
+sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ current_function_->SetHasDiscard();
+
+ return ValidateDiscardStatement(sem);
+ });
+}
+
+sem::Statement* Resolver::FallthroughStatement(
+ const ast::FallthroughStatement* stmt) {
+ auto* sem = builder_->create<sem::Statement>(
+ stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] { return true; });
}
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
@@ -2399,22 +2401,28 @@
return true;
}
-template <typename F>
-bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) {
- auto* prev_current_statement = current_statement_;
- auto* prev_current_compound_statement = current_compound_statement_;
- auto* prev_current_block = current_block_;
- current_statement_ = stmt;
- current_compound_statement_ = stmt;
- current_block_ = stmt->As<sem::BlockStatement>();
+template <typename SEM, typename F>
+SEM* Resolver::StatementScope(const ast::Statement* ast,
+ SEM* sem,
+ F&& callback) {
+ builder_->Sem().Add(ast, sem);
- TINT_DEFER({
- current_block_ = prev_current_block;
- current_compound_statement_ = prev_current_compound_statement;
- current_statement_ = prev_current_statement;
- });
+ auto* as_compound =
+ As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
+ auto* as_block =
+ As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
- return callback();
+ TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
+ TINT_SCOPED_ASSIGNMENT(
+ current_compound_statement_,
+ as_compound ? as_compound : current_compound_statement_);
+ TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
+
+ if (!callback()) {
+ return nullptr;
+ }
+
+ return sem;
}
std::string Resolver::VectorPretty(uint32_t size,
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 08a9c94..f2a2fc8 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -59,8 +59,15 @@
namespace sem {
class Array;
class Atomic;
+class BlockStatement;
+class ElseStatement;
+class ForLoopStatement;
+class IfStatement;
class Intrinsic;
+class LoopStatement;
class Statement;
+class SwitchCaseBlockStatement;
+class SwitchStatement;
class TypeConstructor;
} // namespace sem
@@ -198,20 +205,26 @@
// Statement resolving methods
// Each return true on success, false on failure.
- bool Assignment(const ast::AssignmentStatement* a);
- bool BlockStatement(const ast::BlockStatement*);
- bool CaseStatement(const ast::CaseStatement*);
- bool ElseStatement(const ast::ElseStatement*);
- bool ForLoopStatement(const ast::ForLoopStatement*);
- bool Parameter(const ast::Variable* param);
- bool GlobalVariable(const ast::Variable* var);
- bool IfStatement(const ast::IfStatement*);
- bool LoopStatement(const ast::LoopStatement*);
- bool Return(const ast::ReturnStatement* ret);
- bool Statement(const ast::Statement*);
+ sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
+ sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
+ sem::Statement* BreakStatement(const ast::BreakStatement*);
+ sem::Statement* CallStatement(const ast::CallStatement*);
+ sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
+ sem::Statement* ContinueStatement(const ast::ContinueStatement*);
+ sem::Statement* DiscardStatement(const ast::DiscardStatement*);
+ sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
+ sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
+ sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
+ sem::Statement* Parameter(const ast::Variable*);
+ sem::IfStatement* IfStatement(const ast::IfStatement*);
+ sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
+ sem::Statement* ReturnStatement(const ast::ReturnStatement*);
+ sem::Statement* Statement(const ast::Statement*);
+ sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
+ sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&);
- bool SwitchStatement(const ast::SwitchStatement* s);
- bool VariableDeclStatement(const ast::VariableDeclStatement*);
+
+ bool GlobalVariable(const ast::Variable*);
// AST and Type validation methods
// Each return true on success, false on failure.
@@ -224,13 +237,19 @@
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
+ bool ValidateBreakStatement(const sem::Statement* stmt);
+ bool ValidateContinueStatement(const sem::Statement* stmt);
+ bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
+ bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func);
+ bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);
+ bool ValidateIfStatement(const sem::IfStatement* stmt);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
bool ValidateIntrinsicCall(const sem::Call* call);
@@ -369,14 +388,19 @@
/// @param lit the literal
sem::Type* TypeOf(const ast::LiteralExpression* lit);
- /// Assigns `stmt` to #current_statement_, #current_compound_statement_, and
- /// possibly #current_block_, pushes the variable scope, then calls
- /// `callback`. Before returning #current_statement_,
- /// #current_compound_statement_, and #current_block_ are restored to their
- /// original values, and the variable scope is popped.
- /// @returns the value returned by callback
- template <typename F>
- bool Scope(sem::CompoundStatement* stmt, F&& callback);
+ /// StatementScope() does the following:
+ /// * Creates the AST -> SEM mapping.
+ /// * Assigns `sem` to #current_statement_
+ /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
+ /// sem::CompoundStatement.
+ /// * Assigns `sem` to #current_block_ if `sem` derives from
+ /// sem::BlockStatement.
+ /// * Then calls `callback`.
+ /// * Before returning #current_statement_, #current_compound_statement_, and
+ /// #current_block_ are restored to their original values.
+ /// @returns `sem` if `callback` returns true, otherwise `nullptr`.
+ template <typename SEM, typename F>
+ SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index 5698138..b7cb04e 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1344,6 +1344,82 @@
return true;
}
+bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
+ if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
+ !stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
+ AddError("break statement must be in a loop or switch case",
+ stmt->Declaration()->source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
+ if (auto* block =
+ stmt->FindFirstParent<sem::LoopBlockStatement,
+ sem::LoopContinuingBlockStatement>()) {
+ if (block->Is<sem::LoopContinuingBlockStatement>()) {
+ AddError("continuing blocks must not contain a continue statement",
+ stmt->Declaration()->source);
+ return false;
+ }
+ } else {
+ AddError("continue statement must be in a loop",
+ stmt->Declaration()->source);
+ return false;
+ }
+
+ return true;
+}
+
+bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
+ if (auto* continuing =
+ stmt->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+ AddError("continuing blocks must not contain a discard statement",
+ stmt->Declaration()->source);
+ if (continuing != stmt->Parent()) {
+ AddNote("see continuing block here", continuing->Declaration()->source);
+ }
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
+ if (auto* cond = stmt->Condition()) {
+ auto* cond_ty = cond->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError(
+ "else statement condition must be bool, got " + TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
+ if (auto* cond = stmt->Condition()) {
+ auto* cond_ty = cond->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
+ auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ return true;
+}
+
bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
@@ -2103,8 +2179,8 @@
}
bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
- auto* cond_type = TypeOf(s->condition)->UnwrapRef();
- if (!cond_type->is_integer_scalar()) {
+ auto* cond_ty = TypeOf(s->condition)->UnwrapRef();
+ if (!cond_ty->is_integer_scalar()) {
AddError(
"switch statement selector expression must be of a "
"scalar integer type",
@@ -2127,7 +2203,7 @@
}
for (auto* selector : case_stmt->selectors) {
- if (cond_type != TypeOf(selector)) {
+ if (cond_ty != TypeOf(selector)) {
AddError(
"the case selector values must have the same "
"type as the selector expression.",
diff --git a/src/sem/for_loop_statement.h b/src/sem/for_loop_statement.h
index ff89241..2f287bb 100644
--- a/src/sem/for_loop_statement.h
+++ b/src/sem/for_loop_statement.h
@@ -21,6 +21,9 @@
namespace ast {
class ForLoopStatement;
} // namespace ast
+namespace sem {
+class Expression;
+} // namespace sem
} // namespace tint
namespace tint {
@@ -39,6 +42,16 @@
/// Destructor
~ForLoopStatement() override;
+
+ /// @returns the for-loop condition expression
+ const Expression* Condition() const { return condition_; }
+
+ /// Sets the for-loop condition expression
+ /// @param condition the for-loop condition expression
+ void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+ const Expression* condition_ = nullptr;
};
} // namespace sem
diff --git a/src/sem/if_statement.h b/src/sem/if_statement.h
index 6c25fca..a8c9c2e 100644
--- a/src/sem/if_statement.h
+++ b/src/sem/if_statement.h
@@ -23,6 +23,9 @@
class IfStatement;
class ElseStatement;
} // namespace ast
+namespace sem {
+class Expression;
+} // namespace sem
} // namespace tint
namespace tint {
@@ -41,6 +44,16 @@
/// Destructor
~IfStatement() override;
+
+ /// @returns the if-statement condition expression
+ const Expression* Condition() const { return condition_; }
+
+ /// Sets the if-statement condition expression
+ /// @param condition the if condition expression
+ void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+ const Expression* condition_ = nullptr;
};
/// Holds semantic information about an else statement
@@ -56,6 +69,16 @@
/// Destructor
~ElseStatement() override;
+
+ /// @returns the else-statement condition expression
+ const Expression* Condition() const { return condition_; }
+
+ /// Sets the else-statement condition expression
+ /// @param condition the else condition expression
+ void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+ const Expression* condition_ = nullptr;
};
} // namespace sem
diff --git a/src/sem/member_accessor_expression.h b/src/sem/member_accessor_expression.h
index fcc6c6f..6d444f0 100644
--- a/src/sem/member_accessor_expression.h
+++ b/src/sem/member_accessor_expression.h
@@ -82,7 +82,7 @@
/// Constructor
/// @param declaration the AST node
/// @param type the resolved type of the expression
- /// @param statement the statement that
+ /// @param statement the statement that owns this expression
/// @param indices the swizzle indices
Swizzle(const ast::MemberAccessorExpression* declaration,
const sem::Type* type,
diff --git a/src/sem/variable.h b/src/sem/variable.h
index a389eed..ac7dac1 100644
--- a/src/sem/variable.h
+++ b/src/sem/variable.h
@@ -234,7 +234,7 @@
const sem::Variable* Variable() const { return variable_; }
private:
- sem::Variable const* const variable_;
+ const sem::Variable* const variable_;
};
} // namespace sem