resolver: Validate pipline stage use for intrinsics
Use the new [[stage()]] decorations in intrinsics.def to validate that intrinsics are only called from the correct pipeline stages.
Fixed: tint:657
Change-Id: I9efda26369c45c6f816bdaa53408d3909db403a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53084
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 2ac42f9..7b59942 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -237,6 +237,10 @@
}
}
+ if (!ValidatePipelineStages()) {
+ return false;
+ }
+
bool result = true;
for (auto* node : builder_->ASTNodes().Objects()) {
@@ -1129,6 +1133,10 @@
bool Resolver::Function(ast::Function* func) {
auto* info = function_infos_.Create<FunctionInfo>(func);
+ if (func->IsEntryPoint()) {
+ entry_points_.emplace_back(info);
+ }
+
TINT_SCOPED_ASSIGNMENT(current_function_, info);
variable_stack_.push_scope();
@@ -1707,6 +1715,10 @@
builder_->Sem().Add(
call, builder_->create<sem::Call>(call, result, current_statement_));
SetType(call, result->ReturnType());
+
+ current_function_->intrinsic_calls.emplace_back(
+ IntrinsicCallInfo{call, result});
+
return true;
}
@@ -2460,25 +2472,76 @@
expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
}
+bool Resolver::ValidatePipelineStages() {
+ auto check_intrinsic_calls = [&](FunctionInfo* func,
+ FunctionInfo* entry_point) {
+ auto stage = entry_point->declaration->pipeline_stage();
+ for (auto& call : func->intrinsic_calls) {
+ if (!call.intrinsic->SupportedStages().Contains(stage)) {
+ std::stringstream err;
+ err << "built-in cannot be used by " << stage << " pipeline stage";
+ diagnostics_.add_error(err.str(), call.call->source());
+ if (func != entry_point) {
+ TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
+ diagnostics_.add_note(
+ "called by function '" +
+ builder_->Symbols().NameFor(f->declaration->symbol()) + "'",
+ f->declaration->source());
+ });
+ diagnostics_.add_note("called by entry point '" +
+ builder_->Symbols().NameFor(
+ entry_point->declaration->symbol()) +
+ "'",
+ entry_point->declaration->source());
+ }
+ return false;
+ }
+ }
+ return true;
+ };
+
+ for (auto* entry_point : entry_points_) {
+ if (!check_intrinsic_calls(entry_point, entry_point)) {
+ return false;
+ }
+ for (auto* func : entry_point->transitive_calls) {
+ if (!check_intrinsic_calls(func, entry_point)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+template <typename CALLBACK>
+void Resolver::TraverseCallChain(FunctionInfo* from,
+ FunctionInfo* to,
+ CALLBACK&& callback) const {
+ for (auto* f : from->transitive_calls) {
+ if (f == to) {
+ callback(f);
+ return;
+ }
+ if (f->transitive_calls.contains(to)) {
+ TraverseCallChain(f, to, callback);
+ callback(f);
+ return;
+ }
+ }
+ TINT_ICE(diagnostics_)
+ << "TraverseCallChain() 'from' does not transitively call 'to'";
+}
+
void Resolver::CreateSemanticNodes() const {
auto& sem = builder_->Sem();
// Collate all the 'ancestor_entry_points' - this is a map of function
// symbol to all the entry points that transitively call the function.
std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
- for (auto* func : builder_->AST().Functions()) {
- auto it = function_to_info_.find(func);
- if (it == function_to_info_.end()) {
- continue; // Resolver has likely errored. Process what we can.
- }
-
- auto* info = it->second;
- if (!func->IsEntryPoint()) {
- continue;
- }
- for (auto* call : info->transitive_calls) {
+ for (auto* entry_point : entry_points_) {
+ for (auto* call : entry_point->transitive_calls) {
auto& vec = ancestor_entry_points[call->declaration->symbol()];
- vec.emplace_back(func->symbol());
+ vec.emplace_back(entry_point->declaration->symbol());
}
}