resolver: Validate pipline stage use for intrinsics

Use the new [[stage()]] decorations in intrinsics.def to validate that intrinsics are only called from the correct pipeline stages.

Fixed: tint:657
Change-Id: I9efda26369c45c6f816bdaa53408d3909db403a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53084
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 2ac42f9..7b59942 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -237,6 +237,10 @@
     }
   }
 
+  if (!ValidatePipelineStages()) {
+    return false;
+  }
+
   bool result = true;
 
   for (auto* node : builder_->ASTNodes().Objects()) {
@@ -1129,6 +1133,10 @@
 bool Resolver::Function(ast::Function* func) {
   auto* info = function_infos_.Create<FunctionInfo>(func);
 
+  if (func->IsEntryPoint()) {
+    entry_points_.emplace_back(info);
+  }
+
   TINT_SCOPED_ASSIGNMENT(current_function_, info);
 
   variable_stack_.push_scope();
@@ -1707,6 +1715,10 @@
   builder_->Sem().Add(
       call, builder_->create<sem::Call>(call, result, current_statement_));
   SetType(call, result->ReturnType());
+
+  current_function_->intrinsic_calls.emplace_back(
+      IntrinsicCallInfo{call, result});
+
   return true;
 }
 
@@ -2460,25 +2472,76 @@
   expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
 }
 
+bool Resolver::ValidatePipelineStages() {
+  auto check_intrinsic_calls = [&](FunctionInfo* func,
+                                   FunctionInfo* entry_point) {
+    auto stage = entry_point->declaration->pipeline_stage();
+    for (auto& call : func->intrinsic_calls) {
+      if (!call.intrinsic->SupportedStages().Contains(stage)) {
+        std::stringstream err;
+        err << "built-in cannot be used by " << stage << " pipeline stage";
+        diagnostics_.add_error(err.str(), call.call->source());
+        if (func != entry_point) {
+          TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
+            diagnostics_.add_note(
+                "called by function '" +
+                    builder_->Symbols().NameFor(f->declaration->symbol()) + "'",
+                f->declaration->source());
+          });
+          diagnostics_.add_note("called by entry point '" +
+                                    builder_->Symbols().NameFor(
+                                        entry_point->declaration->symbol()) +
+                                    "'",
+                                entry_point->declaration->source());
+        }
+        return false;
+      }
+    }
+    return true;
+  };
+
+  for (auto* entry_point : entry_points_) {
+    if (!check_intrinsic_calls(entry_point, entry_point)) {
+      return false;
+    }
+    for (auto* func : entry_point->transitive_calls) {
+      if (!check_intrinsic_calls(func, entry_point)) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+template <typename CALLBACK>
+void Resolver::TraverseCallChain(FunctionInfo* from,
+                                 FunctionInfo* to,
+                                 CALLBACK&& callback) const {
+  for (auto* f : from->transitive_calls) {
+    if (f == to) {
+      callback(f);
+      return;
+    }
+    if (f->transitive_calls.contains(to)) {
+      TraverseCallChain(f, to, callback);
+      callback(f);
+      return;
+    }
+  }
+  TINT_ICE(diagnostics_)
+      << "TraverseCallChain() 'from' does not transitively call 'to'";
+}
+
 void Resolver::CreateSemanticNodes() const {
   auto& sem = builder_->Sem();
 
   // Collate all the 'ancestor_entry_points' - this is a map of function
   // symbol to all the entry points that transitively call the function.
   std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
-  for (auto* func : builder_->AST().Functions()) {
-    auto it = function_to_info_.find(func);
-    if (it == function_to_info_.end()) {
-      continue;  // Resolver has likely errored. Process what we can.
-    }
-
-    auto* info = it->second;
-    if (!func->IsEntryPoint()) {
-      continue;
-    }
-    for (auto* call : info->transitive_calls) {
+  for (auto* entry_point : entry_points_) {
+    for (auto* call : entry_point->transitive_calls) {
       auto& vec = ancestor_entry_points[call->declaration->symbol()];
-      vec.emplace_back(func->symbol());
+      vec.emplace_back(entry_point->declaration->symbol());
     }
   }