[hlsl] Polyfill matrix multiplication.
This CL modifies the `mat * vec`, `vec * mat` and `mat * mat` binary
expressions to use the custom `hlsl.mul` intrinsic and flips the LHS and
RHS operands in order to match HLSL semantics.
Bug: 42251045
Change-Id: I2433d4c341c99276d90fd1ee765f79568698948a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196815
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/builtin_fn.cc b/src/tint/lang/hlsl/builtin_fn.cc
index 091e8d2..3bfa425 100644
--- a/src/tint/lang/hlsl/builtin_fn.cc
+++ b/src/tint/lang/hlsl/builtin_fn.cc
@@ -76,6 +76,8 @@
return "Store3";
case BuiltinFn::kStore4:
return "Store4";
+ case BuiltinFn::kMul:
+ return "mul";
}
return "<unknown>";
}
diff --git a/src/tint/lang/hlsl/builtin_fn.h b/src/tint/lang/hlsl/builtin_fn.h
index f9f0d3f..ab8b68a 100644
--- a/src/tint/lang/hlsl/builtin_fn.h
+++ b/src/tint/lang/hlsl/builtin_fn.h
@@ -64,6 +64,7 @@
kStore2,
kStore3,
kStore4,
+ kMul,
kNone,
};
diff --git a/src/tint/lang/hlsl/hlsl.def b/src/tint/lang/hlsl/hlsl.def
index 6adb7ca..b226e1d 100644
--- a/src/tint/lang/hlsl/hlsl.def
+++ b/src/tint/lang/hlsl/hlsl.def
@@ -51,6 +51,16 @@
type vec3<T>
type vec4<T>
@display("vec{N}<{T}>") type vec<N: num, T>
+type mat2x2<T>
+type mat2x3<T>
+type mat2x4<T>
+type mat3x2<T>
+type mat3x3<T>
+type mat3x4<T>
+type mat4x2<T>
+type mat4x3<T>
+type mat4x4<T>
+@display("mat{N}x{M}<{T}>") type mat<N: num, M: num, T>
type byte_address_buffer<A: access>
@@ -61,6 +71,7 @@
match iu32: i32 | u32
match f32_u32: f32 | u32
match f32_i32: f32 | i32
+match f32_f16: f32 | f16
match storage: address_space.storage
@@ -103,3 +114,7 @@
@member_function fn Store2(byte_address_buffer<writable>, offset: u32, value: vec2<u32>)
@member_function fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>)
@member_function fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>)
+
+fn mul [T: f32_f16, C: num, R: num](mat<C, R, T>, vec<C, T>) -> vec<R, T>
+fn mul [T: f32_f16, C: num, R: num](vec<R, T>, mat<C, R, T>) -> vec<C, T>
+fn mul [T: f32_f16, K: num, C: num, R: num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
diff --git a/src/tint/lang/hlsl/intrinsic/data.cc b/src/tint/lang/hlsl/intrinsic/data.cc
index 5780365..8db2ea0 100644
--- a/src/tint/lang/hlsl/intrinsic/data.cc
+++ b/src/tint/lang/hlsl/intrinsic/data.cc
@@ -249,6 +249,218 @@
};
+/// TypeMatcher for 'type mat2x2'
+constexpr TypeMatcher kMat2X2Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat2X2(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat2X2(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat2x2", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat2x3'
+constexpr TypeMatcher kMat2X3Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat2X3(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat2X3(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat2x3", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat2x4'
+constexpr TypeMatcher kMat2X4Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat2X4(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat2X4(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat2x4", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat3x2'
+constexpr TypeMatcher kMat3X2Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat3X2(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat3X2(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat3x2", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat3x3'
+constexpr TypeMatcher kMat3X3Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat3X3(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat3X3(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat3x3", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat3x4'
+constexpr TypeMatcher kMat3X4Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat3X4(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat3X4(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat3x4", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat4x2'
+constexpr TypeMatcher kMat4X2Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat4X2(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat4X2(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat4x2", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat4x3'
+constexpr TypeMatcher kMat4X3Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat4X3(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat4X3(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat4x3", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat4x4'
+constexpr TypeMatcher kMat4X4Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ const Type* T = nullptr;
+ if (!MatchMat4X4(state, ty, T)) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat4X4(state, ty, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat4x4", "<", T, ">");
+ }
+};
+
+
+/// TypeMatcher for 'type mat'
+constexpr TypeMatcher kMatMatcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ Number N = Number::invalid;
+ Number M = Number::invalid;
+ const Type* T = nullptr;
+ if (!MatchMat(state, ty, N, M, T)) {
+ return nullptr;
+ }
+ N = state.Num(N);
+ if (!N.IsValid()) {
+ return nullptr;
+ }
+ M = state.Num(M);
+ if (!M.IsValid()) {
+ return nullptr;
+ }
+ T = state.Type(T);
+ if (T == nullptr) {
+ return nullptr;
+ }
+ return BuildMat(state, ty, N, M, T);
+ },
+/* print */ []([[maybe_unused]] MatchState* state, StyledText& out) {StyledText N;
+ state->PrintNum(N);StyledText M;
+ state->PrintNum(M);StyledText T;
+ state->PrintType(T);
+ out << style::Type("mat", N, "x", M, "<", T, ">");
+ }
+};
+
+
/// TypeMatcher for 'type byte_address_buffer'
constexpr TypeMatcher kByteAddressBufferMatcher {
/* match */ [](MatchState& state, const Type* ty) -> const Type* {
@@ -320,6 +532,23 @@
kF32Matcher.print(nullptr, out); out << style::Plain(" or "); kI32Matcher.print(nullptr, out);}
};
+/// TypeMatcher for 'match f32_f16'
+constexpr TypeMatcher kF32F16Matcher {
+/* match */ [](MatchState& state, const Type* ty) -> const Type* {
+ if (MatchF32(state, ty)) {
+ return BuildF32(state, ty);
+ }
+ if (MatchF16(state, ty)) {
+ return BuildF16(state, ty);
+ }
+ return nullptr;
+ },
+/* print */ [](MatchState*, StyledText& out) {
+ // Note: We pass nullptr to the Matcher.print() functions, as matchers do not support
+ // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
+ kF32Matcher.print(nullptr, out); out << style::Plain(" or "); kF16Matcher.print(nullptr, out);}
+};
+
/// EnumMatcher for 'match storage'
constexpr NumberMatcher kStorageMatcher {
/* match */ [](MatchState&, Number number) -> Number {
@@ -371,68 +600,103 @@
constexpr TypeMatcher kTypeMatchers[] = {
/* [0] */ TemplateTypeMatcher<0>::matcher,
/* [1] */ TemplateTypeMatcher<1>::matcher,
- /* [2] */ kI32Matcher,
- /* [3] */ kU32Matcher,
- /* [4] */ kF32Matcher,
- /* [5] */ kF16Matcher,
- /* [6] */ kPtrMatcher,
- /* [7] */ kVec2Matcher,
- /* [8] */ kVec3Matcher,
- /* [9] */ kVec4Matcher,
- /* [10] */ kVecMatcher,
- /* [11] */ kByteAddressBufferMatcher,
- /* [12] */ kIu32Matcher,
- /* [13] */ kF32U32Matcher,
- /* [14] */ kF32I32Matcher,
+ /* [2] */ TemplateTypeMatcher<2>::matcher,
+ /* [3] */ TemplateTypeMatcher<3>::matcher,
+ /* [4] */ kI32Matcher,
+ /* [5] */ kU32Matcher,
+ /* [6] */ kF32Matcher,
+ /* [7] */ kF16Matcher,
+ /* [8] */ kPtrMatcher,
+ /* [9] */ kVec2Matcher,
+ /* [10] */ kVec3Matcher,
+ /* [11] */ kVec4Matcher,
+ /* [12] */ kVecMatcher,
+ /* [13] */ kMat2X2Matcher,
+ /* [14] */ kMat2X3Matcher,
+ /* [15] */ kMat2X4Matcher,
+ /* [16] */ kMat3X2Matcher,
+ /* [17] */ kMat3X3Matcher,
+ /* [18] */ kMat3X4Matcher,
+ /* [19] */ kMat4X2Matcher,
+ /* [20] */ kMat4X3Matcher,
+ /* [21] */ kMat4X4Matcher,
+ /* [22] */ kMatMatcher,
+ /* [23] */ kByteAddressBufferMatcher,
+ /* [24] */ kIu32Matcher,
+ /* [25] */ kF32U32Matcher,
+ /* [26] */ kF32I32Matcher,
+ /* [27] */ kF32F16Matcher,
};
/// The template numbers, and number matchers
constexpr NumberMatcher kNumberMatchers[] = {
/* [0] */ TemplateNumberMatcher<0>::matcher,
/* [1] */ TemplateNumberMatcher<1>::matcher,
- /* [2] */ kStorageMatcher,
- /* [3] */ kReadableMatcher,
- /* [4] */ kWritableMatcher,
+ /* [2] */ TemplateNumberMatcher<2>::matcher,
+ /* [3] */ TemplateNumberMatcher<3>::matcher,
+ /* [4] */ kStorageMatcher,
+ /* [5] */ kReadableMatcher,
+ /* [6] */ kWritableMatcher,
};
constexpr MatcherIndex kMatcherIndices[] = {
- /* [0] */ MatcherIndex(10),
+ /* [0] */ MatcherIndex(22),
/* [1] */ MatcherIndex(1),
/* [2] */ MatcherIndex(2),
- /* [3] */ MatcherIndex(10),
- /* [4] */ MatcherIndex(1),
- /* [5] */ MatcherIndex(0),
- /* [6] */ MatcherIndex(10),
- /* [7] */ MatcherIndex(1),
- /* [8] */ MatcherIndex(3),
- /* [9] */ MatcherIndex(10),
- /* [10] */ MatcherIndex(1),
- /* [11] */ MatcherIndex(4),
- /* [12] */ MatcherIndex(10),
- /* [13] */ MatcherIndex(0),
- /* [14] */ MatcherIndex(3),
- /* [15] */ MatcherIndex(10),
- /* [16] */ MatcherIndex(0),
- /* [17] */ MatcherIndex(4),
- /* [18] */ MatcherIndex(11),
- /* [19] */ MatcherIndex(3),
- /* [20] */ MatcherIndex(7),
- /* [21] */ MatcherIndex(3),
- /* [22] */ MatcherIndex(8),
- /* [23] */ MatcherIndex(3),
- /* [24] */ MatcherIndex(9),
- /* [25] */ MatcherIndex(3),
- /* [26] */ MatcherIndex(7),
- /* [27] */ MatcherIndex(5),
- /* [28] */ MatcherIndex(8),
- /* [29] */ MatcherIndex(5),
- /* [30] */ MatcherIndex(9),
- /* [31] */ MatcherIndex(5),
- /* [32] */ MatcherIndex(11),
- /* [33] */ MatcherIndex(4),
- /* [34] */ MatcherIndex(13),
- /* [35] */ MatcherIndex(14),
- /* [36] */ MatcherIndex(12),
+ /* [3] */ MatcherIndex(0),
+ /* [4] */ MatcherIndex(22),
+ /* [5] */ MatcherIndex(2),
+ /* [6] */ MatcherIndex(3),
+ /* [7] */ MatcherIndex(0),
+ /* [8] */ MatcherIndex(22),
+ /* [9] */ MatcherIndex(1),
+ /* [10] */ MatcherIndex(3),
+ /* [11] */ MatcherIndex(0),
+ /* [12] */ MatcherIndex(22),
+ /* [13] */ MatcherIndex(2),
+ /* [14] */ MatcherIndex(1),
+ /* [15] */ MatcherIndex(0),
+ /* [16] */ MatcherIndex(12),
+ /* [17] */ MatcherIndex(1),
+ /* [18] */ MatcherIndex(4),
+ /* [19] */ MatcherIndex(12),
+ /* [20] */ MatcherIndex(1),
+ /* [21] */ MatcherIndex(0),
+ /* [22] */ MatcherIndex(12),
+ /* [23] */ MatcherIndex(1),
+ /* [24] */ MatcherIndex(5),
+ /* [25] */ MatcherIndex(12),
+ /* [26] */ MatcherIndex(1),
+ /* [27] */ MatcherIndex(6),
+ /* [28] */ MatcherIndex(12),
+ /* [29] */ MatcherIndex(0),
+ /* [30] */ MatcherIndex(5),
+ /* [31] */ MatcherIndex(12),
+ /* [32] */ MatcherIndex(0),
+ /* [33] */ MatcherIndex(6),
+ /* [34] */ MatcherIndex(12),
+ /* [35] */ MatcherIndex(2),
+ /* [36] */ MatcherIndex(0),
+ /* [37] */ MatcherIndex(23),
+ /* [38] */ MatcherIndex(5),
+ /* [39] */ MatcherIndex(9),
+ /* [40] */ MatcherIndex(5),
+ /* [41] */ MatcherIndex(10),
+ /* [42] */ MatcherIndex(5),
+ /* [43] */ MatcherIndex(11),
+ /* [44] */ MatcherIndex(5),
+ /* [45] */ MatcherIndex(9),
+ /* [46] */ MatcherIndex(7),
+ /* [47] */ MatcherIndex(10),
+ /* [48] */ MatcherIndex(7),
+ /* [49] */ MatcherIndex(11),
+ /* [50] */ MatcherIndex(7),
+ /* [51] */ MatcherIndex(23),
+ /* [52] */ MatcherIndex(6),
+ /* [53] */ MatcherIndex(25),
+ /* [54] */ MatcherIndex(26),
+ /* [55] */ MatcherIndex(24),
+ /* [56] */ MatcherIndex(27),
};
static_assert(MatcherIndicesIndex::CanIndex(kMatcherIndices),
@@ -442,92 +706,92 @@
{
/* [0] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(32),
+ /* matcher_indices */ MatcherIndicesIndex(51),
},
{
/* [1] */
/* usage */ core::ParameterUsage::kOffset,
- /* matcher_indices */ MatcherIndicesIndex(8),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [2] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(8),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [3] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(32),
+ /* matcher_indices */ MatcherIndicesIndex(51),
},
{
/* [4] */
/* usage */ core::ParameterUsage::kOffset,
- /* matcher_indices */ MatcherIndicesIndex(8),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [5] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(20),
+ /* matcher_indices */ MatcherIndicesIndex(39),
},
{
/* [6] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(32),
+ /* matcher_indices */ MatcherIndicesIndex(51),
},
{
/* [7] */
/* usage */ core::ParameterUsage::kOffset,
- /* matcher_indices */ MatcherIndicesIndex(8),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [8] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(22),
+ /* matcher_indices */ MatcherIndicesIndex(41),
},
{
/* [9] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(32),
+ /* matcher_indices */ MatcherIndicesIndex(51),
},
{
/* [10] */
/* usage */ core::ParameterUsage::kOffset,
- /* matcher_indices */ MatcherIndicesIndex(8),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [11] */
/* usage */ core::ParameterUsage::kValue,
- /* matcher_indices */ MatcherIndicesIndex(24),
+ /* matcher_indices */ MatcherIndicesIndex(43),
},
{
/* [12] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(18),
+ /* matcher_indices */ MatcherIndicesIndex(37),
},
{
/* [13] */
/* usage */ core::ParameterUsage::kOffset,
- /* matcher_indices */ MatcherIndicesIndex(8),
+ /* matcher_indices */ MatcherIndicesIndex(24),
},
{
/* [14] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(5),
+ /* matcher_indices */ MatcherIndicesIndex(0),
},
{
/* [15] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(3),
+ /* matcher_indices */ MatcherIndicesIndex(19),
},
{
/* [16] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(11),
+ /* matcher_indices */ MatcherIndicesIndex(34),
},
{
/* [17] */
/* usage */ core::ParameterUsage::kNone,
- /* matcher_indices */ MatcherIndicesIndex(15),
+ /* matcher_indices */ MatcherIndicesIndex(0),
},
{
/* [18] */
@@ -539,6 +803,31 @@
/* usage */ core::ParameterUsage::kNone,
/* matcher_indices */ MatcherIndicesIndex(12),
},
+ {
+ /* [20] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(3),
+ },
+ {
+ /* [21] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(27),
+ },
+ {
+ /* [22] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(31),
+ },
+ {
+ /* [23] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(24),
+ },
+ {
+ /* [24] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(28),
+ },
};
static_assert(ParameterIndex::CanIndex(kParameters),
@@ -548,35 +837,77 @@
{
/* [0] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(34),
+ /* matcher_indices */ MatcherIndicesIndex(56),
/* kind */ TemplateInfo::Kind::kType,
},
{
/* [1] */
- /* name */ "N",
+ /* name */ "K",
/* matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* kind */ TemplateInfo::Kind::kNumber,
},
{
/* [2] */
- /* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(35),
- /* kind */ TemplateInfo::Kind::kType,
+ /* name */ "C",
+ /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* kind */ TemplateInfo::Kind::kNumber,
},
{
/* [3] */
- /* name */ "N",
+ /* name */ "R",
/* matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* kind */ TemplateInfo::Kind::kNumber,
},
{
/* [4] */
/* name */ "T",
- /* matcher_indices */ MatcherIndicesIndex(36),
+ /* matcher_indices */ MatcherIndicesIndex(56),
/* kind */ TemplateInfo::Kind::kType,
},
{
/* [5] */
+ /* name */ "C",
+ /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* kind */ TemplateInfo::Kind::kNumber,
+ },
+ {
+ /* [6] */
+ /* name */ "R",
+ /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* kind */ TemplateInfo::Kind::kNumber,
+ },
+ {
+ /* [7] */
+ /* name */ "T",
+ /* matcher_indices */ MatcherIndicesIndex(53),
+ /* kind */ TemplateInfo::Kind::kType,
+ },
+ {
+ /* [8] */
+ /* name */ "N",
+ /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* kind */ TemplateInfo::Kind::kNumber,
+ },
+ {
+ /* [9] */
+ /* name */ "T",
+ /* matcher_indices */ MatcherIndicesIndex(54),
+ /* kind */ TemplateInfo::Kind::kType,
+ },
+ {
+ /* [10] */
+ /* name */ "N",
+ /* matcher_indices */ MatcherIndicesIndex(/* invalid */),
+ /* kind */ TemplateInfo::Kind::kNumber,
+ },
+ {
+ /* [11] */
+ /* name */ "T",
+ /* matcher_indices */ MatcherIndicesIndex(55),
+ /* kind */ TemplateInfo::Kind::kType,
+ },
+ {
+ /* [12] */
/* name */ "N",
/* matcher_indices */ MatcherIndicesIndex(/* invalid */),
/* kind */ TemplateInfo::Kind::kNumber,
@@ -590,34 +921,34 @@
{
/* [0] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* num_parameters */ 1,
+ /* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(0),
+ /* num_templates */ 3,
+ /* templates */ TemplateIndex(4),
/* parameters */ ParameterIndex(14),
- /* return_matcher_indices */ MatcherIndicesIndex(2),
+ /* return_matcher_indices */ MatcherIndicesIndex(34),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [1] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* num_parameters */ 1,
+ /* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 2,
- /* templates */ TemplateIndex(0),
- /* parameters */ ParameterIndex(15),
- /* return_matcher_indices */ MatcherIndicesIndex(0),
+ /* num_templates */ 3,
+ /* templates */ TemplateIndex(4),
+ /* parameters */ ParameterIndex(16),
+ /* return_matcher_indices */ MatcherIndicesIndex(19),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [2] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* num_parameters */ 1,
+ /* num_parameters */ 2,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(2),
- /* parameters */ ParameterIndex(14),
- /* return_matcher_indices */ MatcherIndicesIndex(8),
+ /* num_templates */ 4,
+ /* templates */ TemplateIndex(0),
+ /* parameters */ ParameterIndex(18),
+ /* return_matcher_indices */ MatcherIndicesIndex(4),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -625,10 +956,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 2,
- /* templates */ TemplateIndex(2),
- /* parameters */ ParameterIndex(15),
- /* return_matcher_indices */ MatcherIndicesIndex(6),
+ /* num_templates */ 1,
+ /* templates */ TemplateIndex(7),
+ /* parameters */ ParameterIndex(20),
+ /* return_matcher_indices */ MatcherIndicesIndex(18),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -636,10 +967,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(4),
- /* parameters */ ParameterIndex(14),
- /* return_matcher_indices */ MatcherIndicesIndex(11),
+ /* num_templates */ 2,
+ /* templates */ TemplateIndex(7),
+ /* parameters */ ParameterIndex(15),
+ /* return_matcher_indices */ MatcherIndicesIndex(16),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -647,10 +978,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 2,
- /* templates */ TemplateIndex(4),
- /* parameters */ ParameterIndex(15),
- /* return_matcher_indices */ MatcherIndicesIndex(9),
+ /* num_templates */ 1,
+ /* templates */ TemplateIndex(9),
+ /* parameters */ ParameterIndex(20),
+ /* return_matcher_indices */ MatcherIndicesIndex(24),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -658,10 +989,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 0,
- /* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(16),
- /* return_matcher_indices */ MatcherIndicesIndex(8),
+ /* num_templates */ 2,
+ /* templates */ TemplateIndex(9),
+ /* parameters */ ParameterIndex(15),
+ /* return_matcher_indices */ MatcherIndicesIndex(22),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -670,9 +1001,9 @@
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
/* num_templates */ 1,
- /* templates */ TemplateIndex(1),
- /* parameters */ ParameterIndex(17),
- /* return_matcher_indices */ MatcherIndicesIndex(12),
+ /* templates */ TemplateIndex(11),
+ /* parameters */ ParameterIndex(20),
+ /* return_matcher_indices */ MatcherIndicesIndex(27),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -680,10 +1011,10 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 0,
- /* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(18),
- /* return_matcher_indices */ MatcherIndicesIndex(11),
+ /* num_templates */ 2,
+ /* templates */ TemplateIndex(11),
+ /* parameters */ ParameterIndex(15),
+ /* return_matcher_indices */ MatcherIndicesIndex(25),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -691,43 +1022,43 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 1,
- /* templates */ TemplateIndex(1),
- /* parameters */ ParameterIndex(19),
- /* return_matcher_indices */ MatcherIndicesIndex(15),
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(21),
+ /* return_matcher_indices */ MatcherIndicesIndex(24),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [10] */
- /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
- /* num_parameters */ 2,
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
+ /* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 0,
- /* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(8),
+ /* num_templates */ 1,
+ /* templates */ TemplateIndex(8),
+ /* parameters */ ParameterIndex(22),
+ /* return_matcher_indices */ MatcherIndicesIndex(28),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [11] */
- /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
- /* num_parameters */ 2,
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
+ /* num_parameters */ 1,
/* num_explicit_templates */ 0,
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(20),
+ /* parameters */ ParameterIndex(23),
+ /* return_matcher_indices */ MatcherIndicesIndex(27),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [12] */
- /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
- /* num_parameters */ 2,
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
+ /* num_parameters */ 1,
/* num_explicit_templates */ 0,
- /* num_templates */ 0,
- /* templates */ TemplateIndex(/* invalid */),
- /* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(22),
+ /* num_templates */ 1,
+ /* templates */ TemplateIndex(8),
+ /* parameters */ ParameterIndex(24),
+ /* return_matcher_indices */ MatcherIndicesIndex(31),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -749,7 +1080,7 @@
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(27),
+ /* return_matcher_indices */ MatcherIndicesIndex(39),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -760,7 +1091,7 @@
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(26),
+ /* return_matcher_indices */ MatcherIndicesIndex(41),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -771,7 +1102,7 @@
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(28),
+ /* return_matcher_indices */ MatcherIndicesIndex(43),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -782,12 +1113,45 @@
/* num_templates */ 0,
/* templates */ TemplateIndex(/* invalid */),
/* parameters */ ParameterIndex(12),
- /* return_matcher_indices */ MatcherIndicesIndex(30),
+ /* return_matcher_indices */ MatcherIndicesIndex(46),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
/* [18] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 2,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(45),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [19] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 2,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(47),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [20] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
+ /* num_parameters */ 2,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(12),
+ /* return_matcher_indices */ MatcherIndicesIndex(49),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [21] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 3,
/* num_explicit_templates */ 0,
/* num_templates */ 0,
@@ -797,7 +1161,7 @@
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
- /* [19] */
+ /* [22] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 3,
/* num_explicit_templates */ 0,
@@ -808,7 +1172,7 @@
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
- /* [20] */
+ /* [23] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 3,
/* num_explicit_templates */ 0,
@@ -819,7 +1183,7 @@
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
- /* [21] */
+ /* [24] */
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMemberFunction),
/* num_parameters */ 3,
/* num_explicit_templates */ 0,
@@ -840,107 +1204,115 @@
/* fn asint[T : f32_u32](T) -> i32 */
/* fn asint[T : f32_u32, N : num](vec<N, T>) -> vec<N, i32> */
/* num overloads */ 2,
- /* overloads */ OverloadIndex(0),
+ /* overloads */ OverloadIndex(3),
},
{
/* [1] */
/* fn asuint[T : f32_i32](T) -> u32 */
/* fn asuint[T : f32_i32, N : num](vec<N, T>) -> vec<N, u32> */
/* num overloads */ 2,
- /* overloads */ OverloadIndex(2),
+ /* overloads */ OverloadIndex(5),
},
{
/* [2] */
/* fn asfloat[T : iu32](T) -> f32 */
/* fn asfloat[T : iu32, N : num](vec<N, T>) -> vec<N, f32> */
/* num overloads */ 2,
- /* overloads */ OverloadIndex(4),
+ /* overloads */ OverloadIndex(7),
},
{
/* [3] */
/* fn f32tof16(f32) -> u32 */
/* fn f32tof16[N : num](vec<N, f32>) -> vec<N, u32> */
/* num overloads */ 2,
- /* overloads */ OverloadIndex(6),
+ /* overloads */ OverloadIndex(9),
},
{
/* [4] */
/* fn f16tof32(u32) -> f32 */
/* fn f16tof32[N : num](vec<N, u32>) -> vec<N, f32> */
/* num overloads */ 2,
- /* overloads */ OverloadIndex(8),
+ /* overloads */ OverloadIndex(11),
},
{
/* [5] */
/* fn Load(byte_address_buffer<readable>, offset: u32) -> u32 */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(10),
+ /* overloads */ OverloadIndex(13),
},
{
/* [6] */
/* fn Load2(byte_address_buffer<readable>, offset: u32) -> vec2<u32> */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(11),
+ /* overloads */ OverloadIndex(14),
},
{
/* [7] */
/* fn Load3(byte_address_buffer<readable>, offset: u32) -> vec3<u32> */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(12),
+ /* overloads */ OverloadIndex(15),
},
{
/* [8] */
/* fn Load4(byte_address_buffer<readable>, offset: u32) -> vec4<u32> */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(13),
+ /* overloads */ OverloadIndex(16),
},
{
/* [9] */
/* fn LoadF16(byte_address_buffer<readable>, offset: u32) -> f16 */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(14),
+ /* overloads */ OverloadIndex(17),
},
{
/* [10] */
/* fn Load2F16(byte_address_buffer<readable>, offset: u32) -> vec2<f16> */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(15),
+ /* overloads */ OverloadIndex(18),
},
{
/* [11] */
/* fn Load3F16(byte_address_buffer<readable>, offset: u32) -> vec3<f16> */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(16),
+ /* overloads */ OverloadIndex(19),
},
{
/* [12] */
/* fn Load4F16(byte_address_buffer<readable>, offset: u32) -> vec4<f16> */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(17),
+ /* overloads */ OverloadIndex(20),
},
{
/* [13] */
/* fn Store(byte_address_buffer<writable>, offset: u32, value: u32) */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(18),
+ /* overloads */ OverloadIndex(21),
},
{
/* [14] */
/* fn Store2(byte_address_buffer<writable>, offset: u32, value: vec2<u32>) */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(19),
+ /* overloads */ OverloadIndex(22),
},
{
/* [15] */
/* fn Store3(byte_address_buffer<writable>, offset: u32, value: vec3<u32>) */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(20),
+ /* overloads */ OverloadIndex(23),
},
{
/* [16] */
/* fn Store4(byte_address_buffer<writable>, offset: u32, value: vec4<u32>) */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(21),
+ /* overloads */ OverloadIndex(24),
+ },
+ {
+ /* [17] */
+ /* fn mul[T : f32_f16, C : num, R : num](mat<C, R, T>, vec<C, T>) -> vec<R, T> */
+ /* fn mul[T : f32_f16, C : num, R : num](vec<R, T>, mat<C, R, T>) -> vec<C, T> */
+ /* fn mul[T : f32_f16, K : num, C : num, R : num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
+ /* num overloads */ 3,
+ /* overloads */ OverloadIndex(0),
},
};
diff --git a/src/tint/lang/hlsl/writer/binary_test.cc b/src/tint/lang/hlsl/writer/binary_test.cc
index a87b0fc..14d8b16 100644
--- a/src/tint/lang/hlsl/writer/binary_test.cc
+++ b/src/tint/lang/hlsl/writer/binary_test.cc
@@ -396,8 +396,7 @@
)");
}
-// TODO(dsinclair): Needs binary polyfill
-TEST_F(HlslWriterTest, DISABLED_BinaryMulMatVec) {
+TEST_F(HlslWriterTest, BinaryMulMatVec) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
func->SetWorkgroupSize(1, 1, 1);
b.Append(func->Block(), [&] {
@@ -421,8 +420,7 @@
)");
}
-// TODO(dsinclair): Needs binary polyfill
-TEST_F(HlslWriterTest, DISABLED_BinaryMulVecMat) {
+TEST_F(HlslWriterTest, BinaryMulVecMat) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
func->SetWorkgroupSize(1, 1, 1);
b.Append(func->Block(), [&] {
@@ -446,8 +444,7 @@
)");
}
-// TODO(dsinclair): Needs binary polyfill
-TEST_F(HlslWriterTest, DISABLED_BinaryMulMatMat) {
+TEST_F(HlslWriterTest, BinaryMulMatMat) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
func->SetWorkgroupSize(1, 1, 1);
b.Append(func->Block(), [&] {
@@ -465,7 +462,7 @@
void foo() {
float4x4 x = float4x4((0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx);
float4x4 y = float4x4((0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx);
- float4 c = mul(y, x);
+ float4x4 c = mul(y, x);
}
)");
diff --git a/src/tint/lang/hlsl/writer/raise/binary_polyfill.cc b/src/tint/lang/hlsl/writer/raise/binary_polyfill.cc
index 70c78fd..0e265d4 100644
--- a/src/tint/lang/hlsl/writer/raise/binary_polyfill.cc
+++ b/src/tint/lang/hlsl/writer/raise/binary_polyfill.cc
@@ -35,6 +35,7 @@
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/manager.h"
+#include "src/tint/lang/hlsl/ir/builtin_call.h"
namespace tint::hlsl::writer::raise {
namespace {
@@ -66,6 +67,21 @@
}
break;
}
+ case core::BinaryOp::kMultiply: {
+ auto* lhs_ty = binary->LHS()->Type();
+ auto* rhs_ty = binary->RHS()->Type();
+
+ if ((lhs_ty->Is<core::type::Vector>() &&
+ rhs_ty->Is<core::type::Matrix>()) ||
+ (lhs_ty->Is<core::type::Matrix>() &&
+ rhs_ty->Is<core::type::Vector>()) ||
+ (lhs_ty->Is<core::type::Matrix>() &&
+ rhs_ty->Is<core::type::Matrix>())) {
+ binary_worklist.Push(binary);
+ }
+ break;
+ }
+
default:
break;
}
@@ -79,12 +95,26 @@
case core::BinaryOp::kModulo:
PreciseFloatMod(binary);
break;
+ case core::BinaryOp::kMultiply:
+ Mul(binary);
+ break;
default:
TINT_UNIMPLEMENTED();
}
}
}
+ // Multiplying by a matrix requires the use of `mul` in order to get the
+ // type of multiply we desire.
+ //
+ // Matrices are transposed, so swap LHS and RHS.
+ void Mul(core::ir::Binary* binary) {
+ auto* call = b.CallWithResult<hlsl::ir::BuiltinCall>(
+ binary->DetachResult(), hlsl::BuiltinFn::kMul, binary->RHS(), binary->LHS());
+ call->InsertBefore(binary);
+ binary->Destroy();
+ }
+
// Replace with:
//
// (lhs - (trunc(lhs / rhs)) * rhs)
diff --git a/src/tint/lang/hlsl/writer/raise/binary_polyfill_test.cc b/src/tint/lang/hlsl/writer/raise/binary_polyfill_test.cc
index 7c3bbfc..6b3381f 100644
--- a/src/tint/lang/hlsl/writer/raise/binary_polyfill_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/binary_polyfill_test.cc
@@ -196,5 +196,215 @@
EXPECT_EQ(expect, str());
}
+TEST_F(HlslWriter_BinaryPolyfillTest, MulVecMatF32) {
+ auto* x = b.FunctionParam<vec3<f32>>("x");
+ auto* y = b.FunctionParam<mat3x3<f32>>("y");
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({x, y});
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Multiply(ty.vec3<f32>(), x, y));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%x:vec3<f32>, %y:mat3x3<f32>):void {
+ $B1: {
+ %4:vec3<f32> = mul %x, %y
+ %a:vec3<f32> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%x:vec3<f32>, %y:mat3x3<f32>):void {
+ $B1: {
+ %4:vec3<f32> = hlsl.mul %y, %x
+ %a:vec3<f32> = let %4
+ ret
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriter_BinaryPolyfillTest, MulVecMatF16) {
+ auto* x = b.FunctionParam<vec3<f16>>("x");
+ auto* y = b.FunctionParam<mat3x3<f16>>("y");
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({x, y});
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Multiply(ty.vec3<f16>(), x, y));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%x:vec3<f16>, %y:mat3x3<f16>):void {
+ $B1: {
+ %4:vec3<f16> = mul %x, %y
+ %a:vec3<f16> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%x:vec3<f16>, %y:mat3x3<f16>):void {
+ $B1: {
+ %4:vec3<f16> = hlsl.mul %y, %x
+ %a:vec3<f16> = let %4
+ ret
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriter_BinaryPolyfillTest, MulMatVecF32) {
+ auto* x = b.FunctionParam<mat3x3<f32>>("x");
+ auto* y = b.FunctionParam<vec3<f32>>("y");
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({x, y});
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Multiply(ty.vec3<f32>(), x, y));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%x:mat3x3<f32>, %y:vec3<f32>):void {
+ $B1: {
+ %4:vec3<f32> = mul %x, %y
+ %a:vec3<f32> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%x:mat3x3<f32>, %y:vec3<f32>):void {
+ $B1: {
+ %4:vec3<f32> = hlsl.mul %y, %x
+ %a:vec3<f32> = let %4
+ ret
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriter_BinaryPolyfillTest, MulMatVecF16) {
+ auto* x = b.FunctionParam<mat3x3<f16>>("x");
+ auto* y = b.FunctionParam<vec3<f16>>("y");
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({x, y});
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Multiply(ty.vec3<f16>(), x, y));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%x:mat3x3<f16>, %y:vec3<f16>):void {
+ $B1: {
+ %4:vec3<f16> = mul %x, %y
+ %a:vec3<f16> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%x:mat3x3<f16>, %y:vec3<f16>):void {
+ $B1: {
+ %4:vec3<f16> = hlsl.mul %y, %x
+ %a:vec3<f16> = let %4
+ ret
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriter_BinaryPolyfillTest, MulMatMat32) {
+ auto* x = b.FunctionParam<mat3x3<f32>>("x");
+ auto* y = b.FunctionParam<mat3x3<f32>>("y");
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({x, y});
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Multiply(ty.mat3x3<f32>(), x, y));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%x:mat3x3<f32>, %y:mat3x3<f32>):void {
+ $B1: {
+ %4:mat3x3<f32> = mul %x, %y
+ %a:mat3x3<f32> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%x:mat3x3<f32>, %y:mat3x3<f32>):void {
+ $B1: {
+ %4:mat3x3<f32> = hlsl.mul %y, %x
+ %a:mat3x3<f32> = let %4
+ ret
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(HlslWriter_BinaryPolyfillTest, MulMatMat16) {
+ auto* x = b.FunctionParam<mat3x3<f16>>("x");
+ auto* y = b.FunctionParam<mat3x3<f16>>("y");
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({x, y});
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Multiply(ty.mat3x3<f16>(), x, y));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%x:mat3x3<f16>, %y:mat3x3<f16>):void {
+ $B1: {
+ %4:mat3x3<f16> = mul %x, %y
+ %a:mat3x3<f16> = let %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%x:mat3x3<f16>, %y:mat3x3<f16>):void {
+ $B1: {
+ %4:mat3x3<f16> = hlsl.mul %y, %x
+ %a:mat3x3<f16> = let %4
+ ret
+ }
+}
+)";
+
+ Run(BinaryPolyfill);
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::hlsl::writer::raise