[tint][ir] Validate continue statements

• Ensure that no variables are declared after the 'continue' and used in the continuing block.
• Ensure that 'continue' is only used in the loop body.

Change-Id: Ie21c321dbaba1a18579580d070947e2f5b17577c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/186642
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 3e0e704..f862f95 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -41,6 +41,7 @@
 #include "src/tint/lang/core/ir/constant.h"
 #include "src/tint/lang/core/ir/construct.h"
 #include "src/tint/lang/core/ir/continue.h"
+#include "src/tint/lang/core/ir/control_instruction.h"
 #include "src/tint/lang/core/ir/convert.h"
 #include "src/tint/lang/core/ir/core_builtin_call.h"
 #include "src/tint/lang/core/ir/disassembly.h"
@@ -51,6 +52,7 @@
 #include "src/tint/lang/core/ir/function.h"
 #include "src/tint/lang/core/ir/function_param.h"
 #include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/instruction.h"
 #include "src/tint/lang/core/ir/instruction_result.h"
 #include "src/tint/lang/core/ir/let.h"
 #include "src/tint/lang/core/ir/load.h"
@@ -76,6 +78,7 @@
 #include "src/tint/lang/core/type/vector.h"
 #include "src/tint/lang/core/type/void.h"
 #include "src/tint/utils/containers/hashset.h"
+#include "src/tint/utils/containers/predicates.h"
 #include "src/tint/utils/containers/reverse.h"
 #include "src/tint/utils/containers/transform.h"
 #include "src/tint/utils/ice/ice.h"
@@ -97,6 +100,24 @@
 
 namespace {
 
+/// @returns the parent block of @p block
+const Block* ParentBlockOf(const Block* block) {
+    if (auto* parent = block->Parent()) {
+        return parent->Block();
+    }
+    return nullptr;
+}
+
+/// @returns true if @p block directly or transitively holds the instruction @p inst
+bool TransitivelyHolds(const Block* block, const Instruction* inst) {
+    for (auto* b = inst->Block(); b; b = ParentBlockOf(b)) {
+        if (b == block) {
+            return true;
+        }
+    }
+    return false;
+}
+
 /// @returns true if the type @p type is of, or indirectly references a type of type `T`.
 template <typename T>
 bool HoldsType(const type::Type* type) {
@@ -294,6 +315,10 @@
     /// @param l the loop to validate
     void CheckLoop(const Loop* l);
 
+    /// Validates the loop continuing block
+    /// @param l the loop to validate
+    void CheckLoopContinuing(const Loop* l);
+
     /// Validates the given switch
     /// @param s the switch to validate
     void CheckSwitch(const Switch* s);
@@ -306,6 +331,10 @@
     /// @param e the exit to validate
     void CheckExit(const Exit* e);
 
+    /// Validates the continue instruction
+    /// @param c the continue to validate
+    void CheckContinue(const Continue* c);
+
     /// Validates the given exit if
     /// @param e the exit if to validate
     void CheckExitIf(const ExitIf* e);
@@ -389,6 +418,7 @@
     diag::List diagnostics_;
     Hashset<const Function*, 4> all_functions_;
     Hashset<const Instruction*, 4> visited_instructions_;
+    Hashmap<const Loop*, const Continue*, 4> first_continues_;
     Vector<const ControlInstruction*, 8> control_stack_;
     Vector<const Block*, 8> block_stack_;
     ScopeStack scope_stack_;
@@ -1043,7 +1073,10 @@
 
 void Validator::CheckLoop(const Loop* l) {
     // Note: Tasks are queued in reverse order of their execution
-    tasks_.Push([this] { control_stack_.Pop(); });
+    tasks_.Push([this, l] {
+        first_continues_.Remove(l);  // No need for this any more. Free memory.
+        control_stack_.Pop();
+    });
     if (!l->Initializer()->IsEmpty()) {
         tasks_.Push([this] { EndBlock(); });
     }
@@ -1057,8 +1090,12 @@
     // ⎣    ⎣    [Continuing ]  ⎦⎦
 
     if (!l->Continuing()->IsEmpty()) {
-        tasks_.Push([this, l] { BeginBlock(l->Continuing()); });
+        tasks_.Push([this, l] {
+            CheckLoopContinuing(l);
+            BeginBlock(l->Continuing());
+        });
     }
+
     tasks_.Push([this, l] { BeginBlock(l->Body()); });
     if (!l->Initializer()->IsEmpty()) {
         tasks_.Push([this, l] { BeginBlock(l->Initializer()); });
@@ -1066,6 +1103,41 @@
     tasks_.Push([this, l] { control_stack_.Push(l); });
 }
 
+void Validator::CheckLoopContinuing(const Loop* loop) {
+    if (!loop->HasContinuing()) {
+        return;
+    }
+
+    // Ensure that values used in the loop continuing are not from the loop body, after a
+    // continue instruction.
+    if (auto* first_continue = first_continues_.GetOr(loop, nullptr)) {
+        // Find the instruction in the body block that is or holds the first continue instruction.
+        const Instruction* holds_continue = first_continue;
+        while (holds_continue && holds_continue->Block() &&
+               holds_continue->Block() != loop->Body()) {
+            holds_continue = holds_continue->Block()->Parent();
+        }
+
+        // Check that all subsequent instruction values are not used in the continuing block.
+        for (auto* inst = holds_continue; inst; inst = inst->next) {
+            for (auto* result : inst->Results()) {
+                result->ForEachUse([&](Usage use) {
+                    if (TransitivelyHolds(loop->Continuing(), use.instruction)) {
+                        AddError(use.instruction, use.operand_index)
+                            << NameOf(result)
+                            << " cannot be used in continuing block as it is declared after the "
+                               "first "
+                            << style::Instruction("continue") << " in the loop's body";
+                        AddDeclarationNote(result);
+                        AddNote(first_continue)
+                            << "loop body's first " << style::Instruction("continue");
+                    }
+                });
+            }
+        }
+    }
+}
+
 void Validator::CheckSwitch(const Switch* s) {
     tasks_.Push([this] { control_stack_.Pop(); });
 
@@ -1083,7 +1155,7 @@
     tint::Switch(
         b,                                                 //
         [&](const ir::BreakIf*) {},                        //
-        [&](const ir::Continue*) {},                       //
+        [&](const ir::Continue* c) { CheckContinue(c); },  //
         [&](const ir::Exit* e) { CheckExit(e); },          //
         [&](const ir::NextIteration*) {},                  //
         [&](const ir::Return* ret) { CheckReturn(ret); },  //
@@ -1092,6 +1164,23 @@
         [&](Default) { AddError(b) << "missing validation"; });
 }
 
+void Validator::CheckContinue(const Continue* c) {
+    auto* loop = c->Loop();
+    if (loop == nullptr) {
+        AddError(c) << "has no associated loop";
+        return;
+    }
+    if (!TransitivelyHolds(loop->Body(), c)) {
+        if (control_stack_.Any(Eq<const ControlInstruction*>(loop))) {
+            AddError(c) << "must only be called from loop body";
+        } else {
+            AddError(c) << "called outside of associated loop";
+        }
+    }
+
+    first_continues_.Add(loop, c);
+}
+
 void Validator::CheckExit(const Exit* e) {
     if (e->ControlInstruction() == nullptr) {
         AddError(e) << "has no parent control instruction";
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 2619338..8e158af 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -2603,6 +2603,206 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, ContinueOutsideOfLoop) {
+    auto* f = b.Function("my_func", ty.void_());
+    b.Append(f->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+        b.Continue(loop);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:8:5 error: continue: called outside of associated loop
+    continue  # -> $B3
+    ^^^^^^^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    loop [b: $B2] {  # loop_1
+      $B2: {  # body
+        exit_loop  # loop_1
+      }
+    }
+    continue  # -> $B3
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopInit) {
+    auto* f = b.Function("my_func", ty.void_());
+    b.Append(f->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] { b.Continue(loop); });
+        b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:5:9 error: continue: must only be called from loop body
+        continue  # -> $B4
+        ^^^^^^^^
+
+:4:7 note: in block
+      $B2: {  # initializer
+      ^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    loop [i: $B2, b: $B3] {  # loop_1
+      $B2: {  # initializer
+        continue  # -> $B4
+      }
+      $B3: {  # body
+        exit_loop  # loop_1
+      }
+    }
+    ret
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopBody) {
+    auto* f = b.Function("my_func", ty.void_());
+    b.Append(f->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] { b.Continue(loop); });
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopContinuing) {
+    auto* f = b.Function("my_func", ty.void_());
+    b.Append(f->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+        b.Append(loop->Continuing(), [&] { b.Continue(loop); });
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:8:9 error: continue: must only be called from loop body
+        continue  # -> $B3
+        ^^^^^^^^
+
+:7:7 note: in block
+      $B3: {  # continuing
+      ^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    loop [b: $B2, c: $B3] {  # loop_1
+      $B2: {  # body
+        exit_loop  # loop_1
+      }
+      $B3: {  # continuing
+        continue  # -> $B3
+      }
+    }
+    ret
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinuingUseValueBeforeContinue) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* value = b.Let("value", 1_i);
+    b.Append(f->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] {
+            b.Append(value);
+            b.Append(b.If(true)->True(), [&] { b.Continue(loop); });
+            b.ExitLoop(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            b.Let("use", value);
+            b.NextIteration(loop);
+        });
+        b.Return(f);
+    });
+
+    ASSERT_EQ(ir::Validate(mod), Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinuingUseValueAfterContinue) {
+    auto* f = b.Function("my_func", ty.void_());
+    auto* value = b.Let("value", 1_i);
+    b.Append(f->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] {
+            b.Append(b.If(true)->True(), [&] { b.Continue(loop); });
+            b.Append(value);
+            b.ExitLoop(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            b.Let("use", value);
+            b.NextIteration(loop);
+        });
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(
+        res.Failure().reason.Str(),
+        R"(:14:24 error: let: %value cannot be used in continuing block as it is declared after the first 'continue' in the loop's body
+        %use:i32 = let %value
+                       ^^^^^^
+
+:4:7 note: in block
+      $B2: {  # body
+      ^^^
+
+:10:9 note: %value declared here
+        %value:i32 = let 1i
+        ^^^^^^^^^^
+
+:7:13 note: loop body's first 'continue'
+            continue  # -> $B3
+            ^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    loop [b: $B2, c: $B3] {  # loop_1
+      $B2: {  # body
+        if true [t: $B4] {  # if_1
+          $B4: {  # true
+            continue  # -> $B3
+          }
+        }
+        %value:i32 = let 1i
+        exit_loop  # loop_1
+      }
+      $B3: {  # continuing
+        %use:i32 = let %value
+        next_iteration  # -> $B2
+      }
+    }
+    ret
+  }
+}
+)");
+}
+
 TEST_F(IR_ValidatorTest, ExitLoop) {
     auto* loop = b.Loop();
     loop->Continuing()->Append(b.NextIteration(loop));
diff --git a/src/tint/lang/core/ir/value.cc b/src/tint/lang/core/ir/value.cc
index c1a66e8..af66b0e 100644
--- a/src/tint/lang/core/ir/value.cc
+++ b/src/tint/lang/core/ir/value.cc
@@ -44,7 +44,7 @@
     flags_.Add(Flag::kDead);
 }
 
-void Value::ForEachUse(std::function<void(Usage use)> func) {
+void Value::ForEachUse(std::function<void(Usage use)> func) const {
     auto uses = uses_;
     for (auto& use : uses) {
         func(use);
diff --git a/src/tint/lang/core/ir/value.h b/src/tint/lang/core/ir/value.h
index 0603e87..e6f1bad 100644
--- a/src/tint/lang/core/ir/value.h
+++ b/src/tint/lang/core/ir/value.h
@@ -105,7 +105,7 @@
 
     /// Apply a function to all uses of the value that exist prior to calling this method.
     /// @param func the function will be applied to each use
-    void ForEachUse(std::function<void(Usage use)> func);
+    void ForEachUse(std::function<void(Usage use)> func) const;
 
     /// Replace all uses of the value.
     /// @param replacer a function which returns a replacement for a given use