tint: Add matrix identify and single-scalar ctors
Fixed: tint:1545
Change-Id: I86451223765f620861bf98861142e6d34c7e945b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90502
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index b7020f6..9a66118 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -66,6 +66,7 @@
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/unwind_discard_functions.h"
+#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
@@ -198,6 +199,7 @@
manager.Add<transform::ExpandCompoundAssignment>();
manager.Add<transform::PromoteSideEffectsToDecl>();
manager.Add<transform::UnwindDiscardFunctions>();
+ manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::SimplifyPointers>();
manager.Add<transform::RemovePhonies>();
// ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
@@ -1080,6 +1082,55 @@
return EmitZeroValue(out, type);
}
+ if (auto* mat = call->Type()->As<sem::Matrix>()) {
+ if (ctor->Parameters().size() == 1) {
+ // Matrix constructor with single scalar.
+ auto fn = utils::GetOrCreate(matrix_scalar_ctors_, mat, [&]() -> std::string {
+ TextBuffer b;
+ TINT_DEFER(helpers_.Append(b));
+
+ auto name = UniqueIdentifier("build_mat" + std::to_string(mat->columns()) + "x" +
+ std::to_string(mat->rows()));
+ {
+ auto l = line(&b);
+ if (!EmitType(l, mat, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
+ return "";
+ }
+ l << " " << name << "(";
+ if (!EmitType(l, mat->type(), ast::StorageClass::kNone, ast::Access::kUndefined,
+ "")) {
+ return "";
+ }
+ l << " value) {";
+ }
+ {
+ ScopedIndent si(&b);
+ auto l = line(&b);
+ l << "return ";
+ if (!EmitType(l, mat, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
+ return "";
+ }
+ l << "(";
+ for (uint32_t i = 0; i < mat->columns() * mat->rows(); i++) {
+ l << ((i > 0) ? ", value" : "value");
+ }
+ l << ");";
+ }
+ line(&b) << "}";
+ return name;
+ });
+ if (fn.empty()) {
+ return false;
+ }
+ out << fn << "(";
+ if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
+ return false;
+ }
+ out << ")";
+ return true;
+ }
+ }
+
bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
// For single-value vector initializers, swizzle the scalar to the right