validation: Validate interpolation attributes

They are only valid on entry point parameters and return types, and
struct members. They must only be used on floating point scalar and
vector types. If the interpolation type is flat, the sampling type
must not be specified.

Bug: tint:746
Change-Id: Iab17816bc9947a74593a5937bdf513ac9ec664f1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56241
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/program_builder.h b/src/program_builder.h
index facd539..8627c9a 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -36,6 +36,7 @@
 #include "src/ast/float_literal.h"
 #include "src/ast/i32.h"
 #include "src/ast/if_statement.h"
+#include "src/ast/interpolate_decoration.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/matrix.h"
 #include "src/ast/member_accessor_expression.h"
@@ -1925,6 +1926,26 @@
     return create<ast::BuiltinDecoration>(source_, builtin);
   }
 
+  /// Creates an ast::InterpolateDecoration
+  /// @param source the source information
+  /// @param type the interpolation type
+  /// @param sampling the interpolation sampling
+  /// @returns the interpolate decoration pointer
+  ast::InterpolateDecoration* Interpolate(const Source& source,
+                                          ast::InterpolationType type,
+                                          ast::InterpolationSampling sampling) {
+    return create<ast::InterpolateDecoration>(source, type, sampling);
+  }
+
+  /// Creates an ast::InterpolateDecoration
+  /// @param type the interpolation type
+  /// @param sampling the interpolation sampling
+  /// @returns the interpolate decoration pointer
+  ast::InterpolateDecoration* Interpolate(ast::InterpolationType type,
+                                          ast::InterpolationSampling sampling) {
+    return create<ast::InterpolateDecoration>(source_, type, sampling);
+  }
+
   /// Creates an ast::LocationDecoration
   /// @param source the source information
   /// @param location the location value
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index 8702d2d..c40c70a 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -61,6 +61,7 @@
   kBinding,
   kBuiltin,
   kGroup,
+  kInterpolate,
   kLocation,
   kOverride,
   kOffset,
@@ -102,6 +103,10 @@
       return {builder.Builtin(source, ast::Builtin::kPosition)};
     case DecorationKind::kGroup:
       return {builder.create<ast::GroupDecoration>(source, 1u)};
+    case DecorationKind::kInterpolate:
+      return {builder.Interpolate(source, ast::InterpolationType::kLinear,
+                                  ast::InterpolationSampling::kCenter),
+              builder.Location(0)};
     case DecorationKind::kLocation:
       return {builder.Location(source, 1)};
     case DecorationKind::kOverride:
@@ -150,6 +155,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -185,6 +191,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, true},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, true},
                     TestParams{DecorationKind::kLocation, true},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -247,6 +254,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -282,6 +290,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, true},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, true},
                     TestParams{DecorationKind::kLocation, true},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -335,6 +344,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -369,6 +379,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -408,7 +419,7 @@
                 createDecorations(Source{{12, 34}}, *this, params.kind))});
   } else {
     members.push_back(
-        {Member("a", ty.i32(),
+        {Member("a", ty.f32(),
                 createDecorations(Source{{12, 34}}, *this, params.kind))});
   }
 
@@ -431,6 +442,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, true},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, true},
                     TestParams{DecorationKind::kLocation, true},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, true},
@@ -492,6 +504,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -542,6 +555,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, true},
                     TestParams{DecorationKind::kOffset, false},
@@ -590,6 +604,7 @@
                     TestParams{DecorationKind::kBinding, false},
                     TestParams{DecorationKind::kBuiltin, false},
                     TestParams{DecorationKind::kGroup, false},
+                    TestParams{DecorationKind::kInterpolate, false},
                     TestParams{DecorationKind::kLocation, false},
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
@@ -957,5 +972,101 @@
 }  // namespace
 }  // namespace WorkgroupDecorationTests
 
+namespace InterpolateTests {
+namespace {
+
+using InterpolateTest = ResolverTest;
+
+struct Params {
+  ast::InterpolationType type;
+  ast::InterpolationSampling sampling;
+  bool should_pass;
+};
+
+struct TestWithParams : ResolverTestWithParam<Params> {};
+
+using InterpolateParameterTest = TestWithParams;
+TEST_P(InterpolateParameterTest, All) {
+  auto& params = GetParam();
+
+  Func("main",
+       ast::VariableList{Param(
+           "a", ty.f32(),
+           {Location(0),
+            Interpolate(Source{{12, 34}}, params.type, params.sampling)})},
+       ty.void_(), {},
+       ast::DecorationList{Stage(ast::PipelineStage::kFragment)});
+
+  if (params.should_pass) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: flat interpolation attribute must not have a "
+              "sampling parameter");
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    ResolverDecorationValidationTest,
+    InterpolateParameterTest,
+    testing::Values(Params{ast::InterpolationType::kPerspective,
+                           ast::InterpolationSampling::kNone, true},
+                    Params{ast::InterpolationType::kPerspective,
+                           ast::InterpolationSampling::kCenter, true},
+                    Params{ast::InterpolationType::kPerspective,
+                           ast::InterpolationSampling::kCentroid, true},
+                    Params{ast::InterpolationType::kPerspective,
+                           ast::InterpolationSampling::kSample, true},
+                    Params{ast::InterpolationType::kLinear,
+                           ast::InterpolationSampling::kNone, true},
+                    Params{ast::InterpolationType::kLinear,
+                           ast::InterpolationSampling::kCenter, true},
+                    Params{ast::InterpolationType::kLinear,
+                           ast::InterpolationSampling::kCentroid, true},
+                    Params{ast::InterpolationType::kLinear,
+                           ast::InterpolationSampling::kSample, true},
+                    // flat interpolation must not have a sampling type
+                    Params{ast::InterpolationType::kFlat,
+                           ast::InterpolationSampling::kNone, true},
+                    Params{ast::InterpolationType::kFlat,
+                           ast::InterpolationSampling::kCenter, false},
+                    Params{ast::InterpolationType::kFlat,
+                           ast::InterpolationSampling::kCentroid, false},
+                    Params{ast::InterpolationType::kFlat,
+                           ast::InterpolationSampling::kSample, false}));
+
+TEST_F(InterpolateTest, Parameter_NotFloatingPoint) {
+  Func("main",
+       ast::VariableList{
+           Param("a", ty.i32(),
+                 {Location(0),
+                  Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat,
+                              ast::InterpolationSampling::kNone)})},
+       ty.void_(), {},
+       ast::DecorationList{Stage(ast::PipelineStage::kFragment)});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: store type of interpolate attribute must be floating "
+            "point scalar or vector");
+}
+
+TEST_F(InterpolateTest, ReturnType_NotFloatingPoint) {
+  Func(
+      "main", {}, ty.i32(), {Return(1)},
+      ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
+      {Location(0), Interpolate(Source{{12, 34}}, ast::InterpolationType::kFlat,
+                                ast::InterpolationSampling::kNone)});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: store type of interpolate attribute must be floating "
+            "point scalar or vector");
+}
+
+}  // namespace
+}  // namespace InterpolateTests
+
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 65b34bd..c6462df 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -30,6 +30,7 @@
 #include "src/ast/fallthrough_statement.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/internal_decoration.h"
+#include "src/ast/interpolate_decoration.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/matrix.h"
 #include "src/ast/override_decoration.h"
@@ -719,7 +720,8 @@
       }
     } else {
       bool is_shader_io_decoration =
-          deco->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
+          deco->IsAnyOf<ast::BuiltinDecoration, ast::InterpolateDecoration,
+                        ast::LocationDecoration>();
       bool has_io_storage_class =
           info->storage_class == ast::StorageClass::kInput ||
           info->storage_class == ast::StorageClass::kOutput;
@@ -947,6 +949,10 @@
       if (!ValidateBuiltinDecoration(builtin, info->type)) {
         return false;
       }
+    } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
+      if (!ValidateInterpolateDecoration(interpolate, info->type)) {
+        return false;
+      }
     } else if (!deco->IsAnyOf<ast::LocationDecoration,
                               ast::InternalDecoration>() &&
                !(IsValidationDisabled(
@@ -1015,6 +1021,29 @@
   return true;
 }
 
+bool Resolver::ValidateInterpolateDecoration(
+    const ast::InterpolateDecoration* deco,
+    const sem::Type* storage_type) {
+  auto* type = storage_type->UnwrapRef();
+
+  if (!type->is_float_scalar_or_vector()) {
+    AddError(
+        "store type of interpolate attribute must be floating point scalar or "
+        "vector",
+        deco->source());
+    return false;
+  }
+
+  if (deco->type() == ast::InterpolationType::kFlat &&
+      deco->sampling() != ast::InterpolationSampling::kNone) {
+    AddError("flat interpolation attribute must not have a sampling parameter",
+             deco->source());
+    return false;
+  }
+
+  return true;
+}
+
 bool Resolver::ValidateFunction(const ast::Function* func,
                                 const FunctionInfo* info) {
   auto func_it = symbol_to_function_.find(func->symbol());
@@ -1101,6 +1130,10 @@
         if (!ValidateBuiltinDecoration(builtin, info->return_type)) {
           return false;
         }
+      } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
+        if (!ValidateInterpolateDecoration(interpolate, info->return_type)) {
+          return false;
+        }
       } else if (!deco->Is<ast::LocationDecoration>()) {
         AddError("decoration is not valid for entry point return types",
                  deco->source());
@@ -3364,6 +3397,7 @@
 
     for (auto* deco : member->Declaration()->decorations()) {
       if (!(deco->Is<ast::BuiltinDecoration>() ||
+            deco->Is<ast::InterpolateDecoration>() ||
             deco->Is<ast::LocationDecoration>() ||
             deco->Is<ast::StructMemberOffsetDecoration>() ||
             deco->Is<ast::StructMemberSizeDecoration>() ||
@@ -3376,6 +3410,10 @@
         if (!ValidateBuiltinDecoration(builtin, member->Type())) {
           return false;
         }
+      } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
+        if (!ValidateInterpolateDecoration(interpolate, member->Type())) {
+          return false;
+        }
       }
     }
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 336b44c..16758bf 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -276,6 +276,8 @@
   bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
   bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
   bool ValidateGlobalVariable(const VariableInfo* var);
+  bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
+                                     const sem::Type* storage_type);
   bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source);
   bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
                                  const sem::Matrix* matrix_type);