[tint][resolver] Fix @index validation, tweak diagnostics

Validator::EntryPoint() was returning early for IndexAttribute, which
was missing later validation. Replace the if-else chain with a Switch()
which makes early return impossible.

Also change the diagnostics for attributes that can only be applied to a
specific stage, so that the message states what stage it can be used in.
Knowing whether the usage is input or output is not that helpful.

Change-Id: I47352a006f45ad6421aa102697216206d0e2044d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/159783
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/wgsl/resolver/attribute_validation_test.cc b/src/tint/lang/wgsl/resolver/attribute_validation_test.cc
index a82437d..96bbcae 100644
--- a/src/tint/lang/wgsl/resolver/attribute_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/attribute_validation_test.cc
@@ -660,15 +660,15 @@
         },
         TestParams{
             {AttributeKind::kInterpolate},
-            R"(1:2 error: @interpolate is not valid for compute shader inputs)",
+            R"(1:2 error: @interpolate cannot be used by compute shaders)",
         },
         TestParams{
             {AttributeKind::kInvariant},
-            R"(1:2 error: @invariant is not valid for compute shader inputs)",
+            R"(1:2 error: @invariant cannot be used by compute shaders)",
         },
         TestParams{
             {AttributeKind::kLocation},
-            R"(1:2 error: @location is not valid for compute shader inputs)",
+            R"(1:2 error: @location cannot be used by compute shaders)",
         },
         TestParams{
             {AttributeKind::kMustUse},
@@ -936,19 +936,19 @@
         },
         TestParams{
             {AttributeKind::kIndex},
-            R"(1:2 error: @index is not valid for compute shader output)",
+            R"(1:2 error: @index can only be used as fragment shader output)",
         },
         TestParams{
             {AttributeKind::kInterpolate},
-            R"(1:2 error: @interpolate is not valid for compute shader output)",
+            R"(1:2 error: @interpolate cannot be used by compute shaders)",
         },
         TestParams{
             {AttributeKind::kInvariant},
-            R"(1:2 error: @invariant is not valid for compute shader output)",
+            R"(1:2 error: @invariant cannot be used by compute shaders)",
         },
         TestParams{
             {AttributeKind::kLocation},
-            R"(1:2 error: @location is not valid for compute shader output)",
+            R"(1:2 error: @location cannot be used by compute shaders)",
         },
         TestParams{
             {AttributeKind::kMustUse},
@@ -1018,7 +1018,7 @@
         },
         TestParams{
             {AttributeKind::kIndex},
-            Pass,
+            R"(9:9 error: missing entry point IO attribute on return type)",
         },
         TestParams{
             {AttributeKind::kIndex, AttributeKind::kLocation},
@@ -1126,7 +1126,7 @@
         },
         TestParams{
             {AttributeKind::kIndex},
-            R"(1:2 error: @index is not valid for vertex shader output)",
+            R"(1:2 error: @index can only be used as fragment shader output)",
         },
         TestParams{
             {AttributeKind::kInterpolate},
@@ -1326,7 +1326,7 @@
                              },
                              TestParams{
                                  {AttributeKind::kIndex},
-                                 R"(1:2 error: @index can only be used with @location)",
+                                 R"(1:2 error: @index can only be used with @location(0))",
                              },
                              TestParams{
                                  {AttributeKind::kInterpolate},
diff --git a/src/tint/lang/wgsl/resolver/dual_source_blending_extension_test.cc b/src/tint/lang/wgsl/resolver/dual_source_blending_extension_test.cc
index 6adfb36..10acd30 100644
--- a/src/tint/lang/wgsl/resolver/dual_source_blending_extension_test.cc
+++ b/src/tint/lang/wgsl/resolver/dual_source_blending_extension_test.cc
@@ -115,14 +115,30 @@
 }
 
 // Using the index attribute without a location attribute should fail.
-TEST_F(DualSourceBlendingExtensionTests, IndexWithMissingLocationAttribute) {
+TEST_F(DualSourceBlendingExtensionTests, IndexWithMissingLocationAttribute_Struct) {
     Structure("Output", Vector{
                             Member(Source{{12, 34}}, "a", ty.vec4<f32>(),
                                    Vector{Index(Source{{12, 34}}, 1_a)}),
                         });
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), "12:34 error: @index can only be used with @location");
+    EXPECT_EQ(r()->error(), "12:34 error: @index can only be used with @location(0)");
+}
+
+// Using the index attribute without a location attribute should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexWithMissingLocationAttribute_ReturnValue) {
+    Func("F", Empty, ty.vec4<f32>(),
+         Vector{
+             Return(Call<vec4<f32>>()),
+         },
+         Vector{Stage(ast::PipelineStage::kFragment)},
+         Vector{
+             Index(Source{{12, 34}}, 1_a),
+             Builtin(core::BuiltinValue::kPointSize),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(), "12:34 error: @index can only be used with @location(0)");
 }
 
 // Using an index attribute on a struct member should pass.
@@ -148,7 +164,7 @@
 }
 
 // Using the index attribute with a non-zero location should fail.
-TEST_F(DualSourceBlendingExtensionTests, IndexWithNonZeroLocation) {
+TEST_F(DualSourceBlendingExtensionTests, IndexWithNonZeroLocation_Struct) {
     Structure("Output",
               Vector{
                   Member("a", ty.vec4<f32>(), Vector{Location(1_a), Index(Source{{12, 34}}, 0_a)}),
@@ -158,6 +174,22 @@
     EXPECT_EQ(r()->error(), "12:34 error: @index can only be used with @location(0)");
 }
 
+// Using the index attribute with a non-zero location should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexWithNonZeroLocation_ReturnValue) {
+    Func("F", Empty, ty.vec4<f32>(),
+         Vector{
+             Return(Call<vec4<f32>>()),
+         },
+         Vector{Stage(ast::PipelineStage::kFragment)},
+         Vector{
+             Location(1_a),
+             Index(Source{{12, 34}}, 1_a),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(), "12:34 error: @index can only be used with @location(0)");
+}
+
 class DualSourceBlendingExtensionTestWithParams : public ResolverTestWithParam<int> {
   public:
     DualSourceBlendingExtensionTestWithParams() {
@@ -167,12 +199,14 @@
 
 // Rendering to multiple render targets while using dual source blending should fail.
 TEST_P(DualSourceBlendingExtensionTestWithParams, MultipleRenderTargetsNotAllowed) {
-    Structure("Output",
+    Structure("S",
               Vector{
                   Member("a", ty.vec4<f32>(), Vector{Location(0_a), Index(0_a)}),
                   Member("b", ty.vec4<f32>(), Vector{Location(0_a), Index(1_a)}),
                   Member("c", ty.vec4<f32>(), Vector{Location(Source{{12, 34}}, AInt(GetParam()))}),
               });
+    Func("F", Empty, ty("S"), Vector{Return(Call("S"))},
+         Vector{Stage(ast::PipelineStage::kFragment)});
 
     EXPECT_FALSE(r()->Resolve());
     StringStream err;
diff --git a/src/tint/lang/wgsl/resolver/entry_point_validation_test.cc b/src/tint/lang/wgsl/resolver/entry_point_validation_test.cc
index 034e3b0..444f9e5 100644
--- a/src/tint/lang/wgsl/resolver/entry_point_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/entry_point_validation_test.cc
@@ -1072,7 +1072,7 @@
          });
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), R"(12:34 error: @location is not valid for compute shader output)");
+    EXPECT_EQ(r()->error(), R"(12:34 error: @location cannot be used by compute shaders)");
 }
 
 TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) {
@@ -1087,7 +1087,7 @@
          });
 
     EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(), R"(12:34 error: @location is not valid for compute shader inputs)");
+    EXPECT_EQ(r()->error(), R"(12:34 error: @location cannot be used by compute shaders)");
 }
 
 TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) {
@@ -1107,7 +1107,7 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: @location is not valid for compute shader output\n"
+              "12:34 error: @location cannot be used by compute shaders\n"
               "56:78 note: while analyzing entry point 'main'");
 }
 
@@ -1126,7 +1126,7 @@
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: @location is not valid for compute shader inputs\n"
+              "12:34 error: @location cannot be used by compute shaders\n"
               "56:78 note: while analyzing entry point 'main'");
 }
 
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 4df0d80..929a508 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -137,10 +137,7 @@
 
 // Helper to stringify a pipeline IO attribute.
 std::string AttrToStr(const ast::Attribute* attr) {
-    return Switch(
-        attr,  //
-        [&](const ast::BuiltinAttribute*) { return "@builtin"; },
-        [&](const ast::LocationAttribute*) { return "@location"; });
+    return "@" + attr->Name();
 }
 
 template <typename CALLBACK>
@@ -1000,7 +997,13 @@
 }
 
 bool Validator::InterpolateAttribute(const ast::InterpolateAttribute* attr,
-                                     const core::type::Type* storage_ty) const {
+                                     const core::type::Type* storage_ty,
+                                     const ast::PipelineStage stage) const {
+    if (stage == ast::PipelineStage::kCompute) {
+        AddError(AttrToStr(attr) + " cannot be used by compute shaders", attr->source);
+        return false;
+    }
+
     auto* type = storage_ty->UnwrapRef();
 
     auto i_type = sem_.AsInterpolationType(sem_.Get(attr->type));
@@ -1022,6 +1025,15 @@
     return true;
 }
 
+bool Validator::InvariantAttribute(const ast::InvariantAttribute* attr,
+                                   const ast::PipelineStage stage) const {
+    if (stage == ast::PipelineStage::kCompute) {
+        AddError(AttrToStr(attr) + " cannot be used by compute shaders", attr->source);
+        return false;
+    }
+    return true;
+}
+
 bool Validator::Function(const sem::Function* func, ast::PipelineStage stage) const {
     auto* decl = func->Declaration();
 
@@ -1125,73 +1137,74 @@
         const ast::InterpolateAttribute* interpolate_attribute = nullptr;
         const ast::InvariantAttribute* invariant_attribute = nullptr;
         for (auto* attr : attrs) {
-            auto is_invalid_compute_shader_attribute = false;
+            bool ok = Switch(
+                attr,  //
+                [&](const ast::BuiltinAttribute* builtin_attr) {
+                    auto builtin = sem_.Get(builtin_attr)->Value();
 
-            if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
-                auto builtin = sem_.Get(builtin_attr)->Value();
+                    if (pipeline_io_attribute) {
+                        AddError("multiple entry point IO attributes", attr->source);
+                        AddNote("previously consumed " + AttrToStr(pipeline_io_attribute),
+                                pipeline_io_attribute->source);
+                        return false;
+                    }
+                    pipeline_io_attribute = attr;
 
-                if (pipeline_io_attribute) {
-                    AddError("multiple entry point IO attributes", attr->source);
-                    AddNote("previously consumed " + AttrToStr(pipeline_io_attribute),
-                            pipeline_io_attribute->source);
-                    return false;
-                }
-                pipeline_io_attribute = attr;
+                    if (builtins.Contains(builtin)) {
+                        StringStream err;
+                        err << "@builtin(" << builtin << ") appears multiple times as pipeline "
+                            << (param_or_ret == ParamOrRetType::kParameter ? "input" : "output");
+                        AddError(err.str(), decl->source);
+                        return false;
+                    }
 
-                if (builtins.Contains(builtin)) {
-                    StringStream err;
-                    err << "@builtin(" << builtin << ") appears multiple times as pipeline "
-                        << (param_or_ret == ParamOrRetType::kParameter ? "input" : "output");
-                    AddError(err.str(), decl->source);
-                    return false;
-                }
+                    if (!BuiltinAttribute(
+                            builtin_attr, ty, stage,
+                            /* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
+                        return false;
+                    }
+                    builtins.Add(builtin);
+                    return true;
+                },
+                [&](const ast::LocationAttribute* loc_attr) {
+                    location_attribute = loc_attr;
+                    if (pipeline_io_attribute) {
+                        AddError("multiple entry point IO attributes", attr->source);
+                        AddNote("previously consumed " + AttrToStr(pipeline_io_attribute),
+                                pipeline_io_attribute->source);
+                        return false;
+                    }
+                    pipeline_io_attribute = attr;
 
-                if (!BuiltinAttribute(builtin_attr, ty, stage,
-                                      /* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
-                    return false;
-                }
-                builtins.Add(builtin);
-            } else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
-                location_attribute = loc_attr;
-                if (pipeline_io_attribute) {
-                    AddError("multiple entry point IO attributes", attr->source);
-                    AddNote("previously consumed " + AttrToStr(pipeline_io_attribute),
-                            pipeline_io_attribute->source);
-                    return false;
-                }
-                pipeline_io_attribute = attr;
+                    if (TINT_UNLIKELY(!location.has_value())) {
+                        TINT_ICE() << "@location has no value";
+                        return false;
+                    }
 
-                bool is_input = param_or_ret == ParamOrRetType::kParameter;
+                    return LocationAttribute(loc_attr, ty, stage, source);
+                },
+                [&](const ast::IndexAttribute* index_attr) {
+                    bool is_input = param_or_ret == ParamOrRetType::kParameter;
+                    index_attribute = index_attr;
 
-                if (TINT_UNLIKELY(!location.has_value())) {
-                    TINT_ICE() << "Location has no value";
-                    return false;
-                }
+                    if (TINT_UNLIKELY(!index.has_value())) {
+                        TINT_ICE() << "@index has no value";
+                        return false;
+                    }
 
-                if (!LocationAttribute(loc_attr, ty, stage, source, is_input)) {
-                    return false;
-                }
-            } else if (auto* index_attr = attr->As<ast::IndexAttribute>()) {
-                index_attribute = index_attr;
-                return IndexAttribute(index_attr, stage);
-            } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
-                if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
-                    is_invalid_compute_shader_attribute = true;
-                } else if (!InterpolateAttribute(interpolate, ty)) {
-                    return false;
-                }
-                interpolate_attribute = interpolate;
-            } else if (auto* invariant = attr->As<ast::InvariantAttribute>()) {
-                if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
-                    is_invalid_compute_shader_attribute = true;
-                }
-                invariant_attribute = invariant;
-            }
-            if (is_invalid_compute_shader_attribute) {
-                std::string input_or_output =
-                    param_or_ret == ParamOrRetType::kParameter ? "inputs" : "output";
-                AddError("@" + attr->Name() + " is not valid for compute shader " + input_or_output,
-                         attr->source);
+                    return IndexAttribute(index_attr, stage, is_input);
+                },
+                [&](const ast::InterpolateAttribute* interpolate) {
+                    interpolate_attribute = interpolate;
+                    return InterpolateAttribute(interpolate, ty, stage);
+                },
+                [&](const ast::InvariantAttribute* invariant) {
+                    invariant_attribute = invariant;
+                    return InvariantAttribute(invariant, stage);
+                },
+                [&](Default) { return true; });
+
+            if (!ok) {
                 return false;
             }
         }
@@ -1234,15 +1247,10 @@
             }
 
             if (index_attribute) {
-                if (Is<ast::LocationAttribute>(pipeline_io_attribute)) {
-                    AddError("@index can only be used with @location", index_attribute->source);
-                    return false;
-                }
-
                 // Because HLSL specifies dual source blending targets with SV_Target0 and 1, we
-                // should restrict targets with index attributes to location 0 for easy translation
+                // should restrict targets with @index to location 0 for easy translation
                 // in the backend writers.
-                if (location.value() != 0) {
+                if (location.value_or(1) != 0) {
                     AddError("@index can only be used with @location(0)", index_attribute->source);
                     return false;
                 }
@@ -2177,16 +2185,13 @@
                 attr,  //
                 [&](const ast::InvariantAttribute* invariant) {
                     invariant_attribute = invariant;
-                    return true;
+                    return InvariantAttribute(invariant, stage);
                 },
                 [&](const ast::LocationAttribute* location) {
                     location_attribute = location;
                     TINT_ASSERT(member->Attributes().location.has_value());
-                    if (!LocationAttribute(location, member->Type(), stage,
-                                           member->Declaration()->source)) {
-                        return false;
-                    }
-                    return true;
+                    return LocationAttribute(location, member->Type(), stage,
+                                             member->Declaration()->source);
                 },
                 [&](const ast::IndexAttribute* index) {
                     index_attribute = index;
@@ -2206,10 +2211,7 @@
                 },
                 [&](const ast::InterpolateAttribute* interpolate) {
                     interpolate_attribute = interpolate;
-                    if (!InterpolateAttribute(interpolate, member->Type())) {
-                        return false;
-                    }
-                    return true;
+                    return InterpolateAttribute(interpolate, member->Type(), stage);
                 },
                 [&](const ast::StructMemberSizeAttribute*) {
                     if (!member->Type()->HasCreationFixedFootprint()) {
@@ -2234,15 +2236,10 @@
         }
 
         if (index_attribute) {
-            if (!location_attribute) {
-                AddError("@index can only be used with @location", index_attribute->source);
-                return false;
-            }
-
             // Because HLSL specifies dual source blending targets with SV_Target0 and 1, we should
             // restrict targets with index attributes to location 0 for easy translation in the
             // backend writers.
-            if (member->Attributes().location.value() != 0) {
+            if (member->Attributes().location.value_or(1) != 0) {
                 AddError("@index can only be used with @location(0)", index_attribute->source);
                 return false;
             }
@@ -2289,15 +2286,12 @@
     return true;
 }
 
-bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
+bool Validator::LocationAttribute(const ast::LocationAttribute* attr,
                                   const core::type::Type* type,
                                   ast::PipelineStage stage,
-                                  const Source& source,
-                                  const bool is_input) const {
-    std::string inputs_or_output = is_input ? "inputs" : "output";
+                                  const Source& source) const {
     if (stage == ast::PipelineStage::kCompute) {
-        AddError("@" + loc_attr->Name() + " is not valid for compute shader " + inputs_or_output,
-                 loc_attr->source);
+        AddError(AttrToStr(attr) + " cannot be used by compute shaders", attr->source);
         return false;
     }
 
@@ -2307,32 +2301,29 @@
         AddNote(
             "@location must only be applied to declarations of numeric scalar or numeric vector "
             "type",
-            loc_attr->source);
+            attr->source);
         return false;
     }
 
     return true;
 }
 
-bool Validator::IndexAttribute(const ast::IndexAttribute* index_attr,
-                               ast::PipelineStage stage) const {
+bool Validator::IndexAttribute(const ast::IndexAttribute* attr,
+                               ast::PipelineStage stage,
+                               const std::optional<bool> is_input) const {
     if (!enabled_extensions_.Contains(wgsl::Extension::kChromiumInternalDualSourceBlending)) {
         AddError(
             "use of '@index' attribute requires enabling extension "
             "'chromium_internal_dual_source_blending'",
-            index_attr->source);
+            attr->source);
         return false;
     }
 
-    if (stage == ast::PipelineStage::kCompute) {
-        AddError("@" + index_attr->Name() + " is not valid for compute shader output",
-                 index_attr->source);
-        return false;
-    }
-
-    if (stage == ast::PipelineStage::kVertex) {
-        AddError("@" + index_attr->Name() + " is not valid for vertex shader output",
-                 index_attr->source);
+    bool is_stage_non_fragment =
+        stage != ast::PipelineStage::kNone && stage != ast::PipelineStage::kFragment;
+    bool is_output = is_input.value_or(false);
+    if (is_stage_non_fragment || is_output) {
+        AddError(AttrToStr(attr) + " can only be used as fragment shader output", attr->source);
         return false;
     }
 
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index 20735de..e78512a 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -315,11 +315,20 @@
     bool IncrementDecrementStatement(const ast::IncrementDecrementStatement* stmt) const;
 
     /// Validates an interpolate attribute
-    /// @param attr the interpolation attribute to validate
+    /// @param attr the attribute to validate
     /// @param storage_type the storage type of the attached variable
-    /// @returns true on succes, false otherwise
+    /// @param stage the current pipeline stage
+    /// @returns true on success, false otherwise
     bool InterpolateAttribute(const ast::InterpolateAttribute* attr,
-                              const core::type::Type* storage_type) const;
+                              const core::type::Type* storage_type,
+                              const ast::PipelineStage stage) const;
+
+    /// Validates an invariant attribute
+    /// @param attr the attribute to validate
+    /// @param stage the current pipeline stage
+    /// @returns true on success, false otherwise
+    bool InvariantAttribute(const ast::InvariantAttribute* attr,
+                            const ast::PipelineStage stage) const;
 
     /// Validates a builtin call
     /// @param call the builtin call to validate
@@ -332,23 +341,25 @@
     bool LocalVariable(const sem::Variable* v) const;
 
     /// Validates a location attribute
-    /// @param loc_attr the location attribute to validate
+    /// @param attr the attribute to validate
     /// @param type the variable type
     /// @param stage the current pipeline stage
     /// @param source the source of the attribute
-    /// @param is_input true if this is an input variable
     /// @returns true on success, false otherwise.
-    bool LocationAttribute(const ast::LocationAttribute* loc_attr,
+    bool LocationAttribute(const ast::LocationAttribute* attr,
                            const core::type::Type* type,
-                           ast::PipelineStage stage,
-                           const Source& source,
-                           const bool is_input = false) const;
+                           const ast::PipelineStage stage,
+                           const Source& source) const;
 
     /// Validates a index attribute
     /// @param index_attr the index attribute to validate
     /// @param stage the current pipeline stage
+    /// @param is_input true if is an input variable, false if output variable, std::nullopt is
+    /// unknown.
     /// @returns true on success, false otherwise.
-    bool IndexAttribute(const ast::IndexAttribute* index_attr, ast::PipelineStage stage) const;
+    bool IndexAttribute(const ast::IndexAttribute* index_attr,
+                        ast::PipelineStage stage,
+                        const std::optional<bool> is_input = {}) const;
 
     /// Validates a loop statement
     /// @param stmt the loop statement