Add break-if support.
This CL adds support for `break-if` to Tint.
Bug: tint:1633, tint:1451
Change-Id: I30dfd62a3e09255624ff76ebe0cdd3a3c7cf9c5f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/106420
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: dan sinclair <dsinclair@google.com>
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index ab28bf7..fc9de73 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -712,6 +712,16 @@
return true;
}
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+ auto out = line();
+ out << "if (";
+ if (!EmitExpression(out, b->condition)) {
+ return false;
+ }
+ out << ") { break; }";
+ return true;
+}
+
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get<sem::Call>(expr);
auto* target = call->Target();
@@ -2616,6 +2626,7 @@
[&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
[&](const ast::BlockStatement* b) { return EmitBlock(b); },
[&](const ast::BreakStatement* b) { return EmitBreak(b); },
+ [&](const ast::BreakIfStatement* b) { return EmitBreakIf(b); },
[&](const ast::CallStatement* c) {
auto out = line();
if (!EmitCall(out, c->expr)) {
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index 2cd5d5d..889b406 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -139,6 +139,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully
bool EmitBreak(const ast::BreakStatement* stmt);
+ /// Handles a break-if statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted successfully
+ bool EmitBreakIf(const ast::BreakIfStatement* stmt);
/// Handles generating a call expression
/// @param out the output of the expression stream
/// @param expr the call expression
diff --git a/src/tint/writer/glsl/generator_impl_loop_test.cc b/src/tint/writer/glsl/generator_impl_loop_test.cc
index e638eb3..7aa94be 100644
--- a/src/tint/writer/glsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/glsl/generator_impl_loop_test.cc
@@ -65,6 +65,31 @@
)");
}
+TEST_F(GlslGeneratorImplTest_Loop, Emit_LoopWithContinuing_BreakIf) {
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* body = Block(create<ast::DiscardStatement>());
+ auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+ auto* l = Loop(body, continuing);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while (true) {
+ discard;
+ {
+ a_statement();
+ if (true) { break; }
+ }
+ }
+)");
+}
+
TEST_F(GlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 1dc3b10..e42274d 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -942,6 +942,16 @@
return true;
}
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+ auto out = line();
+ out << "if (";
+ if (!EmitExpression(out, b->condition)) {
+ return false;
+ }
+ out << ") { break; }";
+ return true;
+}
+
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
auto* call = builder_.Sem().Get<sem::Call>(expr);
auto* target = call->Target();
@@ -3591,6 +3601,9 @@
[&](const ast::BreakStatement* b) { //
return EmitBreak(b);
},
+ [&](const ast::BreakIfStatement* b) { //
+ return EmitBreakIf(b);
+ },
[&](const ast::CallStatement* c) { //
auto out = line();
if (!EmitCall(out, c->expr)) {
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index faef168..2742f00 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -125,6 +125,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully
bool EmitBreak(const ast::BreakStatement* stmt);
+ /// Handles a break-if statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted successfully
+ bool EmitBreakIf(const ast::BreakIfStatement* stmt);
/// Handles generating a call expression
/// @param out the output of the expression stream
/// @param expr the call expression
diff --git a/src/tint/writer/hlsl/generator_impl_loop_test.cc b/src/tint/writer/hlsl/generator_impl_loop_test.cc
index 238fdd2..3d8219b 100644
--- a/src/tint/writer/hlsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_loop_test.cc
@@ -65,6 +65,31 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing_BreakIf) {
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* body = Block(create<ast::DiscardStatement>());
+ auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+ auto* l = Loop(body, continuing);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while (true) {
+ discard;
+ {
+ a_statement();
+ if (true) { break; }
+ }
+ }
+)");
+}
+
TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index c7c7b63..28478b8 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -660,6 +660,16 @@
return true;
}
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+ auto out = line();
+ out << "if (";
+ if (!EmitExpression(out, b->condition)) {
+ return false;
+ }
+ out << ") { break; }";
+ return true;
+}
+
bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
auto* call = program_->Sem().Get<sem::Call>(expr);
auto* target = call->Target();
@@ -2433,6 +2443,9 @@
[&](const ast::BreakStatement* b) { //
return EmitBreak(b);
},
+ [&](const ast::BreakIfStatement* b) { //
+ return EmitBreakIf(b);
+ },
[&](const ast::CallStatement* c) { //
auto out = line();
if (!EmitCall(out, c->expr)) { //
diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h
index 5e16fbe..188bfea 100644
--- a/src/tint/writer/msl/generator_impl.h
+++ b/src/tint/writer/msl/generator_impl.h
@@ -129,6 +129,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully
bool EmitBreak(const ast::BreakStatement* stmt);
+ /// Handles a break-if statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted successfully
+ bool EmitBreakIf(const ast::BreakIfStatement* stmt);
/// Handles generating a call expression
/// @param out the output of the expression stream
/// @param expr the call expression
diff --git a/src/tint/writer/msl/generator_impl_loop_test.cc b/src/tint/writer/msl/generator_impl_loop_test.cc
index 274ee941..1dd5430 100644
--- a/src/tint/writer/msl/generator_impl_loop_test.cc
+++ b/src/tint/writer/msl/generator_impl_loop_test.cc
@@ -65,6 +65,31 @@
)");
}
+TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing_BreakIf) {
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* body = Block(create<ast::DiscardStatement>());
+ auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+ auto* l = Loop(body, continuing);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while (true) {
+ discard_fragment();
+ {
+ a_statement();
+ if (true) { break; }
+ }
+ }
+)");
+}
+
TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
Func("a_statement", {}, ty.void_(), utils::Empty);
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index c93a448..91b16b4 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -455,6 +455,19 @@
return true;
}
+bool Builder::GenerateBreakIfStatement(const ast::BreakIfStatement* stmt) {
+ TINT_ASSERT(Writer, !backedge_stack_.empty());
+ const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
+ if (!cond_id) {
+ return false;
+ }
+ const ContinuingInfo& ci = continuing_stack_.back();
+ backedge_stack_.back() =
+ Backedge(spv::Op::OpBranchConditional,
+ {Operand(cond_id), Operand(ci.break_target_id), Operand(ci.loop_header_id)});
+ return true;
+}
+
bool Builder::GenerateContinueStatement(const ast::ContinueStatement*) {
if (continue_stack_.empty()) {
error_ = "Attempted to continue without a continue block";
@@ -3400,6 +3413,8 @@
// continuing { ...
// if (cond) {} else {break;}
// }
+ //
+ // TODO(crbug.com/tint/1451): Remove this when the if break construct is made an error.
auto is_just_a_break = [](const ast::BlockStatement* block) {
return block && (block->statements.Length() == 1) &&
block->Last()->Is<ast::BreakStatement>();
@@ -3643,6 +3658,7 @@
stmt, [&](const ast::AssignmentStatement* a) { return GenerateAssignStatement(a); },
[&](const ast::BlockStatement* b) { return GenerateBlockStatement(b); },
[&](const ast::BreakStatement* b) { return GenerateBreakStatement(b); },
+ [&](const ast::BreakIfStatement* b) { return GenerateBreakIfStatement(b); },
[&](const ast::CallStatement* c) { return GenerateCallExpression(c->expr) != 0; },
[&](const ast::ContinueStatement* c) { return GenerateContinueStatement(c); },
[&](const ast::DiscardStatement* d) { return GenerateDiscardStatement(d); },
@@ -3659,7 +3675,7 @@
return true; // Not emitted
},
[&](Default) {
- error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
+ error_ = "unknown statement type: " + std::string(stmt->TypeInfo().name);
return false;
});
}
diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h
index 9e633f1..1e412fa 100644
--- a/src/tint/writer/spirv/builder.h
+++ b/src/tint/writer/spirv/builder.h
@@ -248,6 +248,10 @@
/// @param stmt the statement to generate
/// @returns true if the statement was successfully generated
bool GenerateBreakStatement(const ast::BreakStatement* stmt);
+ /// Generates a break-if statement
+ /// @param stmt the statement to generate
+ /// @returns true if the statement was successfully generated
+ bool GenerateBreakIfStatement(const ast::BreakIfStatement* stmt);
/// Generates a continue statement
/// @param stmt the statement to generate
/// @returns true if the statement was successfully generated
diff --git a/src/tint/writer/spirv/builder_loop_test.cc b/src/tint/writer/spirv/builder_loop_test.cc
index 3ab27ef..e27cac4 100644
--- a/src/tint/writer/spirv/builder_loop_test.cc
+++ b/src/tint/writer/spirv/builder_loop_test.cc
@@ -234,12 +234,11 @@
TEST_F(BuilderTest, Loop_WithContinuing_BreakIf) {
// loop {
// continuing {
- // if (true) { break; }
+ // break if (true);
// }
// }
- auto* if_stmt = If(Expr(true), Block(Break()));
- auto* continuing = Block(if_stmt);
+ auto* continuing = Block(BreakIf(true));
auto* loop = Loop(Block(), continuing);
WrapInFunction(loop);
@@ -267,11 +266,10 @@
TEST_F(BuilderTest, Loop_WithContinuing_BreakUnless) {
// loop {
// continuing {
- // if (true) {} else { break; }
+ // break if (false);
// }
// }
- auto* if_stmt = If(Expr(true), Block(), Else(Block(Break())));
- auto* continuing = Block(if_stmt);
+ auto* continuing = Block(BreakIf(false));
auto* loop = Loop(Block(), continuing);
WrapInFunction(loop);
@@ -281,7 +279,7 @@
EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
-%6 = OpConstantTrue %5
+%6 = OpConstantNull %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpBranch %1
@@ -291,7 +289,7 @@
%4 = OpLabel
OpBranch %3
%3 = OpLabel
-OpBranchConditional %6 %1 %2
+OpBranchConditional %6 %2 %1
%2 = OpLabel
)");
}
@@ -300,13 +298,12 @@
// loop {
// continuing {
// var cond = true;
- // if (cond) { break; }
+ // break if (cond);
// }
// }
auto* cond_var = Decl(Var("cond", Expr(true)));
- auto* if_stmt = If(Expr("cond"), Block(Break()));
- auto* continuing = Block(cond_var, if_stmt);
+ auto* continuing = Block(cond_var, BreakIf("cond"));
auto* loop = Loop(Block(), continuing);
WrapInFunction(loop);
@@ -379,19 +376,17 @@
// continuing {
// loop {
// continuing {
- // if (true) { break; }
+ // break if (true);
// }
// }
- // if (true) { break; }
+ // break if (true);
// }
// }
- auto* inner_if_stmt = If(Expr(true), Block(Break()));
- auto* inner_continuing = Block(inner_if_stmt);
+ auto* inner_continuing = Block(BreakIf(true));
auto* inner_loop = Loop(Block(), inner_continuing);
- auto* outer_if_stmt = If(Expr(true), Block(Break()));
- auto* outer_continuing = Block(inner_loop, outer_if_stmt);
+ auto* outer_continuing = Block(inner_loop, BreakIf(true));
auto* outer_loop = Loop(Block(), outer_continuing);
WrapInFunction(outer_loop);
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index c62fbf5..bf7ebd7 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -958,6 +958,7 @@
[&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
[&](const ast::BlockStatement* b) { return EmitBlock(b); },
[&](const ast::BreakStatement* b) { return EmitBreak(b); },
+ [&](const ast::BreakIfStatement* b) { return EmitBreakIf(b); },
[&](const ast::CallStatement* c) {
auto out = line();
if (!EmitCall(out, c->expr)) {
@@ -1023,6 +1024,17 @@
return true;
}
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+ auto out = line();
+
+ out << "break if ";
+ if (!EmitExpression(out, b->condition)) {
+ return false;
+ }
+ out << ";";
+ return true;
+}
+
bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
if (stmt->selectors.Length() == 1 && stmt->ContainsDefault()) {
line() << "default: {";
diff --git a/src/tint/writer/wgsl/generator_impl.h b/src/tint/writer/wgsl/generator_impl.h
index f4fc467..2b7f1c9 100644
--- a/src/tint/writer/wgsl/generator_impl.h
+++ b/src/tint/writer/wgsl/generator_impl.h
@@ -20,6 +20,7 @@
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/binary_expression.h"
#include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_if_statement.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/compound_assignment_statement.h"
#include "src/tint/ast/continue_statement.h"
@@ -92,6 +93,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully
bool EmitBreak(const ast::BreakStatement* stmt);
+ /// Handles a break-if statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted successfully
+ bool EmitBreakIf(const ast::BreakIfStatement* stmt);
/// Handles generating a call expression
/// @param out the output of the expression stream
/// @param expr the call expression
diff --git a/src/tint/writer/wgsl/generator_impl_loop_test.cc b/src/tint/writer/wgsl/generator_impl_loop_test.cc
index bf0eef6..48510bb 100644
--- a/src/tint/writer/wgsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_loop_test.cc
@@ -65,6 +65,32 @@
)");
}
+TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing_BreakIf) {
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* body = Block(create<ast::DiscardStatement>());
+ auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+ auto* l = Loop(body, continuing);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( loop {
+ discard;
+
+ continuing {
+ a_statement();
+ break if true;
+ }
+ }
+)");
+}
+
TEST_F(WgslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
// var<workgroup> a : atomic<i32>;
// for({ignore(1i); ignore(2i);}; ; ) {