WGSL: Replace last uses of var<in> and var<out>

Bug: dawn:755

Change-Id: Idaca6965fd2b5d0f2e0028d8edfff6c507050a45
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/48240
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Brandon Jones <bajones@chromium.org>
Commit-Queue: Brandon Jones <bajones@chromium.org>
diff --git a/src/dawn_native/CopyTextureForBrowserHelper.cpp b/src/dawn_native/CopyTextureForBrowserHelper.cpp
index 386127f..ac730dc 100644
--- a/src/dawn_native/CopyTextureForBrowserHelper.cpp
+++ b/src/dawn_native/CopyTextureForBrowserHelper.cpp
@@ -38,16 +38,23 @@
                 u_scale : vec2<f32>;
                 u_offset : vec2<f32>;
             };
+            [[binding(0), group(0)]] var<uniform> uniforms : Uniforms;
+
             const texcoord : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
                 vec2<f32>(-0.5, 0.0),
                 vec2<f32>( 1.5, 0.0),
                 vec2<f32>( 0.5, 2.0));
-            [[location(0)]] var<out> v_texcoord: vec2<f32>;
-            [[builtin(position)]] var<out> Position : vec4<f32>;
-            [[builtin(vertex_index)]] var<in> VertexIndex : u32;
-            [[binding(0), group(0)]] var<uniform> uniforms : Uniforms;
-            [[stage(vertex)]] fn main() {
-                Position = vec4<f32>((texcoord[VertexIndex] * 2.0 - vec2<f32>(1.0, 1.0)), 0.0, 1.0);
+
+            struct VertexOutputs {
+                [[location(0)]] texcoords : vec2<f32>;
+                [[builtin(position)]] position : vec4<f32>;
+            };
+
+            [[stage(vertex)]] fn main(
+                [[builtin(vertex_index)]] VertexIndex : u32
+            ) -> VertexOutputs {
+                var output : VertexOutputs;
+                output.position = vec4<f32>((texcoord[VertexIndex] * 2.0 - vec2<f32>(1.0, 1.0)), 0.0, 1.0);
 
                 // Y component of scale is calculated by the copySizeHeight / textureHeight. Only
                 // flipY case can get negative number.
@@ -59,33 +66,38 @@
                     // We need to get the mirror positions(mirrored based on y = 0.5) on flip cases.
                     // Adopt transform to src texture and then mapping it to triangle coord which
                     // do a +1 shift on Y dimension will help us got that mirror position perfectly.
-                    v_texcoord = (texcoord[VertexIndex] * uniforms.u_scale + uniforms.u_offset) *
-                                  vec2<f32>(1.0, -1.0) + vec2<f32>(0.0, 1.0);
+                    output.texcoords = (texcoord[VertexIndex] * uniforms.u_scale + uniforms.u_offset) *
+                        vec2<f32>(1.0, -1.0) + vec2<f32>(0.0, 1.0);
                 } else {
                     // For the normal case, we need to get the exact position.
                     // So mapping texture to triangle firstly then adopt the transform.
-                    v_texcoord = (texcoord[VertexIndex] *
-                                  vec2<f32>(1.0, -1.0) + vec2<f32>(0.0, 1.0)) *
-                                  uniforms.u_scale + uniforms.u_offset;
+                    output.texcoords = (texcoord[VertexIndex] *
+                        vec2<f32>(1.0, -1.0) + vec2<f32>(0.0, 1.0)) *
+                        uniforms.u_scale + uniforms.u_offset;
                 }
+
+                return output;
             }
         )";
 
         static const char sCopyTextureForBrowserFragment[] = R"(
             [[binding(1), group(0)]] var mySampler: sampler;
             [[binding(2), group(0)]] var myTexture: texture_2d<f32>;
-            [[location(0)]] var<in> v_texcoord : vec2<f32>;
-            [[location(0)]] var<out> outputColor : vec4<f32>;
-            [[stage(fragment)]] fn main() {
+
+            [[stage(fragment)]] fn main(
+                [[location(0)]] texcoord : vec2<f32>
+            ) -> [[location(0)]] vec4<f32> {
                 // Clamp the texcoord and discard the out-of-bound pixels.
                 var clampedTexcoord : vec2<f32> =
-                    clamp(v_texcoord, vec2<f32>(0.0, 0.0), vec2<f32>(1.0, 1.0));
-                if (all(clampedTexcoord == v_texcoord)) {
-                    var srcColor : vec4<f32> = textureSample(myTexture, mySampler, v_texcoord);
-                    // Swizzling of texture formats when sampling / rendering is handled by the
-                    // hardware so we don't need special logic in this shader. This is covered by tests.
-                    outputColor = srcColor;
+                    clamp(texcoord, vec2<f32>(0.0, 0.0), vec2<f32>(1.0, 1.0));
+                if (!all(clampedTexcoord == texcoord)) {
+                    discard;
                 }
+
+                var srcColor : vec4<f32> = textureSample(myTexture, mySampler, texcoord);
+                // Swizzling of texture formats when sampling / rendering is handled by the
+                // hardware so we don't need special logic in this shader. This is covered by tests.
+                return srcColor;
             }
         )";
 
diff --git a/src/dawn_native/QueryHelper.cpp b/src/dawn_native/QueryHelper.cpp
index 177dcc0..efc851b 100644
--- a/src/dawn_native/QueryHelper.cpp
+++ b/src/dawn_native/QueryHelper.cpp
@@ -58,12 +58,11 @@
                 var<storage> availability : [[access(read)]] AvailabilityArr;
             [[group(0), binding(2)]] var<uniform> params : TimestampParams;
 
-            [[builtin(global_invocation_id)]] var<in> GlobalInvocationID : vec3<u32>;
 
             const sizeofTimestamp : u32 = 8u;
 
             [[stage(compute), workgroup_size(8, 1, 1)]]
-            fn main() {
+            fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
                 if (GlobalInvocationID.x >= params.count) { return; }
 
                 var index : u32 = GlobalInvocationID.x + params.offset / sizeofTimestamp;
diff --git a/src/tests/end2end/DeprecatedAPITests.cpp b/src/tests/end2end/DeprecatedAPITests.cpp
index b78fb79..341c1f3 100644
--- a/src/tests/end2end/DeprecatedAPITests.cpp
+++ b/src/tests/end2end/DeprecatedAPITests.cpp
@@ -538,16 +538,15 @@
   protected:
     // Runs the test
     void DoTest(const wgpu::VertexFormat vertexFormat, bool deprecated) {
-        std::string attribute = "[[location(0)]] var<in> a : ";
+        std::string attribute = "[[location(0)]] a : ";
         attribute += dawn::GetWGSLVertexFormatType(vertexFormat);
-        attribute += ";";
 
         std::string attribAccess = dawn::VertexFormatNumComponents(vertexFormat) > 1
                                        ? "vec4<f32>(f32(a.x), 0.0, 0.0, 1.0)"
                                        : "vec4<f32>(f32(a), 0.0, 0.0, 1.0)";
 
-        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, (attribute + R"(
-                [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, (R"(
+                [[stage(vertex)]] fn main()" + attribute + R"() -> [[builtin(position)]] vec4<f32> {
                     return )" + attribAccess + R"(;
                 }
             )")
diff --git a/src/tests/end2end/DepthStencilSamplingTests.cpp b/src/tests/end2end/DepthStencilSamplingTests.cpp
index c03349b..761fdd8 100644
--- a/src/tests/end2end/DepthStencilSamplingTests.cpp
+++ b/src/tests/end2end/DepthStencilSamplingTests.cpp
@@ -73,6 +73,7 @@
         utils::ComboRenderPipelineDescriptor2 pipelineDescriptor;
 
         std::ostringstream shaderSource;
+        std::ostringstream shaderOutputStruct;
         std::ostringstream shaderBody;
 
         uint32_t index = 0;
@@ -82,10 +83,10 @@
                     shaderSource << "[[group(0), binding(" << index << ")]] var tex" << index
                                  << " : texture_2d<f32>;\n";
 
-                    shaderSource << "[[location(" << index << ")]] var<out> result" << index
-                                 << " : f32;\n";
+                    shaderOutputStruct << "  [[location(" << index << ")]] result" << index
+                                       << " : f32;\n";
 
-                    shaderBody << "\nresult" << index << " = textureLoad(tex" << index
+                    shaderBody << "\n  output.result" << index << " = textureLoad(tex" << index
                                << ", vec2<i32>(0, 0), 0)[" << componentIndex << "];\n";
                     pipelineDescriptor.cTargets[index].format = wgpu::TextureFormat::R32Float;
                     break;
@@ -93,10 +94,10 @@
                     shaderSource << "[[group(0), binding(" << index << ")]] var tex" << index
                                  << " : texture_2d<u32>;\n";
 
-                    shaderSource << "[[location(" << index << ")]] var<out> result" << index
-                                 << " : u32;\n";
+                    shaderOutputStruct << "  [[location(" << index << ")]] result" << index
+                                       << " : u32;\n";
 
-                    shaderBody << "\nresult" << index << " = textureLoad(tex" << index
+                    shaderBody << "\n  output.result" << index << " = textureLoad(tex" << index
                                << ", vec2<i32>(0, 0), 0)[" << componentIndex << "];\n";
                     pipelineDescriptor.cTargets[index].format = wgpu::TextureFormat::R8Uint;
                     break;
@@ -105,7 +106,10 @@
             index++;
         }
 
-        shaderSource << "[[stage(fragment)]] fn main() { " << shaderBody.str() << "\n}";
+        shaderSource << "struct FragOutputs {\n" << shaderOutputStruct.str() << "};\n";
+        shaderSource << "[[stage(fragment)]] fn main() -> FragOutputs {\n";
+        shaderSource << "  var output : FragOutputs;\n"
+                     << shaderBody.str() << "  return output;\n}";
 
         wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, shaderSource.str().c_str());
         pipelineDescriptor.vertex.module = vsModule;
diff --git a/src/tests/end2end/FirstIndexOffsetTests.cpp b/src/tests/end2end/FirstIndexOffsetTests.cpp
index 46d2d6b..680cd17 100644
--- a/src/tests/end2end/FirstIndexOffsetTests.cpp
+++ b/src/tests/end2end/FirstIndexOffsetTests.cpp
@@ -86,77 +86,68 @@
                                      uint32_t firstVertex,
                                      uint32_t firstInstance) {
     using wgpu::operator&;
-    std::stringstream vertexShader;
-    std::stringstream fragmentShader;
+
+    std::stringstream vertexInputs;
+    std::stringstream vertexOutputs;
+    std::stringstream vertexBody;
+    std::stringstream fragmentInputs;
+    std::stringstream fragmentBody;
+
+    vertexInputs << "  [[location(0)]] position : vec4<f32>;\n";
+    vertexOutputs << "  [[builtin(position)]] position : vec4<f32>;\n";
 
     if ((checkIndex & CheckIndex::Vertex) != 0) {
-        vertexShader << R"(
-        [[builtin(vertex_index)]] var<in> vertex_index : u32;
-        [[location(1)]] var<out> out_vertex_index : u32;
-        )";
-        fragmentShader << R"(
-        [[location(1)]] var<in> in_vertex_index : u32;
-    )";
+        vertexInputs << "  [[builtin(vertex_index)]] vertex_index : u32;\n";
+        vertexOutputs << "  [[location(1)]] vertex_index : u32;\n";
+        vertexBody << "  output.vertex_index = input.vertex_index;\n";
+
+        fragmentInputs << "  [[location(1)]] vertex_index : u32;\n";
+        fragmentBody << "  idx_vals.vertex_index = input.vertex_index;\n";
     }
     if ((checkIndex & CheckIndex::Instance) != 0) {
-        vertexShader << R"(
-            [[builtin(instance_index)]] var<in> instance_index : u32;
-            [[location(2)]] var<out> out_instance_index : u32;
-            )";
-        fragmentShader << R"(
-            [[location(2)]] var<in> in_instance_index : u32;
-        )";
+        vertexInputs << "  [[builtin(instance_index)]] instance_index : u32;\n";
+        vertexOutputs << "  [[location(2)]] instance_index : u32;\n";
+        vertexBody << "  output.instance_index = input.instance_index;\n";
+
+        fragmentInputs << "  [[location(2)]] instance_index : u32;\n";
+        fragmentBody << "  idx_vals.instance_index = input.instance_index;\n";
     }
 
-    vertexShader << R"(
-        [[builtin(position)]] var<out> position : vec4<f32>;
-        [[location(0)]] var<in> pos : vec4<f32>;
+    std::string vertexShader = R"(
+struct VertexInputs {
+)" + vertexInputs.str() + R"(
+};
+struct VertexOutputs {
+)" + vertexOutputs.str() + R"(
+};
+[[stage(vertex)]] fn main(input : VertexInputs) -> VertexOutputs {
+  var output : VertexOutputs;
+)" + vertexBody.str() + R"(
+  output.position = input.position;
+  return output;
+})";
 
-        [[stage(vertex)]] fn main() {)";
-    fragmentShader << R"(
-         [[block]] struct IndexVals {
-             vertex_index : u32;
-             instance_index : u32;
-         };
+    std::string fragmentShader = R"(
+[[block]] struct IndexVals {
+  vertex_index : u32;
+  instance_index : u32;
+};
+[[group(0), binding(0)]] var<storage> idx_vals : [[access(read_write)]] IndexVals;
 
-        [[group(0), binding(0)]] var<storage> idx_vals : [[access(read_write)]] IndexVals;
-
-        [[stage(fragment)]] fn main() {
-        )";
-
-    if ((checkIndex & CheckIndex::Vertex) != 0) {
-        vertexShader << R"(
-            out_vertex_index = vertex_index;
-            )";
-        fragmentShader << R"(
-            idx_vals.vertex_index = in_vertex_index;
-            )";
-    }
-    if ((checkIndex & CheckIndex::Instance) != 0) {
-        vertexShader << R"(
-            out_instance_index = instance_index;
-            )";
-        fragmentShader << R"(
-            idx_vals.instance_index = in_instance_index;
-            )";
-    }
-
-    vertexShader << R"(
-            position = pos;
-            return;
-        })";
-
-    fragmentShader << R"(
-            return;
-        })";
+struct FragInputs {
+)" + fragmentInputs.str() + R"(
+};
+[[stage(fragment)]] fn main(input : FragInputs) {
+)" + fragmentBody.str() + R"(
+})";
 
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
     constexpr uint32_t kComponentsPerVertex = 4;
 
     utils::ComboRenderPipelineDescriptor2 pipelineDesc;
-    pipelineDesc.vertex.module = utils::CreateShaderModule(device, vertexShader.str().c_str());
-    pipelineDesc.cFragment.module = utils::CreateShaderModule(device, fragmentShader.str().c_str());
+    pipelineDesc.vertex.module = utils::CreateShaderModule(device, vertexShader.c_str());
+    pipelineDesc.cFragment.module = utils::CreateShaderModule(device, fragmentShader.c_str());
     pipelineDesc.primitive.topology = wgpu::PrimitiveTopology::PointList;
     pipelineDesc.vertex.bufferCount = 1;
     pipelineDesc.cBuffers[0].arrayStride = kComponentsPerVertex * sizeof(float);
diff --git a/src/tests/end2end/VertexBufferRobustnessTests.cpp b/src/tests/end2end/VertexBufferRobustnessTests.cpp
index 7253560..1bc99ee 100644
--- a/src/tests/end2end/VertexBufferRobustnessTests.cpp
+++ b/src/tests/end2end/VertexBufferRobustnessTests.cpp
@@ -30,10 +30,12 @@
     // Creates a vertex module that tests an expression with given attributes. If successful, the
     // point drawn would be moved out of the viewport. On failure, the point is kept inside the
     // viewport.
-    wgpu::ShaderModule CreateVertexModule(const std::string& attributes,
+    wgpu::ShaderModule CreateVertexModule(const std::string& attribute,
                                           const std::string& successExpression) {
-        return utils::CreateShaderModule(device, (attributes + R"(
-                [[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
+        return utils::CreateShaderModule(device, (R"(
+                [[stage(vertex)]] fn main(
+                    )" + attribute + R"(
+                ) -> [[builtin(position)]] vec4<f32> {
                     if ()" + successExpression + R"() {
                         // Success case, move the vertex out of the viewport
                         return vec4<f32>(-10.0, 0.0, 0.0, 1.0);
@@ -102,7 +104,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : f32;", "a == 473.0", vertexState, vertexBuffer, 0, false);
+    DoTest("[[location(0)]] a : f32", "a == 473.0", vertexState, vertexBuffer, 0, false);
 }
 
 TEST_P(VertexBufferRobustnessTest, FloatClamp) {
@@ -119,7 +121,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : f32;", "a == 473.0", vertexState, vertexBuffer, 4, true);
+    DoTest("[[location(0)]] a : f32", "a == 473.0", vertexState, vertexBuffer, 4, true);
 }
 
 TEST_P(VertexBufferRobustnessTest, IntClamp) {
@@ -136,7 +138,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : i32;", "a == 473", vertexState, vertexBuffer, 4, true);
+    DoTest("[[location(0)]] a : i32", "a == 473", vertexState, vertexBuffer, 4, true);
 }
 
 TEST_P(VertexBufferRobustnessTest, UIntClamp) {
@@ -153,7 +155,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : u32;", "a == 473u", vertexState, vertexBuffer, 4, true);
+    DoTest("[[location(0)]] a : u32", "a == 473u", vertexState, vertexBuffer, 4, true);
 }
 
 TEST_P(VertexBufferRobustnessTest, Float2Clamp) {
@@ -170,7 +172,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : vec2<f32>;", "a[0] == 473.0 && a[1] == 473.0",
+    DoTest("[[location(0)]] a : vec2<f32>", "a[0] == 473.0 && a[1] == 473.0",
            std::move(vertexState), vertexBuffer, 8, true);
 }
 
@@ -188,7 +190,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : vec3<f32>;",
+    DoTest("[[location(0)]] a : vec3<f32>",
            "a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0", vertexState, vertexBuffer, 12, true);
 }
 
@@ -206,7 +208,7 @@
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
                                                             wgpu::BufferUsage::Vertex);
 
-    DoTest("[[location(0)]] var<in> a : vec4<f32>;",
+    DoTest("[[location(0)]] a : vec4<f32>",
            "a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0 && a[3] == 473.0", vertexState,
            vertexBuffer, 16, true);
 }
diff --git a/src/tests/white_box/D3D12DescriptorHeapTests.cpp b/src/tests/white_box/D3D12DescriptorHeapTests.cpp
index 7310941..ffd2ccc 100644
--- a/src/tests/white_box/D3D12DescriptorHeapTests.cpp
+++ b/src/tests/white_box/D3D12DescriptorHeapTests.cpp
@@ -448,10 +448,9 @@
             heapSize : f32;
         };
         [[group(0), binding(0)]] var<uniform> buffer0 : U;
-        [[location(0)]] var<out> FragColor : f32;
 
-        [[stage(fragment)]] fn main() {
-            FragColor = buffer0.heapSize;
+        [[stage(fragment)]] fn main() -> [[location(0)]] f32 {
+            return buffer0.heapSize;
         })");
 
     wgpu::BlendState blend;