resolver: Implement Behavior Analysis
This change implements the behavior analysis for expressions and
statements as described in:
https://www.w3.org/TR/WGSL/#behaviors
This CL makes no changes to the validation rules. This will be done as a
followup change.
Bug: tint:1302
Change-Id: If0a251a7982ea15ff5d93b54a5cc5ed03ba60608
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68408
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index e587fb5..7a3bb13 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -689,6 +689,7 @@
resolver/pipeline_overridable_constant_test.cc
resolver/ptr_ref_test.cc
resolver/ptr_ref_validation_test.cc
+ resolver/resolver_behavior_test.cc
resolver/resolver_constants_test.cc
resolver/resolver_test_helper.cc
resolver/resolver_test_helper.h
diff --git a/src/program_builder.h b/src/program_builder.h
index bc7a52c..4335126 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -32,6 +32,7 @@
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
#include "src/ast/depth_multisampled_texture.h"
#include "src/ast/depth_texture.h"
#include "src/ast/disable_validation_decoration.h"
@@ -1864,6 +1865,19 @@
/// @returns the break statement pointer
const ast::BreakStatement* Break() { return create<ast::BreakStatement>(); }
+ /// Creates an ast::ContinueStatement
+ /// @param source the source information
+ /// @returns the continue statement pointer
+ const ast::ContinueStatement* Continue(const Source& source) {
+ return create<ast::ContinueStatement>(source);
+ }
+
+ /// Creates an ast::ContinueStatement
+ /// @returns the continue statement pointer
+ const ast::ContinueStatement* Continue() {
+ return create<ast::ContinueStatement>();
+ }
+
/// Creates an ast::ReturnStatement with no return value
/// @param source the source information
/// @returns the return statement pointer
@@ -2041,6 +2055,13 @@
body);
}
+ /// Creates a ast::ElseStatement with no condition and body
+ /// @param body the else body
+ /// @returns the else statement pointer
+ const ast::ElseStatement* Else(const ast::BlockStatement* body) {
+ return create<ast::ElseStatement>(nullptr, body);
+ }
+
/// Creates a ast::IfStatement with input condition, body, and optional
/// variadic else statements
/// @param condition the if statement condition expression
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 36ca107..6ac7835 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -657,11 +657,21 @@
<< "Resolver::Function() called with a current compound statement";
return nullptr;
}
- if (!StatementScope(decl->body,
- builder_->create<sem::FunctionBlockStatement>(func),
- [&] { return Statements(decl->body->statements); })) {
+ auto* body = StatementScope(
+ decl->body, builder_->create<sem::FunctionBlockStatement>(func),
+ [&] { return Statements(decl->body->statements); });
+ if (!body) {
return nullptr;
}
+ func->Behaviors() = body->Behaviors();
+ if (func->Behaviors().Contains(sem::Behavior::kReturn)) {
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // We assign a behavior to each function: it is its body’s behavior
+ // (treating the body as a regular statement), with any "Return" replaced
+ // by "Next".
+ func->Behaviors().Remove(sem::Behavior::kReturn);
+ func->Behaviors().Add(sem::Behavior::kNext);
+ }
}
for (auto* deco : decl->decorations) {
@@ -797,13 +807,22 @@
}
bool Resolver::Statements(const ast::StatementList& stmts) {
+ sem::Behaviors behaviors{sem::Behavior::kNext};
+
for (auto* stmt : stmts) {
Mark(stmt);
auto* sem = Statement(stmt);
if (!sem) {
return false;
}
+ // s1 s2:(B1∖{Next}) ∪ B2
+ // ValidateStatements will ensure that statements can only follow a Next.
+ behaviors.Remove(sem::Behavior::kNext);
+ behaviors.Add(sem->Behaviors());
}
+
+ current_statement_->Behaviors() = behaviors;
+
if (!ValidateStatements(stmts)) {
return false;
}
@@ -887,6 +906,7 @@
return false;
}
sem->SetBlock(body);
+ sem->Behaviors() = body->Behaviors();
return true;
});
}
@@ -900,6 +920,8 @@
return false;
}
sem->SetCondition(cond);
+ sem->Behaviors() = cond->Behaviors();
+ sem->Behaviors().Remove(sem::Behavior::kNext);
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
@@ -908,12 +930,23 @@
[&] { return Statements(stmt->body->statements); })) {
return false;
}
+ sem->Behaviors().Add(body->Behaviors());
for (auto* else_stmt : stmt->else_statements) {
Mark(else_stmt);
- if (!ElseStatement(else_stmt)) {
+ auto* else_sem = ElseStatement(else_stmt);
+ if (!else_sem) {
return false;
}
+ sem->Behaviors().Add(else_sem->Behaviors());
+ }
+
+ if (stmt->else_statements.empty() ||
+ stmt->else_statements.back()->condition != nullptr) {
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // if statements without an else branch are treated as if they had an
+ // empty else branch (which adds Next to their behavior)
+ sem->Behaviors().Add(sem::Behavior::kNext);
}
return ValidateIfStatement(sem);
@@ -930,7 +963,12 @@
return false;
}
sem->SetCondition(cond);
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // if statements with else if branches are treated as if they were nested
+ // simple if/else statements
+ sem->Behaviors() = cond->Behaviors();
}
+ sem->Behaviors().Remove(sem::Behavior::kNext);
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
@@ -939,6 +977,7 @@
[&] { return Statements(stmt->body->statements); })) {
return false;
}
+ sem->Behaviors().Add(body->Behaviors());
return ValidateElseStatement(sem);
});
@@ -964,20 +1003,32 @@
if (!Statements(stmt->body->statements)) {
return false;
}
+ auto& behaviors = sem->Behaviors();
+ behaviors = body->Behaviors();
if (stmt->continuing) {
Mark(stmt->continuing);
if (!stmt->continuing->Empty()) {
- auto* continuing =
+ auto* continuing = StatementScope(
+ stmt->continuing,
builder_->create<sem::LoopContinuingBlockStatement>(
stmt->continuing, current_compound_statement_,
- current_function_);
- return StatementScope(stmt->continuing, continuing, [&] {
- return Statements(stmt->continuing->statements);
- }) != nullptr;
+ current_function_),
+ [&] { return Statements(stmt->continuing->statements); });
+ if (!continuing) {
+ return false;
+ }
+ behaviors.Add(continuing->Behaviors());
}
}
+ if (behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
+ behaviors.Add(sem::Behavior::kNext);
+ } else {
+ behaviors.Remove(sem::Behavior::kNext);
+ }
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
return true;
});
});
@@ -988,11 +1039,14 @@
auto* sem = builder_->create<sem::ForLoopStatement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
+ auto& behaviors = sem->Behaviors();
if (auto* initializer = stmt->initializer) {
Mark(initializer);
- if (!Statement(initializer)) {
+ auto* init = Statement(initializer);
+ if (!init) {
return false;
}
+ behaviors.Add(init->Behaviors());
}
if (auto* cond_expr = stmt->condition) {
@@ -1001,13 +1055,16 @@
return false;
}
sem->SetCondition(cond);
+ behaviors.Add(cond->Behaviors());
}
if (auto* continuing = stmt->continuing) {
Mark(continuing);
- if (!Statement(continuing)) {
+ auto* cont = Statement(continuing);
+ if (!cont) {
return false;
}
+ behaviors.Add(cont->Behaviors());
}
Mark(stmt->body);
@@ -1019,6 +1076,15 @@
return false;
}
+ behaviors.Add(body->Behaviors());
+ if (stmt->condition ||
+ behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
+ behaviors.Add(sem::Behavior::kNext);
+ } else {
+ behaviors.Remove(sem::Behavior::kNext);
+ }
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
return ValidateForLoopStatement(sem);
});
}
@@ -1072,6 +1138,19 @@
if (!sem_expr) {
return nullptr;
}
+
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // an expression behavior is always either {Next} or {Next, Discard}
+ if (sem_expr->Behaviors() != sem::Behavior::kNext &&
+ sem_expr->Behaviors() != sem::Behaviors{sem::Behavior::kNext, // NOLINT
+ sem::Behavior::kDiscard} &&
+ !IsCallStatement(expr)) {
+ TINT_ICE(Resolver, diagnostics_)
+ << expr->TypeInfo().name
+ << " behaviors are: " << sem_expr->Behaviors();
+ return nullptr;
+ }
+
builder_->Sem().Add(expr, sem_expr);
if (expr == root) {
return sem_expr;
@@ -1084,52 +1163,57 @@
sem::Expression* Resolver::IndexAccessor(
const ast::IndexAccessorExpression* expr) {
- auto* idx = expr->index;
- auto* parent_raw_ty = TypeOf(expr->object);
- auto* parent_ty = parent_raw_ty->UnwrapRef();
+ auto* idx = Sem(expr->index);
+ auto* obj = Sem(expr->object);
+ auto* obj_raw_ty = obj->Type();
+ auto* obj_ty = obj_raw_ty->UnwrapRef();
const sem::Type* ty = nullptr;
- if (auto* arr = parent_ty->As<sem::Array>()) {
+ if (auto* arr = obj_ty->As<sem::Array>()) {
ty = arr->ElemType();
- } else if (auto* vec = parent_ty->As<sem::Vector>()) {
+ } else if (auto* vec = obj_ty->As<sem::Vector>()) {
ty = vec->type();
- } else if (auto* mat = parent_ty->As<sem::Matrix>()) {
+ } else if (auto* mat = obj_ty->As<sem::Matrix>()) {
ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
} else {
- AddError("cannot index type '" + TypeNameOf(parent_ty) + "'", expr->source);
+ AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source);
return nullptr;
}
- auto* idx_ty = TypeOf(idx)->UnwrapRef();
+ auto* idx_ty = idx->Type()->UnwrapRef();
if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
AddError("index must be of type 'i32' or 'u32', found: '" +
TypeNameOf(idx_ty) + "'",
- idx->source);
+ idx->Declaration()->source);
return nullptr;
}
- if (parent_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
- if (!parent_raw_ty->Is<sem::Reference>()) {
+ if (obj_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
+ if (!obj_raw_ty->Is<sem::Reference>()) {
// TODO(bclayton): expand this to allow any const_expr expression
// https://github.com/gpuweb/gpuweb/issues/1272
- if (!idx->As<ast::IntLiteralExpression>()) {
+ if (!idx->Declaration()->As<ast::IntLiteralExpression>()) {
AddError("index must be signed or unsigned integer literal",
- idx->source);
+ idx->Declaration()->source);
return nullptr;
}
}
}
// If we're extracting from a reference, we return a reference.
- if (auto* ref = parent_raw_ty->As<sem::Reference>()) {
+ if (auto* ref = obj_raw_ty->As<sem::Reference>()) {
ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
ref->Access());
}
auto val = EvaluateConstantValue(expr, ty);
- return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ auto* sem =
+ builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
+ return sem;
}
sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
+ auto* inner = Sem(expr->expr);
auto* ty = Type(expr->type);
if (!ty) {
return nullptr;
@@ -1140,12 +1224,17 @@
}
auto val = EvaluateConstantValue(expr, ty);
- return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ auto* sem =
+ builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ sem->Behaviors() = inner->Behaviors();
+ return sem;
}
sem::Call* Resolver::Call(const ast::CallExpression* expr) {
std::vector<const sem::Expression*> args(expr->args.size());
std::vector<const sem::Type*> arg_tys(args.size());
+ sem::Behaviors arg_behaviors;
+
for (size_t i = 0; i < expr->args.size(); i++) {
auto* arg = Sem(expr->args[i]);
if (!arg) {
@@ -1153,8 +1242,11 @@
}
args[i] = arg;
arg_tys[i] = args[i]->Type();
+ arg_behaviors.Add(arg->Behaviors());
}
+ arg_behaviors.Remove(sem::Behavior::kNext);
+
auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
// The call has resolved to a type constructor or cast.
if (args.size() == 1) {
@@ -1192,7 +1284,7 @@
}
if (auto* fn = As<sem::Function>(resolved)) {
- return FunctionCall(expr, fn, std::move(args));
+ return FunctionCall(expr, fn, std::move(args), arg_behaviors);
}
auto name = builder_->Symbols().NameFor(ident->symbol);
@@ -1247,7 +1339,8 @@
sem::Call* Resolver::FunctionCall(
const ast::CallExpression* expr,
sem::Function* target,
- const std::vector<const sem::Expression*> args) {
+ const std::vector<const sem::Expression*> args,
+ sem::Behaviors arg_behaviors) {
auto sym = expr->target.name->symbol;
auto name = builder_->Symbols().NameFor(sym);
@@ -1272,6 +1365,8 @@
target->AddCallSite(call);
+ call->Behaviors() = arg_behaviors + target->Behaviors();
+
if (!ValidateFunctionCall(call)) {
return nullptr;
}
@@ -1285,14 +1380,9 @@
const sem::Type* source) {
// It is not valid to have a type-cast call expression inside a call
// statement.
- if (current_statement_) {
- if (auto* stmt =
- current_statement_->Declaration()->As<ast::CallStatement>()) {
- if (stmt->expr == expr) {
- AddError("type cast evaluated but not used", expr->source);
- return nullptr;
- }
- }
+ if (IsCallStatement(expr)) {
+ AddError("type cast evaluated but not used", expr->source);
+ return nullptr;
}
auto* call_target = utils::GetOrCreate(
@@ -1349,14 +1439,9 @@
const std::vector<const sem::Type*> arg_tys) {
// It is not valid to have a type-constructor call expression as a call
// statement.
- if (current_statement_) {
- if (auto* stmt =
- current_statement_->Declaration()->As<ast::CallStatement>()) {
- if (stmt->expr == expr) {
- AddError("type constructor evaluated but not used", expr->source);
- return nullptr;
- }
- }
+ if (IsCallStatement(expr)) {
+ AddError("type constructor evaluated but not used", expr->source);
+ return nullptr;
}
auto* call_target = utils::GetOrCreate(
@@ -1619,8 +1704,11 @@
using Matrix = sem::Matrix;
using Vector = sem::Vector;
- auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef();
- auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef();
+ auto* lhs = Sem(expr->lhs);
+ auto* rhs = Sem(expr->rhs);
+
+ auto* lhs_ty = lhs->Type()->UnwrapRef();
+ auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@@ -1636,7 +1724,10 @@
auto build = [&](const sem::Type* ty) {
auto val = EvaluateConstantValue(expr, ty);
- return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ auto* sem =
+ builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
+ return sem;
};
// Binary logical expressions
@@ -1798,7 +1889,8 @@
}
sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
- auto* expr_ty = TypeOf(unary->expr);
+ auto* expr = Sem(unary->expr);
+ auto* expr_ty = expr->Type();
if (!expr_ty) {
return nullptr;
}
@@ -1880,7 +1972,10 @@
}
auto val = EvaluateConstantValue(unary, ty);
- return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
+ auto* sem =
+ builder_->create<sem::Expression>(unary, ty, current_statement_, val);
+ sem->Behaviors() = expr->Behaviors();
+ return sem;
}
sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
@@ -2248,10 +2343,15 @@
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
+ auto& behaviors = current_statement_->Behaviors();
+ behaviors = sem::Behavior::kReturn;
+
if (auto* value = stmt->value) {
- if (!Expression(value)) {
+ auto* expr = Expression(value);
+ if (!expr) {
return false;
}
+ behaviors.Add(expr->Behaviors() - sem::Behavior::kNext);
}
// Validate after processing the return value expression so that its type is
@@ -2265,17 +2365,28 @@
auto* sem = builder_->create<sem::SwitchStatement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
- if (!Expression(stmt->condition)) {
+ auto& behaviors = sem->Behaviors();
+
+ auto* cond = Expression(stmt->condition);
+ if (!cond) {
return false;
}
+ behaviors = cond->Behaviors() - sem::Behavior::kNext;
for (auto* case_stmt : stmt->body) {
Mark(case_stmt);
- if (!CaseStatement(case_stmt)) {
+ auto* c = CaseStatement(case_stmt);
+ if (!c) {
return false;
}
+ behaviors.Add(c->Behaviors());
}
+ if (behaviors.Contains(sem::Behavior::kBreak)) {
+ behaviors.Add(sem::Behavior::kNext);
+ }
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
+
return ValidateSwitch(stmt);
});
}
@@ -2304,6 +2415,10 @@
current_block_->AddDecl(stmt->variable);
}
+ if (auto* ctor = var->Constructor()) {
+ sem->Behaviors() = ctor->Behaviors();
+ }
+
return ValidateVariable(var);
});
}
@@ -2313,10 +2428,22 @@
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
- if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
+ auto* lhs = Expression(stmt->lhs);
+ if (!lhs) {
return false;
}
+ auto* rhs = Expression(stmt->rhs);
+ if (!rhs) {
+ return false;
+ }
+
+ auto& behaviors = sem->Behaviors();
+ behaviors = rhs->Behaviors();
+ if (!stmt->lhs->Is<ast::PhonyExpression>()) {
+ behaviors.Add(lhs->Behaviors());
+ }
+
return ValidateAssignment(stmt);
});
}
@@ -2324,13 +2451,23 @@
sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
- return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
+ return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kBreak;
+
+ return ValidateBreakStatement(sem);
+ });
}
sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
- return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
+ return StatementScope(stmt, sem, [&] {
+ if (auto* expr = Expression(stmt->expr)) {
+ sem->Behaviors() = expr->Behaviors();
+ return true;
+ }
+ return false;
+ });
}
sem::Statement* Resolver::ContinueStatement(
@@ -2338,6 +2475,8 @@
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kContinue;
+
// Set if we've hit the first continue statement in our parent loop
if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
if (!block->FirstContinue()) {
@@ -2354,6 +2493,7 @@
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kDiscard;
current_function_->SetHasDiscard();
return ValidateDiscardStatement(sem);
@@ -2365,6 +2505,8 @@
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
+ sem->Behaviors() = sem::Behavior::kFallthrough;
+
return ValidateFallthroughStatement(sem);
});
}
@@ -2512,6 +2654,12 @@
return sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone;
}
+bool Resolver::IsCallStatement(const ast::Expression* expr) const {
+ return current_statement_ &&
+ Is<ast::CallStatement>(current_statement_->Declaration(),
+ [&](auto* stmt) { return stmt->expr == expr; });
+}
+
////////////////////////////////////////////////////////////////////////////////
// Resolver::TypeConversionSig
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 84565ee..336e183 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -185,7 +185,8 @@
sem::Function* Function(const ast::Function*);
sem::Call* FunctionCall(const ast::CallExpression*,
sem::Function* target,
- const std::vector<const sem::Expression*> args);
+ const std::vector<const sem::Expression*> args,
+ sem::Behaviors arg_behaviors);
sem::Expression* Identifier(const ast::IdentifierExpression*);
sem::Call* IntrinsicCall(const ast::CallExpression*,
sem::IntrinsicType,
@@ -460,6 +461,9 @@
/// function.
bool IsIntrinsic(Symbol) const;
+ /// @returns true if `expr` is the current CallStatement's CallExpression
+ bool IsCallStatement(const ast::Expression* expr) const;
+
/// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type.
template <typename SEM = sem::Node>
diff --git a/src/resolver/resolver_behavior_test.cc b/src/resolver/resolver_behavior_test.cc
new file mode 100644
index 0000000..6ee656b
--- /dev/null
+++ b/src/resolver/resolver_behavior_test.cc
@@ -0,0 +1,687 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/resolver/resolver.h"
+
+#include "gtest/gtest.h"
+#include "src/resolver/resolver_test_helper.h"
+#include "src/sem/expression.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+class ResolverBehaviorTest : public ResolverTest {
+ protected:
+ void SetUp() override {
+ // Create a function called 'DiscardOrNext' which returns an i32, and has
+ // the behavior of {Discard, Return}, which when called, will have the
+ // behavior {Discard, Next}.
+ Func("DiscardOrNext", {}, ty.i32(),
+ {
+ If(true, Block(Discard())),
+ Return(1),
+ });
+ }
+};
+
+TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
+ auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1)));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
+ auto* stmt = Decl(Var("lhs", ty.i32(), Add(1, Call("DiscardOrNext"))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprBitcastOp) {
+ auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext"))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprIndex_Arr) {
+ Func("ArrayDiscardOrNext", {}, ty.array<i32, 4>(),
+ {
+ If(true, Block(Discard())),
+ Return(Construct(ty.array<i32, 4>())),
+ });
+
+ auto* stmt =
+ Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1)));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprIndex_Idx) {
+ auto* stmt =
+ Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext"))));
+ WrapInFunction(Decl(Var("arr", ty.array<i32, 4>())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprUnaryOp) {
+ auto* stmt = Decl(Var("lhs", ty.i32(),
+ create<ast::UnaryOpExpression>(
+ ast::UnaryOp::kComplement, Call("DiscardOrNext"))));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign) {
+ auto* stmt = Assign("lhs", "rhs");
+ WrapInFunction(Decl(Var("lhs", ty.i32())), //
+ Decl(Var("rhs", ty.i32())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
+ auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1);
+ WrapInFunction(Decl(Var("lhs", ty.array<i32, 4>())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) {
+ auto* stmt = Assign("lhs", Call("DiscardOrNext"));
+ WrapInFunction(Decl(Var("lhs", ty.i32())), //
+ stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtBlockEmpty) {
+ auto* stmt = Block();
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) {
+ auto* stmt = Block(Discard());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallReturn) {
+ Func("f", {}, ty.void_(), {Return()});
+ auto* stmt = CallStmt(Call("f"));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
+ Func("f", {}, ty.void_(), {Discard()});
+ auto* stmt = CallStmt(Call("f"));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) {
+ auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+ nullptr, Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtBreak) {
+ auto* stmt = Break();
+ WrapInFunction(Loop(Block(stmt)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kBreak);
+}
+
+TEST_F(ResolverBehaviorTest, StmtContinue) {
+ auto* stmt = Continue();
+ WrapInFunction(Loop(Block(stmt)));
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kContinue);
+}
+
+TEST_F(ResolverBehaviorTest, StmtDiscard) {
+ auto* stmt = Discard();
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopContinue) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Continue()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopReturn) {
+ auto* stmt = For(nullptr, nullptr, nullptr, Block(Return()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
+ auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+ nullptr, Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) {
+ auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+ nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondTrue) {
+ auto* stmt = For(nullptr, true, nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
+ auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1), nullptr, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak_ContCallFuncMayDiscard) {
+ auto* stmt =
+ For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_ContCallFuncMayDiscard) {
+ auto* stmt = For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
+ auto* stmt = If(true, Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
+ auto* stmt = If(true, Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
+ auto* stmt = If(true, Block(), Else(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
+ auto* stmt = If(true, Block(Discard()), Else(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
+ auto* stmt = If(Equal(Call("DiscardOrNext"), 1), Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) {
+ auto* stmt = If(true, Block(), //
+ Else(Equal(Call("DiscardOrNext"), 1), Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtLetDecl) {
+ auto* stmt = Decl(Const("v", ty.i32(), Expr(1)));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) {
+ auto* stmt = Decl(Const("lhs", ty.i32(), Call("DiscardOrNext")));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty) {
+ auto* stmt = Loop(Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopBreak) {
+ auto* stmt = Loop(Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopContinue) {
+ auto* stmt = Loop(Block(Continue()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopDiscard) {
+ auto* stmt = Loop(Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopReturn) {
+ auto* stmt = Loop(Block(Return()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContEmpty) {
+ auto* stmt = Loop(Block(), Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContBreak) {
+ auto* stmt = Loop(Block(), Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtReturn) {
+ auto* stmt = Return();
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) {
+ auto* stmt = Return(Call("DiscardOrNext"));
+ Func("F", {}, ty.i32(), {stmt});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kDiscard));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondTrue_DefaultEmpty) {
+ auto* stmt = Switch(1, DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultEmpty) {
+ auto* stmt = Switch(1, DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) {
+ auto* stmt = Switch(1, DefaultCase(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultReturn) {
+ auto* stmt = Switch(1, DefaultCase(Block(Return())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) {
+ auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
+ auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kDiscard));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) {
+ auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Return())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
+ auto* stmt = Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest,
+ StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
+ auto* stmt =
+ Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Discard())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest,
+ StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
+ auto* stmt =
+ Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Return())));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest,
+ StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) {
+ auto* stmt = Switch(1, //
+ Case(Expr(0), Block(Discard())), //
+ Case(Expr(1), Block(Return())), //
+ DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext,
+ sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) {
+ auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtVarDecl) {
+ auto* stmt = Decl(Var("v", ty.i32()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) {
+ auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext")));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(),
+ sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index eb2be86..ef365ce 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1042,6 +1042,18 @@
}
}
+ // https://www.w3.org/TR/WGSL/#behaviors-rules
+ // a function behavior is always one of {}, {Next}, {Discard}, or
+ // {Next, Discard}.
+ if (func->Behaviors() != sem::Behaviors{} && // NOLINT: bad warning
+ func->Behaviors() != sem::Behavior::kNext &&
+ func->Behaviors() != sem::Behavior::kDiscard &&
+ func->Behaviors() != sem::Behaviors{sem::Behavior::kNext, //
+ sem::Behavior::kDiscard}) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "function '" << name << "' behaviors are: " << func->Behaviors();
+ }
+
return true;
}
diff --git a/src/sem/expression.h b/src/sem/expression.h
index 06ae10b..b2ff4ac 100644
--- a/src/sem/expression.h
+++ b/src/sem/expression.h
@@ -68,7 +68,7 @@
const sem::Type* const type_;
const Statement* const statement_;
const Constant constant_;
- sem::Behaviors behaviors_;
+ sem::Behaviors behaviors_{sem::Behavior::kNext};
};
} // namespace sem
diff --git a/src/sem/function.h b/src/sem/function.h
index ea834a7..6d980c5 100644
--- a/src/sem/function.h
+++ b/src/sem/function.h
@@ -240,6 +240,12 @@
/// @returns true if this function has a discard statement
bool HasDiscard() const { return has_discard_; }
+ /// @return the behaviors of this function
+ const sem::Behaviors& Behaviors() const { return behaviors_; }
+
+ /// @return the behaviors of this function
+ sem::Behaviors& Behaviors() { return behaviors_; }
+
private:
VariableBindings TransitivelyReferencedSamplerVariablesImpl(
ast::SamplerKind kind) const;
@@ -257,6 +263,7 @@
std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_;
bool has_discard_ = false;
+ sem::Behaviors behaviors_{sem::Behavior::kNext};
};
} // namespace sem
diff --git a/src/sem/statement.h b/src/sem/statement.h
index e1c5160..1468da9 100644
--- a/src/sem/statement.h
+++ b/src/sem/statement.h
@@ -110,8 +110,7 @@
const ast::Statement* const declaration_;
const CompoundStatement* const parent_;
const sem::Function* const function_;
-
- sem::Behaviors behaviors_;
+ sem::Behaviors behaviors_{sem::Behavior::kNext};
};
/// CompoundStatement is the base class of statements that can hold other
diff --git a/test/BUILD.gn b/test/BUILD.gn
index bfaa498..c2cc482 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -253,6 +253,7 @@
"../src/resolver/pipeline_overridable_constant_test.cc",
"../src/resolver/ptr_ref_test.cc",
"../src/resolver/ptr_ref_validation_test.cc",
+ "../src/resolver/resolver_behavior_test.cc",
"../src/resolver/resolver_constants_test.cc",
"../src/resolver/resolver_test.cc",
"../src/resolver/resolver_test_helper.cc",