[tint][ast] Fix std140 matrix size
matNx3 matrices decomposed to a column vectors with a size one scalar shorter than the original matrix. This could change the memory layout if a scalar followed the matrix.
Change-Id: Ib82bad350fb2c8773cddac3cfcad29a89a4c8fba
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/179000
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/wgsl/ast/transform/std140.cc b/src/tint/lang/wgsl/ast/transform/std140.cc
index 9dc49b5..42fb0f9 100644
--- a/src/tint/lang/wgsl/ast/transform/std140.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140.cc
@@ -456,6 +456,8 @@
uint32_t size) {
// Replace the member with column vectors.
const auto num_columns = mat->columns();
+ const auto column_size = mat->ColumnType()->Size();
+ const auto column_stride = mat->ColumnStride();
// Build a struct member for each column of the matrix
tint::Vector<const StructMember*, 4> out;
for (uint32_t i = 0; i < num_columns; i++) {
@@ -466,10 +468,13 @@
// needs to be applied to the first column vector.
attributes.Push(b.MemberAlign(i32(align)));
}
- if ((i == num_columns - 1) && mat->Size() != size) {
- // The matrix was @size() annotated with a larger size than the
- // natural size for the matrix. This extra padding needs to be
- // applied to the last column vector.
+ if ((i == num_columns - 1) &&
+ (column_stride * (num_columns - 1) + column_size) != size) {
+ // The matrix size is larger than the individual component vectors.
+ // This occurs with matNx3 matrices, as the last vec3 column has space for one extra
+ // trailing scalar, which is occupied by the matrix. It also applies to matrices
+ // with an explicit @size() attribute.
+ // Apply extra padding needs to the last column vector.
attributes.Push(
b.MemberSize(AInt(size - mat->ColumnType()->Align() * (num_columns - 1))));
}
diff --git a/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc b/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
index f9dca89..18e722f 100644
--- a/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140_exhaustive_test.cc
@@ -80,11 +80,11 @@
}
// For each column, replaces "${col_id_for_tmpl}" by column index in `tmpl` to get a string, and
- // join all these strings with `seperator`. If `tmpl_for_last_column` is not empty, use it
+ // join all these strings with `separator`. If `tmpl_for_last_column` is not empty, use it
// instead of `tmpl` for the last column.
std::string JoinTemplatedStringForEachMatrixColumn(
std::string tmpl,
- std::string seperator,
+ std::string separator,
std::string tmpl_for_last_column = "") const {
std::string result;
if (tmpl_for_last_column.size() == 0) {
@@ -92,13 +92,13 @@
}
for (size_t c = 0; c < columns - 1; c++) {
if (c > 0) {
- result += seperator;
+ result += separator;
}
std::string string_for_current_column =
tint::ReplaceAll(tmpl, "${col_id_for_tmpl}", std::to_string(c));
result += string_for_current_column;
}
- result += seperator;
+ result += separator;
std::string string_for_last_column = tint::ReplaceAll(
tmpl_for_last_column, "${col_id_for_tmpl}", std::to_string(columns - 1));
result += string_for_last_column;
@@ -106,13 +106,17 @@
}
std::string ExpendedColumnVectors(uint32_t leading_space, std::string name) const {
+ if (rows == 3) {
+ return ExpendedColumnVectorsWithLastSize(leading_space, name,
+ type == MatrixType::f16 ? 8 : 16);
+ }
std::string space(leading_space, ' ');
return JoinTemplatedStringForEachMatrixColumn(
space + name + "${col_id_for_tmpl} : " + ColumnVector() + ",", "\n");
}
- std::string ExpendedColumnVectorsInline(std::string name, std::string seperator) const {
- return JoinTemplatedStringForEachMatrixColumn(name + "${col_id_for_tmpl}", seperator);
+ std::string ExpendedColumnVectorsInline(std::string name, std::string separator) const {
+ return JoinTemplatedStringForEachMatrixColumn(name + "${col_id_for_tmpl}", separator);
}
std::string ExpendedColumnVectorsWithLastSize(uint32_t leading_space,
diff --git a/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc b/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
index 43ddf38..7c17213 100644
--- a/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140_f16_test.cc
@@ -121,6 +121,7 @@
struct S2x3F16_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -131,6 +132,7 @@
struct S3x3F16_std140 {
m_0 : vec3<f16>,
m_1 : vec3<f16>,
+ @size(8)
m_2 : vec3<f16>,
}
@@ -142,6 +144,7 @@
m_0 : vec3<f16>,
m_1 : vec3<f16>,
m_2 : vec3<f16>,
+ @size(8)
m_3 : vec3<f16>,
}
@@ -223,6 +226,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -262,6 +266,7 @@
before : i32,
@align(128i)
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
after : i32,
}
@@ -380,6 +385,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -426,6 +432,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -461,6 +468,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -493,6 +501,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -539,6 +548,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -571,6 +581,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -617,6 +628,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -649,6 +661,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -696,6 +709,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -729,6 +743,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -782,6 +797,7 @@
struct S_std140 {
m_1 : i32,
m__0 : vec3<f16>,
+ @size(8)
m__1 : vec3<f16>,
}
@@ -817,6 +833,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -860,6 +877,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -904,6 +922,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -944,6 +963,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -998,6 +1018,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -1038,6 +1059,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -1093,6 +1115,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -1134,6 +1157,7 @@
struct S_std140 {
m_0 : vec3<f16>,
+ @size(8)
m_1 : vec3<f16>,
}
@@ -2294,6 +2318,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2337,6 +2362,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2373,6 +2399,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2409,6 +2436,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2441,6 +2469,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2474,6 +2503,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2522,6 +2552,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2572,6 +2603,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2631,6 +2663,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2686,6 +2719,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2734,6 +2768,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2783,6 +2818,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2828,6 +2864,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2874,6 +2911,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2934,6 +2972,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -2988,6 +3027,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3039,6 +3079,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3083,6 +3124,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3128,6 +3170,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3165,6 +3208,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3203,6 +3247,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3241,6 +3286,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3279,6 +3325,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3313,6 +3360,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3362,6 +3410,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3398,6 +3447,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3448,6 +3498,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3484,6 +3535,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3535,6 +3587,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}
@@ -3573,6 +3626,7 @@
struct mat2x3_f16 {
col0 : vec3<f16>,
+ @size(8)
col1 : vec3<f16>,
}