Resolver: Validation for continuing blocks

Check they do not contain returns, discards
Check they do not directly contain continues, however a nested loop can have its own continue.

Bug: chromium:1229976
Change-Id: Ia3c4ac118ffdaa6cca6025366c19f9897718c930
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58384
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/program_builder.h b/src/program_builder.h
index 70b356d..d4a25cf 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -105,6 +105,14 @@
 /// To construct a Program, populate the builder and then `std::move` it to a
 /// Program.
 class ProgramBuilder {
+  /// A helper used to disable overloads if the first type in `TYPES` is a
+  /// Source. Used to avoid ambiguities in overloads that take a Source as the
+  /// first parameter and those that perfectly-forward the first argument.
+  template <typename... TYPES>
+  using DisableIfSource = traits::EnableIfIsNotType<
+      traits::Decay<traits::NthTypeOf<0, TYPES..., void>>,
+      Source>;
+
   /// VarOptionals is a helper for accepting a number of optional, extra
   /// arguments for Var() and Global().
   struct VarOptionals {
@@ -1383,7 +1391,7 @@
   /// global variable with the ast::Module.
   template <typename NAME,
             typename... OPTIONAL,
-            traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr>
+            typename = DisableIfSource<NAME>>
   ast::Variable* Global(NAME&& name,
                         const ast::Type* type,
                         OPTIONAL&&... optional) {
@@ -1504,9 +1512,7 @@
   /// @param args the function call arguments
   /// @returns a `ast::CallExpression` to the function `func`, with the
   /// arguments of `args` converted to `ast::Expression`s using `Expr()`.
-  template <typename NAME,
-            typename... ARGS,
-            traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr>
+  template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>>
   ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
     return create<ast::CallExpression>(Expr(func),
                                        ExprList(std::forward<ARGS>(args)...));
@@ -1781,7 +1787,7 @@
   /// Creates an ast::ReturnStatement with the given return value
   /// @param val the return value
   /// @returns the return statement pointer
-  template <typename EXPR>
+  template <typename EXPR, typename = DisableIfSource<EXPR>>
   ast::ReturnStatement* Return(EXPR&& val) {
     return create<ast::ReturnStatement>(Expr(std::forward<EXPR>(val)));
   }
@@ -1886,12 +1892,22 @@
   }
 
   /// Creates a ast::BlockStatement with input statements
+  /// @param source the source information for the block
   /// @param statements statements of block
   /// @returns the block statement pointer
   template <typename... Statements>
-  ast::BlockStatement* Block(Statements&&... statements) {
+  ast::BlockStatement* Block(const Source& source, Statements&&... statements) {
     return create<ast::BlockStatement>(
-        ast::StatementList{std::forward<Statements>(statements)...});
+        source, ast::StatementList{std::forward<Statements>(statements)...});
+  }
+
+  /// Creates a ast::BlockStatement with input statements
+  /// @param statements statements of block
+  /// @returns the block statement pointer
+  template <typename... STATEMENTS, typename = DisableIfSource<STATEMENTS...>>
+  ast::BlockStatement* Block(STATEMENTS&&... statements) {
+    return create<ast::BlockStatement>(
+        ast::StatementList{std::forward<STATEMENTS>(statements)...});
   }
 
   /// Creates a ast::ElseStatement with input condition and body
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 8288a3f..ca5ea27 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2062,11 +2062,18 @@
   }
   if (stmt->Is<ast::ContinueStatement>()) {
     // Set if we've hit the first continue statement in our parent loop
-    if (auto* loop_block =
-            current_block_->FindFirstParent<sem::LoopBlockStatement>()) {
-      if (loop_block->FirstContinue() == size_t(~0)) {
-        const_cast<sem::LoopBlockStatement*>(loop_block)
-            ->SetFirstContinue(loop_block->Decls().size());
+    if (auto* block =
+            current_block_->FindFirstParent<
+                sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
+      if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
+        if (loop_block->FirstContinue() == size_t(~0)) {
+          const_cast<sem::LoopBlockStatement*>(loop_block)
+              ->SetFirstContinue(loop_block->Decls().size());
+        }
+      } else {
+        AddError("continuing blocks must not contain a continue statement",
+                 stmt->source());
+        return false;
       }
     } else {
       AddError("continue statement must be in a loop", stmt->source());
@@ -2076,6 +2083,17 @@
     return true;
   }
   if (stmt->Is<ast::DiscardStatement>()) {
+    if (auto* continuing =
+            sem_statement
+                ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+      AddError("continuing blocks must not contain a discard statement",
+               stmt->source());
+      if (continuing != sem_statement->Parent()) {
+        AddNote("see continuing block here",
+                continuing->Declaration()->source());
+      }
+      return false;
+    }
     return true;
   }
   if (stmt->Is<ast::FallthroughStatement>()) {
@@ -4110,6 +4128,17 @@
     return false;
   }
 
+  auto* sem = builder_->Sem().Get(ret);
+  if (auto* continuing =
+          sem->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+    AddError("continuing blocks must not contain a return statement",
+             ret->source());
+    if (continuing != sem->Parent()) {
+      AddNote("see continuing block here", continuing->Declaration()->source());
+    }
+    return false;
+  }
+
   return true;
 }
 
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 573629a..6bb426f 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -20,6 +20,7 @@
 #include "src/ast/break_statement.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/intrinsic_texture_helper_test.h"
 #include "src/ast/loop_statement.h"
@@ -650,6 +651,123 @@
   EXPECT_TRUE(r()->Resolve());
 }
 
+TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Direct) {
+  // loop  {
+  //   continuing {
+  //     return;
+  //   }
+  // }
+
+  WrapInFunction(Loop(  // loop
+      Block(),          //   loop block
+      Block(            //   loop continuing block
+          Return(Source{{12, 34}}))));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      R"(12:34 error: continuing blocks must not contain a return statement)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Indirect) {
+  // loop {
+  //   continuing {
+  //     loop {
+  //       return;
+  //     }
+  //   }
+  // }
+
+  WrapInFunction(Loop(         // outer loop
+      Block(),                 //   outer loop block
+      Block(Source{{56, 78}},  //   outer loop continuing block
+            Loop(              //     inner loop
+                Block(         //       inner loop block
+                    Return(Source{{12, 34}}))))));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      R"(12:34 error: continuing blocks must not contain a return statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Direct) {
+  // loop  {
+  //   continuing {
+  //     discard;
+  //   }
+  // }
+
+  WrapInFunction(Loop(  // loop
+      Block(),          //   loop block
+      Block(            //   loop continuing block
+          create<ast::DiscardStatement>(Source{{12, 34}}))));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      R"(12:34 error: continuing blocks must not contain a discard statement)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Indirect) {
+  // loop {
+  //   continuing {
+  //     loop { discard; }
+  //   }
+  // }
+
+  WrapInFunction(Loop(         // outer loop
+      Block(),                 //   outer loop block
+      Block(Source{{56, 78}},  //   outer loop continuing block
+            Loop(              //     inner loop
+                Block(         //       inner loop block
+                    create<ast::DiscardStatement>(Source{{12, 34}}))))));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      R"(12:34 error: continuing blocks must not contain a discard statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Direct) {
+  // loop  {
+  //     continuing {
+  //         continue;
+  //     }
+  // }
+
+  WrapInFunction(Loop(         // loop
+      Block(),                 //   loop block
+      Block(Source{{56, 78}},  //   loop continuing block
+            create<ast::ContinueStatement>(Source{{12, 34}}))));
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: continuing blocks must not contain a continue statement");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Indirect) {
+  // loop {
+  //   continuing {
+  //     loop {
+  //       continue;
+  //     }
+  //   }
+  // }
+
+  WrapInFunction(Loop(  // outer loop
+      Block(),          //   outer loop block
+      Block(            //   outer loop continuing block
+          Loop(         //     inner loop
+              Block(    //       inner loop block
+                  create<ast::ContinueStatement>(Source{{12, 34}}))))));
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
 TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) {
   // for (; 1.0f; ) {
   // }
diff --git a/src/sem/statement.h b/src/sem/statement.h
index 821a8a5..0449c3a 100644
--- a/src/sem/statement.h
+++ b/src/sem/statement.h
@@ -34,6 +34,30 @@
 /// Forward declaration
 class CompoundStatement;
 
+namespace detail {
+/// FindFirstParentReturn is a traits helper for determining the return type for
+/// the template member function Statement::FindFirstParent().
+/// For zero or multiple template arguments, FindFirstParentReturn::type
+/// resolves to CompoundStatement.
+template <typename... TYPES>
+struct FindFirstParentReturn {
+  /// The pointer type returned by Statement::FindFirstParent()
+  using type = CompoundStatement;
+};
+
+/// A specialization of FindFirstParentReturn for a single template argument.
+/// FindFirstParentReturn::type resolves to the single template argument.
+template <typename T>
+struct FindFirstParentReturn<T> {
+  /// The pointer type returned by Statement::FindFirstParent()
+  using type = T;
+};
+
+template <typename... TYPES>
+using FindFirstParentReturnType =
+    typename FindFirstParentReturn<TYPES...>::type;
+}  // namespace detail
+
 /// Statement holds the semantic information for a statement.
 class Statement : public Castable<Statement, Node> {
  public:
@@ -49,16 +73,18 @@
   const CompoundStatement* Parent() const { return parent_; }
 
   /// @returns the closest enclosing parent that satisfies the given predicate,
-  /// which may be the statement itself, or nullptr if no match is found
+  /// which may be the statement itself, or nullptr if no match is found.
   /// @param pred a predicate that the resulting block must satisfy
   template <typename Pred>
   const CompoundStatement* FindFirstParent(Pred&& pred) const;
 
-  /// @returns the statement itself if it matches the template type `T`,
-  /// otherwise the nearest enclosing statement that matches `T`, or nullptr if
-  /// there is none.
-  template <typename T>
-  const T* FindFirstParent() const;
+  /// @returns the closest enclosing parent that is of one of the types in
+  /// `TYPES`, which may be the statement itself, or nullptr if no match is
+  /// found. If `TYPES` is a single template argument, the return type is a
+  /// pointer to that template argument type, otherwise a CompoundStatement
+  /// pointer is returned.
+  template <typename... TYPES>
+  const detail::FindFirstParentReturnType<TYPES...>* FindFirstParent() const;
 
   /// @return the closest enclosing block for this statement
   const BlockStatement* Block() const;
@@ -99,17 +125,32 @@
   return curr;
 }
 
-template <typename T>
-const T* Statement::FindFirstParent() const {
-  if (auto* p = As<T>()) {
-    return p;
-  }
-  const auto* curr = parent_;
-  while (curr) {
-    if (auto* p = curr->As<T>()) {
+template <typename... TYPES>
+const detail::FindFirstParentReturnType<TYPES...>* Statement::FindFirstParent()
+    const {
+  using ReturnType = detail::FindFirstParentReturnType<TYPES...>;
+  if (sizeof...(TYPES) == 1) {
+    if (auto* p = As<ReturnType>()) {
       return p;
     }
-    curr = curr->Parent();
+    const auto* curr = parent_;
+    while (curr) {
+      if (auto* p = curr->As<ReturnType>()) {
+        return p;
+      }
+      curr = curr->Parent();
+    }
+  } else {
+    if (IsAnyOf<TYPES...>()) {
+      return As<ReturnType>();
+    }
+    const auto* curr = parent_;
+    while (curr) {
+      if (curr->IsAnyOf<TYPES...>()) {
+        return curr->As<ReturnType>();
+      }
+      curr = curr->Parent();
+    }
   }
   return nullptr;
 }
diff --git a/src/writer/hlsl/generator_impl_loop_test.cc b/src/writer/hlsl/generator_impl_loop_test.cc
index b9d006f..87b4a3b 100644
--- a/src/writer/hlsl/generator_impl_loop_test.cc
+++ b/src/writer/hlsl/generator_impl_loop_test.cc
@@ -41,8 +41,10 @@
 }
 
 TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) {
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* body = Block(create<ast::DiscardStatement>());
-  auto* continuing = Block(Return());
+  auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
   auto* l = Loop(body, continuing);
 
   WrapInFunction(l);
@@ -55,18 +57,20 @@
   EXPECT_EQ(gen.result(), R"(  while (true) {
     discard;
     {
-      return;
+      a_statement();
     }
   }
 )");
 }
 
 TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
+  Func("a_statement", {}, ty.void_(), {});
+
   Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
   Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
 
   auto* body = Block(create<ast::DiscardStatement>());
-  auto* continuing = Block(Return());
+  auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
   auto* inner = Loop(body, continuing);
 
   body = Block(inner);
@@ -88,7 +92,7 @@
     while (true) {
       discard;
       {
-        return;
+        a_statement();
       }
     }
     {
@@ -153,7 +157,10 @@
   //   return;
   // }
 
-  auto* f = For(nullptr, nullptr, nullptr, Block(Return()));
+  Func("a_statement", {}, ty.void_(), {});
+
+  auto* f = For(nullptr, nullptr, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -163,7 +170,7 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
     for(; ; ) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -174,7 +181,10 @@
   //   return;
   // }
 
-  auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return()));
+  Func("a_statement", {}, ty.void_(), {});
+
+  auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -184,7 +194,7 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
     for(int i = 0; ; ) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -194,10 +204,12 @@
   // for(var b = true && false; ; ) {
   //   return;
   // }
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
                                                    Expr(true), Expr(false));
   auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr,
-                Block(Return()));
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -212,7 +224,7 @@
     }
     bool b = (tint_tmp);
     for(; ; ) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -223,7 +235,10 @@
   //   return;
   // }
 
-  auto* f = For(nullptr, true, nullptr, Block(Return()));
+  Func("a_statement", {}, ty.void_(), {});
+
+  auto* f = For(nullptr, true, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -233,7 +248,7 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
     for(; true; ) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -244,9 +259,12 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
                                                    Expr(true), Expr(false));
-  auto* f = For(nullptr, multi_stmt, nullptr, Block(Return()));
+  auto* f = For(nullptr, multi_stmt, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -261,7 +279,7 @@
         tint_tmp = false;
       }
       if (!((tint_tmp))) { break; }
-      return;
+      a_statement();
     }
   }
 )");
@@ -272,8 +290,11 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* v = Decl(Var("i", ty.i32()));
-  auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return()));
+  auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)),
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(v, f);
 
   GeneratorImpl& gen = Build();
@@ -283,7 +304,7 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
     for(; ; i = (i + 1)) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -294,10 +315,13 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
                                                    Expr(true), Expr(false));
   auto* v = Decl(Var("i", ty.bool_()));
-  auto* f = For(nullptr, nullptr, Assign("i", multi_stmt), Block(Return()));
+  auto* f = For(nullptr, nullptr, Assign("i", multi_stmt),
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(v, f);
 
   GeneratorImpl& gen = Build();
@@ -307,7 +331,7 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
     while (true) {
-      return;
+      a_statement();
       bool tint_tmp = true;
       if (tint_tmp) {
         tint_tmp = false;
@@ -323,8 +347,10 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)),
-                Block(Return()));
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -334,7 +360,7 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
     for(int i = 0; true; i = (i + 1)) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -344,6 +370,8 @@
   // for(var i = true && false; true && false; i = true && false) {
   //   return;
   // }
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* multi_stmt_a = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
                                                      Expr(true), Expr(false));
   auto* multi_stmt_b = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
@@ -352,7 +380,8 @@
                                                      Expr(true), Expr(false));
 
   auto* f = For(Decl(Var("i", nullptr, multi_stmt_a)), multi_stmt_b,
-                Assign("i", multi_stmt_c), Block(Return()));
+                Assign("i", multi_stmt_c),
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -372,7 +401,7 @@
         tint_tmp_1 = false;
       }
       if (!((tint_tmp_1))) { break; }
-      return;
+      a_statement();
       bool tint_tmp_2 = true;
       if (tint_tmp_2) {
         tint_tmp_2 = false;
diff --git a/src/writer/msl/generator_impl_loop_test.cc b/src/writer/msl/generator_impl_loop_test.cc
index 5c32822..178df07 100644
--- a/src/writer/msl/generator_impl_loop_test.cc
+++ b/src/writer/msl/generator_impl_loop_test.cc
@@ -40,8 +40,10 @@
 }
 
 TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) {
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* body = Block(create<ast::DiscardStatement>());
-  auto* continuing = Block(Return());
+  auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
   auto* l = Loop(body, continuing);
   WrapInFunction(l);
 
@@ -53,18 +55,20 @@
   EXPECT_EQ(gen.result(), R"(  while (true) {
     discard_fragment();
     {
-      return;
+      a_statement();
     }
   }
 )");
 }
 
 TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
+  Func("a_statement", {}, ty.void_(), {});
+
   Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
   Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
 
   auto* body = Block(create<ast::DiscardStatement>());
-  auto* continuing = Block(Return());
+  auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
   auto* inner = Loop(body, continuing);
 
   body = Block(inner);
@@ -83,7 +87,7 @@
     while (true) {
       discard_fragment();
       {
-        return;
+        a_statement();
       }
     }
     {
@@ -146,7 +150,10 @@
   //   return;
   // }
 
-  auto* f = For(nullptr, nullptr, nullptr, Block(Return()));
+  Func("a_statement", {}, ty.void_(), {});
+
+  auto* f = For(nullptr, nullptr, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -155,7 +162,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  for(; ; ) {
-    return;
+    a_statement();
   }
 )");
 }
@@ -165,7 +172,10 @@
   //   return;
   // }
 
-  auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return()));
+  Func("a_statement", {}, ty.void_(), {});
+
+  auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -174,7 +184,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  for(int i = 0; ; ) {
-    return;
+    a_statement();
   }
 )");
 }
@@ -184,9 +194,13 @@
   // for({ignore(1); ignore(2);}; ; ) {
   //   return;
   // }
+
+  Func("a_statement", {}, ty.void_(), {});
+
   Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
   auto* multi_stmt = Block(Ignore(1), Ignore(2));
-  auto* f = For(multi_stmt, nullptr, nullptr, Block(Return()));
+  auto* f = For(multi_stmt, nullptr, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -200,7 +214,7 @@
       (void) 2;
     }
     for(; ; ) {
-      return;
+      a_statement();
     }
   }
 )");
@@ -211,7 +225,10 @@
   //   return;
   // }
 
-  auto* f = For(nullptr, true, nullptr, Block(Return()));
+  Func("a_statement", {}, ty.void_(), {});
+
+  auto* f = For(nullptr, true, nullptr,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -220,7 +237,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  for(; true; ) {
-    return;
+    a_statement();
   }
 )");
 }
@@ -230,8 +247,11 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* v = Decl(Var("i", ty.i32()));
-  auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return()));
+  auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)),
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(v, f);
 
   GeneratorImpl& gen = Build();
@@ -240,7 +260,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  for(; ; i = (i + 1)) {
-    return;
+    a_statement();
   }
 )");
 }
@@ -251,9 +271,12 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
   auto* multi_stmt = Block(Ignore(1), Ignore(2));
-  auto* f = For(nullptr, nullptr, multi_stmt, Block(Return()));
+  auto* f = For(nullptr, nullptr, multi_stmt,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -262,7 +285,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  while (true) {
-    return;
+    a_statement();
     {
       (void) 1;
       (void) 2;
@@ -276,8 +299,10 @@
   //   return;
   // }
 
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)),
-                Block(Return()));
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -286,7 +311,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  for(int i = 0; true; i = (i + 1)) {
-    return;
+    a_statement();
   }
 )");
 }
@@ -296,10 +321,14 @@
   // for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) {
   //   return;
   // }
+
+  Func("a_statement", {}, ty.void_(), {});
+
   Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
   auto* multi_stmt_a = Block(Ignore(1), Ignore(2));
   auto* multi_stmt_b = Block(Ignore(3), Ignore(4));
-  auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return()));
+  auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b,
+                Block(create<ast::CallStatement>(Call("a_statement"))));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -314,7 +343,7 @@
     }
     while (true) {
       if (!(true)) { break; }
-      return;
+      a_statement();
       {
         (void) 3;
         (void) 4;
diff --git a/src/writer/wgsl/generator_impl_loop_test.cc b/src/writer/wgsl/generator_impl_loop_test.cc
index 6c413a8..1a28857 100644
--- a/src/writer/wgsl/generator_impl_loop_test.cc
+++ b/src/writer/wgsl/generator_impl_loop_test.cc
@@ -40,8 +40,10 @@
 }
 
 TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) {
+  Func("a_statement", {}, ty.void_(), {});
+
   auto* body = Block(create<ast::DiscardStatement>());
-  auto* continuing = Block(Return());
+  auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
   auto* l = Loop(body, continuing);
 
   WrapInFunction(l);
@@ -55,7 +57,7 @@
     discard;
 
     continuing {
-      return;
+      a_statement();
     }
   }
 )");