[ir] Strip interpolation attributes when invalid

In the ShaderIO transform base, remove interpolation attributes when
they are not placed on vertex outputs or fragment inputs, as most
backends do not support them in other places.

Bug: tint:1718
Change-Id: I4e57aeff699d52163f41fc31480a09269d0d2668
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/151584
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/transform/shader_io.cc b/src/tint/lang/core/ir/transform/shader_io.cc
index f279b86..1d072af 100644
--- a/src/tint/lang/core/ir/transform/shader_io.cc
+++ b/src/tint/lang/core/ir/transform/shader_io.cc
@@ -135,8 +135,12 @@
             if (auto* str = param->Type()->As<core::type::Struct>()) {
                 for (auto* member : str->Members()) {
                     auto name = str->Name().Name() + "_" + member->Name().Name();
-                    backend->AddInput(ir->symbols.Register(name), member->Type(),
-                                      member->Attributes());
+                    auto attributes = member->Attributes();
+                    if (attributes.interpolation &&
+                        func->Stage() != Function::PipelineStage::kFragment) {
+                        attributes.interpolation = {};
+                    }
+                    backend->AddInput(ir->symbols.Register(name), member->Type(), attributes);
                     members_to_strip.Add(member);
                 }
             } else {
@@ -144,7 +148,7 @@
                 core::type::StructMemberAttributes attributes;
                 if (auto loc = param->Location()) {
                     attributes.location = loc->value;
-                    if (loc->interpolation) {
+                    if (loc->interpolation && func->Stage() == Function::PipelineStage::kFragment) {
                         attributes.interpolation = *loc->interpolation;
                     }
                     param->ClearLocation();
@@ -170,8 +174,11 @@
         if (auto* str = func->ReturnType()->As<core::type::Struct>()) {
             for (auto* member : str->Members()) {
                 auto name = str->Name().Name() + "_" + member->Name().Name();
-                backend->AddOutput(ir->symbols.Register(name), member->Type(),
-                                   member->Attributes());
+                auto attributes = member->Attributes();
+                if (attributes.interpolation && func->Stage() != Function::PipelineStage::kVertex) {
+                    attributes.interpolation = {};
+                }
+                backend->AddOutput(ir->symbols.Register(name), member->Type(), attributes);
                 members_to_strip.Add(member);
             }
         } else {
@@ -179,6 +186,9 @@
             core::type::StructMemberAttributes attributes;
             if (auto loc = func->ReturnLocation()) {
                 attributes.location = loc->value;
+                if (loc->interpolation && func->Stage() == Function::PipelineStage::kVertex) {
+                    attributes.interpolation = *loc->interpolation;
+                }
                 func->ClearReturnLocation();
             } else if (auto builtin = func->ReturnBuiltin()) {
                 attributes.builtin = ReturnBuiltin(*builtin);
diff --git a/src/tint/lang/spirv/writer/raise/shader_io_test.cc b/src/tint/lang/spirv/writer/raise/shader_io_test.cc
index 8814aee..0f01cf5 100644
--- a/src/tint/lang/spirv/writer/raise/shader_io_test.cc
+++ b/src/tint/lang/spirv/writer/raise/shader_io_test.cc
@@ -837,6 +837,149 @@
     EXPECT_EQ(expect, str());
 }
 
+// Test that interpolation attributes are stripped from vertex inputs and fragment outputs.
+TEST_F(SpirvWriter_ShaderIOTest, InterpolationOnVertexInputOrFragmentOutput) {
+    auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"),
+                             {
+                                 {
+                                     mod.symbols.New("color"),
+                                     ty.f32(),
+                                     {1u,
+                                      {},
+                                      {},
+                                      core::Interpolation{core::InterpolationType::kLinear,
+                                                          core::InterpolationSampling::kSample},
+                                      false},
+                                 },
+                             });
+
+    // Vertex shader.
+    {
+        auto* ep = b.Function("vert", ty.vec4<f32>());
+        ep->SetReturnBuiltin(core::ir::Function::ReturnBuiltin::kPosition);
+        ep->SetReturnInvariant(true);
+        ep->SetStage(core::ir::Function::PipelineStage::kVertex);
+
+        auto* str_param = b.FunctionParam("input", str_ty);
+        auto* ival = b.FunctionParam("ival", ty.i32());
+        ival->SetLocation(1, core::Interpolation{core::InterpolationType::kFlat});
+        ep->SetParams({str_param, ival});
+
+        b.Append(ep->Block(), [&] {  //
+            b.Return(ep, b.Construct(ty.vec4<f32>(), 0.5_f));
+        });
+    }
+
+    // Fragment shader with struct output.
+    {
+        auto* ep = b.Function("frag1", str_ty);
+        ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+
+        b.Append(ep->Block(), [&] {  //
+            b.Return(ep, b.Construct(str_ty, 0.5_f));
+        });
+    }
+
+    // Fragment shader with non-struct output.
+    {
+        auto* ep = b.Function("frag2", ty.i32());
+        ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+        ep->SetReturnLocation(0, core::Interpolation{core::InterpolationType::kFlat});
+
+        b.Append(ep->Block(), [&] {  //
+            b.Return(ep, b.Constant(42_i));
+        });
+    }
+
+    auto* src = R"(
+MyStruct = struct @align(4) {
+  color:f32 @offset(0), @location(1), @interpolate(linear, sample)
+}
+
+%vert = @vertex func(%input:MyStruct, %ival:i32 [@location(1), @interpolate(flat)]):vec4<f32> [@invariant, @position] -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct 0.5f
+    ret %4
+  }
+}
+%frag1 = @fragment func():MyStruct -> %b2 {
+  %b2 = block {
+    %6:MyStruct = construct 0.5f
+    ret %6
+  }
+}
+%frag2 = @fragment func():i32 [@location(0), @interpolate(flat)] -> %b3 {
+  %b3 = block {
+    ret 42i
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+MyStruct = struct @align(4) {
+  color:f32 @offset(0)
+}
+
+%b1 = block {  # root
+  %vert_loc1_Input:ptr<__in, f32, read> = var @location(1)
+  %vert_loc1_Input_1:ptr<__in, i32, read> = var @location(1)  # %vert_loc1_Input_1: 'vert_loc1_Input'
+  %vert_position_Output:ptr<__out, vec4<f32>, write> = var @invariant @builtin(position)
+  %frag1_loc1_Output:ptr<__out, f32, write> = var @location(1)
+  %frag2_loc0_Output:ptr<__out, i32, write> = var @location(0)
+}
+
+%vert_inner = func(%input:MyStruct, %ival:i32):vec4<f32> -> %b2 {
+  %b2 = block {
+    %9:vec4<f32> = construct 0.5f
+    ret %9
+  }
+}
+%frag1_inner = func():MyStruct -> %b3 {
+  %b3 = block {
+    %11:MyStruct = construct 0.5f
+    ret %11
+  }
+}
+%frag2_inner = func():i32 -> %b4 {
+  %b4 = block {
+    ret 42i
+  }
+}
+%vert = @vertex func():void -> %b5 {
+  %b5 = block {
+    %14:f32 = load %vert_loc1_Input
+    %15:MyStruct = construct %14
+    %16:i32 = load %vert_loc1_Input_1
+    %17:vec4<f32> = call %vert_inner, %15, %16
+    store %vert_position_Output, %17
+    ret
+  }
+}
+%frag1 = @fragment func():void -> %b6 {
+  %b6 = block {
+    %19:MyStruct = call %frag1_inner
+    %20:f32 = access %19, 0u
+    store %frag1_loc1_Output, %20
+    ret
+  }
+}
+%frag2 = @fragment func():void -> %b7 {
+  %b7 = block {
+    %22:i32 = call %frag2_inner
+    store %frag2_loc0_Output, %22
+    ret
+  }
+}
+)";
+
+    ShaderIOConfig config;
+    config.clamp_frag_depth = false;
+    Run(ShaderIO, config);
+
+    EXPECT_EQ(expect, str());
+}
+
 TEST_F(SpirvWriter_ShaderIOTest, ClampFragDepth) {
     auto* str_ty = ty.Struct(mod.symbols.New("Outputs"),
                              {