[tint][ir] Validate blocks without using recursion
Fixes stack overflows of deeply nested blocks (easily done with long if-else chains)
Change-Id: I4b315c8922dd2e8b16c63fbbf81094a3e35accd9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/186641
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 31d0cb8..21d8213 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -32,12 +32,13 @@
#include <string>
#include <utility>
-#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/intrinsic/table.h"
#include "src/tint/lang/core/ir/access.h"
#include "src/tint/lang/core/ir/binary.h"
#include "src/tint/lang/core/ir/bitcast.h"
+#include "src/tint/lang/core/ir/block_param.h"
#include "src/tint/lang/core/ir/break_if.h"
+#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
@@ -48,7 +49,9 @@
#include "src/tint/lang/core/ir/exit_loop.h"
#include "src/tint/lang/core/ir/exit_switch.h"
#include "src/tint/lang/core/ir/function.h"
+#include "src/tint/lang/core/ir/function_param.h"
#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/instruction_result.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
#include "src/tint/lang/core/ir/load_vector_element.h"
@@ -72,8 +75,11 @@
#include "src/tint/lang/core/type/type.h"
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/core/type/void.h"
+#include "src/tint/utils/containers/hashset.h"
#include "src/tint/utils/containers/reverse.h"
#include "src/tint/utils/containers/transform.h"
+#include "src/tint/utils/ice/ice.h"
+#include "src/tint/utils/macros/defer.h"
#include "src/tint/utils/macros/scoped_assignment.h"
#include "src/tint/utils/rtti/switch.h"
#include "src/tint/utils/text/styled_text.h"
@@ -235,10 +241,6 @@
/// @param func the function validate
void CheckFunction(const Function* func);
- /// Validates the given block
- /// @param blk the block to validate
- void CheckBlock(const Block* blk);
-
/// Validates the given instruction
/// @param inst the instruction to validate
void CheckInstruction(const Instruction* inst);
@@ -338,15 +340,33 @@
/// @returns the vector pointer type for the given instruction operand
const core::type::Type* GetVectorPtrElementType(const Instruction* inst, size_t idx);
- private:
+ /// Executes all the pending tasks
+ void ProcessTasks();
+
+ /// Queues the block to be validated with ProcessTasks()
+ /// @param blk the block to validate
+ void QueueBlock(const Block* blk);
+
+ /// Queues the list of instructions starting with @p inst to be validated
+ /// @param inst the first instruction
+ void QueueInstructions(const Instruction* inst);
+
+ /// Begins validation of the block @p blk, and its instructions.
+ /// Must be paired with a call to EndBlock().
+ void BeginBlock(const Block* blk);
+
+ /// Ends validation of the block opened with BeginBlock().
+ void EndBlock();
+
const Module& mod_;
Capabilities capabilities_;
std::optional<ir::Disassembly> disassembly_; // Use Disassembly()
diag::List diagnostics_;
- const Block* current_block_ = nullptr;
Hashset<const Function*, 4> all_functions_;
Hashset<const Instruction*, 4> visited_instructions_;
Vector<const ControlInstruction*, 8> control_stack_;
+ Vector<const Block*, 8> block_stack_;
+ Vector<std::function<void()>, 16> tasks_;
};
Validator::Validator(const Module& mod, Capabilities capabilities)
@@ -354,7 +374,7 @@
Validator::~Validator() = default;
-ir::Disassembly& Validator::Disassembly() {
+Disassembly& Validator::Disassembly() {
if (!disassembly_) {
disassembly_.emplace(Disassemble(mod_));
}
@@ -362,6 +382,11 @@
}
Result<SuccessType> Validator::Run() {
+ TINT_DEFER({
+ TINT_ASSERT(tasks_.IsEmpty());
+ TINT_ASSERT(control_stack_.IsEmpty());
+ TINT_ASSERT(block_stack_.IsEmpty());
+ });
CheckRootBlock(mod_.root_block);
for (auto& func : mod_.functions) {
@@ -396,8 +421,8 @@
auto src = Disassembly().InstructionSource(inst);
auto& diag = AddError(src) << inst->FriendlyName() << ": ";
- if (current_block_) {
- AddNote(current_block_) << "in block";
+ if (!block_stack_.IsEmpty()) {
+ AddNote(block_stack_.Back()) << "in block";
}
return diag;
}
@@ -407,10 +432,9 @@
Disassembly().OperandSource(Disassembly::IndexedValue{inst, static_cast<uint32_t>(idx)});
auto& diag = AddError(src) << inst->FriendlyName() << ": ";
- if (current_block_) {
- AddNote(current_block_) << "in block";
+ if (!block_stack_.IsEmpty()) {
+ AddNote(block_stack_.Back()) << "in block";
}
-
return diag;
}
@@ -419,8 +443,8 @@
Disassembly().ResultSource(Disassembly::IndexedValue{inst, static_cast<uint32_t>(idx)});
auto& diag = AddError(src) << inst->FriendlyName() << ": ";
- if (current_block_) {
- AddNote(current_block_) << "in block";
+ if (!block_stack_.IsEmpty()) {
+ AddNote(block_stack_.Back()) << "in block";
}
return diag;
}
@@ -504,7 +528,8 @@
}
void Validator::CheckRootBlock(const Block* blk) {
- TINT_SCOPED_ASSIGNMENT(current_block_, blk);
+ block_stack_.Push(blk);
+ TINT_DEFER(block_stack_.Pop());
for (auto* inst : *blk) {
if (inst->Block() != blk) {
@@ -521,8 +546,6 @@
}
void Validator::CheckFunction(const Function* func) {
- CheckBlock(func->Block());
-
for (auto* param : func->Params()) {
if (!param->Alive()) {
AddError(param) << "destroyed parameter found in function parameter list";
@@ -545,10 +568,24 @@
if (HoldsType<type::Reference>(func->ReturnType())) {
AddError(func) << "references are not permitted as return types";
}
+
+ QueueBlock(func->Block());
+ ProcessTasks();
}
-void Validator::CheckBlock(const Block* blk) {
- TINT_SCOPED_ASSIGNMENT(current_block_, blk);
+void Validator::ProcessTasks() {
+ while (!tasks_.IsEmpty()) {
+ tasks_.Pop()();
+ }
+}
+
+void Validator::QueueBlock(const Block* blk) {
+ tasks_.Push([this] { EndBlock(); });
+ tasks_.Push([this, blk] { BeginBlock(blk); });
+}
+
+void Validator::BeginBlock(const Block* blk) {
+ block_stack_.Push(blk);
if (auto* mb = blk->As<MultiInBlock>()) {
for (auto* param : mb->Params()) {
@@ -568,22 +605,39 @@
}
if (!blk->Terminator()) {
- AddError(blk) << "block: does not end in a terminator instruction";
+ AddError(blk) << "block does not end in a terminator instruction";
}
+ // Validate the instructions w.r.t. the parent block
for (auto* inst : *blk) {
if (inst->Block() != blk) {
AddError(inst) << "block instruction does not have same block as parent";
- AddNote(current_block_) << "in block";
+ AddNote(blk) << "in block";
continue;
}
if (inst->Is<ir::Terminator>() && inst != blk->Terminator()) {
- AddError(inst) << "block: terminator which isn't the final instruction";
+ AddError(inst) << "block terminator which isn't the final instruction";
continue;
}
-
- CheckInstruction(inst);
}
+
+ // Enqueue validation of the instructions of the block
+ if (!blk->IsEmpty()) {
+ QueueInstructions(blk->Instructions());
+ }
+}
+
+void Validator::EndBlock() {
+ block_stack_.Pop();
+}
+
+void Validator::QueueInstructions(const Instruction* inst) {
+ tasks_.Push([this, inst] {
+ CheckInstruction(inst);
+ if (inst->next) {
+ QueueInstructions(inst->next);
+ }
+ });
}
void Validator::CheckInstruction(const Instruction* inst) {
@@ -624,10 +678,10 @@
// for `nullptr` here.
if (!op->Alive()) {
AddError(inst, i) << "operand is not alive";
- }
-
- if (!op->HasUsage(inst, i)) {
+ } else if (!op->HasUsage(inst, i)) {
AddError(inst, i) << "operand missing usage";
+ } else if (auto fn = op->As<Function>(); fn && !all_functions_.Contains(fn)) {
+ AddError(inst, i) << NameOf(op) << " is not part of the module";
}
if (!capabilities_.Contains(Capability::kAllowRefTypes)) {
@@ -713,10 +767,6 @@
}
void Validator::CheckUserCall(const UserCall* call) {
- if (!all_functions_.Contains(call->Target())) {
- AddError(call, UserCall::kFunctionOperandOffset) << "call target is not part of the module";
- }
-
if (call->Target()->Stage() != Function::PipelineStage::kUndefined) {
AddError(call, UserCall::kFunctionOperandOffset)
<< "call target must not have a pipeline stage";
@@ -904,36 +954,50 @@
AddError(if_, If::kConditionOperandOffset) << "condition must be a `bool` type";
}
- control_stack_.Push(if_);
- TINT_DEFER(control_stack_.Pop());
+ tasks_.Push([this] { control_stack_.Pop(); });
- CheckBlock(if_->True());
if (!if_->False()->IsEmpty()) {
- CheckBlock(if_->False());
+ QueueBlock(if_->False());
}
+
+ QueueBlock(if_->True());
+
+ tasks_.Push([this, if_] { control_stack_.Push(if_); });
}
void Validator::CheckLoop(const Loop* l) {
- control_stack_.Push(l);
- TINT_DEFER(control_stack_.Pop());
-
+ // Note: Tasks are queued in reverse order of their execution
+ tasks_.Push([this] { control_stack_.Pop(); });
if (!l->Initializer()->IsEmpty()) {
- CheckBlock(l->Initializer());
+ tasks_.Push([this] { EndBlock(); });
}
- CheckBlock(l->Body());
+ tasks_.Push([this] { EndBlock(); });
+ if (!l->Continuing()->IsEmpty()) {
+ tasks_.Push([this] { EndBlock(); });
+ }
+
+ // ⎡Initializer ⎤
+ // ⎢ ⎡Body ⎤⎥
+ // ⎣ ⎣ [Continuing ] ⎦⎦
if (!l->Continuing()->IsEmpty()) {
- CheckBlock(l->Continuing());
+ tasks_.Push([this, l] { BeginBlock(l->Continuing()); });
}
+ tasks_.Push([this, l] { BeginBlock(l->Body()); });
+ if (!l->Initializer()->IsEmpty()) {
+ tasks_.Push([this, l] { BeginBlock(l->Initializer()); });
+ }
+ tasks_.Push([this, l] { control_stack_.Push(l); });
}
void Validator::CheckSwitch(const Switch* s) {
- control_stack_.Push(s);
- TINT_DEFER(control_stack_.Pop());
+ tasks_.Push([this] { control_stack_.Pop(); });
for (auto& cse : s->Cases()) {
- CheckBlock(cse.block);
+ QueueBlock(cse.block);
}
+
+ tasks_.Push([this, s] { control_stack_.Push(s); });
}
void Validator::CheckTerminator(const Terminator* b) {
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 9c321a7..7c03112 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -253,7 +253,7 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:3:20 error: call: call target is not part of the module
+ R"(:3:20 error: call: %g is not part of the module
%2:void = call %g
^^
@@ -424,7 +424,7 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:2:3 error: block: does not end in a terminator instruction
+ R"(:2:3 error: block does not end in a terminator instruction
$B1: {
^^^
@@ -1123,7 +1123,7 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:3:5 error: return: block: terminator which isn't the final instruction
+ R"(:3:5 error: return: block terminator which isn't the final instruction
ret
^^^
@@ -1165,7 +1165,7 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:4:7 error: block: does not end in a terminator instruction
+ R"(:4:7 error: block does not end in a terminator instruction
$B2: { # true
^^^
@@ -1321,7 +1321,7 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:4:7 error: block: does not end in a terminator instruction
+ R"(:4:7 error: block does not end in a terminator instruction
$B2: { # body
^^^