Remove usage of deprecated WGSL IO in perf/unit/whitebox tests

Also drive-by fixes some other deprecated constructs (const -> let, and
a disabled test having ancient WGSL).

Bug: dawn:755

Change-Id: I924dfbcbd0a7d0478f3e9b3766585751a0392499
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47620
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tests/perf_tests/DrawCallPerf.cpp b/src/tests/perf_tests/DrawCallPerf.cpp
index bbff208..a0456c0 100644
--- a/src/tests/perf_tests/DrawCallPerf.cpp
+++ b/src/tests/perf_tests/DrawCallPerf.cpp
@@ -33,10 +33,10 @@
     };
 
     constexpr char kVertexShader[] = R"(
-        [[location(0)]] var<in> pos : vec4<f32>;
-        [[builtin(position)]] var<out> Position : vec4<f32>;
-        [[stage(vertex)]] fn main() {
-            Position = pos;
+        [[stage(vertex)]] fn main(
+            [[location(0)]] pos : vec4<f32>
+        ) -> [[builtin(position)]] vec4<f32> {
+            return pos;
         })";
 
     constexpr char kFragmentShaderA[] = R"(
@@ -44,9 +44,8 @@
             color : vec3<f32>;
         };
         [[group(0), binding(0)]] var<uniform> uniforms : Uniforms;
-        [[location(0)]] var<out> fragColor : vec4<f32>;
-        [[stage(fragment)]] fn main() {
-            fragColor = vec4<f32>(uniforms.color * (1.0 / 5000.0), 1.0);
+        [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+            return vec4<f32>(uniforms.color * (1.0 / 5000.0), 1.0);
         })";
 
     constexpr char kFragmentShaderB[] = R"(
@@ -59,10 +58,8 @@
         [[group(0), binding(0)]] var<uniform> constants : Constants;
         [[group(1), binding(0)]] var<uniform> uniforms : Uniforms;
 
-        [[location(0)]] var<out> fragColor : vec4<f32>;
-
-        [[stage(fragment)]] fn main() {
-            fragColor = vec4<f32>((constants.color + uniforms.color) * (1.0 / 5000.0), 1.0);
+        [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+            return vec4<f32>((constants.color + uniforms.color) * (1.0 / 5000.0), 1.0);
         })";
 
     enum class Pipeline {
diff --git a/src/tests/perf_tests/SubresourceTrackingPerf.cpp b/src/tests/perf_tests/SubresourceTrackingPerf.cpp
index 3a49a1e..2b32621 100644
--- a/src/tests/perf_tests/SubresourceTrackingPerf.cpp
+++ b/src/tests/perf_tests/SubresourceTrackingPerf.cpp
@@ -71,17 +71,15 @@
 
         utils::ComboRenderPipelineDescriptor2 pipelineDesc;
         pipelineDesc.vertex.module = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(1.0, 0.0, 0.0, 1.0);
             }
         )");
         pipelineDesc.cFragment.module = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> FragColor : vec4<f32>;
             [[group(0), binding(0)]] var materials : texture_2d<f32>;
-            [[stage(fragment)]] fn main() {
-                const foo : vec2<i32> = textureDimensions(materials);
-                FragColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                let foo : vec2<i32> = textureDimensions(materials);
+                return vec4<f32>(1.0, 0.0, 0.0, 1.0);
             }
         )");
         mPipeline = device.CreateRenderPipeline2(&pipelineDesc);
diff --git a/src/tests/unittests/validation/DrawIndirectValidationTests.cpp b/src/tests/unittests/validation/DrawIndirectValidationTests.cpp
index 3870f1a..edb73a1 100644
--- a/src/tests/unittests/validation/DrawIndirectValidationTests.cpp
+++ b/src/tests/unittests/validation/DrawIndirectValidationTests.cpp
@@ -24,15 +24,13 @@
         ValidationTest::SetUp();
 
         wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(0.0, 0.0, 0.0, 0.0);
             })");
 
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32>{
+                return vec4<f32>(0.0, 0.0, 0.0, 0.0);
             })");
 
         // Set up render pipeline
diff --git a/src/tests/unittests/validation/IndexBufferValidationTests.cpp b/src/tests/unittests/validation/IndexBufferValidationTests.cpp
index 0320913..fd23ce5 100644
--- a/src/tests/unittests/validation/IndexBufferValidationTests.cpp
+++ b/src/tests/unittests/validation/IndexBufferValidationTests.cpp
@@ -23,15 +23,13 @@
     wgpu::RenderPipeline MakeTestPipeline(wgpu::IndexFormat format,
         wgpu::PrimitiveTopology primitiveTopology) {
         wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(0.0, 0.0, 0.0, 1.0);
             })");
 
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(0.0, 1.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return vec4<f32>(0.0, 1.0, 0.0, 1.0);
             })");
 
         utils::ComboRenderPipelineDescriptor2 descriptor;
diff --git a/src/tests/unittests/validation/QueueSubmitValidationTests.cpp b/src/tests/unittests/validation/QueueSubmitValidationTests.cpp
index 184994d..8bf0f4d 100644
--- a/src/tests/unittests/validation/QueueSubmitValidationTests.cpp
+++ b/src/tests/unittests/validation/QueueSubmitValidationTests.cpp
@@ -175,15 +175,13 @@
         };
 
         wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(0.0, 0.0, 0.0, 1.0);
             })");
 
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(0.0, 1.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return vec4<f32>(0.0, 1.0, 0.0, 1.0);
             })");
 
         utils::ComboRenderPipelineDescriptor2 descriptor;
diff --git a/src/tests/unittests/validation/RenderBundleValidationTests.cpp b/src/tests/unittests/validation/RenderBundleValidationTests.cpp
index 51e6d8b..cad2341 100644
--- a/src/tests/unittests/validation/RenderBundleValidationTests.cpp
+++ b/src/tests/unittests/validation/RenderBundleValidationTests.cpp
@@ -28,14 +28,12 @@
             ValidationTest::SetUp();
 
             vsModule = utils::CreateShaderModule(device, R"(
-                [[location(0)]] var<in> pos : vec2<f32>;
-
                 [[block]] struct S {
                     transform : mat2x2<f32>;
                 };
                 [[group(0), binding(0)]] var<uniform> uniforms : S;
 
-                [[stage(vertex)]] fn main() {
+                [[stage(vertex)]] fn main([[location(0)]] pos : vec2<f32>) {
                 })");
 
             fsModule = utils::CreateShaderModule(device, R"(
diff --git a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
index a4cfa43..54e500c 100644
--- a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
+++ b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
@@ -27,15 +27,13 @@
         ValidationTest::SetUp();
 
         vsModule = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(0.0, 0.0, 0.0, 1.0);
             })");
 
         fsModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(0.0, 1.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return vec4<f32>(0.0, 1.0, 0.0, 1.0);
             })");
     }
 
@@ -191,9 +189,11 @@
 
             std::ostringstream stream;
             stream << R"(
-                [[location(0)]] var<out> fragColor : vec4<)"
+                [[stage(fragment)]] fn main() -> [[location(0)]] vec4<)"
+                   << kScalarTypes[i] << R"(> {
+                    var result : vec4<)"
                    << kScalarTypes[i] << R"(>;
-                [[stage(fragment)]] fn main() {
+                    return result;
                 })";
             descriptor.cFragment.module = utils::CreateShaderModule(device, stream.str().c_str());
 
@@ -484,8 +484,7 @@
             data : array<u32, 100>;
         };
         [[group(0), binding(0)]] var<storage> dst : [[access(read_write)]] Dst;
-        [[builtin(vertex_index)]] var<in> VertexIndex : u32;
-        [[stage(vertex)]] fn main() {
+        [[stage(vertex)]] fn main([[builtin(vertex_index)]] VertexIndex : u32) {
             dst.data[VertexIndex] = 0x1234u;
         })");
 
@@ -591,16 +590,12 @@
 // Test that the entryPoint names must be present for the correct stage in the shader module.
 TEST_F(RenderPipelineValidationTest, EntryPointNameValidation) {
     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
-        [[builtin(position)]] var<out> position : vec4<f32>;
-        [[stage(vertex)]] fn vertex_main() {
-            position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
-            return;
+        [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 1.0);
         }
 
-        [[location(0)]] var<out> color : vec4<f32>;
-        [[stage(fragment)]] fn fragment_main() {
-            color = vec4<f32>(1.0, 0.0, 0.0, 1.0);
-            return;
+        [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
+            return vec4<f32>(1.0, 0.0, 0.0, 1.0);
         }
     )");
 
@@ -641,17 +636,13 @@
 // Test that vertex attrib validation is for the correct entryPoint
 TEST_F(RenderPipelineValidationTest, VertexAttribCorrectEntryPoint) {
     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
-        [[builtin(position)]] var<out> position : vec4<f32>;
-        [[location(0)]] var<in> attrib0 : vec4<f32>;
-        [[location(1)]] var<in> attrib1 : vec4<f32>;
-
-        [[stage(vertex)]] fn vertex0() {
-            position = attrib0;
-            return;
+        [[stage(vertex)]] fn vertex0([[location(0)]] attrib0 : vec4<f32>)
+                                    -> [[builtin(position)]] vec4<f32> {
+            return attrib0;
         }
-        [[stage(vertex)]] fn vertex1() {
-            position = attrib1;
-            return;
+        [[stage(vertex)]] fn vertex1([[location(1)]] attrib1 : vec4<f32>)
+                                    -> [[builtin(position)]] vec4<f32> {
+            return attrib1;
         }
     )");
 
@@ -687,16 +678,11 @@
 // Test that fragment output validation is for the correct entryPoint
 TEST_F(RenderPipelineValidationTest, FragmentOutputCorrectEntryPoint) {
     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
-        [[location(0)]] var<out> colorFloat : vec4<f32>;
-        [[location(0)]] var<out> colorUint : vec4<u32>;
-
-        [[stage(fragment)]] fn fragmentFloat() {
-            colorFloat = vec4<f32>(0.0, 0.0, 0.0, 0.0);
-            return;
+        [[stage(fragment)]] fn fragmentFloat() -> [[location(0)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
         }
-        [[stage(fragment)]] fn fragmentUint() {
-            colorUint = vec4<u32>(0u, 0u, 0u, 0u);
-            return;
+        [[stage(fragment)]] fn fragmentUint() -> [[location(0)]] vec4<u32> {
+            return vec4<u32>(0u, 0u, 0u, 0u);
         }
     )");
 
@@ -730,21 +716,15 @@
         [[block]] struct Uniforms {
             data : vec4<f32>;
         };
-        [[binding 0, set 0]] var<uniform> var0 : Uniforms;
-        [[binding 1, set 0]] var<uniform> var1 : Uniforms;
-        [[builtin(position)]] var<out> position : vec4<f32>;
+        [[group(0), binding(0)]] var<uniform> var0 : Uniforms;
+        [[group(0), binding(1)]] var<uniform> var1 : Uniforms;
 
-        fn vertex0() {
-            position = var0.data;
-            return;
+        [[stage(vertex)]] fn vertex0() -> [[builtin(position)]] vec4<f32> {
+            return var0.data;
         }
-        fn vertex1() {
-            position = var1.data;
-            return;
+        [[stage(vertex)]] fn vertex1() -> [[builtin(position)]] vec4<f32> {
+            return var1.data;
         }
-
-        entry_point vertex = vertex0;
-        entry_point vertex = vertex1;
     )");
 
     wgpu::BindGroupLayout bgl0 = utils::MakeBindGroupLayout(
diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
index 1bc39f6..f3a7a05 100644
--- a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -60,9 +60,9 @@
 // be compiled.
 TEST_F(ShaderModuleValidationTest, FragmentOutputLocationExceedsMaxColorAttachments) {
     std::ostringstream stream;
-    stream << "[[location(" << kMaxColorAttachments << R"()]] var<out> fragColor : vec4<f32>;
-        [[stage(fragment)]] fn main() {
-            fragColor = vec4<f32>(0.0, 1.0, 0.0, 1.0);
+    stream << "[[stage(fragment)]] fn main() -> [[location(" << kMaxColorAttachments
+           << R"()]]  vec4<f32> {
+            return vec4<f32>(0.0, 1.0, 0.0, 1.0);
         })";
     ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, stream.str().c_str()));
 }
@@ -162,12 +162,10 @@
     // is not the case on the wire.
     DAWN_SKIP_TEST_IF(UsesWire());
 
-    std::ostringstream stream;
-    stream << R"([[location(0)]] var<out> fragColor : vec4<f32>;
-        [[stage(fragment)]] fn main() {
-            fragColor = vec4<f32>(0.0, 1.0, 0.0, 1.0);
-        })";
-    wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, stream.str().c_str());
+    wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, R"(
+        [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+            return vec4<f32>(0.0, 1.0, 0.0, 1.0);
+        })");
 
     dawn_native::ShaderModuleBase* shaderModuleBase =
         reinterpret_cast<dawn_native::ShaderModuleBase*>(shaderModule.Get());
diff --git a/src/tests/unittests/validation/StorageTextureValidationTests.cpp b/src/tests/unittests/validation/StorageTextureValidationTests.cpp
index f705049..aa96915 100644
--- a/src/tests/unittests/validation/StorageTextureValidationTests.cpp
+++ b/src/tests/unittests/validation/StorageTextureValidationTests.cpp
@@ -24,14 +24,12 @@
         ValidationTest::SetUp();
 
         mDefaultVSModule = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(0.0, 0.0, 0.0, 1.0);
             })");
         mDefaultFSModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return vec4<f32>(1.0, 0.0, 0.0, 1.0);
             })");
     }
 
@@ -122,10 +120,10 @@
     {
         wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
             [[group(0), binding(0)]] var image0 : [[access(read)]] texture_storage_2d<rgba8unorm>;
-            [[builtin(vertex_index)]] var<in> VertexIndex : u32;
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = textureLoad(image0, vec2<i32>(i32(VertexIndex), 0));
+            [[stage(vertex)]] fn main(
+                [[builtin(vertex_index)]] VertexIndex : u32
+            ) -> [[builtin(position)]] vec4<f32> {
+                return textureLoad(image0, vec2<i32>(i32(VertexIndex), 0));
             })");
 
         utils::ComboRenderPipelineDescriptor2 descriptor;
@@ -139,10 +137,10 @@
     {
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
             [[group(0), binding(0)]] var image0 : [[access(read)]] texture_storage_2d<rgba8unorm>;
-            [[builtin(frag_coord)]] var<in> FragCoord : vec4<f32>;
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = textureLoad(image0, vec2<i32>(FragCoord.xy));
+            [[stage(fragment)]] fn main(
+                [[builtin(frag_coord)]] FragCoord : vec4<f32>
+            ) -> [[location(0)]] vec4<f32> {
+                return textureLoad(image0, vec2<i32>(FragCoord.xy));
             })");
 
         utils::ComboRenderPipelineDescriptor2 descriptor;
@@ -155,9 +153,8 @@
     // Write-only storage textures cannot be declared in a vertex shader.
     if ((false) /* TODO(https://crbug.com/tint/449) */) {
         wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
-            [[builtin(vertex_index)]] var<in> vertex_index : u32;
             [[group(0), binding(0)]] var image0 : [[access(write)]] texture_storage_2d<rgba8unorm>;
-            [[stage(vertex)]] fn main() {
+            [[stage(vertex)]] fn main([[builtin(vertex_index)]] vertex_index : u32) {
                 textureStore(image0, vec2<i32>(i32(vertex_index), 0), vec4<f32>(1.0, 0.0, 0.0, 1.0));
             })");
 
@@ -171,9 +168,8 @@
     // Write-only storage textures can be declared in a fragment shader.
     {
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
-            [[builtin(frag_coord)]] var<in> frag_coord : vec4<f32>;
             [[group(0), binding(0)]] var image0 : [[access(write)]] texture_storage_2d<rgba8unorm>;
-            [[stage(fragment)]] fn main() {
+            [[stage(fragment)]] fn main([[builtin(frag_coord)]] frag_coord : vec4<f32>) {
                 textureStore(image0, vec2<i32>(frag_coord.xy), vec4<f32>(1.0, 0.0, 0.0, 1.0));
             })");
 
@@ -192,14 +188,13 @@
     {
         wgpu::ShaderModule csModule = utils::CreateShaderModule(device, R"(
             [[group(0), binding(0)]] var image0 : [[access(read)]] texture_storage_2d<rgba8unorm>;
-            [[builtin(local_invocation_id)]] var<in> LocalInvocationID : vec3<u32>;
 
             [[block]] struct Buf {
                 data : f32;
             };
             [[group(0), binding(1)]] var<storage> buf : [[access(read_write)]] Buf;
 
-            [[stage(compute)]] fn main() {
+            [[stage(compute)]] fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
                  buf.data = textureLoad(image0, vec2<i32>(LocalInvocationID.xy)).x;
             })");
 
@@ -215,9 +210,8 @@
     {
         wgpu::ShaderModule csModule = utils::CreateShaderModule(device, R"(
             [[group(0), binding(0)]] var image0 : [[access(write)]] texture_storage_2d<rgba8unorm>;
-            [[builtin(local_invocation_id)]] var<in> LocalInvocationID : vec3<u32>;
 
-            [[stage(compute)]] fn main() {
+            [[stage(compute)]] fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
                 textureStore(image0, vec2<i32>(LocalInvocationID.xy), vec4<f32>(0.0, 0.0, 0.0, 0.0));
             })");
 
diff --git a/src/tests/unittests/validation/VertexBufferValidationTests.cpp b/src/tests/unittests/validation/VertexBufferValidationTests.cpp
index cbab0e7..c7163b6 100644
--- a/src/tests/unittests/validation/VertexBufferValidationTests.cpp
+++ b/src/tests/unittests/validation/VertexBufferValidationTests.cpp
@@ -26,9 +26,8 @@
         ValidationTest::SetUp();
 
         fsModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(0.0, 1.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return vec4<f32>(0.0, 1.0, 0.0, 1.0);
             })");
     }
 
@@ -42,13 +41,18 @@
 
     wgpu::ShaderModule MakeVertexShader(unsigned int bufferCount) {
         std::ostringstream vs;
+        vs << "[[stage(vertex)]] fn main(\n";
         for (unsigned int i = 0; i < bufferCount; ++i) {
-            vs << "[[location(" << i << ")]] var<in> a_position" << i << " : vec3<f32>;\n";
+            // TODO(cwallez@chromium.org): remove this special handling of 0 once Tint supports
+            // trailing commas in argument lists.
+            if (i != 0) {
+                vs << ", ";
+            }
+            vs << "[[location(" << i << ")]] a_position" << i << " : vec3<f32>\n";
         }
-        vs << "[[builtin(position)]] var<out> Position : vec4<f32>;";
-        vs << "[[stage(vertex)]] fn main() {\n";
+        vs << ") -> [[builtin(position)]] vec4<f32> {";
 
-        vs << "Position = vec4<f32>(";
+        vs << "return vec4<f32>(";
         for (unsigned int i = 0; i < bufferCount; ++i) {
             vs << "a_position" << i;
             if (i != bufferCount - 1) {
diff --git a/src/tests/unittests/validation/VertexStateValidationTests.cpp b/src/tests/unittests/validation/VertexStateValidationTests.cpp
index 2a4c601..9682b64 100644
--- a/src/tests/unittests/validation/VertexStateValidationTests.cpp
+++ b/src/tests/unittests/validation/VertexStateValidationTests.cpp
@@ -24,9 +24,8 @@
                         const char* vertexSource) {
         wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, vertexSource);
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> fragColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
-                fragColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return vec4<f32>(1.0, 0.0, 0.0, 1.0);
             }
         )");
 
@@ -45,9 +44,8 @@
     }
 
     const char* kDummyVertexShader = R"(
-        [[builtin(position)]] var<out> Position : vec4<f32>;
-        [[stage(vertex)]] fn main() {
-            Position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
         }
     )";
 };
@@ -99,29 +97,29 @@
 
     // Control case: pipeline with one input per attribute
     CreatePipeline(true, state, R"(
-        [[location(0)]] var<in> a : vec4<f32>;
-        [[location(1)]] var<in> b : vec4<f32>;
-        [[builtin(position)]] var<out> Position : vec4<f32>;
-        [[stage(vertex)]] fn main() {
-            Position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        [[stage(vertex)]] fn main(
+            [[location(0)]] a : vec4<f32>,
+            [[location(1)]] b : vec4<f32>
+        ) -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
         }
     )");
 
     // Check it is valid for the pipeline to use a subset of the VertexState
     CreatePipeline(true, state, R"(
-        [[location(0)]] var<in> a : vec4<f32>;
-        [[builtin(position)]] var<out> Position : vec4<f32>;
-        [[stage(vertex)]] fn main() {
-            Position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        [[stage(vertex)]] fn main(
+            [[location(0)]] a : vec4<f32>
+        ) -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
         }
     )");
 
     // Check for an error when the pipeline uses an attribute not in the vertex input
     CreatePipeline(false, state, R"(
-        [[location(2)]] var<in> a : vec4<f32>;
-        [[builtin(position)]] var<out> Position : vec4<f32>;
-        [[stage(vertex)]] fn main() {
-            Position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        [[stage(vertex)]] fn main(
+            [[location(2)]] a : vec4<f32>
+        ) -> [[builtin(position)]] vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 0.0);
         }
     )");
 }
diff --git a/src/tests/white_box/D3D12DescriptorHeapTests.cpp b/src/tests/white_box/D3D12DescriptorHeapTests.cpp
index e4f6fe6..5393e22 100644
--- a/src/tests/white_box/D3D12DescriptorHeapTests.cpp
+++ b/src/tests/white_box/D3D12DescriptorHeapTests.cpp
@@ -40,16 +40,16 @@
         mD3DDevice = reinterpret_cast<Device*>(device.Get());
 
         mSimpleVSModule = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[builtin(vertex_index)]] var<in> VertexIndex : u32;
 
-            [[stage(vertex)]] fn main() {
+            [[stage(vertex)]] fn main(
+                [[builtin(vertex_index)]] VertexIndex : u32
+            ) -> [[builtin(position)]] vec4<f32> {
                 const pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
                     vec2<f32>(-1.0,  1.0),
                     vec2<f32>( 1.0,  1.0),
                     vec2<f32>(-1.0, -1.0)
                 );
-                Position = vec4<f32>(pos[VertexIndex], 0.0, 1.0);
+                return vec4<f32>(pos[VertexIndex], 0.0, 1.0);
             })");
 
         mSimpleFSModule = utils::CreateShaderModule(device, R"(
@@ -57,10 +57,9 @@
                 color : vec4<f32>;
             };
             [[group(0), binding(0)]] var<uniform> colorBuffer : U;
-            [[location(0)]] var<out> FragColor : vec4<f32>;
 
-            [[stage(fragment)]] fn main() {
-                FragColor = colorBuffer.color;
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return colorBuffer.color;
             })");
     }
 
@@ -177,17 +176,15 @@
     // sampler bindgroup each draw. After HEAP_SIZE + 1 draws, the heaps WILL NOT switch over
     // because the sampler heap allocations are de-duplicated.
     renderPipelineDescriptor.vertex.module = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+                return vec4<f32>(0.0, 0.0, 0.0, 1.0);
             })");
 
     renderPipelineDescriptor.cFragment.module = utils::CreateShaderModule(device, R"(
-            [[location(0)]] var<out> FragColor : vec4<f32>;
             [[group(0), binding(0)]] var sampler0 : sampler;
-            [[stage(fragment)]] fn main() {
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
                 let referenceSampler : sampler = sampler0;
-                FragColor = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+                return vec4<f32>(0.0, 0.0, 0.0, 0.0);
             })");
 
     wgpu::RenderPipeline renderPipeline = device.CreateRenderPipeline2(&renderPipelineDescriptor);
@@ -787,16 +784,16 @@
                 transform : mat2x2<f32>;
             };
             [[group(0), binding(0)]] var<uniform> buffer0 : U;
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[builtin(vertex_index)]] var<in> VertexIndex : u32;
 
-            [[stage(vertex)]] fn main() {
+            [[stage(vertex)]] fn main(
+                [[builtin(vertex_index)]] VertexIndex : u32
+            ) -> [[builtin(position)]] vec4<f32> {
                 const pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
                     vec2<f32>(-1.0,  1.0),
                     vec2<f32>( 1.0,  1.0),
                     vec2<f32>(-1.0, -1.0)
                 );
-                Position = vec4<f32>(buffer0.transform * (pos[VertexIndex]), 0.0, 1.0);
+                return vec4<f32>(buffer0.transform * (pos[VertexIndex]), 0.0, 1.0);
             })");
         pipelineDescriptor.cFragment.module = utils::CreateShaderModule(device, R"(
             [[block]] struct U {
@@ -806,11 +803,10 @@
             [[group(0), binding(2)]] var texture0 : texture_2d<f32>;
             [[group(0), binding(3)]] var<uniform> buffer0 : U;
 
-            [[location(0)]] var<out> FragColor : vec4<f32>;
-            [[builtin(frag_coord)]] var<in> FragCoord : vec4<f32>;
-
-            [[stage(fragment)]] fn main() {
-                FragColor = textureSample(texture0, sampler0, FragCoord.xy) + buffer0.color;
+            [[stage(fragment)]] fn main(
+                [[builtin(frag_coord)]] FragCoord : vec4<f32>
+            ) -> [[location(0)]] vec4<f32> {
+                return textureSample(texture0, sampler0, FragCoord.xy) + buffer0.color;
             })");
 
         utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
diff --git a/src/tests/white_box/D3D12ResidencyTests.cpp b/src/tests/white_box/D3D12ResidencyTests.cpp
index 1b5ce49..804d933 100644
--- a/src/tests/white_box/D3D12ResidencyTests.cpp
+++ b/src/tests/white_box/D3D12ResidencyTests.cpp
@@ -339,16 +339,15 @@
     // Fill in a view heap with "view only" bindgroups (1x view per group) by creating a
     // view bindgroup each draw. After HEAP_SIZE + 1 draws, the heaps must switch over.
     renderPipelineDescriptor.vertex.module = utils::CreateShaderModule(device, R"(
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[builtin(vertex_index)]] var<in> VertexIndex : u32;
-
-            [[stage(vertex)]] fn main() {
+            [[stage(vertex)]] fn main(
+                [[builtin(vertex_index)]] VertexIndex : u32
+            ) -> [[builtin(position)]] vec4<f32> {
                 const pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
                     vec2<f32>(-1.0,  1.0),
                     vec2<f32>( 1.0,  1.0),
                     vec2<f32>(-1.0, -1.0)
                 );
-                Position = vec4<f32>(pos[VertexIndex], 0.0, 1.0);
+                return vec4<f32>(pos[VertexIndex], 0.0, 1.0);
             })");
 
     renderPipelineDescriptor.cFragment.module = utils::CreateShaderModule(device, R"(
@@ -356,10 +355,9 @@
                 color : vec4<f32>;
             };
             [[group(0), binding(0)]] var<uniform> colorBuffer : U;
-            [[location(0)]] var<out> FragColor : vec4<f32>;
 
-            [[stage(fragment)]] fn main() {
-                FragColor = colorBuffer.color;
+            [[stage(fragment)]] fn main() -> [[location(0)]] vec4<f32> {
+                return colorBuffer.color;
             })");
 
     wgpu::RenderPipeline renderPipeline = device.CreateRenderPipeline2(&renderPipelineDescriptor);