tint: Show the reason for a uniformity requirement

When producing an error from the uniformity analysis, add notes to
show the underlying reason for the uniformity requirement.

For function calls that are required-to-be-uniform, show the innermost
builtin call that has the requirement.

For function parameters that are required-to-be-uniform, recurse into
that function to show where its requirement comes from.

Add some new tests to specifically test the error messages.

Bug: tint:880
Change-Id: Ib166fdeceaffb156a3afc50ebc5a4ad0860dc002
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89722
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index cf3ba17..521fec4 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -321,9 +321,9 @@
         // Look at which nodes are reachable from "RequiredToBeUniform".
         {
             utils::UniqueVector<Node*> reachable;
-            Traverse(current_function_->required_to_be_uniform, reachable);
+            Traverse(current_function_->required_to_be_uniform, &reachable);
             if (reachable.contains(current_function_->may_be_non_uniform)) {
-                MakeError();
+                MakeError(*current_function_, current_function_->may_be_non_uniform);
                 return false;
             }
             if (reachable.contains(current_function_->cf_start)) {
@@ -343,7 +343,7 @@
         // Look at which nodes are reachable from "CF_return"
         {
             utils::UniqueVector<Node*> reachable;
-            Traverse(current_function_->cf_return, reachable);
+            Traverse(current_function_->cf_return, &reachable);
             if (reachable.contains(current_function_->may_be_non_uniform)) {
                 current_function_->function_tag = SubsequentControlFlowMayBeNonUniform;
             }
@@ -362,7 +362,7 @@
         // If "Value_return" exists, look at which nodes are reachable from it
         if (current_function_->value_return) {
             utils::UniqueVector<Node*> reachable;
-            Traverse(current_function_->value_return, reachable);
+            Traverse(current_function_->value_return, &reachable);
             if (reachable.contains(current_function_->may_be_non_uniform)) {
                 current_function_->function_tag = ReturnValueMayBeNonUniform;
             }
@@ -388,7 +388,7 @@
             current_function_->ResetVisited();
 
             utils::UniqueVector<Node*> reachable;
-            Traverse(current_function_->parameters[i].pointer_return_value, reachable);
+            Traverse(current_function_->parameters[i].pointer_return_value, &reachable);
             if (reachable.contains(current_function_->may_be_non_uniform)) {
                 current_function_->parameters[i].pointer_may_become_non_uniform = true;
             }
@@ -1234,9 +1234,11 @@
     /// Recursively traverse a graph starting at `node`, inserting all nodes that are reached into
     /// `reachable`.
     /// @param node the starting node
-    /// @param reachable the set of reachable nodes to populate
-    void Traverse(Node* node, utils::UniqueVector<Node*>& reachable) {
-        reachable.add(node);
+    /// @param reachable the set of reachable nodes to populate, if required
+    void Traverse(Node* node, utils::UniqueVector<Node*>* reachable = nullptr) {
+        if (reachable) {
+            reachable->add(node);
+        }
         for (auto* to : node->edges) {
             if (to->visited_from == nullptr) {
                 to->visited_from = node;
@@ -1245,48 +1247,113 @@
         }
     }
 
-    /// Generate an error for a required_to_be_uniform->may_be_non_uniform path.
-    void MakeError() {
-        // Trace back to find a node that is required to be uniform that was reachable from a
-        // non-uniform value or control flow node.
-        Node* current = current_function_->may_be_non_uniform;
-        while (current) {
-            TINT_ASSERT(Resolver, current->visited_from);
-            if (current->visited_from == current_function_->required_to_be_uniform) {
+    /// Recursively descend through the function called by `call` and the functions that it calls in
+    /// order to find a call to a builtin function that requires uniformity.
+    const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call) {
+        auto* target = sem_.Get(call)->Target();
+        if (target->Is<sem::Builtin>()) {
+            // This is a call to a builtin, so we must be done.
+            return call;
+        } else if (auto* user = target->As<sem::Function>()) {
+            // This is a call to a user-defined function, so inspect the functions called by that
+            // function and look for one whose node has an edge from the RequiredToBeUniform node.
+            auto& target_info = functions_.at(user->Declaration());
+            for (auto* call_node : target_info.required_to_be_uniform->edges) {
+                if (call_node->arg_index == std::numeric_limits<uint32_t>::max()) {
+                    auto* child_call = call_node->ast->As<ast::CallExpression>();
+                    return FindBuiltinThatRequiresUniformity(child_call);
+                }
+            }
+            TINT_ASSERT(Resolver, false && "unable to find child call with uniformity requirement");
+        } else {
+            TINT_ASSERT(Resolver, false && "unexpected call expression type");
+        }
+        return nullptr;
+    }
+
+    /// Generate an error message for a uniformity issue.
+    /// @param function the function that the diagnostic is being produced for
+    /// @param source_node the node that has caused a uniformity issue in `function`
+    /// @param note `true` if the diagnostic should be emitted as a note
+    void MakeError(FunctionInfo& function, Node* source_node, bool note = false) {
+        // Helper to produce a diagnostic message with the severity required by this invocation of
+        // the `MakeError` function.
+        auto report = [&](Source source, std::string msg) {
+            // TODO(jrprice): Switch to error instead of warning when feedback has settled.
+            diag::Diagnostic error{};
+            error.severity = note ? diag::Severity::Note : diag::Severity::Warning;
+            error.system = diag::System::Resolver;
+            error.source = source;
+            error.message = msg;
+            diagnostics_.add(std::move(error));
+        };
+
+        // Traverse the graph to generate a path from RequiredToBeUniform to the source node.
+        function.ResetVisited();
+        Traverse(function.required_to_be_uniform);
+        TINT_ASSERT(Resolver, source_node->visited_from);
+
+        // Trace back through the graph to find a node that is required to be uniform that has
+        // a path to the source node.
+        Node* cause = source_node;
+        while (cause) {
+            if (cause->visited_from == function.required_to_be_uniform) {
                 break;
             }
-            current = current->visited_from;
+            cause = cause->visited_from;
         }
 
-        // The node will always have an corresponding call expression.
-        auto* call = current->ast->As<ast::CallExpression>();
+        // The node will always have a corresponding call expression.
+        auto* call = cause->ast->As<ast::CallExpression>();
         TINT_ASSERT(Resolver, call);
         auto* target = sem_.Get(call)->Target();
 
-        std::string name;
+        std::string func_name;
         if (auto* builtin = target->As<sem::Builtin>()) {
-            name = builtin->str();
+            func_name = builtin->str();
         } else if (auto* user = target->As<sem::Function>()) {
-            name = builder_->Symbols().NameFor(user->Declaration()->symbol);
+            func_name = builder_->Symbols().NameFor(user->Declaration()->symbol);
         }
 
-        // TODO(jrprice): Switch to error instead of warning when feedback has settled.
-        if (current->arg_index != std::numeric_limits<uint32_t>::max()) {
+        if (cause->arg_index != std::numeric_limits<uint32_t>::max()) {
             // The requirement was on a function parameter.
             auto param_name = builder_->Symbols().NameFor(
-                target->Parameters()[current->arg_index]->Declaration()->symbol);
-            diagnostics_.add_warning(
-                diag::System::Resolver,
-                "parameter '" + param_name + "' of '" + name + "' must be uniform",
-                call->args[current->arg_index]->source);
-            // TODO(jrprice): Show the reason why.
+                target->Parameters()[cause->arg_index]->Declaration()->symbol);
+            report(call->args[cause->arg_index]->source,
+                   "parameter '" + param_name + "' of '" + func_name + "' must be uniform");
+
+            // If this is a call to a user-defined function, add a note to show the reason that the
+            // parameter is required to be uniform.
+            if (auto* user = target->As<sem::Function>()) {
+                auto& next_function = functions_.at(user->Declaration());
+                Node* next_cause = next_function.parameters[cause->arg_index].init_value;
+                MakeError(next_function, next_cause, true);
+            }
         } else {
             // The requirement was on a function callsite.
-            diagnostics_.add_warning(diag::System::Resolver,
-                                     "'" + name + "' must only be called from uniform control flow",
-                                     call->source);
-            // TODO(jrprice): Show full call stack to the problematic builtin.
+            report(call->source,
+                   "'" + func_name + "' must only be called from uniform control flow");
+
+            // If this is a call to a user-defined function, add a note to show the builtin that
+            // causes the uniformity requirement.
+            auto* innermost_call = FindBuiltinThatRequiresUniformity(call);
+            if (innermost_call != call) {
+                // Determine whether the builtin is being called directly or indirectly.
+                bool indirect = false;
+                if (sem_.Get(call)->Target()->As<sem::Function>() !=
+                    sem_.Get(innermost_call)->Stmt()->Function()) {
+                    indirect = true;
+                }
+
+                auto* builtin = sem_.Get(innermost_call)->Target()->As<sem::Builtin>();
+                diagnostics_.add_note(diag::System::Resolver,
+                                      "'" + func_name + "' requires uniformity because it " +
+                                          (indirect ? "indirectly " : "") + "calls " +
+                                          builtin->str(),
+                                      innermost_call->source);
+            }
         }
+        // TODO(jrprice): Show the source of non-uniformity.
     }
 };
 
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index bb33f89..d81d7d6 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -466,6 +466,10 @@
               R"(test:11:7 warning: parameter 'i' of 'foo' must be uniform
   foo(rw);
       ^^
+
+test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
 )");
 }
 
@@ -3229,6 +3233,34 @@
 )");
 }
 
+TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter) {
+    auto src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) {
+  if (*p == 0) {
+    workgroupBarrier();
+  }
+}
+
+fn foo() {
+  var v = non_uniform;
+  bar(&v);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:12:7 warning: parameter 'p' of 'bar' must be uniform
+  bar(&v);
+      ^
+
+test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
+)");
+}
+
 TEST_F(UniformityAnalysisTest, LoadUniformThroughPointer) {
     auto src = R"(
 fn foo() {
@@ -3256,6 +3288,23 @@
     RunTest(src, true);
 }
 
+TEST_F(UniformityAnalysisTest, LoadUniformThroughPointerParameter) {
+    auto src = R"(
+fn bar(p : ptr<function, i32>) {
+  if (*p == 0) {
+    workgroupBarrier();
+  }
+}
+
+fn foo() {
+  var v = 42;
+  bar(&v);
+}
+)";
+
+    RunTest(src, true);
+}
+
 TEST_F(UniformityAnalysisTest, StoreNonUniformAfterCapturingPointer) {
     auto src = R"(
 @group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@@ -4884,5 +4933,114 @@
     RunTest(src, true);
 }
 
+////////////////////////////////////////////////////////////////////////////////
+/// Tests for the quality of the error messages produced by the analysis.
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinDirectly) {
+    auto src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+  workgroupBarrier();
+}
+
+fn main() {
+  if (non_uniform == 42) {
+    foo();
+  }
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:10:5 warning: 'foo' must only be called from uniform control flow
+    foo();
+    ^^^
+
+test:5:3 note: 'foo' requires uniformity because it calls workgroupBarrier
+  workgroupBarrier();
+  ^^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinIndirectly) {
+    auto src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn zoo() {
+  workgroupBarrier();
+}
+
+fn bar() {
+  zoo();
+}
+
+fn foo() {
+  bar();
+}
+
+fn main() {
+  if (non_uniform == 42) {
+    foo();
+  }
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:18:5 warning: 'foo' must only be called from uniform control flow
+    foo();
+    ^^^
+
+test:5:3 note: 'foo' requires uniformity because it indirectly calls workgroupBarrier
+  workgroupBarrier();
+  ^^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, Error_ParametersRequireUniformityInChain) {
+    auto src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn zoo(a : i32) {
+  if (a == 42) {
+    workgroupBarrier();
+  }
+}
+
+fn bar(b : i32) {
+  zoo(b);
+}
+
+fn foo(c : i32) {
+  bar(c);
+}
+
+fn main() {
+  foo(non_uniform);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:19:7 warning: parameter 'c' of 'foo' must be uniform
+  foo(non_uniform);
+      ^^^^^^^^^^^
+
+test:15:7 note: parameter 'b' of 'bar' must be uniform
+  bar(c);
+      ^
+
+test:11:7 note: parameter 'a' of 'zoo' must be uniform
+  zoo(b);
+      ^
+
+test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
+)");
+}
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/test/tint/bug/tint/943.spvasm.expected.glsl b/test/tint/bug/tint/943.spvasm.expected.glsl
index f581b5d..7edb269 100644
--- a/test/tint/bug/tint/943.spvasm.expected.glsl
+++ b/test/tint/bug/tint/943.spvasm.expected.glsl
@@ -1,4 +1,5 @@
 warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
+note: 'workgroupBarrier' must only be called from uniform control flow
 #version 310 es
 
 struct Uniforms {
diff --git a/test/tint/bug/tint/943.spvasm.expected.hlsl b/test/tint/bug/tint/943.spvasm.expected.hlsl
index 89accf2..ece0c20 100644
--- a/test/tint/bug/tint/943.spvasm.expected.hlsl
+++ b/test/tint/bug/tint/943.spvasm.expected.hlsl
@@ -1,4 +1,5 @@
 warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
+note: 'workgroupBarrier' must only be called from uniform control flow
 static int dimAOuter_1 = 0;
 cbuffer cbuffer_x_48 : register(b3, space0) {
   uint4 x_48[5];
diff --git a/test/tint/bug/tint/943.spvasm.expected.msl b/test/tint/bug/tint/943.spvasm.expected.msl
index 031bf63..48f3434 100644
--- a/test/tint/bug/tint/943.spvasm.expected.msl
+++ b/test/tint/bug/tint/943.spvasm.expected.msl
@@ -1,4 +1,5 @@
 warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
+note: 'workgroupBarrier' must only be called from uniform control flow
 #include <metal_stdlib>
 
 using namespace metal;
diff --git a/test/tint/bug/tint/943.spvasm.expected.spvasm b/test/tint/bug/tint/943.spvasm.expected.spvasm
index 228d7df..e9fc8c5 100644
--- a/test/tint/bug/tint/943.spvasm.expected.spvasm
+++ b/test/tint/bug/tint/943.spvasm.expected.spvasm
@@ -1,4 +1,5 @@
 warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
+note: 'workgroupBarrier' must only be called from uniform control flow
 ; SPIR-V
 ; Version: 1.3
 ; Generator: Google Tint Compiler; 0
diff --git a/test/tint/bug/tint/943.spvasm.expected.wgsl b/test/tint/bug/tint/943.spvasm.expected.wgsl
index aa383da..08be16a 100644
--- a/test/tint/bug/tint/943.spvasm.expected.wgsl
+++ b/test/tint/bug/tint/943.spvasm.expected.wgsl
@@ -1,4 +1,5 @@
 warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
+note: 'workgroupBarrier' must only be called from uniform control flow
 struct Uniforms {
   NAN : f32,
   @size(12)