[tint][ast] Fix std140 matrix size

matNx3 matrices decomposed to a column vectors with a size one scalar shorter than the original matrix. This could change the memory layout if a scalar followed the matrix.

Change-Id: Ib82bad350fb2c8773cddac3cfcad29a89a4c8fba
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/179000
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/wgsl/ast/transform/std140.cc b/src/tint/lang/wgsl/ast/transform/std140.cc
index 9dc49b5..42fb0f9 100644
--- a/src/tint/lang/wgsl/ast/transform/std140.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140.cc
@@ -456,6 +456,8 @@
         uint32_t size) {
         // Replace the member with column vectors.
         const auto num_columns = mat->columns();
+        const auto column_size = mat->ColumnType()->Size();
+        const auto column_stride = mat->ColumnStride();
         // Build a struct member for each column of the matrix
         tint::Vector<const StructMember*, 4> out;
         for (uint32_t i = 0; i < num_columns; i++) {
@@ -466,10 +468,13 @@
                 // needs to be applied to the first column vector.
                 attributes.Push(b.MemberAlign(i32(align)));
             }
-            if ((i == num_columns - 1) && mat->Size() != size) {
-                // The matrix was @size() annotated with a larger size than the
-                // natural size for the matrix. This extra padding needs to be
-                // applied to the last column vector.
+            if ((i == num_columns - 1) &&
+                (column_stride * (num_columns - 1) + column_size) != size) {
+                // The matrix size is larger than the individual component vectors.
+                // This occurs with matNx3 matrices, as the last vec3 column has space for one extra
+                // trailing scalar, which is occupied by the matrix. It also applies to matrices
+                // with an explicit @size() attribute.
+                // Apply extra padding needs to the last column vector.
                 attributes.Push(
                     b.MemberSize(AInt(size - mat->ColumnType()->Align() * (num_columns - 1))));
             }
diff --git a/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc b/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
index f9dca89..18e722f 100644
--- a/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
@@ -80,11 +80,11 @@
     }
 
     // For each column, replaces "${col_id_for_tmpl}" by column index in `tmpl` to get a string, and
-    // join all these strings with `seperator`. If `tmpl_for_last_column` is not empty, use it
+    // join all these strings with `separator`. If `tmpl_for_last_column` is not empty, use it
     // instead of `tmpl` for the last column.
     std::string JoinTemplatedStringForEachMatrixColumn(
         std::string tmpl,
-        std::string seperator,
+        std::string separator,
         std::string tmpl_for_last_column = "") const {
         std::string result;
         if (tmpl_for_last_column.size() == 0) {
@@ -92,13 +92,13 @@
         }
         for (size_t c = 0; c < columns - 1; c++) {
             if (c > 0) {
-                result += seperator;
+                result += separator;
             }
             std::string string_for_current_column =
                 tint::ReplaceAll(tmpl, "${col_id_for_tmpl}", std::to_string(c));
             result += string_for_current_column;
         }
-        result += seperator;
+        result += separator;
         std::string string_for_last_column = tint::ReplaceAll(
             tmpl_for_last_column, "${col_id_for_tmpl}", std::to_string(columns - 1));
         result += string_for_last_column;
@@ -106,13 +106,17 @@
     }
 
     std::string ExpendedColumnVectors(uint32_t leading_space, std::string name) const {
+        if (rows == 3) {
+            return ExpendedColumnVectorsWithLastSize(leading_space, name,
+                                                     type == MatrixType::f16 ? 8 : 16);
+        }
         std::string space(leading_space, ' ');
         return JoinTemplatedStringForEachMatrixColumn(
             space + name + "${col_id_for_tmpl} : " + ColumnVector() + ",", "\n");
     }
 
-    std::string ExpendedColumnVectorsInline(std::string name, std::string seperator) const {
-        return JoinTemplatedStringForEachMatrixColumn(name + "${col_id_for_tmpl}", seperator);
+    std::string ExpendedColumnVectorsInline(std::string name, std::string separator) const {
+        return JoinTemplatedStringForEachMatrixColumn(name + "${col_id_for_tmpl}", separator);
     }
 
     std::string ExpendedColumnVectorsWithLastSize(uint32_t leading_space,
diff --git a/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc b/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
index 43ddf38..7c17213 100644
--- a/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
@@ -121,6 +121,7 @@
 
 struct S2x3F16_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -131,6 +132,7 @@
 struct S3x3F16_std140 {
   m_0 : vec3<f16>,
   m_1 : vec3<f16>,
+  @size(8)
   m_2 : vec3<f16>,
 }
 
@@ -142,6 +144,7 @@
   m_0 : vec3<f16>,
   m_1 : vec3<f16>,
   m_2 : vec3<f16>,
+  @size(8)
   m_3 : vec3<f16>,
 }
 
@@ -223,6 +226,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -262,6 +266,7 @@
   before : i32,
   @align(128i)
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
   after : i32,
 }
@@ -380,6 +385,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -426,6 +432,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -461,6 +468,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -493,6 +501,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -539,6 +548,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -571,6 +581,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -617,6 +628,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -649,6 +661,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -696,6 +709,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -729,6 +743,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -782,6 +797,7 @@
 struct S_std140 {
   m_1 : i32,
   m__0 : vec3<f16>,
+  @size(8)
   m__1 : vec3<f16>,
 }
 
@@ -817,6 +833,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -860,6 +877,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -904,6 +922,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -944,6 +963,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -998,6 +1018,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -1038,6 +1059,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -1093,6 +1115,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -1134,6 +1157,7 @@
 
 struct S_std140 {
   m_0 : vec3<f16>,
+  @size(8)
   m_1 : vec3<f16>,
 }
 
@@ -2294,6 +2318,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2337,6 +2362,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2373,6 +2399,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2409,6 +2436,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2441,6 +2469,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2474,6 +2503,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2522,6 +2552,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2572,6 +2603,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2631,6 +2663,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2686,6 +2719,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2734,6 +2768,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2783,6 +2818,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2828,6 +2864,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2874,6 +2911,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2934,6 +2972,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -2988,6 +3027,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3039,6 +3079,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3083,6 +3124,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3128,6 +3170,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3165,6 +3208,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3203,6 +3247,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3241,6 +3286,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3279,6 +3325,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3313,6 +3360,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3362,6 +3410,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3398,6 +3447,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3448,6 +3498,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3484,6 +3535,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3535,6 +3587,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }
 
@@ -3573,6 +3626,7 @@
 
 struct mat2x3_f16 {
   col0 : vec3<f16>,
+  @size(8)
   col1 : vec3<f16>,
 }