[tint][ir] Validate continue statements
• Ensure that no variables are declared after the 'continue' and used in the continuing block.
• Ensure that 'continue' is only used in the loop body.
Change-Id: Ie21c321dbaba1a18579580d070947e2f5b17577c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/186642
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 3e0e704..f862f95 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -41,6 +41,7 @@
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
+#include "src/tint/lang/core/ir/control_instruction.h"
#include "src/tint/lang/core/ir/convert.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/disassembly.h"
@@ -51,6 +52,7 @@
#include "src/tint/lang/core/ir/function.h"
#include "src/tint/lang/core/ir/function_param.h"
#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/instruction.h"
#include "src/tint/lang/core/ir/instruction_result.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
@@ -76,6 +78,7 @@
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/core/type/void.h"
#include "src/tint/utils/containers/hashset.h"
+#include "src/tint/utils/containers/predicates.h"
#include "src/tint/utils/containers/reverse.h"
#include "src/tint/utils/containers/transform.h"
#include "src/tint/utils/ice/ice.h"
@@ -97,6 +100,24 @@
namespace {
+/// @returns the parent block of @p block
+const Block* ParentBlockOf(const Block* block) {
+ if (auto* parent = block->Parent()) {
+ return parent->Block();
+ }
+ return nullptr;
+}
+
+/// @returns true if @p block directly or transitively holds the instruction @p inst
+bool TransitivelyHolds(const Block* block, const Instruction* inst) {
+ for (auto* b = inst->Block(); b; b = ParentBlockOf(b)) {
+ if (b == block) {
+ return true;
+ }
+ }
+ return false;
+}
+
/// @returns true if the type @p type is of, or indirectly references a type of type `T`.
template <typename T>
bool HoldsType(const type::Type* type) {
@@ -294,6 +315,10 @@
/// @param l the loop to validate
void CheckLoop(const Loop* l);
+ /// Validates the loop continuing block
+ /// @param l the loop to validate
+ void CheckLoopContinuing(const Loop* l);
+
/// Validates the given switch
/// @param s the switch to validate
void CheckSwitch(const Switch* s);
@@ -306,6 +331,10 @@
/// @param e the exit to validate
void CheckExit(const Exit* e);
+ /// Validates the continue instruction
+ /// @param c the continue to validate
+ void CheckContinue(const Continue* c);
+
/// Validates the given exit if
/// @param e the exit if to validate
void CheckExitIf(const ExitIf* e);
@@ -389,6 +418,7 @@
diag::List diagnostics_;
Hashset<const Function*, 4> all_functions_;
Hashset<const Instruction*, 4> visited_instructions_;
+ Hashmap<const Loop*, const Continue*, 4> first_continues_;
Vector<const ControlInstruction*, 8> control_stack_;
Vector<const Block*, 8> block_stack_;
ScopeStack scope_stack_;
@@ -1043,7 +1073,10 @@
void Validator::CheckLoop(const Loop* l) {
// Note: Tasks are queued in reverse order of their execution
- tasks_.Push([this] { control_stack_.Pop(); });
+ tasks_.Push([this, l] {
+ first_continues_.Remove(l); // No need for this any more. Free memory.
+ control_stack_.Pop();
+ });
if (!l->Initializer()->IsEmpty()) {
tasks_.Push([this] { EndBlock(); });
}
@@ -1057,8 +1090,12 @@
// ⎣ ⎣ [Continuing ] ⎦⎦
if (!l->Continuing()->IsEmpty()) {
- tasks_.Push([this, l] { BeginBlock(l->Continuing()); });
+ tasks_.Push([this, l] {
+ CheckLoopContinuing(l);
+ BeginBlock(l->Continuing());
+ });
}
+
tasks_.Push([this, l] { BeginBlock(l->Body()); });
if (!l->Initializer()->IsEmpty()) {
tasks_.Push([this, l] { BeginBlock(l->Initializer()); });
@@ -1066,6 +1103,41 @@
tasks_.Push([this, l] { control_stack_.Push(l); });
}
+void Validator::CheckLoopContinuing(const Loop* loop) {
+ if (!loop->HasContinuing()) {
+ return;
+ }
+
+ // Ensure that values used in the loop continuing are not from the loop body, after a
+ // continue instruction.
+ if (auto* first_continue = first_continues_.GetOr(loop, nullptr)) {
+ // Find the instruction in the body block that is or holds the first continue instruction.
+ const Instruction* holds_continue = first_continue;
+ while (holds_continue && holds_continue->Block() &&
+ holds_continue->Block() != loop->Body()) {
+ holds_continue = holds_continue->Block()->Parent();
+ }
+
+ // Check that all subsequent instruction values are not used in the continuing block.
+ for (auto* inst = holds_continue; inst; inst = inst->next) {
+ for (auto* result : inst->Results()) {
+ result->ForEachUse([&](Usage use) {
+ if (TransitivelyHolds(loop->Continuing(), use.instruction)) {
+ AddError(use.instruction, use.operand_index)
+ << NameOf(result)
+ << " cannot be used in continuing block as it is declared after the "
+ "first "
+ << style::Instruction("continue") << " in the loop's body";
+ AddDeclarationNote(result);
+ AddNote(first_continue)
+ << "loop body's first " << style::Instruction("continue");
+ }
+ });
+ }
+ }
+ }
+}
+
void Validator::CheckSwitch(const Switch* s) {
tasks_.Push([this] { control_stack_.Pop(); });
@@ -1083,7 +1155,7 @@
tint::Switch(
b, //
[&](const ir::BreakIf*) {}, //
- [&](const ir::Continue*) {}, //
+ [&](const ir::Continue* c) { CheckContinue(c); }, //
[&](const ir::Exit* e) { CheckExit(e); }, //
[&](const ir::NextIteration*) {}, //
[&](const ir::Return* ret) { CheckReturn(ret); }, //
@@ -1092,6 +1164,23 @@
[&](Default) { AddError(b) << "missing validation"; });
}
+void Validator::CheckContinue(const Continue* c) {
+ auto* loop = c->Loop();
+ if (loop == nullptr) {
+ AddError(c) << "has no associated loop";
+ return;
+ }
+ if (!TransitivelyHolds(loop->Body(), c)) {
+ if (control_stack_.Any(Eq<const ControlInstruction*>(loop))) {
+ AddError(c) << "must only be called from loop body";
+ } else {
+ AddError(c) << "called outside of associated loop";
+ }
+ }
+
+ first_continues_.Add(loop, c);
+}
+
void Validator::CheckExit(const Exit* e) {
if (e->ControlInstruction() == nullptr) {
AddError(e) << "has no parent control instruction";
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 2619338..8e158af 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -2603,6 +2603,206 @@
)");
}
+TEST_F(IR_ValidatorTest, ContinueOutsideOfLoop) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Continue(loop);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:8:5 error: continue: called outside of associated loop
+ continue # -> $B3
+ ^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2] { # loop_1
+ $B2: { # body
+ exit_loop # loop_1
+ }
+ }
+ continue # -> $B3
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopInit) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] { b.Continue(loop); });
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:5:9 error: continue: must only be called from loop body
+ continue # -> $B4
+ ^^^^^^^^
+
+:4:7 note: in block
+ $B2: { # initializer
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ continue # -> $B4
+ }
+ $B3: { # body
+ exit_loop # loop_1
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopBody) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.Continue(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopContinuing) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Append(loop->Continuing(), [&] { b.Continue(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:8:9 error: continue: must only be called from loop body
+ continue # -> $B3
+ ^^^^^^^^
+
+:7:7 note: in block
+ $B3: { # continuing
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2, c: $B3] { # loop_1
+ $B2: { # body
+ exit_loop # loop_1
+ }
+ $B3: { # continuing
+ continue # -> $B3
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinuingUseValueBeforeContinue) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* value = b.Let("value", 1_i);
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] {
+ b.Append(value);
+ b.Append(b.If(true)->True(), [&] { b.Continue(loop); });
+ b.ExitLoop(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ b.Let("use", value);
+ b.NextIteration(loop);
+ });
+ b.Return(f);
+ });
+
+ ASSERT_EQ(ir::Validate(mod), Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinuingUseValueAfterContinue) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* value = b.Let("value", 1_i);
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] {
+ b.Append(b.If(true)->True(), [&] { b.Continue(loop); });
+ b.Append(value);
+ b.ExitLoop(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ b.Let("use", value);
+ b.NextIteration(loop);
+ });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(
+ res.Failure().reason.Str(),
+ R"(:14:24 error: let: %value cannot be used in continuing block as it is declared after the first 'continue' in the loop's body
+ %use:i32 = let %value
+ ^^^^^^
+
+:4:7 note: in block
+ $B2: { # body
+ ^^^
+
+:10:9 note: %value declared here
+ %value:i32 = let 1i
+ ^^^^^^^^^^
+
+:7:13 note: loop body's first 'continue'
+ continue # -> $B3
+ ^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2, c: $B3] { # loop_1
+ $B2: { # body
+ if true [t: $B4] { # if_1
+ $B4: { # true
+ continue # -> $B3
+ }
+ }
+ %value:i32 = let 1i
+ exit_loop # loop_1
+ }
+ $B3: { # continuing
+ %use:i32 = let %value
+ next_iteration # -> $B2
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, ExitLoop) {
auto* loop = b.Loop();
loop->Continuing()->Append(b.NextIteration(loop));
diff --git a/src/tint/lang/core/ir/value.cc b/src/tint/lang/core/ir/value.cc
index c1a66e8..af66b0e 100644
--- a/src/tint/lang/core/ir/value.cc
+++ b/src/tint/lang/core/ir/value.cc
@@ -44,7 +44,7 @@
flags_.Add(Flag::kDead);
}
-void Value::ForEachUse(std::function<void(Usage use)> func) {
+void Value::ForEachUse(std::function<void(Usage use)> func) const {
auto uses = uses_;
for (auto& use : uses) {
func(use);
diff --git a/src/tint/lang/core/ir/value.h b/src/tint/lang/core/ir/value.h
index 0603e87..e6f1bad 100644
--- a/src/tint/lang/core/ir/value.h
+++ b/src/tint/lang/core/ir/value.h
@@ -105,7 +105,7 @@
/// Apply a function to all uses of the value that exist prior to calling this method.
/// @param func the function will be applied to each use
- void ForEachUse(std::function<void(Usage use)> func);
+ void ForEachUse(std::function<void(Usage use)> func) const;
/// Replace all uses of the value.
/// @param replacer a function which returns a replacement for a given use