tint/sem: Record diagnostic severity modifications

Each sem::Node has a map which stores the diagnostic modifications
applied to that node. The sem::Info class provides a query to get the
diagnostic severity for a given AST node, by walking up the semantic
tree to find the tightest diagnostic severity modification. The
default severity is used if it was not overridden.

This allows components outside of the Resolver/Validator to determine
the diagnostic severity while walking the AST, which is required for
the uniformity analysis.

Bug: tint:1809
Change-Id: I4caf99d7412fb22fb1183b2c8cfde349da2fefd3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/117601
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 8b32d93..0473d49 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1335,6 +1335,7 @@
   tint_unittests_source_set("tint_unittests_sem_src") {
     sources = [
       "sem/builtin_test.cc",
+      "sem/diagnostic_severity_test.cc",
       "sem/expression_test.cc",
       "sem/struct_test.cc",
     ]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 845ed65..58d28cf 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -961,6 +961,7 @@
     resolver/variable_validation_test.cc
     scope_stack_test.cc
     sem/builtin_test.cc
+    sem/diagnostic_severity_test.cc
     sem/expression_test.cc
     sem/struct_test.cc
     source_test.cc
diff --git a/src/tint/ast/diagnostic_control.h b/src/tint/ast/diagnostic_control.h
index 2571265..a4b209f 100644
--- a/src/tint/ast/diagnostic_control.h
+++ b/src/tint/ast/diagnostic_control.h
@@ -25,6 +25,7 @@
 
 #include <ostream>
 #include <string>
+#include <unordered_map>
 
 #include "src/tint/ast/node.h"
 
@@ -84,6 +85,9 @@
 /// Convert a DiagnosticSeverity to the corresponding diag::Severity.
 diag::Severity ToSeverity(DiagnosticSeverity sc);
 
+/// DiagnosticRuleSeverities is a map from diagnostic rule to diagnostic severity.
+using DiagnosticRuleSeverities = std::unordered_map<DiagnosticRule, DiagnosticSeverity>;
+
 /// A diagnostic control used for diagnostic directives and attributes.
 class DiagnosticControl : public Castable<DiagnosticControl, Node> {
   public:
diff --git a/src/tint/ast/diagnostic_control.h.tmpl b/src/tint/ast/diagnostic_control.h.tmpl
index 5018783..49946c2 100644
--- a/src/tint/ast/diagnostic_control.h.tmpl
+++ b/src/tint/ast/diagnostic_control.h.tmpl
@@ -15,6 +15,7 @@
 
 #include <ostream>
 #include <string>
+#include <unordered_map>
 
 #include "src/tint/ast/node.h"
 
@@ -34,6 +35,9 @@
 /// Convert a DiagnosticSeverity to the corresponding diag::Severity.
 diag::Severity ToSeverity(DiagnosticSeverity sc);
 
+/// DiagnosticRuleSeverities is a map from diagnostic rule to diagnostic severity.
+using DiagnosticRuleSeverities = std::unordered_map<DiagnosticRule, DiagnosticSeverity>;
+
 /// A diagnostic control used for diagnostic directives and attributes.
 class DiagnosticControl : public Castable<DiagnosticControl, Node> {
   public:
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 357da57..742599e 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -136,9 +136,11 @@
         return false;
     }
 
-    // Create the semantic module
-    builder_->Sem().SetModule(builder_->create<sem::Module>(
-        std::move(dependencies_.ordered_globals), std::move(enabled_extensions_)));
+    // Create the semantic module.
+    auto* mod = builder_->create<sem::Module>(std::move(dependencies_.ordered_globals),
+                                              std::move(enabled_extensions_));
+    ApplyDiagnosticSeverities(mod);
+    builder_->Sem().SetModule(mod);
 
     return result;
 }
@@ -1073,6 +1075,7 @@
 
     auto* func =
         builder_->create<sem::Function>(decl, return_type, return_location, std::move(parameters));
+    ApplyDiagnosticSeverities(func);
     builder_->Sem().Add(decl, func);
 
     TINT_SCOPED_ASSIGNMENT(current_function_, func);
@@ -3878,6 +3881,13 @@
     return false;
 }
 
+template <typename NODE>
+void Resolver::ApplyDiagnosticSeverities(NODE* node) {
+    for (auto itr : validator_.DiagnosticFilters().Top()) {
+        node->SetDiagnosticSeverity(itr.key, itr.value);
+    }
+}
+
 void Resolver::AddError(const std::string& msg, const Source& source) const {
     diagnostics_.add_error(diag::System::Resolver, msg, source);
 }
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index d6d7ba9..519cea8 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -414,6 +414,11 @@
     /// @returns true on success, false on error
     bool Mark(const ast::Node* node);
 
+    /// Applies the diagnostic severities from the current scope to a semantic node.
+    /// @param node the semantic node to apply the diagnostic severities to
+    template <typename NODE>
+    void ApplyDiagnosticSeverities(NODE* node);
+
     /// Adds the given error message to the diagnostics
     void AddError(const std::string& msg, const Source& source) const;
 
diff --git a/src/tint/sem/diagnostic_severity_test.cc b/src/tint/sem/diagnostic_severity_test.cc
new file mode 100644
index 0000000..7ede2bc
--- /dev/null
+++ b/src/tint/sem/diagnostic_severity_test.cc
@@ -0,0 +1,68 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/sem/test_helper.h"
+
+#include "src/tint/sem/module.h"
+
+using namespace tint::number_suffixes;  // NOLINT
+
+namespace tint::sem {
+namespace {
+
+class DiagnosticSeverityTest : public TestHelper {
+  protected:
+    /// Create a program with two functions, setting the severity for "chromium_unreachable_code"
+    /// using an attribute. Test that we correctly track the severity of the filter for the
+    /// functions and the statements with them.
+    /// @param global_severity the global severity of the "chromium_unreachable_code" filter
+    void Run(ast::DiagnosticSeverity global_severity) {
+        // @diagnostic(off, chromium_unreachable_code)
+        // fn foo() {
+        //   return;
+        // }
+        //
+        // fn bar() {
+        //   return;
+        // }
+        auto rule = ast::DiagnosticRule::kChromiumUnreachableCode;
+        auto func_severity = ast::DiagnosticSeverity::kOff;
+
+        auto* return_1 = Return();
+        auto* return_2 = Return();
+        auto* func_attr = DiagnosticAttribute(func_severity, Expr("chromium_unreachable_code"));
+        auto* foo = Func("foo", {}, ty.void_(), utils::Vector{return_1}, utils::Vector{func_attr});
+        auto* bar = Func("bar", {}, ty.void_(), utils::Vector{return_2});
+
+        auto p = Build();
+        EXPECT_TRUE(p.IsValid()) << p.Diagnostics().str();
+
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(foo, rule), func_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(return_1, rule), func_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(bar, rule), global_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(return_2, rule), global_severity);
+    }
+};
+
+TEST_F(DiagnosticSeverityTest, WithDirective) {
+    DiagnosticDirective(ast::DiagnosticSeverity::kError, Expr("chromium_unreachable_code"));
+    Run(ast::DiagnosticSeverity::kError);
+}
+
+TEST_F(DiagnosticSeverityTest, WithoutDirective) {
+    Run(ast::DiagnosticSeverity::kWarning);
+}
+
+}  // namespace
+}  // namespace tint::sem
diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h
index 32468aa..25bacda 100644
--- a/src/tint/sem/function.h
+++ b/src/tint/sem/function.h
@@ -20,6 +20,7 @@
 #include <utility>
 #include <vector>
 
+#include "src/tint/ast/diagnostic_control.h"
 #include "src/tint/ast/variable.h"
 #include "src/tint/sem/call.h"
 #include "src/tint/utils/unique_vector.h"
@@ -256,6 +257,18 @@
     /// @return the location for the return, if provided
     std::optional<uint32_t> ReturnLocation() const { return return_location_; }
 
+    /// Modifies the severity of a specific diagnostic rule for this function.
+    /// @param rule the diagnostic rule
+    /// @param severity the new diagnostic severity
+    void SetDiagnosticSeverity(ast::DiagnosticRule rule, ast::DiagnosticSeverity severity) {
+        diagnostic_severities_[rule] = severity;
+    }
+
+    /// @returns the diagnostic severity modifications applied to this function
+    const ast::DiagnosticRuleSeverities& DiagnosticSeverities() const {
+        return diagnostic_severities_;
+    }
+
   private:
     Function(const Function&) = delete;
     Function(Function&&) = delete;
@@ -276,6 +289,7 @@
     std::vector<const Function*> ancestor_entry_points_;
     const Statement* discard_stmt_ = nullptr;
     sem::Behaviors behaviors_{sem::Behavior::kNext};
+    ast::DiagnosticRuleSeverities diagnostic_severities_;
 
     std::optional<uint32_t> return_location_;
 };
diff --git a/src/tint/sem/info.cc b/src/tint/sem/info.cc
index edeab7e..a3f5b48 100644
--- a/src/tint/sem/info.cc
+++ b/src/tint/sem/info.cc
@@ -14,6 +14,11 @@
 
 #include "src/tint/sem/info.h"
 
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/statement.h"
+
 namespace tint::sem {
 
 Info::Info() = default;
@@ -24,4 +29,61 @@
 
 Info& Info::operator=(Info&&) = default;
 
+ast::DiagnosticSeverity Info::DiagnosticSeverity(const ast::Node* ast_node,
+                                                 ast::DiagnosticRule rule) const {
+    // Get the diagnostic severity modification for a node.
+    auto check = [&](auto* node) {
+        auto& severities = node->DiagnosticSeverities();
+        auto itr = severities.find(rule);
+        if (itr != severities.end()) {
+            return itr->second;
+        }
+        return ast::DiagnosticSeverity::kUndefined;
+    };
+
+    // Get the diagnostic severity modification for a function.
+    auto check_func = [&](const sem::Function* func) {
+        auto severity = check(func);
+        if (severity != ast::DiagnosticSeverity::kUndefined) {
+            return severity;
+        }
+
+        // No severity set on the function, so check the module instead.
+        return check(module_);
+    };
+
+    // Get the diagnostic severity modification for a statement.
+    auto check_stmt = [&](const sem::Statement* stmt) {
+        // Walk up the statement hierarchy, checking for diagnostic severity modifications.
+        while (true) {
+            auto severity = check(stmt);
+            if (severity != ast::DiagnosticSeverity::kUndefined) {
+                return severity;
+            }
+            if (!stmt->Parent()) {
+                break;
+            }
+            stmt = stmt->Parent();
+        }
+
+        // No severity set on the statement, so check the function instead.
+        return check_func(stmt->Function());
+    };
+
+    // Query the diagnostic severity from the semantic node that corresponds to the AST node.
+    auto* sem = Get(ast_node);
+    TINT_ASSERT(Resolver, sem != nullptr);
+    auto severity = Switch(
+        sem,  //
+        [&](const sem::Expression* expr) { return check_stmt(expr->Stmt()); },
+        [&](const sem::Statement* stmt) { return check_stmt(stmt); },
+        [&](const sem::Function* func) { return check_func(func); },
+        [&](Default) {
+            // Use the global severity set on the module.
+            return check(module_);
+        });
+    TINT_ASSERT(Resolver, severity != ast::DiagnosticSeverity::kUndefined);
+    return severity;
+}
+
 }  // namespace tint::sem
diff --git a/src/tint/sem/info.h b/src/tint/sem/info.h
index 246510c..c97b014 100644
--- a/src/tint/sem/info.h
+++ b/src/tint/sem/info.h
@@ -20,6 +20,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include "src/tint/ast/diagnostic_control.h"
 #include "src/tint/ast/node.h"
 #include "src/tint/debug.h"
 #include "src/tint/sem/node.h"
@@ -144,6 +145,13 @@
         return &referenced_overrides_.at(from);
     }
 
+    /// Determines the severity of a filterable diagnostic rule for the AST node `ast_node`.
+    /// @param ast_node the AST node
+    /// @param rule the diagnostic rule
+    /// @returns the severity of the rule for that AST node
+    ast::DiagnosticSeverity DiagnosticSeverity(const ast::Node* ast_node,
+                                               ast::DiagnosticRule rule) const;
+
   private:
     // AST node index to semantic node
     std::vector<const CastableBase*> nodes_;
diff --git a/src/tint/sem/module.h b/src/tint/sem/module.h
index 216a2c4..21b08f2 100644
--- a/src/tint/sem/module.h
+++ b/src/tint/sem/module.h
@@ -15,6 +15,7 @@
 #ifndef SRC_TINT_SEM_MODULE_H_
 #define SRC_TINT_SEM_MODULE_H_
 
+#include "src/tint/ast/diagnostic_control.h"
 #include "src/tint/ast/extension.h"
 #include "src/tint/sem/node.h"
 #include "src/tint/utils/vector.h"
@@ -46,9 +47,22 @@
     /// @returns the list of enabled extensions in the module
     const ast::Extensions& Extensions() const { return extensions_; }
 
+    /// Modifies the severity of a specific diagnostic rule for this module.
+    /// @param rule the diagnostic rule
+    /// @param severity the new diagnostic severity
+    void SetDiagnosticSeverity(ast::DiagnosticRule rule, ast::DiagnosticSeverity severity) {
+        diagnostic_severities_[rule] = severity;
+    }
+
+    /// @returns the diagnostic severity modifications applied to this module
+    const ast::DiagnosticRuleSeverities& DiagnosticSeverities() const {
+        return diagnostic_severities_;
+    }
+
   private:
     const utils::Vector<const ast::Node*, 64> dep_ordered_decls_;
     ast::Extensions extensions_;
+    ast::DiagnosticRuleSeverities diagnostic_severities_;
 };
 
 }  // namespace tint::sem
diff --git a/src/tint/sem/statement.h b/src/tint/sem/statement.h
index 09112d5..2d17295 100644
--- a/src/tint/sem/statement.h
+++ b/src/tint/sem/statement.h
@@ -15,6 +15,7 @@
 #ifndef SRC_TINT_SEM_STATEMENT_H_
 #define SRC_TINT_SEM_STATEMENT_H_
 
+#include "src/tint/ast/diagnostic_control.h"
 #include "src/tint/sem/behavior.h"
 #include "src/tint/sem/node.h"
 #include "src/tint/symbol.h"
@@ -109,12 +110,25 @@
     /// according to the behavior analysis
     void SetIsReachable(bool is_reachable) { is_reachable_ = is_reachable; }
 
+    /// Modifies the severity of a specific diagnostic rule for this statement.
+    /// @param rule the diagnostic rule
+    /// @param severity the new diagnostic severity
+    void SetDiagnosticSeverity(ast::DiagnosticRule rule, ast::DiagnosticSeverity severity) {
+        diagnostic_severities_[rule] = severity;
+    }
+
+    /// @returns the diagnostic severity modifications applied to this statement
+    const ast::DiagnosticRuleSeverities& DiagnosticSeverities() const {
+        return diagnostic_severities_;
+    }
+
   private:
     const ast::Statement* const declaration_;
     const CompoundStatement* const parent_;
     const sem::Function* const function_;
     sem::Behaviors behaviors_{sem::Behavior::kNext};
     bool is_reachable_ = true;
+    ast::DiagnosticRuleSeverities diagnostic_severities_;
 };
 
 /// CompoundStatement is the base class of statements that can hold other