spirv-reader: support Flat decoration

Bug: tint:935
Change-Id: Ie6f97d8d9a273fd8099d8e9807ffa189ba3031a0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56820
Auto-Submit: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index c5830c5..608ba3f 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -21,6 +21,7 @@
 
 #include "source/opt/build_module.h"
 #include "src/ast/bitcast_expression.h"
+#include "src/ast/interpolate_decoration.h"
 #include "src/ast/override_decoration.h"
 #include "src/ast/struct_block_decoration.h"
 #include "src/ast/type_name.h"
@@ -1093,21 +1094,28 @@
     bool is_non_writable = false;
     ast::DecorationList ast_member_decorations;
     for (auto& decoration : GetDecorationsForMember(type_id, member_index)) {
-      if (decoration[0] == SpvDecorationNonWritable) {
-        // WGSL doesn't represent individual members as non-writable. Instead,
-        // apply the ReadOnly access control to the containing struct if all
-        // the members are non-writable.
-        is_non_writable = true;
-      } else if (decoration[0] == SpvDecorationLocation) {
-        // Location decorations are handled when emitting the entry point.
-      } else {
-        auto* ast_member_decoration =
-            ConvertMemberDecoration(type_id, member_index, decoration);
-        if (!success_) {
-          return nullptr;
-        }
-        if (ast_member_decoration) {
-          ast_member_decorations.push_back(ast_member_decoration);
+      switch (decoration[0]) {
+        case SpvDecorationNonWritable:
+
+          // WGSL doesn't represent individual members as non-writable. Instead,
+          // apply the ReadOnly access control to the containing struct if all
+          // the members are non-writable.
+          is_non_writable = true;
+          break;
+        case SpvDecorationLocation:
+        case SpvDecorationFlat:
+          // IO decorations are handled when emitting the entry point.
+          break;
+        default: {
+          auto* ast_member_decoration =
+              ConvertMemberDecoration(type_id, member_index, decoration);
+          if (!success_) {
+            return nullptr;
+          }
+          if (ast_member_decoration) {
+            ast_member_decorations.push_back(ast_member_decoration);
+          }
+          break;
         }
       }
     }
@@ -1635,6 +1643,17 @@
             create<ast::LocationDecoration>(Source{}, deco[1]));
       }
     }
+    if (deco[0] == SpvDecorationFlat) {
+      if (transfer_pipeline_io) {
+        // In WGSL, integral types are always flat, and so the decoration
+        // is never specified.
+        if (!(*store_type)->IsIntegerScalarOrVector()) {
+          decorations->emplace_back(create<ast::InterpolateDecoration>(
+              Source{}, ast::InterpolationType::kFlat,
+              ast::InterpolationSampling::kNone));
+        }
+      }
+    }
     if (deco[0] == SpvDecorationDescriptorSet) {
       if (deco.size() == 1) {
         return Fail() << "malformed DescriptorSet decoration on ID " << id
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index beeef85..941129c 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -87,6 +87,7 @@
     %v2uint = OpTypeVector %uint 2
     %v2int = OpTypeVector %int 2
     %v2float = OpTypeVector %float 2
+    %v4float = OpTypeVector %float 4
     %m3v2float = OpTypeMatrix %v2float 3
 
     %arr2uint = OpTypeArray %uint %uint_2
@@ -7278,6 +7279,487 @@
                         "or member of a structure type"));
 }
 
+TEST_F(SpvModuleScopeVarParserTest,
+       EntryPointWrapping_Interpolation_Flat_Vertex_In) {
+  // Flat decorations are dropped for integral
+  const auto assembly = CommonCapabilities() + R"(
+     OpEntryPoint Vertex %main "main" %1 %2 %3 %4 %5 %6 %10
+     OpDecorate %1 Location 1
+     OpDecorate %2 Location 2
+     OpDecorate %3 Location 3
+     OpDecorate %4 Location 4
+     OpDecorate %5 Location 5
+     OpDecorate %6 Location 6
+     OpDecorate %1 Flat
+     OpDecorate %2 Flat
+     OpDecorate %3 Flat
+     OpDecorate %4 Flat
+     OpDecorate %5 Flat
+     OpDecorate %6 Flat
+     OpDecorate %10 BuiltIn Position
+)" + CommonTypes() +
+                        R"(
+     %ptr_in_uint = OpTypePointer Input %uint
+     %ptr_in_v2uint = OpTypePointer Input %v2uint
+     %ptr_in_int = OpTypePointer Input %int
+     %ptr_in_v2int = OpTypePointer Input %v2int
+     %ptr_in_float = OpTypePointer Input %float
+     %ptr_in_v2float = OpTypePointer Input %v2float
+     %1 = OpVariable %ptr_in_uint Input
+     %2 = OpVariable %ptr_in_v2uint Input
+     %3 = OpVariable %ptr_in_int Input
+     %4 = OpVariable %ptr_in_v2int Input
+     %5 = OpVariable %ptr_in_float Input
+     %6 = OpVariable %ptr_in_v2float Input
+
+     %ptr_out_v4float = OpTypePointer Output %v4float
+     %10 = OpVariable %ptr_out_v4float Output
+
+     %main = OpFunction %void None %voidfn
+     %entry = OpLabel
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  EXPECT_TRUE(p->error().empty());
+  const auto got = p->program().to_str();
+  const std::string expected =
+      R"(Module{
+  Struct main_out {
+    StructMember{[[ BuiltinDecoration{position}
+ ]] x_10_1: __vec_4__f32}
+  }
+  Variable{
+    x_1
+    private
+    undefined
+    __u32
+  }
+  Variable{
+    x_2
+    private
+    undefined
+    __vec_2__u32
+  }
+  Variable{
+    x_3
+    private
+    undefined
+    __i32
+  }
+  Variable{
+    x_4
+    private
+    undefined
+    __vec_2__i32
+  }
+  Variable{
+    x_5
+    private
+    undefined
+    __f32
+  }
+  Variable{
+    x_6
+    private
+    undefined
+    __vec_2__f32
+  }
+  Variable{
+    x_10
+    private
+    undefined
+    __vec_4__f32
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __type_name_main_out
+  StageDecoration{vertex}
+  (
+    VariableConst{
+      Decorations{
+        LocationDecoration{1}
+      }
+      x_1_param
+      none
+      undefined
+      __u32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{2}
+      }
+      x_2_param
+      none
+      undefined
+      __vec_2__u32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{3}
+      }
+      x_3_param
+      none
+      undefined
+      __i32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{4}
+      }
+      x_4_param
+      none
+      undefined
+      __vec_2__i32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{5}
+        InterpolateDecoration{flat none}
+      }
+      x_5_param
+      none
+      undefined
+      __f32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{6}
+        InterpolateDecoration{flat none}
+      }
+      x_6_param
+      none
+      undefined
+      __vec_2__f32
+    }
+  )
+  {
+    Assignment{
+      Identifier[not set]{x_1}
+      Identifier[not set]{x_1_param}
+    }
+    Assignment{
+      Identifier[not set]{x_2}
+      Identifier[not set]{x_2_param}
+    }
+    Assignment{
+      Identifier[not set]{x_3}
+      Identifier[not set]{x_3_param}
+    }
+    Assignment{
+      Identifier[not set]{x_4}
+      Identifier[not set]{x_4_param}
+    }
+    Assignment{
+      Identifier[not set]{x_5}
+      Identifier[not set]{x_5_param}
+    }
+    Assignment{
+      Identifier[not set]{x_6}
+      Identifier[not set]{x_6_param}
+    }
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+    Return{
+      {
+        TypeConstructor[not set]{
+          __type_name_main_out
+          Identifier[not set]{x_10}
+        }
+      }
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
+TEST_F(SpvModuleScopeVarParserTest,
+       EntryPointWrapping_Interpolation_Flat_Vertex_Output) {
+  // Flat decorations are dropped for integral
+  const auto assembly = CommonCapabilities() + R"(
+     OpEntryPoint Vertex %main "main" %1 %2 %3 %4 %5 %6 %10
+     OpDecorate %1 Location 1
+     OpDecorate %2 Location 2
+     OpDecorate %3 Location 3
+     OpDecorate %4 Location 4
+     OpDecorate %5 Location 5
+     OpDecorate %6 Location 6
+     OpDecorate %1 Flat
+     OpDecorate %2 Flat
+     OpDecorate %3 Flat
+     OpDecorate %4 Flat
+     OpDecorate %5 Flat
+     OpDecorate %6 Flat
+     OpDecorate %10 BuiltIn Position
+)" + CommonTypes() +
+                        R"(
+     %ptr_out_uint = OpTypePointer Output %uint
+     %ptr_out_v2uint = OpTypePointer Output %v2uint
+     %ptr_out_int = OpTypePointer Output %int
+     %ptr_out_v2int = OpTypePointer Output %v2int
+     %ptr_out_float = OpTypePointer Output %float
+     %ptr_out_v2float = OpTypePointer Output %v2float
+     %1 = OpVariable %ptr_out_uint Output
+     %2 = OpVariable %ptr_out_v2uint Output
+     %3 = OpVariable %ptr_out_int Output
+     %4 = OpVariable %ptr_out_v2int Output
+     %5 = OpVariable %ptr_out_float Output
+     %6 = OpVariable %ptr_out_v2float Output
+
+     %ptr_out_v4float = OpTypePointer Output %v4float
+     %10 = OpVariable %ptr_out_v4float Output
+
+     %main = OpFunction %void None %voidfn
+     %entry = OpLabel
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  EXPECT_TRUE(p->error().empty());
+  const auto got = p->program().to_str();
+  const std::string expected =
+      R"(Module{
+  Struct main_out {
+    StructMember{[[ LocationDecoration{1}
+ ]] x_1_1: __u32}
+    StructMember{[[ LocationDecoration{2}
+ ]] x_2_1: __vec_2__u32}
+    StructMember{[[ LocationDecoration{3}
+ ]] x_3_1: __i32}
+    StructMember{[[ LocationDecoration{4}
+ ]] x_4_1: __vec_2__i32}
+    StructMember{[[ LocationDecoration{5}
+ InterpolateDecoration{flat none}
+ ]] x_5_1: __f32}
+    StructMember{[[ LocationDecoration{6}
+ InterpolateDecoration{flat none}
+ ]] x_6_1: __vec_2__f32}
+    StructMember{[[ BuiltinDecoration{position}
+ ]] x_10_1: __vec_4__f32}
+  }
+  Variable{
+    x_1
+    private
+    undefined
+    __u32
+  }
+  Variable{
+    x_2
+    private
+    undefined
+    __vec_2__u32
+  }
+  Variable{
+    x_3
+    private
+    undefined
+    __i32
+  }
+  Variable{
+    x_4
+    private
+    undefined
+    __vec_2__i32
+  }
+  Variable{
+    x_5
+    private
+    undefined
+    __f32
+  }
+  Variable{
+    x_6
+    private
+    undefined
+    __vec_2__f32
+  }
+  Variable{
+    x_10
+    private
+    undefined
+    __vec_4__f32
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __type_name_main_out
+  StageDecoration{vertex}
+  ()
+  {
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+    Return{
+      {
+        TypeConstructor[not set]{
+          __type_name_main_out
+          Identifier[not set]{x_1}
+          Identifier[not set]{x_2}
+          Identifier[not set]{x_3}
+          Identifier[not set]{x_4}
+          Identifier[not set]{x_5}
+          Identifier[not set]{x_6}
+          Identifier[not set]{x_10}
+        }
+      }
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
+TEST_F(SpvModuleScopeVarParserTest,
+       EntryPointWrapping_Flatten_Interpolation_Flat_Fragment_In) {
+  // Flat decorations are dropped for integral
+  const auto assembly = CommonCapabilities() + R"(
+     OpEntryPoint Fragment %main "main" %1 %2
+     OpExecutionMode %main OriginUpperLeft
+     OpDecorate %1 Location 1
+     OpDecorate %2 Location 5
+     OpDecorate %1 Flat
+     OpDecorate %2 Flat
+)" + CommonTypes() +
+                        R"(
+     %arr = OpTypeArray %float %uint_2
+     %strct = OpTypeStruct %float %float
+     %ptr_in_arr = OpTypePointer Input %arr
+     %ptr_in_strct = OpTypePointer Input %strct
+     %1 = OpVariable %ptr_in_arr Input
+     %2 = OpVariable %ptr_in_strct Input
+
+     %main = OpFunction %void None %voidfn
+     %entry = OpLabel
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+
+  ASSERT_TRUE(p->BuildAndParseInternalModule());
+  EXPECT_TRUE(p->error().empty());
+  const auto got = p->program().to_str();
+  const std::string expected =
+      R"(Module{
+  Struct S {
+    StructMember{field0: __f32}
+    StructMember{field1: __f32}
+  }
+  Variable{
+    x_1
+    private
+    undefined
+    __array__f32_2
+  }
+  Variable{
+    x_2
+    private
+    undefined
+    __type_name_S
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __void
+  StageDecoration{fragment}
+  (
+    VariableConst{
+      Decorations{
+        LocationDecoration{1}
+        InterpolateDecoration{flat none}
+      }
+      x_1_param
+      none
+      undefined
+      __f32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{2}
+        InterpolateDecoration{flat none}
+      }
+      x_1_param_1
+      none
+      undefined
+      __f32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{5}
+        InterpolateDecoration{flat none}
+      }
+      x_2_param
+      none
+      undefined
+      __f32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{6}
+        InterpolateDecoration{flat none}
+      }
+      x_2_param_1
+      none
+      undefined
+      __f32
+    }
+  )
+  {
+    Assignment{
+      ArrayAccessor[not set]{
+        Identifier[not set]{x_1}
+        ScalarConstructor[not set]{0}
+      }
+      Identifier[not set]{x_1_param}
+    }
+    Assignment{
+      ArrayAccessor[not set]{
+        Identifier[not set]{x_1}
+        ScalarConstructor[not set]{1}
+      }
+      Identifier[not set]{x_1_param_1}
+    }
+    Assignment{
+      MemberAccessor[not set]{
+        Identifier[not set]{x_2}
+        Identifier[not set]{field0}
+      }
+      Identifier[not set]{x_2_param}
+    }
+    Assignment{
+      MemberAccessor[not set]{
+        Identifier[not set]{x_2}
+        Identifier[not set]{field1}
+      }
+      Identifier[not set]{x_2_param_1}
+    }
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader