resolver: Implement Behavior Analysis

This change implements the behavior analysis for expressions and
statements as described in:
https://www.w3.org/TR/WGSL/#behaviors

This CL makes no changes to the validation rules. This will be done as a
followup change.

Bug: tint:1302
Change-Id: If0a251a7982ea15ff5d93b54a5cc5ed03ba60608
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68408
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index e587fb5..7a3bb13 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -689,6 +689,7 @@
     resolver/pipeline_overridable_constant_test.cc
     resolver/ptr_ref_test.cc
     resolver/ptr_ref_validation_test.cc
+    resolver/resolver_behavior_test.cc
     resolver/resolver_constants_test.cc
     resolver/resolver_test_helper.cc
     resolver/resolver_test_helper.h
diff --git a/src/program_builder.h b/src/program_builder.h
index bc7a52c..4335126 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -32,6 +32,7 @@
 #include "src/ast/call_expression.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/case_statement.h"
+#include "src/ast/continue_statement.h"
 #include "src/ast/depth_multisampled_texture.h"
 #include "src/ast/depth_texture.h"
 #include "src/ast/disable_validation_decoration.h"
@@ -1864,6 +1865,19 @@
   /// @returns the break statement pointer
   const ast::BreakStatement* Break() { return create<ast::BreakStatement>(); }
 
+  /// Creates an ast::ContinueStatement
+  /// @param source the source information
+  /// @returns the continue statement pointer
+  const ast::ContinueStatement* Continue(const Source& source) {
+    return create<ast::ContinueStatement>(source);
+  }
+
+  /// Creates an ast::ContinueStatement
+  /// @returns the continue statement pointer
+  const ast::ContinueStatement* Continue() {
+    return create<ast::ContinueStatement>();
+  }
+
   /// Creates an ast::ReturnStatement with no return value
   /// @param source the source information
   /// @returns the return statement pointer
@@ -2041,6 +2055,13 @@
                                       body);
   }
 
+  /// Creates a ast::ElseStatement with no condition and body
+  /// @param body the else body
+  /// @returns the else statement pointer
+  const ast::ElseStatement* Else(const ast::BlockStatement* body) {
+    return create<ast::ElseStatement>(nullptr, body);
+  }
+
   /// Creates a ast::IfStatement with input condition, body, and optional
   /// variadic else statements
   /// @param condition the if statement condition expression
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 36ca107..6ac7835 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -657,11 +657,21 @@
           << "Resolver::Function() called with a current compound statement";
       return nullptr;
     }
-    if (!StatementScope(decl->body,
-                        builder_->create<sem::FunctionBlockStatement>(func),
-                        [&] { return Statements(decl->body->statements); })) {
+    auto* body = StatementScope(
+        decl->body, builder_->create<sem::FunctionBlockStatement>(func),
+        [&] { return Statements(decl->body->statements); });
+    if (!body) {
       return nullptr;
     }
+    func->Behaviors() = body->Behaviors();
+    if (func->Behaviors().Contains(sem::Behavior::kReturn)) {
+      // https://www.w3.org/TR/WGSL/#behaviors-rules
+      // We assign a behavior to each function: it is its body’s behavior
+      // (treating the body as a regular statement), with any "Return" replaced
+      // by "Next".
+      func->Behaviors().Remove(sem::Behavior::kReturn);
+      func->Behaviors().Add(sem::Behavior::kNext);
+    }
   }
 
   for (auto* deco : decl->decorations) {
@@ -797,13 +807,22 @@
 }
 
 bool Resolver::Statements(const ast::StatementList& stmts) {
+  sem::Behaviors behaviors{sem::Behavior::kNext};
+
   for (auto* stmt : stmts) {
     Mark(stmt);
     auto* sem = Statement(stmt);
     if (!sem) {
       return false;
     }
+    // s1 s2:(B1∖{Next}) ∪ B2
+    // ValidateStatements will ensure that statements can only follow a Next.
+    behaviors.Remove(sem::Behavior::kNext);
+    behaviors.Add(sem->Behaviors());
   }
+
+  current_statement_->Behaviors() = behaviors;
+
   if (!ValidateStatements(stmts)) {
     return false;
   }
@@ -887,6 +906,7 @@
       return false;
     }
     sem->SetBlock(body);
+    sem->Behaviors() = body->Behaviors();
     return true;
   });
 }
@@ -900,6 +920,8 @@
       return false;
     }
     sem->SetCondition(cond);
+    sem->Behaviors() = cond->Behaviors();
+    sem->Behaviors().Remove(sem::Behavior::kNext);
 
     Mark(stmt->body);
     auto* body = builder_->create<sem::BlockStatement>(
@@ -908,12 +930,23 @@
                         [&] { return Statements(stmt->body->statements); })) {
       return false;
     }
+    sem->Behaviors().Add(body->Behaviors());
 
     for (auto* else_stmt : stmt->else_statements) {
       Mark(else_stmt);
-      if (!ElseStatement(else_stmt)) {
+      auto* else_sem = ElseStatement(else_stmt);
+      if (!else_sem) {
         return false;
       }
+      sem->Behaviors().Add(else_sem->Behaviors());
+    }
+
+    if (stmt->else_statements.empty() ||
+        stmt->else_statements.back()->condition != nullptr) {
+      // https://www.w3.org/TR/WGSL/#behaviors-rules
+      // if statements without an else branch are treated as if they had an
+      // empty else branch (which adds Next to their behavior)
+      sem->Behaviors().Add(sem::Behavior::kNext);
     }
 
     return ValidateIfStatement(sem);
@@ -930,7 +963,12 @@
         return false;
       }
       sem->SetCondition(cond);
+      // https://www.w3.org/TR/WGSL/#behaviors-rules
+      // if statements with else if branches are treated as if they were nested
+      // simple if/else statements
+      sem->Behaviors() = cond->Behaviors();
     }
+    sem->Behaviors().Remove(sem::Behavior::kNext);
 
     Mark(stmt->body);
     auto* body = builder_->create<sem::BlockStatement>(
@@ -939,6 +977,7 @@
                         [&] { return Statements(stmt->body->statements); })) {
       return false;
     }
+    sem->Behaviors().Add(body->Behaviors());
 
     return ValidateElseStatement(sem);
   });
@@ -964,20 +1003,32 @@
       if (!Statements(stmt->body->statements)) {
         return false;
       }
+      auto& behaviors = sem->Behaviors();
+      behaviors = body->Behaviors();
 
       if (stmt->continuing) {
         Mark(stmt->continuing);
         if (!stmt->continuing->Empty()) {
-          auto* continuing =
+          auto* continuing = StatementScope(
+              stmt->continuing,
               builder_->create<sem::LoopContinuingBlockStatement>(
                   stmt->continuing, current_compound_statement_,
-                  current_function_);
-          return StatementScope(stmt->continuing, continuing, [&] {
-                   return Statements(stmt->continuing->statements);
-                 }) != nullptr;
+                  current_function_),
+              [&] { return Statements(stmt->continuing->statements); });
+          if (!continuing) {
+            return false;
+          }
+          behaviors.Add(continuing->Behaviors());
         }
       }
 
+      if (behaviors.Contains(sem::Behavior::kBreak)) {  // Does the loop exit?
+        behaviors.Add(sem::Behavior::kNext);
+      } else {
+        behaviors.Remove(sem::Behavior::kNext);
+      }
+      behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
       return true;
     });
   });
@@ -988,11 +1039,14 @@
   auto* sem = builder_->create<sem::ForLoopStatement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
+    auto& behaviors = sem->Behaviors();
     if (auto* initializer = stmt->initializer) {
       Mark(initializer);
-      if (!Statement(initializer)) {
+      auto* init = Statement(initializer);
+      if (!init) {
         return false;
       }
+      behaviors.Add(init->Behaviors());
     }
 
     if (auto* cond_expr = stmt->condition) {
@@ -1001,13 +1055,16 @@
         return false;
       }
       sem->SetCondition(cond);
+      behaviors.Add(cond->Behaviors());
     }
 
     if (auto* continuing = stmt->continuing) {
       Mark(continuing);
-      if (!Statement(continuing)) {
+      auto* cont = Statement(continuing);
+      if (!cont) {
         return false;
       }
+      behaviors.Add(cont->Behaviors());
     }
 
     Mark(stmt->body);
@@ -1019,6 +1076,15 @@
       return false;
     }
 
+    behaviors.Add(body->Behaviors());
+    if (stmt->condition ||
+        behaviors.Contains(sem::Behavior::kBreak)) {  // Does the loop exit?
+      behaviors.Add(sem::Behavior::kNext);
+    } else {
+      behaviors.Remove(sem::Behavior::kNext);
+    }
+    behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
     return ValidateForLoopStatement(sem);
   });
 }
@@ -1072,6 +1138,19 @@
     if (!sem_expr) {
       return nullptr;
     }
+
+    // https://www.w3.org/TR/WGSL/#behaviors-rules
+    // an expression behavior is always either {Next} or {Next, Discard}
+    if (sem_expr->Behaviors() != sem::Behavior::kNext &&
+        sem_expr->Behaviors() != sem::Behaviors{sem::Behavior::kNext,  // NOLINT
+                                                sem::Behavior::kDiscard} &&
+        !IsCallStatement(expr)) {
+      TINT_ICE(Resolver, diagnostics_)
+          << expr->TypeInfo().name
+          << " behaviors are: " << sem_expr->Behaviors();
+      return nullptr;
+    }
+
     builder_->Sem().Add(expr, sem_expr);
     if (expr == root) {
       return sem_expr;
@@ -1084,52 +1163,57 @@
 
 sem::Expression* Resolver::IndexAccessor(
     const ast::IndexAccessorExpression* expr) {
-  auto* idx = expr->index;
-  auto* parent_raw_ty = TypeOf(expr->object);
-  auto* parent_ty = parent_raw_ty->UnwrapRef();
+  auto* idx = Sem(expr->index);
+  auto* obj = Sem(expr->object);
+  auto* obj_raw_ty = obj->Type();
+  auto* obj_ty = obj_raw_ty->UnwrapRef();
   const sem::Type* ty = nullptr;
-  if (auto* arr = parent_ty->As<sem::Array>()) {
+  if (auto* arr = obj_ty->As<sem::Array>()) {
     ty = arr->ElemType();
-  } else if (auto* vec = parent_ty->As<sem::Vector>()) {
+  } else if (auto* vec = obj_ty->As<sem::Vector>()) {
     ty = vec->type();
-  } else if (auto* mat = parent_ty->As<sem::Matrix>()) {
+  } else if (auto* mat = obj_ty->As<sem::Matrix>()) {
     ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
   } else {
-    AddError("cannot index type '" + TypeNameOf(parent_ty) + "'", expr->source);
+    AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source);
     return nullptr;
   }
 
-  auto* idx_ty = TypeOf(idx)->UnwrapRef();
+  auto* idx_ty = idx->Type()->UnwrapRef();
   if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
     AddError("index must be of type 'i32' or 'u32', found: '" +
                  TypeNameOf(idx_ty) + "'",
-             idx->source);
+             idx->Declaration()->source);
     return nullptr;
   }
 
-  if (parent_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
-    if (!parent_raw_ty->Is<sem::Reference>()) {
+  if (obj_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
+    if (!obj_raw_ty->Is<sem::Reference>()) {
       // TODO(bclayton): expand this to allow any const_expr expression
       // https://github.com/gpuweb/gpuweb/issues/1272
-      if (!idx->As<ast::IntLiteralExpression>()) {
+      if (!idx->Declaration()->As<ast::IntLiteralExpression>()) {
         AddError("index must be signed or unsigned integer literal",
-                 idx->source);
+                 idx->Declaration()->source);
         return nullptr;
       }
     }
   }
 
   // If we're extracting from a reference, we return a reference.
-  if (auto* ref = parent_raw_ty->As<sem::Reference>()) {
+  if (auto* ref = obj_raw_ty->As<sem::Reference>()) {
     ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
                                           ref->Access());
   }
 
   auto val = EvaluateConstantValue(expr, ty);
-  return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+  auto* sem =
+      builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+  sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
+  return sem;
 }
 
 sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
+  auto* inner = Sem(expr->expr);
   auto* ty = Type(expr->type);
   if (!ty) {
     return nullptr;
@@ -1140,12 +1224,17 @@
   }
 
   auto val = EvaluateConstantValue(expr, ty);
-  return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+  auto* sem =
+      builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+  sem->Behaviors() = inner->Behaviors();
+  return sem;
 }
 
 sem::Call* Resolver::Call(const ast::CallExpression* expr) {
   std::vector<const sem::Expression*> args(expr->args.size());
   std::vector<const sem::Type*> arg_tys(args.size());
+  sem::Behaviors arg_behaviors;
+
   for (size_t i = 0; i < expr->args.size(); i++) {
     auto* arg = Sem(expr->args[i]);
     if (!arg) {
@@ -1153,8 +1242,11 @@
     }
     args[i] = arg;
     arg_tys[i] = args[i]->Type();
+    arg_behaviors.Add(arg->Behaviors());
   }
 
+  arg_behaviors.Remove(sem::Behavior::kNext);
+
   auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
     // The call has resolved to a type constructor or cast.
     if (args.size() == 1) {
@@ -1192,7 +1284,7 @@
   }
 
   if (auto* fn = As<sem::Function>(resolved)) {
-    return FunctionCall(expr, fn, std::move(args));
+    return FunctionCall(expr, fn, std::move(args), arg_behaviors);
   }
 
   auto name = builder_->Symbols().NameFor(ident->symbol);
@@ -1247,7 +1339,8 @@
 sem::Call* Resolver::FunctionCall(
     const ast::CallExpression* expr,
     sem::Function* target,
-    const std::vector<const sem::Expression*> args) {
+    const std::vector<const sem::Expression*> args,
+    sem::Behaviors arg_behaviors) {
   auto sym = expr->target.name->symbol;
   auto name = builder_->Symbols().NameFor(sym);
 
@@ -1272,6 +1365,8 @@
 
   target->AddCallSite(call);
 
+  call->Behaviors() = arg_behaviors + target->Behaviors();
+
   if (!ValidateFunctionCall(call)) {
     return nullptr;
   }
@@ -1285,14 +1380,9 @@
                                     const sem::Type* source) {
   // It is not valid to have a type-cast call expression inside a call
   // statement.
-  if (current_statement_) {
-    if (auto* stmt =
-            current_statement_->Declaration()->As<ast::CallStatement>()) {
-      if (stmt->expr == expr) {
-        AddError("type cast evaluated but not used", expr->source);
-        return nullptr;
-      }
-    }
+  if (IsCallStatement(expr)) {
+    AddError("type cast evaluated but not used", expr->source);
+    return nullptr;
   }
 
   auto* call_target = utils::GetOrCreate(
@@ -1349,14 +1439,9 @@
     const std::vector<const sem::Type*> arg_tys) {
   // It is not valid to have a type-constructor call expression as a call
   // statement.
-  if (current_statement_) {
-    if (auto* stmt =
-            current_statement_->Declaration()->As<ast::CallStatement>()) {
-      if (stmt->expr == expr) {
-        AddError("type constructor evaluated but not used", expr->source);
-        return nullptr;
-      }
-    }
+  if (IsCallStatement(expr)) {
+    AddError("type constructor evaluated but not used", expr->source);
+    return nullptr;
   }
 
   auto* call_target = utils::GetOrCreate(
@@ -1619,8 +1704,11 @@
   using Matrix = sem::Matrix;
   using Vector = sem::Vector;
 
-  auto* lhs_ty = TypeOf(expr->lhs)->UnwrapRef();
-  auto* rhs_ty = TypeOf(expr->rhs)->UnwrapRef();
+  auto* lhs = Sem(expr->lhs);
+  auto* rhs = Sem(expr->rhs);
+
+  auto* lhs_ty = lhs->Type()->UnwrapRef();
+  auto* rhs_ty = rhs->Type()->UnwrapRef();
 
   auto* lhs_vec = lhs_ty->As<Vector>();
   auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@@ -1636,7 +1724,10 @@
 
   auto build = [&](const sem::Type* ty) {
     auto val = EvaluateConstantValue(expr, ty);
-    return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+    auto* sem =
+        builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+    sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
+    return sem;
   };
 
   // Binary logical expressions
@@ -1798,7 +1889,8 @@
 }
 
 sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
-  auto* expr_ty = TypeOf(unary->expr);
+  auto* expr = Sem(unary->expr);
+  auto* expr_ty = expr->Type();
   if (!expr_ty) {
     return nullptr;
   }
@@ -1880,7 +1972,10 @@
   }
 
   auto val = EvaluateConstantValue(unary, ty);
-  return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
+  auto* sem =
+      builder_->create<sem::Expression>(unary, ty, current_statement_, val);
+  sem->Behaviors() = expr->Behaviors();
+  return sem;
 }
 
 sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
@@ -2248,10 +2343,15 @@
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
+    auto& behaviors = current_statement_->Behaviors();
+    behaviors = sem::Behavior::kReturn;
+
     if (auto* value = stmt->value) {
-      if (!Expression(value)) {
+      auto* expr = Expression(value);
+      if (!expr) {
         return false;
       }
+      behaviors.Add(expr->Behaviors() - sem::Behavior::kNext);
     }
 
     // Validate after processing the return value expression so that its type is
@@ -2265,17 +2365,28 @@
   auto* sem = builder_->create<sem::SwitchStatement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
-    if (!Expression(stmt->condition)) {
+    auto& behaviors = sem->Behaviors();
+
+    auto* cond = Expression(stmt->condition);
+    if (!cond) {
       return false;
     }
+    behaviors = cond->Behaviors() - sem::Behavior::kNext;
 
     for (auto* case_stmt : stmt->body) {
       Mark(case_stmt);
-      if (!CaseStatement(case_stmt)) {
+      auto* c = CaseStatement(case_stmt);
+      if (!c) {
         return false;
       }
+      behaviors.Add(c->Behaviors());
     }
 
+    if (behaviors.Contains(sem::Behavior::kBreak)) {
+      behaviors.Add(sem::Behavior::kNext);
+    }
+    behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
+
     return ValidateSwitch(stmt);
   });
 }
@@ -2304,6 +2415,10 @@
       current_block_->AddDecl(stmt->variable);
     }
 
+    if (auto* ctor = var->Constructor()) {
+      sem->Behaviors() = ctor->Behaviors();
+    }
+
     return ValidateVariable(var);
   });
 }
@@ -2313,10 +2428,22 @@
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
-    if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
+    auto* lhs = Expression(stmt->lhs);
+    if (!lhs) {
       return false;
     }
 
+    auto* rhs = Expression(stmt->rhs);
+    if (!rhs) {
+      return false;
+    }
+
+    auto& behaviors = sem->Behaviors();
+    behaviors = rhs->Behaviors();
+    if (!stmt->lhs->Is<ast::PhonyExpression>()) {
+      behaviors.Add(lhs->Behaviors());
+    }
+
     return ValidateAssignment(stmt);
   });
 }
@@ -2324,13 +2451,23 @@
 sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
-  return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
+  return StatementScope(stmt, sem, [&] {
+    sem->Behaviors() = sem::Behavior::kBreak;
+
+    return ValidateBreakStatement(sem);
+  });
 }
 
 sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
-  return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
+  return StatementScope(stmt, sem, [&] {
+    if (auto* expr = Expression(stmt->expr)) {
+      sem->Behaviors() = expr->Behaviors();
+      return true;
+    }
+    return false;
+  });
 }
 
 sem::Statement* Resolver::ContinueStatement(
@@ -2338,6 +2475,8 @@
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
+    sem->Behaviors() = sem::Behavior::kContinue;
+
     // Set if we've hit the first continue statement in our parent loop
     if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
       if (!block->FirstContinue()) {
@@ -2354,6 +2493,7 @@
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
+    sem->Behaviors() = sem::Behavior::kDiscard;
     current_function_->SetHasDiscard();
 
     return ValidateDiscardStatement(sem);
@@ -2365,6 +2505,8 @@
   auto* sem = builder_->create<sem::Statement>(
       stmt, current_compound_statement_, current_function_);
   return StatementScope(stmt, sem, [&] {
+    sem->Behaviors() = sem::Behavior::kFallthrough;
+
     return ValidateFallthroughStatement(sem);
   });
 }
@@ -2512,6 +2654,12 @@
   return sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone;
 }
 
+bool Resolver::IsCallStatement(const ast::Expression* expr) const {
+  return current_statement_ &&
+         Is<ast::CallStatement>(current_statement_->Declaration(),
+                                [&](auto* stmt) { return stmt->expr == expr; });
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Resolver::TypeConversionSig
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 84565ee..336e183 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -185,7 +185,8 @@
   sem::Function* Function(const ast::Function*);
   sem::Call* FunctionCall(const ast::CallExpression*,
                           sem::Function* target,
-                          const std::vector<const sem::Expression*> args);
+                          const std::vector<const sem::Expression*> args,
+                          sem::Behaviors arg_behaviors);
   sem::Expression* Identifier(const ast::IdentifierExpression*);
   sem::Call* IntrinsicCall(const ast::CallExpression*,
                            sem::IntrinsicType,
@@ -460,6 +461,9 @@
   /// function.
   bool IsIntrinsic(Symbol) const;
 
+  /// @returns true if `expr` is the current CallStatement's CallExpression
+  bool IsCallStatement(const ast::Expression* expr) const;
+
   /// @returns the resolved symbol (function, type or variable) for the given
   /// ast::Identifier or ast::TypeName cast to the given semantic type.
   template <typename SEM = sem::Node>
diff --git a/src/resolver/resolver_behavior_test.cc b/src/resolver/resolver_behavior_test.cc
new file mode 100644
index 0000000..6ee656b
--- /dev/null
+++ b/src/resolver/resolver_behavior_test.cc
@@ -0,0 +1,687 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/resolver/resolver.h"
+
+#include "gtest/gtest.h"
+#include "src/resolver/resolver_test_helper.h"
+#include "src/sem/expression.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+class ResolverBehaviorTest : public ResolverTest {
+ protected:
+  void SetUp() override {
+    // Create a function called 'DiscardOrNext' which returns an i32, and has
+    // the behavior of {Discard, Return}, which when called, will have the
+    // behavior {Discard, Next}.
+    Func("DiscardOrNext", {}, ty.i32(),
+         {
+             If(true, Block(Discard())),
+             Return(1),
+         });
+  }
+};
+
+TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
+  auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1)));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
+  auto* stmt = Decl(Var("lhs", ty.i32(), Add(1, Call("DiscardOrNext"))));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprBitcastOp) {
+  auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext"))));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprIndex_Arr) {
+  Func("ArrayDiscardOrNext", {}, ty.array<i32, 4>(),
+       {
+           If(true, Block(Discard())),
+           Return(Construct(ty.array<i32, 4>())),
+       });
+
+  auto* stmt =
+      Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1)));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprIndex_Idx) {
+  auto* stmt =
+      Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext"))));
+  WrapInFunction(Decl(Var("arr", ty.array<i32, 4>())),  //
+                 stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, ExprUnaryOp) {
+  auto* stmt = Decl(Var("lhs", ty.i32(),
+                        create<ast::UnaryOpExpression>(
+                            ast::UnaryOp::kComplement, Call("DiscardOrNext"))));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign) {
+  auto* stmt = Assign("lhs", "rhs");
+  WrapInFunction(Decl(Var("lhs", ty.i32())),  //
+                 Decl(Var("rhs", ty.i32())),  //
+                 stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
+  auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1);
+  WrapInFunction(Decl(Var("lhs", ty.array<i32, 4>())),  //
+                 stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) {
+  auto* stmt = Assign("lhs", Call("DiscardOrNext"));
+  WrapInFunction(Decl(Var("lhs", ty.i32())),  //
+                 stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtBlockEmpty) {
+  auto* stmt = Block();
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) {
+  auto* stmt = Block(Discard());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallReturn) {
+  Func("f", {}, ty.void_(), {Return()});
+  auto* stmt = CallStmt(Call("f"));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
+  Func("f", {}, ty.void_(), {Discard()});
+  auto* stmt = CallStmt(Call("f"));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) {
+  auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+                   nullptr, Block(Break()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtBreak) {
+  auto* stmt = Break();
+  WrapInFunction(Loop(Block(stmt)));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kBreak);
+}
+
+TEST_F(ResolverBehaviorTest, StmtContinue) {
+  auto* stmt = Continue();
+  WrapInFunction(Loop(Block(stmt)));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kContinue);
+}
+
+TEST_F(ResolverBehaviorTest, StmtDiscard) {
+  auto* stmt = Discard();
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty) {
+  auto* stmt = For(nullptr, nullptr, nullptr, Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak) {
+  auto* stmt = For(nullptr, nullptr, nullptr, Block(Break()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopContinue) {
+  auto* stmt = For(nullptr, nullptr, nullptr, Block(Continue()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) {
+  auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopReturn) {
+  auto* stmt = For(nullptr, nullptr, nullptr, Block(Return()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
+  auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+                   nullptr, Block(Break()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) {
+  auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr,
+                   nullptr, Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondTrue) {
+  auto* stmt = For(nullptr, true, nullptr, Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
+  auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1), nullptr, Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopBreak_ContCallFuncMayDiscard) {
+  auto* stmt =
+      For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block(Break()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_ContCallFuncMayDiscard) {
+  auto* stmt = For(nullptr, nullptr, CallStmt(Call("DiscardOrNext")), Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
+  auto* stmt = If(true, Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
+  auto* stmt = If(true, Block(Discard()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
+  auto* stmt = If(true, Block(), Else(Block(Discard())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
+  auto* stmt = If(true, Block(Discard()), Else(Block(Discard())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
+  auto* stmt = If(Equal(Call("DiscardOrNext"), 1), Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) {
+  auto* stmt = If(true, Block(),  //
+                  Else(Equal(Call("DiscardOrNext"), 1), Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtLetDecl) {
+  auto* stmt = Decl(Const("v", ty.i32(), Expr(1)));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) {
+  auto* stmt = Decl(Const("lhs", ty.i32(), Call("DiscardOrNext")));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty) {
+  auto* stmt = Loop(Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopBreak) {
+  auto* stmt = Loop(Block(Break()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopContinue) {
+  auto* stmt = Loop(Block(Continue()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopDiscard) {
+  auto* stmt = Loop(Block(Discard()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopReturn) {
+  auto* stmt = Loop(Block(Return()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContEmpty) {
+  auto* stmt = Loop(Block(), Block());
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_TRUE(sem->Behaviors().Empty());
+}
+
+TEST_F(ResolverBehaviorTest, StmtLoopEmpty_ContBreak) {
+  auto* stmt = Loop(Block(), Block(Break()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtReturn) {
+  auto* stmt = Return();
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) {
+  auto* stmt = Return(Call("DiscardOrNext"));
+  Func("F", {}, ty.i32(), {stmt});
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kDiscard));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondTrue_DefaultEmpty) {
+  auto* stmt = Switch(1, DefaultCase(Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultEmpty) {
+  auto* stmt = Switch(1, DefaultCase(Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) {
+  auto* stmt = Switch(1, DefaultCase(Block(Discard())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultReturn) {
+  auto* stmt = Switch(1, DefaultCase(Block(Return())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kReturn);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) {
+  auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
+  auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Discard())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kDiscard));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) {
+  auto* stmt = Switch(1, Case(Expr(0), Block()), DefaultCase(Block(Return())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kNext, sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
+  auto* stmt = Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest,
+       StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
+  auto* stmt =
+      Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Discard())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kDiscard);
+}
+
+TEST_F(ResolverBehaviorTest,
+       StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
+  auto* stmt =
+      Switch(1, Case(Expr(0), Block(Discard())), DefaultCase(Block(Return())));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest,
+       StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) {
+  auto* stmt = Switch(1,                                //
+                      Case(Expr(0), Block(Discard())),  //
+                      Case(Expr(1), Block(Return())),   //
+                      DefaultCase(Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext,
+                           sem::Behavior::kReturn));
+}
+
+TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) {
+  auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtVarDecl) {
+  auto* stmt = Decl(Var("v", ty.i32()));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) {
+  auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext")));
+  WrapInFunction(stmt);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(stmt);
+  EXPECT_EQ(sem->Behaviors(),
+            sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+}  // namespace
+}  // namespace resolver
+}  // namespace tint
diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc
index eb2be86..ef365ce 100644
--- a/src/resolver/resolver_validation.cc
+++ b/src/resolver/resolver_validation.cc
@@ -1042,6 +1042,18 @@
     }
   }
 
+  // https://www.w3.org/TR/WGSL/#behaviors-rules
+  // a function behavior is always one of {}, {Next}, {Discard}, or
+  // {Next, Discard}.
+  if (func->Behaviors() != sem::Behaviors{} &&  // NOLINT: bad warning
+      func->Behaviors() != sem::Behavior::kNext &&
+      func->Behaviors() != sem::Behavior::kDiscard &&
+      func->Behaviors() != sem::Behaviors{sem::Behavior::kNext,  //
+                                          sem::Behavior::kDiscard}) {
+    TINT_ICE(Resolver, diagnostics_)
+        << "function '" << name << "' behaviors are: " << func->Behaviors();
+  }
+
   return true;
 }
 
diff --git a/src/sem/expression.h b/src/sem/expression.h
index 06ae10b..b2ff4ac 100644
--- a/src/sem/expression.h
+++ b/src/sem/expression.h
@@ -68,7 +68,7 @@
   const sem::Type* const type_;
   const Statement* const statement_;
   const Constant constant_;
-  sem::Behaviors behaviors_;
+  sem::Behaviors behaviors_{sem::Behavior::kNext};
 };
 
 }  // namespace sem
diff --git a/src/sem/function.h b/src/sem/function.h
index ea834a7..6d980c5 100644
--- a/src/sem/function.h
+++ b/src/sem/function.h
@@ -240,6 +240,12 @@
   /// @returns true if this function has a discard statement
   bool HasDiscard() const { return has_discard_; }
 
+  /// @return the behaviors of this function
+  const sem::Behaviors& Behaviors() const { return behaviors_; }
+
+  /// @return the behaviors of this function
+  sem::Behaviors& Behaviors() { return behaviors_; }
+
  private:
   VariableBindings TransitivelyReferencedSamplerVariablesImpl(
       ast::SamplerKind kind) const;
@@ -257,6 +263,7 @@
   std::vector<const Call*> callsites_;
   std::vector<const Function*> ancestor_entry_points_;
   bool has_discard_ = false;
+  sem::Behaviors behaviors_{sem::Behavior::kNext};
 };
 
 }  // namespace sem
diff --git a/src/sem/statement.h b/src/sem/statement.h
index e1c5160..1468da9 100644
--- a/src/sem/statement.h
+++ b/src/sem/statement.h
@@ -110,8 +110,7 @@
   const ast::Statement* const declaration_;
   const CompoundStatement* const parent_;
   const sem::Function* const function_;
-
-  sem::Behaviors behaviors_;
+  sem::Behaviors behaviors_{sem::Behavior::kNext};
 };
 
 /// CompoundStatement is the base class of statements that can hold other
diff --git a/test/BUILD.gn b/test/BUILD.gn
index bfaa498..c2cc482 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -253,6 +253,7 @@
     "../src/resolver/pipeline_overridable_constant_test.cc",
     "../src/resolver/ptr_ref_test.cc",
     "../src/resolver/ptr_ref_validation_test.cc",
+    "../src/resolver/resolver_behavior_test.cc",
     "../src/resolver/resolver_constants_test.cc",
     "../src/resolver/resolver_test.cc",
     "../src/resolver/resolver_test_helper.cc",