transform/CanonicalizeEntryPointIO: Sort struct members

HLSL interface matching rules impose additional requirements on the
order of structure members. We now sort members such that all members
with location attributes appear first (ordered by location slot),
followed by those with builtin attributes.

Fixed: tint:710
Change-Id: I90940bcb7a5b9eeb1f50f132d406d4cf74e47ea2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47822
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index f655b30..22bab4a 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -14,6 +14,7 @@
 
 #include "src/transform/canonicalize_entry_point_io.h"
 
+#include <algorithm>
 #include <utility>
 
 #include "src/program_builder.h"
@@ -27,6 +28,36 @@
 CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
 CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
 
+namespace {
+
+// Comparison function used to reorder struct members such that all members with
+// location attributes appear first (ordered by location slot), followed by
+// those with builtin attributes.
+bool StructMemberComparator(const ast::StructMember* a,
+                            const ast::StructMember* b) {
+  auto* a_loc = ast::GetDecoration<ast::LocationDecoration>(a->decorations());
+  auto* b_loc = ast::GetDecoration<ast::LocationDecoration>(b->decorations());
+  auto* a_blt = ast::GetDecoration<ast::BuiltinDecoration>(a->decorations());
+  auto* b_blt = ast::GetDecoration<ast::BuiltinDecoration>(b->decorations());
+  if (a_loc) {
+    if (!b_loc) {
+      // `a` has location attribute and `b` does not: `a` goes first.
+      return true;
+    }
+    // Both have location attributes: smallest goes first.
+    return a_loc->value() < b_loc->value();
+  } else {
+    if (b_loc) {
+      // `b` has location attribute and `a` does not: `b` goes first.
+      return false;
+    }
+    // Both are builtins: order doesn't matter, just use enum value.
+    return a_blt->value() < b_blt->value();
+  }
+}
+
+}  // namespace
+
 Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
                                                 const DataMap&) {
   ProgramBuilder out;
@@ -133,6 +164,10 @@
         }
       }
 
+      // Sort struct members to satisfy HLSL interfacing matching rules.
+      std::sort(new_struct_members.begin(), new_struct_members.end(),
+                StructMemberComparator);
+
       // Create the new struct type.
       auto* in_struct = ctx.dst->create<type::Struct>(
           ctx.dst->Symbols().New(),
@@ -175,6 +210,10 @@
                             ctx.Clone(func->return_type_decorations())));
       }
 
+      // Sort struct members to satisfy HLSL interfacing matching rules.
+      std::sort(new_struct_members.begin(), new_struct_members.end(),
+                StructMemberComparator);
+
       // Create the new struct type.
       auto* out_struct = ctx.dst->create<type::Struct>(
           ctx.dst->Symbols().New(),
@@ -190,7 +229,7 @@
         // Reconstruct the return value using the newly created struct.
         auto* new_ret_value = ctx.Clone(ret->value());
         ast::ExpressionList ret_values;
-        if (auto* struct_ty = ret_type->As<type::Struct>()) {
+        if (ret_type->Is<type::Struct>()) {
           if (!ret->value()->Is<ast::IdentifierExpression>()) {
             // Create a const to hold the return value expression to avoid
             // re-evaluating it multiple times.
@@ -201,9 +240,9 @@
             new_ret_value = ctx.dst->Expr(temp);
           }
 
-          for (auto* member : struct_ty->impl()->members()) {
-            ret_values.push_back(ctx.dst->MemberAccessor(
-                new_ret_value, ctx.Clone(member->symbol())));
+          for (auto* member : new_struct_members) {
+            ret_values.push_back(
+                ctx.dst->MemberAccessor(new_ret_value, member->symbol()));
           }
         } else {
           ret_values.push_back(new_ret_value);
diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc
index aadf916..2f4b4ab 100644
--- a/src/transform/canonicalize_entry_point_io_test.cc
+++ b/src/transform/canonicalize_entry_point_io_test.cc
@@ -25,28 +25,28 @@
 TEST_F(CanonicalizeEntryPointIOTest, Parameters) {
   auto* src = R"(
 [[stage(fragment)]]
-fn frag_main([[builtin(position)]] coord : vec4<f32>,
-             [[location(1)]] loc1 : f32,
-             [[location(2)]] loc2 : vec4<u32>) {
+fn frag_main([[location(1)]] loc1 : f32,
+             [[location(2)]] loc2 : vec4<u32>,
+             [[builtin(position)]] coord : vec4<f32>) {
   var col : f32 = (coord.x * loc1);
 }
 )";
 
   auto* expect = R"(
 struct tint_symbol_1 {
-  [[builtin(position)]]
-  coord : vec4<f32>;
   [[location(1)]]
   loc1 : f32;
   [[location(2)]]
   loc2 : vec4<u32>;
+  [[builtin(position)]]
+  coord : vec4<f32>;
 };
 
 [[stage(fragment)]]
 fn frag_main(tint_symbol : tint_symbol_1) {
-  let coord : vec4<f32> = tint_symbol.coord;
   let loc1 : f32 = tint_symbol.loc1;
   let loc2 : vec4<u32> = tint_symbol.loc2;
+  let coord : vec4<f32> = tint_symbol.coord;
   var col : f32 = (coord.x * loc1);
 }
 )";
@@ -89,20 +89,20 @@
 TEST_F(CanonicalizeEntryPointIOTest, Parameters_EmptyBody) {
   auto* src = R"(
 [[stage(fragment)]]
-fn frag_main([[builtin(position)]] coord : vec4<f32>,
-             [[location(1)]] loc1 : f32,
-             [[location(2)]] loc2 : vec4<u32>) {
+fn frag_main([[location(1)]] loc1 : f32,
+             [[location(2)]] loc2 : vec4<u32>,
+             [[builtin(position)]] coord : vec4<f32>) {
 }
 )";
 
   auto* expect = R"(
 struct tint_symbol_1 {
-  [[builtin(position)]]
-  coord : vec4<f32>;
   [[location(1)]]
   loc1 : f32;
   [[location(2)]]
   loc2 : vec4<u32>;
+  [[builtin(position)]]
+  coord : vec4<f32>;
 };
 
 [[stage(fragment)]]
@@ -126,9 +126,9 @@
 };
 
 [[stage(fragment)]]
-fn frag_main(builtins : FragBuiltins,
+fn frag_main([[location(0)]] loc0 : f32,
              locations : FragLocations,
-             [[location(0)]] loc0 : f32) {
+             builtins : FragBuiltins) {
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 )";
@@ -144,21 +144,21 @@
 };
 
 struct tint_symbol_1 {
-  [[builtin(position)]]
-  coord : vec4<f32>;
+  [[location(0)]]
+  loc0 : f32;
   [[location(1)]]
   loc1 : f32;
   [[location(2)]]
   loc2 : vec4<u32>;
-  [[location(0)]]
-  loc0 : f32;
+  [[builtin(position)]]
+  coord : vec4<f32>;
 };
 
 [[stage(fragment)]]
 fn frag_main(tint_symbol : tint_symbol_1) {
-  let builtins : FragBuiltins = FragBuiltins(tint_symbol.coord);
-  let locations : FragLocations = FragLocations(tint_symbol.loc1, tint_symbol.loc2);
   let loc0 : f32 = tint_symbol.loc0;
+  let locations : FragLocations = FragLocations(tint_symbol.loc1, tint_symbol.loc2);
+  let builtins : FragBuiltins = FragBuiltins(tint_symbol.coord);
   var col : f32 = ((builtins.coord.x * locations.loc1) + loc0);
 }
 )";
@@ -196,9 +196,9 @@
 TEST_F(CanonicalizeEntryPointIOTest, Return_Struct) {
   auto* src = R"(
 struct FragOutput {
+  [[location(0)]] color : vec4<f32>;
   [[builtin(frag_depth)]] depth : f32;
   [[builtin(sample_mask)]] mask : u32;
-  [[location(0)]] color : vec4<f32>;
 };
 
 [[stage(fragment)]]
@@ -213,18 +213,18 @@
 
   auto* expect = R"(
 struct FragOutput {
+  color : vec4<f32>;
   depth : f32;
   mask : u32;
-  color : vec4<f32>;
 };
 
 struct tint_symbol {
+  [[location(0)]]
+  color : vec4<f32>;
   [[builtin(frag_depth)]]
   depth : f32;
   [[builtin(sample_mask)]]
   mask : u32;
-  [[location(0)]]
-  color : vec4<f32>;
 };
 
 [[stage(fragment)]]
@@ -233,7 +233,7 @@
   output.depth = 1.0;
   output.mask = 7u;
   output.color = vec4<f32>(0.5, 0.5, 0.5, 1.0);
-  return tint_symbol(output.depth, output.mask, output.color);
+  return tint_symbol(output.color, output.depth, output.mask);
 }
 )";
 
@@ -503,6 +503,94 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(CanonicalizeEntryPointIOTest, SortedMembers) {
+  auto* src = R"(
+struct VertexOutput {
+  [[location(1)]] b : u32;
+  [[builtin(position)]] pos : vec4<f32>;
+  [[location(3)]] d : bool;
+  [[location(0)]] a : f32;
+  [[location(2)]] c : i32;
+};
+
+struct FragmentInputExtra {
+  [[location(3)]] d : bool;
+  [[builtin(position)]] pos : vec4<f32>;
+  [[location(0)]] a : f32;
+};
+
+[[stage(vertex)]]
+fn vert_main() -> VertexOutput {
+  return VertexOutput();
+}
+
+[[stage(fragment)]]
+fn frag_main([[builtin(front_facing)]] ff : bool,
+             [[location(2)]] c : i32,
+             inputs : FragmentInputExtra,
+             [[location(1)]] b : u32) {
+}
+)";
+
+  auto* expect = R"(
+struct VertexOutput {
+  b : u32;
+  pos : vec4<f32>;
+  d : bool;
+  a : f32;
+  c : i32;
+};
+
+struct FragmentInputExtra {
+  d : bool;
+  pos : vec4<f32>;
+  a : f32;
+};
+
+struct tint_symbol {
+  [[location(0)]]
+  a : f32;
+  [[location(1)]]
+  b : u32;
+  [[location(2)]]
+  c : i32;
+  [[location(3)]]
+  d : bool;
+  [[builtin(position)]]
+  pos : vec4<f32>;
+};
+
+[[stage(vertex)]]
+fn vert_main() -> tint_symbol {
+  let tint_symbol_1 : VertexOutput = VertexOutput();
+  return tint_symbol(tint_symbol_1.a, tint_symbol_1.b, tint_symbol_1.c, tint_symbol_1.d, tint_symbol_1.pos);
+}
+
+struct tint_symbol_3 {
+  [[location(0)]]
+  a : f32;
+  [[location(1)]]
+  b : u32;
+  [[location(2)]]
+  c : i32;
+  [[location(3)]]
+  d : bool;
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[builtin(front_facing)]]
+  ff : bool;
+};
+
+[[stage(fragment)]]
+fn frag_main(tint_symbol_2 : tint_symbol_3) {
+}
+)";
+
+  auto got = Run<CanonicalizeEntryPointIO>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(CanonicalizeEntryPointIOTest, DontRenameSymbols) {
   auto* src = R"(
 [[stage(fragment)]]