Import Tint changes from Dawn
Changes:
- 651732fc0e32bc2c41d5ca6ecdef44c25c92f3d0 [tint][ast] Fix RemovePhonies transform with short-circui... by Ben Clayton <bclayton@google.com>
- 12f92c5271b61be7e3ae122c1dd0b2ba5b3f2d25 [tint][ir] Split BreakIf arguments into two lists by Ben Clayton <bclayton@google.com>
- 203ef75874bd4e045a1a26bac6a6b99db9a7600b [tint][resolver] Fix evaluation stage of function calls by Ben Clayton <bclayton@google.com>
- cabf62259d57b046b38d305cceac1d84893d3329 [tint][ir] Validate NextIteration instructions by Ben Clayton <bclayton@google.com>
- 7b35ff1d2a59d182701a98c906b61ef8756d8840 [tint][ir] Validate continue statements by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 651732fc0e32bc2c41d5ca6ecdef44c25c92f3d0
Change-Id: I208ec105b4bf372971ec0bd199c5ae33999c82bb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/187881
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index eb320b2..c422433 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -341,6 +341,11 @@
}
inst_out->SetResults(std::move(results));
+ if (inst_in.has_break_if()) {
+ static_cast<BreakIf*>(inst_out)->SetNumNextIterValues(
+ inst_in.break_if().num_next_iter_values());
+ }
+
return inst_out;
}
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index 70305bd..74d00ac 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -253,7 +253,10 @@
void InstructionBitcast(pb::InstructionBitcast&, const ir::Bitcast*) {}
- void InstructionBreakIf(pb::InstructionBreakIf&, const ir::BreakIf*) {}
+ void InstructionBreakIf(pb::InstructionBreakIf& breakif_out, const ir::BreakIf* breakif_in) {
+ auto num_next_iter_values = static_cast<uint32_t>(breakif_in->NextIterValues().Length());
+ breakif_out.set_num_next_iter_values(num_next_iter_values);
+ }
void InstructionBuiltinCall(pb::InstructionBuiltinCall& call_out,
const ir::CoreBuiltinCall* call_in) {
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 19f3b0a..31c2821 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -354,7 +354,9 @@
message InstructionContinue {}
-message InstructionBreakIf {}
+message InstructionBreakIf {
+ uint32 num_next_iter_values = 1;
+}
message InstructionUnreachable {}
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 52a516b..76f92b8 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -698,7 +698,10 @@
TEST_F(IRBinaryRoundtripTest, LoopBlockParams) {
auto* fn = b.Function("Function", ty.void_());
b.Append(fn->Block(), [&] {
+ auto* loop_res_a = b.InstructionResult(ty.i32());
+ auto* loop_res_b = b.InstructionResult(ty.f32());
auto* loop = b.Loop();
+ loop->SetResults(Vector{loop_res_a, loop_res_b});
b.Append(loop->Initializer(), [&] {
b.Let("L", 1_i);
b.NextIteration(loop);
@@ -710,7 +713,12 @@
auto* z = b.BlockParam<u32>("z");
auto* w = b.BlockParam<bool>("w");
loop->Continuing()->SetParams({z, w});
- b.Append(loop->Continuing(), [&] { b.BreakIf(loop, false, 3_i, 4_f); });
+ b.Append(loop->Continuing(), [&] {
+ b.BreakIf(loop,
+ /* condition */ false,
+ /* next iter */ b.Values(3_i, 4_f),
+ /* exit */ b.Values(5_u, 6_i));
+ });
b.Return(fn);
});
RUN_TEST();
diff --git a/src/tint/lang/core/ir/break_if.cc b/src/tint/lang/core/ir/break_if.cc
index 53a7145..efb2990 100644
--- a/src/tint/lang/core/ir/break_if.cc
+++ b/src/tint/lang/core/ir/break_if.cc
@@ -42,11 +42,16 @@
BreakIf::BreakIf() = default;
-BreakIf::BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args) : loop_(loop) {
+BreakIf::BreakIf(Value* condition,
+ ir::Loop* loop,
+ VectorRef<Value*> next_iter_values /* = tint::Empty */,
+ VectorRef<Value*> exit_values /* = tint::Empty */)
+ : loop_(loop), num_next_iter_values_(next_iter_values.Length()) {
TINT_ASSERT(loop_);
AddOperand(BreakIf::kConditionOperandOffset, condition);
- AddOperands(BreakIf::kArgsOperandOffset, std::move(args));
+ AddOperands(BreakIf::kArgsOperandOffset, std::move(next_iter_values));
+ AddOperands(BreakIf::kArgsOperandOffset + num_next_iter_values_, std::move(exit_values));
if (loop_) {
loop_->Body()->AddInboundSiblingBranch(this);
diff --git a/src/tint/lang/core/ir/break_if.h b/src/tint/lang/core/ir/break_if.h
index 85508e2..090f22a 100644
--- a/src/tint/lang/core/ir/break_if.h
+++ b/src/tint/lang/core/ir/break_if.h
@@ -57,8 +57,14 @@
/// Constructor
/// @param condition the break condition
/// @param loop the loop containing the break-if
- /// @param args the MultiInBlock arguments
- BreakIf(Value* condition, ir::Loop* loop, VectorRef<Value*> args = tint::Empty);
+ /// @param next_iter_values the arguments passed to the loop body MultiInBlock, if the break
+ /// condition evaluates to `false`.
+ /// @param exit_values the values returned by the loop, if the break condition evaluates to
+ /// `true`.
+ BreakIf(Value* condition,
+ ir::Loop* loop,
+ VectorRef<Value*> next_iter_values = tint::Empty,
+ VectorRef<Value*> exit_values = tint::Empty);
~BreakIf() override;
/// @copydoc Instruction::Clone()
@@ -85,8 +91,39 @@
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "break_if"; }
+ /// @returns the arguments passed to the loop body MultiInBlock, if the break condition
+ /// evaluates to `false`.
+ Slice<Value* const> NextIterValues() {
+ return operands_.Slice().Offset(kArgsOperandOffset).Truncate(num_next_iter_values_);
+ }
+
+ /// @returns the arguments passed to the loop body MultiInBlock, if the break condition
+ /// evaluates to `false`.
+ Slice<const Value* const> NextIterValues() const {
+ return operands_.Slice().Offset(kArgsOperandOffset).Truncate(num_next_iter_values_);
+ }
+
+ /// @returns the values returned by the loop, if the break condition evaluates to `true`.
+ Slice<Value* const> ExitValues() {
+ return operands_.Slice().Offset(kArgsOperandOffset + num_next_iter_values_);
+ }
+
+ /// @returns the values returned by the loop, if the break condition evaluates to `true`.
+ Slice<const Value* const> ExitValues() const {
+ return operands_.Slice().Offset(kArgsOperandOffset + num_next_iter_values_);
+ }
+
+ /// Sets the number of operands used as the next iterator values.
+ /// The first @p num operands after kArgsOperandOffset are used as next iterator values,
+ /// subsequent operators are used as exit values.
+ void SetNumNextIterValues(size_t num) {
+ TINT_ASSERT(operands_.Length() >= num + kArgsOperandOffset);
+ num_next_iter_values_ = num;
+ }
+
private:
ConstPropagatingPtr<ir::Loop> loop_;
+ size_t num_next_iter_values_ = 0;
};
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 84fc0ed..5fbbb14 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -1304,14 +1304,32 @@
/// Creates a loop break-if instruction
/// @param condition the break condition
/// @param loop the loop being iterated
- /// @param args the arguments for the target MultiInBlock
/// @returns the instruction
- template <typename CONDITION, typename... ARGS>
- ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition, ARGS&&... args) {
- CheckForNonDeterministicEvaluation<CONDITION, ARGS...>();
+ template <typename CONDITION>
+ ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition) {
+ CheckForNonDeterministicEvaluation<CONDITION>();
+ auto* cond_val = Value(std::forward<CONDITION>(condition));
+ return Append(ir.allocators.instructions.Create<ir::BreakIf>(cond_val, loop));
+ }
+
+ /// Creates a loop break-if instruction
+ /// @param condition the break condition
+ /// @param loop the loop being iterated
+ /// @param next_iter_values the arguments passed to the loop body MultiInBlock, if the break
+ /// condition evaluates to `false`.
+ /// @param exit_values the values returned by the loop, if the break condition evaluates to
+ /// `true`.
+ /// @returns the instruction
+ template <typename CONDITION, typename NEXT_ITER_VALUES, typename EXIT_VALUES>
+ ir::BreakIf* BreakIf(ir::Loop* loop,
+ CONDITION&& condition,
+ NEXT_ITER_VALUES&& next_iter_values,
+ EXIT_VALUES&& exit_values) {
+ CheckForNonDeterministicEvaluation<CONDITION, NEXT_ITER_VALUES, EXIT_VALUES>();
auto* cond_val = Value(std::forward<CONDITION>(condition));
return Append(ir.allocators.instructions.Create<ir::BreakIf>(
- cond_val, loop, Values(std::forward<ARGS>(args)...)));
+ cond_val, loop, Values(std::forward<NEXT_ITER_VALUES>(next_iter_values)),
+ Values(std::forward<EXIT_VALUES>(exit_values))));
}
/// Creates a continue instruction
diff --git a/src/tint/lang/core/ir/disassembly.cc b/src/tint/lang/core/ir/disassembly.cc
index df0cc4e..1a8a9a6 100644
--- a/src/tint/lang/core/ir/disassembly.cc
+++ b/src/tint/lang/core/ir/disassembly.cc
@@ -26,7 +26,10 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/core/ir/disassembly.h"
+
+#include <algorithm>
#include <memory>
+#include <optional>
#include <string_view>
#include "src//tint/lang/core/ir/unary.h"
@@ -577,6 +580,16 @@
}
}
+void Disassembly::EmitOperandList(const Instruction* inst, size_t start_index, size_t count) {
+ size_t n = std::min(count, inst->Operands().Length());
+ for (size_t i = start_index; i < n; i++) {
+ if (i != start_index) {
+ out_ << ", ";
+ }
+ EmitOperand(inst, i);
+ }
+}
+
void Disassembly::EmitIf(const If* if_) {
SourceMarker sm(this);
if (auto results = if_->Results(); !results.IsEmpty()) {
@@ -730,52 +743,74 @@
out_ << "}";
}
-void Disassembly::EmitTerminator(const Terminator* b) {
+void Disassembly::EmitTerminator(const Terminator* term) {
SourceMarker sm(this);
- size_t args_offset = 0;
- tint::Switch(
- b,
+ auto args_offset = tint::Switch<std::optional<size_t>>(
+ term,
[&](const ir::Return*) {
out_ << StyleInstruction("ret");
- args_offset = ir::Return::kArgsOperandOffset;
+ return ir::Return::kArgsOperandOffset;
},
[&](const ir::Continue*) {
out_ << StyleInstruction("continue");
- args_offset = ir::Continue::kArgsOperandOffset;
+ return ir::Continue::kArgsOperandOffset;
},
[&](const ir::ExitIf*) {
out_ << StyleInstruction("exit_if");
- args_offset = ir::ExitIf::kArgsOperandOffset;
+ return ir::ExitIf::kArgsOperandOffset;
},
[&](const ir::ExitSwitch*) {
out_ << StyleInstruction("exit_switch");
- args_offset = ir::ExitSwitch::kArgsOperandOffset;
+ return ir::ExitSwitch::kArgsOperandOffset;
},
[&](const ir::ExitLoop*) {
out_ << StyleInstruction("exit_loop");
- args_offset = ir::ExitLoop::kArgsOperandOffset;
+ return ir::ExitLoop::kArgsOperandOffset;
},
[&](const ir::NextIteration*) {
out_ << StyleInstruction("next_iteration");
- args_offset = ir::NextIteration::kArgsOperandOffset;
+ return ir::NextIteration::kArgsOperandOffset;
},
- [&](const ir::Unreachable*) { out_ << StyleInstruction("unreachable"); },
+ [&](const ir::Unreachable*) {
+ out_ << StyleInstruction("unreachable");
+ return std::nullopt;
+ },
[&](const ir::BreakIf* bi) {
- out_ << StyleInstruction("break_if") << " ";
+ out_ << StyleInstruction("break_if");
+ out_ << " ";
EmitValue(bi->Condition());
- args_offset = ir::BreakIf::kArgsOperandOffset;
+ auto next_iter_values = bi->NextIterValues();
+ auto exit_values = bi->ExitValues();
+ if (!next_iter_values.IsEmpty()) {
+ out_ << " " << StyleLabel("next_iteration") << ": [ ";
+ EmitOperandList(bi, ir::BreakIf::kArgsOperandOffset, next_iter_values.Length());
+ out_ << " ]";
+ }
+ if (!exit_values.IsEmpty()) {
+ out_ << " " << StyleLabel("exit_loop") << ": [ ";
+ EmitOperandList(bi, ir::BreakIf::kArgsOperandOffset + next_iter_values.Length());
+ out_ << " ]";
+ }
+ return std::nullopt;
},
- [&](const ir::TerminateInvocation*) { out_ << StyleInstruction("terminate_invocation"); },
- [&](Default) { out_ << StyleError("unknown terminator ", b->TypeInfo().name); });
+ [&](const ir::TerminateInvocation*) {
+ out_ << StyleInstruction("terminate_invocation");
+ return std::nullopt;
+ },
+ [&](Default) {
+ out_ << StyleError("unknown terminator ", term->TypeInfo().name);
+ return std::nullopt;
+ });
- if (!b->Args().IsEmpty()) {
+ if (args_offset && !term->Args().IsEmpty()) {
out_ << " ";
- EmitOperandList(b, args_offset);
+ EmitOperandList(term, *args_offset);
}
- sm.Store(b);
+
+ sm.Store(term);
tint::Switch(
- b, //
+ term, //
[&](const ir::BreakIf* bi) {
out_ << " "
<< StyleComment("# -> [t: exit_loop ", NameOf(bi->Loop()),
diff --git a/src/tint/lang/core/ir/disassembly.h b/src/tint/lang/core/ir/disassembly.h
index 70c48b4..99dfced 100644
--- a/src/tint/lang/core/ir/disassembly.h
+++ b/src/tint/lang/core/ir/disassembly.h
@@ -247,6 +247,7 @@
void EmitLine();
void EmitOperand(const Instruction* inst, size_t index);
void EmitOperandList(const Instruction* inst, size_t start_index = 0);
+ void EmitOperandList(const Instruction* inst, size_t start_index, size_t count);
void EmitInstructionName(const Instruction* inst);
const Module& mod_;
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 3e0e704..891fc09 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -41,6 +41,7 @@
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
+#include "src/tint/lang/core/ir/control_instruction.h"
#include "src/tint/lang/core/ir/convert.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/disassembly.h"
@@ -51,6 +52,7 @@
#include "src/tint/lang/core/ir/function.h"
#include "src/tint/lang/core/ir/function_param.h"
#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/instruction.h"
#include "src/tint/lang/core/ir/instruction_result.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
@@ -76,6 +78,7 @@
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/core/type/void.h"
#include "src/tint/utils/containers/hashset.h"
+#include "src/tint/utils/containers/predicates.h"
#include "src/tint/utils/containers/reverse.h"
#include "src/tint/utils/containers/transform.h"
#include "src/tint/utils/ice/ice.h"
@@ -97,6 +100,24 @@
namespace {
+/// @returns the parent block of @p block
+const Block* ParentBlockOf(const Block* block) {
+ if (auto* parent = block->Parent()) {
+ return parent->Block();
+ }
+ return nullptr;
+}
+
+/// @returns true if @p block directly or transitively holds the instruction @p inst
+bool TransitivelyHolds(const Block* block, const Instruction* inst) {
+ for (auto* b = inst->Block(); b; b = ParentBlockOf(b)) {
+ if (b == block) {
+ return true;
+ }
+ }
+ return false;
+}
+
/// @returns true if the type @p type is of, or indirectly references a type of type `T`.
template <typename T>
bool HoldsType(const type::Type* type) {
@@ -294,6 +315,10 @@
/// @param l the loop to validate
void CheckLoop(const Loop* l);
+ /// Validates the loop continuing block
+ /// @param l the loop to validate
+ void CheckLoopContinuing(const Loop* l);
+
/// Validates the given switch
/// @param s the switch to validate
void CheckSwitch(const Switch* s);
@@ -302,10 +327,18 @@
/// @param b the terminator to validate
void CheckTerminator(const Terminator* b);
+ /// Validates the continue instruction
+ /// @param c the continue to validate
+ void CheckContinue(const Continue* c);
+
/// Validates the given exit
/// @param e the exit to validate
void CheckExit(const Exit* e);
+ /// Validates the next iteration instruction
+ /// @param n the next iteration to validate
+ void CheckNextIteration(const NextIteration* n);
+
/// Validates the given exit if
/// @param e the exit if to validate
void CheckExitIf(const ExitIf* e);
@@ -389,6 +422,7 @@
diag::List diagnostics_;
Hashset<const Function*, 4> all_functions_;
Hashset<const Instruction*, 4> visited_instructions_;
+ Hashmap<const Loop*, const Continue*, 4> first_continues_;
Vector<const ControlInstruction*, 8> control_stack_;
Vector<const Block*, 8> block_stack_;
ScopeStack scope_stack_;
@@ -1043,7 +1077,10 @@
void Validator::CheckLoop(const Loop* l) {
// Note: Tasks are queued in reverse order of their execution
- tasks_.Push([this] { control_stack_.Pop(); });
+ tasks_.Push([this, l] {
+ first_continues_.Remove(l); // No need for this any more. Free memory.
+ control_stack_.Pop();
+ });
if (!l->Initializer()->IsEmpty()) {
tasks_.Push([this] { EndBlock(); });
}
@@ -1057,8 +1094,12 @@
// ⎣ ⎣ [Continuing ] ⎦⎦
if (!l->Continuing()->IsEmpty()) {
- tasks_.Push([this, l] { BeginBlock(l->Continuing()); });
+ tasks_.Push([this, l] {
+ CheckLoopContinuing(l);
+ BeginBlock(l->Continuing());
+ });
}
+
tasks_.Push([this, l] { BeginBlock(l->Body()); });
if (!l->Initializer()->IsEmpty()) {
tasks_.Push([this, l] { BeginBlock(l->Initializer()); });
@@ -1066,6 +1107,41 @@
tasks_.Push([this, l] { control_stack_.Push(l); });
}
+void Validator::CheckLoopContinuing(const Loop* loop) {
+ if (!loop->HasContinuing()) {
+ return;
+ }
+
+ // Ensure that values used in the loop continuing are not from the loop body, after a
+ // continue instruction.
+ if (auto* first_continue = first_continues_.GetOr(loop, nullptr)) {
+ // Find the instruction in the body block that is or holds the first continue instruction.
+ const Instruction* holds_continue = first_continue;
+ while (holds_continue && holds_continue->Block() &&
+ holds_continue->Block() != loop->Body()) {
+ holds_continue = holds_continue->Block()->Parent();
+ }
+
+ // Check that all subsequent instruction values are not used in the continuing block.
+ for (auto* inst = holds_continue; inst; inst = inst->next) {
+ for (auto* result : inst->Results()) {
+ result->ForEachUse([&](Usage use) {
+ if (TransitivelyHolds(loop->Continuing(), use.instruction)) {
+ AddError(use.instruction, use.operand_index)
+ << NameOf(result)
+ << " cannot be used in continuing block as it is declared after the "
+ "first "
+ << style::Instruction("continue") << " in the loop's body";
+ AddDeclarationNote(result);
+ AddNote(first_continue)
+ << "loop body's first " << style::Instruction("continue");
+ }
+ });
+ }
+ }
+ }
+}
+
void Validator::CheckSwitch(const Switch* s) {
tasks_.Push([this] { control_stack_.Pop(); });
@@ -1081,17 +1157,34 @@
// DemoteToHelper) so we can't add validation.
tint::Switch(
- b, //
- [&](const ir::BreakIf*) {}, //
- [&](const ir::Continue*) {}, //
- [&](const ir::Exit* e) { CheckExit(e); }, //
- [&](const ir::NextIteration*) {}, //
- [&](const ir::Return* ret) { CheckReturn(ret); }, //
- [&](const ir::TerminateInvocation*) {}, //
- [&](const ir::Unreachable*) {}, //
+ b, //
+ [&](const ir::BreakIf*) {}, //
+ [&](const ir::Continue* c) { CheckContinue(c); }, //
+ [&](const ir::Exit* e) { CheckExit(e); }, //
+ [&](const ir::NextIteration* n) { CheckNextIteration(n); }, //
+ [&](const ir::Return* ret) { CheckReturn(ret); }, //
+ [&](const ir::TerminateInvocation*) {}, //
+ [&](const ir::Unreachable*) {}, //
[&](Default) { AddError(b) << "missing validation"; });
}
+void Validator::CheckContinue(const Continue* c) {
+ auto* loop = c->Loop();
+ if (loop == nullptr) {
+ AddError(c) << "has no associated loop";
+ return;
+ }
+ if (!TransitivelyHolds(loop->Body(), c)) {
+ if (control_stack_.Any(Eq<const ControlInstruction*>(loop))) {
+ AddError(c) << "must only be called from loop body";
+ } else {
+ AddError(c) << "called outside of associated loop";
+ }
+ }
+
+ first_continues_.Add(loop, c);
+}
+
void Validator::CheckExit(const Exit* e) {
if (e->ControlInstruction() == nullptr) {
AddError(e) << "has no parent control instruction";
@@ -1130,6 +1223,21 @@
[&](Default) { AddError(e) << "missing validation"; });
}
+void Validator::CheckNextIteration(const NextIteration* n) {
+ auto* loop = n->Loop();
+ if (loop == nullptr) {
+ AddError(n) << "has no associated loop";
+ return;
+ }
+ if (!TransitivelyHolds(loop->Initializer(), n) && !TransitivelyHolds(loop->Continuing(), n)) {
+ if (control_stack_.Any(Eq<const ControlInstruction*>(loop))) {
+ AddError(n) << "must only be called from loop initializer or continuing";
+ } else {
+ AddError(n) << "called outside of associated loop";
+ }
+ }
+}
+
void Validator::CheckExitIf(const ExitIf* e) {
if (control_stack_.Back() != e->If()) {
AddError(e) << "if target jumps over other control instructions";
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 2619338..b14f6c3 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -2603,6 +2603,298 @@
)");
}
+TEST_F(IR_ValidatorTest, ContinueOutsideOfLoop) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Continue(loop);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:8:5 error: continue: called outside of associated loop
+ continue # -> $B3
+ ^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2] { # loop_1
+ $B2: { # body
+ exit_loop # loop_1
+ }
+ }
+ continue # -> $B3
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopInit) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] { b.Continue(loop); });
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:5:9 error: continue: must only be called from loop body
+ continue # -> $B4
+ ^^^^^^^^
+
+:4:7 note: in block
+ $B2: { # initializer
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ continue # -> $B4
+ }
+ $B3: { # body
+ exit_loop # loop_1
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopBody) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.Continue(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinueInLoopContinuing) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Append(loop->Continuing(), [&] { b.Continue(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:8:9 error: continue: must only be called from loop body
+ continue # -> $B3
+ ^^^^^^^^
+
+:7:7 note: in block
+ $B3: { # continuing
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2, c: $B3] { # loop_1
+ $B2: { # body
+ exit_loop # loop_1
+ }
+ $B3: { # continuing
+ continue # -> $B3
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, NextIterationOutsideOfLoop) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.NextIteration(loop);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:8:5 error: next_iteration: called outside of associated loop
+ next_iteration # -> $B2
+ ^^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2] { # loop_1
+ $B2: { # body
+ exit_loop # loop_1
+ }
+ }
+ next_iteration # -> $B2
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, NextIterationInLoopInit) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] { b.NextIteration(loop); });
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, NextIterationInLoopBody) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.NextIteration(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:5:9 error: next_iteration: must only be called from loop initializer or continuing
+ next_iteration # -> $B2
+ ^^^^^^^^^^^^^^
+
+:4:7 note: in block
+ $B2: { # body
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2] { # loop_1
+ $B2: { # body
+ next_iteration # -> $B2
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, NextIterationInLoopContinuing) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Append(loop->Continuing(), [&] { b.NextIteration(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinuingUseValueBeforeContinue) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* value = b.Let("value", 1_i);
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] {
+ b.Append(value);
+ b.Append(b.If(true)->True(), [&] { b.Continue(loop); });
+ b.ExitLoop(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ b.Let("use", value);
+ b.NextIteration(loop);
+ });
+ b.Return(f);
+ });
+
+ ASSERT_EQ(ir::Validate(mod), Success);
+}
+
+TEST_F(IR_ValidatorTest, ContinuingUseValueAfterContinue) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* value = b.Let("value", 1_i);
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] {
+ b.Append(b.If(true)->True(), [&] { b.Continue(loop); });
+ b.Append(value);
+ b.ExitLoop(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ b.Let("use", value);
+ b.NextIteration(loop);
+ });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(
+ res.Failure().reason.Str(),
+ R"(:14:24 error: let: %value cannot be used in continuing block as it is declared after the first 'continue' in the loop's body
+ %use:i32 = let %value
+ ^^^^^^
+
+:4:7 note: in block
+ $B2: { # body
+ ^^^
+
+:10:9 note: %value declared here
+ %value:i32 = let 1i
+ ^^^^^^^^^^
+
+:7:13 note: loop body's first 'continue'
+ continue # -> $B3
+ ^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2, c: $B3] { # loop_1
+ $B2: { # body
+ if true [t: $B4] { # if_1
+ $B4: { # true
+ continue # -> $B3
+ }
+ }
+ %value:i32 = let 1i
+ exit_loop # loop_1
+ }
+ $B3: { # continuing
+ %use:i32 = let %value
+ next_iteration # -> $B2
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, ExitLoop) {
auto* loop = b.Loop();
loop->Continuing()->Append(b.NextIteration(loop));
diff --git a/src/tint/lang/core/ir/value.cc b/src/tint/lang/core/ir/value.cc
index c1a66e8..af66b0e 100644
--- a/src/tint/lang/core/ir/value.cc
+++ b/src/tint/lang/core/ir/value.cc
@@ -44,7 +44,7 @@
flags_.Add(Flag::kDead);
}
-void Value::ForEachUse(std::function<void(Usage use)> func) {
+void Value::ForEachUse(std::function<void(Usage use)> func) const {
auto uses = uses_;
for (auto& use : uses) {
func(use);
diff --git a/src/tint/lang/core/ir/value.h b/src/tint/lang/core/ir/value.h
index 0603e87..e6f1bad 100644
--- a/src/tint/lang/core/ir/value.h
+++ b/src/tint/lang/core/ir/value.h
@@ -105,7 +105,7 @@
/// Apply a function to all uses of the value that exist prior to calling this method.
/// @param func the function will be applied to each use
- void ForEachUse(std::function<void(Usage use)> func);
+ void ForEachUse(std::function<void(Usage use)> func) const;
/// Replace all uses of the value.
/// @param replacer a function which returns a replacement for a given use
diff --git a/src/tint/lang/spirv/writer/loop_test.cc b/src/tint/lang/spirv/writer/loop_test.cc
index 633eb9b..1fb4f9e 100644
--- a/src/tint/lang/spirv/writer/loop_test.cc
+++ b/src/tint/lang/spirv/writer/loop_test.cc
@@ -358,7 +358,7 @@
loop->Continuing()->SetParams({cont_param});
b.Append(loop->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(loop, cmp, cont_param);
+ b.BreakIf(loop, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
});
b.Return(func);
@@ -467,7 +467,7 @@
loop->Continuing()->SetParams({cont_param});
b.Append(loop->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(loop, cmp, cont_param);
+ b.BreakIf(loop, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
});
b.Return(func);
@@ -533,7 +533,7 @@
outer->Continuing()->SetParams({cont_param});
b.Append(outer->Continuing(), [&] {
auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
- b.BreakIf(outer, cmp, cont_param);
+ b.BreakIf(outer, cmp, /* next_iter */ Vector{cont_param}, /* exit */ Empty);
});
b.Return(func);
diff --git a/src/tint/lang/spirv/writer/raise/merge_return_test.cc b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
index 40d44b0..9cb2f0a 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return_test.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
@@ -1698,7 +1698,7 @@
b.Append(loop->Continuing(), [&] {
b.Store(global, 1_i);
- b.BreakIf(loop, true, 4_i);
+ b.BreakIf(loop, true, /* next_iter */ b.Values(4_i), /* exit */ Empty);
});
b.Store(global, 3_i);
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies.cc b/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
index ff33316c..2994e09 100644
--- a/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies.cc
@@ -32,6 +32,7 @@
#include <utility>
#include <vector>
+#include "src/tint/lang/core/evaluation_stage.h"
#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
@@ -85,7 +86,8 @@
return TraverseAction::Skip;
}
if (call->Target()->IsAnyOf<sem::Function, sem::BuiltinFn>() &&
- call->HasSideEffects()) {
+ call->HasSideEffects() &&
+ call->Stage() != core::EvaluationStage::kNotEvaluated) {
side_effects.push_back(expr);
return TraverseAction::Skip;
}
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies.h b/src/tint/lang/wgsl/ast/transform/remove_phonies.h
index ecb6d6e..665dc7f 100644
--- a/src/tint/lang/wgsl/ast/transform/remove_phonies.h
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies.h
@@ -38,6 +38,8 @@
/// RemovePhonies is a Transform that removes all phony-assignment statements,
/// while preserving function call expressions in the RHS of the assignment that
/// may have side-effects. It also removes calls to builtins that return a constant value.
+/// @note RemovePhonies must be run after the PromoteSideEffectsToDecl transform, otherwise `f` in
+/// `_ = cond && f()` may get hoisted to a call statement without the short-circuiting conditional.
class RemovePhonies final : public Castable<RemovePhonies, Transform> {
public:
/// Constructor
diff --git a/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc b/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc
index af6c01d..dc1dbcd 100644
--- a/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/remove_phonies_test.cc
@@ -496,5 +496,30 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(RemovePhoniesTest, ConstShortCircuit) {
+ auto* src = R"(
+fn a(v : i32) -> i32 {
+ return v;
+}
+
+fn b() {
+ _ = false && (a(4294967295) < a(a(4294967295)));
+}
+)";
+
+ auto* expect = R"(
+fn a(v : i32) -> i32 {
+ return v;
+}
+
+fn b() {
+}
+)";
+
+ auto got = Run<RemovePhonies>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc b/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc
index 3a3714a..7873d8f 100644
--- a/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc
+++ b/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc
@@ -29,6 +29,7 @@
#include "gmock/gmock.h"
#include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
+#include "src/tint/utils/containers/slice.h"
namespace tint::resolver {
namespace {
@@ -355,5 +356,58 @@
EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
}
+TEST_F(ResolverEvaluationStageTest, FnCall_Runtime) {
+ // fn f() -> bool { return true; }
+ // let l = false
+ // let result = l && f();
+ Func("f", Empty, ty.bool_(), Vector{Return(true)});
+ auto* let = Let("l", Expr(false));
+ auto* lhs = Expr(let);
+ auto* rhs = Call("f");
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Let("result", binary);
+ WrapInFunction(let, result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), core::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverEvaluationStageTest, FnCall_NotEvaluated) {
+ // fn f() -> bool { return true; }
+ // let result = false && f();
+ Func("f", Empty, ty.bool_(), Vector{Return(true)});
+ auto* rhs = Call("f");
+ auto* lhs = Expr(false);
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Let("result", binary);
+ WrapInFunction(result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+}
+
+TEST_F(ResolverEvaluationStageTest, NestedFnCall_NotEvaluated) {
+ // fn f(b : bool) -> bool { return b; }
+ // let result = false && f(f(f(1 == 0)));
+ Func("f", Vector{Param("b", ty.bool_())}, ty.bool_(), Vector{Return("b")});
+ auto* cmp = Equal(0_i, 1_i);
+ auto* rhs_0 = Call("f", cmp);
+ auto* rhs_1 = Call("f", rhs_0);
+ auto* rhs_2 = Call("f", rhs_1);
+ auto* lhs = Expr(false);
+ auto* binary = LogicalAnd(lhs, rhs_2);
+ auto* result = Let("result", binary);
+ WrapInFunction(result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(cmp)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(rhs_0)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(rhs_1)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(rhs_2)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index a400979..34d3af3 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -3072,10 +3072,12 @@
return nullptr;
}
+ auto stage = skip_const_eval_.Contains(expr) ? core::EvaluationStage::kNotEvaluated
+ : core::EvaluationStage::kRuntime;
+
// TODO(crbug.com/tint/1420): For now, assume all function calls have side effects.
bool has_side_effects = true;
- auto* call = b.create<sem::Call>(expr, target, core::EvaluationStage::kRuntime, std::move(args),
- current_statement_,
+ auto* call = b.create<sem::Call>(expr, target, stage, std::move(args), current_statement_,
/* constant_value */ nullptr, has_side_effects);
target->AddCallSite(call);
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc
index 5e601cc..7b9b3cd 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc
@@ -2093,7 +2093,7 @@
b.Append(if2->True(), [&] { b.Return(fn, 1_i); });
b.Append(if2->False(), [&] { b.Return(fn, 2_i); });
- b.NextIteration(loop);
+ b.Continue(loop);
});
});