Resolver: Validation for continuing blocks
Check they do not contain returns, discards
Check they do not directly contain continues, however a nested loop can have its own continue.
Bug: chromium:1229976
Change-Id: Ia3c4ac118ffdaa6cca6025366c19f9897718c930
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58384
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/program_builder.h b/src/program_builder.h
index 70b356d..d4a25cf 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -105,6 +105,14 @@
/// To construct a Program, populate the builder and then `std::move` it to a
/// Program.
class ProgramBuilder {
+ /// A helper used to disable overloads if the first type in `TYPES` is a
+ /// Source. Used to avoid ambiguities in overloads that take a Source as the
+ /// first parameter and those that perfectly-forward the first argument.
+ template <typename... TYPES>
+ using DisableIfSource = traits::EnableIfIsNotType<
+ traits::Decay<traits::NthTypeOf<0, TYPES..., void>>,
+ Source>;
+
/// VarOptionals is a helper for accepting a number of optional, extra
/// arguments for Var() and Global().
struct VarOptionals {
@@ -1383,7 +1391,7 @@
/// global variable with the ast::Module.
template <typename NAME,
typename... OPTIONAL,
- traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr>
+ typename = DisableIfSource<NAME>>
ast::Variable* Global(NAME&& name,
const ast::Type* type,
OPTIONAL&&... optional) {
@@ -1504,9 +1512,7 @@
/// @param args the function call arguments
/// @returns a `ast::CallExpression` to the function `func`, with the
/// arguments of `args` converted to `ast::Expression`s using `Expr()`.
- template <typename NAME,
- typename... ARGS,
- traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr>
+ template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>>
ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
return create<ast::CallExpression>(Expr(func),
ExprList(std::forward<ARGS>(args)...));
@@ -1781,7 +1787,7 @@
/// Creates an ast::ReturnStatement with the given return value
/// @param val the return value
/// @returns the return statement pointer
- template <typename EXPR>
+ template <typename EXPR, typename = DisableIfSource<EXPR>>
ast::ReturnStatement* Return(EXPR&& val) {
return create<ast::ReturnStatement>(Expr(std::forward<EXPR>(val)));
}
@@ -1886,12 +1892,22 @@
}
/// Creates a ast::BlockStatement with input statements
+ /// @param source the source information for the block
/// @param statements statements of block
/// @returns the block statement pointer
template <typename... Statements>
- ast::BlockStatement* Block(Statements&&... statements) {
+ ast::BlockStatement* Block(const Source& source, Statements&&... statements) {
return create<ast::BlockStatement>(
- ast::StatementList{std::forward<Statements>(statements)...});
+ source, ast::StatementList{std::forward<Statements>(statements)...});
+ }
+
+ /// Creates a ast::BlockStatement with input statements
+ /// @param statements statements of block
+ /// @returns the block statement pointer
+ template <typename... STATEMENTS, typename = DisableIfSource<STATEMENTS...>>
+ ast::BlockStatement* Block(STATEMENTS&&... statements) {
+ return create<ast::BlockStatement>(
+ ast::StatementList{std::forward<STATEMENTS>(statements)...});
}
/// Creates a ast::ElseStatement with input condition and body
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 8288a3f..ca5ea27 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2062,11 +2062,18 @@
}
if (stmt->Is<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop
- if (auto* loop_block =
- current_block_->FindFirstParent<sem::LoopBlockStatement>()) {
- if (loop_block->FirstContinue() == size_t(~0)) {
- const_cast<sem::LoopBlockStatement*>(loop_block)
- ->SetFirstContinue(loop_block->Decls().size());
+ if (auto* block =
+ current_block_->FindFirstParent<
+ sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
+ if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
+ if (loop_block->FirstContinue() == size_t(~0)) {
+ const_cast<sem::LoopBlockStatement*>(loop_block)
+ ->SetFirstContinue(loop_block->Decls().size());
+ }
+ } else {
+ AddError("continuing blocks must not contain a continue statement",
+ stmt->source());
+ return false;
}
} else {
AddError("continue statement must be in a loop", stmt->source());
@@ -2076,6 +2083,17 @@
return true;
}
if (stmt->Is<ast::DiscardStatement>()) {
+ if (auto* continuing =
+ sem_statement
+ ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+ AddError("continuing blocks must not contain a discard statement",
+ stmt->source());
+ if (continuing != sem_statement->Parent()) {
+ AddNote("see continuing block here",
+ continuing->Declaration()->source());
+ }
+ return false;
+ }
return true;
}
if (stmt->Is<ast::FallthroughStatement>()) {
@@ -4110,6 +4128,17 @@
return false;
}
+ auto* sem = builder_->Sem().Get(ret);
+ if (auto* continuing =
+ sem->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
+ AddError("continuing blocks must not contain a return statement",
+ ret->source());
+ if (continuing != sem->Parent()) {
+ AddNote("see continuing block here", continuing->Declaration()->source());
+ }
+ return false;
+ }
+
return true;
}
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 573629a..6bb426f 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -20,6 +20,7 @@
#include "src/ast/break_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h"
@@ -650,6 +651,123 @@
EXPECT_TRUE(r()->Resolve());
}
+TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Direct) {
+ // loop {
+ // continuing {
+ // return;
+ // }
+ // }
+
+ WrapInFunction(Loop( // loop
+ Block(), // loop block
+ Block( // loop continuing block
+ Return(Source{{12, 34}}))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a return statement)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Indirect) {
+ // loop {
+ // continuing {
+ // loop {
+ // return;
+ // }
+ // }
+ // }
+
+ WrapInFunction(Loop( // outer loop
+ Block(), // outer loop block
+ Block(Source{{56, 78}}, // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ Return(Source{{12, 34}}))))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a return statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Direct) {
+ // loop {
+ // continuing {
+ // discard;
+ // }
+ // }
+
+ WrapInFunction(Loop( // loop
+ Block(), // loop block
+ Block( // loop continuing block
+ create<ast::DiscardStatement>(Source{{12, 34}}))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a discard statement)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Indirect) {
+ // loop {
+ // continuing {
+ // loop { discard; }
+ // }
+ // }
+
+ WrapInFunction(Loop( // outer loop
+ Block(), // outer loop block
+ Block(Source{{56, 78}}, // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ create<ast::DiscardStatement>(Source{{12, 34}}))))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: continuing blocks must not contain a discard statement
+56:78 note: see continuing block here)");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Direct) {
+ // loop {
+ // continuing {
+ // continue;
+ // }
+ // }
+
+ WrapInFunction(Loop( // loop
+ Block(), // loop block
+ Block(Source{{56, 78}}, // loop continuing block
+ create<ast::ContinueStatement>(Source{{12, 34}}))));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: continuing blocks must not contain a continue statement");
+}
+
+TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Indirect) {
+ // loop {
+ // continuing {
+ // loop {
+ // continue;
+ // }
+ // }
+ // }
+
+ WrapInFunction(Loop( // outer loop
+ Block(), // outer loop block
+ Block( // outer loop continuing block
+ Loop( // inner loop
+ Block( // inner loop block
+ create<ast::ContinueStatement>(Source{{12, 34}}))))));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) {
// for (; 1.0f; ) {
// }
diff --git a/src/sem/statement.h b/src/sem/statement.h
index 821a8a5..0449c3a 100644
--- a/src/sem/statement.h
+++ b/src/sem/statement.h
@@ -34,6 +34,30 @@
/// Forward declaration
class CompoundStatement;
+namespace detail {
+/// FindFirstParentReturn is a traits helper for determining the return type for
+/// the template member function Statement::FindFirstParent().
+/// For zero or multiple template arguments, FindFirstParentReturn::type
+/// resolves to CompoundStatement.
+template <typename... TYPES>
+struct FindFirstParentReturn {
+ /// The pointer type returned by Statement::FindFirstParent()
+ using type = CompoundStatement;
+};
+
+/// A specialization of FindFirstParentReturn for a single template argument.
+/// FindFirstParentReturn::type resolves to the single template argument.
+template <typename T>
+struct FindFirstParentReturn<T> {
+ /// The pointer type returned by Statement::FindFirstParent()
+ using type = T;
+};
+
+template <typename... TYPES>
+using FindFirstParentReturnType =
+ typename FindFirstParentReturn<TYPES...>::type;
+} // namespace detail
+
/// Statement holds the semantic information for a statement.
class Statement : public Castable<Statement, Node> {
public:
@@ -49,16 +73,18 @@
const CompoundStatement* Parent() const { return parent_; }
/// @returns the closest enclosing parent that satisfies the given predicate,
- /// which may be the statement itself, or nullptr if no match is found
+ /// which may be the statement itself, or nullptr if no match is found.
/// @param pred a predicate that the resulting block must satisfy
template <typename Pred>
const CompoundStatement* FindFirstParent(Pred&& pred) const;
- /// @returns the statement itself if it matches the template type `T`,
- /// otherwise the nearest enclosing statement that matches `T`, or nullptr if
- /// there is none.
- template <typename T>
- const T* FindFirstParent() const;
+ /// @returns the closest enclosing parent that is of one of the types in
+ /// `TYPES`, which may be the statement itself, or nullptr if no match is
+ /// found. If `TYPES` is a single template argument, the return type is a
+ /// pointer to that template argument type, otherwise a CompoundStatement
+ /// pointer is returned.
+ template <typename... TYPES>
+ const detail::FindFirstParentReturnType<TYPES...>* FindFirstParent() const;
/// @return the closest enclosing block for this statement
const BlockStatement* Block() const;
@@ -99,17 +125,32 @@
return curr;
}
-template <typename T>
-const T* Statement::FindFirstParent() const {
- if (auto* p = As<T>()) {
- return p;
- }
- const auto* curr = parent_;
- while (curr) {
- if (auto* p = curr->As<T>()) {
+template <typename... TYPES>
+const detail::FindFirstParentReturnType<TYPES...>* Statement::FindFirstParent()
+ const {
+ using ReturnType = detail::FindFirstParentReturnType<TYPES...>;
+ if (sizeof...(TYPES) == 1) {
+ if (auto* p = As<ReturnType>()) {
return p;
}
- curr = curr->Parent();
+ const auto* curr = parent_;
+ while (curr) {
+ if (auto* p = curr->As<ReturnType>()) {
+ return p;
+ }
+ curr = curr->Parent();
+ }
+ } else {
+ if (IsAnyOf<TYPES...>()) {
+ return As<ReturnType>();
+ }
+ const auto* curr = parent_;
+ while (curr) {
+ if (curr->IsAnyOf<TYPES...>()) {
+ return curr->As<ReturnType>();
+ }
+ curr = curr->Parent();
+ }
}
return nullptr;
}
diff --git a/src/writer/hlsl/generator_impl_loop_test.cc b/src/writer/hlsl/generator_impl_loop_test.cc
index b9d006f..87b4a3b 100644
--- a/src/writer/hlsl/generator_impl_loop_test.cc
+++ b/src/writer/hlsl/generator_impl_loop_test.cc
@@ -41,8 +41,10 @@
}
TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) {
+ Func("a_statement", {}, ty.void_(), {});
+
auto* body = Block(create<ast::DiscardStatement>());
- auto* continuing = Block(Return());
+ auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* l = Loop(body, continuing);
WrapInFunction(l);
@@ -55,18 +57,20 @@
EXPECT_EQ(gen.result(), R"( while (true) {
discard;
{
- return;
+ a_statement();
}
}
)");
}
TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
+ Func("a_statement", {}, ty.void_(), {});
+
Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
auto* body = Block(create<ast::DiscardStatement>());
- auto* continuing = Block(Return());
+ auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* inner = Loop(body, continuing);
body = Block(inner);
@@ -88,7 +92,7 @@
while (true) {
discard;
{
- return;
+ a_statement();
}
}
{
@@ -153,7 +157,10 @@
// return;
// }
- auto* f = For(nullptr, nullptr, nullptr, Block(Return()));
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* f = For(nullptr, nullptr, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -163,7 +170,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
for(; ; ) {
- return;
+ a_statement();
}
}
)");
@@ -174,7 +181,10 @@
// return;
// }
- auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return()));
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -184,7 +194,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
for(int i = 0; ; ) {
- return;
+ a_statement();
}
}
)");
@@ -194,10 +204,12 @@
// for(var b = true && false; ; ) {
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false));
auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr,
- Block(Return()));
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -212,7 +224,7 @@
}
bool b = (tint_tmp);
for(; ; ) {
- return;
+ a_statement();
}
}
)");
@@ -223,7 +235,10 @@
// return;
// }
- auto* f = For(nullptr, true, nullptr, Block(Return()));
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* f = For(nullptr, true, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -233,7 +248,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
for(; true; ) {
- return;
+ a_statement();
}
}
)");
@@ -244,9 +259,12 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false));
- auto* f = For(nullptr, multi_stmt, nullptr, Block(Return()));
+ auto* f = For(nullptr, multi_stmt, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -261,7 +279,7 @@
tint_tmp = false;
}
if (!((tint_tmp))) { break; }
- return;
+ a_statement();
}
}
)");
@@ -272,8 +290,11 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* v = Decl(Var("i", ty.i32()));
- auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return()));
+ auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)),
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(v, f);
GeneratorImpl& gen = Build();
@@ -283,7 +304,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
for(; ; i = (i + 1)) {
- return;
+ a_statement();
}
}
)");
@@ -294,10 +315,13 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false));
auto* v = Decl(Var("i", ty.bool_()));
- auto* f = For(nullptr, nullptr, Assign("i", multi_stmt), Block(Return()));
+ auto* f = For(nullptr, nullptr, Assign("i", multi_stmt),
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(v, f);
GeneratorImpl& gen = Build();
@@ -307,7 +331,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
while (true) {
- return;
+ a_statement();
bool tint_tmp = true;
if (tint_tmp) {
tint_tmp = false;
@@ -323,8 +347,10 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)),
- Block(Return()));
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -334,7 +360,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
for(int i = 0; true; i = (i + 1)) {
- return;
+ a_statement();
}
}
)");
@@ -344,6 +370,8 @@
// for(var i = true && false; true && false; i = true && false) {
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* multi_stmt_a = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false));
auto* multi_stmt_b = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
@@ -352,7 +380,8 @@
Expr(true), Expr(false));
auto* f = For(Decl(Var("i", nullptr, multi_stmt_a)), multi_stmt_b,
- Assign("i", multi_stmt_c), Block(Return()));
+ Assign("i", multi_stmt_c),
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -372,7 +401,7 @@
tint_tmp_1 = false;
}
if (!((tint_tmp_1))) { break; }
- return;
+ a_statement();
bool tint_tmp_2 = true;
if (tint_tmp_2) {
tint_tmp_2 = false;
diff --git a/src/writer/msl/generator_impl_loop_test.cc b/src/writer/msl/generator_impl_loop_test.cc
index 5c32822..178df07 100644
--- a/src/writer/msl/generator_impl_loop_test.cc
+++ b/src/writer/msl/generator_impl_loop_test.cc
@@ -40,8 +40,10 @@
}
TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) {
+ Func("a_statement", {}, ty.void_(), {});
+
auto* body = Block(create<ast::DiscardStatement>());
- auto* continuing = Block(Return());
+ auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* l = Loop(body, continuing);
WrapInFunction(l);
@@ -53,18 +55,20 @@
EXPECT_EQ(gen.result(), R"( while (true) {
discard_fragment();
{
- return;
+ a_statement();
}
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
+ Func("a_statement", {}, ty.void_(), {});
+
Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
auto* body = Block(create<ast::DiscardStatement>());
- auto* continuing = Block(Return());
+ auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* inner = Loop(body, continuing);
body = Block(inner);
@@ -83,7 +87,7 @@
while (true) {
discard_fragment();
{
- return;
+ a_statement();
}
}
{
@@ -146,7 +150,10 @@
// return;
// }
- auto* f = For(nullptr, nullptr, nullptr, Block(Return()));
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* f = For(nullptr, nullptr, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -155,7 +162,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(; ; ) {
- return;
+ a_statement();
}
)");
}
@@ -165,7 +172,10 @@
// return;
// }
- auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return()));
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -174,7 +184,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(int i = 0; ; ) {
- return;
+ a_statement();
}
)");
}
@@ -184,9 +194,13 @@
// for({ignore(1); ignore(2);}; ; ) {
// return;
// }
+
+ Func("a_statement", {}, ty.void_(), {});
+
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = Block(Ignore(1), Ignore(2));
- auto* f = For(multi_stmt, nullptr, nullptr, Block(Return()));
+ auto* f = For(multi_stmt, nullptr, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -200,7 +214,7 @@
(void) 2;
}
for(; ; ) {
- return;
+ a_statement();
}
}
)");
@@ -211,7 +225,10 @@
// return;
// }
- auto* f = For(nullptr, true, nullptr, Block(Return()));
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* f = For(nullptr, true, nullptr,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -220,7 +237,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(; true; ) {
- return;
+ a_statement();
}
)");
}
@@ -230,8 +247,11 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* v = Decl(Var("i", ty.i32()));
- auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return()));
+ auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)),
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(v, f);
GeneratorImpl& gen = Build();
@@ -240,7 +260,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(; ; i = (i + 1)) {
- return;
+ a_statement();
}
)");
}
@@ -251,9 +271,12 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = Block(Ignore(1), Ignore(2));
- auto* f = For(nullptr, nullptr, multi_stmt, Block(Return()));
+ auto* f = For(nullptr, nullptr, multi_stmt,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -262,7 +285,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) {
- return;
+ a_statement();
{
(void) 1;
(void) 2;
@@ -276,8 +299,10 @@
// return;
// }
+ Func("a_statement", {}, ty.void_(), {});
+
auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)),
- Block(Return()));
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -286,7 +311,7 @@
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(int i = 0; true; i = (i + 1)) {
- return;
+ a_statement();
}
)");
}
@@ -296,10 +321,14 @@
// for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) {
// return;
// }
+
+ Func("a_statement", {}, ty.void_(), {});
+
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt_a = Block(Ignore(1), Ignore(2));
auto* multi_stmt_b = Block(Ignore(3), Ignore(4));
- auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return()));
+ auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b,
+ Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@@ -314,7 +343,7 @@
}
while (true) {
if (!(true)) { break; }
- return;
+ a_statement();
{
(void) 3;
(void) 4;
diff --git a/src/writer/wgsl/generator_impl_loop_test.cc b/src/writer/wgsl/generator_impl_loop_test.cc
index 6c413a8..1a28857 100644
--- a/src/writer/wgsl/generator_impl_loop_test.cc
+++ b/src/writer/wgsl/generator_impl_loop_test.cc
@@ -40,8 +40,10 @@
}
TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) {
+ Func("a_statement", {}, ty.void_(), {});
+
auto* body = Block(create<ast::DiscardStatement>());
- auto* continuing = Block(Return());
+ auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* l = Loop(body, continuing);
WrapInFunction(l);
@@ -55,7 +57,7 @@
discard;
continuing {
- return;
+ a_statement();
}
}
)");