Import Tint changes from Dawn
Changes:
- 657e61d43dcd6958b455251a3b896e22893e06ca tint: Add and use new Std140 transform by Ben Clayton <bclayton@google.com>
- 655db070220ab631f58104db52038ea20ffe7a58 tint/sem: Rename sem::Manager to TypeManager by Ben Clayton <bclayton@google.com>
- c20c5dfb4a3cc606f51742f24e2566df999cc1d0 tint: Implement const eval of binary multiply by Antonio Maiorano <amaiorano@google.com>
- 2b47c216407a0ab0e0add702115f375290b0677e tint/sem: Add Find() to the type manager by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 657e61d43dcd6958b455251a3b896e22893e06ca
Change-Id: I08abfa36d71bf37ace7abd1892df8b6a217716b9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/101140
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 61bdea6..c1e84d1 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -533,6 +533,8 @@
"transform/single_entry_point.h",
"transform/spirv_atomic.cc",
"transform/spirv_atomic.h",
+ "transform/std140.cc",
+ "transform/std140.h",
"transform/substitute_override.cc",
"transform/substitute_override.h",
"transform/transform.cc",
@@ -1213,6 +1215,7 @@
"transform/simplify_pointers_test.cc",
"transform/single_entry_point_test.cc",
"transform/spirv_atomic_test.cc",
+ "transform/std140_test.cc",
"transform/substitute_override_test.cc",
"transform/test_helper.h",
"transform/transform_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 62809ab..e8f875b 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -445,6 +445,8 @@
transform/single_entry_point.h
transform/spirv_atomic.cc
transform/spirv_atomic.h
+ transform/std140.cc
+ transform/std140.h
transform/substitute_override.cc
transform/substitute_override.h
transform/transform.cc
@@ -1128,6 +1130,7 @@
transform/simplify_pointers_test.cc
transform/single_entry_point_test.cc
transform/spirv_atomic_test.cc
+ transform/std140_test.cc
transform/substitute_override_test.cc
transform/test_helper.h
transform/unshadow_test.cc
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index ae89f5c..a565160 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -898,15 +898,15 @@
@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
-op * <T: fiu32_f16>(T, T) -> T
-op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
-op * <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
-op * <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
-op * <T: f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
-op * <T: f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
-op * <T: f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
-op * <T: f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
-op * <T: f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
+@const("Multiply") op * <T: fia_fiu32_f16>(T, T) -> T
+@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
+@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
+@const("Multiply") op * <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
+@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
+@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
+@const("MultiplyMatVec") op * <T: fa_f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
+@const("MultiplyVecMat") op * <T: fa_f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
+@const("MultiplyMatMat") op * <T: fa_f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
op / <T: fiu32_f16>(T, T) -> T
op / <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
diff --git a/src/tint/number.h b/src/tint/number.h
index 4635051..032844c 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -22,6 +22,7 @@
#include <optional>
#include <ostream>
+#include "src/tint/traits.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/result.h"
@@ -33,6 +34,14 @@
} // namespace tint
namespace tint::detail {
+/// Base template for IsNumber
+template <typename T>
+struct IsNumber : std::false_type {};
+
+/// Specialization for IsNumber
+template <typename T>
+struct IsNumber<Number<T>> : std::true_type {};
+
/// An empty structure used as a unique template type for Number when
/// specializing for the f16 type.
struct NumberKindF16 {};
@@ -68,6 +77,10 @@
template <typename T>
constexpr bool IsNumeric = IsInteger<T> || IsFloatingPoint<T>;
+/// Evaluates to true iff T is a Number
+template <typename T>
+constexpr bool IsNumber = detail::IsNumber<T>::value;
+
/// Resolves to the underlying type for a Number.
template <typename T>
using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
@@ -236,6 +249,26 @@
/// However since C++ don't have native binary16 type, the value is stored as float.
using f16 = Number<detail::NumberKindF16>;
+/// @returns the friendly name of Number type T
+template <typename T, typename = traits::EnableIf<IsNumber<T>>>
+const char* FriendlyName() {
+ if constexpr (std::is_same_v<T, AInt>) {
+ return "abstract-int";
+ } else if constexpr (std::is_same_v<T, AFloat>) {
+ return "abstract-float";
+ } else if constexpr (std::is_same_v<T, i32>) {
+ return "i32";
+ } else if constexpr (std::is_same_v<T, u32>) {
+ return "u32";
+ } else if constexpr (std::is_same_v<T, f32>) {
+ return "f32";
+ } else if constexpr (std::is_same_v<T, f16>) {
+ return "f16";
+ } else {
+ static_assert(!sizeof(T), "Unhandled type");
+ }
+}
+
/// Enumerator of failure reasons when converting from one number to another.
enum class ConversionFailure {
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
@@ -437,6 +470,15 @@
return AInt(result);
}
+/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
+inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
+ auto result = a.value * b.value;
+ if (!std::isfinite(result)) {
+ return {};
+ }
+ return AFloat{result};
+}
+
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
diff --git a/src/tint/program.h b/src/tint/program.h
index ea88470..b482f90 100644
--- a/src/tint/program.h
+++ b/src/tint/program.h
@@ -73,7 +73,7 @@
ast::NodeID HighestASTNodeID() const { return highest_node_id_; }
/// @returns a reference to the program's types
- const sem::Manager& Types() const {
+ const sem::TypeManager& Types() const {
AssertNotMoved();
return types_;
}
@@ -165,7 +165,7 @@
ProgramID id_;
ast::NodeID highest_node_id_;
- sem::Manager types_;
+ sem::TypeManager types_;
ASTNodeAllocator ast_nodes_;
SemNodeAllocator sem_nodes_;
ConstantAllocator constant_nodes_;
diff --git a/src/tint/program_builder.cc b/src/tint/program_builder.cc
index c6cff0f..443be11 100644
--- a/src/tint/program_builder.cc
+++ b/src/tint/program_builder.cc
@@ -70,7 +70,7 @@
ProgramBuilder builder;
builder.id_ = program->ID();
builder.last_ast_node_id_ = program->HighestASTNodeID();
- builder.types_ = sem::Manager::Wrap(program->Types());
+ builder.types_ = sem::TypeManager::Wrap(program->Types());
builder.ast_ =
builder.create<ast::Module>(program->AST().source, program->AST().GlobalDeclarations());
builder.sem_ = sem::Info::Wrap(program->Sem());
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 487420c..0af3a90 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -297,13 +297,13 @@
ProgramID ID() const { return id_; }
/// @returns a reference to the program's types
- sem::Manager& Types() {
+ sem::TypeManager& Types() {
AssertNotMoved();
return types_;
}
/// @returns a reference to the program's types
- const sem::Manager& Types() const {
+ const sem::TypeManager& Types() const {
AssertNotMoved();
return types_;
}
@@ -3169,7 +3169,7 @@
private:
ProgramID id_;
ast::NodeID last_ast_node_id_ = ast::NodeID{static_cast<decltype(ast::NodeID::value)>(0) - 1};
- sem::Manager types_;
+ sem::TypeManager types_;
ASTNodeAllocator ast_nodes_;
SemNodeAllocator sem_nodes_;
ConstantAllocator constant_nodes_;
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 9bc1d47..d363b67 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -38,6 +38,7 @@
#include "src/tint/sem/vector.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
+#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
@@ -508,11 +509,202 @@
auto* ty = n0 > n1 ? c0->Type() : c1->Type();
return CreateComposite(builder, ty, std::move(els));
}
-
} // namespace
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
+template <typename NumberT>
+utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) {
+ using T = UnwrapNumber<NumberT>;
+ auto add_values = [](T lhs, T rhs) {
+ if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
+ // Ensure no UB for signed overflow
+ using UT = std::make_unsigned_t<T>;
+ return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
+ } else {
+ return lhs + rhs;
+ }
+ };
+ NumberT result;
+ if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
+ // Check for over/underflow for abstract values
+ if (auto r = CheckedAdd(a, b)) {
+ result = r->value;
+ } else {
+ AddError("'" + std::to_string(add_values(a.value, b.value)) +
+ "' cannot be represented as '" + FriendlyName<NumberT>() + "'",
+ *current_source);
+ return utils::Failure;
+ }
+ } else {
+ result = add_values(a.value, b.value);
+ }
+ return result;
+}
+
+template <typename NumberT>
+utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
+ using T = UnwrapNumber<NumberT>;
+ auto mul_values = [](T lhs, T rhs) { //
+ if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
+ // For signed integrals, avoid C++ UB by multiplying as unsigned
+ using UT = std::make_unsigned_t<T>;
+ return static_cast<T>(static_cast<UT>(lhs) * static_cast<UT>(rhs));
+ } else {
+ return lhs * rhs;
+ }
+ };
+ NumberT result;
+ if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
+ // Check for over/underflow for abstract values
+ if (auto r = CheckedMul(a, b)) {
+ result = r->value;
+ } else {
+ AddError("'" + std::to_string(mul_values(a.value, b.value)) +
+ "' cannot be represented as '" + FriendlyName<NumberT>() + "'",
+ *current_source);
+ return utils::Failure;
+ }
+ } else {
+ result = mul_values(a.value, b.value);
+ }
+ return result;
+}
+
+template <typename NumberT>
+utils::Result<NumberT> ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) {
+ auto r1 = Mul(a1, b1);
+ if (!r1) {
+ return utils::Failure;
+ }
+ auto r2 = Mul(a2, b2);
+ if (!r2) {
+ return utils::Failure;
+ }
+ auto r = Add(r1.Get(), r2.Get());
+ if (!r) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+template <typename NumberT>
+utils::Result<NumberT> ConstEval::Dot3(NumberT a1,
+ NumberT a2,
+ NumberT a3,
+ NumberT b1,
+ NumberT b2,
+ NumberT b3) {
+ auto r1 = Mul(a1, b1);
+ if (!r1) {
+ return utils::Failure;
+ }
+ auto r2 = Mul(a2, b2);
+ if (!r2) {
+ return utils::Failure;
+ }
+ auto r3 = Mul(a3, b3);
+ if (!r3) {
+ return utils::Failure;
+ }
+ auto r = Add(r1.Get(), r2.Get());
+ if (!r) {
+ return utils::Failure;
+ }
+ r = Add(r.Get(), r3.Get());
+ if (!r) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+template <typename NumberT>
+utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
+ NumberT a2,
+ NumberT a3,
+ NumberT a4,
+ NumberT b1,
+ NumberT b2,
+ NumberT b3,
+ NumberT b4) {
+ auto r1 = Mul(a1, b1);
+ if (!r1) {
+ return utils::Failure;
+ }
+ auto r2 = Mul(a2, b2);
+ if (!r2) {
+ return utils::Failure;
+ }
+ auto r3 = Mul(a3, b3);
+ if (!r3) {
+ return utils::Failure;
+ }
+ auto r4 = Mul(a4, b4);
+ if (!r4) {
+ return utils::Failure;
+ }
+ auto r = Add(r1.Get(), r2.Get());
+ if (!r) {
+ return utils::Failure;
+ }
+ r = Add(r.Get(), r3.Get());
+ if (!r) {
+ return utils::Failure;
+ }
+ r = Add(r.Get(), r4.Get());
+ if (!r) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+auto ConstEval::AddFunc(const sem::Type* elem_ty) {
+ return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
+ if (auto r = Add(a1, a2)) {
+ return CreateElement(builder, elem_ty, r.Get());
+ }
+ return utils::Failure;
+ };
+}
+
+auto ConstEval::MulFunc(const sem::Type* elem_ty) {
+ return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
+ if (auto r = Mul(a1, a2)) {
+ return CreateElement(builder, elem_ty, r.Get());
+ }
+ return utils::Failure;
+ };
+}
+
+auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
+ return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> {
+ if (auto r = Dot2(a1, a2, b1, b2)) {
+ return CreateElement(builder, elem_ty, r.Get());
+ }
+ return utils::Failure;
+ };
+}
+
+auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
+ return [=](auto a1, auto a2, auto a3, auto b1, auto b2,
+ auto b3) -> utils::Result<const Constant*> {
+ if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
+ return CreateElement(builder, elem_ty, r.Get());
+ }
+ return utils::Failure;
+ };
+}
+
+auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
+ return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
+ auto b4) -> utils::Result<const Constant*> {
+ if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
+ return CreateElement(builder, elem_ty, r.Get());
+ }
+ return utils::Failure;
+ };
+}
+
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
const ast::LiteralExpression* literal) {
return Switch(
@@ -756,42 +948,15 @@
return TransformElements(builder, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
+ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
- auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
- auto create = [&](auto i, auto j) -> const Constant* {
- using NumberT = decltype(i);
- using T = UnwrapNumber<NumberT>;
-
- auto add_values = [](T lhs, T rhs) {
- if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
- // Ensure no UB for signed overflow
- using UT = std::make_unsigned_t<T>;
- return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
- } else {
- return lhs + rhs;
- }
- };
-
- NumberT result;
- if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
- // Check for over/underflow for abstract values
- if (auto r = CheckedAdd(i, j)) {
- result = r->value;
- } else {
- AddError("'" + std::to_string(add_values(i.value, j.value)) +
- "' cannot be represented as '" +
- ty->FriendlyName(builder.Symbols()) + "'",
- source);
- return nullptr;
- }
- } else {
- result = add_values(i.value, j.value);
- }
- return CreateElement(builder, c0->Type(), result);
- };
- return Dispatch_fia_fiu32_f16(create, c0, c1);
+ TINT_SCOPED_ASSIGNMENT(current_source, &source);
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
+ if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) {
+ return r.Get();
+ }
+ return nullptr;
};
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
@@ -846,6 +1011,192 @@
return r;
}
+ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
+ TINT_SCOPED_ASSIGNMENT(current_source, &source);
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
+ if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) {
+ return r.Get();
+ }
+ return nullptr;
+ };
+
+ auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
+ TINT_SCOPED_ASSIGNMENT(current_source, &source);
+ auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
+ auto* vec_ty = args[1]->Type()->As<sem::Vector>();
+ auto* elem_ty = vec_ty->type();
+
+ auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
+ utils::Result<const Constant*> result;
+ switch (mat_ty->columns()) {
+ case 2:
+ result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
+ m->Index(0)->Index(row), //
+ m->Index(1)->Index(row), //
+ v->Index(0), //
+ v->Index(1));
+ break;
+ case 3:
+ result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
+ m->Index(0)->Index(row), //
+ m->Index(1)->Index(row), //
+ m->Index(2)->Index(row), //
+ v->Index(0), //
+ v->Index(1), v->Index(2));
+ break;
+ case 4:
+ result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
+ m->Index(0)->Index(row), //
+ m->Index(1)->Index(row), //
+ m->Index(2)->Index(row), //
+ m->Index(3)->Index(row), //
+ v->Index(0), //
+ v->Index(1), //
+ v->Index(2), //
+ v->Index(3));
+ break;
+ }
+ return result;
+ };
+
+ utils::Vector<const sem::Constant*, 4> result;
+ for (size_t i = 0; i < mat_ty->rows(); ++i) {
+ auto r = dot(args[0], i, args[1]); // matrix row i * vector
+ if (!r) {
+ return utils::Failure;
+ }
+ result.Push(r.Get());
+ }
+ return CreateComposite(builder, ty, result);
+}
+ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
+ TINT_SCOPED_ASSIGNMENT(current_source, &source);
+ auto* vec_ty = args[0]->Type()->As<sem::Vector>();
+ auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
+ auto* elem_ty = vec_ty->type();
+
+ auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
+ utils::Result<const Constant*> result;
+ switch (mat_ty->rows()) {
+ case 2:
+ result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
+ m->Index(col)->Index(0), //
+ m->Index(col)->Index(1), //
+ v->Index(0), //
+ v->Index(1));
+ break;
+ case 3:
+ result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
+ m->Index(col)->Index(0), //
+ m->Index(col)->Index(1), //
+ m->Index(col)->Index(2),
+ v->Index(0), //
+ v->Index(1), //
+ v->Index(2));
+ break;
+ case 4:
+ result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
+ m->Index(col)->Index(0), //
+ m->Index(col)->Index(1), //
+ m->Index(col)->Index(2), //
+ m->Index(col)->Index(3), //
+ v->Index(0), //
+ v->Index(1), //
+ v->Index(2), //
+ v->Index(3));
+ }
+ return result;
+ };
+
+ utils::Vector<const sem::Constant*, 4> result;
+ for (size_t i = 0; i < mat_ty->columns(); ++i) {
+ auto r = dot(args[0], args[1], i); // vector * matrix col i
+ if (!r) {
+ return utils::Failure;
+ }
+ result.Push(r.Get());
+ }
+ return CreateComposite(builder, ty, result);
+}
+
+ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
+ TINT_SCOPED_ASSIGNMENT(current_source, &source);
+ auto* mat1 = args[0];
+ auto* mat2 = args[1];
+ auto* mat1_ty = mat1->Type()->As<sem::Matrix>();
+ auto* mat2_ty = mat2->Type()->As<sem::Matrix>();
+ auto* elem_ty = mat1_ty->type();
+
+ auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) {
+ auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
+ auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
+
+ utils::Result<const Constant*> result;
+ switch (mat1_ty->columns()) {
+ case 2:
+ result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
+ m1e(row, 0), //
+ m1e(row, 1), //
+ m2e(0, col), //
+ m2e(1, col));
+ break;
+ case 3:
+ result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
+ m1e(row, 0), //
+ m1e(row, 1), //
+ m1e(row, 2), //
+ m2e(0, col), //
+ m2e(1, col), //
+ m2e(2, col));
+ break;
+ case 4:
+ result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
+ m1e(row, 0), //
+ m1e(row, 1), //
+ m1e(row, 2), //
+ m1e(row, 3), //
+ m2e(0, col), //
+ m2e(1, col), //
+ m2e(2, col), //
+ m2e(3, col));
+ break;
+ }
+ return result;
+ };
+
+ utils::Vector<const sem::Constant*, 4> result_mat;
+ for (size_t c = 0; c < mat2_ty->columns(); ++c) {
+ utils::Vector<const sem::Constant*, 4> col_vec;
+ for (size_t r = 0; r < mat1_ty->rows(); ++r) {
+ auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c
+ if (!v) {
+ return utils::Failure;
+ }
+ col_vec.Push(v.Get()); // mat1 row r * mat2 col c
+ }
+
+ // Add column vector to matrix
+ auto* col_vec_ty = ty->As<sem::Matrix>()->ColumnType();
+ result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
+ }
+ return CreateComposite(builder, ty, result_mat);
+}
+
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 9cd38d7..4007162 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -230,6 +230,42 @@
utils::VectorRef<const sem::Constant*> args,
const Source& source);
+ /// Multiply operator '*' for the same type on the LHS and RHS
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpMultiply(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Multiply operator '*' for matCxR<T> * vecC<T>
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpMultiplyMatVec(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Multiply operator '*' for vecR<T> * matCxR<T>
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpMultiplyVecMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Multiply operator '*' for matKxR<T> * matCxK<T>
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpMultiplyMatMat(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
@@ -259,7 +295,97 @@
/// Adds the given warning message to the diagnostics
void AddWarning(const std::string& msg, const Source& source) const;
+ /// Adds two Number<T>s
+ /// @param a the lhs number
+ /// @param b the rhs number
+ /// @returns the result number on success, or logs an error and returns Failure
+ template <typename NumberT>
+ utils::Result<NumberT> Add(NumberT a, NumberT b);
+
+ /// Multiplies two Number<T>s
+ /// @param a the lhs number
+ /// @param b the rhs number
+ /// @returns the result number on success, or logs an error and returns Failure
+ template <typename NumberT>
+ utils::Result<NumberT> Mul(NumberT a, NumberT b);
+
+ /// Returns the dot product of (a1,a2) with (b1,b2)
+ /// @param a1 component 1 of lhs vector
+ /// @param a2 component 2 of lhs vector
+ /// @param b1 component 1 of rhs vector
+ /// @param b2 component 2 of rhs vector
+ /// @returns the result number on success, or logs an error and returns Failure
+ template <typename NumberT>
+ utils::Result<NumberT> Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2);
+
+ /// Returns the dot product of (a1,a2,a3) with (b1,b2,b3)
+ /// @param a1 component 1 of lhs vector
+ /// @param a2 component 2 of lhs vector
+ /// @param a3 component 3 of lhs vector
+ /// @param b1 component 1 of rhs vector
+ /// @param b2 component 2 of rhs vector
+ /// @param b3 component 3 of rhs vector
+ /// @returns the result number on success, or logs an error and returns Failure
+ template <typename NumberT>
+ utils::Result<NumberT> Dot3(NumberT a1,
+ NumberT a2,
+ NumberT a3,
+ NumberT b1,
+ NumberT b2,
+ NumberT b3);
+
+ /// Returns the dot product of (a1,b1,c1,d1) with (a2,b2,c2,d2)
+ /// @param a1 component 1 of lhs vector
+ /// @param a2 component 2 of lhs vector
+ /// @param a3 component 3 of lhs vector
+ /// @param a4 component 4 of lhs vector
+ /// @param b1 component 1 of rhs vector
+ /// @param b2 component 2 of rhs vector
+ /// @param b3 component 3 of rhs vector
+ /// @param b4 component 4 of rhs vector
+ /// @returns the result number on success, or logs an error and returns Failure
+ template <typename NumberT>
+ utils::Result<NumberT> Dot4(NumberT a1,
+ NumberT a2,
+ NumberT a3,
+ NumberT a4,
+ NumberT b1,
+ NumberT b2,
+ NumberT b3,
+ NumberT b4);
+
+ /// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
+ /// if successful, or returns Failure otherwise.
+ /// @param elem_ty the element type of the Constant to create on success
+ /// @returns the callable function
+ auto AddFunc(const sem::Type* elem_ty);
+
+ /// Returns a callable that calls Mul, and creates a Constant with its result of type `elem_ty`
+ /// if successful, or returns Failure otherwise.
+ /// @param elem_ty the element type of the Constant to create on success
+ /// @returns the callable function
+ auto MulFunc(const sem::Type* elem_ty);
+
+ /// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
+ /// if successful, or returns Failure otherwise.
+ /// @param elem_ty the element type of the Constant to create on success
+ /// @returns the callable function
+ auto Dot2Func(const sem::Type* elem_ty);
+
+ /// Returns a callable that calls Dot3, and creates a Constant with its result of type `elem_ty`
+ /// if successful, or returns Failure otherwise.
+ /// @param elem_ty the element type of the Constant to create on success
+ /// @returns the callable function
+ auto Dot3Func(const sem::Type* elem_ty);
+
+ /// Returns a callable that calls Dot4, and creates a Constant with its result of type `elem_ty`
+ /// if successful, or returns Failure otherwise.
+ /// @param elem_ty the element type of the Constant to create on success
+ /// @returns the callable function
+ auto Dot4Func(const sem::Type* elem_ty);
+
ProgramBuilder& builder;
+ const Source* current_source = nullptr;
};
} // namespace tint::resolver
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index 23e621f..a83cfa1 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -15,6 +15,7 @@
#include <cmath>
#include <type_traits>
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/builtin_type.h"
@@ -24,6 +25,8 @@
#include "src/tint/sem/test_helper.h"
#include "src/tint/utils/transform.h"
+using ::testing::HasSubstr;
+
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
@@ -74,6 +77,19 @@
}
}
+TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
+template <typename T>
+constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
+ if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
+ // For signed integrals, avoid C++ UB by multiplying as unsigned
+ using UT = std::make_unsigned_t<T>;
+ return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2));
+ } else {
+ return static_cast<Number<T>>(v1 * v2);
+ }
+}
+TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
+
// Concats any number of std::vectors
template <typename Vec, typename... Vecs>
[[nodiscard]] auto Concat(Vec&& v1, Vecs&&... vs) {
@@ -3228,29 +3244,26 @@
return Case{Val(lhs), Val(rhs), Val(expected), overflow};
}
-static std::ostream& operator<<(std::ostream& o, const Case& c) {
- auto print_value = [&](auto&& value) {
- std::visit(
- [&](auto&& v) {
- using ValueType = std::decay_t<decltype(v)>;
- o << ValueType::DataType::Name() << "(";
- for (auto& a : v.args.values) {
- o << std::get<typename ValueType::ElementType>(a);
- if (&a != &v.args.values.Back()) {
- o << ", ";
- }
+static std::ostream& operator<<(std::ostream& o, const Types& types) {
+ std::visit(
+ [&](auto&& v) {
+ using ValueType = std::decay_t<decltype(v)>;
+ o << ValueType::DataType::Name() << "(";
+ for (auto& a : v.args.values) {
+ o << std::get<typename ValueType::ElementType>(a);
+ if (&a != &v.args.values.Back()) {
+ o << ", ";
}
- o << ")";
- },
- value);
- };
- o << "lhs: ";
- print_value(c.lhs);
- o << ", rhs: ";
- print_value(c.rhs);
- o << ", expected: ";
- print_value(c.expected);
- o << ", overflow: " << c.overflow;
+ }
+ o << ")";
+ },
+ types);
+ return o;
+}
+
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected
+ << ", overflow: " << c.overflow;
return o;
}
@@ -3281,7 +3294,6 @@
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
TEST_P(ResolverConstEvalBinaryOpTest, Test) {
Enable(ast::Extension::kF16);
-
auto op = std::get<0>(GetParam());
auto& c = std::get<1>(GetParam());
@@ -3300,10 +3312,8 @@
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
GlobalConst("C", expr);
-
auto* expected_expr = expected.Expr(*this);
GlobalConst("E", expected_expr);
-
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
@@ -3413,6 +3423,215 @@
OpSubFloatCases<f32>(),
OpSubFloatCases<f16>()))));
+template <typename T>
+std::vector<Case> OpMulScalarCases() {
+ return {
+ C(T{0}, T{0}, T{0}),
+ C(T{1}, T{2}, T{2}),
+ C(T{2}, T{3}, T{6}),
+ C(Negate(T{2}), T{3}, Negate(T{6})),
+ C(T::Highest(), T{1}, T::Highest()),
+ C(T::Lowest(), T{1}, T::Lowest()),
+ C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true),
+ C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true),
+ };
+}
+
+template <typename T>
+std::vector<Case> OpMulVecCases() {
+ return {
+ // s * vec3 = vec3
+ C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})),
+ // vec3 * s = vec3
+ C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
+ // vec3 * vec3 = vec3
+ C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
+ };
+}
+
+template <typename T>
+std::vector<Case> OpMulMatCases() {
+ return {
+ // s * mat3x2 = mat3x2
+ C(Val(T{2.25}),
+ Mat({T{1.0}, T{4.0}}, //
+ {T{2.0}, T{5.0}}, //
+ {T{3.0}, T{6.0}}),
+ Mat({T{2.25}, T{9.0}}, //
+ {T{4.5}, T{11.25}}, //
+ {T{6.75}, T{13.5}})),
+ // mat3x2 * s = mat3x2
+ C(Mat({T{1.0}, T{4.0}}, //
+ {T{2.0}, T{5.0}}, //
+ {T{3.0}, T{6.0}}),
+ Val(T{2.25}),
+ Mat({T{2.25}, T{9.0}}, //
+ {T{4.5}, T{11.25}}, //
+ {T{6.75}, T{13.5}})),
+ // vec3 * mat2x3 = vec2
+ C(Vec(T{1.25}, T{2.25}, T{3.25}), //
+ Mat({T{1.0}, T{2.0}, T{3.0}}, //
+ {T{4.0}, T{5.0}, T{6.0}}), //
+ Vec(T{15.5}, T{35.75})),
+ // mat2x3 * vec2 = vec3
+ C(Mat({T{1.0}, T{2.0}, T{3.0}}, //
+ {T{4.0}, T{5.0}, T{6.0}}), //
+ Vec(T{1.25}, T{2.25}), //
+ Vec(T{10.25}, T{13.75}, T{17.25})),
+ // mat3x2 * mat2x3 = mat2x2
+ C(Mat({T{1.0}, T{2.0}}, //
+ {T{3.0}, T{4.0}}, //
+ {T{5.0}, T{6.0}}), //
+ Mat({T{1.25}, T{2.25}, T{3.25}}, //
+ {T{4.25}, T{5.25}, T{6.25}}), //
+ Mat({T{24.25}, T{31.0}}, //
+ {T{51.25}, T{67.0}})), //
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(Mul,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kMultiply),
+ testing::ValuesIn(Concat( //
+ OpMulScalarCases<AInt>(),
+ OpMulScalarCases<i32>(),
+ OpMulScalarCases<u32>(),
+ OpMulScalarCases<AFloat>(),
+ OpMulScalarCases<f32>(),
+ OpMulScalarCases<f16>(),
+ OpMulVecCases<AInt>(),
+ OpMulVecCases<i32>(),
+ OpMulVecCases<u32>(),
+ OpMulVecCases<AFloat>(),
+ OpMulVecCases<f32>(),
+ OpMulVecCases<f16>(),
+ OpMulMatCases<AFloat>(),
+ OpMulMatCases<f32>(),
+ OpMulMatCases<f16>()))));
+
+// Tests for errors on overflow/underflow of binary operations with abstract numbers
+struct OverflowCase {
+ ast::BinaryOp op;
+ Types lhs;
+ Types rhs;
+ std::string overflowed_result;
+};
+
+static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) {
+ o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs;
+ return o;
+}
+using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam<OverflowCase>;
+TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
+ Enable(ast::Extension::kF16);
+ auto& c = GetParam();
+ auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
+ auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
+ auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
+ GlobalConst("C", expr);
+ ASSERT_FALSE(r()->Resolve());
+
+ std::string type_name = std::visit(
+ [&](auto&& value) {
+ using ValueType = std::decay_t<decltype(value)>;
+ return tint::FriendlyName<typename ValueType::ElementType>();
+ },
+ c.lhs);
+
+ EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '" + c.overflowed_result +
+ "' cannot be represented as '" + type_name + "'"));
+}
+INSTANTIATE_TEST_SUITE_P(
+ Test,
+ ResolverConstEvalBinaryOpTest_Overflow,
+ testing::Values( //
+ // scalar-scalar add
+ OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a), "-9223372036854775808"},
+ OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a), "9223372036854775807"},
+ OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest()), "inf"},
+ OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest()), "-inf"},
+ // scalar-scalar subtract
+ OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a),
+ "9223372036854775807"},
+ OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a),
+ "-9223372036854775808"},
+ OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest()),
+ "inf"},
+ OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest()),
+ "-inf"},
+
+ // scalar-scalar multiply
+ OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a), "-2"},
+ OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a), "0"},
+
+ // scalar-vector multiply
+ OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a), "-2"},
+ OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a), "0"},
+
+ // vector-matrix multiply
+
+ // Overflow from first multiplication of dot product of vector and matrix column 0
+ // i.e. (v[0] * m[0][0] + v[1] * m[0][1])
+ // ^
+ OverflowCase{ast::BinaryOp::kMultiply, //
+ Vec(AFloat::Highest(), 1.0_a), //
+ Mat({2.0_a, 1.0_a}, //
+ {1.0_a, 1.0_a}), //
+ "inf"},
+
+ // Overflow from second multiplication of dot product of vector and matrix column 0
+ // i.e. (v[0] * m[0][0] + v[1] * m[0][1])
+ // ^
+ OverflowCase{ast::BinaryOp::kMultiply, //
+ Vec(1.0_a, AFloat::Highest()), //
+ Mat({1.0_a, 2.0_a}, //
+ {1.0_a, 1.0_a}), //
+ "inf"},
+
+ // Overflow from addition of dot product of vector and matrix column 0
+ // i.e. (v[0] * m[0][0] + v[1] * m[0][1])
+ // ^
+ OverflowCase{ast::BinaryOp::kMultiply, //
+ Vec(AFloat::Highest(), AFloat::Highest()), //
+ Mat({1.0_a, 1.0_a}, //
+ {1.0_a, 1.0_a}), //
+ "inf"},
+
+ // matrix-matrix multiply
+
+ // Overflow from first multiplication of dot product of lhs row 0 and rhs column 0
+ // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
+ // ^
+ OverflowCase{ast::BinaryOp::kMultiply, //
+ Mat({AFloat::Highest(), 1.0_a}, //
+ {1.0_a, 1.0_a}), //
+ Mat({2.0_a, 1.0_a}, //
+ {1.0_a, 1.0_a}), //
+ "inf"},
+
+ // Overflow from second multiplication of dot product of lhs row 0 and rhs column 0
+ // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
+ // ^
+ OverflowCase{ast::BinaryOp::kMultiply, //
+ Mat({1.0_a, AFloat::Highest()}, //
+ {1.0_a, 1.0_a}), //
+ Mat({1.0_a, 1.0_a}, //
+ {2.0_a, 1.0_a}), //
+ "inf"},
+
+ // Overflow from addition of dot product of lhs row 0 and rhs column 0
+ // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
+ // ^
+ OverflowCase{ast::BinaryOp::kMultiply, //
+ Mat({AFloat::Highest(), 1.0_a}, //
+ {AFloat::Highest(), 1.0_a}), //
+ Mat({1.0_a, 1.0_a}, //
+ {1.0_a, 1.0_a}), //
+ "inf"}
+
+ ));
+
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index ff5878a..e0a8ddc 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -9572,108 +9572,108 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[727],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiply,
},
{
/* [118] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[725],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiply,
},
{
/* [119] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[723],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiply,
},
{
/* [120] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[721],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiply,
},
{
/* [121] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
- /* template types */ &kTemplateTypes[11],
+ /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[719],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiply,
},
{
/* [122] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
- /* template types */ &kTemplateTypes[11],
+ /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[717],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiply,
},
{
/* [123] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
- /* template types */ &kTemplateTypes[11],
+ /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[1],
/* parameters */ &kParameters[715],
/* return matcher indices */ &kMatcherIndices[69],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiplyMatVec,
},
{
/* [124] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
- /* template types */ &kTemplateTypes[11],
+ /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[1],
/* parameters */ &kParameters[713],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiplyVecMat,
},
{
/* [125] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 3,
- /* template types */ &kTemplateTypes[11],
+ /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[0],
/* parameters */ &kParameters[711],
/* return matcher indices */ &kMatcherIndices[22],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMultiplyMatMat,
},
{
/* [126] */
@@ -14630,15 +14630,15 @@
},
{
/* [2] */
- /* op *<T : fiu32_f16>(T, T) -> T */
- /* op *<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
- /* op *<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
- /* op *<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
- /* op *<T : f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
- /* op *<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
- /* op *<T : f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
- /* op *<T : f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
- /* op *<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
+ /* op *<T : fia_fiu32_f16>(T, T) -> T */
+ /* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
+ /* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
+ /* op *<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
+ /* op *<T : fa_f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
+ /* op *<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
+ /* op *<T : fa_f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
+ /* op *<T : fa_f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
+ /* op *<T : fa_f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* num overloads */ 9,
/* overloads */ &kOverloads[117],
},
diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc
index b9f9d53..0ffc011 100644
--- a/src/tint/resolver/intrinsic_table_test.cc
+++ b/src/tint/resolver/intrinsic_table_test.cc
@@ -641,15 +641,15 @@
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
9 candidate operators:
- operator * (T, T) -> T where: T is f32, i32, u32 or f16
- operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
- operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
- operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
- operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
- operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
- operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
- operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
- operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
+ operator * (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator * (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator * (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator * (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
+ operator * (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
+ operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
+ operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
+ operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
)");
}
@@ -673,15 +673,15 @@
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
9 candidate operators:
- operator *= (T, T) -> T where: T is f32, i32, u32 or f16
- operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
- operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
- operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
- operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
- operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
- operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
- operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
- operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
+ operator *= (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator *= (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator *= (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator *= (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
+ operator *= (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
+ operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
+ operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
+ operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
+ operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
)");
}
diff --git a/src/tint/sem/storage_texture.cc b/src/tint/sem/storage_texture.cc
index 5b743e9..caa8a79 100644
--- a/src/tint/sem/storage_texture.cc
+++ b/src/tint/sem/storage_texture.cc
@@ -48,7 +48,7 @@
return out.str();
}
-sem::Type* StorageTexture::SubtypeFor(ast::TexelFormat format, sem::Manager& type_mgr) {
+sem::Type* StorageTexture::SubtypeFor(ast::TexelFormat format, sem::TypeManager& type_mgr) {
switch (format) {
case ast::TexelFormat::kR32Uint:
case ast::TexelFormat::kRgba8Uint:
diff --git a/src/tint/sem/storage_texture.h b/src/tint/sem/storage_texture.h
index 68dbff9..00258d6 100644
--- a/src/tint/sem/storage_texture.h
+++ b/src/tint/sem/storage_texture.h
@@ -23,7 +23,7 @@
// Forward declarations
namespace tint::sem {
-class Manager;
+class TypeManager;
} // namespace tint::sem
namespace tint::sem {
@@ -67,9 +67,9 @@
std::string FriendlyName(const SymbolTable& symbols) const override;
/// @param format the storage texture image format
- /// @param type_mgr the sem::Manager used to build the returned type
+ /// @param type_mgr the sem::TypeManager used to build the returned type
/// @returns the storage texture subtype for the given TexelFormat
- static sem::Type* SubtypeFor(ast::TexelFormat format, sem::Manager& type_mgr);
+ static sem::Type* SubtypeFor(ast::TexelFormat format, sem::TypeManager& type_mgr);
private:
ast::TexelFormat const texel_format_;
diff --git a/src/tint/sem/type_manager.cc b/src/tint/sem/type_manager.cc
index e6f5f6d..8c26a4f 100644
--- a/src/tint/sem/type_manager.cc
+++ b/src/tint/sem/type_manager.cc
@@ -16,9 +16,9 @@
namespace tint::sem {
-Manager::Manager() = default;
-Manager::Manager(Manager&&) = default;
-Manager& Manager::operator=(Manager&& rhs) = default;
-Manager::~Manager() = default;
+TypeManager::TypeManager() = default;
+TypeManager::TypeManager(TypeManager&&) = default;
+TypeManager& TypeManager::operator=(TypeManager&& rhs) = default;
+TypeManager::~TypeManager() = default;
} // namespace tint::sem
diff --git a/src/tint/sem/type_manager.h b/src/tint/sem/type_manager.h
index fb08689..636b7a0 100644
--- a/src/tint/sem/type_manager.h
+++ b/src/tint/sem/type_manager.h
@@ -25,24 +25,24 @@
namespace tint::sem {
/// The type manager holds all the pointers to the known types.
-class Manager final : public utils::UniqueAllocator<Type> {
+class TypeManager final : public utils::UniqueAllocator<Type> {
public:
/// Iterator is the type returned by begin() and end()
using Iterator = utils::BlockAllocator<Type>::ConstIterator;
/// Constructor
- Manager();
+ TypeManager();
/// Move constructor
- Manager(Manager&&);
+ TypeManager(TypeManager&&);
/// Move assignment operator
/// @param rhs the Manager to move
/// @return this Manager
- Manager& operator=(Manager&& rhs);
+ TypeManager& operator=(TypeManager&& rhs);
/// Destructor
- ~Manager();
+ ~TypeManager();
/// Wrap returns a new Manager created with the types of `inner`.
/// The Manager returned by Wrap is intended to temporarily extend the types
@@ -53,12 +53,28 @@
/// function. See crbug.com/tint/460.
/// @param inner the immutable Manager to extend
/// @return the Manager that wraps `inner`
- static Manager Wrap(const Manager& inner) {
- Manager out;
+ static TypeManager Wrap(const TypeManager& inner) {
+ TypeManager out;
out.items = inner.items;
return out;
}
+ /// @param args the arguments used to create the temporary type used for the search.
+ /// @return a pointer to an instance of `T` with the provided arguments, or nullptr if the type
+ /// was not found.
+ template <typename TYPE, typename... ARGS>
+ TYPE* Find(ARGS&&... args) const {
+ // Create a temporary T instance on the stack so that we can hash it, and
+ // use it for equality lookup for the std::unordered_set.
+ TYPE key{args...};
+ auto hash = Hasher{}(key);
+ auto it = items.find(Entry{hash, &key});
+ if (it != items.end()) {
+ return static_cast<TYPE*>(it->ptr);
+ }
+ return nullptr;
+ }
+
/// @returns an iterator to the beginning of the types
Iterator begin() const { return allocator.Objects().begin(); }
/// @returns an iterator to the end of the types
diff --git a/src/tint/sem/type_manager_test.cc b/src/tint/sem/type_manager_test.cc
index c670db0..d2ca916 100644
--- a/src/tint/sem/type_manager_test.cc
+++ b/src/tint/sem/type_manager_test.cc
@@ -34,14 +34,14 @@
using TypeManagerTest = testing::Test;
TEST_F(TypeManagerTest, GetUnregistered) {
- Manager tm;
+ TypeManager tm;
auto* t = tm.Get<I32>();
ASSERT_NE(t, nullptr);
EXPECT_TRUE(t->Is<I32>());
}
TEST_F(TypeManagerTest, GetSameTypeReturnsSamePtr) {
- Manager tm;
+ TypeManager tm;
auto* t = tm.Get<I32>();
ASSERT_NE(t, nullptr);
EXPECT_TRUE(t->Is<I32>());
@@ -51,7 +51,7 @@
}
TEST_F(TypeManagerTest, GetDifferentTypeReturnsDifferentPtr) {
- Manager tm;
+ TypeManager tm;
Type* t = tm.Get<I32>();
ASSERT_NE(t, nullptr);
EXPECT_TRUE(t->Is<I32>());
@@ -62,9 +62,17 @@
EXPECT_TRUE(t2->Is<U32>());
}
+TEST_F(TypeManagerTest, Find) {
+ TypeManager tm;
+ auto* created = tm.Get<I32>();
+
+ EXPECT_EQ(tm.Find<U32>(), nullptr);
+ EXPECT_EQ(tm.Find<I32>(), created);
+}
+
TEST_F(TypeManagerTest, WrapDoesntAffectInner) {
- Manager inner;
- Manager outer = Manager::Wrap(inner);
+ TypeManager inner;
+ TypeManager outer = TypeManager::Wrap(inner);
inner.Get<I32>();
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
new file mode 100644
index 0000000..57da4da
--- /dev/null
+++ b/src/tint/transform/std140.cc
@@ -0,0 +1,950 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/std140.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/index_accessor_expression.h"
+#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/module.h"
+#include "src/tint/sem/struct.h"
+#include "src/tint/sem/variable.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/transform.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::transform::Std140);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace {
+
+/// DynamicIndex is used by Std140::State::AccessIndex to indicate a runtime-expression index
+struct DynamicIndex {
+ size_t slot; // The index of the expression in Std140::State::AccessChain::dynamic_indices
+};
+
+/// Inequality operator for DynamicIndex
+bool operator!=(const DynamicIndex& a, const DynamicIndex& b) {
+ return a.slot != b.slot;
+}
+
+} // namespace
+
+namespace tint::utils {
+
+/// Hasher specialization for DynamicIndex
+template <>
+struct Hasher<DynamicIndex> {
+ /// The hash function for the DynamicIndex
+ /// @param d the DynamicIndex to hash
+ /// @return the hash for the given DynamicIndex
+ uint64_t operator()(const DynamicIndex& d) const { return utils::Hash(d.slot); }
+};
+
+} // namespace tint::utils
+
+namespace tint::transform {
+
+/// The PIMPL state for the Std140 transform
+struct Std140::State {
+ /// Constructor
+ /// @param c the CloneContext
+ explicit State(CloneContext& c) : ctx(c) {}
+
+ /// Runs the transform
+ void Run() {
+ // Begin by creating forked structures for any struct that is used as a uniform buffer, that
+ // either directly or transitively contains a matrix that needs splitting for std140 layout.
+ ForkStructs();
+
+ // Next, replace all the uniform variables to use the forked types.
+ ReplaceUniformVarTypes();
+
+ // Finally, replace all expression chains that used the authored types with those that
+ // correctly use the forked types.
+ ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
+ if (auto access = AccessChainFor(expr)) {
+ if (!access->std140_mat_idx.has_value()) {
+ // loading a std140 type, which is not a whole or partial decomposed matrix
+ return LoadWithConvert(access.value());
+ }
+ if (!access->IsMatrixSubset() || // loading a whole matrix
+ std::holds_alternative<DynamicIndex>(
+ access->indices[*access->std140_mat_idx + 1])) {
+ // Whole object or matrix is loaded, or the matrix column is indexed with a
+ // non-constant index. Build a helper function to load the expression chain.
+ return LoadMatrixWithFn(access.value());
+ }
+ // Matrix column is statically indexed. Can be emitted as an inline expression.
+ return LoadSubMatrixInline(access.value());
+ }
+ // Expression isn't an access to a std140-layout uniform buffer.
+ // Just clone.
+ return nullptr;
+ });
+
+ ctx.Clone();
+ }
+
+ /// @returns true if this transform should be run for the given program
+ /// @param program the program to inspect
+ static bool ShouldRun(const Program* program) {
+ for (auto* ty : program->Types()) {
+ if (auto* str = ty->As<sem::Struct>()) {
+ if (str->UsedAs(ast::StorageClass::kUniform)) {
+ for (auto* member : str->Members()) {
+ if (auto* mat = member->Type()->As<sem::Matrix>()) {
+ if (MatrixNeedsDecomposing(mat)) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ private:
+ /// Swizzle describes a vector swizzle
+ using Swizzle = utils::Vector<uint32_t, 4>;
+
+ /// AccessIndex describes a single access in an access chain.
+ /// The access is one of:
+ /// u32 - a static member index on a struct, static array index, static matrix column
+ /// index, static vector element index.
+ /// DynamicIndex - a runtime-expression index on an array, matrix column selection, or vector
+ /// element index.
+ /// Swizzle - a static vector swizzle.
+ using AccessIndex = std::variant<u32, DynamicIndex, Swizzle>;
+
+ /// A vector of AccessIndex.
+ using AccessIndices = utils::Vector<AccessIndex, 8>;
+
+ /// A key used to cache load functions for an access chain.
+ struct LoadFnKey {
+ /// The root uniform buffer variable for the access chain.
+ const sem::GlobalVariable* var;
+
+ /// The chain of accesses indices.
+ AccessIndices indices;
+
+ /// Hash function for LoadFnKey.
+ struct Hasher {
+ /// @param fn the LoadFnKey to hash
+ /// @return the hash for the given LoadFnKey
+ uint64_t operator()(const LoadFnKey& fn) const {
+ return utils::Hash(fn.var, fn.indices);
+ }
+ };
+
+ /// Equality operator
+ bool operator==(const LoadFnKey& other) const {
+ return var == other.var && indices == other.indices;
+ }
+ };
+
+ /// The clone context
+ CloneContext& ctx;
+ /// Alias to the semantic info in ctx.src
+ const sem::Info& sem = ctx.src->Sem();
+ /// Alias to the symbols in ctx.src
+ const SymbolTable& sym = ctx.src->Symbols();
+ /// Alias to the ctx.dst program builder
+ ProgramBuilder& b = *ctx.dst;
+
+ /// Map of load function signature, to the generated function
+ utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns;
+
+ /// Map of std140-forked type to converter function name
+ utils::Hashmap<const sem::Type*, Symbol, 8> conv_fns;
+
+ // Uniform variables that have been modified to use a std140 type
+ utils::Hashset<const sem::Variable*, 8> std140_uniforms;
+
+ // Map of original structure to 'std140' forked structure
+ utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs;
+
+ // Map of structure member in ctx.src of a matrix type, to list of decomposed column
+ // members in ctx.dst.
+ utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8>
+ std140_mats;
+
+ /// AccessChain describes a chain of access expressions to uniform buffer variable.
+ struct AccessChain {
+ /// The uniform buffer variable.
+ const sem::GlobalVariable* var;
+ /// The chain of access indices, starting with the first access on #var.
+ AccessIndices indices;
+ /// The runtime-evaluated expressions. This vector is indexed by the DynamicIndex::slot
+ utils::Vector<const sem::Expression*, 8> dynamic_indices;
+ /// The type of the std140-decomposed matrix being accessed.
+ /// May be nullptr if the chain does not pass through a std140-decomposed matrix.
+ const sem::Matrix* std140_mat_ty = nullptr;
+ /// The index in #indices of the access that resolves to the std140-decomposed matrix.
+ /// May hold no value if the chain does not pass through a std140-decomposed matrix.
+ std::optional<size_t> std140_mat_idx;
+
+ /// @returns true if the access chain is to part of (not the whole) std140-decomposed matrix
+ bool IsMatrixSubset() const {
+ return std140_mat_idx.has_value() && (std140_mat_idx.value() + 1 != indices.Length());
+ }
+ };
+
+ /// @returns true if the given matrix needs decomposing to column vectors for std140 layout.
+ /// TODO(crbug.com/tint/1502): This may need adjusting for `f16` matrices.
+ static bool MatrixNeedsDecomposing(const sem::Matrix* mat) { return mat->ColumnStride() == 8; }
+
+ /// ForkStructs walks the structures in dependency order, forking structures that are used as
+ /// uniform buffers which (transitively) use matrices that need std140 decomposition to column
+ /// vectors.
+ /// Populates the #std140_mats map and #std140_structs set.
+ void ForkStructs() {
+ // For each module scope declaration...
+ for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
+ // Check to see if this is a structure used by a uniform buffer...
+ auto* str = sem.Get<sem::Struct>(global);
+ if (str && str->UsedAs(ast::StorageClass::kUniform)) {
+ // Should this uniform buffer be forked for std140 usage?
+ bool fork_std140 = false;
+ utils::Vector<const ast::StructMember*, 8> members;
+ for (auto* member : str->Members()) {
+ if (auto* mat = member->Type()->As<sem::Matrix>()) {
+ // Is this member a matrix that needs decomposition for std140-layout?
+ if (MatrixNeedsDecomposing(mat)) {
+ // Structure member of matrix type needs decomposition.
+ fork_std140 = true;
+ // Replace the member with column vectors.
+ const auto num_columns = mat->columns();
+ const auto name_prefix = PrefixForUniqueNames(
+ str->Declaration(), member->Name(), num_columns);
+ // Build a struct member for each column of the matrix
+ utils::Vector<const ast::StructMember*, 4> column_members;
+ for (uint32_t i = 0; i < num_columns; i++) {
+ utils::Vector<const ast::Attribute*, 1> attributes;
+ if ((i == 0) && mat->Align() != member->Align()) {
+ // The matrix was @align() annotated with a larger alignment
+ // than the natural alignment for the matrix. This extra padding
+ // needs to be applied to the first column vector.
+ attributes.Push(b.MemberAlign(u32(member->Align())));
+ }
+ if ((i == num_columns - 1) && mat->Size() != member->Size()) {
+ // The matrix was @size() annotated with a larger size than the
+ // natural size for the matrix. This extra padding needs to be
+ // applied to the last column vector.
+ attributes.Push(
+ b.MemberSize(member->Size() - mat->ColumnType()->Size() *
+ (num_columns - 1)));
+ }
+
+ // Build the member
+ const auto col_name = name_prefix + std::to_string(i);
+ const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
+ const auto* col_member =
+ ctx.dst->Member(col_name, col_ty, std::move(attributes));
+ // Add the member to the forked structure
+ members.Push(col_member);
+ // Record the member for std140_mats
+ column_members.Push(col_member);
+ }
+ std140_mats.Add(member, std::move(column_members));
+ continue;
+ }
+ }
+
+ // Is the member part of a struct that has been forked for std140-layout?
+ if (auto* std140_ty = Std140Type(member->Type())) {
+ // Yes - use this type for the forked structure member.
+ fork_std140 = true;
+ auto attrs = ctx.Clone(member->Declaration()->attributes);
+ members.Push(
+ b.Member(sym.NameFor(member->Name()), std140_ty, std::move(attrs)));
+ continue;
+ }
+
+ // Nothing special about this member.
+ // Push the member in src to members without first cloning. We'll replace this
+ // with a cloned member once we know whether we need to fork the structure or
+ // not.
+ members.Push(member->Declaration());
+ }
+
+ // Did any of the members require forking the structure?
+ if (fork_std140) {
+ // Clone any members that have not already been cloned.
+ for (auto& member : members) {
+ if (member->program_id == ctx.src->ID()) {
+ member = ctx.Clone(member);
+ }
+ }
+ // Create a new forked structure, and insert it just under the original
+ // structure.
+ auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140");
+ auto* std140 = b.create<ast::Struct>(name, std::move(members),
+ ctx.Clone(str->Declaration()->attributes));
+ ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140);
+ std140_structs.Add(str, name);
+ }
+ }
+ }
+ }
+
+ /// Walks the global variables, replacing the type of those that are a uniform buffer with a
+ /// type that has been forked for std140-layout.
+ /// Populates the #std140_uniforms set.
+ void ReplaceUniformVarTypes() {
+ for (auto* global : ctx.src->AST().GlobalVariables()) {
+ if (auto* var = global->As<ast::Var>()) {
+ if (var->declared_storage_class == ast::StorageClass::kUniform) {
+ auto* v = sem.Get(var);
+ if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) {
+ ctx.Replace(global->type, std140_ty);
+ std140_uniforms.Add(v);
+ }
+ }
+ }
+ }
+ }
+
+ /// @returns a unique structure member prefix for the splitting of a matrix member into @p count
+ /// column vector members. The new members must be suffixed with a zero-based index ranging from
+ /// `[0..count)`.
+ /// @param str the structure that will hold the uniquely named member.
+ /// @param unsuffixed the common name prefix to use for the new members.
+ /// @param count the number of members that need to be created.
+ std::string PrefixForUniqueNames(const ast::Struct* str,
+ Symbol unsuffixed,
+ uint32_t count) const {
+ auto prefix = sym.NameFor(unsuffixed);
+ // Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name
+ // is unique.
+ while (true) {
+ prefix += "_";
+
+ utils::Hashset<std::string, 4> strings;
+ for (uint32_t i = 0; i < count; i++) {
+ strings.Add(prefix + std::to_string(i));
+ }
+
+ bool unique = true;
+ for (auto* member : str->members) {
+ // The member name must be unique over the entire set of `count` suffixed names.
+ if (strings.Contains(sym.NameFor(member->symbol))) {
+ unique = false;
+ break;
+ }
+ }
+
+ if (unique) {
+ return prefix;
+ }
+ }
+ }
+
+ /// @returns a new, forked std140 AST type for the corresponding non-forked semantic type. If
+ /// the
+ /// semantic type is not split for std140-layout, then nullptr is returned.
+ const ast::Type* Std140Type(const sem::Type* ty) const {
+ return Switch(
+ ty, //
+ [&](const sem::Struct* str) -> const ast::Type* {
+ if (auto* std140 = std140_structs.Find(str)) {
+ return b.create<ast::TypeName>(*std140);
+ }
+ return nullptr;
+ },
+ [&](const sem::Array* arr) -> const ast::Type* {
+ if (auto* std140 = Std140Type(arr->ElemType())) {
+ utils::Vector<const ast::Attribute*, 1> attrs;
+ if (!arr->IsStrideImplicit()) {
+ attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
+ }
+ return b.create<ast::Array>(std140, b.Expr(u32(arr->Count())),
+ std::move(attrs));
+ }
+ return nullptr;
+ });
+ }
+
+ /// Walks the @p ast_expr, constructing and returning an AccessChain.
+ /// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
+ /// otherwise returns a std::nullopt.
+ std::optional<AccessChain> AccessChainFor(const ast::Expression* ast_expr) {
+ auto* expr = sem.Get(ast_expr);
+ if (!expr) {
+ return std::nullopt;
+ }
+
+ AccessChain access;
+
+ // Start by looking at the source variable. This must be a std140-forked uniform buffer.
+ access.var = tint::As<sem::GlobalVariable>(expr->SourceVariable());
+ if (!access.var || !std140_uniforms.Contains(access.var)) {
+ // Not at std140-forked uniform buffer access chain.
+ return std::nullopt;
+ }
+
+ // Walk from the outer-most expression, inwards towards the source variable.
+ while (true) {
+ enum class Action { kStop, kContinue, kError };
+ Action action = Switch(
+ expr, //
+ [&](const sem::VariableUser* user) {
+ if (user->Variable() == access.var) {
+ // Walked all the way to the source variable. We're done traversing.
+ return Action::kStop;
+ }
+ if (user->Variable()->Type()->Is<sem::Pointer>()) {
+ // Found a pointer. As the source variable is a uniform buffer variable,
+ // this must be a pointer-let. Continue traversing from the let initializer.
+ expr = user->Variable()->Constructor();
+ return Action::kContinue;
+ }
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unexpected variable found walking access chain: "
+ << sym.NameFor(user->Variable()->Declaration()->symbol);
+ return Action::kError;
+ },
+ [&](const sem::StructMemberAccess* a) {
+ // Is this a std140 decomposed matrix?
+ if (!access.std140_mat_ty && std140_mats.Contains(a->Member())) {
+ // Record this on the access.
+ access.std140_mat_idx = access.indices.Length();
+ access.std140_mat_ty = expr->Type()->UnwrapRef()->As<sem::Matrix>();
+ }
+ // Structure member accesses are always statically indexed
+ access.indices.Push(u32(a->Member()->Index()));
+ expr = a->Object();
+ return Action::kContinue;
+ },
+ [&](const sem::IndexAccessorExpression* a) {
+ // Array, matrix or vector index.
+ if (auto* val = a->Index()->ConstantValue()) {
+ access.indices.Push(val->As<u32>());
+ } else {
+ access.indices.Push(DynamicIndex{access.dynamic_indices.Length()});
+ access.dynamic_indices.Push(a->Index());
+ }
+ expr = a->Object();
+ return Action::kContinue;
+ },
+ [&](const sem::Swizzle* s) {
+ // Vector swizzle.
+ if (s->Indices().Length() == 1) {
+ access.indices.Push(u32(s->Indices()[0]));
+ } else {
+ access.indices.Push(s->Indices());
+ }
+ expr = s->Object();
+ return Action::kContinue;
+ },
+ [&](const sem::Expression* e) {
+ // Walk past indirection and address-of unary ops.
+ return Switch(e->Declaration(), //
+ [&](const ast::UnaryOpExpression* u) {
+ switch (u->op) {
+ case ast::UnaryOp::kAddressOf:
+ case ast::UnaryOp::kIndirection:
+ expr = sem.Get(u->expr);
+ return Action::kContinue;
+ default:
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled unary op for access chain: "
+ << u->op;
+ return Action::kError;
+ }
+ });
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled expression type for access chain\n"
+ << "AST: " << expr->Declaration()->TypeInfo().name << "\n"
+ << "SEM: " << expr->TypeInfo().name;
+ return Action::kError;
+ });
+
+ switch (action) {
+ case Action::kContinue:
+ continue;
+ case Action::kStop:
+ break;
+ case Action::kError:
+ return std::nullopt;
+ }
+
+ break;
+ }
+
+ // As the access walked from RHS to LHS, the last index operation applies to the source
+ // variable. We want this the other way around, so reverse the arrays and fix indicies.
+ std::reverse(access.indices.begin(), access.indices.end());
+ std::reverse(access.dynamic_indices.begin(), access.dynamic_indices.end());
+ if (access.std140_mat_idx.has_value()) {
+ access.std140_mat_idx = access.indices.Length() - *access.std140_mat_idx - 1;
+ }
+ for (auto& index : access.indices) {
+ if (auto* dyn_idx = std::get_if<DynamicIndex>(&index)) {
+ dyn_idx->slot = access.dynamic_indices.Length() - dyn_idx->slot - 1;
+ }
+ }
+
+ return access;
+ }
+
+ /// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
+ /// being converted.
+ const std::string ConvertSuffix(const sem::Type* ty) const {
+ return Switch(
+ ty, //
+ [&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
+ [&](const sem::Array* arr) {
+ return "arr_" + std::to_string(arr->Count()) + "_" + ConvertSuffix(arr->ElemType());
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for conversion name: " << b.FriendlyName(ty);
+ return "";
+ });
+ }
+
+ /// Generates and returns an expression that loads the value from a std140 uniform buffer,
+ /// converting the final result to a non-std140 type.
+ /// @param access the access chain from a uniform buffer to the value to load.
+ const ast::Expression* LoadWithConvert(const AccessChain& access) {
+ const ast::Expression* expr = b.Expr(sym.NameFor(access.var->Declaration()->symbol));
+ const sem::Type* ty = access.var->Type()->UnwrapRef();
+ auto dynamic_index = [&](size_t idx) {
+ return ctx.Clone(access.dynamic_indices[idx]->Declaration());
+ };
+ for (auto index : access.indices) {
+ auto [new_expr, new_ty, _] = BuildAccessExpr(expr, ty, index, dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ }
+ return Convert(ty, expr);
+ }
+
+ /// Generates and returns an expression that converts the expression @p expr of the
+ /// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert()
+ /// will simply return @p expr.
+ /// @returns the converted value expression.
+ const ast::Expression* Convert(const sem::Type* ty, const ast::Expression* expr) {
+ // Get an existing, or create a new function for converting the std140 type to ty.
+ auto fn = conv_fns.GetOrCreate(ty, [&] {
+ auto std140_ty = Std140Type(ty);
+ if (!std140_ty) {
+ // ty was not forked for std140.
+ return Symbol{};
+ }
+
+ // The converter function takes a single argument of the std140 type.
+ auto* param = b.Param("val", std140_ty);
+
+ utils::Vector<const ast::Statement*, 3> stmts;
+
+ Switch(
+ ty, //
+ [&](const sem::Struct* str) {
+ // Convert each of the structure members using either a converter function call,
+ // or by reassembling a std140 matrix from column vector members.
+ utils::Vector<const ast::Expression*, 8> args;
+ for (auto* member : str->Members()) {
+ if (auto* col_members = std140_mats.Find(member)) {
+ // std140 decomposed matrix. Reassemble.
+ auto* mat_ty = CreateASTTypeFor(ctx, member->Type());
+ auto mat_args =
+ utils::Transform(*col_members, [&](const ast::StructMember* m) {
+ return b.MemberAccessor(param, m->symbol);
+ });
+ args.Push(b.Construct(mat_ty, std::move(mat_args)));
+ } else {
+ // Convert the member
+ args.Push(
+ Convert(member->Type(),
+ b.MemberAccessor(param, sym.NameFor(member->Name()))));
+ }
+ }
+ auto* converted = b.Construct(CreateASTTypeFor(ctx, ty), std::move(args));
+ stmts.Push(b.Return(converted));
+ }, //
+ [&](const sem::Array* arr) {
+ // Converting an array. Create a function var for the converted array, and loop
+ // over the input elements, converting each and assigning the result to the
+ // local array.
+ auto* var = b.Var("arr", CreateASTTypeFor(ctx, ty));
+ auto* i = b.Var("i", b.ty.u32());
+ auto* dst_el = b.IndexAccessor(var, i);
+ auto* src_el = Convert(arr->ElemType(), b.IndexAccessor(param, i));
+ stmts.Push(b.Decl(var));
+ stmts.Push(b.For(b.Decl(i), //
+ b.LessThan(i, u32(arr->Count())), //
+ b.Assign(i, b.Add(i, 1_a)), //
+ b.Block(b.Assign(dst_el, src_el))));
+ stmts.Push(b.Return(var));
+ },
+ [&](Default) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for conversion: " << b.FriendlyName(ty);
+ });
+
+ // Generate the function
+ auto* ret_ty = CreateASTTypeFor(ctx, ty);
+ auto fn_sym = b.Symbols().New("conv_" + ConvertSuffix(ty));
+ b.Func(fn_sym, utils::Vector{param}, ret_ty, std::move(stmts));
+ return fn_sym;
+ });
+
+ if (!fn.IsValid()) {
+ // Not a std140 type, nothing to convert.
+ return expr;
+ }
+
+ // Call the helper
+ return b.Call(fn, utils::Vector{expr});
+ }
+
+ /// Loads a part of, or a whole std140-decomposed matrix from a uniform buffer, using a helper
+ /// function which will be generated if it hasn't been already.
+ /// @param access the access chain from the uniform buffer to either the whole matrix or part of
+ /// the matrix (column, column-swizzle, or element).
+ /// @returns the loaded value expression.
+ const ast::Expression* LoadMatrixWithFn(const AccessChain& access) {
+ // Get an existing, or create a new function for loading the uniform buffer value.
+ // This function is keyed off the uniform buffer variable and the access chain.
+ auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
+ if (access.IsMatrixSubset()) {
+ // Access chain passes through the matrix, but ends either at a column vector,
+ // column swizzle, or element.
+ return BuildLoadPartialMatrixFn(access);
+ }
+ // Access is to the whole matrix.
+ return BuildLoadWholeMatrixFn(access);
+ });
+
+ // Build the arguments
+ auto args = utils::Transform(access.dynamic_indices, [&](const sem::Expression* e) {
+ return b.Construct(b.ty.u32(), ctx.Clone(e->Declaration()));
+ });
+
+ // Call the helper
+ return b.Call(fn, std::move(args));
+ }
+
+ /// Loads a part of a std140-decomposed matrix from a uniform buffer, inline (without calling a
+ /// helper function).
+ /// @param access the access chain from the uniform buffer to part of the matrix (column,
+ /// column-swizzle, or element).
+ /// @note The matrix column must be statically indexed to use this method.
+ /// @returns the loaded value expression.
+ const ast::Expression* LoadSubMatrixInline(const AccessChain& access) {
+ const ast::Expression* expr = b.Expr(ctx.Clone(access.var->Declaration()->symbol));
+ const sem::Type* ty = access.var->Type()->UnwrapRef();
+ // Method for generating dynamic index expressions.
+ // As this is inline, we can just clone the expression.
+ auto dynamic_index = [&](size_t idx) {
+ return ctx.Clone(access.dynamic_indices[idx]->Declaration());
+ };
+ for (size_t i = 0; i < access.indices.Length(); i++) {
+ if (i == access.std140_mat_idx) {
+ // Access is to the std140 decomposed matrix.
+ // As this is accessing only part of the matrix, we just need to pick the right
+ // column vector member.
+ auto mat_member_idx = std::get<u32>(access.indices[i]);
+ auto* mat_member = ty->As<sem::Struct>()->Members()[mat_member_idx];
+ auto mat_columns = *std140_mats.Get(mat_member);
+ auto column_idx = std::get<u32>(access.indices[i + 1]);
+ expr = b.MemberAccessor(expr, mat_columns[column_idx]->symbol);
+ ty = mat_member->Type()->As<sem::Matrix>()->ColumnType();
+ // We've consumed both the matrix member access and the column access. Increment i.
+ i++;
+ } else {
+ // Access is to something that is not a decomposed matrix.
+ auto [new_expr, new_ty, _] =
+ BuildAccessExpr(expr, ty, access.indices[i], dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ }
+ }
+ return expr;
+ }
+
+ /// Generates a function to load part of a std140-decomposed matrix from a uniform buffer.
+ /// The generated function will have a parameter per dynamic (runtime-evaluated) index in the
+ /// access chain.
+ /// The generated function uses a WGSL switch statement to dynamically select the decomposed
+ /// matrix column.
+ /// @param access the access chain from the uniform buffer to part of the matrix (column,
+ /// column-swizzle, or element).
+ /// @note The matrix column must be dynamically indexed to use this method.
+ /// @returns the generated function name.
+ Symbol BuildLoadPartialMatrixFn(const AccessChain& access) {
+ // Build the dynamic index parameters
+ auto dynamic_index_params = utils::Transform(access.dynamic_indices, [&](auto*, size_t i) {
+ return b.Param("p" + std::to_string(i), b.ty.u32());
+ });
+ // Method for generating dynamic index expressions.
+ // These are passed in as arguments to the function.
+ auto dynamic_index = [&](size_t idx) { return b.Expr(dynamic_index_params[idx]->symbol); };
+
+ // Fetch the access chain indices of the matrix access and the parameter index that holds
+ // the matrix column index.
+ auto std140_mat_idx = *access.std140_mat_idx;
+ auto column_param_idx = std::get<DynamicIndex>(access.indices[std140_mat_idx + 1]).slot;
+
+ // Begin building the function name. This is extended with logic in the loop below
+ // (when column_idx == 0).
+ std::string name = "load_" + sym.NameFor(access.var->Declaration()->symbol);
+
+ // The switch cases
+ utils::Vector<const ast::CaseStatement*, 4> cases;
+
+ // The function return type.
+ const sem::Type* ret_ty = nullptr;
+
+ // Build switch() cases for each column of the matrix
+ auto num_columns = access.std140_mat_ty->columns();
+ for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) {
+ const ast::Expression* expr = b.Expr(ctx.Clone(access.var->Declaration()->symbol));
+ const sem::Type* ty = access.var->Type()->UnwrapRef();
+ // Build the expression up to, but not including the matrix member
+ for (size_t i = 0; i < access.std140_mat_idx; i++) {
+ auto [new_expr, new_ty, access_name] =
+ BuildAccessExpr(expr, ty, access.indices[i], dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ if (column_idx == 0) {
+ name = name + "_" + access_name;
+ }
+ }
+
+ // Get the matrix member that was dynamically accessed.
+ auto mat_member_idx = std::get<u32>(access.indices[std140_mat_idx]);
+ auto* mat_member = ty->As<sem::Struct>()->Members()[mat_member_idx];
+ auto mat_columns = *std140_mats.Get(mat_member);
+ if (column_idx == 0) {
+ name = name + +"_" + sym.NameFor(mat_member->Name()) + "_p" +
+ std::to_string(column_param_idx);
+ }
+
+ // Build the expression to the column vector member.
+ expr = b.MemberAccessor(expr, mat_columns[column_idx]->symbol);
+ ty = mat_member->Type()->As<sem::Matrix>()->ColumnType();
+ // Build the rest of the expression, skipping over the column index.
+ for (size_t i = std140_mat_idx + 2; i < access.indices.Length(); i++) {
+ auto [new_expr, new_ty, access_name] =
+ BuildAccessExpr(expr, ty, access.indices[i], dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ if (column_idx == 0) {
+ name = name + "_" + access_name;
+ }
+ }
+
+ if (column_idx == 0) {
+ ret_ty = ty;
+ }
+
+ auto* case_sel = b.Expr(u32(column_idx));
+ auto* case_body = b.Block(utils::Vector{b.Return(expr)});
+ cases.Push(b.Case(case_sel, case_body));
+ }
+
+ // Build the default case (required in WGSL).
+ // This just returns a zero value of the return type, as the index must be out of bounds.
+ cases.Push(b.DefaultCase(b.Block(b.Return(b.Construct(CreateASTTypeFor(ctx, ret_ty))))));
+
+ auto* column_selector = dynamic_index(column_param_idx);
+ auto* stmt = b.Switch(column_selector, std::move(cases));
+
+ auto fn_sym = b.Symbols().New(name);
+ b.Func(fn_sym, std::move(dynamic_index_params), CreateASTTypeFor(ctx, ret_ty),
+ utils::Vector{stmt});
+ return fn_sym;
+ }
+
+ /// Generates a function to load a whole std140-decomposed matrix from a uniform buffer.
+ /// The generated function will have a parameter per dynamic (runtime-evaluated) index in the
+ /// access chain.
+ /// @param access the access chain from the uniform buffer to the whole std140-decomposed
+ /// matrix.
+ /// @returns the generated function name.
+ Symbol BuildLoadWholeMatrixFn(const AccessChain& access) {
+ // Build the dynamic index parameters
+ auto dynamic_index_params = utils::Transform(access.dynamic_indices, [&](auto*, size_t i) {
+ return b.Param("p" + std::to_string(i), b.ty.u32());
+ });
+ // Method for generating dynamic index expressions.
+ // These are passed in as arguments to the function.
+ auto dynamic_index = [&](size_t idx) { return b.Expr(dynamic_index_params[idx]->symbol); };
+
+ const ast::Expression* expr = b.Expr(ctx.Clone(access.var->Declaration()->symbol));
+ std::string name = sym.NameFor(access.var->Declaration()->symbol);
+ const sem::Type* ty = access.var->Type()->UnwrapRef();
+
+ // Build the expression up to, but not including the matrix member
+ auto std140_mat_idx = *access.std140_mat_idx;
+ for (size_t i = 0; i < std140_mat_idx; i++) {
+ auto [new_expr, new_ty, access_name] =
+ BuildAccessExpr(expr, ty, access.indices[i], dynamic_index);
+ expr = new_expr;
+ ty = new_ty;
+ name = name + "_" + access_name;
+ }
+
+ utils::Vector<const ast::Statement*, 2> stmts;
+
+ // Create a temporary pointer to the structure that holds the matrix columns
+ auto* let = b.Let("s", b.AddressOf(expr));
+ stmts.Push(b.Decl(let));
+
+ // Gather the decomposed matrix columns
+ auto mat_member_idx = std::get<u32>(access.indices[std140_mat_idx]);
+ auto* mat_member = ty->As<sem::Struct>()->Members()[mat_member_idx];
+ auto mat_columns = *std140_mats.Get(mat_member);
+ auto columns = utils::Transform(mat_columns, [&](auto* column_member) {
+ return b.MemberAccessor(b.Deref(let), column_member->symbol);
+ });
+
+ // Reconstruct the matrix from the columns
+ expr = b.Construct(CreateASTTypeFor(ctx, access.std140_mat_ty), std::move(columns));
+ ty = mat_member->Type();
+ name = name + "_" + sym.NameFor(mat_member->Name());
+
+ // Have the function return the constructed matrix
+ stmts.Push(b.Return(expr));
+
+ // Build the function
+ auto* ret_ty = CreateASTTypeFor(ctx, ty);
+ auto fn_sym = b.Symbols().New("load_" + name);
+ b.Func(fn_sym, std::move(dynamic_index_params), ret_ty, std::move(stmts));
+ return fn_sym;
+ }
+
+ /// Return type of BuildAccessExpr()
+ struct ExprTypeName {
+ /// The new, post-access expression
+ const ast::Expression* expr;
+ /// The type of #expr
+ const sem::Type* type;
+ /// A name segment which can be used to build sensible names for helper functions
+ std::string name;
+ };
+
+ /// Builds a single access in an access chain.
+ /// @param lhs the expression to index using @p access
+ /// @param ty the type of the expression @p lhs
+ /// @param access the access index to perform on @p lhs
+ /// @param dynamic_index a function that obtains the i'th dynamic index
+ /// @returns a ExprTypeName which holds the new expression, new type and a name segment which
+ /// can be used for creating helper function names.
+ ExprTypeName BuildAccessExpr(const ast::Expression* lhs,
+ const sem::Type* ty,
+ AccessIndex access,
+ std::function<const ast::Expression*(size_t)> dynamic_index) {
+ if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
+ /// The access uses a dynamic (runtime-expression) index.
+ auto name = "p" + std::to_string(dyn_idx->slot);
+ return Switch(
+ ty, //
+ [&](const sem::Array* arr) -> ExprTypeName {
+ auto* idx = dynamic_index(dyn_idx->slot);
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, arr->ElemType(), name};
+ }, //
+ [&](const sem::Matrix* mat) -> ExprTypeName {
+ auto* idx = dynamic_index(dyn_idx->slot);
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, mat->ColumnType(), name};
+ }, //
+ [&](const sem::Vector* vec) -> ExprTypeName {
+ auto* idx = dynamic_index(dyn_idx->slot);
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, vec->type(), name};
+ }, //
+ [&](Default) -> ExprTypeName {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for access chain: " << b.FriendlyName(ty);
+ return {};
+ });
+ }
+ if (auto* swizzle = std::get_if<Swizzle>(&access)) {
+ /// The access is a vector swizzle.
+ return Switch(
+ ty, //
+ [&](const sem::Vector* vec) -> ExprTypeName {
+ static const char xyzw[] = {'x', 'y', 'z', 'w'};
+ std::string rhs;
+ for (auto el : *swizzle) {
+ rhs += xyzw[el];
+ }
+ auto swizzle_ty = ctx.src->Types().Find<sem::Vector>(
+ vec->type(), static_cast<uint32_t>(swizzle->Length()));
+ auto* expr = b.MemberAccessor(lhs, rhs);
+ return {expr, swizzle_ty, rhs};
+ }, //
+ [&](Default) -> ExprTypeName {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for access chain: " << b.FriendlyName(ty);
+ return {};
+ });
+ }
+ /// The access is a static index.
+ auto idx = std::get<u32>(access);
+ return Switch(
+ ty, //
+ [&](const sem::Struct* str) -> ExprTypeName {
+ auto* member = str->Members()[idx];
+ auto member_name = sym.NameFor(member->Name());
+ auto* expr = b.MemberAccessor(lhs, member_name);
+ ty = member->Type();
+ return {expr, ty, member_name};
+ }, //
+ [&](const sem::Array* arr) -> ExprTypeName {
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, arr->ElemType(), std::to_string(idx)};
+ }, //
+ [&](const sem::Matrix* mat) -> ExprTypeName {
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, mat->ColumnType(), std::to_string(idx)};
+ }, //
+ [&](const sem::Vector* vec) -> ExprTypeName {
+ auto* expr = b.IndexAccessor(lhs, idx);
+ return {expr, vec->type(), std::to_string(idx)};
+ }, //
+ [&](Default) -> ExprTypeName {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled type for access chain: " << b.FriendlyName(ty);
+ return {};
+ });
+ }
+};
+
+Std140::Std140() = default;
+
+Std140::~Std140() = default;
+
+bool Std140::ShouldRun(const Program* program, const DataMap&) const {
+ return State::ShouldRun(program);
+}
+
+void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ State(ctx).Run();
+}
+
+} // namespace tint::transform
diff --git a/src/tint/transform/std140.h b/src/tint/transform/std140.h
new file mode 100644
index 0000000..f987805
--- /dev/null
+++ b/src/tint/transform/std140.h
@@ -0,0 +1,57 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_STD140_H_
+#define SRC_TINT_TRANSFORM_STD140_H_
+
+#include "src/tint/transform/transform.h"
+
+namespace tint::transform {
+
+/// Std140 is a transform that forks structures used in the uniform storage class that contain
+/// `matNx2<f32>` matrices into `N`x`vec2<f32>` column vectors. Structure types that transitively
+/// use these forked structures as members are also forked. `var<uniform>` variables will use these
+/// forked structures, and expressions loading from these variables will do appropriate conversions
+/// to the regular WGSL types. As `matNx2<f32>` matrices are the only type that violate
+/// std140-layout, this transformation is sufficient to have any WGSL structure be std140-layout
+/// conformant.
+///
+/// @note This transform requires the PromoteSideEffectsToDecl transform to have been run first.
+class Std140 final : public Castable<Std140, Transform> {
+ public:
+ /// Constructor
+ Std140();
+ /// Destructor
+ ~Std140() override;
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+
+ private:
+ struct State;
+
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::transform
+
+#endif // SRC_TINT_TRANSFORM_STD140_H_
diff --git a/src/tint/transform/std140_test.cc b/src/tint/transform/std140_test.cc
new file mode 100644
index 0000000..4681cdf
--- /dev/null
+++ b/src/tint/transform/std140_test.cc
@@ -0,0 +1,2082 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/std140.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/tint/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::transform {
+namespace {
+
+using Std140Test = TransformTest;
+
+TEST_F(Std140Test, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_F(Std140Test, ShouldRunStructMat2x2Unused) {
+ auto* src = R"(
+struct Unused {
+ m : mat2x2<f32>,
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+struct ShouldRunCase {
+ uint32_t columns;
+ uint32_t rows;
+ bool should_run;
+
+ std::string Mat() const { return "mat" + std::to_string(columns) + "x" + std::to_string(rows); }
+};
+
+inline std::ostream& operator<<(std::ostream& os, const ShouldRunCase& c) {
+ return os << c.Mat();
+}
+
+using Std140TestShouldRun = TransformTestWithParam<ShouldRunCase>;
+
+TEST_P(Std140TestShouldRun, StructStorage) {
+ std::string src = R"(
+struct S {
+ m : ${mat}<f32>,
+}
+
+@group(0) @binding(0) var<storage> s : S;
+)";
+
+ src = utils::ReplaceAll(src, "${mat}", GetParam().Mat());
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_P(Std140TestShouldRun, StructUniform) {
+ std::string src = R"(
+struct S {
+ m : ${mat}<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ src = utils::ReplaceAll(src, "${mat}", GetParam().Mat());
+
+ EXPECT_EQ(ShouldRun<Std140>(src), GetParam().should_run);
+}
+
+TEST_P(Std140TestShouldRun, ArrayStorage) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage> s : array<${mat}<f32>, 2>;
+)";
+
+ src = utils::ReplaceAll(src, "${mat}", GetParam().Mat());
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+TEST_P(Std140TestShouldRun, ArrayUniform) {
+ if (GetParam().columns == 3u && GetParam().rows == 2u) {
+ // This permutation is invalid. Skip the test:
+ // error: uniform storage requires that array elements be aligned to 16 bytes, but array
+ // element alignment is currently 24. Consider wrapping the element type in a struct and
+ // using the @size attribute.
+ return;
+ }
+
+ std::string src = R"(
+@group(0) @binding(0) var<uniform> s : array<${mat}<f32>, 2>;
+)";
+
+ src = utils::ReplaceAll(src, "${mat}", GetParam().Mat());
+
+ EXPECT_FALSE(ShouldRun<Std140>(src));
+}
+
+INSTANTIATE_TEST_SUITE_P(Std140TestShouldRun,
+ Std140TestShouldRun,
+ ::testing::ValuesIn(std::vector<ShouldRunCase>{
+ {2, 2, true},
+ {2, 3, false},
+ {2, 4, false},
+ {3, 2, true},
+ {3, 3, false},
+ {3, 4, false},
+ {4, 2, true},
+ {4, 3, false},
+ {4, 4, false},
+ }));
+
+TEST_F(Std140Test, EmptyModule) {
+ auto* src = R"()";
+
+ auto* expect = R"()";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, SingleStructMat4x4Uniform) {
+ auto* src = R"(
+struct S {
+ m : mat4x4<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = src; // Nothing violates std140 layout
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, SingleStructMat2x2Uniform) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, CustomAlignMat3x2) {
+ auto* src = R"(
+struct S {
+ before : i32,
+ @align(128) m : mat3x2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ before : i32,
+ @align(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128u)
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, CustomSizeMat3x2) {
+ auto* src = R"(
+struct S {
+ before : i32,
+ @size(128) m : mat3x2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ before : i32,
+ @size(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(112)
+ m_2 : vec2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, CustomAlignAndSizeMat3x2) {
+ auto* src = R"(
+struct S {
+ before : i32,
+ @align(128) @size(128) m : mat3x2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ before : i32,
+ @align(128) @size(128)
+ m : mat3x2<f32>,
+ after : i32,
+}
+
+struct S_std140 {
+ before : i32,
+ @align(128u)
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(112)
+ m_2 : vec2<f32>,
+ after : i32,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMatricesUniform) {
+ auto* src = R"(
+struct S2x2 {
+ m : mat2x2<f32>,
+}
+struct S3x2 {
+ m : mat3x2<f32>,
+}
+struct S4x2 {
+ m : mat4x2<f32>,
+}
+struct S2x3 {
+ m : mat2x3<f32>,
+}
+struct S3x3 {
+ m : mat3x3<f32>,
+}
+struct S4x3 {
+ m : mat4x3<f32>,
+}
+struct S2x4 {
+ m : mat2x4<f32>,
+}
+struct S3x4 {
+ m : mat3x4<f32>,
+}
+struct S4x4 {
+ m : mat4x4<f32>,
+}
+
+@group(2) @binding(2) var<uniform> s2x2 : S2x2;
+@group(3) @binding(2) var<uniform> s3x2 : S3x2;
+@group(4) @binding(2) var<uniform> s4x2 : S4x2;
+@group(2) @binding(3) var<uniform> s2x3 : S2x3;
+@group(3) @binding(3) var<uniform> s3x3 : S3x3;
+@group(4) @binding(3) var<uniform> s4x3 : S4x3;
+@group(2) @binding(4) var<uniform> s2x4 : S2x4;
+@group(3) @binding(4) var<uniform> s3x4 : S3x4;
+@group(4) @binding(4) var<uniform> s4x4 : S4x4;
+)";
+
+ auto* expect = R"(
+struct S2x2 {
+ m : mat2x2<f32>,
+}
+
+struct S2x2_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+struct S3x2 {
+ m : mat3x2<f32>,
+}
+
+struct S3x2_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+}
+
+struct S4x2 {
+ m : mat4x2<f32>,
+}
+
+struct S4x2_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+ m_3 : vec2<f32>,
+}
+
+struct S2x3 {
+ m : mat2x3<f32>,
+}
+
+struct S3x3 {
+ m : mat3x3<f32>,
+}
+
+struct S4x3 {
+ m : mat4x3<f32>,
+}
+
+struct S2x4 {
+ m : mat2x4<f32>,
+}
+
+struct S3x4 {
+ m : mat3x4<f32>,
+}
+
+struct S4x4 {
+ m : mat4x4<f32>,
+}
+
+@group(2) @binding(2) var<uniform> s2x2 : S2x2_std140;
+
+@group(3) @binding(2) var<uniform> s3x2 : S3x2_std140;
+
+@group(4) @binding(2) var<uniform> s4x2 : S4x2_std140;
+
+@group(2) @binding(3) var<uniform> s2x3 : S2x3;
+
+@group(3) @binding(3) var<uniform> s3x3 : S3x3;
+
+@group(4) @binding(3) var<uniform> s4x3 : S4x3;
+
+@group(2) @binding(4) var<uniform> s2x4 : S2x4;
+
+@group(3) @binding(4) var<uniform> s3x4 : S3x4;
+
+@group(4) @binding(4) var<uniform> s4x4 : S4x4;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_NameCollision) {
+ auto* src = R"(
+struct S {
+ m_1 : i32,
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+)";
+
+ auto* expect = R"(
+struct S {
+ m_1 : i32,
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_1 : i32,
+ m__0 : vec2<f32>,
+ m__1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadStruct) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat2x2<f32>(val.m_0, val.m_1));
+}
+
+fn f() {
+ let l = conv_S(s);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadMatrix) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m() -> mat2x2<f32> {
+ let s = &(s);
+ return mat2x2<f32>((*(s)).m_0, (*(s)).m_1);
+}
+
+fn f() {
+ let l = load_s_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadColumn0) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadColumn1) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadColumnI) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0(p0 : u32) -> vec2<f32> {
+ switch(p0) {
+ case 0u: {
+ return s.m_0;
+ }
+ case 1u: {
+ return s.m_1;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalar00) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[0][0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_0[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalar10) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1][0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalarI0) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_0(p0 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[0u];
+ }
+ case 1u: {
+ return s.m_1[0u];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_0(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalar01) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[0][1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_0[1u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalar11) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let l = s.m[1][1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let l = s.m_1[1u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalarI1) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_1(p0 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[1u];
+ }
+ case 1u: {
+ return s.m_1[1u];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_1(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalar0I) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[0][I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 0;
+ let l = s.m_0[I];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalar1I) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[1][I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn f() {
+ let I = 0;
+ let l = s.m_1[I];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, StructMat2x2Uniform_LoadScalarII) {
+ auto* src = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn f() {
+ let I = 0;
+ let l = s.m[I][I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : mat2x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S_std140;
+
+fn load_s_m_p0_p1(p0 : u32, p1 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return s.m_0[p1];
+ }
+ case 1u: {
+ return s.m_1[p1];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ let I = 0;
+ let l = load_s_m_p0_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadArray) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn conv_arr_3_S(val : array<S_std140, 3u>) -> array<S, 3u> {
+ var arr : array<S, 3u>;
+ for(var i : u32; (i < 3u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ let l = conv_arr_3_S(a);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadStruct0) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ let l = conv_S(a[0u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadStruct1) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ let l = conv_S(a[1u]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadStructI) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn f() {
+ let I = 1;
+ let l = conv_S(a[I]);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrix0) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[0].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_0_m() -> mat3x2<f32> {
+ let s = &(a[0u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ let l = load_a_0_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrix1) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[1].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_1_m() -> mat3x2<f32> {
+ let s = &(a[1u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ let l = load_a_1_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrixI) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m(p0 : u32) -> mat3x2<f32> {
+ let s = &(a[p0]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m(u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrix0Column0) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[0].m[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[0u].m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrix1Column0) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[1].m[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[1u].m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrixIColumn0) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[0];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrix0Column1) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[0].m[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[0u].m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrix1Column1) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let l = a[1].m[1];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn f() {
+ let l = a[1u].m_1;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_LoadMatrixIColumnI) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S, 3>;
+
+fn f() {
+ let I = 1;
+ let l = a[I].m[I];
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<S_std140, 3u>;
+
+fn load_a_p0_m_p1(p0 : u32, p1 : u32) -> vec2<f32> {
+ switch(p1) {
+ case 0u: {
+ return a[p0].m_0;
+ }
+ case 1u: {
+ return a[p0].m_1;
+ }
+ case 2u: {
+ return a[p0].m_2;
+ }
+ default: {
+ return vec2<f32>();
+ }
+ }
+}
+
+fn f() {
+ let I = 1;
+ let l = load_a_p0_m_p1(u32(I), u32(I));
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructArrayStructMat4x2Uniform_Loads) {
+ auto* src = R"(
+struct Inner {
+ m : mat4x2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+
+ let l_a : array<Outer, 4> = a;
+ let l_a_1 : Outer = a[1];
+ let l_a_2_a : array<Inner, 4> = a[2].a;
+ let l_a_3_a_1 : Inner = a[3].a[1];
+ let l_a_0_a_2_m : mat4x2<f32> = a[0].a[2].m;
+ let l_a_1_a_3_m_0 : vec2<f32> = a[1].a[3].m[0];
+ let l_a_2_a_0_m_1_2 : f32 = a[2].a[0].m[1][2];
+}
+)";
+
+ auto* expect = R"(
+struct Inner {
+ m : mat4x2<f32>,
+}
+
+struct Inner_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+ m_3 : vec2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(mat4x2<f32>(val.m_0, val.m_1, val.m_2, val.m_3));
+}
+
+fn conv_arr_4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr_4_Inner(val.a));
+}
+
+fn conv_arr_4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_0_a_2_m() -> mat4x2<f32> {
+ let s = &(a[0u].a[2u]);
+ return mat4x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2, (*(s)).m_3);
+}
+
+fn f() {
+ let I = 1;
+ let l_a : array<Outer, 4> = conv_arr_4_Outer(a);
+ let l_a_1 : Outer = conv_Outer(a[1u]);
+ let l_a_2_a : array<Inner, 4> = conv_arr_4_Inner(a[2u].a);
+ let l_a_3_a_1 : Inner = conv_Inner(a[3u].a[1u]);
+ let l_a_0_a_2_m : mat4x2<f32> = load_a_0_a_2_m();
+ let l_a_1_a_3_m_0 : vec2<f32> = a[1u].a[3u].m_0;
+ let l_a_2_a_0_m_1_2 : f32 = a[2u].a[0u].m_1[2u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructArrayStructMat4x2Uniform_LoadsViaPtrs) {
+ // Note: Std140Test requires the PromoteSideEffectsToDecl transform to have been run first, so
+ // side-effects in the let-chain will not be a problem.
+ auto* src = R"(
+struct Inner {
+ m : mat4x2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer, 4>;
+
+fn f() {
+ let I = 1;
+
+ let p_a = &a;
+ let p_a_3 = &((*p_a)[3]);
+ let p_a_3_a = &((*p_a_3).a);
+ let p_a_3_a_2 = &((*p_a_3_a)[2]);
+ let p_a_3_a_2_m = &((*p_a_3_a_2).m);
+ let p_a_3_a_2_m_1 = &((*p_a_3_a_2_m)[1]);
+
+
+ let l_a : array<Outer, 4> = *p_a;
+ let l_a_3 : Outer = *p_a_3;
+ let l_a_3_a : array<Inner, 4> = *p_a_3_a;
+ let l_a_3_a_2 : Inner = *p_a_3_a_2;
+ let l_a_3_a_2_m : mat4x2<f32> = *p_a_3_a_2_m;
+ let l_a_3_a_2_m_1 : vec2<f32> = *p_a_3_a_2_m_1;
+ let l_a_2_a_0_m_1_0 : f32 = (*p_a_3_a_2_m_1)[0];
+}
+)";
+
+ auto* expect = R"(
+struct Inner {
+ m : mat4x2<f32>,
+}
+
+struct Inner_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ m_2 : vec2<f32>,
+ m_3 : vec2<f32>,
+}
+
+struct Outer {
+ a : array<Inner, 4>,
+}
+
+struct Outer_std140 {
+ a : array<Inner_std140, 4u>,
+}
+
+@group(0) @binding(0) var<uniform> a : array<Outer_std140, 4u>;
+
+fn conv_Inner(val : Inner_std140) -> Inner {
+ return Inner(mat4x2<f32>(val.m_0, val.m_1, val.m_2, val.m_3));
+}
+
+fn conv_arr_4_Inner(val : array<Inner_std140, 4u>) -> array<Inner, 4u> {
+ var arr : array<Inner, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Inner(val[i]);
+ }
+ return arr;
+}
+
+fn conv_Outer(val : Outer_std140) -> Outer {
+ return Outer(conv_arr_4_Inner(val.a));
+}
+
+fn conv_arr_4_Outer(val : array<Outer_std140, 4u>) -> array<Outer, 4u> {
+ var arr : array<Outer, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_Outer(val[i]);
+ }
+ return arr;
+}
+
+fn load_a_3_a_2_m() -> mat4x2<f32> {
+ let s = &(a[3u].a[2u]);
+ return mat4x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2, (*(s)).m_3);
+}
+
+fn f() {
+ let I = 1;
+ let p_a = conv_arr_4_Outer(a);
+ let p_a_3 = conv_Outer(a[3u]);
+ let p_a_3_a = conv_arr_4_Inner(a[3u].a);
+ let p_a_3_a_2 = conv_Inner(a[3u].a[2u]);
+ let p_a_3_a_2_m = load_a_3_a_2_m();
+ let p_a_3_a_2_m_1 = a[3u].a[2u].m_1;
+ let l_a : array<Outer, 4> = conv_arr_4_Outer(a);
+ let l_a_3 : Outer = conv_Outer(a[3u]);
+ let l_a_3_a : array<Inner, 4> = conv_arr_4_Inner(a[3u].a);
+ let l_a_3_a_2 : Inner = conv_Inner(a[3u].a[2u]);
+ let l_a_3_a_2_m : mat4x2<f32> = load_a_3_a_2_m();
+ let l_a_3_a_2_m_1 : vec2<f32> = a[3u].a[2u].m_1;
+ let l_a_2_a_0_m_1_0 : f32 = a[3u].a[2u].m_1[0u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_CopyArray_UniformToStorage) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s = u;
+}
+)";
+
+ auto* expect =
+ R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn conv_S(val : S_std140) -> S {
+ return S(mat3x2<f32>(val.m_0, val.m_1, val.m_2));
+}
+
+fn conv_arr_4_S(val : array<S_std140, 4u>) -> array<S, 4u> {
+ var arr : array<S, 4u>;
+ for(var i : u32; (i < 4u); i = (i + 1)) {
+ arr[i] = conv_S(val[i]);
+ }
+ return arr;
+}
+
+fn f() {
+ s = conv_arr_4_S(u);
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_CopyStruct_UniformToWorkgroup) {
+ auto* src = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+@group(0) @binding(1) var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[0] = u[1];
+}
+)";
+
+ auto* expect =
+ R"(test:8:38 error: non-resource variables must not have @group or @binding attributes
+@group(0) @binding(1) var<workgroup> w : array<S, 4>;
+ ^
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_CopyMatrix_UniformToPrivate) {
+ auto* src = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+var<private> p : array<S, 4>;
+
+fn f() {
+ p[2].m = u[1].m;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ v : vec4<i32>,
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ v : vec4<i32>,
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<private> p : array<S, 4>;
+
+fn load_u_1_m() -> mat3x2<f32> {
+ let s = &(u[1u]);
+ return mat3x2<f32>((*(s)).m_0, (*(s)).m_1, (*(s)).m_2);
+}
+
+fn f() {
+ p[2].m = load_u_1_m();
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_CopyColumn_UniformToStorage) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2].m[0];
+}
+)";
+
+ auto* expect =
+ R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+@group(0) @binding(1) var<storage, read_write> s : array<S, 4>;
+
+fn f() {
+ s[3].m[1] = u[2u].m_0;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_CopySwizzle_UniformToWorkgroup) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2].m[0].yx.xy;
+}
+)";
+
+ auto* expect =
+ R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<workgroup> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1] = u[2u].m_0.yx.xy;
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, ArrayStructMat3x2Uniform_CopyScalar_UniformToPrivate) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S, 4>;
+var<private> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1].x = u[2].m[0].y;
+}
+)";
+
+ auto* expect =
+ R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : array<S_std140, 4u>;
+
+var<private> w : array<S, 4>;
+
+fn f() {
+ w[3].m[1].x = u[2u].m_0[1u];
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(Std140Test, MatrixUsageInForLoop) {
+ auto* src = R"(
+struct S {
+ @size(64) m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+fn f() {
+ for (var i = u32(u.m[0][1]); i < u32(u.m[i][2]); i += u32(u.m[1][i])) {
+ }
+}
+)";
+
+ auto* expect =
+ R"(
+struct S {
+ @size(64)
+ m : mat3x2<f32>,
+}
+
+struct S_std140 {
+ m_0 : vec2<f32>,
+ m_1 : vec2<f32>,
+ @size(48)
+ m_2 : vec2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> u : S_std140;
+
+fn load_u_m_p0_2(p0 : u32) -> f32 {
+ switch(p0) {
+ case 0u: {
+ return u.m_0[2u];
+ }
+ case 1u: {
+ return u.m_1[2u];
+ }
+ case 2u: {
+ return u.m_2[2u];
+ }
+ default: {
+ return f32();
+ }
+ }
+}
+
+fn f() {
+ for(var i = u32(u.m_0[1u]); (i < u32(load_u_m_p0_2(u32(i)))); i += u32(u.m_1[i])) {
+ }
+}
+)";
+
+ auto got = Run<Std140>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::transform
diff --git a/src/tint/utils/result.h b/src/tint/utils/result.h
index b535f4f..6a14352 100644
--- a/src/tint/utils/result.h
+++ b/src/tint/utils/result.h
@@ -17,6 +17,7 @@
#include <ostream>
#include <variant>
+#include "src/tint/debug.h"
namespace tint::utils {
@@ -36,6 +37,9 @@
static_assert(!std::is_same_v<SUCCESS_TYPE, FAILURE_TYPE>,
"Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE");
+ /// Default constructor initializes to invalid state
+ Result() : value(std::monostate{}) {}
+
/// Constructor
/// @param success the success result
Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit):
@@ -47,27 +51,43 @@
: value{failure} {}
/// @returns true if the result was a success
- operator bool() const { return std::holds_alternative<SUCCESS_TYPE>(value); }
+ operator bool() const {
+ Validate();
+ return std::holds_alternative<SUCCESS_TYPE>(value);
+ }
/// @returns true if the result was a failure
- bool operator!() const { return std::holds_alternative<FAILURE_TYPE>(value); }
+ bool operator!() const {
+ Validate();
+ return std::holds_alternative<FAILURE_TYPE>(value);
+ }
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure will result in UB.
- const SUCCESS_TYPE* operator->() const { return &std::get<SUCCESS_TYPE>(value); }
+ const SUCCESS_TYPE* operator->() const {
+ Validate();
+ return &(Get());
+ }
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure value will result in UB.
- const SUCCESS_TYPE& Get() const { return std::get<SUCCESS_TYPE>(value); }
+ const SUCCESS_TYPE& Get() const {
+ Validate();
+ return std::get<SUCCESS_TYPE>(value);
+ }
/// @returns the failure value
/// @warning attempting to call this when the Result holds a success value will result in UB.
- const FAILURE_TYPE& Failure() const { return std::get<FAILURE_TYPE>(value); }
+ const FAILURE_TYPE& Failure() const {
+ Validate();
+ return std::get<FAILURE_TYPE>(value);
+ }
/// Equality operator
/// @param val the value to compare this Result to
/// @returns true if this result holds a success value equal to `value`
bool operator==(SUCCESS_TYPE val) const {
+ Validate();
if (auto* v = std::get_if<SUCCESS_TYPE>(&value)) {
return *v == val;
}
@@ -78,14 +98,18 @@
/// @param val the value to compare this Result to
/// @returns true if this result holds a failure value equal to `value`
bool operator==(FAILURE_TYPE val) const {
+ Validate();
if (auto* v = std::get_if<FAILURE_TYPE>(&value)) {
return *v == val;
}
return false;
}
+ private:
+ void Validate() const { TINT_ASSERT(Utils, !std::holds_alternative<std::monostate>(value)); }
+
/// The result. Either a success of failure value.
- std::variant<SUCCESS_TYPE, FAILURE_TYPE> value;
+ std::variant<std::monostate, SUCCESS_TYPE, FAILURE_TYPE> value;
};
/// Writes the result to the ostream.
diff --git a/src/tint/utils/string.h b/src/tint/utils/string.h
index f10e258..de1f4d6 100644
--- a/src/tint/utils/string.h
+++ b/src/tint/utils/string.h
@@ -24,9 +24,9 @@
/// @param substr the string to search for
/// @param replacement the replacement string to use instead of `substr`
/// @returns `str` with all occurrences of `substr` replaced with `replacement`
-inline std::string ReplaceAll(std::string str,
- const std::string& substr,
- const std::string& replacement) {
+[[nodiscard]] inline std::string ReplaceAll(std::string str,
+ const std::string& substr,
+ const std::string& replacement) {
size_t pos = 0;
while ((pos = str.find(substr, pos)) != std::string::npos) {
str.replace(pos, substr.length(), replacement);
diff --git a/src/tint/utils/unique_allocator.h b/src/tint/utils/unique_allocator.h
index 628bc79..1297701 100644
--- a/src/tint/utils/unique_allocator.h
+++ b/src/tint/utils/unique_allocator.h
@@ -39,7 +39,7 @@
// found in the set, then we create the persisted instance with the
// allocator.
TYPE key{args...};
- auto hash = HASH{}(key);
+ auto hash = Hasher{}(key);
auto it = items.find(Entry{hash, &key});
if (it != items.end()) {
return static_cast<TYPE*>(it->ptr);
@@ -50,6 +50,11 @@
}
protected:
+ /// The hash function
+ using Hasher = HASH;
+ /// The equality function
+ using Equality = EQUAL;
+
/// Entry is used as the entry to the unordered_set
struct Entry {
/// The pre-calculated hash of the entry
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 31f749f..e783b0b 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -64,6 +64,7 @@
#include "src/tint/transform/renamer.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/single_entry_point.h"
+#include "src/tint/transform/std140.h"
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/unwind_discard_functions.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
@@ -221,6 +222,7 @@
manager.Add<transform::CanonicalizeEntryPointIO>();
manager.Add<transform::ExpandCompoundAssignment>();
manager.Add<transform::PromoteSideEffectsToDecl>();
+ manager.Add<transform::Std140>(); // Must come after PromoteSideEffectsToDecl
manager.Add<transform::UnwindDiscardFunctions>();
manager.Add<transform::SimplifyPointers>();
diff --git a/src/tint/writer/glsl/generator_impl_binary_test.cc b/src/tint/writer/glsl/generator_impl_binary_test.cc
index 9113ae6..23e2ece 100644
--- a/src/tint/writer/glsl/generator_impl_binary_test.cc
+++ b/src/tint/writer/glsl/generator_impl_binary_test.cc
@@ -151,7 +151,8 @@
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
- auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
+ GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
+ auto* lhs = Expr("a");
auto* rhs = Expr(1_f);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -162,13 +163,14 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
+ EXPECT_EQ(out.str(), "(a * 1.0f)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
Enable(ast::Extension::kF16);
- auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
+ GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
+ auto* lhs = Expr("a");
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -179,12 +181,13 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
+ EXPECT_EQ(out.str(), "(a * 1.0hf)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
+ GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_f);
- auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
+ auto* rhs = Expr("a");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -194,14 +197,15 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
+ EXPECT_EQ(out.str(), "(1.0f * a)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
Enable(ast::Extension::kF16);
+ GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_h);
- auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
+ auto* rhs = Expr("a");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -211,7 +215,7 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
+ EXPECT_EQ(out.str(), "(1.0hf * a)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
diff --git a/src/tint/writer/hlsl/generator_impl_binary_test.cc b/src/tint/writer/hlsl/generator_impl_binary_test.cc
index aa54144..8a87697 100644
--- a/src/tint/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_binary_test.cc
@@ -184,7 +184,7 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
+ EXPECT_EQ(out.str(), "(1.0f).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
@@ -201,7 +201,7 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
+ EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
@@ -216,7 +216,7 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
+ EXPECT_EQ(out.str(), "(1.0f).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
@@ -233,7 +233,7 @@
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
+ EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
diff --git a/src/tint/writer/spirv/generator_impl.cc b/src/tint/writer/spirv/generator_impl.cc
index ace5209..d7a2b80 100644
--- a/src/tint/writer/spirv/generator_impl.cc
+++ b/src/tint/writer/spirv/generator_impl.cc
@@ -29,6 +29,7 @@
#include "src/tint/transform/remove_phonies.h"
#include "src/tint/transform/remove_unreachable_statements.h"
#include "src/tint/transform/simplify_pointers.h"
+#include "src/tint/transform/std140.h"
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/unwind_discard_functions.h"
#include "src/tint/transform/var_for_dynamic_index.h"
@@ -75,6 +76,7 @@
manager.Add<transform::RemoveUnreachableStatements>();
manager.Add<transform::ExpandCompoundAssignment>();
manager.Add<transform::PromoteSideEffectsToDecl>();
+ manager.Add<transform::Std140>(); // Must come after PromoteSideEffectsToDecl
manager.Add<transform::UnwindDiscardFunctions>();
manager.Add<transform::SimplifyPointers>(); // Required for arrayLength()
manager.Add<transform::RemovePhonies>();