tint: Support @diagnostic on switch statements

Bug: tint:1809
Change-Id: I9dc0ff97aef1914d53799259339b7fc2cd01f615
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124061
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/ast/switch_statement.cc b/src/tint/ast/switch_statement.cc
index 1cdf8a4..0445e48 100644
--- a/src/tint/ast/switch_statement.cc
+++ b/src/tint/ast/switch_statement.cc
@@ -26,14 +26,19 @@
                                  NodeID nid,
                                  const Source& src,
                                  const Expression* cond,
-                                 utils::VectorRef<const CaseStatement*> b)
-    : Base(pid, nid, src), condition(cond), body(std::move(b)) {
+                                 utils::VectorRef<const CaseStatement*> b,
+                                 utils::VectorRef<const Attribute*> attrs)
+    : Base(pid, nid, src), condition(cond), body(std::move(b)), attributes(std::move(attrs)) {
     TINT_ASSERT(AST, condition);
     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
     for (auto* stmt : body) {
         TINT_ASSERT(AST, stmt);
         TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, stmt, program_id);
     }
+    for (auto* attr : attributes) {
+        TINT_ASSERT(AST, attr);
+        TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, attr, program_id);
+    }
 }
 
 SwitchStatement::SwitchStatement(SwitchStatement&&) = default;
@@ -45,7 +50,8 @@
     auto src = ctx->Clone(source);
     auto* cond = ctx->Clone(condition);
     auto b = ctx->Clone(body);
-    return ctx->dst->create<SwitchStatement>(src, cond, std::move(b));
+    auto attrs = ctx->Clone(attributes);
+    return ctx->dst->create<SwitchStatement>(src, cond, std::move(b), std::move(attrs));
 }
 
 }  // namespace tint::ast
diff --git a/src/tint/ast/switch_statement.h b/src/tint/ast/switch_statement.h
index eb083f0..e918535 100644
--- a/src/tint/ast/switch_statement.h
+++ b/src/tint/ast/switch_statement.h
@@ -29,11 +29,13 @@
     /// @param src the source of this node
     /// @param condition the switch condition
     /// @param body the switch body
+    /// @param attributes the switch statement attributes
     SwitchStatement(ProgramID pid,
                     NodeID nid,
                     const Source& src,
                     const Expression* condition,
-                    utils::VectorRef<const CaseStatement*> body);
+                    utils::VectorRef<const CaseStatement*> body,
+                    utils::VectorRef<const Attribute*> attributes);
     /// Move constructor
     SwitchStatement(SwitchStatement&&);
     ~SwitchStatement() override;
@@ -50,6 +52,9 @@
     /// The Switch body
     const utils::Vector<const CaseStatement*, 4> body;
     SwitchStatement(const SwitchStatement&) = delete;
+
+    /// The attribute list
+    const utils::Vector<const Attribute*, 1> attributes;
 };
 
 }  // namespace tint::ast
diff --git a/src/tint/ast/switch_statement_test.cc b/src/tint/ast/switch_statement_test.cc
index e6101c8..aab2604 100644
--- a/src/tint/ast/switch_statement_test.cc
+++ b/src/tint/ast/switch_statement_test.cc
@@ -14,6 +14,7 @@
 
 #include "src/tint/ast/switch_statement.h"
 
+#include "gmock/gmock.h"
 #include "gtest/gtest-spi.h"
 #include "src/tint/ast/test_helper.h"
 
@@ -29,7 +30,7 @@
     auto* ident = Expr("ident");
     utils::Vector body{case_stmt};
 
-    auto* stmt = create<SwitchStatement>(ident, body);
+    auto* stmt = create<SwitchStatement>(ident, body, utils::Empty);
     EXPECT_EQ(stmt->condition, ident);
     ASSERT_EQ(stmt->body.Length(), 1u);
     EXPECT_EQ(stmt->body[0], case_stmt);
@@ -37,18 +38,28 @@
 
 TEST_F(SwitchStatementTest, Creation_WithSource) {
     auto* ident = Expr("ident");
-    auto* stmt = create<SwitchStatement>(Source{Source::Location{20, 2}}, ident, utils::Empty);
+    auto* stmt =
+        create<SwitchStatement>(Source{Source::Location{20, 2}}, ident, utils::Empty, utils::Empty);
     auto src = stmt->source;
     EXPECT_EQ(src.range.begin.line, 20u);
     EXPECT_EQ(src.range.begin.column, 2u);
 }
 
+TEST_F(SwitchStatementTest, Creation_WithAttributes) {
+    auto* attr1 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "foo");
+    auto* attr2 = DiagnosticAttribute(builtin::DiagnosticSeverity::kOff, "bar");
+    auto* ident = Expr("ident");
+    auto* stmt = create<SwitchStatement>(ident, utils::Empty, utils::Vector{attr1, attr2});
+
+    EXPECT_THAT(stmt->attributes, testing::ElementsAre(attr1, attr2));
+}
+
 TEST_F(SwitchStatementTest, IsSwitch) {
     utils::Vector lit{CaseSelector(2_i)};
     auto* ident = Expr("ident");
     utils::Vector body{create<CaseStatement>(lit, Block())};
 
-    auto* stmt = create<SwitchStatement>(ident, body);
+    auto* stmt = create<SwitchStatement>(ident, body, utils::Empty);
     EXPECT_TRUE(stmt->Is<SwitchStatement>());
 }
 
@@ -60,7 +71,7 @@
             CaseStatementList cases;
             cases.Push(
                 b.create<CaseStatement>(utils::Vector{b.CaseSelector(b.Expr(1_i))}, b.Block()));
-            b.create<SwitchStatement>(nullptr, cases);
+            b.create<SwitchStatement>(nullptr, cases, utils::Empty);
         },
         "internal compiler error");
 }
@@ -70,7 +81,7 @@
     EXPECT_FATAL_FAILURE(
         {
             ProgramBuilder b;
-            b.create<SwitchStatement>(b.Expr(true), CaseStatementList{nullptr});
+            b.create<SwitchStatement>(b.Expr(true), CaseStatementList{nullptr}, utils::Empty);
         },
         "internal compiler error");
 }
@@ -80,13 +91,15 @@
         {
             ProgramBuilder b1;
             ProgramBuilder b2;
-            b1.create<SwitchStatement>(b2.Expr(true), utils::Vector{
-                                                          b1.create<CaseStatement>(
-                                                              utils::Vector{
-                                                                  b1.CaseSelector(b1.Expr(1_i)),
-                                                              },
-                                                              b1.Block()),
-                                                      });
+            b1.create<SwitchStatement>(b2.Expr(true),
+                                       utils::Vector{
+                                           b1.create<CaseStatement>(
+                                               utils::Vector{
+                                                   b1.CaseSelector(b1.Expr(1_i)),
+                                               },
+                                               b1.Block()),
+                                       },
+                                       utils::Empty);
         },
         "internal compiler error");
 }
@@ -96,13 +109,15 @@
         {
             ProgramBuilder b1;
             ProgramBuilder b2;
-            b1.create<SwitchStatement>(b1.Expr(true), utils::Vector{
-                                                          b2.create<CaseStatement>(
-                                                              utils::Vector{
-                                                                  b2.CaseSelector(b2.Expr(1_i)),
-                                                              },
-                                                              b2.Block()),
-                                                      });
+            b1.create<SwitchStatement>(b1.Expr(true),
+                                       utils::Vector{
+                                           b2.create<CaseStatement>(
+                                               utils::Vector{
+                                                   b2.CaseSelector(b2.Expr(1_i)),
+                                               },
+                                               b2.Block()),
+                                       },
+                                       utils::Empty);
         },
         "internal compiler error");
 }
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 4c77d75..a2d4f48 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -3352,14 +3352,15 @@
     /// @param condition the condition expression initializer
     /// @param cases case statements
     /// @returns the switch statement pointer
-    template <typename ExpressionInit, typename... Cases>
+    template <typename ExpressionInit, typename... Cases, typename = DisableIfVectorLike<Cases...>>
     const ast::SwitchStatement* Switch(const Source& source,
                                        ExpressionInit&& condition,
                                        Cases&&... cases) {
         return create<ast::SwitchStatement>(
             source, Expr(std::forward<ExpressionInit>(condition)),
             utils::Vector<const ast::CaseStatement*, sizeof...(cases)>{
-                std::forward<Cases>(cases)...});
+                std::forward<Cases>(cases)...},
+            utils::Empty);
     }
 
     /// Creates a ast::SwitchStatement with input expression and cases
@@ -3368,12 +3369,44 @@
     /// @returns the switch statement pointer
     template <typename ExpressionInit,
               typename... Cases,
-              typename = DisableIfSource<ExpressionInit>>
+              typename = DisableIfSource<ExpressionInit>,
+              typename = DisableIfVectorLike<Cases...>>
     const ast::SwitchStatement* Switch(ExpressionInit&& condition, Cases&&... cases) {
         return create<ast::SwitchStatement>(
             Expr(std::forward<ExpressionInit>(condition)),
             utils::Vector<const ast::CaseStatement*, sizeof...(cases)>{
-                std::forward<Cases>(cases)...});
+                std::forward<Cases>(cases)...},
+            utils::Empty);
+    }
+
+    /// Creates a ast::SwitchStatement with input expression, cases, and optional attributes
+    /// @param source the source information
+    /// @param condition the condition expression initializer
+    /// @param cases case statements
+    /// @param attributes optional attributes
+    /// @returns the switch statement pointer
+    template <typename ExpressionInit>
+    const ast::SwitchStatement* Switch(
+        const Source& source,
+        ExpressionInit&& condition,
+        utils::VectorRef<const ast::CaseStatement*> cases,
+        utils::VectorRef<const ast::Attribute*> attributes = utils::Empty) {
+        return create<ast::SwitchStatement>(source, Expr(std::forward<ExpressionInit>(condition)),
+                                            cases, std::move(attributes));
+    }
+
+    /// Creates a ast::SwitchStatement with input expression, cases, and optional attributes
+    /// @param condition the condition expression initializer
+    /// @param cases case statements
+    /// @param attributes optional attributes
+    /// @returns the switch statement pointer
+    template <typename ExpressionInit, typename = DisableIfSource<ExpressionInit>>
+    const ast::SwitchStatement* Switch(
+        ExpressionInit&& condition,
+        utils::VectorRef<const ast::CaseStatement*> cases,
+        utils::VectorRef<const ast::Attribute*> attributes = utils::Empty) {
+        return create<ast::SwitchStatement>(Expr(std::forward<ExpressionInit>(condition)), cases,
+                                            std::move(attributes));
     }
 
     /// Creates a ast::CaseStatement with input list of selectors, and body
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index b78719c..7ec46e7 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -706,8 +706,8 @@
         auto reversed_cases = cases;
         std::reverse(reversed_cases.begin(), reversed_cases.end());
 
-        return builder->create<ast::SwitchStatement>(Source{}, condition,
-                                                     std::move(reversed_cases));
+        return builder->create<ast::SwitchStatement>(Source{}, condition, std::move(reversed_cases),
+                                                     utils::Empty);
     }
 
     /// Switch statement condition
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 35a6d15..5f901e0 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -1272,7 +1272,7 @@
         return stmt_if.value;
     }
 
-    auto sw = switch_statement();
+    auto sw = switch_statement(attrs.value);
     if (sw.errored) {
         return Failure::kErrored;
     }
@@ -1604,8 +1604,8 @@
 }
 
 // switch_statement
-//   : SWITCH expression BRACKET_LEFT switch_body+ BRACKET_RIGHT
-Maybe<const ast::SwitchStatement*> ParserImpl::switch_statement() {
+//   : attribute* SWITCH expression BRACKET_LEFT switch_body+ BRACKET_RIGHT
+Maybe<const ast::SwitchStatement*> ParserImpl::switch_statement(AttributeList& attrs) {
     Source source;
     if (!match(Token::Type::kSwitch, &source)) {
         return Failure::kNoMatch;
@@ -1643,7 +1643,8 @@
         return Failure::kErrored;
     }
 
-    return create<ast::SwitchStatement>(source, condition.value, body.value);
+    TINT_DEFER(attrs.Clear());
+    return create<ast::SwitchStatement>(source, condition.value, body.value, std::move(attrs));
 }
 
 // switch_body
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index 8f039b7..dd0a96c 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -493,8 +493,9 @@
     /// @returns the parsed statement or nullptr
     Maybe<const ast::IfStatement*> if_statement(AttributeList& attrs);
     /// Parses a `switch_statement` grammar element
+    /// @param attrs the list of attributes for the statement
     /// @returns the parsed statement or nullptr
-    Maybe<const ast::SwitchStatement*> switch_statement();
+    Maybe<const ast::SwitchStatement*> switch_statement(AttributeList& attrs);
     /// Parses a `switch_body` grammar element
     /// @returns the parsed statement or nullptr
     Maybe<const ast::CaseStatement*> switch_body();
diff --git a/src/tint/reader/wgsl/parser_impl_statement_test.cc b/src/tint/reader/wgsl/parser_impl_statement_test.cc
index dce5114..8b50d10 100644
--- a/src/tint/reader/wgsl/parser_impl_statement_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_statement_test.cc
@@ -338,6 +338,18 @@
     EXPECT_EQ(s->attributes.Length(), 1u);
 }
 
+TEST_F(ParserImplTest, Statement_ConsumedAttributes_Switch) {
+    auto p = parser("@diagnostic(off, derivative_uniformity) switch (0) { default{} }");
+    auto e = p->statement();
+    ASSERT_FALSE(p->has_error()) << p->error();
+    EXPECT_TRUE(e.matched);
+    EXPECT_FALSE(e.errored);
+
+    auto* s = As<ast::SwitchStatement>(e.value);
+    ASSERT_NE(s, nullptr);
+    EXPECT_EQ(s->attributes.Length(), 1u);
+}
+
 TEST_F(ParserImplTest, Statement_UnexpectedAttributes) {
     auto p = parser("@diagnostic(off, derivative_uniformity) return;");
     auto e = p->statement();
diff --git a/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc
index a5e3dd3..ef0b1b4 100644
--- a/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc
@@ -22,7 +22,8 @@
   case 1: {}
   case 2: {}
 })");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_TRUE(e.matched);
     EXPECT_FALSE(e.errored);
     EXPECT_FALSE(p->has_error()) << p->error();
@@ -35,7 +36,8 @@
 
 TEST_F(ParserImplTest, SwitchStmt_Empty) {
     auto p = parser("switch a { }");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_TRUE(e.matched);
     EXPECT_FALSE(e.errored);
     EXPECT_FALSE(p->has_error()) << p->error();
@@ -50,7 +52,8 @@
   default: {}
   case 2: {}
 })");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_TRUE(e.matched);
     EXPECT_FALSE(e.errored);
     EXPECT_FALSE(p->has_error()) << p->error();
@@ -67,7 +70,8 @@
     auto p = parser(R"(switch a {
   case 1, default, 2: {}
 })");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_TRUE(e.matched);
     EXPECT_FALSE(e.errored);
     EXPECT_FALSE(p->has_error()) << p->error();
@@ -80,7 +84,8 @@
 
 TEST_F(ParserImplTest, SwitchStmt_WithParens) {
     auto p = parser("switch(a+b) { }");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_TRUE(e.matched);
     EXPECT_FALSE(e.errored);
     EXPECT_FALSE(p->has_error()) << p->error();
@@ -89,9 +94,25 @@
     ASSERT_EQ(e->body.Length(), 0u);
 }
 
+TEST_F(ParserImplTest, SwitchStmt_WithAttributes) {
+    auto p = parser(R"(@diagnostic(off, derivative_uniformity) switch a { default{} })");
+    auto a = p->attribute_list();
+    auto e = p->switch_statement(a.value);
+    EXPECT_TRUE(e.matched);
+    EXPECT_FALSE(e.errored);
+    EXPECT_FALSE(p->has_error()) << p->error();
+    ASSERT_NE(e.value, nullptr);
+    ASSERT_TRUE(e->Is<ast::SwitchStatement>());
+
+    EXPECT_TRUE(a->IsEmpty());
+    ASSERT_EQ(e->attributes.Length(), 1u);
+    EXPECT_TRUE(e->attributes[0]->Is<ast::DiagnosticAttribute>());
+}
+
 TEST_F(ParserImplTest, SwitchStmt_InvalidExpression) {
     auto p = parser("switch a=b {}");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_FALSE(e.matched);
     EXPECT_TRUE(e.errored);
     EXPECT_EQ(e.value, nullptr);
@@ -101,7 +122,8 @@
 
 TEST_F(ParserImplTest, SwitchStmt_MissingExpression) {
     auto p = parser("switch {}");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_FALSE(e.matched);
     EXPECT_TRUE(e.errored);
     EXPECT_EQ(e.value, nullptr);
@@ -111,7 +133,8 @@
 
 TEST_F(ParserImplTest, SwitchStmt_MissingBracketLeft) {
     auto p = parser("switch a }");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_FALSE(e.matched);
     EXPECT_TRUE(e.errored);
     EXPECT_EQ(e.value, nullptr);
@@ -121,7 +144,8 @@
 
 TEST_F(ParserImplTest, SwitchStmt_MissingBracketRight) {
     auto p = parser("switch a {");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_FALSE(e.matched);
     EXPECT_TRUE(e.errored);
     EXPECT_EQ(e.value, nullptr);
@@ -133,7 +157,8 @@
     auto p = parser(R"(switch a {
   case: {}
 })");
-    auto e = p->switch_statement();
+    ParserImpl::AttributeList attrs;
+    auto e = p->switch_statement(attrs);
     EXPECT_FALSE(e.matched);
     EXPECT_TRUE(e.errored);
     EXPECT_EQ(e.value, nullptr);
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 3430d9d..1a8de3b 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -1045,6 +1045,39 @@
 12:34 note: first attribute declared here)");
 }
 
+using SwitchStatementAttributeTest = TestWithParams;
+TEST_P(SwitchStatementAttributeTest, IsValid) {
+    auto& params = GetParam();
+
+    WrapInFunction(Switch(Expr(0_a), utils::Vector{DefaultCase()},
+                          createAttributes(Source{{12, 34}}, *this, params.kind)));
+
+    if (params.should_pass) {
+        EXPECT_TRUE(r()->Resolve()) << r()->error();
+    } else {
+        EXPECT_FALSE(r()->Resolve());
+        EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for switch statements");
+    }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
+                         SwitchStatementAttributeTest,
+                         testing::Values(TestParams{AttributeKind::kAlign, false},
+                                         TestParams{AttributeKind::kBinding, false},
+                                         TestParams{AttributeKind::kBuiltin, false},
+                                         TestParams{AttributeKind::kDiagnostic, true},
+                                         TestParams{AttributeKind::kGroup, false},
+                                         TestParams{AttributeKind::kId, false},
+                                         TestParams{AttributeKind::kInterpolate, false},
+                                         TestParams{AttributeKind::kInvariant, false},
+                                         TestParams{AttributeKind::kLocation, false},
+                                         TestParams{AttributeKind::kMustUse, false},
+                                         TestParams{AttributeKind::kOffset, false},
+                                         TestParams{AttributeKind::kSize, false},
+                                         TestParams{AttributeKind::kStage, false},
+                                         TestParams{AttributeKind::kStride, false},
+                                         TestParams{AttributeKind::kWorkgroup, false},
+                                         TestParams{AttributeKind::kBindingAndGroup, false}));
+
 using IfStatementAttributeTest = TestWithParams;
 TEST_P(IfStatementAttributeTest, IsValid) {
     auto& params = GetParam();
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index ff053c3..1b5c71a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -4275,6 +4275,9 @@
                 return handle_attributes(block, sem, "block statements");
             },
             [&](const ast::IfStatement* i) { return handle_attributes(i, sem, "if statements"); },
+            [&](const ast::SwitchStatement* s) {
+                return handle_attributes(s, sem, "switch statements");
+            },
             [&](Default) { return true; })) {
         return nullptr;
     }
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index 7fa32f6..e7ee8cd 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -8452,6 +8452,58 @@
     }
 }
 
+TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnSwitchStatement_CallInCondition) {
+    auto& param = GetParam();
+    utils::StringStream ss;
+    ss << R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+fn foo() {
+  )"
+       << "@diagnostic(" << param << ", derivative_uniformity)"
+       << R"(switch (i32(non_uniform == 42 && dpdx(1.0) > 0.0)) {
+    default {}
+  }
+}
+)";
+
+    RunTest(ss.str(), param != builtin::DiagnosticSeverity::kError);
+    if (param == builtin::DiagnosticSeverity::kOff) {
+        EXPECT_TRUE(error_.empty());
+    } else {
+        utils::StringStream err;
+        err << ToStr(param) << ": 'dpdx' must only be called";
+        EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
+    }
+}
+
+TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnSwitchStatement_CallInBody) {
+    auto& param = GetParam();
+    utils::StringStream ss;
+    ss << R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+@group(0) @binding(1) var t : texture_2d<f32>;
+@group(0) @binding(2) var s : sampler;
+fn foo() {
+  )"
+       << "@diagnostic(" << param << ", derivative_uniformity)"
+       << R"(switch (non_uniform) {
+    default {
+      let color = textureSample(t, s, vec2(0, 0));
+    }
+  }
+}
+)";
+
+    RunTest(ss.str(), param != builtin::DiagnosticSeverity::kError);
+    if (param == builtin::DiagnosticSeverity::kOff) {
+        EXPECT_TRUE(error_.empty());
+    } else {
+        utils::StringStream err;
+        err << ToStr(param) << ": 'textureSample' must only be called";
+        EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
+    }
+}
+
 INSTANTIATE_TEST_SUITE_P(UniformityAnalysisTest,
                          UniformityAnalysisDiagnosticFilterTest,
                          ::testing::Values(builtin::DiagnosticSeverity::kError,
diff --git a/src/tint/sem/diagnostic_severity_test.cc b/src/tint/sem/diagnostic_severity_test.cc
index b3b0b3a..328e589 100644
--- a/src/tint/sem/diagnostic_severity_test.cc
+++ b/src/tint/sem/diagnostic_severity_test.cc
@@ -40,6 +40,16 @@
         //       return;
         //     }
         //     return;
+        //
+        //     @diagnostic(error, chromium_unreachable_code)
+        //     switch (42) {
+        //       case 0 @diagnostic(warning, chromium_unreachable_code) {
+        //         return;
+        //       }
+        //       default {
+        //         return;
+        //       }
+        //     }
         //   }
         // }
         //
@@ -52,6 +62,8 @@
         auto if_severity = builtin::DiagnosticSeverity::kError;
         auto if_body_severity = builtin::DiagnosticSeverity::kWarning;
         auto else_body_severity = builtin::DiagnosticSeverity::kInfo;
+        auto switch_severity = builtin::DiagnosticSeverity::kError;
+        auto case_severity = builtin::DiagnosticSeverity::kWarning;
         auto attr = [&](auto severity) {
             return utils::Vector{DiagnosticAttribute(severity, "chromium_unreachable_code")};
         };
@@ -60,11 +72,17 @@
         auto* return_foo_elseif = Return();
         auto* return_foo_else = Return();
         auto* return_foo_block = Return();
+        auto* return_foo_case = Return();
+        auto* return_foo_default = Return();
         auto* else_stmt = Block(utils::Vector{return_foo_else}, attr(else_body_severity));
         auto* elseif = If(Expr(false), Block(return_foo_elseif), Else(else_stmt));
         auto* if_foo = If(Expr(true), Block(utils::Vector{return_foo_if}, attr(if_body_severity)),
                           Else(elseif), attr(if_severity));
-        auto* block_1 = Block(utils::Vector{if_foo, return_foo_block}, attr(block_severity));
+        auto* case_stmt =
+            Case(CaseSelector(0_a), Block(utils::Vector{return_foo_case}, attr(case_severity)));
+        auto* swtch = Switch(42_a, utils::Vector{case_stmt, DefaultCase(Block(return_foo_default))},
+                             attr(switch_severity));
+        auto* block_1 = Block(utils::Vector{if_foo, return_foo_block, swtch}, attr(block_severity));
         auto* func_attr = DiagnosticAttribute(func_severity, "chromium_unreachable_code");
         auto* foo = Func("foo", {}, ty.void_(), utils::Vector{block_1}, utils::Vector{func_attr});
 
@@ -86,6 +104,12 @@
         EXPECT_EQ(p.Sem().DiagnosticSeverity(return_foo_elseif, rule), if_severity);
         EXPECT_EQ(p.Sem().DiagnosticSeverity(else_stmt, rule), else_body_severity);
         EXPECT_EQ(p.Sem().DiagnosticSeverity(return_foo_else, rule), else_body_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(swtch, rule), switch_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(swtch->condition, rule), switch_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(case_stmt, rule), switch_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(case_stmt->body, rule), case_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(return_foo_case, rule), case_severity);
+        EXPECT_EQ(p.Sem().DiagnosticSeverity(return_foo_default, rule), switch_severity);
 
         EXPECT_EQ(p.Sem().DiagnosticSeverity(bar, rule), global_severity);
         EXPECT_EQ(p.Sem().DiagnosticSeverity(return_bar, rule), global_severity);
diff --git a/src/tint/writer/glsl/generator_impl_switch_test.cc b/src/tint/writer/glsl/generator_impl_switch_test.cc
index b9ff49d..8e2f739 100644
--- a/src/tint/writer/glsl/generator_impl_switch_test.cc
+++ b/src/tint/writer/glsl/generator_impl_switch_test.cc
@@ -31,7 +31,7 @@
     auto* case_stmt = create<ast::CaseStatement>(utils::Vector{CaseSelector(5_i)}, case_body);
 
     auto* cond = Expr("cond");
-    auto* s = create<ast::SwitchStatement>(cond, utils::Vector{case_stmt, def});
+    auto* s = create<ast::SwitchStatement>(cond, utils::Vector{case_stmt, def}, utils::Empty);
     WrapInFunction(s);
 
     GeneratorImpl& gen = Build();
@@ -58,7 +58,7 @@
                                            def_body);
 
     auto* cond = Expr("cond");
-    auto* s = create<ast::SwitchStatement>(cond, utils::Vector{def});
+    auto* s = create<ast::SwitchStatement>(cond, utils::Vector{def}, utils::Empty);
     WrapInFunction(s);
 
     GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 76075f1..9b47f10 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -1170,6 +1170,14 @@
 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
     {
         auto out = line();
+
+        if (!stmt->attributes.IsEmpty()) {
+            if (!EmitAttributes(out, stmt->attributes)) {
+                return false;
+            }
+            out << " ";
+        }
+
         out << "switch(";
         if (!EmitExpression(out, stmt->condition)) {
             return false;
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl
new file mode 100644
index 0000000..34e342e
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl
@@ -0,0 +1,11 @@
+@group(0) @binding(1) var t : texture_2d<f32>;
+@group(0) @binding(2) var s : sampler;
+
+@fragment
+fn main(@location(0) x : f32) {
+  @diagnostic(warning, derivative_uniformity)
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+    default {
+    }
+  }
+}
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.dxc.hlsl b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..267eab7
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.dxc.hlsl
@@ -0,0 +1,28 @@
+diagnostic_filtering/switch_statement_attribute.wgsl:7:27 warning: 'dpdx' must only be called from uniform control flow
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                          ^^^^^^^^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:24 note: control flow depends on possibly non-uniform value
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                       ^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:15 note: user-defined input 'x' of 'main' may be non-uniform
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+              ^
+
+Texture2D<float4> t : register(t1, space0);
+SamplerState s : register(s2, space0);
+
+struct tint_symbol_1 {
+  float x : TEXCOORD0;
+};
+
+void main_inner(float x) {
+  do {
+  } while (false);
+}
+
+void main(tint_symbol_1 tint_symbol) {
+  main_inner(tint_symbol.x);
+  return;
+}
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.fxc.hlsl b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..267eab7
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.fxc.hlsl
@@ -0,0 +1,28 @@
+diagnostic_filtering/switch_statement_attribute.wgsl:7:27 warning: 'dpdx' must only be called from uniform control flow
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                          ^^^^^^^^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:24 note: control flow depends on possibly non-uniform value
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                       ^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:15 note: user-defined input 'x' of 'main' may be non-uniform
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+              ^
+
+Texture2D<float4> t : register(t1, space0);
+SamplerState s : register(s2, space0);
+
+struct tint_symbol_1 {
+  float x : TEXCOORD0;
+};
+
+void main_inner(float x) {
+  do {
+  } while (false);
+}
+
+void main(tint_symbol_1 tint_symbol) {
+  main_inner(tint_symbol.x);
+  return;
+}
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.glsl b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.glsl
new file mode 100644
index 0000000..608302f
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.glsl
@@ -0,0 +1,32 @@
+diagnostic_filtering/switch_statement_attribute.wgsl:7:27 warning: 'dpdx' must only be called from uniform control flow
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                          ^^^^^^^^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:24 note: control flow depends on possibly non-uniform value
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                       ^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:15 note: user-defined input 'x' of 'main' may be non-uniform
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+              ^
+
+#version 310 es
+precision highp float;
+
+layout(location = 0) in float x_1;
+void tint_symbol(float x) {
+  bool tint_tmp = (x == 0.0f);
+  if (tint_tmp) {
+    tint_tmp = (dFdx(1.0f) == 0.0f);
+  }
+  switch(int((tint_tmp))) {
+    default: {
+      break;
+    }
+  }
+}
+
+void main() {
+  tint_symbol(x_1);
+  return;
+}
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.msl b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.msl
new file mode 100644
index 0000000..c75db53
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.msl
@@ -0,0 +1,32 @@
+diagnostic_filtering/switch_statement_attribute.wgsl:7:27 warning: 'dpdx' must only be called from uniform control flow
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                          ^^^^^^^^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:24 note: control flow depends on possibly non-uniform value
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                       ^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:15 note: user-defined input 'x' of 'main' may be non-uniform
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+              ^
+
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_symbol_2 {
+  float x [[user(locn0)]];
+};
+
+void tint_symbol_inner(float x) {
+  switch(int(((x == 0.0f) && (dfdx(1.0f) == 0.0f)))) {
+    default: {
+      break;
+    }
+  }
+}
+
+fragment void tint_symbol(tint_symbol_2 tint_symbol_1 [[stage_in]]) {
+  tint_symbol_inner(tint_symbol_1.x);
+  return;
+}
+
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.spvasm b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.spvasm
new file mode 100644
index 0000000..6f4c569
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.spvasm
@@ -0,0 +1,76 @@
+diagnostic_filtering/switch_statement_attribute.wgsl:7:27 warning: 'dpdx' must only be called from uniform control flow
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                          ^^^^^^^^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:24 note: control flow depends on possibly non-uniform value
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                       ^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:15 note: user-defined input 'x' of 'main' may be non-uniform
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+              ^
+
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 35
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %x_1
+               OpExecutionMode %main OriginUpperLeft
+               OpName %x_1 "x_1"
+               OpName %t "t"
+               OpName %s "s"
+               OpName %main_inner "main_inner"
+               OpName %x "x"
+               OpName %main "main"
+               OpDecorate %x_1 Location 0
+               OpDecorate %t DescriptorSet 0
+               OpDecorate %t Binding 1
+               OpDecorate %s DescriptorSet 0
+               OpDecorate %s Binding 2
+      %float = OpTypeFloat 32
+%_ptr_Input_float = OpTypePointer Input %float
+        %x_1 = OpVariable %_ptr_Input_float Input
+          %6 = OpTypeImage %float 2D 0 0 0 1 Unknown
+%_ptr_UniformConstant_6 = OpTypePointer UniformConstant %6
+          %t = OpVariable %_ptr_UniformConstant_6 UniformConstant
+          %9 = OpTypeSampler
+%_ptr_UniformConstant_9 = OpTypePointer UniformConstant %9
+          %s = OpVariable %_ptr_UniformConstant_9 UniformConstant
+       %void = OpTypeVoid
+         %10 = OpTypeFunction %void %float
+        %int = OpTypeInt 32 1
+         %18 = OpConstantNull %float
+       %bool = OpTypeBool
+    %float_1 = OpConstant %float 1
+      %int_0 = OpConstant %int 0
+      %int_1 = OpConstant %int 1
+         %30 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %10
+          %x = OpFunctionParameter %float
+         %14 = OpLabel
+         %19 = OpFOrdEqual %bool %x %18
+               OpSelectionMerge %21 None
+               OpBranchConditional %19 %22 %21
+         %22 = OpLabel
+         %23 = OpDPdx %float %float_1
+         %25 = OpFOrdEqual %bool %23 %18
+               OpBranch %21
+         %21 = OpLabel
+         %26 = OpPhi %bool %19 %14 %25 %22
+         %16 = OpSelect %int %26 %int_1 %int_0
+               OpSelectionMerge %15 None
+               OpSwitch %16 %29
+         %29 = OpLabel
+               OpBranch %15
+         %15 = OpLabel
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %30
+         %32 = OpLabel
+         %34 = OpLoad %float %x_1
+         %33 = OpFunctionCall %void %main_inner %34
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.wgsl b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.wgsl
new file mode 100644
index 0000000..1ed3a61
--- /dev/null
+++ b/test/tint/diagnostic_filtering/switch_statement_attribute.wgsl.expected.wgsl
@@ -0,0 +1,23 @@
+diagnostic_filtering/switch_statement_attribute.wgsl:7:27 warning: 'dpdx' must only be called from uniform control flow
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                          ^^^^^^^^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:24 note: control flow depends on possibly non-uniform value
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+                       ^^
+
+diagnostic_filtering/switch_statement_attribute.wgsl:7:15 note: user-defined input 'x' of 'main' may be non-uniform
+  switch (i32(x == 0.0 && dpdx(1.0) == 0.0)) {
+              ^
+
+@group(0) @binding(1) var t : texture_2d<f32>;
+
+@group(0) @binding(2) var s : sampler;
+
+@fragment
+fn main(@location(0) x : f32) {
+  @diagnostic(warning, derivative_uniformity) switch(i32(((x == 0.0) && (dpdx(1.0) == 0.0)))) {
+    default: {
+    }
+  }
+}