[tint][ir][ToProgram] Implement loops
Change-Id: I28218132a864a538a366d91a61aec3df5ab34f9d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/138282
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 3ff8811..8121bab 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -701,6 +701,9 @@
SetTerminator(builder_.ExitLoop(loop_inst));
}
+ EmitStatements(stmt->body->statements);
+
+ // The current block didn't `break`, `return` or `continue`, go to the continuing block.
if (NeedTerminator()) {
SetTerminator(builder_.Continue(loop_inst));
}
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index 81738d6..c5384bc 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -681,7 +681,7 @@
ASSERT_EQ(1u, m.functions.Length());
EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
- EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -696,7 +696,7 @@
exit_loop # loop_1
}
}
- continue %b3
+ ret
}
%b3 = block { # continuing
next_iteration %b2
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 495f608..d792849 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -23,13 +23,17 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/ir/continue.h"
#include "src/tint/ir/exit_if.h"
+#include "src/tint/ir/exit_loop.h"
#include "src/tint/ir/exit_switch.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/load.h"
+#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/multi_in_block.h"
+#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
@@ -159,13 +163,17 @@
[&](ir::Call* i) { Call(i); }, //
[&](ir::ExitIf*) {}, //
[&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
+ [&](ir::ExitLoop* i) { ExitLoop(i); }, //
[&](ir::If* i) { If(i); }, //
[&](ir::Load* l) { Load(l); }, //
+ [&](ir::Loop* l) { Loop(l); }, //
[&](ir::Return* i) { Return(i); }, //
[&](ir::Store* i) { Store(i); }, //
[&](ir::Switch* i) { Switch(i); }, //
[&](ir::Unary* u) { Unary(u); }, //
[&](ir::Var* i) { Var(i); }, //
+ [&](ir::NextIteration*) {}, //
+ [&](ir::Continue*) {}, //
[&](Default) { UNHANDLED_CASE(inst); });
}
@@ -174,7 +182,7 @@
auto true_stmts = Statements(if_->True());
auto false_stmts = Statements(if_->False());
- if (IsShortCircuit(if_, true_stmts, false_stmts)) {
+ if (AsShortCircuit(if_, true_stmts, false_stmts)) {
return;
}
@@ -197,6 +205,57 @@
Append(b.If(cond, true_block, b.Else(false_block)));
}
+ void Loop(ir::Loop* l) {
+ auto init_stmts = Statements(l->Initializer());
+ auto* init = init_stmts.Length() == 1 ? init_stmts.Front()->As<ast::VariableDeclStatement>()
+ : nullptr;
+
+ const ast::Expression* cond = nullptr;
+
+ StatementList body_stmts;
+ {
+ TINT_SCOPED_ASSIGNMENT(statements_, &body_stmts);
+ for (auto* inst : *l->Body()) {
+ if (body_stmts.IsEmpty()) {
+ if (auto* if_ = inst->As<ir::If>()) {
+ if (!if_->HasResults() && //
+ if_->True()->Length() == 1 && //
+ if_->False()->Length() == 1 && //
+ tint::Is<ir::ExitIf>(if_->True()->Front()) && //
+ tint::Is<ir::ExitLoop>(if_->False()->Front())) {
+ cond = Expr(if_->Condition());
+ continue;
+ }
+ }
+ }
+
+ Instruction(inst);
+ }
+ }
+
+ auto cont_stmts = Statements(l->Continuing());
+ auto* cont = cont_stmts.Length() == 1 ? cont_stmts.Front() : nullptr;
+
+ auto* body = b.Block(std::move(body_stmts));
+
+ const ast::Statement* loop = nullptr;
+ if (cond) {
+ if (init || cont) {
+ loop = b.For(init, cond, cont, body);
+ } else {
+ loop = b.While(cond, body);
+ }
+ } else {
+ loop = cont_stmts.IsEmpty() ? b.Loop(body) //
+ : b.Loop(body, b.Block(std::move(cont_stmts)));
+ if (!init_stmts.IsEmpty()) {
+ init_stmts.Push(loop);
+ loop = b.Block(std::move(init_stmts));
+ }
+ }
+ statements_->Push(loop);
+ }
+
void Switch(ir::Switch* s) {
SCOPED_NESTING();
@@ -232,6 +291,8 @@
Append(b.Break());
}
+ void ExitLoop(const ir::ExitLoop*) { Append(b.Break()); }
+
void Return(ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
@@ -575,7 +636,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////
// Helpers
////////////////////////////////////////////////////////////////////////////////////////////////
- bool IsShortCircuit(ir::If* i,
+ bool AsShortCircuit(ir::If* i,
const StatementList& true_stmts,
const StatementList& false_stmts) {
if (!i->HasResults()) {
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 2d91267..0fda9be 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -663,8 +663,7 @@
fn a() {
}
-fn f() {
- var cond : bool = true;
+fn f(cond : bool) {
if (cond) {
a();
}
@@ -674,8 +673,7 @@
TEST_F(IRToProgramRoundtripTest, If_Return) {
Test(R"(
-fn f() {
- var cond : bool = true;
+fn f(cond : bool) {
if (cond) {
return;
}
@@ -703,8 +701,7 @@
fn b() {
}
-fn f() {
- var cond : bool = true;
+fn f(cond : bool) {
if (cond) {
a();
} else {
@@ -915,5 +912,270 @@
)");
}
+////////////////////////////////////////////////////////////////////////////////
+// For
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_F(IRToProgramRoundtripTest, For_Empty) {
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_Empty_NoInit) {
+ Test(R"(
+fn f() {
+ var i : i32 = 0i;
+ for(; (i < 5i); i = (i + 1i)) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_Empty_NoCond) {
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; ; i = (i + 1i)) {
+ break;
+ }
+}
+)",
+ R"(
+fn f() {
+ {
+ var i : i32 = 0i;
+ loop {
+ break;
+
+ continuing {
+ i = (i + 1i);
+ }
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_Empty_NoCont) {
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; (i < 5i); ) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody_NoInit) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ var i : i32 = 0i;
+ for(; (i < 5i); i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody_NoCond) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; ; i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+}
+)",
+ R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ {
+ var i : i32 = 0i;
+ loop {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+
+ continuing {
+ i = (i + 1i);
+ }
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody_NoCont) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; (i < 5i); ) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// While
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_F(IRToProgramRoundtripTest, While_Empty) {
+ Test(R"(
+fn f() {
+ while(true) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_Cond) {
+ Test(R"(
+fn f(cond : bool) {
+ while(cond) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_Break) {
+ Test(R"(
+fn f() {
+ while(true) {
+ break;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_IfBreak) {
+ Test(R"(
+fn f(cond : bool) {
+ while(true) {
+ if (cond) {
+ break;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_IfReturn) {
+ Test(R"(
+fn f(cond : bool) {
+ while(true) {
+ if (cond) {
+ return;
+ }
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Loop
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_F(IRToProgramRoundtripTest, Loop_Break) {
+ Test(R"(
+fn f() {
+ loop {
+ break;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_IfBreak) {
+ Test(R"(
+fn f(cond : bool) {
+ loop {
+ if (cond) {
+ break;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_IfReturn) {
+ Test(R"(
+fn f(cond : bool) {
+ loop {
+ if (cond) {
+ return;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_IfContinuing) {
+ Test(R"(
+fn f() {
+ var cond : bool = false;
+ loop {
+ if (cond) {
+ return;
+ }
+
+ continuing {
+ cond = true;
+ }
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir