Import Tint changes from Dawn

Changes:
  - bd30d9e594226e96c07ec3cd46dbdc68928aef41 tint: uniformity: control flow reconverges after short-ci... by Antonio Maiorano <amaiorano@google.com>
  - a4666888a446e9bd058d08b14a0e7d1e1e0eb430 tint: Fix include layering violation by James Price <jrprice@google.com>
  - b29892be0999aba1f958aafd7256dbf7c6d2fa28 Update src/tint unittests to new @stage format. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: bd30d9e594226e96c07ec3cd46dbdc68928aef41
Change-Id: I1f1cfa3eaea5d2bb74cca21190c38a5311bbe016
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/93060
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/include/tint/tint.h b/include/tint/tint.h
index 2b8430e..c397cdc 100644
--- a/include/tint/tint.h
+++ b/include/tint/tint.h
@@ -33,6 +33,7 @@
 #include "src/tint/transform/robustness.h"
 #include "src/tint/transform/single_entry_point.h"
 #include "src/tint/transform/vertex_pulling.h"
+#include "src/tint/writer/flatten_bindings.h"
 #include "src/tint/writer/writer.h"
 
 #if TINT_BUILD_SPV_READER
diff --git a/src/tint/ast/module_clone_test.cc b/src/tint/ast/module_clone_test.cc
index 65c2406..bd96b26 100644
--- a/src/tint/ast/module_clone_test.cc
+++ b/src/tint/ast/module_clone_test.cc
@@ -99,7 +99,7 @@
   return 1.0;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   f1(1.0, 2);
 }
diff --git a/src/tint/cmd/main.cc b/src/tint/cmd/main.cc
index c067c39..438e4a2 100644
--- a/src/tint/cmd/main.cc
+++ b/src/tint/cmd/main.cc
@@ -34,7 +34,6 @@
 #include "src/tint/utils/io/command.h"
 #include "src/tint/utils/string.h"
 #include "src/tint/val/val.h"
-#include "src/tint/writer/flatten_bindings.h"
 #include "tint/tint.h"
 
 namespace {
diff --git a/src/tint/fuzzers/tint_regex_fuzzer/regex_fuzzer_tests.cc b/src/tint/fuzzers/tint_regex_fuzzer/regex_fuzzer_tests.cc
index b344a69..891a111 100644
--- a/src/tint/fuzzers/tint_regex_fuzzer/regex_fuzzer_tests.cc
+++ b/src/tint/fuzzers/tint_regex_fuzzer/regex_fuzzer_tests.cc
@@ -172,16 +172,16 @@
         R"(fn clamp_0acf8f() {
         var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());
       }
-      @stage(vertex)
+      @vertex
       fn vertex_main() -> @builtin(position) vec4<f32> {
          clamp_0acf8f();"
          return vec4<f32>();
       }
-      @stage(fragment)
+      @fragment
       fn fragment_main() {
         clamp_0acf8f();
       }
-      @stage(compute) @workgroup_size(1)
+      @compute @workgroup_size(1)
       fn compute_main() {"
         var<private> foo: f32 = 0.0;
         clamp_0acf8f();
@@ -192,13 +192,13 @@
     std::vector<std::pair<size_t, size_t>> ground_truth = {
         std::make_pair(3, 12),   std::make_pair(28, 3),  std::make_pair(37, 4),
         std::make_pair(49, 5),   std::make_pair(60, 3),  std::make_pair(68, 4),
-        std::make_pair(81, 4),   std::make_pair(110, 5), std::make_pair(130, 2),
-        std::make_pair(140, 4),  std::make_pair(151, 7), std::make_pair(169, 4),
-        std::make_pair(190, 12), std::make_pair(216, 6), std::make_pair(228, 3),
-        std::make_pair(251, 5),  std::make_pair(273, 2), std::make_pair(285, 4),
-        std::make_pair(302, 12), std::make_pair(333, 5), std::make_pair(349, 14),
-        std::make_pair(373, 2),  std::make_pair(384, 4), std::make_pair(402, 3),
-        std::make_pair(415, 3),  std::make_pair(420, 3), std::make_pair(439, 12)};
+        std::make_pair(81, 4),   std::make_pair(110, 6), std::make_pair(123, 2),
+        std::make_pair(133, 4),  std::make_pair(144, 7), std::make_pair(162, 4),
+        std::make_pair(183, 12), std::make_pair(209, 6), std::make_pair(221, 3),
+        std::make_pair(244, 8),  std::make_pair(259, 2), std::make_pair(271, 4),
+        std::make_pair(288, 12), std::make_pair(319, 7), std::make_pair(328, 14),
+        std::make_pair(352, 2),  std::make_pair(363, 4), std::make_pair(381, 3),
+        std::make_pair(394, 3),  std::make_pair(399, 3), std::make_pair(418, 12)};
 
     ASSERT_EQ(ground_truth, identifiers_pos);
 }
@@ -208,17 +208,17 @@
         R"(fn clamp_0acf8f() {
         var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());
       }
-      @stage(vertex)
+      @vertex
       fn vertex_main() -> @builtin(position) vec4<f32> {
         clamp_0acf8f();
         var foo_1: i32 = 3;
         return vec4<f32>();
       }
-      @stage(fragment)
+      @fragment
       fn fragment_main() {
         clamp_0acf8f();
       }
-      @stage(compute) @workgroup_size(1)
+      @compute @workgroup_size(1)
       fn compute_main() {
         var<private> foo: f32 = 0.0;
         var foo_2: i32 = 10;
@@ -249,17 +249,17 @@
           var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());
           }
         }
-        @stage(vertex)
+        @vertex
         fn vertex_main() -> @builtin(position) vec4<f32> {
           clamp_0acf8f();
           var foo_1: i32 = 3;
           return vec4<f32>();
         }
-        @stage(fragment)
+        @fragment
         fn fragment_main() {
           clamp_0acf8f();
         }
-        @stage(compute) @workgroup_size(1)
+        @compute @workgroup_size(1)
         fn compute_main() {
           var<private> foo: f32 = 0.0;
           var foo_2: i32 = 10;
@@ -295,17 +295,17 @@
         var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());
         }
       }
-      @stage(vertex)
+      @vertex
       fn vertex_main() -> @builtin(position) vec4<f32> {
         clamp_0acf8f();
         var foo_1: i32 = 3;
         return vec4<f32>();
       }
-      @stage(fragment)
+      @fragment
       fn fragment_main() {
         clamp_0acf8f();
       }
-      @stage(compute) @workgroup_size(1)
+      @compute @workgroup_size(1)
       fn compute_main() {
         var<private> foo: f32 = 0.0;
         var foo_2: i32 = 10;
@@ -334,17 +334,17 @@
         R"(fn clamp_0acf8f() {
         var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());
       }
-      @stage(vertex)
+      @vertex
       fn vertex_main() -> @builtin(position) vec4<f32> {
         clamp_0acf8f();
         var foo_1: i32 = 3;
         return vec4<f32>();
       }
-      @stage(fragment)
+      @fragment
       fn fragment_main() {
         clamp_0acf8f();
       }
-      @stage(compute) @workgroup_size(1)
+      @compute @workgroup_size(1)
       fn compute_main() {
         var<private> foo: f32 = 0.0;
         var foo_2: i32 = 10;
@@ -367,17 +367,17 @@
         R"(fn clamp_0acf8f() {
         var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());return true;
       }
-      @stage(vertex)
+      @vertex
       fn vertex_main() -> @builtin(position) vec4<f32> {
         clamp_0acf8f();
         var foo_1: i32 = 3;
         return vec4<f32>();
       }
-      @stage(fragment)
+      @fragment
       fn fragment_main() {
         clamp_0acf8f();
       }
-      @stage(compute) @workgroup_size(1)
+      @compute @workgroup_size(1)
       fn compute_main() {
         var<private> foo: f32 = 0.0;
         var foo_2: i32 = 10;
@@ -394,17 +394,17 @@
         R"(fn clamp_0acf8f() {
           var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>());
         }
-        @stage(vertex)
+        @vertex
         fn vertex_main() -> @builtin(position) vec4<f32> {
           clamp_0acf8f();
           var foo_1: i32 = 3;
           return vec4<f32>();
         }
-        @stage(fragment)
+        @fragment
         fn fragment_main() {
           clamp_0acf8f();
         }
-        @stage(compute) @workgroup_size(1)
+        @compute @workgroup_size(1)
         fn compute_main() {
           var<private> foo: f32 = 0.0;
           var foo_2: i32 = 10;
@@ -419,7 +419,7 @@
         var foo_3 : i32 = -20;)";
 
     std::vector<size_t> function_positions = GetFunctionBodyPositions(wgsl_code);
-    std::vector<size_t> expected_positions = {187, 607};
+    std::vector<size_t> expected_positions = {180, 586};
     ASSERT_EQ(expected_positions, function_positions);
 }
 
@@ -428,17 +428,17 @@
         R"(fn clamp_0acf8f() {
           var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>())
         }
-        @stage(vertex)
+        @vertex
         fn vertex_main() -> @builtin(position) vec4<f32> {
           clamp_0acf8f()
           var foo_1: i32 = 3
           return vec4<f32>()
         }
-        @stage(fragment)
+        @fragment
         fn fragment_main() {
           clamp_0acf8f();
         }
-        @stage(compute) @workgroup_size(1)
+        @compute @workgroup_size(1)
         fn compute_main() {
           var<private> foo: f32 = 0.0;
           var foo_2: i32 = 10;
@@ -461,17 +461,17 @@
         R"(fn clamp_0acf8f() {
           var res: vec2<f32> = clamp(vec2<f32>(), vec2<f32>(), vec2<f32>())
         }
-        @stage(vertex)
+        @vertex
         fn vertex_main() -> @builtin(position) vec4<f32> {
           clamp_0acf8f()
           var foo_1: i32 = 3
           return vec4<f32>()
         }
-        @stage(fragment)
+        @fragment
         fn fragment_main() {
           clamp_0acf8f();
         }
-        @stage(compute) @workgroup_size(1)
+        @compute @workgroup_size(1)
         fn compute_main() {
           var<private> foo: f32 = 0.0;
           var foo_2: i32 = 10;
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 033f5b8..986e169 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -2485,7 +2485,7 @@
 
 TEST_F(InspectorGetSamplerTextureUsesTest, None) {
     std::string shader = R"(
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2501,7 +2501,7 @@
 @group(0) @binding(1) var mySampler: sampler;
 @group(0) @binding(2) var myTexture: texture_2d<f32>;
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return textureSample(myTexture, mySampler, fragUV) * fragPosition;
@@ -2524,7 +2524,7 @@
 @group(0) @binding(1) var mySampler: sampler;
 @group(0) @binding(2) var myTexture: texture_2d<f32>;
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return textureSample(myTexture, mySampler, fragUV) * fragPosition;
@@ -2540,7 +2540,7 @@
 @group(0) @binding(1) var mySampler: sampler;
 @group(0) @binding(2) var myTexture: texture_2d<f32>;
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return textureSample(myTexture, mySampler, fragUV) * fragPosition;
@@ -2565,7 +2565,7 @@
   return textureSample(t, s, uv);
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return doSample(myTexture, mySampler, fragUV) * fragPosition;
@@ -2592,7 +2592,7 @@
   return textureSample(myTexture, s, uv);
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return doSample(mySampler, fragUV) * fragPosition;
@@ -2619,7 +2619,7 @@
   return textureSample(t, mySampler, uv);
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return doSample(myTexture, fragUV) * fragPosition;
@@ -2646,7 +2646,7 @@
   return textureSample(myTexture, mySampler, uv);
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return doSample(fragUV) * fragPosition;
@@ -2686,19 +2686,19 @@
   return X(t, s, uv) + Y(t, s, uv);
 }
 
-@stage(fragment)
+@fragment
 fn via_call(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return Z(myTexture, mySampler, fragUV) * fragPosition;
 }
 
-@stage(fragment)
+@fragment
 fn via_ptr(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return textureSample(myTexture, mySampler, fragUV) + fragPosition;
 }
 
-@stage(fragment)
+@fragment
 fn direct(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return textureSample(myTexture, mySampler, fragUV) + fragPosition;
@@ -2836,7 +2836,7 @@
 // Test calling GetUsedExtensionNames on a shader with no extension.
 TEST_F(InspectorGetUsedExtensionNamesTest, None) {
     std::string shader = R"(
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2851,7 +2851,7 @@
     std::string shader = R"(
 enable f16;
 
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2869,7 +2869,7 @@
 enable f16;
 enable f16;
 
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2893,7 +2893,7 @@
 // Test calling GetEnableDirectives on a shader with no extension.
 TEST_F(InspectorGetEnableDirectivesTest, None) {
     std::string shader = R"(
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2908,7 +2908,7 @@
     std::string shader = R"(
 enable f16;
 
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2927,7 +2927,7 @@
 enable f16;
 
 enable f16;
-@stage(fragment)
+@fragment
 fn main() {
 })";
 
@@ -2952,7 +2952,7 @@
   return textureSample(t, s, uv);
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(0) fragUV: vec2<f32>,
         @location(1) fragPosition: vec4<f32>) -> @location(0) vec4<f32> {
   return doSample(myTexture, mySampler, fragUV) * fragPosition;
diff --git a/src/tint/reader/spirv/function_call_test.cc b/src/tint/reader/spirv/function_call_test.cc
index aa1789b..584ad16 100644
--- a/src/tint/reader/spirv/function_call_test.cc
+++ b/src/tint/reader/spirv/function_call_test.cc
@@ -59,7 +59,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn x_100() {
   x_100_1();
 }
@@ -185,7 +185,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn x_100() {
   x_100_1();
 }
diff --git a/src/tint/reader/spirv/function_memory_test.cc b/src/tint/reader/spirv/function_memory_test.cc
index ed39e21..2bb98e3 100644
--- a/src/tint/reader/spirv/function_memory_test.cc
+++ b/src/tint/reader/spirv/function_memory_test.cc
@@ -810,7 +810,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   main_1();
 }
@@ -850,7 +850,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   main_1();
 }
diff --git a/src/tint/reader/spirv/parser_impl_function_decl_test.cc b/src/tint/reader/spirv/parser_impl_function_decl_test.cc
index fe6b1e0..459460d 100644
--- a/src/tint/reader/spirv/parser_impl_function_decl_test.cc
+++ b/src/tint/reader/spirv/parser_impl_function_decl_test.cc
@@ -126,7 +126,7 @@
 )")) << program_ast;
 
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(vertex)
+@vertex
 fn main() -> main_out {
 )"));
 }
@@ -144,7 +144,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(fragment)
+@fragment
 fn main() {
 )"));
 }
@@ -162,7 +162,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(compute) @workgroup_size(1i, 1i, 1i)
+@compute @workgroup_size(1i, 1i, 1i)
 fn main() {
 )"));
 }
@@ -182,11 +182,11 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(fragment)
+@fragment
 fn first_shader() {
 )"));
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(fragment)
+@fragment
 fn second_shader() {
 )"));
 }
@@ -208,7 +208,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(compute) @workgroup_size(2i, 4i, 8i)
+@compute @workgroup_size(2i, 4i, 8i)
 fn comp_main() {
 )")) << program_ast;
 }
@@ -233,7 +233,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(compute) @workgroup_size(3i, 5i, 7i)
+@compute @workgroup_size(3i, 5i, 7i)
 fn comp_main() {
 )")) << program_ast;
 }
@@ -262,7 +262,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(compute) @workgroup_size(3i, 5i, 7i)
+@compute @workgroup_size(3i, 5i, 7i)
 fn comp_main() {
 )")) << program_ast;
 }
@@ -290,7 +290,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(compute) @workgroup_size(3i, 5i, 7i)
+@compute @workgroup_size(3i, 5i, 7i)
 fn comp_main() {
 )")) << program_ast;
 }
@@ -323,7 +323,7 @@
     Program program = p->program();
     const auto program_ast = test::ToString(program);
     EXPECT_THAT(program_ast, HasSubstr(R"(
-@stage(compute) @workgroup_size(3i, 5i, 7i)
+@compute @workgroup_size(3i, 5i, 7i)
 fn comp_main() {
 )")) << program_ast;
 }
@@ -409,7 +409,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn x_100() {
   x_100_1();
 }
diff --git a/src/tint/reader/spirv/parser_impl_module_var_test.cc b/src/tint/reader/spirv/parser_impl_module_var_test.cc
index 622d383..e4ac8cb 100644
--- a/src/tint/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/tint/reader/spirv/parser_impl_module_var_test.cc
@@ -460,7 +460,7 @@
   gl_Position : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(gl_Position);
@@ -519,7 +519,7 @@
   gl_Position : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(gl_Position);
@@ -575,7 +575,7 @@
   gl_Position : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(gl_Position);
@@ -634,7 +634,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_2);
@@ -690,7 +690,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_2);
@@ -746,7 +746,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_2);
@@ -782,7 +782,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_2);
@@ -1650,7 +1650,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_index) x_1_param : u32) {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -1763,7 +1763,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_index) x_1_param : u32) {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -1816,7 +1816,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_index) x_1_param : u32) {
   x_1 = x_1_param;
   main_1();
@@ -1846,7 +1846,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_index) x_1_param : u32) {
   x_1 = x_1_param;
   main_1();
@@ -1875,7 +1875,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_index) x_1_param : u32) {
   x_1 = x_1_param;
   main_1();
@@ -1989,7 +1989,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = x_1_param;
   main_1();
@@ -2021,7 +2021,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = x_1_param;
   main_1();
@@ -2053,7 +2053,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = x_1_param;
   main_1();
@@ -2084,7 +2084,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = bitcast<i32>(x_1_param);
   main_1();
@@ -2116,7 +2116,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = bitcast<i32>(x_1_param);
   main_1();
@@ -2148,7 +2148,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = bitcast<i32>(x_1_param);
   main_1();
@@ -2202,7 +2202,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i]);
@@ -2239,7 +2239,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i]);
@@ -2276,7 +2276,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i]);
@@ -2312,7 +2312,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(bitcast<u32>(x_1[0i]));
@@ -2349,7 +2349,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(bitcast<u32>(x_1[0i]));
@@ -2386,7 +2386,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(bitcast<u32>(x_1[0i]));
@@ -2425,7 +2425,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = x_1_param;
   main_1();
@@ -2469,7 +2469,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i]);
@@ -2527,7 +2527,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -2565,7 +2565,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -2602,7 +2602,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -2638,7 +2638,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -2676,7 +2676,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -2713,7 +2713,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -2797,7 +2797,7 @@
   position_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -2835,7 +2835,7 @@
   position_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -2872,7 +2872,7 @@
   position_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -2931,7 +2931,7 @@
   position_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -2969,7 +2969,7 @@
   position_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -3006,7 +3006,7 @@
   position_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -3149,7 +3149,7 @@
   return;
 }
 
-@stage(compute) @workgroup_size(1i, 1i, 1i)
+@compute @workgroup_size(1i, 1i, 1i)
 fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) {
   x_1 = ${assignment_value};
   main_1();
@@ -3195,7 +3195,7 @@
   return;
 }
 
-@stage(compute) @workgroup_size(1i, 1i, 1i)
+@compute @workgroup_size(1i, 1i, 1i)
 fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) {
   x_1 = ${assignment_value};
   main_1();
@@ -3240,7 +3240,7 @@
   return;
 }
 
-@stage(compute) @workgroup_size(1i, 1i, 1i)
+@compute @workgroup_size(1i, 1i, 1i)
 fn main(@builtin(${wgsl_builtin}) x_1_param : ${unsigned_wgsl_type}) {
   x_1 = ${assignment_value};
   main_1();
@@ -3552,7 +3552,7 @@
   x_4_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(0) @interpolate(flat) x_1_param : u32, @location(30) @interpolate(flat) x_3_param : u32) -> main_out {
   x_1 = x_1_param;
   x_3 = x_3_param;
@@ -3603,7 +3603,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = x_1_param;
   main_1();
@@ -3652,7 +3652,7 @@
   x_4_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) x_1_param : u32) -> main_out {
   x_1 = bitcast<i32>(x_1_param);
   main_1();
@@ -3692,7 +3692,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = x_1_param;
   main_1();
@@ -3730,7 +3730,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_mask) x_1_param : u32) {
   x_1[0i] = bitcast<i32>(x_1_param);
   main_1();
@@ -3776,7 +3776,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i]);
@@ -3822,7 +3822,7 @@
   x_1_1 : u32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(bitcast<u32>(x_1[0i]));
@@ -3865,7 +3865,7 @@
   x_1_1 : f32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1);
@@ -3899,7 +3899,7 @@
   gl_Position : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(gl_Position);
@@ -3968,7 +3968,7 @@
   gl_Position : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(gl_Position);
@@ -4025,7 +4025,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@location(4) x_1_param : f32, @location(5) x_1_param_1 : f32, @location(6) x_1_param_2 : f32) -> main_out {
   x_1[0i] = x_1_param;
   x_1[1i] = x_1_param_1;
@@ -4083,7 +4083,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@location(9) x_1_param : vec4<f32>, @location(10) x_1_param_1 : vec4<f32>) -> main_out {
   x_1[0i] = x_1_param;
   x_1[1i] = x_1_param_1;
@@ -4150,7 +4150,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@location(9) x_1_param : f32, @location(10) x_1_param_1 : vec4<f32>) -> main_out {
   x_1.alice = x_1_param;
   x_1.bob = x_1_param_1;
@@ -4209,7 +4209,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@location(7) x_1_param : vec4<f32>, @location(8) x_1_param_1 : vec4<f32>, @location(9) x_1_param_2 : vec4<f32>, @location(10) x_1_param_3 : vec4<f32>) -> main_out {
   x_1[0i][0i] = x_1_param;
   x_1[0i][1i] = x_1_param_1;
@@ -4276,7 +4276,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i], x_1[1i], x_1[2i], x_2);
@@ -4335,7 +4335,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_1[0i], x_1[1i], x_2);
@@ -4404,7 +4404,7 @@
   x_2_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_1.alice, x_1.bob, x_2);
@@ -4479,7 +4479,7 @@
   x_3_2 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@location(9) x_1_param : f32, @location(11) x_1_param_1 : vec4<f32>) -> main_out {
   x_1.alice = x_1_param;
   x_1.bob = x_1_param_1;
@@ -4558,7 +4558,7 @@
   x_10_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main(@location(1) @interpolate(flat) x_1_param : u32, @location(2) @interpolate(flat) x_2_param : vec2<u32>, @location(3) @interpolate(flat) x_3_param : i32, @location(4) @interpolate(flat) x_4_param : vec2<i32>, @location(5) @interpolate(flat) x_5_param : f32, @location(6) @interpolate(flat) x_6_param : vec2<f32>) -> main_out {
   x_1 = x_1_param;
   x_2 = x_2_param;
@@ -4653,7 +4653,7 @@
   x_10_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_1, x_2, x_3, x_4, x_5, x_6, x_10);
@@ -4703,7 +4703,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(1) @interpolate(flat) x_1_param : f32, @location(2) @interpolate(flat) x_1_param_1 : f32, @location(5) @interpolate(flat) x_2_param : f32, @location(6) @interpolate(flat) x_2_param_1 : f32) {
   x_1[0i] = x_1_param;
   x_1[1i] = x_1_param_1;
@@ -4777,7 +4777,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(1) x_1_param : f32, @location(2) @interpolate(perspective, centroid) x_2_param : f32, @location(3) @interpolate(perspective, sample) x_3_param : f32, @location(4) @interpolate(linear) x_4_param : f32, @location(5) @interpolate(linear, centroid) x_5_param : f32, @location(6) @interpolate(linear, sample) x_6_param : f32) {
   x_1 = x_1_param;
   x_2 = x_2_param;
@@ -4844,7 +4844,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(1) x_1_param : f32, @location(2) @interpolate(perspective, centroid) x_1_param_1 : f32, @location(3) @interpolate(perspective, sample) x_1_param_2 : f32, @location(4) @interpolate(linear) x_1_param_3 : f32, @location(5) @interpolate(linear, centroid) x_1_param_4 : f32, @location(6) @interpolate(linear, sample) x_1_param_5 : f32) {
   x_1.field0 = x_1_param;
   x_1.field1 = x_1_param_1;
@@ -4935,7 +4935,7 @@
   x_6_1 : f32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1, x_2, x_3, x_4, x_5, x_6);
@@ -5014,7 +5014,7 @@
   x_1_6 : f32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> main_out {
   main_1();
   return main_out(x_1.field0, x_1.field1, x_1.field2, x_1.field3, x_1.field4, x_1.field5);
@@ -5100,7 +5100,7 @@
   x_10_1 : vec4<f32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> main_out {
   main_1();
   return main_out(x_1, x_2, x_3, x_4, x_5, x_6, x_10);
@@ -5163,7 +5163,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn main(@location(1) @interpolate(flat) x_1_param : u32, @location(2) @interpolate(flat) x_2_param : vec2<u32>, @location(3) @interpolate(flat) x_3_param : i32, @location(4) @interpolate(flat) x_4_param : vec2<i32>, @location(5) x_5_param : f32, @location(6) x_6_param : vec2<f32>) {
   x_1 = x_1_param;
   x_2 = x_2_param;
diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
index 1234ae4..ab97d94 100644
--- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
@@ -306,6 +306,7 @@
 )");
 }
 
+// TODO(crbug.com/tint/1503): Remove this when @stage is removed
 TEST_F(ParserImplErrorTest, FunctionDeclStageMissingLParen) {
     EXPECT("@stage vertex) fn f() {}",
            R"(test.wgsl:1:8 error: expected '(' for stage attribute
@@ -566,10 +567,10 @@
 }
 
 TEST_F(ParserImplErrorTest, GlobalDeclInvalidAttribute) {
-    EXPECT("@stage(vertex) x;",
-           R"(test.wgsl:1:16 error: expected declaration after attributes
-@stage(vertex) x;
-               ^
+    EXPECT("@vertex x;",
+           R"(test.wgsl:1:9 error: expected declaration after attributes
+@vertex x;
+        ^
 )");
 }
 
diff --git a/src/tint/reader/wgsl/parser_impl_function_attribute_list_test.cc b/src/tint/reader/wgsl/parser_impl_function_attribute_list_test.cc
index 83ebeac..0773ffd 100644
--- a/src/tint/reader/wgsl/parser_impl_function_attribute_list_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_function_attribute_list_test.cc
@@ -18,6 +18,7 @@
 namespace tint::reader::wgsl {
 namespace {
 
+// TODO(crbug.com/tint/1503): Remove this when @stage is removed
 TEST_F(ParserImplTest, AttributeList_Parses_Stage) {
     auto p = parser("@workgroup_size(2) @stage(compute)");
     auto attrs = p->attribute_list();
diff --git a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
index 15be222..44f3d78 100644
--- a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
@@ -212,6 +212,7 @@
     EXPECT_EQ(p->error(), "1:22: expected workgroup_size z parameter");
 }
 
+// TODO(crbug.com/tint/1503): Remove when @stage is removed
 TEST_F(ParserImplTest, Attribute_Stage) {
     auto p = parser("stage(compute)");
     auto attr = p->attribute();
diff --git a/src/tint/reader/wgsl/parser_impl_function_decl_test.cc b/src/tint/reader/wgsl/parser_impl_function_decl_test.cc
index 181501a..4d7857b 100644
--- a/src/tint/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_function_decl_test.cc
@@ -137,7 +137,7 @@
 
 TEST_F(ParserImplTest, FunctionDecl_AttributeList_MultipleEntries) {
     auto p = parser(R"(
-@workgroup_size(2, 3, 4) @stage(compute)
+@workgroup_size(2, 3, 4) @compute
 fn main() { return; })");
     auto attrs = p->attribute_list();
     EXPECT_FALSE(p->has_error()) << p->error();
@@ -186,7 +186,7 @@
 TEST_F(ParserImplTest, FunctionDecl_AttributeList_MultipleLists) {
     auto p = parser(R"(
 @workgroup_size(2, 3, 4)
-@stage(compute)
+@compute
 fn main() { return; })");
     auto attributes = p->attribute_list();
     EXPECT_FALSE(p->has_error()) << p->error();
diff --git a/src/tint/reader/wgsl/parser_impl_test.cc b/src/tint/reader/wgsl/parser_impl_test.cc
index 66291c5..99ca25d 100644
--- a/src/tint/reader/wgsl/parser_impl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_test.cc
@@ -24,7 +24,7 @@
 
 TEST_F(ParserImplTest, Parses) {
     auto p = parser(R"(
-@stage(fragment)
+@fragment
 fn main() -> @location(0) vec4<f32> {
   return vec4<f32>(.4, .2, .3, 1);
 }
@@ -112,7 +112,7 @@
  * /* I can nest /**/ comments. */
  * // I can nest line comments too.
  **/
-@stage(fragment) // This is the stage
+@fragment // This is the stage
 fn main(/*
 no
 parameters
@@ -126,7 +126,7 @@
 
 TEST_F(ParserImplTest, Comments_UnterminatedBlockComment) {
     auto p = parser(R"(
-@stage(fragment)
+@fragment
 fn main() -> @location(0) vec4<f32> {
   return vec4<f32>(.4, .2, .3, 1);
 } /* unterminated block comments are invalid ...)");
diff --git a/src/tint/reader/wgsl/parser_test.cc b/src/tint/reader/wgsl/parser_test.cc
index d343c7c..1f0e206 100644
--- a/src/tint/reader/wgsl/parser_test.cc
+++ b/src/tint/reader/wgsl/parser_test.cc
@@ -32,7 +32,7 @@
 
 TEST_F(ParserTest, Parses) {
     Source::File file("test.wgsl", R"(
-@stage(fragment)
+@fragment
 fn main() -> @location(0) vec4<f32> {
   return vec4<f32>(.4, .2, .3, 1.);
 }
diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc
index 770d8d0..ab40296 100644
--- a/src/tint/resolver/builtin_validation_test.cc
+++ b/src/tint/resolver/builtin_validation_test.cc
@@ -34,7 +34,7 @@
 }
 
 TEST_F(ResolverBuiltinValidationTest, InvalidPipelineStageDirect) {
-    // @stage(compute) @workgroup_size(1) fn func { return dpdx(1.0); }
+    // @compute @workgroup_size(1) fn func { return dpdx(1.0); }
 
     auto* dpdx =
         create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"), ast::ExpressionList{Expr(1_f)});
@@ -49,7 +49,7 @@
     // fn f0 { return dpdx(1.0); }
     // fn f1 { f0(); }
     // fn f2 { f1(); }
-    // @stage(compute) @workgroup_size(1) fn main { return f2(); }
+    // @compute @workgroup_size(1) fn main { return f2(); }
 
     auto* dpdx =
         create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"), ast::ExpressionList{Expr(1_f)});
diff --git a/src/tint/resolver/builtins_validation_test.cc b/src/tint/resolver/builtins_validation_test.cc
index 9744e9b..0c59485 100644
--- a/src/tint/resolver/builtins_validation_test.cc
+++ b/src/tint/resolver/builtins_validation_test.cc
@@ -126,7 +126,7 @@
                          testing::ValuesIn(cases));
 
 TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main(
     //   @builtin(frag_depth) fd: f32,
     // ) -> @location(0) f32 { return 1.0; }
@@ -144,7 +144,7 @@
     // struct MyInputs {
     //   @builtin(frag_depth) ff: f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(arg: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* s = Structure(
@@ -165,7 +165,7 @@
     // struct S {
     //   @builtin(vertex_index) idx: u32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader() { var s : S; }
 
     Structure("S", {Member("idx", ty.u32(), {Builtin(ast::Builtin::kVertexIndex)})});
@@ -181,7 +181,7 @@
     // struct MyInputs {
     //   @builtin(kPosition) p: vec4<u32>;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* m = Member("position", ty.vec4<u32>(),
@@ -195,7 +195,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_ReturnType_Fail) {
-    // @stage(vertex)
+    // @vertex
     // fn main() -> @builtin(position) f32 { return 1.0; }
     Func("main", {}, ty.f32(), {Return(1_f)}, {Stage(ast::PipelineStage::kVertex)},
          {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
@@ -208,7 +208,7 @@
     // struct MyInputs {
     //   @builtin(kFragDepth) p: i32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* m = Member("frag_depth", ty.i32(),
@@ -225,7 +225,7 @@
     // struct MyInputs {
     //   @builtin(sample_mask) m: f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* s = Structure(
@@ -240,7 +240,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_ReturnType_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn main() -> @builtin(sample_mask) i32 { return 1; }
     Func("main", {}, ty.i32(), {Return(1_i)}, {Stage(ast::PipelineStage::kFragment)},
          {Builtin(Source{{12, 34}}, ast::Builtin::kSampleMask)});
@@ -250,7 +250,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main(
     //   @builtin(sample_mask) arg: bool
     // ) -> @location(0) f32 { return 1.0; }
@@ -266,7 +266,7 @@
     // struct MyInputs {
     //   @builtin(sample_index) m: f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* s = Structure(
@@ -281,7 +281,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main(
     //   @builtin(sample_index) arg: bool
     // ) -> @location(0) f32 { return 1.0; }
@@ -294,7 +294,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main(
     //   @builtin(kPosition) p: vec3<f32>,
     // ) -> @location(0) f32 { return 1.0; }
@@ -307,7 +307,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main() -> @builtin(kFragDepth) f32 { var fd: i32; return fd; }
     auto* fd = Var("fd", ty.i32());
     Func("fs_main", {}, ty.i32(), {Decl(fd), Return(fd)},
@@ -318,7 +318,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, VertexIndexIsNotU32_Fail) {
-    // @stage(vertex)
+    // @vertex
     // fn main(
     //   @builtin(kVertexIndex) vi : f32,
     //   @builtin(kPosition) p :vec4<f32>
@@ -334,7 +334,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, InstanceIndexIsNotU32) {
-    // @stage(vertex)
+    // @vertex
     // fn main(
     //   @builtin(kInstanceIndex) ii : f32,
     //   @builtin(kPosition) p :vec4<f32>
@@ -350,7 +350,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltin_Pass) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main(
     //   @builtin(kPosition) p: vec4<f32>,
     //   @builtin(front_facing) ff: bool,
@@ -369,7 +369,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) {
-    // @stage(vertex)
+    // @vertex
     // fn main(
     //   @builtin(vertex_index) vi : u32,
     //   @builtin(instance_index) ii : u32,
@@ -392,7 +392,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_Pass) {
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn main(
     //   @builtin(local_invocationId) li_id: vec3<u32>,
     //   @builtin(local_invocationIndex) li_index: u32,
@@ -493,7 +493,7 @@
     //   @builtin(sample_index) si: u32;
     //   @builtin(sample_mask) sm : u32;;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(arg: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* s = Structure(
@@ -509,7 +509,7 @@
 }
 
 TEST_F(ResolverBuiltinsValidationTest, FrontFacingParamIsNotBool_Fail) {
-    // @stage(fragment)
+    // @fragment
     // fn fs_main(
     //   @builtin(front_facing) is_front: i32;
     // ) -> @location(0) f32 { return 1.0; }
@@ -528,7 +528,7 @@
     // struct MyInputs {
     //   @builtin(front_facing) pos: f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn fragShader(is_front: MyInputs) -> @location(0) f32 { return 1.0; }
 
     auto* s = Structure(
diff --git a/src/tint/resolver/call_validation_test.cc b/src/tint/resolver/call_validation_test.cc
index 84d2a51..4aa6350 100644
--- a/src/tint/resolver/call_validation_test.cc
+++ b/src/tint/resolver/call_validation_test.cc
@@ -163,7 +163,7 @@
     // fn bar(p: ptr<function, i32>) {
     // foo(p);
     // }
-    // @stage(fragment)
+    // @fragment
     // fn main() {
     //   var v: i32;
     //   bar(&v);
@@ -185,7 +185,7 @@
 
 TEST_F(ResolverCallValidationTest, LetPointer) {
     // fn x(p : ptr<function, i32>) -> i32 {}
-    // @stage(fragment)
+    // @fragment
     // fn main() {
     //   var v: i32;
     //   let p: ptr<function, i32> = &v;
@@ -214,7 +214,7 @@
     // let p: ptr<private, i32> = &v;
     // fn foo(p : ptr<private, i32>) -> i32 {}
     // var v: i32;
-    // @stage(fragment)
+    // @fragment
     // fn main() {
     //   var c: i32 = foo(p);
     // }
diff --git a/src/tint/resolver/entry_point_validation_test.cc b/src/tint/resolver/entry_point_validation_test.cc
index 3028d9b..5e5df15 100644
--- a/src/tint/resolver/entry_point_validation_test.cc
+++ b/src/tint/resolver/entry_point_validation_test.cc
@@ -47,7 +47,7 @@
 class ResolverEntryPointValidationTest : public TestHelper, public testing::Test {};
 
 TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Location) {
-    // @stage(fragment)
+    // @fragment
     // fn main() -> @location(0) f32 { return 1.0; }
     Func(Source{{12, 34}}, "main", {}, ty.f32(), {Return(1_f)},
          {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
@@ -56,7 +56,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Builtin) {
-    // @stage(vertex)
+    // @vertex
     // fn main() -> @builtin(position) vec4<f32> { return vec4<f32>(); }
     Func(Source{{12, 34}}, "main", {}, ty.vec4<f32>(), {Return(Construct(ty.vec4<f32>()))},
          {Stage(ast::PipelineStage::kVertex)}, {Builtin(ast::Builtin::kPosition)});
@@ -65,7 +65,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Missing) {
-    // @stage(vertex)
+    // @vertex
     // fn main() -> f32 {
     //   return 1.0;
     // }
@@ -77,7 +77,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Multiple) {
-    // @stage(vertex)
+    // @vertex
     // fn main() -> @location(0) @builtin(position) vec4<f32> {
     //   return vec4<f32>();
     // }
@@ -95,7 +95,7 @@
     //   @location(0) a : f32;
     //   @builtin(frag_depth) b : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -112,7 +112,7 @@
     // struct Output {
     //   @location(0) @builtin(frag_depth) a : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -134,7 +134,7 @@
     //   @location(0) a : f32;
     //   b : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -154,7 +154,7 @@
     //   @builtin(frag_depth) a : f32;
     //   @builtin(frag_depth) b : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -172,7 +172,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) {
-    // @stage(fragment)
+    // @fragment
     // fn main(@location(0) param : f32) {}
     auto* param = Param("param", ty.f32(), {Location(0)});
     Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)});
@@ -181,7 +181,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) {
-    // @stage(fragment)
+    // @fragment
     // fn main(param : f32) {}
     auto* param = Param(Source{{13, 43}}, "param", ty.vec4<f32>());
     Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)});
@@ -191,7 +191,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) {
-    // @stage(fragment)
+    // @fragment
     // fn main(@location(0) @builtin(sample_index) param : u32) {}
     auto* param = Param(
         "param", ty.u32(),
@@ -208,7 +208,7 @@
     //   @location(0) a : f32;
     //   @builtin(sample_index) b : u32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param : Input) {}
     auto* input =
         Structure("Input", {Member("a", ty.f32(), {Location(0)}),
@@ -223,7 +223,7 @@
     // struct Input {
     //   @location(0) @builtin(sample_index) a : u32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param : Input) {}
     auto* input =
         Structure("Input", {Member("a", ty.u32(),
@@ -243,7 +243,7 @@
     //   @location(0) a : f32;
     //   b : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param : Input) {}
     auto* input = Structure("Input", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)}),
                                       Member(Source{{14, 52}}, "b", ty.f32(), {})});
@@ -256,7 +256,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, Parameter_DuplicateBuiltins) {
-    // @stage(fragment)
+    // @fragment
     // fn main(@builtin(sample_index) param_a : u32,
     //         @builtin(sample_index) param_b : u32) {}
     auto* param_a = Param("param_a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)});
@@ -277,7 +277,7 @@
     // struct InputB {
     //   @builtin(sample_index) a : u32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param_a : InputA, param_b : InputB) {}
     auto* input_a =
         Structure("InputA", {Member("a", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})});
@@ -296,7 +296,7 @@
 }
 
 TEST_F(ResolverEntryPointValidationTest, VertexShaderMustReturnPosition) {
-    // @stage(vertex)
+    // @vertex
     // fn main() {}
     Func(Source{{12, 34}}, "main", {}, ty.void_(), {}, {Stage(ast::PipelineStage::kVertex)});
 
@@ -337,7 +337,7 @@
 };
 
 TEST_P(TypeValidationTest, BareInputs) {
-    // @stage(fragment)
+    // @fragment
     // fn main(@location(0) @interpolate(flat) a : *) {}
     auto params = GetParam();
     auto* a = Param("a", params.create_ast_type(*this), {Location(0), Flat()});
@@ -354,7 +354,7 @@
     // struct Input {
     //   @location(0) @interpolate(flat) a : *;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(a : Input) {}
     auto params = GetParam();
     auto* input =
@@ -370,7 +370,7 @@
 }
 
 TEST_P(TypeValidationTest, BareOutputs) {
-    // @stage(fragment)
+    // @fragment
     // fn main() -> @location(0) * {
     //   return *();
     // }
@@ -390,7 +390,7 @@
     // struct Output {
     //   @location(0) a : *;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -416,7 +416,7 @@
 using LocationAttributeTests = ResolverTest;
 
 TEST_F(LocationAttributeTests, Pass) {
-    // @stage(fragment)
+    // @fragment
     // fn frag_main(@location(0) @interpolate(flat) a: i32) {}
 
     auto* p = Param(Source{{12, 34}}, "a", ty.i32(), {Location(0), Flat()});
@@ -426,7 +426,7 @@
 }
 
 TEST_F(LocationAttributeTests, BadType_Input_bool) {
-    // @stage(fragment)
+    // @fragment
     // fn frag_main(@location(0) a: bool) {}
 
     auto* p = Param(Source{{12, 34}}, "a", ty.bool_(), {Location(Source{{34, 56}}, 0)});
@@ -441,7 +441,7 @@
 }
 
 TEST_F(LocationAttributeTests, BadType_Output_Array) {
-    // @stage(fragment)
+    // @fragment
     // fn frag_main()->@location(0) array<f32, 2> { return array<f32, 2>(); }
 
     Func(Source{{12, 34}}, "frag_main", {}, ty.array<f32, 2>(),
@@ -460,7 +460,7 @@
     // struct Input {
     //   a : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(@location(0) param : Input) {}
     auto* input = Structure("Input", {Member("a", ty.f32())});
     auto* param = Param(Source{{12, 34}}, "param", ty.Of(input), {Location(Source{{13, 43}}, 0)});
@@ -481,7 +481,7 @@
     // struct Input {
     //   a : Inner;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param : Input) {}
     auto* inner = Structure("Inner", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)})});
     auto* input = Structure("Input", {Member(Source{{14, 52}}, "a", ty.Of(inner))});
@@ -498,7 +498,7 @@
     // struct Input {
     //   @location(0) a : array<f32>;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param : Input) {}
     auto* input =
         Structure("Input", {Member(Source{{13, 43}}, "a", ty.array<f32>(), {Location(0)})});
@@ -515,7 +515,7 @@
 
 TEST_F(LocationAttributeTests, BadMemberType_Input) {
     // struct S { @location(0) m: array<i32>; };
-    // @stage(fragment)
+    // @fragment
     // fn frag_main( a: S) {}
 
     auto* m = Member(Source{{34, 56}}, "m", ty.array<i32>(),
@@ -535,7 +535,7 @@
 
 TEST_F(LocationAttributeTests, BadMemberType_Output) {
     // struct S { @location(0) m: atomic<i32>; };
-    // @stage(fragment)
+    // @fragment
     // fn frag_main() -> S {}
     auto* m = Member(Source{{34, 56}}, "m", ty.atomic<i32>(),
                      ast::AttributeList{Location(Source{{12, 34}}, 0u)});
@@ -572,7 +572,7 @@
     //   @location(0) a : f32;
     //   @builtin(frag_depth) b : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -589,7 +589,7 @@
     // struct Output {
     //   a : f32;
     // };
-    // @stage(vertex)
+    // @vertex
     // fn main() -> @location(0) Output {
     //   return Output();
     // }
@@ -612,7 +612,7 @@
     // struct Output {
     //   a : Inner;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output { return Output(); }
     auto* inner = Structure("Inner", {Member(Source{{13, 43}}, "a", ty.f32(), {Location(0)})});
     auto* output = Structure("Output", {Member(Source{{14, 52}}, "a", ty.Of(inner))});
@@ -629,7 +629,7 @@
     // struct Output {
     //   @location(0) a : array<f32>;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main() -> Output {
     //   return Output();
     // }
@@ -695,7 +695,7 @@
 }
 
 TEST_F(LocationAttributeTests, Duplicate_input) {
-    // @stage(fragment)
+    // @fragment
     // fn main(@location(1) param_a : f32,
     //         @location(1) param_b : f32) {}
     auto* param_a = Param("param_a", ty.f32(), {Location(1)});
@@ -714,7 +714,7 @@
     // struct InputB {
     //   @location(1) a : f32;
     // };
-    // @stage(fragment)
+    // @fragment
     // fn main(param_a : InputA, param_b : InputB) {}
     auto* input_a = Structure("InputA", {Member("a", ty.f32(), {Location(1)})});
     auto* input_b = Structure("InputB", {Member("a", ty.f32(), {Location(Source{{34, 56}}, 1)})});
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 3be06a9..cf8ddab 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -342,7 +342,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) {
-    // @stage(compute) @workgroup_size(1) fn entrypoint() {}
+    // @compute @workgroup_size(1) fn entrypoint() {}
     // fn func() { return entrypoint(); }
     Func("entrypoint", {}, ty.void_(), {},
          {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i)});
@@ -359,8 +359,8 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, PipelineStage_MustBeUnique_Fail) {
-    // @stage(fragment)
-    // @stage(vertex)
+    // @fragment
+    // @vertex
     // fn main() { return; }
     Func(Source{{12, 34}}, "main", {}, ty.void_(),
          {
@@ -425,7 +425,7 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
     // let x = 4u;
     // let x = 8u;
-    // @stage(compute) @workgroup_size(x, y, 16u)
+    // @compute @workgroup_size(x, y, 16u)
     // fn main() {}
     auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
     auto* y = GlobalConst("y", ty.u32(), Expr(8_u));
@@ -447,7 +447,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
-    // @stage(compute) @workgroup_size(1i, 2i, 3i)
+    // @compute @workgroup_size(1i, 2i, 3i)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -457,7 +457,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
-    // @stage(compute) @workgroup_size(1u, 2u, 3u)
+    // @compute @workgroup_size(1u, 2u, 3u)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -467,7 +467,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32_AInt) {
-    // @stage(compute) @workgroup_size(1, 2i, 3)
+    // @compute @workgroup_size(1, 2i, 3)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -477,7 +477,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
-    // @stage(compute) @workgroup_size(1u, 2, 3u)
+    // @compute @workgroup_size(1u, 2, 3u)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -487,7 +487,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
-    // @stage(compute) @workgroup_size(1u, 2, 3_i)
+    // @compute @workgroup_size(1u, 2, 3_i)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -499,7 +499,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_I32) {
-    // @stage(compute) @workgroup_size(1_i, 2u, 3)
+    // @compute @workgroup_size(1_i, 2u, 3)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -512,7 +512,7 @@
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
     // let x = 64u;
-    // @stage(compute) @workgroup_size(1i, x)
+    // @compute @workgroup_size(1i, x)
     // fn main() {}
     GlobalConst("x", ty.u32(), Expr(64_u));
     Func("main", {}, ty.void_(), {},
@@ -526,7 +526,7 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
     // let x = 64u;
     // let y = 32i;
-    // @stage(compute) @workgroup_size(x, y)
+    // @compute @workgroup_size(x, y)
     // fn main() {}
     GlobalConst("x", ty.u32(), Expr(64_u));
     GlobalConst("y", ty.i32(), Expr(32_i));
@@ -540,7 +540,7 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
     // let x = 4u;
     // let x = 8u;
-    // @stage(compute) @workgroup_size(x, y, 16i)
+    // @compute @workgroup_size(x, y, 16i)
     // fn main() {}
     GlobalConst("x", ty.u32(), Expr(4_u));
     GlobalConst("y", ty.u32(), Expr(8_u));
@@ -553,7 +553,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
-    // @stage(compute) @workgroup_size(64.0)
+    // @compute @workgroup_size(64.0)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -566,7 +566,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
-    // @stage(compute) @workgroup_size(-2i)
+    // @compute @workgroup_size(-2i)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -577,7 +577,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Zero) {
-    // @stage(compute) @workgroup_size(0i)
+    // @compute @workgroup_size(0i)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
@@ -589,7 +589,7 @@
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
     // let x = 64.0;
-    // @stage(compute) @workgroup_size(x)
+    // @compute @workgroup_size(x)
     // fn main() {}
     GlobalConst("x", ty.f32(), Expr(64_f));
     Func("main", {}, ty.void_(), {},
@@ -603,7 +603,7 @@
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
     // let x = -2i;
-    // @stage(compute) @workgroup_size(x)
+    // @compute @workgroup_size(x)
     // fn main() {}
     GlobalConst("x", ty.i32(), Expr(-2_i));
     Func("main", {}, ty.void_(), {},
@@ -615,7 +615,7 @@
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Zero) {
     // let x = 0i;
-    // @stage(compute) @workgroup_size(x)
+    // @compute @workgroup_size(x)
     // fn main() {}
     GlobalConst("x", ty.i32(), Expr(0_i));
     Func("main", {}, ty.void_(), {},
@@ -627,7 +627,7 @@
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_NestedZeroValueConstructor) {
     // let x = i32(i32(i32()));
-    // @stage(compute) @workgroup_size(x)
+    // @compute @workgroup_size(x)
     // fn main() {}
     GlobalConst("x", ty.i32(), Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32()))));
     Func("main", {}, ty.void_(), {},
@@ -639,7 +639,7 @@
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
     // var<private> x = 64i;
-    // @stage(compute) @workgroup_size(x)
+    // @compute @workgroup_size(x)
     // fn main() {}
     Global("x", ty.i32(), ast::StorageClass::kPrivate, Expr(64_i));
     Func("main", {}, ty.void_(), {},
@@ -652,7 +652,7 @@
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
-    // @stage(compute) @workgroup_size(i32(1))
+    // @compute @workgroup_size(i32(1))
     // fn main() {}
     Func("main", {}, ty.void_(), {},
          {Stage(ast::PipelineStage::kCompute),
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index d62ece9..22d5606 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -143,7 +143,7 @@
     kSwitchCaseWithAbstractCase,
 
     // @workgroup_size(target_expr, abstract_expr, 123)
-    // @stage(compute)
+    // @compute
     // fn f() {}
     kWorkgroupSize
 };
@@ -608,7 +608,7 @@
     kSwitch,
 
     // @workgroup_size(abstract_expr)
-    // @stage(compute)
+    // @compute
     // fn f() {}
     kWorkgroupSize,
 
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 07c661c..16725ba 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -912,7 +912,7 @@
 }
 
 TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn main() {}
     auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
 
@@ -930,7 +930,7 @@
 }
 
 TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
-    // @stage(compute) @workgroup_size(8, 2, 3)
+    // @compute @workgroup_size(8, 2, 3)
     // fn main() {}
     auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
                       {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8_i, 2_i, 3_i)});
@@ -952,7 +952,7 @@
     // let width = 16i;
     // let height = 8i;
     // let depth = 2i;
-    // @stage(compute) @workgroup_size(width, height, depth)
+    // @compute @workgroup_size(width, height, depth)
     // fn main() {}
     GlobalConst("width", ty.i32(), Expr(16_i));
     GlobalConst("height", ty.i32(), Expr(8_i));
@@ -977,7 +977,7 @@
 TEST_F(ResolverTest, Function_WorkgroupSize_Consts_NestedInitializer) {
     // let width = i32(i32(i32(8i)));
     // let height = i32(i32(i32(4i)));
-    // @stage(compute) @workgroup_size(width, height)
+    // @compute @workgroup_size(width, height)
     // fn main() {}
     GlobalConst("width", ty.i32(),
                 Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8_i))));
@@ -1003,7 +1003,7 @@
     // @id(0) override width = 16i;
     // @id(1) override height = 8i;
     // @id(2) override depth = 2i;
-    // @stage(compute) @workgroup_size(width, height, depth)
+    // @compute @workgroup_size(width, height, depth)
     // fn main() {}
     auto* width = Override("width", ty.i32(), Expr(16_i), {Id(0)});
     auto* height = Override("height", ty.i32(), Expr(8_i), {Id(1)});
@@ -1029,7 +1029,7 @@
     // @id(0) override width : i32;
     // @id(1) override height : i32;
     // @id(2) override depth : i32;
-    // @stage(compute) @workgroup_size(width, height, depth)
+    // @compute @workgroup_size(width, height, depth)
     // fn main() {}
     auto* width = Override("width", ty.i32(), nullptr, {Id(0)});
     auto* height = Override("height", ty.i32(), nullptr, {Id(1)});
@@ -1054,7 +1054,7 @@
 TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
     // @id(1) override height = 2i;
     // let depth = 3i;
-    // @stage(compute) @workgroup_size(8, height, depth)
+    // @compute @workgroup_size(8, height, depth)
     // fn main() {}
     auto* height = Override("height", ty.i32(), Expr(2_i), {Id(0)});
     GlobalConst("depth", ty.i32(), Expr(3_i));
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index ff1b57b..a5e68ce 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -369,7 +369,7 @@
 }
 
 TEST_F(ResolverTypeValidationTest, RuntimeArrayInFunction_Fail) {
-    /// @stage(vertex)
+    /// @vertex
     // fn func() { var a : array<i32>; }
 
     auto* var = Var(Source{{12, 34}}, "a", ty.array<i32>(), ast::StorageClass::kNone);
@@ -552,7 +552,7 @@
 
 TEST_F(ResolverTypeValidationTest, RuntimeArrayAsParameter_Fail) {
     // fn func(a : array<u32>) {}
-    // @stage(vertex) fn main() {}
+    // @vertex fn main() {}
 
     auto* param = Param(Source{{12, 34}}, "a", ty.array<i32>());
 
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 6188521..97612a4 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -996,6 +996,10 @@
                     v1_cf->AddEdge(v1);
 
                     auto [cf2, v2] = ProcessExpression(v1_cf, b->rhs);
+
+                    if (sem_.Get(b)->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
+                        return std::pair<Node*, Node*>(cf, v2);
+                    }
                     return std::pair<Node*, Node*>(cf2, v2);
                 } else {
                     auto [cf1, v1] = ProcessExpression(cf, b->lhs);
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index 4a71045..7ed4a6b 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -683,7 +683,7 @@
                        public ::testing::TestWithParam<BuiltinEntry> {};
 TEST_P(ComputeBuiltin, AsParam) {
     std::string src = R"(
-@stage(compute) @workgroup_size(64)
+@compute @workgroup_size(64)
 fn main(@builtin()" + GetParam().name +
                       R"() b : )" + GetParam().type + R"() {
   if (all(vec3(b) == vec3(0u))) {
@@ -719,7 +719,7 @@
                       R"() b : )" + GetParam().type + R"(
 }
 
-@stage(compute) @workgroup_size(64)
+@compute @workgroup_size(64)
 fn main(s : S) {
   if (all(vec3(s.b) == vec3(0u))) {
     workgroupBarrier();
@@ -767,7 +767,7 @@
   @builtin(local_invocation_index) idx : u32,
 }
 
-@stage(compute) @workgroup_size(64)
+@compute @workgroup_size(64)
 fn main(s : S) {
   if (s.num_groups.x == 0u) {
     workgroupBarrier();
@@ -795,7 +795,7 @@
                         public ::testing::TestWithParam<BuiltinEntry> {};
 TEST_P(FragmentBuiltin, AsParam) {
     std::string src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin()" + GetParam().name +
                       R"() b : )" + GetParam().type + R"() {
   if (u32(vec4(b).x) == 0u) {
@@ -830,7 +830,7 @@
                       R"() b : )" + GetParam().type + R"(
 }
 
-@stage(fragment)
+@fragment
 fn main(s : S) {
   if (u32(vec4(s.b).x) == 0u) {
     dpdx(0.5);
@@ -869,7 +869,7 @@
 
 TEST_F(UniformityAnalysisTest, FragmentLocation) {
     std::string src = R"(
-@stage(fragment)
+@fragment
 fn main(@location(0) l : f32) {
   if (l == 0.0) {
     dpdx(0.5);
@@ -899,7 +899,7 @@
   @location(0) l : f32
 }
 
-@stage(fragment)
+@fragment
 fn main(s : S) {
   if (s.l == 0.0) {
     dpdx(0.5);
@@ -5598,7 +5598,7 @@
   }
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   foo();
 }
@@ -5786,34 +5786,189 @@
 )");
 }
 
-TEST_F(UniformityAnalysisTest, ShortCircuiting_CausesNonUniformControlFlow) {
+TEST_F(UniformityAnalysisTest, ShortCircuiting_NoReconvergeLHS) {
     std::string src = R"(
 @group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
 
 var<private> p : i32;
 
+fn non_uniform_discard_func() -> bool {
+  if (non_uniform_global == 42) {
+    discard;
+  }
+  return false;
+}
+
 fn main() {
-  let b = (non_uniform_global == 42) && false;
+  let b = non_uniform_discard_func() && false;
   workgroupBarrier();
 }
 )";
 
     RunTest(src, false);
     EXPECT_EQ(error_,
-              R"(test:8:3 warning: 'workgroupBarrier' must only be called from uniform control flow
+              R"(test:15:3 warning: 'workgroupBarrier' must only be called from uniform control flow
   workgroupBarrier();
   ^^^^^^^^^^^^^^^^
 
-test:7:38 note: control flow depends on non-uniform value
-  let b = (non_uniform_global == 42) && false;
-                                     ^^
+test:14:11 note: calling 'non_uniform_discard_func' may cause subsequent control flow to be non-uniform
+  let b = non_uniform_discard_func() && false;
+          ^^^^^^^^^^^^^^^^^^^^^^^^
 
-test:7:12 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
-  let b = (non_uniform_global == 42) && false;
-           ^^^^^^^^^^^^^^^^^^
+test:7:3 note: control flow depends on non-uniform value
+  if (non_uniform_global == 42) {
+  ^^
+
+test:7:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+  if (non_uniform_global == 42) {
+      ^^^^^^^^^^^^^^^^^^
 )");
 }
 
+TEST_F(UniformityAnalysisTest, ShortCircuiting_NoReconvergeRHS) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+var<private> p : i32;
+
+fn non_uniform_discard_func() -> bool {
+  if (non_uniform_global == 42) {
+    discard;
+  }
+  return false;
+}
+
+fn main() {
+  let b = false && non_uniform_discard_func();
+  workgroupBarrier();
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:15:3 warning: 'workgroupBarrier' must only be called from uniform control flow
+  workgroupBarrier();
+  ^^^^^^^^^^^^^^^^
+
+test:14:20 note: calling 'non_uniform_discard_func' may cause subsequent control flow to be non-uniform
+  let b = false && non_uniform_discard_func();
+                   ^^^^^^^^^^^^^^^^^^^^^^^^
+
+test:7:3 note: control flow depends on non-uniform value
+  if (non_uniform_global == 42) {
+  ^^
+
+test:7:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+  if (non_uniform_global == 42) {
+      ^^^^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, ShortCircuiting_NoReconvergeBoth) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+var<private> p : i32;
+
+fn non_uniform_discard_func() -> bool {
+  if (non_uniform_global == 42) {
+    discard;
+  }
+  return false;
+}
+
+fn main() {
+  let b = non_uniform_discard_func() && non_uniform_discard_func();
+  workgroupBarrier();
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:15:3 warning: 'workgroupBarrier' must only be called from uniform control flow
+  workgroupBarrier();
+  ^^^^^^^^^^^^^^^^
+
+test:14:41 note: calling 'non_uniform_discard_func' may cause subsequent control flow to be non-uniform
+  let b = non_uniform_discard_func() && non_uniform_discard_func();
+                                        ^^^^^^^^^^^^^^^^^^^^^^^^
+
+test:7:3 note: control flow depends on non-uniform value
+  if (non_uniform_global == 42) {
+  ^^
+
+test:7:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+  if (non_uniform_global == 42) {
+      ^^^^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, ShortCircuiting_ReconvergeLHS) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+var<private> p : i32;
+
+fn uniform_discard_func() -> bool {
+  if (true) {
+    discard;
+  }
+  return false;
+}
+
+fn main() {
+  let b = uniform_discard_func() && false;
+  workgroupBarrier();
+}
+)";
+
+    RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, ShortCircuiting_ReconvergeRHS) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+var<private> p : i32;
+
+fn uniform_discard_func() -> bool {
+  if (true) {
+    discard;
+  }
+  return false;
+}
+
+fn main() {
+  let b = false && uniform_discard_func();
+  workgroupBarrier();
+}
+)";
+
+    RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, ShortCircuiting_ReconvergeBoth) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+var<private> p : i32;
+
+fn uniform_discard_func() -> bool {
+  if (true) {
+    discard;
+  }
+  return false;
+}
+
+fn main() {
+  let b = uniform_discard_func() && uniform_discard_func();
+  workgroupBarrier();
+}
+)";
+
+    RunTest(src, true);
+}
+
 TEST_F(UniformityAnalysisTest, DeadCode_AfterReturn) {
     // Dead code after a return statement shouldn't cause uniformity errors.
     std::string src = R"(
@@ -6228,8 +6383,15 @@
 @group(0) @binding(0) var<uniform> uniform_value : i32;
 @group(0) @binding(1) var<storage, read_write> non_uniform_value : i32;
 
+fn non_uniform_discard_func() -> bool {
+  if (non_uniform_value == 42) {
+    discard;
+  }
+  return false;
+}
+
 fn main() {
-  let b = (non_uniform_value == 0) && true;
+  let b = non_uniform_discard_func() && true;
   if (uniform_value == 42) {
     workgroupBarrier();
   }
@@ -6238,17 +6400,21 @@
 
     RunTest(src, false);
     EXPECT_EQ(error_,
-              R"(test:8:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+              R"(test:15:5 warning: 'workgroupBarrier' must only be called from uniform control flow
     workgroupBarrier();
     ^^^^^^^^^^^^^^^^
 
-test:6:36 note: control flow depends on non-uniform value
-  let b = (non_uniform_value == 0) && true;
-                                   ^^
+test:13:11 note: calling 'non_uniform_discard_func' may cause subsequent control flow to be non-uniform
+  let b = non_uniform_discard_func() && true;
+          ^^^^^^^^^^^^^^^^^^^^^^^^
 
-test:6:12 note: reading from read_write storage buffer 'non_uniform_value' may result in a non-uniform value
-  let b = (non_uniform_value == 0) && true;
-           ^^^^^^^^^^^^^^^^^
+test:6:3 note: control flow depends on non-uniform value
+  if (non_uniform_value == 42) {
+  ^^
+
+test:6:7 note: reading from read_write storage buffer 'non_uniform_value' may result in a non-uniform value
+  if (non_uniform_value == 42) {
+      ^^^^^^^^^^^^^^^^^
 )");
 }
 
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index 9270065..4fe5b2b 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -80,7 +80,7 @@
     // var<workgroup> dst : vec4<f32>;
     // fn f2(){ dst = wg; }
     // fn f1() { f2(); }
-    // @stage(fragment)
+    // @fragment
     // fn f0() {
     //  f1();
     //}
diff --git a/src/tint/transform/add_empty_entry_point_test.cc b/src/tint/transform/add_empty_entry_point_test.cc
index 2f8a7ad..44f9005 100644
--- a/src/tint/transform/add_empty_entry_point_test.cc
+++ b/src/tint/transform/add_empty_entry_point_test.cc
@@ -31,7 +31,7 @@
 
 TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn existing() {}
 )";
 
@@ -42,7 +42,7 @@
     auto* src = R"()";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn unused_entry_point() {
 }
 )";
@@ -54,7 +54,7 @@
 
 TEST_F(AddEmptyEntryPointTest, ExistingEntryPoint) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main() {
 }
 )";
@@ -70,7 +70,7 @@
     auto* src = R"(var<private> unused_entry_point : f32;)";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn unused_entry_point_1() {
 }
 
diff --git a/src/tint/transform/add_spirv_block_attribute_test.cc b/src/tint/transform/add_spirv_block_attribute_test.cc
index 14ba929..455be60 100644
--- a/src/tint/transform/add_spirv_block_attribute_test.cc
+++ b/src/tint/transform/add_spirv_block_attribute_test.cc
@@ -41,7 +41,7 @@
 
 var<private> p : S;
 
-@stage(fragment)
+@fragment
 fn main() {
   p.f = 1.0;
 }
@@ -60,7 +60,7 @@
   f : f32,
 }
 
-@stage(fragment)
+@fragment
 fn main() -> S {
   return S();
 }
@@ -77,7 +77,7 @@
 @group(0) @binding(0)
 var<uniform> u : f32;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u;
 }
@@ -90,7 +90,7 @@
 
 @group(0) @binding(0) var<uniform> u : u_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.inner;
 }
@@ -106,7 +106,7 @@
 @group(0) @binding(0)
 var<uniform> u : array<vec4<f32>, 4u>;
 
-@stage(fragment)
+@fragment
 fn main() {
   let a = u;
 }
@@ -119,7 +119,7 @@
 
 @group(0) @binding(0) var<uniform> u : u_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let a = u.inner;
 }
@@ -137,7 +137,7 @@
 @group(0) @binding(0)
 var<uniform> u : Numbers;
 
-@stage(fragment)
+@fragment
 fn main() {
   let a = u;
 }
@@ -152,7 +152,7 @@
 
 @group(0) @binding(0) var<uniform> u : u_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let a = u.inner;
 }
@@ -172,7 +172,7 @@
 @group(0) @binding(0)
 var<uniform> u : S;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.f;
 }
@@ -185,7 +185,7 @@
 
 @group(0) @binding(0) var<uniform> u : S;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.f;
 }
@@ -209,7 +209,7 @@
 @group(0) @binding(0)
 var<uniform> u : Outer;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.i.f;
 }
@@ -226,7 +226,7 @@
 
 @group(0) @binding(0) var<uniform> u : Outer;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.i.f;
 }
@@ -253,7 +253,7 @@
 @group(0) @binding(1)
 var<uniform> u1 : Inner;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.f;
@@ -278,7 +278,7 @@
 
 @group(0) @binding(1) var<uniform> u1 : u1_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.inner.f;
@@ -305,7 +305,7 @@
 @group(0) @binding(1)
 var<uniform> u : Inner;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = p.i.f;
   let f1 = u.f;
@@ -329,7 +329,7 @@
 
 @group(0) @binding(1) var<uniform> u : u_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = p.i.f;
   let f1 = u.inner.f;
@@ -360,7 +360,7 @@
 @group(0) @binding(2)
 var<uniform> u2 : Inner;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.f;
@@ -388,7 +388,7 @@
 
 @group(0) @binding(2) var<uniform> u2 : u1_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.inner.f;
@@ -410,7 +410,7 @@
 @group(0) @binding(0)
 var<uniform> u : S;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.f;
   let a = array<S, 4>();
@@ -428,7 +428,7 @@
 
 @group(0) @binding(0) var<uniform> u : u_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f = u.inner.f;
   let a = array<S, 4>();
@@ -452,7 +452,7 @@
 @group(0) @binding(1)
 var<uniform> u1 : S;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.f;
   let f1 = u1.f;
@@ -473,7 +473,7 @@
 
 @group(0) @binding(1) var<uniform> u1 : u0_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.inner.f;
   let f1 = u1.inner.f;
@@ -506,7 +506,7 @@
 @group(0) @binding(1)
 var<uniform> u1 : MyInner;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.f;
@@ -535,7 +535,7 @@
 
 @group(0) @binding(1) var<uniform> u1 : u1_block;
 
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.inner.f;
@@ -549,7 +549,7 @@
 
 TEST_F(AddSpirvBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.f;
@@ -574,7 +574,7 @@
 };
 )";
     auto* expect = R"(
-@stage(fragment)
+@fragment
 fn main() {
   let f0 = u0.i.f;
   let f1 = u1.inner.f;
diff --git a/src/tint/transform/array_length_from_uniform_test.cc b/src/tint/transform/array_length_from_uniform_test.cc
index ee0d4b1..109904c 100644
--- a/src/tint/transform/array_length_from_uniform_test.cc
+++ b/src/tint/transform/array_length_from_uniform_test.cc
@@ -40,7 +40,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
 }
 )";
@@ -57,7 +57,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb.arr);
 }
@@ -75,7 +75,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb.arr);
 }
@@ -94,7 +94,7 @@
     auto* src = R"(
 @group(0) @binding(0) var<storage, read> sb : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb);
 }
@@ -109,7 +109,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = (tint_symbol_1.buffer_size[0u][0u] / 4u);
 }
@@ -137,7 +137,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb.arr);
 }
@@ -157,7 +157,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
 }
@@ -197,7 +197,7 @@
 @group(3) @binding(2) var<storage, read> sb4 : SB4;
 @group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = arrayLength(&(sb1.arr1));
   var len2 : u32 = arrayLength(&(sb2.arr2));
@@ -240,7 +240,7 @@
 
 @group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
   var len2 : u32 = ((tint_symbol_1.buffer_size[0u][1u] - 16u) / 16u);
@@ -289,7 +289,7 @@
 @group(3) @binding(2) var<storage, read> sb4 : SB4;
 @group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = arrayLength(&(sb1.arr1));
   var len3 : u32 = arrayLength(&sb3);
@@ -329,7 +329,7 @@
 
 @group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
   var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
@@ -363,7 +363,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = &(sb.arr);
 }
@@ -397,7 +397,7 @@
 
 @group(1) @binding(2) var<storage, read> sb2 : SB2;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = arrayLength(&(sb1.arr1));
   var len2 : u32 = arrayLength(&(sb2.arr2));
@@ -426,7 +426,7 @@
 
 @group(1) @binding(2) var<storage, read> sb2 : SB2;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
   var len2 : u32 = arrayLength(&(sb2.arr2));
@@ -449,7 +449,7 @@
 
 TEST_F(ArrayLengthFromUniformTest, OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb.arr);
 }
@@ -469,7 +469,7 @@
 
 @group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
 }
diff --git a/src/tint/transform/binding_remapper_test.cc b/src/tint/transform/binding_remapper_test.cc
index 29a96c3..3274886 100644
--- a/src/tint/transform/binding_remapper_test.cc
+++ b/src/tint/transform/binding_remapper_test.cc
@@ -74,7 +74,7 @@
 
 @group(3) @binding(2) var<storage, read> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -99,7 +99,7 @@
 
 @group(3) @binding(2) var<storage, read> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -113,7 +113,7 @@
 
 @group(3) @binding(2) var<storage, read> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -143,7 +143,7 @@
 
 @group(4) @binding(3) var<storage, read> c : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -159,7 +159,7 @@
 
 @group(4) @binding(3) var<storage, read> c : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -187,7 +187,7 @@
 
 @group(3) @binding(2) var<storage, read> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -201,7 +201,7 @@
 
 @group(6) @binding(7) var<storage, write> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -235,7 +235,7 @@
 
 @group(5) @binding(4) var<storage, read> d : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   let x : i32 = (((a.i + b.i) + c.i) + d.i);
 }
@@ -254,7 +254,7 @@
 
 @internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> d : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   let x : i32 = (((a.i + b.i) + c.i) + d.i);
 }
@@ -287,12 +287,12 @@
 
 @group(5) @binding(4) var<storage, read> d : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f1() {
   let x : i32 = (a.i + c.i);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f2() {
   let x : i32 = (b.i + d.i);
 }
@@ -311,12 +311,12 @@
 
 @group(5) @binding(4) var<storage, read> d : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f1() {
   let x : i32 = (a.i + c.i);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f2() {
   let x : i32 = (b.i + d.i);
 }
@@ -345,7 +345,7 @@
 
 @group(3) @binding(2) var<storage, read> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
diff --git a/src/tint/transform/calculate_array_length_test.cc b/src/tint/transform/calculate_array_length_test.cc
index 9c7c3ac..e2674b0 100644
--- a/src/tint/transform/calculate_array_length_test.cc
+++ b/src/tint/transform/calculate_array_length_test.cc
@@ -38,7 +38,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
 }
 )";
@@ -55,7 +55,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb.arr);
 }
@@ -68,7 +68,7 @@
     auto* src = R"(
 @group(0) @binding(0) var<storage, read> sb : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb);
 }
@@ -80,7 +80,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(sb, &(tint_symbol_1));
@@ -103,7 +103,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len : u32 = arrayLength(&sb.arr);
 }
@@ -120,7 +120,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(sb, &(tint_symbol_1));
@@ -142,7 +142,7 @@
 
 @group(0) @binding(0) var<storage, read> arr : array<S>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let len = arrayLength(&arr);
 }
@@ -157,7 +157,7 @@
 
 @group(0) @binding(0) var<storage, read> arr : array<S>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(arr, &(tint_symbol_1));
@@ -179,7 +179,7 @@
 
 @group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let len = arrayLength(&arr);
 }
@@ -194,7 +194,7 @@
 
 @group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(arr, &(tint_symbol_1));
@@ -212,7 +212,7 @@
     auto* src = R"(
 @group(0) @binding(0) var<storage, read> sb : array<i32>;;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : u32 = arrayLength(&sb);
   var b : u32 = arrayLength(&sb);
@@ -226,7 +226,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(sb, &(tint_symbol_1));
@@ -251,7 +251,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : u32 = arrayLength(&sb.arr);
   var b : u32 = arrayLength(&sb.arr);
@@ -270,7 +270,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(sb, &(tint_symbol_1));
@@ -295,7 +295,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   if (true) {
     var len : u32 = arrayLength(&sb.arr);
@@ -318,7 +318,7 @@
 
 @group(0) @binding(0) var<storage, read> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   if (true) {
     var tint_symbol_1 : u32 = 0u;
@@ -359,7 +359,7 @@
 
 @group(0) @binding(2) var<storage, read> sb3 : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = arrayLength(&(sb1.arr1));
   var len2 : u32 = arrayLength(&(sb2.arr2));
@@ -394,7 +394,7 @@
 
 @group(0) @binding(2) var<storage, read> sb3 : array<i32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(sb1, &(tint_symbol_1));
@@ -427,7 +427,7 @@
 @group(0) @binding(0) var<storage, read> a : SB;
 @group(0) @binding(1) var<storage, read> b : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x = &a;
   var a : u32 = arrayLength(&a.arr);
@@ -451,7 +451,7 @@
 
 @group(0) @binding(1) var<storage, read> b : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(a, &(tint_symbol_1));
@@ -473,7 +473,7 @@
 
 TEST_F(CalculateArrayLengthTest, OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var len1 : u32 = arrayLength(&(sb1.arr1));
   var len2 : u32 = arrayLength(&(sb2.arr2));
@@ -508,7 +508,7 @@
 @internal(intrinsic_buffer_size)
 fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var tint_symbol_1 : u32 = 0u;
   tint_symbol(sb1, &(tint_symbol_1));
diff --git a/src/tint/transform/canonicalize_entry_point_io_test.cc b/src/tint/transform/canonicalize_entry_point_io_test.cc
index 68ccb8b..f17c5f5 100644
--- a/src/tint/transform/canonicalize_entry_point_io_test.cc
+++ b/src/tint/transform/canonicalize_entry_point_io_test.cc
@@ -38,11 +38,11 @@
     // Test that we do not introduce wrapper functions when there is no shader IO
     // to process.
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main() {
 }
 )";
@@ -58,7 +58,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Parameters_Spirv) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(1) loc1 : f32,
              @location(2) @interpolate(flat) loc2 : vec4<u32>,
              @builtin(position) coord : vec4<f32>) {
@@ -77,7 +77,7 @@
   var col : f32 = (coord.x * loc1);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   frag_main_inner(loc1_1, loc2_1, coord_1);
 }
@@ -92,7 +92,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Parameters_Msl) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(1) loc1 : f32,
              @location(2) @interpolate(flat) loc2 : vec4<u32>,
              @builtin(position) coord : vec4<f32>) {
@@ -112,7 +112,7 @@
   var col : f32 = (coord.x * loc1);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(@builtin(position) coord : vec4<f32>, tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, coord);
 }
@@ -127,7 +127,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Parameters_Hlsl) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(1) loc1 : f32,
              @location(2) @interpolate(flat) loc2 : vec4<u32>,
              @builtin(position) coord : vec4<f32>) {
@@ -149,7 +149,7 @@
   var col : f32 = (coord.x * loc1);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc1, tint_symbol.loc2, tint_symbol.coord);
 }
@@ -166,7 +166,7 @@
     auto* src = R"(
 type myf32 = f32;
 
-@stage(fragment)
+@fragment
 fn frag_main(@location(1) loc1 : myf32) {
   var x : myf32 = loc1;
 }
@@ -184,7 +184,7 @@
   var x : myf32 = loc1;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc1);
 }
@@ -199,7 +199,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Parameter_TypeAlias_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(1) loc1 : myf32) {
   var x : myf32 = loc1;
 }
@@ -217,7 +217,7 @@
   var x : myf32 = loc1;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc1);
 }
@@ -242,7 +242,7 @@
   @location(2) @interpolate(flat) loc2 : vec4<u32>,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main(@location(0) loc0 : f32,
              locations : FragLocations,
              builtins : FragBuiltins) {
@@ -272,7 +272,7 @@
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   frag_main_inner(loc0_1, FragLocations(loc1_1, loc2_1), FragBuiltins(coord_1));
 }
@@ -287,7 +287,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Spirv_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(0) loc0 : f32,
              locations : FragLocations,
              builtins : FragBuiltins) {
@@ -316,7 +316,7 @@
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   frag_main_inner(loc0_1, FragLocations(loc1_1, loc2_1), FragBuiltins(coord_1));
 }
@@ -348,7 +348,7 @@
   @location(2) @interpolate(flat) loc2 : vec4<u32>,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main(@location(0) loc0 : f32,
              locations : FragLocations,
              builtins : FragBuiltins) {
@@ -379,7 +379,7 @@
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(@builtin(position) coord : vec4<f32>, tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(coord));
 }
@@ -394,7 +394,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, StructParameters_kMsl_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(0) loc0 : f32,
              locations : FragLocations,
              builtins : FragBuiltins) {
@@ -424,7 +424,7 @@
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(@builtin(position) coord : vec4<f32>, tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(coord));
 }
@@ -456,7 +456,7 @@
   @location(2) @interpolate(flat) loc2 : vec4<u32>,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main(@location(0) loc0 : f32,
              locations : FragLocations,
              builtins : FragBuiltins) {
@@ -489,7 +489,7 @@
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(tint_symbol.coord));
 }
@@ -504,7 +504,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, StructParameters_Hlsl_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(@location(0) loc0 : f32,
              locations : FragLocations,
              builtins : FragBuiltins) {
@@ -536,7 +536,7 @@
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) {
   frag_main_inner(tint_symbol.loc0, FragLocations(tint_symbol.loc1, tint_symbol.loc2), FragBuiltins(tint_symbol.coord));
 }
@@ -560,7 +560,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Spirv) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> @builtin(frag_depth) f32 {
   return 1.0;
 }
@@ -573,7 +573,7 @@
   return 1.0;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   let inner_result = frag_main_inner();
   value = inner_result;
@@ -589,7 +589,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Msl) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> @builtin(frag_depth) f32 {
   return 1.0;
 }
@@ -605,7 +605,7 @@
   return 1.0;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -623,7 +623,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Return_NonStruct_Hlsl) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> @builtin(frag_depth) f32 {
   return 1.0;
 }
@@ -639,7 +639,7 @@
   return 1.0;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -663,7 +663,7 @@
   @builtin(sample_mask) mask : u32,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOutput {
   var output : FragOutput;
   output.depth = 1.0;
@@ -694,7 +694,7 @@
   return output;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   let inner_result = frag_main_inner();
   color_1 = inner_result.color;
@@ -712,7 +712,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Spirv_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOutput {
   var output : FragOutput;
   output.depth = 1.0;
@@ -743,7 +743,7 @@
   return output;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   let inner_result = frag_main_inner();
   color_1 = inner_result.color;
@@ -773,7 +773,7 @@
   @builtin(sample_mask) mask : u32,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOutput {
   var output : FragOutput;
   output.depth = 1.0;
@@ -807,7 +807,7 @@
   return output;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -827,7 +827,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Msl_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOutput {
   var output : FragOutput;
   output.depth = 1.0;
@@ -861,7 +861,7 @@
   return output;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -893,7 +893,7 @@
   @builtin(sample_mask) mask : u32,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOutput {
   var output : FragOutput;
   output.depth = 1.0;
@@ -927,7 +927,7 @@
   return output;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -947,7 +947,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Return_Struct_Hlsl_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOutput {
   var output : FragOutput;
   output.depth = 1.0;
@@ -981,7 +981,7 @@
   return output;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -1016,12 +1016,12 @@
   return x.value * x.mul;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
@@ -1049,7 +1049,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1() {
   frag_main1_inner(FragmentInput(value_1, mul_1));
 }
@@ -1058,7 +1058,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2() {
   frag_main2_inner(FragmentInput(value_2, mul_2));
 }
@@ -1073,12 +1073,12 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Spirv_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
@@ -1106,7 +1106,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1() {
   frag_main1_inner(FragmentInput(value_1, mul_1));
 }
@@ -1115,7 +1115,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2() {
   frag_main2_inner(FragmentInput(value_2, mul_2));
 }
@@ -1148,12 +1148,12 @@
   return x.value * x.mul;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
@@ -1180,7 +1180,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(tint_symbol : tint_symbol_1) {
   frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
 }
@@ -1196,7 +1196,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(tint_symbol_2 : tint_symbol_3) {
   frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
 }
@@ -1211,12 +1211,12 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Msl_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
@@ -1243,7 +1243,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(tint_symbol : tint_symbol_1) {
   frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
 }
@@ -1259,7 +1259,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(tint_symbol_2 : tint_symbol_3) {
   frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
 }
@@ -1292,12 +1292,12 @@
   return x.value * x.mul;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
@@ -1324,7 +1324,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(tint_symbol : tint_symbol_1) {
   frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
 }
@@ -1340,7 +1340,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(tint_symbol_2 : tint_symbol_3) {
   frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
 }
@@ -1355,12 +1355,12 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, StructParameters_SharedDeviceFunction_Hlsl_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(inputs : FragmentInput) {
   var x : f32 = foo(inputs);
 }
@@ -1387,7 +1387,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(tint_symbol : tint_symbol_1) {
   frag_main1_inner(FragmentInput(tint_symbol.value, tint_symbol.mul));
 }
@@ -1403,7 +1403,7 @@
   var x : f32 = foo(inputs);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2(tint_symbol_2 : tint_symbol_3) {
   frag_main2_inner(FragmentInput(tint_symbol_2.value, tint_symbol_2.mul));
 }
@@ -1442,7 +1442,7 @@
   return global_inputs.col2 * 2.0;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
  global_inputs = inputs;
  var r : f32 = foo();
@@ -1479,7 +1479,7 @@
   var g : f32 = bar();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(tint_symbol : tint_symbol_1) {
   frag_main1_inner(FragmentInput(tint_symbol.col1, tint_symbol.col2));
 }
@@ -1494,7 +1494,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Struct_ModuleScopeVariable_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main1(inputs : FragmentInput) {
  global_inputs = inputs;
  var r : f32 = foo();
@@ -1531,7 +1531,7 @@
   var g : f32 = bar();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1(tint_symbol : tint_symbol_1) {
   frag_main1_inner(FragmentInput(tint_symbol.col1, tint_symbol.col2));
 }
@@ -1581,7 +1581,7 @@
   return x.col1;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(inputs : MyFragmentInput) -> MyFragmentOutput {
   var x : myf32 = foo(inputs);
   return MyFragmentOutput(x, inputs.col2);
@@ -1628,7 +1628,7 @@
   return MyFragmentOutput(x, inputs.col2);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = frag_main_inner(MyFragmentInput(tint_symbol.col1, tint_symbol.col2));
   var wrapper_result : tint_symbol_2;
@@ -1647,7 +1647,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Struct_TypeAliases_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(inputs : MyFragmentInput) -> MyFragmentOutput {
   var x : myf32 = foo(inputs);
   return MyFragmentOutput(x, inputs.col2);
@@ -1694,7 +1694,7 @@
   return MyFragmentOutput(x, inputs.col2);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = frag_main_inner(MyFragmentInput(tint_symbol.col1, tint_symbol.col2));
   var wrapper_result : tint_symbol_2;
@@ -1745,12 +1745,12 @@
   @location(2) @interpolate(linear, sample) loc2 : f32,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main() -> VertexOut {
   return VertexOut();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(inputs : FragmentIn,
              @location(3) @interpolate(perspective, centroid) loc3 : f32) {
   let x = inputs.loc1 + inputs.loc2 + loc3;
@@ -1785,7 +1785,7 @@
   return VertexOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol;
@@ -1809,7 +1809,7 @@
   let x = ((inputs.loc1 + inputs.loc2) + loc3);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol_1 : tint_symbol_2) {
   frag_main_inner(FragmentIn(tint_symbol_1.loc1, tint_symbol_1.loc2), tint_symbol_1.loc3);
 }
@@ -1824,13 +1824,13 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, InterpolateAttributes_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(inputs : FragmentIn,
              @location(3) @interpolate(perspective, centroid) loc3 : f32) {
   let x = inputs.loc1 + inputs.loc2 + loc3;
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> VertexOut {
   return VertexOut();
 }
@@ -1862,7 +1862,7 @@
   let x = ((inputs.loc1 + inputs.loc2) + loc3);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) {
   frag_main_inner(FragmentIn(tint_symbol.loc1, tint_symbol.loc2), tint_symbol.loc3);
 }
@@ -1882,7 +1882,7 @@
   return VertexOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol_2 {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol_2;
@@ -1939,12 +1939,12 @@
   @location(3) @interpolate(flat) vu : vec4<u32>,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main(in : VertexIn) -> VertexOut {
   return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(inputs : FragmentInterface) -> FragmentInterface {
   return inputs;
 }
@@ -2012,7 +2012,7 @@
   return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner(VertexIn(i_1, u_1, vi_1, vu_1));
   i_2 = inner_result.i;
@@ -2026,7 +2026,7 @@
   return inputs;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   let inner_result_1 = frag_main_inner(FragmentInterface(i_3, u_3, vi_3, vu_3));
   i_4 = inner_result_1.i;
@@ -2047,12 +2047,12 @@
     // Test that we add a Flat attribute to integers that are vertex outputs and
     // fragment inputs, but not vertex inputs or fragment outputs.
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main(in : VertexIn) -> VertexOut {
   return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(inputs : FragmentInterface) -> FragmentInterface {
   return inputs;
 }
@@ -2120,7 +2120,7 @@
   return VertexOut(in.i, in.u, in.vi, in.vu, vec4<f32>());
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner(VertexIn(i_1, u_1, vi_1, vu_1));
   i_2 = inner_result.i;
@@ -2134,7 +2134,7 @@
   return inputs;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   let inner_result_1 = frag_main_inner(FragmentInterface(i_3, u_3, vi_3, vu_3));
   i_4 = inner_result_1.i;
@@ -2179,12 +2179,12 @@
   @builtin(position) @invariant pos : vec4<f32>,
 };
 
-@stage(vertex)
+@vertex
 fn main1() -> VertexOut {
   return VertexOut();
 }
 
-@stage(vertex)
+@vertex
 fn main2() -> @builtin(position) @invariant vec4<f32> {
   return vec4<f32>();
 }
@@ -2204,7 +2204,7 @@
   return VertexOut();
 }
 
-@stage(vertex)
+@vertex
 fn main1() -> tint_symbol {
   let inner_result = main1_inner();
   var wrapper_result : tint_symbol;
@@ -2221,7 +2221,7 @@
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn main2() -> tint_symbol_1 {
   let inner_result_1 = main2_inner();
   var wrapper_result_1 : tint_symbol_1;
@@ -2239,12 +2239,12 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, InvariantAttributes_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main1() -> VertexOut {
   return VertexOut();
 }
 
-@stage(vertex)
+@vertex
 fn main2() -> @builtin(position) @invariant vec4<f32> {
   return vec4<f32>();
 }
@@ -2264,7 +2264,7 @@
   return VertexOut();
 }
 
-@stage(vertex)
+@vertex
 fn main1() -> tint_symbol {
   let inner_result = main1_inner();
   var wrapper_result : tint_symbol;
@@ -2281,7 +2281,7 @@
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn main2() -> tint_symbol_1 {
   let inner_result_1 = main2_inner();
   var wrapper_result_1 : tint_symbol_1;
@@ -2313,7 +2313,7 @@
   @size(16) @location(1) @interpolate(flat) value : f32,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main(inputs : FragmentInput) -> FragmentOutput {
   return FragmentOutput(inputs.coord.x * inputs.value + inputs.loc0);
 }
@@ -2352,7 +2352,7 @@
   return FragmentOutput(((inputs.coord.x * inputs.value) + inputs.loc0));
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = frag_main_inner(FragmentInput(tint_symbol.value, tint_symbol.coord, tint_symbol.loc0));
   var wrapper_result : tint_symbol_2;
@@ -2370,7 +2370,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, Struct_LayoutAttributes_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main(inputs : FragmentInput) -> FragmentOutput {
   return FragmentOutput(inputs.coord.x * inputs.value + inputs.loc0);
 }
@@ -2405,7 +2405,7 @@
   return FragmentOutput(((inputs.coord.x * inputs.value) + inputs.loc0));
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = frag_main_inner(FragmentInput(tint_symbol.value, tint_symbol.coord, tint_symbol.loc0));
   var wrapper_result : tint_symbol_2;
@@ -2451,12 +2451,12 @@
   @location(0) a : f32,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main() -> VertexOutput {
   return VertexOutput();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(@builtin(front_facing) ff : bool,
              @location(2) @interpolate(flat) c : i32,
              inputs : FragmentInputExtra,
@@ -2496,7 +2496,7 @@
   return VertexOutput();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol;
@@ -2526,7 +2526,7 @@
 fn frag_main_inner(ff : bool, c : i32, inputs : FragmentInputExtra, b : u32) {
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol_1 : tint_symbol_2) {
   frag_main_inner(tint_symbol_1.ff, tint_symbol_1.c, FragmentInputExtra(tint_symbol_1.d, tint_symbol_1.pos, tint_symbol_1.a), tint_symbol_1.b);
 }
@@ -2541,12 +2541,12 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, SortedMembers_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main() -> VertexOutput {
   return VertexOutput();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(@builtin(front_facing) ff : bool,
              @location(2) @interpolate(flat) c : i32,
              inputs : FragmentInputExtra,
@@ -2586,7 +2586,7 @@
   return VertexOutput();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol;
@@ -2616,7 +2616,7 @@
 fn frag_main_inner(ff : bool, c : i32, inputs : FragmentInputExtra, b : u32) {
 }
 
-@stage(fragment)
+@fragment
 fn frag_main(tint_symbol_1 : tint_symbol_2) {
   frag_main_inner(tint_symbol_1.ff, tint_symbol_1.c, FragmentInputExtra(tint_symbol_1.d, tint_symbol_1.pos, tint_symbol_1.a), tint_symbol_1.b);
 }
@@ -2645,7 +2645,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, DontRenameSymbols) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn tint_symbol_1(@location(0) col : f32) {
 }
 )";
@@ -2659,7 +2659,7 @@
 fn tint_symbol_1_inner(col : f32) {
 }
 
-@stage(fragment)
+@fragment
 fn tint_symbol_1(tint_symbol : tint_symbol_2) {
   tint_symbol_1_inner(tint_symbol.col);
 }
@@ -2674,7 +2674,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_VoidNoReturn) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
 }
 )";
@@ -2688,7 +2688,7 @@
 fn frag_main_inner() {
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2706,7 +2706,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_VoidWithReturn) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   return;
 }
@@ -2722,7 +2722,7 @@
   return;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2740,7 +2740,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_WithAuthoredMask) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> @builtin(sample_mask) u32 {
   return 7u;
 }
@@ -2756,7 +2756,7 @@
   return 7u;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2774,7 +2774,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_WithoutAuthoredMask) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> @location(0) f32 {
   return 1.0;
 }
@@ -2792,7 +2792,7 @@
   return 1.0;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2817,7 +2817,7 @@
   @location(0) value : f32,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main() -> Output {
   return Output(0.5, 7u, 1.0);
 }
@@ -2843,7 +2843,7 @@
   return Output(0.5, 7u, 1.0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2863,7 +2863,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithAuthoredMask_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> Output {
   return Output(0.5, 7u, 1.0);
 }
@@ -2889,7 +2889,7 @@
   return Output(0.5, 7u, 1.0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2920,7 +2920,7 @@
   @location(0) value : f32,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main() -> Output {
   return Output(0.5, 1.0);
 }
@@ -2945,7 +2945,7 @@
   return Output(0.5, 1.0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -2965,7 +2965,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_StructWithoutAuthoredMask_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() -> Output {
   return Output(0.5, 1.0);
 }
@@ -2990,7 +2990,7 @@
   return Output(0.5, 1.0);
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -3015,22 +3015,22 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, FixedSampleMask_MultipleShaders) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn frag_main1() -> @builtin(sample_mask) u32 {
   return 7u;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2() -> @location(0) f32 {
   return 1.0;
 }
 
-@stage(vertex)
+@vertex
 fn vert_main1() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
 }
 )";
@@ -3045,7 +3045,7 @@
   return 7u;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main1() -> tint_symbol {
   let inner_result = frag_main1_inner();
   var wrapper_result : tint_symbol;
@@ -3064,7 +3064,7 @@
   return 1.0;
 }
 
-@stage(fragment)
+@fragment
 fn frag_main2() -> tint_symbol_1 {
   let inner_result_1 = frag_main2_inner();
   var wrapper_result_1 : tint_symbol_1;
@@ -3082,7 +3082,7 @@
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main1() -> tint_symbol_2 {
   let inner_result_2 = vert_main1_inner();
   var wrapper_result_2 : tint_symbol_2;
@@ -3090,7 +3090,7 @@
   return wrapper_result_2;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
 }
 )";
@@ -3109,7 +3109,7 @@
   @location(1) fixed_sample_mask_1 : vec4<f32>,
 };
 
-@stage(fragment)
+@fragment
 fn frag_main() -> FragOut {
   return FragOut();
 }
@@ -3134,7 +3134,7 @@
   return FragOut();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() -> tint_symbol {
   let inner_result = frag_main_inner();
   var wrapper_result : tint_symbol;
@@ -3154,7 +3154,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnNonStruct_Spirv) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -3169,7 +3169,7 @@
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner();
   value = inner_result;
@@ -3187,7 +3187,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnNonStruct_Msl) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -3205,7 +3205,7 @@
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol;
@@ -3229,7 +3229,7 @@
   @builtin(position) pos : vec4<f32>,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main() -> VertOut {
   return VertOut();
 }
@@ -3248,7 +3248,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner();
   pos_1 = inner_result.pos;
@@ -3266,7 +3266,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Spirv_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main() -> VertOut {
   return VertOut();
 }
@@ -3285,7 +3285,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner();
   pos_1 = inner_result.pos;
@@ -3311,7 +3311,7 @@
   @builtin(position) pos : vec4<f32>,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main() -> VertOut {
   return VertOut();
 }
@@ -3333,7 +3333,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol;
@@ -3353,7 +3353,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_ReturnStruct_Msl_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main() -> VertOut {
   return VertOut();
 }
@@ -3375,7 +3375,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() -> tint_symbol {
   let inner_result = vert_main_inner();
   var wrapper_result : tint_symbol;
@@ -3416,7 +3416,7 @@
   @builtin(position) vertex_point_size_1 : vec4<f32>,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
   let x = collide.collide + collide_1.collide;
   return VertOut();
@@ -3458,7 +3458,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner(VertIn1(collide_2), VertIn2(collide_3));
   vertex_point_size_3 = inner_result.vertex_point_size;
@@ -3477,7 +3477,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Spirv_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
   let x = collide.collide + collide_1.collide;
   return VertOut();
@@ -3517,7 +3517,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main() {
   let inner_result = vert_main_inner(VertIn1(collide_2), VertIn2(collide_3));
   vertex_point_size_3 = inner_result.vertex_point_size;
@@ -3568,7 +3568,7 @@
   @builtin(position) vertex_point_size_1 : vec4<f32>,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
   let x = collide.collide + collide_1.collide;
   return VertOut();
@@ -3610,7 +3610,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
   var wrapper_result : tint_symbol_2;
@@ -3631,7 +3631,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Msl_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
   let x = collide.collide + collide_1.collide;
   return VertOut();
@@ -3673,7 +3673,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
   var wrapper_result : tint_symbol_2;
@@ -3720,7 +3720,7 @@
   @builtin(position) vertex_point_size_1 : vec4<f32>,
 };
 
-@stage(vertex)
+@vertex
 fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
   let x = collide.collide + collide_1.collide;
   return VertOut();
@@ -3762,7 +3762,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
   var wrapper_result : tint_symbol_2;
@@ -3783,7 +3783,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, EmitVertexPointSize_AvoidNameClash_Hlsl_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main(collide : VertIn1, collide_1 : VertIn2) -> VertOut {
   let x = collide.collide + collide_1.collide;
   return VertOut();
@@ -3825,7 +3825,7 @@
   return VertOut();
 }
 
-@stage(vertex)
+@vertex
 fn vert_main(tint_symbol : tint_symbol_1) -> tint_symbol_2 {
   let inner_result = vert_main_inner(VertIn1(tint_symbol.collide), VertIn2(tint_symbol.collide_2));
   var wrapper_result : tint_symbol_2;
@@ -3859,7 +3859,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin(sample_index) sample_index : u32,
         @builtin(sample_mask) mask_in : u32
         ) -> @builtin(sample_mask) u32 {
@@ -3878,7 +3878,7 @@
   return mask_in;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   let inner_result = main_inner(sample_index_1, mask_in_1[0i]);
   value[0i] = inner_result;
@@ -3894,7 +3894,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, GLSLSampleMaskBuiltins) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn fragment_main(@builtin(sample_index) sample_index : u32,
                  @builtin(sample_mask) mask_in : u32
                  ) -> @builtin(sample_mask) u32 {
@@ -3913,7 +3913,7 @@
   return mask_in;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   let inner_result = fragment_main(bitcast<u32>(gl_SampleID), bitcast<u32>(gl_SampleMaskIn[0i]));
   gl_SampleMask[0i] = bitcast<i32>(inner_result);
@@ -3929,7 +3929,7 @@
 
 TEST_F(CanonicalizeEntryPointIOTest, GLSLVertexInstanceIndexBuiltins) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vertex_main(@builtin(vertex_index) vertexID : u32,
                @builtin(instance_index) instanceID : u32
                ) -> @builtin(position) vec4<f32> {
@@ -3948,7 +3948,7 @@
   return vec4<f32>((f32(vertexID) + f32(instanceID)));
 }
 
-@stage(vertex)
+@vertex
 fn main() {
   let inner_result = vertex_main(bitcast<u32>(gl_VertexID), bitcast<u32>(gl_InstanceID));
   gl_Position = inner_result;
diff --git a/src/tint/transform/decompose_memory_access_test.cc b/src/tint/transform/decompose_memory_access_test.cc
index 19f8b2e..4b96bcb 100644
--- a/src/tint/transform/decompose_memory_access_test.cc
+++ b/src/tint/transform/decompose_memory_access_test.cc
@@ -78,7 +78,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = sb.a;
   var b : u32 = sb.b;
@@ -213,7 +213,7 @@
   return arr;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = tint_symbol(sb, 0u);
   var b : u32 = tint_symbol_1(sb, 4u);
@@ -247,7 +247,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = sb.a;
   var b : u32 = sb.b;
@@ -382,7 +382,7 @@
   return arr;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = tint_symbol(sb, 0u);
   var b : u32 = tint_symbol_1(sb, 4u);
@@ -470,7 +470,7 @@
 
 @group(0) @binding(0) var<uniform> ub : UB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = ub.a;
   var b : u32 = ub.b;
@@ -605,7 +605,7 @@
   return arr;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = tint_symbol(ub, 0u);
   var b : u32 = tint_symbol_1(ub, 4u);
@@ -639,7 +639,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = ub.a;
   var b : u32 = ub.b;
@@ -774,7 +774,7 @@
   return arr;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a : i32 = tint_symbol(ub, 0u);
   var b : u32 = tint_symbol_1(ub, 4u);
@@ -862,7 +862,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   sb.a = i32();
   sb.b = u32();
@@ -1014,7 +1014,7 @@
   }
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   tint_symbol(sb, 0u, i32());
   tint_symbol_1(sb, 4u, u32());
@@ -1048,7 +1048,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, SB_BasicStore_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   sb.a = i32();
   sb.b = u32();
@@ -1200,7 +1200,7 @@
   }
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   tint_symbol(sb, 0u, i32());
   tint_symbol_1(sb, 4u, u32());
@@ -1288,7 +1288,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : SB = sb;
 }
@@ -1406,7 +1406,7 @@
   return SB(tint_symbol_1(buffer, (offset + 0u)), tint_symbol_2(buffer, (offset + 4u)), tint_symbol_3(buffer, (offset + 8u)), tint_symbol_4(buffer, (offset + 16u)), tint_symbol_5(buffer, (offset + 24u)), tint_symbol_6(buffer, (offset + 32u)), tint_symbol_7(buffer, (offset + 48u)), tint_symbol_8(buffer, (offset + 64u)), tint_symbol_9(buffer, (offset + 80u)), tint_symbol_10(buffer, (offset + 96u)), tint_symbol_11(buffer, (offset + 112u)), tint_symbol_12(buffer, (offset + 128u)), tint_symbol_13(buffer, (offset + 144u)), tint_symbol_14(buffer, (offset + 160u)), tint_symbol_15(buffer, (offset + 192u)), tint_symbol_16(buffer, (offset + 224u)), tint_symbol_17(buffer, (offset + 256u)), tint_symbol_18(buffer, (offset + 304u)), tint_symbol_19(buffer, (offset + 352u)), tint_symbol_20(buffer, (offset + 384u)), tint_symbol_21(buffer, (offset + 448u)), tint_symbol_22(buffer, (offset + 512u)));
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : SB = tint_symbol(sb, 0u);
 }
@@ -1419,7 +1419,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, LoadStructure_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : SB = sb;
 }
@@ -1537,7 +1537,7 @@
   return SB(tint_symbol_1(buffer, (offset + 0u)), tint_symbol_2(buffer, (offset + 4u)), tint_symbol_3(buffer, (offset + 8u)), tint_symbol_4(buffer, (offset + 16u)), tint_symbol_5(buffer, (offset + 24u)), tint_symbol_6(buffer, (offset + 32u)), tint_symbol_7(buffer, (offset + 48u)), tint_symbol_8(buffer, (offset + 64u)), tint_symbol_9(buffer, (offset + 80u)), tint_symbol_10(buffer, (offset + 96u)), tint_symbol_11(buffer, (offset + 112u)), tint_symbol_12(buffer, (offset + 128u)), tint_symbol_13(buffer, (offset + 144u)), tint_symbol_14(buffer, (offset + 160u)), tint_symbol_15(buffer, (offset + 192u)), tint_symbol_16(buffer, (offset + 224u)), tint_symbol_17(buffer, (offset + 256u)), tint_symbol_18(buffer, (offset + 304u)), tint_symbol_19(buffer, (offset + 352u)), tint_symbol_20(buffer, (offset + 384u)), tint_symbol_21(buffer, (offset + 448u)), tint_symbol_22(buffer, (offset + 512u)));
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : SB = tint_symbol(sb, 0u);
 }
@@ -1604,7 +1604,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   sb = SB();
 }
@@ -1760,7 +1760,7 @@
   tint_symbol_22(buffer, (offset + 512u), value.v);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   tint_symbol(sb, 0u, SB());
 }
@@ -1773,7 +1773,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, StoreStructure_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   sb = SB();
 }
@@ -1929,7 +1929,7 @@
   tint_symbol_22(buffer, (offset + 512u), value.v);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   tint_symbol(sb, 0u, SB());
 }
@@ -1993,7 +1993,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : f32 = sb.b[4].b[1].b.z;
 }
@@ -2030,7 +2030,7 @@
 @internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
 fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : f32 = tint_symbol(sb, 712u);
 }
@@ -2043,7 +2043,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : f32 = sb.b[4].b[1].b.z;
 }
@@ -2080,7 +2080,7 @@
 @internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
 fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var x : f32 = tint_symbol(sb, 712u);
 }
@@ -2133,7 +2133,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2166,7 +2166,7 @@
 @internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
 fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2182,7 +2182,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2215,7 +2215,7 @@
 @internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
 fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2279,7 +2279,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2320,7 +2320,7 @@
 @internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
 fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2336,7 +2336,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2377,7 +2377,7 @@
 @internal(intrinsic_load_storage_f32) @internal(disable_validation__function_has_no_body)
 fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32) -> f32
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var i : i32 = 4;
   var j : u32 = 1u;
@@ -2429,7 +2429,7 @@
 
 @group(0) @binding(0) var<storage, read_write> sb : SB;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   atomicStore(&sb.a, 123);
   atomicLoad(&sb.a);
@@ -2542,7 +2542,7 @@
 @internal(intrinsic_atomic_compare_exchange_weak_storage_u32) @internal(disable_validation__function_has_no_body)
 fn tint_atomicCompareExchangeWeak_1(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32, param_1 : u32, param_2 : u32) -> atomic_compare_exchange_weak_ret_type_1
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   tint_atomicStore(sb, 16u, 123);
   tint_atomicLoad(sb, 16u);
@@ -2576,7 +2576,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   atomicStore(&sb.a, 123);
   atomicLoad(&sb.a);
@@ -2689,7 +2689,7 @@
 @internal(intrinsic_atomic_compare_exchange_weak_storage_u32) @internal(disable_validation__function_has_no_body)
 fn tint_atomicCompareExchangeWeak_1(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, offset : u32, param_1 : u32, param_2 : u32) -> atomic_compare_exchange_weak_ret_type_1
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   tint_atomicStore(sb, 16u, 123);
   tint_atomicLoad(sb, 16u);
@@ -2739,7 +2739,7 @@
 
 var<workgroup> w : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   atomicStore(&(w.a), 123);
   atomicLoad(&(w.a));
@@ -2775,7 +2775,7 @@
 
 TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   atomicStore(&(w.a), 123);
   atomicLoad(&(w.a));
diff --git a/src/tint/transform/decompose_strided_array_test.cc b/src/tint/transform/decompose_strided_array_test.cc
index e911df3..ffc6695 100644
--- a/src/tint/transform/decompose_strided_array_test.cc
+++ b/src/tint/transform/decompose_strided_array_test.cc
@@ -71,7 +71,7 @@
 TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
     // var<private> arr : @stride(4) array<f32, 4u>
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : @stride(4) array<f32, 4u> = a;
     //   let b : f32 = arr[1];
@@ -92,7 +92,7 @@
     auto* expect = R"(
 var<private> arr : array<f32, 4u>;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : array<f32, 4u> = arr;
   let b : f32 = arr[1i];
@@ -107,7 +107,7 @@
 TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
     // var<private> arr : @stride(32) array<f32, 4u>
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : @stride(32) array<f32, 4u> = a;
     //   let b : f32 = arr[1];
@@ -133,7 +133,7 @@
 
 var<private> arr : array<strided_arr, 4u>;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : array<strided_arr, 4u> = arr;
   let b : f32 = arr[1i].el;
@@ -151,7 +151,7 @@
     // };
     // @group(0) @binding(0) var<uniform> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : @stride(32) array<f32, 4u> = s.a;
     //   let b : f32 = s.a[1];
@@ -181,7 +181,7 @@
 
 @group(0) @binding(0) var<uniform> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : array<strided_arr, 4u> = s.a;
   let b : f32 = s.a[1i].el;
@@ -199,7 +199,7 @@
     // };
     // @group(0) @binding(0) var<uniform> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : @stride(16) array<vec4<f32>, 4u> = s.a;
     //   let b : f32 = s.a[1][2];
@@ -227,7 +227,7 @@
 
 @group(0) @binding(0) var<uniform> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : array<vec4<f32>, 4u> = s.a;
   let b : f32 = s.a[1i][2i];
@@ -245,7 +245,7 @@
     // };
     // @group(0) @binding(0) var<storage> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : @stride(32) array<f32, 4u> = s.a;
     //   let b : f32 = s.a[1];
@@ -275,7 +275,7 @@
 
 @group(0) @binding(0) var<storage> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : array<strided_arr, 4u> = s.a;
   let b : f32 = s.a[1i].el;
@@ -293,7 +293,7 @@
     // };
     // @group(0) @binding(0) var<storage> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : @stride(4) array<f32, 4u> = s.a;
     //   let b : f32 = s.a[1];
@@ -318,7 +318,7 @@
 
 @group(0) @binding(0) var<storage> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : array<f32, 4u> = s.a;
   let b : f32 = s.a[1i];
@@ -336,7 +336,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   s.a = @stride(32) array<f32, 4u>();
     //   s.a = @stride(32) array<f32, 4u>(1.0, 2.0, 3.0, 4.0);
@@ -371,7 +371,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   s.a = array<strided_arr, 4u>();
   s.a = array<strided_arr, 4u>(strided_arr(1.0f), strided_arr(2.0f), strided_arr(3.0f), strided_arr(4.0f));
@@ -390,7 +390,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   s.a = @stride(4) array<f32, 4u>();
     //   s.a = @stride(4) array<f32, 4u>(1.0, 2.0, 3.0, 4.0);
@@ -420,7 +420,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   s.a = array<f32, 4u>();
   s.a = array<f32, 4u>(1.0f, 2.0f, 3.0f, 4.0f);
@@ -439,7 +439,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a = &s.a;
     //   let b = &*&*(a);
@@ -479,7 +479,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let c = s.a;
   let d = s.a[1i].el;
@@ -500,7 +500,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : ARR = s.a;
     //   let b : f32 = s.a[1];
@@ -541,7 +541,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : ARR = s.a;
   let b : f32 = s.a[1i].el;
@@ -564,7 +564,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a : ARR_B = s.a;
     //   let b : array<@stride(8) array<f32, 2u>, 3u> = s.a[3];
@@ -641,7 +641,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let a : ARR_B = s.a;
   let b : array<ARR_A, 3u> = s.a[3i].el;
diff --git a/src/tint/transform/decompose_strided_matrix_test.cc b/src/tint/transform/decompose_strided_matrix_test.cc
index 2367cf1..06169d7 100644
--- a/src/tint/transform/decompose_strided_matrix_test.cc
+++ b/src/tint/transform/decompose_strided_matrix_test.cc
@@ -62,7 +62,7 @@
     // };
     // @group(0) @binding(0) var<uniform> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let x : mat2x2<f32> = s.m;
     // }
@@ -99,7 +99,7 @@
   return mat2x2<f32>(arr[0u], arr[1u]);
 }
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
 }
@@ -118,7 +118,7 @@
     // };
     // @group(0) @binding(0) var<uniform> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let x : vec2<f32> = s.m[1];
     // }
@@ -152,7 +152,7 @@
 
 @group(0) @binding(0) var<uniform> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x : vec2<f32> = s.m[1i];
 }
@@ -171,7 +171,7 @@
     // };
     // @group(0) @binding(0) var<uniform> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let x : mat2x2<f32> = s.m;
     // }
@@ -205,7 +205,7 @@
 
 @group(0) @binding(0) var<uniform> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x : mat2x2<f32> = s.m;
 }
@@ -224,7 +224,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let x : mat2x2<f32> = s.m;
     // }
@@ -262,7 +262,7 @@
   return mat2x2<f32>(arr[0u], arr[1u]);
 }
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
 }
@@ -281,7 +281,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let x : vec2<f32> = s.m[1];
     // }
@@ -316,7 +316,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x : vec2<f32> = s.m[1i];
 }
@@ -335,7 +335,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
     // }
@@ -374,7 +374,7 @@
   return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
 }
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0f, 2.0f), vec2<f32>(3.0f, 4.0f)));
 }
@@ -393,7 +393,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   s.m[1] = vec2<f32>(1.0, 2.0);
     // }
@@ -427,7 +427,7 @@
 
 @group(0) @binding(0) var<storage, read_write> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   s.m[1i] = vec2<f32>(1.0f, 2.0f);
 }
@@ -446,7 +446,7 @@
     // };
     // @group(0) @binding(0) var<storage, read_write> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let a = &s.m;
     //   let b = &*&*(a);
@@ -500,7 +500,7 @@
   return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
 }
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x = arr_to_mat2x2_stride_32(s.m);
   let y = s.m[1i];
@@ -523,7 +523,7 @@
     // };
     // var<private> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   let x : mat2x2<f32> = s.m;
     // }
@@ -557,7 +557,7 @@
 
 var<private> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   let x : mat2x2<f32> = s.m;
 }
@@ -576,7 +576,7 @@
     // };
     // var<private> s : S;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn f() {
     //   s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
     // }
@@ -611,7 +611,7 @@
 
 var<private> s : S;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn f() {
   s.m = mat2x2<f32>(vec2<f32>(1.0f, 2.0f), vec2<f32>(3.0f, 4.0f));
 }
diff --git a/src/tint/transform/disable_uniformity_analysis_test.cc b/src/tint/transform/disable_uniformity_analysis_test.cc
index a502442..3bcac0e 100644
--- a/src/tint/transform/disable_uniformity_analysis_test.cc
+++ b/src/tint/transform/disable_uniformity_analysis_test.cc
@@ -54,7 +54,7 @@
     auto* src = R"(
 @group(0) @binding(0) var<storage, read> global : i32;
 
-@stage(compute) @workgroup_size(64)
+@compute @workgroup_size(64)
 fn main() {
   if ((global == 42)) {
     workgroupBarrier();
diff --git a/src/tint/transform/first_index_offset_test.cc b/src/tint/transform/first_index_offset_test.cc
index a467c17..c159261 100644
--- a/src/tint/transform/first_index_offset_test.cc
+++ b/src/tint/transform/first_index_offset_test.cc
@@ -33,7 +33,7 @@
 
 TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn entry() {
   return;
 }
@@ -44,7 +44,7 @@
 
 TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -70,7 +70,7 @@
 
 TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -95,7 +95,7 @@
   return vert_idx;
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   test(vert_idx);
   return vec4<f32>();
@@ -114,7 +114,7 @@
   return vert_idx;
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   test((vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
@@ -135,7 +135,7 @@
 
 TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   test(vert_idx);
   return vec4<f32>();
@@ -154,7 +154,7 @@
 
 @binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   test((vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
@@ -183,7 +183,7 @@
   return inst_idx;
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   test(inst_idx);
   return vec4<f32>();
@@ -202,7 +202,7 @@
   return inst_idx;
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   test((inst_idx + tint_symbol_1.first_instance_index));
   return vec4<f32>();
@@ -223,7 +223,7 @@
 
 TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   test(inst_idx);
   return vec4<f32>();
@@ -242,7 +242,7 @@
 
 @binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   test((inst_idx + tint_symbol_1.first_instance_index));
   return vec4<f32>();
@@ -276,7 +276,7 @@
   @builtin(vertex_index) vert_idx : u32,
 };
 
-@stage(vertex)
+@vertex
 fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
   test(inputs.instance_idx, inputs.vert_idx);
   return vec4<f32>();
@@ -302,7 +302,7 @@
   vert_idx : u32,
 }
 
-@stage(vertex)
+@vertex
 fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
   test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
@@ -323,7 +323,7 @@
 
 TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
   test(inputs.instance_idx, inputs.vert_idx);
   return vec4<f32>();
@@ -347,7 +347,7 @@
 
 @binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
 
-@stage(vertex)
+@vertex
 fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
   test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
@@ -387,7 +387,7 @@
   return func1(vert_idx);
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func2(vert_idx);
   return vec4<f32>();
@@ -410,7 +410,7 @@
   return func1(vert_idx);
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func2((vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
@@ -431,7 +431,7 @@
 
 TEST_F(FirstIndexOffsetTest, NestedCalls_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func2(vert_idx);
   return vec4<f32>();
@@ -454,7 +454,7 @@
 
 @binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func2((vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
@@ -487,19 +487,19 @@
   return i;
 }
 
-@stage(vertex)
+@vertex
 fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func(vert_idx);
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func(vert_idx + inst_idx);
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func(inst_idx);
   return vec4<f32>();
@@ -518,19 +518,19 @@
   return i;
 }
 
-@stage(vertex)
+@vertex
 fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func((vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func((inst_idx + tint_symbol_1.first_instance_index));
   return vec4<f32>();
@@ -551,19 +551,19 @@
 
 TEST_F(FirstIndexOffsetTest, MultipleEntryPoints_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func(vert_idx);
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func(vert_idx + inst_idx);
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func(inst_idx);
   return vec4<f32>();
@@ -582,19 +582,19 @@
 
 @binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
 
-@stage(vertex)
+@vertex
 fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
   func((vert_idx + tint_symbol_1.first_vertex_index));
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
   return vec4<f32>();
 }
 
-@stage(vertex)
+@vertex
 fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
   func((inst_idx + tint_symbol_1.first_instance_index));
   return vec4<f32>();
diff --git a/src/tint/transform/localize_struct_array_assignment_test.cc b/src/tint/transform/localize_struct_array_assignment_test.cc
index ee6df9f..e85a600 100644
--- a/src/tint/transform/localize_struct_array_assignment_test.cc
+++ b/src/tint/transform/localize_struct_array_assignment_test.cc
@@ -46,7 +46,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -69,7 +69,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -88,7 +88,7 @@
 
 TEST_F(LocalizeStructArrayAssignmentTest, StructArray_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -111,7 +111,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -162,7 +162,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -189,7 +189,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -208,7 +208,7 @@
 
 TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -235,7 +235,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -287,7 +287,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -311,7 +311,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -348,7 +348,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -375,7 +375,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -413,7 +413,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s : OuterS;
@@ -441,7 +441,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s : OuterS;
@@ -488,7 +488,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s : OuterS;
@@ -523,7 +523,7 @@
 
 @group(1) @binding(4) var<uniform> uniforms : Uniforms;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s : OuterS;
@@ -545,7 +545,7 @@
 
 TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s : OuterS;
@@ -579,7 +579,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s : OuterS;
@@ -643,7 +643,7 @@
   (*p).a1[uniforms.i] = v;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var s1 : OuterS;
   f(&s1);
@@ -675,7 +675,7 @@
   }
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var s1 : OuterS;
   f(&(s1));
@@ -688,7 +688,7 @@
 
 TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var s1 : OuterS;
   f(&s1);
@@ -714,7 +714,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var s1 : OuterS;
   f(&(s1));
@@ -769,7 +769,7 @@
   *(p) = v;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -797,7 +797,7 @@
   *(p) = v;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var v : InnerS;
   var s1 : OuterS;
@@ -831,7 +831,7 @@
   return (i + 1u);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var s1 : OuterS;
   var v : vec3<f32>;
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param_test.cc b/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
index 580e695..9e81d31 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
@@ -42,14 +42,14 @@
 var<private> p : f32;
 var<workgroup> w : f32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   w = p;
 }
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol : f32;
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol_1 : f32;
@@ -64,7 +64,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Basic_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   w = p;
 }
@@ -74,7 +74,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol : f32;
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol_1 : f32;
@@ -106,7 +106,7 @@
   no_uses();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo(1.0);
 }
@@ -127,7 +127,7 @@
   no_uses();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol_4 : f32;
   @internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol_5 : f32;
@@ -142,7 +142,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo(1.0);
 }
@@ -166,7 +166,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
   @internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol_1 : f32;
@@ -198,14 +198,14 @@
 var<private> a : f32 = 1.0;
 var<private> b : f32 = f32();
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x : f32 = a + b;
 }
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32 = 1.0;
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol_1 : f32 = f32();
@@ -220,7 +220,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Constructors_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x : f32 = a + b;
 }
@@ -230,7 +230,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32 = 1.0;
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol_1 : f32 = f32();
@@ -248,7 +248,7 @@
 var<private> p : f32;
 var<workgroup> w : f32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let p_ptr : ptr<private, f32> = &p;
   let w_ptr : ptr<workgroup, f32> = &w;
@@ -258,7 +258,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
   @internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol_1 : f32;
@@ -276,7 +276,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let p_ptr : ptr<private, f32> = &p;
   let w_ptr : ptr<workgroup, f32> = &w;
@@ -289,7 +289,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
   @internal(disable_validation__ignore_storage_class) var<workgroup> tint_symbol_1 : f32;
@@ -317,7 +317,7 @@
   bar(&v);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo();
 }
@@ -332,7 +332,7 @@
   bar(tint_symbol);
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol_1 : f32;
   foo(&(tint_symbol_1));
@@ -346,7 +346,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, FoldAddressOfDeref_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo();
 }
@@ -363,7 +363,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   @internal(disable_validation__ignore_storage_class) var<private> tint_symbol : f32;
   foo(&(tint_symbol));
@@ -394,7 +394,7 @@
 @group(0) @binding(1)
 var<storage> s : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = u;
   _ = s;
@@ -406,7 +406,7 @@
   a : f32,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_1 : ptr<storage, S>) {
   _ = *(tint_symbol);
   _ = *(tint_symbol_1);
@@ -420,7 +420,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = u;
   _ = s;
@@ -436,7 +436,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_1 : ptr<storage, S>) {
   _ = *(tint_symbol);
   _ = *(tint_symbol_1);
@@ -457,7 +457,7 @@
 @group(0) @binding(0)
 var<storage> buffer : array<f32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = buffer[0];
 }
@@ -468,7 +468,7 @@
   arr : array<f32>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   _ = (*(tint_symbol)).arr[0];
 }
@@ -481,7 +481,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = buffer[0];
 }
@@ -495,7 +495,7 @@
   arr : array<f32>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   _ = (*(tint_symbol)).arr[0];
 }
@@ -515,7 +515,7 @@
   _ = buffer[0];
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo();
 }
@@ -530,7 +530,7 @@
   _ = (*(tint_symbol))[0];
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_1 : ptr<storage, tint_symbol_2>) {
   foo(&((*(tint_symbol_1)).arr));
 }
@@ -543,7 +543,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo();
 }
@@ -560,7 +560,7 @@
   arr : array<f32>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   foo(&((*(tint_symbol)).arr));
 }
@@ -582,7 +582,7 @@
 @group(0) @binding(0)
 var<storage> buffer : myarray;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = buffer[0];
 }
@@ -595,7 +595,7 @@
 
 type myarray = array<f32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   _ = (*(tint_symbol)).arr[0];
 }
@@ -608,7 +608,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = buffer[0];
 }
@@ -623,7 +623,7 @@
   arr : array<f32>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   _ = (*(tint_symbol)).arr[0];
 }
@@ -645,7 +645,7 @@
 @group(0) @binding(0)
 var<storage> buffer : array<S>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = buffer[0];
 }
@@ -660,7 +660,7 @@
   arr : array<S>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   _ = (*(tint_symbol)).arr[0];
 }
@@ -673,7 +673,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_ArrayOfStruct_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = buffer[0];
 }
@@ -694,7 +694,7 @@
   arr : array<S>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<storage, tint_symbol_1>) {
   _ = (*(tint_symbol)).arr[0];
 }
@@ -731,7 +731,7 @@
   no_uses();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo(1.0);
 }
@@ -757,7 +757,7 @@
   no_uses();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_4 : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_5 : ptr<storage, S>) {
   foo(1.0, tint_symbol_4, tint_symbol_5);
 }
@@ -770,7 +770,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo(1.0);
 }
@@ -801,7 +801,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol : ptr<uniform, S>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) @internal(disable_validation__ignore_storage_class) tint_symbol_1 : ptr<storage, S>) {
   foo(1.0, tint_symbol, tint_symbol_1);
 }
@@ -836,7 +836,7 @@
 @group(0) @binding(0) var t : texture_2d<f32>;
 @group(0) @binding(1) var s : sampler;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   _ = t;
   _ = s;
@@ -844,7 +844,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_1 : sampler) {
   _ = tint_symbol;
   _ = tint_symbol_1;
@@ -876,7 +876,7 @@
   no_uses();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo(1.0);
 }
@@ -898,7 +898,7 @@
   no_uses();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol_4 : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_5 : sampler) {
   foo(1.0, tint_symbol_4, tint_symbol_5);
 }
@@ -911,7 +911,7 @@
 
 TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   foo(1.0);
 }
@@ -936,7 +936,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@group(0) @binding(0) @internal(disable_validation__entry_point_parameter) tint_symbol : texture_2d<f32>, @group(0) @binding(1) @internal(disable_validation__entry_point_parameter) tint_symbol_1 : sampler) {
   foo(1.0, tint_symbol, tint_symbol_1);
 }
@@ -966,7 +966,7 @@
     auto* src = R"(
 var<workgroup> m : mat2x2<f32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x = m;
 }
@@ -977,7 +977,7 @@
   m : mat2x2<f32>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@internal(disable_validation__entry_point_parameter) tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
   let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
   let x = *(tint_symbol);
@@ -999,7 +999,7 @@
 };
 var<workgroup> m : array<S2, 4>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x = m;
 }
@@ -1018,7 +1018,7 @@
   m : array<S2, 4u>,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@internal(disable_validation__entry_point_parameter) tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
   let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
   let x = *(tint_symbol);
@@ -1042,7 +1042,7 @@
 
 var<workgroup> b : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x = a;
   let y = b;
@@ -1059,7 +1059,7 @@
   b : S,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@internal(disable_validation__entry_point_parameter) tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
   let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
   let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
@@ -1077,7 +1077,7 @@
 // variables that are promoted to threadgroup memory arguments.
 TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x = a;
   let y = b;
@@ -1101,7 +1101,7 @@
   b : S,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@internal(disable_validation__entry_point_parameter) tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
   let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
   let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
@@ -1132,7 +1132,7 @@
 @group(0) @binding(2) var t : texture_2d<f32>;
 @group(0) @binding(3) var s : sampler;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
 }
 )";
@@ -1142,7 +1142,7 @@
   a : f32,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
 }
 )";
diff --git a/src/tint/transform/multiplanar_external_texture_test.cc b/src/tint/transform/multiplanar_external_texture_test.cc
index 39c4602..63d12f1 100644
--- a/src/tint/transform/multiplanar_external_texture_test.cc
+++ b/src/tint/transform/multiplanar_external_texture_test.cc
@@ -56,7 +56,7 @@
 @group(0) @binding(0) var s : sampler;
 @group(0) @binding(1) var ext_tex : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy);
 }
@@ -74,7 +74,7 @@
 @group(0) @binding(0) var s : sampler;
 @group(0) @binding(1) var ext_tex : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy);
 }
@@ -96,7 +96,7 @@
     auto* src = R"(
 @group(0) @binding(0) var ext_tex : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   var dim : vec2<i32>;
   dim = textureDimensions(ext_tex);
@@ -130,7 +130,7 @@
 
 @group(0) @binding(0) var ext_tex : texture_2d<f32>;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   var dim : vec2<i32>;
   dim = textureDimensions(ext_tex);
@@ -148,7 +148,7 @@
 // Tests that the transform works with a textureDimensions call.
 TEST_F(MultiplanarExternalTextureTest, Dimensions_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   var dim : vec2<i32>;
   dim = textureDimensions(ext_tex);
@@ -182,7 +182,7 @@
 
 @group(0) @binding(2) var<uniform> ext_tex_params : ExternalTextureParams;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   var dim : vec2<i32>;
   dim = textureDimensions(ext_tex);
@@ -205,7 +205,7 @@
 @group(0) @binding(0) var s : sampler;
 @group(0) @binding(1) var ext_tex : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy);
 }
@@ -259,7 +259,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params);
 }
@@ -275,7 +275,7 @@
 // Test that the transform works with a textureSampleLevel call.
 TEST_F(MultiplanarExternalTextureTest, BasicTextureSampleLevel_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy);
 }
@@ -328,7 +328,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params);
 }
@@ -350,7 +350,7 @@
     auto* src = R"(
 @group(0) @binding(0) var ext_tex : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureLoad(ext_tex, vec2<i32>(1, 1));
 }
@@ -402,7 +402,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1, 1), ext_tex_params);
 }
@@ -418,7 +418,7 @@
 // Tests that the transform works with a textureLoad call.
 TEST_F(MultiplanarExternalTextureTest, BasicTextureLoad_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureLoad(ext_tex, vec2<i32>(1, 1));
 }
@@ -470,7 +470,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1, 1), ext_tex_params);
 }
@@ -492,7 +492,7 @@
 @group(0) @binding(0) var s : sampler;
 @group(0) @binding(1) var ext_tex : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy) + textureLoad(ext_tex, vec2<i32>(1, 1));
 }
@@ -559,7 +559,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return (textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params) + textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1, 1), ext_tex_params));
 }
@@ -576,7 +576,7 @@
 // call.
 TEST_F(MultiplanarExternalTextureTest, TextureSampleAndTextureLoad_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy) + textureLoad(ext_tex, vec2<i32>(1, 1));
 }
@@ -642,7 +642,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return (textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params) + textureLoadExternal(ext_tex, ext_tex_plane_1, vec2<i32>(1, 1), ext_tex_params));
 }
@@ -668,7 +668,7 @@
 @group(0) @binding(3) var ext_tex_2 : texture_external;
 @group(1) @binding(0) var ext_tex_3 : texture_external;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return textureSampleLevel(ext_tex, s, coord.xy) + textureSampleLevel(ext_tex_1, s, coord.xy) + textureSampleLevel(ext_tex_2, s, coord.xy) + textureSampleLevel(ext_tex_3, s, coord.xy);
 }
@@ -740,7 +740,7 @@
   return vec4<f32>(color, 1.0f);
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
   return (((textureSampleExternal(ext_tex, ext_tex_plane_1, s, coord.xy, ext_tex_params) + textureSampleExternal(ext_tex_1, ext_tex_plane_1_1, s, coord.xy, ext_tex_params_1)) + textureSampleExternal(ext_tex_2, ext_tex_plane_1_2, s, coord.xy, ext_tex_params_2)) + textureSampleExternal(ext_tex_3, ext_tex_plane_1_3, s, coord.xy, ext_tex_params_3));
 }
@@ -768,7 +768,7 @@
 @group(0) @binding(0) var ext_tex : texture_external;
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp);
 }
@@ -826,7 +826,7 @@
 
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
 }
@@ -843,7 +843,7 @@
 // correct output.
 TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParam_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp);
 }
@@ -880,7 +880,7 @@
 
 @group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
 }
@@ -932,7 +932,7 @@
 @group(0) @binding(0) var ext_tex : texture_external;
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(smp, ext_tex);
 }
@@ -990,7 +990,7 @@
 
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(smp, ext_tex, ext_tex_plane_1, ext_tex_params);
 }
@@ -1016,7 +1016,7 @@
 @group(0) @binding(1) var smp : sampler;
 @group(0) @binding(2) var ext_tex2 : texture_external;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp, ext_tex2);
 }
@@ -1081,7 +1081,7 @@
 
 @group(0) @binding(2) var ext_tex2 : texture_2d<f32>;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp, ext_tex2, ext_tex_plane_1_1, ext_tex_params_1);
 }
@@ -1099,7 +1099,7 @@
 // correct output.
 TEST_F(MultiplanarExternalTextureTest, ExternalTexturePassedAsParamMultiple_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp, ext_tex2);
 }
@@ -1143,7 +1143,7 @@
 
 @group(0) @binding(6) var<uniform> ext_tex_params_1 : ExternalTextureParams;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp, ext_tex2, ext_tex_plane_1_1, ext_tex_params_1);
 }
@@ -1203,7 +1203,7 @@
 @group(0) @binding(0) var ext_tex : texture_external;
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp);
 }
@@ -1265,7 +1265,7 @@
 
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
 }
@@ -1293,7 +1293,7 @@
 @group(0) @binding(0) var ext_tex : texture_external;
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp);
 }
@@ -1355,7 +1355,7 @@
 
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
 }
@@ -1421,7 +1421,7 @@
 @group(0) @binding(0) var ext_tex : ET;
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp);
 }
@@ -1481,7 +1481,7 @@
 
 @group(0) @binding(1) var smp : sampler;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
 }
@@ -1497,7 +1497,7 @@
 // Tests that the the transform handles aliases to external textures
 TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias_OutOfOrder) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, smp);
 }
@@ -1536,7 +1536,7 @@
 
 @group(0) @binding(3) var<uniform> ext_tex_params : ExternalTextureParams;
 
-@stage(fragment)
+@fragment
 fn main() {
   f(ext_tex, ext_tex_plane_1, ext_tex_params, smp);
 }
diff --git a/src/tint/transform/num_workgroups_from_uniform_test.cc b/src/tint/transform/num_workgroups_from_uniform_test.cc
index ffc0ca8..8562c01 100644
--- a/src/tint/transform/num_workgroups_from_uniform_test.cc
+++ b/src/tint/transform/num_workgroups_from_uniform_test.cc
@@ -33,7 +33,7 @@
 
 TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
 }
 )";
@@ -43,7 +43,7 @@
 
 TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
 }
 )";
@@ -61,7 +61,7 @@
 
 TEST_F(NumWorkgroupsFromUniformTest, Basic) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
   let groups_x = num_wgs.x;
   let groups_y = num_wgs.y;
@@ -82,7 +82,7 @@
   let groups_z = num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   main_inner(tint_symbol_3.num_workgroups);
 }
@@ -101,7 +101,7 @@
   @builtin(num_workgroups) num_wgs : vec3<u32>,
 };
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(in : Builtins) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
@@ -126,7 +126,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   main_inner(Builtins(tint_symbol_3.num_workgroups));
 }
@@ -141,7 +141,7 @@
 
 TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(in : Builtins) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
@@ -166,7 +166,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   main_inner(Builtins(tint_symbol_3.num_workgroups));
 }
@@ -191,7 +191,7 @@
   @builtin(workgroup_id) wgid : vec3<u32>,
 };
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(in : Builtins) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
@@ -225,7 +225,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(tint_symbol : tint_symbol_1) {
   main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
 }
@@ -240,7 +240,7 @@
 
 TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(in : Builtins) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
@@ -275,7 +275,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(tint_symbol : tint_symbol_1) {
   main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
 }
@@ -306,21 +306,21 @@
   @builtin(workgroup_id) wgid : vec3<u32>,
 };
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main1(in : Builtins1) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main2(in : Builtins2) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
   let groups_x = num_wgs.x;
   let groups_y = num_wgs.y;
@@ -351,7 +351,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main1() {
   main1_inner(Builtins1(tint_symbol_7.num_workgroups));
 }
@@ -369,7 +369,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main2(tint_symbol_2 : tint_symbol_3) {
   main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
 }
@@ -380,7 +380,7 @@
   let groups_z = num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main3() {
   main3_inner(tint_symbol_7.num_workgroups);
 }
@@ -400,7 +400,7 @@
   @builtin(workgroup_id) wgid : vec3<u32>,
 };
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(in : Builtins) {
 }
 )";
@@ -421,7 +421,7 @@
 fn main_inner(in : Builtins) {
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main(tint_symbol : tint_symbol_1) {
   main_inner(Builtins(tint_symbol.gid, tint_symbol.wgid));
 }
@@ -448,21 +448,21 @@
   @builtin(workgroup_id) wgid : vec3<u32>,
 };
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main1(in : Builtins1) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main2(in : Builtins2) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
   let groups_x = num_wgs.x;
   let groups_y = num_wgs.y;
@@ -493,7 +493,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main1() {
   main1_inner(Builtins1(tint_symbol_7.num_workgroups));
 }
@@ -511,7 +511,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main2(tint_symbol_2 : tint_symbol_3) {
   main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
 }
@@ -522,7 +522,7 @@
   let groups_z = num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main3() {
   main3_inner(tint_symbol_7.num_workgroups);
 }
@@ -572,7 +572,7 @@
 @group(1) @binding(3) var<storage, read> g9 : S0;
 @group(3) @binding(2) var<storage, read_write> g10 : S0;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main1(in : Builtins1) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
@@ -580,14 +580,14 @@
   g8.m0 = 1u;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main2(in : Builtins2) {
   let groups_x = in.num_wgs.x;
   let groups_y = in.num_wgs.y;
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
   let groups_x = num_wgs.x;
   let groups_y = num_wgs.y;
@@ -647,7 +647,7 @@
   g8.m0 = 1u;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main1() {
   main1_inner(Builtins1(tint_symbol_7.num_workgroups));
 }
@@ -665,7 +665,7 @@
   let groups_z = in.num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main2(tint_symbol_2 : tint_symbol_3) {
   main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
 }
@@ -676,7 +676,7 @@
   let groups_z = num_wgs.z;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main3() {
   main3_inner(tint_symbol_7.num_workgroups);
 }
diff --git a/src/tint/transform/renamer_test.cc b/src/tint/transform/renamer_test.cc
index f56971e..516b164 100644
--- a/src/tint/transform/renamer_test.cc
+++ b/src/tint/transform/renamer_test.cc
@@ -50,7 +50,7 @@
   return vert_idx;
 }
 
-@stage(vertex)
+@vertex
 fn entry(@builtin(vertex_index) vert_idx : u32
         ) -> @builtin(position) vec4<f32>  {
   _ = test(vert_idx);
@@ -63,7 +63,7 @@
   return tint_symbol_1;
 }
 
-@stage(vertex)
+@vertex
 fn tint_symbol_2(@builtin(vertex_index) tint_symbol_1 : u32) -> @builtin(position) vec4<f32> {
   _ = tint_symbol(tint_symbol_1);
   return vec4<f32>();
@@ -87,7 +87,7 @@
 
 TEST_F(RenamerTest, PreserveSwizzles) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry() -> @builtin(position) vec4<f32> {
   var v : vec4<f32>;
   var rgba : f32;
@@ -97,7 +97,7 @@
 )";
 
     auto* expect = R"(
-@stage(vertex)
+@vertex
 fn tint_symbol() -> @builtin(position) vec4<f32> {
   var tint_symbol_1 : vec4<f32>;
   var tint_symbol_2 : f32;
@@ -124,7 +124,7 @@
 
 TEST_F(RenamerTest, PreserveBuiltins) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry() -> @builtin(position) vec4<f32> {
   var blah : vec4<f32>;
   return abs(blah);
@@ -132,7 +132,7 @@
 )";
 
     auto* expect = R"(
-@stage(vertex)
+@vertex
 fn tint_symbol() -> @builtin(position) vec4<f32> {
   var tint_symbol_1 : vec4<f32>;
   return abs(tint_symbol_1);
@@ -155,7 +155,7 @@
 
 TEST_F(RenamerTest, PreserveBuiltinTypes) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn entry() {
   var a = modf(1.0).whole;
   var b = modf(1.0).fract;
@@ -165,7 +165,7 @@
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn tint_symbol() {
   var tint_symbol_1 = modf(1.0).whole;
   var tint_symbol_2 = modf(1.0).fract;
@@ -190,7 +190,7 @@
 
 TEST_F(RenamerTest, PreserveUnicode) {
     auto src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var )" + std::string(kUnicodeIdentifier) +
                R"( : i32;
@@ -209,7 +209,7 @@
 
 TEST_F(RenamerTest, AttemptSymbolCollision) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn entry() -> @builtin(position) vec4<f32> {
   var tint_symbol : vec4<f32>;
   var tint_symbol_2 : vec4<f32>;
@@ -219,7 +219,7 @@
 )";
 
     auto* expect = R"(
-@stage(vertex)
+@vertex
 fn tint_symbol() -> @builtin(position) vec4<f32> {
   var tint_symbol_1 : vec4<f32>;
   var tint_symbol_2 : vec4<f32>;
@@ -252,7 +252,7 @@
     auto keyword = GetParam();
 
     auto src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var )" + keyword +
                R"( : i32;
@@ -260,7 +260,7 @@
 )";
 
     auto* expect = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var tint_symbol : i32;
 }
@@ -278,7 +278,7 @@
     auto keyword = GetParam();
 
     auto src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var )" + keyword +
                R"( : i32;
@@ -286,7 +286,7 @@
 )";
 
     auto* expect = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var tint_symbol : i32;
 }
@@ -304,7 +304,7 @@
     auto keyword = GetParam();
 
     auto src = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var )" + keyword +
                R"( : i32;
@@ -312,7 +312,7 @@
 )";
 
     auto* expect = R"(
-@stage(fragment)
+@fragment
 fn frag_main() {
   var tint_symbol : i32;
 }
diff --git a/src/tint/transform/simplify_pointers_test.cc b/src/tint/transform/simplify_pointers_test.cc
index f5658de..9848ff3 100644
--- a/src/tint/transform/simplify_pointers_test.cc
+++ b/src/tint/transform/simplify_pointers_test.cc
@@ -236,7 +236,7 @@
   return 1;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   var arr = array<f32, 4>();
   for (let a = &arr[foo()]; ;) {
@@ -251,7 +251,7 @@
   return 1;
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   var arr = array<f32, 4>();
   let a_save = foo();
@@ -337,7 +337,7 @@
     auto* src = R"(
 var<private> a : array<i32, 2>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   let x = &a;
   var a : i32 = (*x)[0];
@@ -350,7 +350,7 @@
     auto* expect = R"(
 var<private> a : array<i32, 2>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
   var a_1 : i32 = a[0];
   {
diff --git a/src/tint/transform/single_entry_point_test.cc b/src/tint/transform/single_entry_point_test.cc
index 8044621..8445f61 100644
--- a/src/tint/transform/single_entry_point_test.cc
+++ b/src/tint/transform/single_entry_point_test.cc
@@ -47,7 +47,7 @@
 
 TEST_F(SingleEntryPointTest, Error_InvalidEntryPoint) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -68,7 +68,7 @@
     auto* src = R"(
 fn foo() {}
 
-@stage(fragment)
+@fragment
 fn main() {}
 )";
 
@@ -85,7 +85,7 @@
 
 TEST_F(SingleEntryPointTest, SingleEntryPoint) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn main() {
 }
 )";
@@ -101,26 +101,26 @@
 
 TEST_F(SingleEntryPointTest, MultipleEntryPoints) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn vert_main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
 }
 )";
 
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
 }
 )";
@@ -144,23 +144,23 @@
 
 var<private> d : f32;
 
-@stage(vertex)
+@vertex
 fn vert_main() -> @builtin(position) vec4<f32> {
   a = 0.0;
   return vec4<f32>();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   b = 0.0;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   c = 0.0;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
   d = 0.0;
 }
@@ -169,7 +169,7 @@
     auto* expect = R"(
 var<private> c : f32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   c = 0.0;
 }
@@ -194,23 +194,23 @@
 
 let d : f32 = 1.0;
 
-@stage(vertex)
+@vertex
 fn vert_main() -> @builtin(position) vec4<f32> {
   let local_a : f32 = a;
   return vec4<f32>();
 }
 
-@stage(fragment)
+@fragment
 fn frag_main() {
   let local_b : f32 = b;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   let local_c : f32 = c;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
   let local_d : f32 = d;
 }
@@ -219,7 +219,7 @@
     auto* expect = R"(
 let c : f32 = 1.0;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   let local_c : f32 = c;
 }
@@ -238,7 +238,7 @@
     auto* src = R"(
 let size : i32 = 1;
 
-@stage(compute) @workgroup_size(size)
+@compute @workgroup_size(size)
 fn main() {
 }
 )";
@@ -261,27 +261,27 @@
 @id(0)    override c3 : u32 = 1u;
 @id(9999) override c4 : u32 = 1u;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
     let local_d = c1;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
     let local_d = c2;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main3() {
     let local_d = c3;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main4() {
     let local_d = c4;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main5() {
     let local_d = 1u;
 }
@@ -292,7 +292,7 @@
         auto* expect = R"(
 @id(1001) override c1 : u32 = 1u;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   let local_d = c1;
 }
@@ -310,7 +310,7 @@
         auto* expect = R"(
 @id(1) override c2 : u32 = 1u;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
   let local_d = c2;
 }
@@ -326,7 +326,7 @@
         auto* expect = R"(
 @id(0) override c3 : u32 = 1u;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main3() {
   let local_d = c3;
 }
@@ -342,7 +342,7 @@
         auto* expect = R"(
 @id(9999) override c4 : u32 = 1u;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main4() {
   let local_d = c4;
 }
@@ -356,7 +356,7 @@
     {
         SingleEntryPoint::Config cfg("comp_main5");
         auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main5() {
   let local_d = 1u;
 }
@@ -389,12 +389,12 @@
   inner_shared();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   outer1();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
   outer2();
 }
@@ -412,7 +412,7 @@
   inner_shared();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   outer1();
 }
@@ -463,12 +463,12 @@
   outer2_var = 0.0;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   outer1();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main2() {
   outer2();
 }
@@ -495,7 +495,7 @@
   outer1_var = 0.0;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn comp_main1() {
   outer1();
 }
diff --git a/src/tint/transform/unwind_discard_functions_test.cc b/src/tint/transform/unwind_discard_functions_test.cc
index 70b4218..481df9d 100644
--- a/src/tint/transform/unwind_discard_functions_test.cc
+++ b/src/tint/transform/unwind_discard_functions_test.cc
@@ -102,7 +102,7 @@
   let marker1 = 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   f();
   let marker1 = 0;
@@ -122,7 +122,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   f();
   if (tint_discard) {
@@ -156,7 +156,7 @@
   return s;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   f();
@@ -186,7 +186,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   f();
@@ -230,7 +230,7 @@
   return 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   h();
@@ -275,7 +275,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   h();
@@ -311,7 +311,7 @@
   let marker1 = 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   f();
@@ -348,7 +348,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   f();
@@ -381,7 +381,7 @@
 
 TEST_F(UnwindDiscardFunctionsTest, Call_DiscardFuncDeclaredBelow) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   f();
   let marker1 = 0;
@@ -400,7 +400,7 @@
 
 var<private> tint_discard : bool = false;
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   f();
   if (tint_discard) {
@@ -433,7 +433,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   if (f() == 42) {
     let marker1 = 0;
@@ -456,7 +456,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let tint_symbol = f();
   if (tint_discard) {
@@ -485,7 +485,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   if (true) {
     let marker1 = 0;
@@ -512,7 +512,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   if (true) {
     let marker1 = 0;
@@ -547,7 +547,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   var a = 0;
@@ -573,7 +573,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   var a = 0;
@@ -605,7 +605,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   for (f(); ; ) {
@@ -630,7 +630,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   var tint_symbol = f();
@@ -661,7 +661,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   for (let i = f(); ; ) {
@@ -686,7 +686,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   var tint_symbol = f();
@@ -717,7 +717,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   for (; f() == 42; ) {
@@ -742,7 +742,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   loop {
@@ -778,7 +778,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   for (; ; f()) {
@@ -809,7 +809,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   switch (f()) {
     case 0: {
@@ -843,7 +843,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   var tint_symbol = f();
   if (tint_discard) {
@@ -893,7 +893,7 @@
   return f();
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   g();
@@ -929,7 +929,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let marker1 = 0;
   g();
@@ -956,7 +956,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   var a = f();
   let marker1 = 0;
@@ -978,7 +978,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   var a = f();
   if (tint_discard) {
@@ -1005,7 +1005,7 @@
   return 42;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   var a : i32;
   a = f();
@@ -1028,7 +1028,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   var a : i32;
   a = f();
@@ -1056,7 +1056,7 @@
   return 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   var b = array<i32, 10>();
   b[f()] = 10;
@@ -1079,7 +1079,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   var b = array<i32, 10>();
   let tint_symbol = f();
@@ -1115,7 +1115,7 @@
   return 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   var b = array<i32, 10>();
   b[f()] = g();
@@ -1146,7 +1146,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   var b = array<i32, 10>();
   let tint_symbol = g();
@@ -1194,7 +1194,7 @@
   return 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   if ((f() + g() + h()) == 0) {
     let marker1 = 0;
@@ -1233,7 +1233,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let tint_symbol = f();
   if (tint_discard) {
@@ -1286,7 +1286,7 @@
   return 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   if (f() == 1 && g() == 2 && h() == 3) {
     let marker1 = 0;
@@ -1325,7 +1325,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   let tint_symbol_2 = f();
   if (tint_discard) {
@@ -1373,7 +1373,7 @@
   let marker1 = 0;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
   f();
   let marker1 = 0;
@@ -1397,7 +1397,7 @@
   discard;
 }
 
-@stage(fragment)
+@fragment
 fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
   f();
   if (tint_discard_1) {
diff --git a/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc b/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc
index 2b1c098..1ee9d33 100644
--- a/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc
+++ b/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc
@@ -53,7 +53,7 @@
     }
 
     std::string src = R"(
-@stage(fragment)
+@fragment
 fn main() {
   let m = ${matrix}(42.0);
 }
@@ -64,7 +64,7 @@
   return ${matrix}(${values});
 }
 
-@stage(fragment)
+@fragment
 fn main() {
   let m = build_${matrix_no_type}(42.0);
 }
@@ -107,7 +107,7 @@
     }
 
     std::string tmpl = R"(
-@stage(fragment)
+@fragment
 fn main() {
   let m = ${matrix}(${values});
 }
@@ -137,7 +137,7 @@
     }
 
     std::string tmpl = R"(
-@stage(fragment)
+@fragment
 fn main() {
   let m = ${matrix}(${columns});
 }
diff --git a/src/tint/transform/vertex_pulling_test.cc b/src/tint/transform/vertex_pulling_test.cc
index ef37631..5fb8b1c 100644
--- a/src/tint/transform/vertex_pulling_test.cc
+++ b/src/tint/transform/vertex_pulling_test.cc
@@ -37,7 +37,7 @@
 
 TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -57,7 +57,7 @@
 
 TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
     auto* src = R"(
-@stage(fragment)
+@fragment
 fn main() {}
 )";
 
@@ -75,7 +75,7 @@
 
 TEST_F(VertexPullingTest, Error_BadStride) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
   return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
@@ -98,7 +98,7 @@
 
 TEST_F(VertexPullingTest, BasicModule) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -109,7 +109,7 @@
   tint_vertex_data : array<u32>,
 }
 
-@stage(vertex)
+@vertex
 fn main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
@@ -127,7 +127,7 @@
 
 TEST_F(VertexPullingTest, OneAttribute) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
   return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
@@ -140,7 +140,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var var_a : f32;
   {
@@ -164,7 +164,7 @@
 
 TEST_F(VertexPullingTest, OneInstancedAttribute) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
   return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
@@ -177,7 +177,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(instance_index) tint_pulling_instance_index : u32) -> @builtin(position) vec4<f32> {
   var var_a : f32;
   {
@@ -201,7 +201,7 @@
 
 TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32) -> @builtin(position) vec4<f32> {
   return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
@@ -214,7 +214,7 @@
 
 @binding(0) @group(5) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var var_a : f32;
   {
@@ -243,7 +243,7 @@
   @location(0) var_a : f32,
 };
 
-@stage(vertex)
+@vertex
 fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
   return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
 }
@@ -261,7 +261,7 @@
   var_a : f32,
 }
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var inputs : Inputs;
   {
@@ -286,7 +286,7 @@
 // We expect the transform to use an existing builtin variables if it finds them
 TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32,
         @location(1) var_b : f32,
         @builtin(vertex_index) custom_vertex_index : u32,
@@ -305,7 +305,7 @@
 
 @binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) custom_vertex_index : u32, @builtin(instance_index) custom_instance_index : u32) -> @builtin(position) vec4<f32> {
   var var_a : f32;
   var var_b : f32;
@@ -350,7 +350,7 @@
   @builtin(instance_index) custom_instance_index : u32,
 };
 
-@stage(vertex)
+@vertex
 fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
   return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
 }
@@ -383,7 +383,7 @@
   custom_instance_index : u32,
 }
 
-@stage(vertex)
+@vertex
 fn main(tint_symbol_1 : tint_symbol) -> @builtin(position) vec4<f32> {
   var inputs : Inputs;
   inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index;
@@ -422,7 +422,7 @@
 
 TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(inputs : Inputs) -> @builtin(position) vec4<f32> {
   return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
 }
@@ -451,7 +451,7 @@
   custom_instance_index : u32,
 }
 
-@stage(vertex)
+@vertex
 fn main(tint_symbol_1 : tint_symbol) -> @builtin(position) vec4<f32> {
   var inputs : Inputs;
   inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index;
@@ -511,7 +511,7 @@
   @builtin(instance_index) custom_instance_index : u32,
 };
 
-@stage(vertex)
+@vertex
 fn main(inputs : Inputs, indices : Indices) -> @builtin(position) vec4<f32> {
   return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
 }
@@ -540,7 +540,7 @@
   custom_instance_index : u32,
 }
 
-@stage(vertex)
+@vertex
 fn main(indices : Indices) -> @builtin(position) vec4<f32> {
   var inputs : Inputs;
   {
@@ -577,7 +577,7 @@
 
 TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct_OutOfOrder) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(inputs : Inputs, indices : Indices) -> @builtin(position) vec4<f32> {
   return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
 }
@@ -602,7 +602,7 @@
 
 @binding(1) @group(4) var<storage, read> tint_pulling_vertex_buffer_1 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(indices : Indices) -> @builtin(position) vec4<f32> {
   var inputs : Inputs;
   {
@@ -653,7 +653,7 @@
 
 TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32,
         @location(1) var_b : vec4<f32>) -> @builtin(position) vec4<f32> {
   return vec4<f32>();
@@ -667,7 +667,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var var_a : f32;
   var var_b : vec4<f32>;
@@ -695,7 +695,7 @@
 
 TEST_F(VertexPullingTest, FloatVectorAttributes) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : vec2<f32>,
         @location(1) var_b : vec3<f32>,
         @location(2) var_c : vec4<f32>
@@ -715,7 +715,7 @@
 
 @binding(2) @group(4) var<storage, read> tint_pulling_vertex_buffer_2 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var var_a : vec2<f32>;
   var var_b : vec3<f32>;
@@ -749,7 +749,7 @@
 
 TEST_F(VertexPullingTest, AttemptSymbolCollision) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(@location(0) var_a : f32,
         @location(1) var_b : vec4<f32>) -> @builtin(position) vec4<f32> {
   var tint_pulling_vertex_index : i32;
@@ -767,7 +767,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0_1 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index_1 : u32) -> @builtin(position) vec4<f32> {
   var var_a : f32;
   var var_b : vec4<f32>;
@@ -799,7 +799,7 @@
 
 TEST_F(VertexPullingTest, FormatsAligned) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(
     @location(0) uint8x2 : vec2<u32>,
     @location(1) uint8x4 : vec4<u32>,
@@ -843,7 +843,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var uint8x2 : vec2<u32>;
   var uint8x4 : vec4<u32>;
@@ -944,7 +944,7 @@
 
 TEST_F(VertexPullingTest, FormatsStrideUnaligned) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(
     @location(0) uint8x2 : vec2<u32>,
     @location(1) uint8x4 : vec4<u32>,
@@ -989,7 +989,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var uint8x2 : vec2<u32>;
   var uint8x4 : vec4<u32>;
@@ -1090,7 +1090,7 @@
 
 TEST_F(VertexPullingTest, FormatsWithVectorsResized) {
     auto* src = R"(
-@stage(vertex)
+@vertex
 fn main(
     @location(0) uint8x2 : vec3<u32>,
     @location(1) uint8x4 : vec2<u32>,
@@ -1134,7 +1134,7 @@
 
 @binding(0) @group(4) var<storage, read> tint_pulling_vertex_buffer_0 : TintVertexData;
 
-@stage(vertex)
+@vertex
 fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(position) vec4<f32> {
   var uint8x2 : vec3<u32>;
   var uint8x4 : vec2<u32>;
diff --git a/src/tint/transform/zero_init_workgroup_memory_test.cc b/src/tint/transform/zero_init_workgroup_memory_test.cc
index c846d55..93f3933 100644
--- a/src/tint/transform/zero_init_workgroup_memory_test.cc
+++ b/src/tint/transform/zero_init_workgroup_memory_test.cc
@@ -81,7 +81,7 @@
   b = c;
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 )";
@@ -94,7 +94,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
 }
 
@@ -119,7 +119,7 @@
     auto* src = R"(
 var<workgroup> v : i32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = v; // Initialization should be inserted above this statement
 }
@@ -127,7 +127,7 @@
     auto* expect = R"(
 var<workgroup> v : i32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   {
     v = i32();
@@ -144,7 +144,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = v; // Initialization should be inserted above this statement
 }
@@ -152,7 +152,7 @@
 var<workgroup> v : i32;
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   {
     v = i32();
@@ -177,7 +177,7 @@
   @builtin(local_invocation_index) local_idx : u32,
 };
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(params : Params) {
   _ = v; // Initialization should be inserted above this statement
 }
@@ -190,7 +190,7 @@
   local_idx : u32,
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(params : Params) {
   {
     v = i32();
@@ -207,7 +207,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndexInStruct_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(params : Params) {
   _ = v; // Initialization should be inserted above this statement
 }
@@ -219,7 +219,7 @@
 var<workgroup> v : i32;
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(params : Params) {
   {
     v = i32();
@@ -245,7 +245,7 @@
     auto* src = R"(
 var<workgroup> v : i32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   _ = v; // Initialization should be inserted above this statement
 }
@@ -253,7 +253,7 @@
     auto* expect = R"(
 var<workgroup> v : i32;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     v = i32();
@@ -270,7 +270,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   _ = v; // Initialization should be inserted above this statement
 }
@@ -278,7 +278,7 @@
 var<workgroup> v : i32;
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     v = i32();
@@ -308,7 +308,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -327,7 +327,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   {
     a = i32();
@@ -360,7 +360,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex_Size1_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -379,7 +379,7 @@
 };
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   {
     a = i32();
@@ -434,7 +434,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(2, 3)
+@compute @workgroup_size(2, 3)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -453,7 +453,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(2, 3)
+@compute @workgroup_size(2, 3)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   if ((local_idx < 1u)) {
     a = i32();
@@ -499,7 +499,7 @@
 
 @id(1) override X : i32;
 
-@stage(compute) @workgroup_size(2, 3, X)
+@compute @workgroup_size(2, 3, X)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -521,7 +521,7 @@
 
 @id(1) override X : i32;
 
-@stage(compute) @workgroup_size(2, 3, X)
+@compute @workgroup_size(2, 3, X)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (u32(X) * 6u))) {
     a = i32();
@@ -568,7 +568,7 @@
 
 @id(1) override X : u32;
 
-@stage(compute) @workgroup_size(5u, X, 10u)
+@compute @workgroup_size(5u, X, 10u)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -591,7 +591,7 @@
 
 @id(1) override X : u32;
 
-@stage(compute) @workgroup_size(5u, X, 10u)
+@compute @workgroup_size(5u, X, 10u)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (X * 50u))) {
     a = i32();
@@ -654,7 +654,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -673,7 +673,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index : u32) {
   {
     a = i32();
@@ -706,7 +706,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
   _ = a; // Initialization should be inserted above this statement
   _ = b;
@@ -725,7 +725,7 @@
 };
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index : u32) {
   {
     a = i32();
@@ -780,18 +780,18 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f1() {
   _ = a; // Initialization should be inserted above this statement
   _ = c;
 }
 
-@stage(compute) @workgroup_size(1, 2, 3)
+@compute @workgroup_size(1, 2, 3)
 fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
   _ = b; // Initialization should be inserted above this statement
 }
 
-@stage(compute) @workgroup_size(4, 5, 6)
+@compute @workgroup_size(4, 5, 6)
 fn f3() {
   _ = c; // Initialization should be inserted above this statement
   _ = a;
@@ -809,7 +809,7 @@
 
 var<workgroup> c : array<S, 32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f1(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     a = i32();
@@ -828,7 +828,7 @@
   _ = c;
 }
 
-@stage(compute) @workgroup_size(1, 2, 3)
+@compute @workgroup_size(1, 2, 3)
 fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index_1 : u32) {
   if ((local_invocation_index_1 < 1u)) {
     b.x = i32();
@@ -841,7 +841,7 @@
   _ = b;
 }
 
-@stage(compute) @workgroup_size(4, 5, 6)
+@compute @workgroup_size(4, 5, 6)
 fn f3(@builtin(local_invocation_index) local_invocation_index_2 : u32) {
   if ((local_invocation_index_2 < 1u)) {
     a = i32();
@@ -868,18 +868,18 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f1() {
   _ = a; // Initialization should be inserted above this statement
   _ = c;
 }
 
-@stage(compute) @workgroup_size(1, 2, 3)
+@compute @workgroup_size(1, 2, 3)
 fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>) {
   _ = b; // Initialization should be inserted above this statement
 }
 
-@stage(compute) @workgroup_size(4, 5, 6)
+@compute @workgroup_size(4, 5, 6)
 fn f3() {
   _ = c; // Initialization should be inserted above this statement
   _ = a;
@@ -897,7 +897,7 @@
 };
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f1(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     a = i32();
@@ -916,7 +916,7 @@
   _ = c;
 }
 
-@stage(compute) @workgroup_size(1, 2, 3)
+@compute @workgroup_size(1, 2, 3)
 fn f2(@builtin(local_invocation_id) local_invocation_id : vec3<u32>, @builtin(local_invocation_index) local_invocation_index_1 : u32) {
   if ((local_invocation_index_1 < 1u)) {
     b.x = i32();
@@ -929,7 +929,7 @@
   _ = b;
 }
 
-@stage(compute) @workgroup_size(4, 5, 6)
+@compute @workgroup_size(4, 5, 6)
 fn f3(@builtin(local_invocation_index) local_invocation_index_2 : u32) {
   if ((local_invocation_index_2 < 1u)) {
     a = i32();
@@ -977,7 +977,7 @@
   use_v();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   call_use_v(); // Initialization should be inserted above this statement
 }
@@ -993,7 +993,7 @@
   use_v();
 }
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   {
     v = i32();
@@ -1010,7 +1010,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   call_use_v(); // Initialization should be inserted above this statement
 }
@@ -1026,7 +1026,7 @@
 var<workgroup> v : i32;
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_idx : u32) {
   {
     v = i32();
@@ -1056,7 +1056,7 @@
 var<workgroup> i : atomic<i32>;
 var<workgroup> u : atomic<u32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   atomicLoad(&(i)); // Initialization should be inserted above this statement
   atomicLoad(&(u));
@@ -1067,7 +1067,7 @@
 
 var<workgroup> u : atomic<u32>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     atomicStore(&(i), i32());
@@ -1086,7 +1086,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   atomicLoad(&(i)); // Initialization should be inserted above this statement
   atomicLoad(&(u));
@@ -1096,7 +1096,7 @@
 var<workgroup> u : atomic<u32>;
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     atomicStore(&(i), i32());
@@ -1129,7 +1129,7 @@
 
 var<workgroup> w : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   _ = w.a; // Initialization should be inserted above this statement
 }
@@ -1145,7 +1145,7 @@
 
 var<workgroup> w : S;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     w.a = i32();
@@ -1166,7 +1166,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   _ = w.a; // Initialization should be inserted above this statement
 }
@@ -1182,7 +1182,7 @@
 };
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   {
     w.a = i32();
@@ -1215,7 +1215,7 @@
     auto* src = R"(
 var<workgroup> w : array<atomic<u32>, 4>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   atomicLoad(&w[0]); // Initialization should be inserted above this statement
 }
@@ -1223,7 +1223,7 @@
     auto* expect = R"(
 var<workgroup> w : array<atomic<u32>, 4>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
     let i : u32 = idx;
@@ -1241,7 +1241,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   atomicLoad(&w[0]); // Initialization should be inserted above this statement
 }
@@ -1249,7 +1249,7 @@
 var<workgroup> w : array<atomic<u32>, 4>;
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
     let i : u32 = idx;
@@ -1279,7 +1279,7 @@
 
 var<workgroup> w : array<S, 4>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   _ = w[0].a; // Initialization should be inserted above this statement
 }
@@ -1295,7 +1295,7 @@
 
 var<workgroup> w : array<S, 4>;
 
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
     let i_1 : u32 = idx;
@@ -1317,7 +1317,7 @@
 
 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics_OutOfOrder) {
     auto* src = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f() {
   _ = w[0].a; // Initialization should be inserted above this statement
 }
@@ -1333,7 +1333,7 @@
 };
 )";
     auto* expect = R"(
-@stage(compute) @workgroup_size(1)
+@compute @workgroup_size(1)
 fn f(@builtin(local_invocation_index) local_invocation_index : u32) {
   for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
     let i_1 : u32 = idx;
diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc
index 201e757..c5bccf2 100644
--- a/src/tint/writer/glsl/generator_impl_function_test.cc
+++ b/src/tint/writer/glsl/generator_impl_function_test.cc
@@ -877,13 +877,13 @@
     // };
     // @binding(0) @group(0) var<storage> data : Data;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn a() {
     //   var v = data.d;
     //   return;
     // }
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn b() {
     //   var v = data.d;
     //   return;
diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc
index cbd95ec..c994b35 100644
--- a/src/tint/writer/hlsl/generator_impl_function_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_function_test.cc
@@ -816,13 +816,13 @@
     // };
     // @binding(0) @group(0) var<storage> data : Data;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn a() {
     //   var v = data.d;
     //   return;
     // }
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn b() {
     //   var v = data.d;
     //   return;
diff --git a/src/tint/writer/msl/generator_impl_function_test.cc b/src/tint/writer/msl/generator_impl_function_test.cc
index 585e8fc..3c78522 100644
--- a/src/tint/writer/msl/generator_impl_function_test.cc
+++ b/src/tint/writer/msl/generator_impl_function_test.cc
@@ -615,12 +615,12 @@
     // };
     // @binding(0) @group(0) var<storage> data : Data;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn a() {
     //   return;
     // }
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn b() {
     //   return;
     // }
diff --git a/src/tint/writer/spirv/builder_entry_point_test.cc b/src/tint/writer/spirv/builder_entry_point_test.cc
index 50398a3..a407aa7 100644
--- a/src/tint/writer/spirv/builder_entry_point_test.cc
+++ b/src/tint/writer/spirv/builder_entry_point_test.cc
@@ -37,7 +37,7 @@
 using BuilderTest = TestHelper;
 
 TEST_F(BuilderTest, EntryPoint_Parameters) {
-    // @stage(fragment)
+    // @fragment
     // fn frag_main(@builtin(position) coord : vec4<f32>,
     //              @location(1) loc1 : f32) {
     //   var col : f32 = (coord.x * loc1);
@@ -105,7 +105,7 @@
 }
 
 TEST_F(BuilderTest, EntryPoint_ReturnValue) {
-    // @stage(fragment)
+    // @fragment
     // fn frag_main(@location(0) @interpolate(flat) loc_in : u32)
     //     -> @location(0) f32 {
     //   if (loc_in > 10) {
@@ -187,12 +187,12 @@
     //   @builtin(position) pos : vec4<f32>;
     // };
     //
-    // @stage(vertex)
+    // @vertex
     // fn vert_main() -> Interface {
     //   return Interface(42.0, vec4<f32>());
     // }
     //
-    // @stage(fragment)
+    // @fragment
     // fn frag_main(inputs : Interface) -> @builtin(frag_depth) f32 {
     //   return inputs.value;
     // }
diff --git a/src/tint/writer/spirv/builder_function_test.cc b/src/tint/writer/spirv/builder_function_test.cc
index fdb6a34..2fd1c32 100644
--- a/src/tint/writer/spirv/builder_function_test.cc
+++ b/src/tint/writer/spirv/builder_function_test.cc
@@ -186,12 +186,12 @@
     // };
     // @binding(0) @group(0) var<storage> data : Data;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn a() {
     //   return;
     // }
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn b() {
     //   return;
     // }
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index c71cd49..5e0ce8c 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -710,7 +710,7 @@
                 return true;
             },
             [&](const ast::StageAttribute* stage) {
-                out << "stage(" << stage->stage << ")";
+                out << stage->stage;
                 return true;
             },
             [&](const ast::BindingAttribute* binding) {
diff --git a/src/tint/writer/wgsl/generator_impl_function_test.cc b/src/tint/writer/wgsl/generator_impl_function_test.cc
index 4c43a59..e34292b 100644
--- a/src/tint/writer/wgsl/generator_impl_function_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_function_test.cc
@@ -73,7 +73,7 @@
     gen.increment_indent();
 
     ASSERT_TRUE(gen.EmitFunction(func));
-    EXPECT_EQ(gen.result(), R"(  @stage(compute) @workgroup_size(2i, 4i, 6i)
+    EXPECT_EQ(gen.result(), R"(  @compute @workgroup_size(2i, 4i, 6i)
   fn my_func() {
     return;
   }
@@ -93,7 +93,7 @@
     gen.increment_indent();
 
     ASSERT_TRUE(gen.EmitFunction(func));
-    EXPECT_EQ(gen.result(), R"(  @stage(compute) @workgroup_size(2i, height)
+    EXPECT_EQ(gen.result(), R"(  @compute @workgroup_size(2i, height)
   fn my_func() {
     return;
   }
@@ -114,7 +114,7 @@
     gen.increment_indent();
 
     ASSERT_TRUE(gen.EmitFunction(func));
-    EXPECT_EQ(gen.result(), R"(  @stage(fragment)
+    EXPECT_EQ(gen.result(), R"(  @fragment
   fn frag_main(@builtin(position) coord : vec4<f32>, @location(1) loc1 : f32) {
   }
 )");
@@ -137,7 +137,7 @@
     gen.increment_indent();
 
     ASSERT_TRUE(gen.EmitFunction(func));
-    EXPECT_EQ(gen.result(), R"(  @stage(fragment)
+    EXPECT_EQ(gen.result(), R"(  @fragment
   fn frag_main() -> @location(1) f32 {
     return 1.0f;
   }
@@ -151,12 +151,12 @@
     // };
     // @binding(0) @group(0) var<storage> data : Data;
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn a() {
     //   return;
     // }
     //
-    // @stage(compute) @workgroup_size(1)
+    // @compute @workgroup_size(1)
     // fn b() {
     //   return;
     // }
@@ -206,13 +206,13 @@
 
 @binding(0) @group(0) var<storage, read_write> data : Data;
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn a() {
   var v : f32 = data.d;
   return;
 }
 
-@stage(compute) @workgroup_size(1i)
+@compute @workgroup_size(1i)
 fn b() {
   var v : f32 = data.d;
   return;
diff --git a/src/tint/writer/wgsl/generator_impl_global_decl_test.cc b/src/tint/writer/wgsl/generator_impl_global_decl_test.cc
index 8e1c9a6..48f1276 100644
--- a/src/tint/writer/wgsl/generator_impl_global_decl_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_global_decl_test.cc
@@ -35,7 +35,7 @@
     gen.increment_indent();
 
     ASSERT_TRUE(gen.Generate()) << gen.error();
-    EXPECT_EQ(gen.result(), R"(  @stage(compute) @workgroup_size(1i, 1i, 1i)
+    EXPECT_EQ(gen.result(), R"(  @compute @workgroup_size(1i, 1i, 1i)
   fn test_function() {
     var a : f32;
   }
@@ -91,7 +91,7 @@
     a : i32,
   }
 
-  @stage(compute) @workgroup_size(1i)
+  @compute @workgroup_size(1i)
   fn main() {
     var s0 : S0;
     var s1 : S1;