[hlsl] Fix matrix ordering in HLSL mul operator.
When creating the HLSL mul operator, the parameters are swapped because
HLSL works transformed. This CL updates the def file to swap the matrix
columns and rows so that validating the matrices will work correctly.
Bug: 42251045
Change-Id: I8afa0ed5cf295568a23ab8bffda618b0198d434d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/197496
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/hlsl.def b/src/tint/lang/hlsl/hlsl.def
index abfcd55..e557759 100644
--- a/src/tint/lang/hlsl/hlsl.def
+++ b/src/tint/lang/hlsl/hlsl.def
@@ -123,9 +123,10 @@
fn f16tof32(u32) -> f32
fn f16tof32[N: num](vec<N, u32>) -> vec<N, f32>
-fn mul [T: f32_f16, C: num, R: num](mat<C, R, T>, vec<C, T>) -> vec<R, T>
-fn mul [T: f32_f16, C: num, R: num](vec<R, T>, mat<C, R, T>) -> vec<C, T>
-fn mul [T: f32_f16, K: num, C: num, R: num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
+// Treat the HLSL matrices as transposed when used in `mul`
+fn mul [T: f32_f16, C: num, R: num](mat<R, C, T>, vec<C, T>) -> vec<R, T>
+fn mul [T: f32_f16, C: num, R: num](vec<R, T>, mat<R, C, T>) -> vec<C, T>
+fn mul [T: f32_f16, K: num, C: num, R: num](mat<R, K, T>, mat<K, C, T>) -> mat<R, C, T>
fn sign[T: fi32_f16](T) -> i32
fn sign[N: num, T: fi32_f16](vec<N, T>) -> vec<N, i32>
diff --git a/src/tint/lang/hlsl/intrinsic/data.cc b/src/tint/lang/hlsl/intrinsic/data.cc
index 90ec334..ff76180 100644
--- a/src/tint/lang/hlsl/intrinsic/data.cc
+++ b/src/tint/lang/hlsl/intrinsic/data.cc
@@ -1057,20 +1057,20 @@
constexpr MatcherIndex kMatcherIndices[] = {
/* [0] */ MatcherIndex(22),
- /* [1] */ MatcherIndex(1),
- /* [2] */ MatcherIndex(2),
+ /* [1] */ MatcherIndex(2),
+ /* [2] */ MatcherIndex(1),
/* [3] */ MatcherIndex(0),
/* [4] */ MatcherIndex(22),
- /* [5] */ MatcherIndex(2),
- /* [6] */ MatcherIndex(3),
+ /* [5] */ MatcherIndex(3),
+ /* [6] */ MatcherIndex(2),
/* [7] */ MatcherIndex(0),
/* [8] */ MatcherIndex(22),
- /* [9] */ MatcherIndex(1),
- /* [10] */ MatcherIndex(3),
+ /* [9] */ MatcherIndex(3),
+ /* [10] */ MatcherIndex(1),
/* [11] */ MatcherIndex(0),
/* [12] */ MatcherIndex(22),
- /* [13] */ MatcherIndex(2),
- /* [14] */ MatcherIndex(1),
+ /* [13] */ MatcherIndex(1),
+ /* [14] */ MatcherIndex(2),
/* [15] */ MatcherIndex(0),
/* [16] */ MatcherIndex(8),
/* [17] */ MatcherIndex(5),
@@ -2600,9 +2600,9 @@
},
{
/* [5] */
- /* fn mul[T : f32_f16, C : num, R : num](mat<C, R, T>, vec<C, T>) -> vec<R, T> */
- /* fn mul[T : f32_f16, C : num, R : num](vec<R, T>, mat<C, R, T>) -> vec<C, T> */
- /* fn mul[T : f32_f16, K : num, C : num, R : num](mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
+ /* fn mul[T : f32_f16, C : num, R : num](mat<R, C, T>, vec<C, T>) -> vec<R, T> */
+ /* fn mul[T : f32_f16, C : num, R : num](vec<R, T>, mat<R, C, T>) -> vec<C, T> */
+ /* fn mul[T : f32_f16, K : num, C : num, R : num](mat<R, K, T>, mat<K, C, T>) -> mat<R, C, T> */
/* num overloads */ 3,
/* overloads */ OverloadIndex(27),
},
diff --git a/src/tint/lang/hlsl/writer/binary_test.cc b/src/tint/lang/hlsl/writer/binary_test.cc
index 14d8b16..f029a9f 100644
--- a/src/tint/lang/hlsl/writer/binary_test.cc
+++ b/src/tint/lang/hlsl/writer/binary_test.cc
@@ -444,6 +444,54 @@
)");
}
+TEST_F(HlslWriterTest, BinaryMulVec4Mat3x4) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ func->SetWorkgroupSize(1, 1, 1);
+ b.Append(func->Block(), [&] {
+ auto* x = b.Var("x", b.Zero<vec4<f32>>());
+ auto* y = b.Var("y", b.Zero<mat3x4<f32>>());
+ auto* l = b.Load(x);
+ auto* r = b.Load(y);
+ b.Var("c", b.Multiply(ty.vec3<f32>(), l, r));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+[numthreads(1, 1, 1)]
+void foo() {
+ float4 x = (0.0f).xxxx;
+ float3x4 y = float3x4((0.0f).xxxx, (0.0f).xxxx, (0.0f).xxxx);
+ float3 c = mul(y, x);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BinaryMulMat3x2Vec3) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ func->SetWorkgroupSize(1, 1, 1);
+ b.Append(func->Block(), [&] {
+ auto* x = b.Var("x", b.Zero<mat3x2<f32>>());
+ auto* y = b.Var("y", b.Zero<vec3<f32>>());
+ auto* l = b.Load(x);
+ auto* r = b.Load(y);
+ b.Var("c", b.Multiply(ty.vec2<f32>(), l, r));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+[numthreads(1, 1, 1)]
+void foo() {
+ float3x2 x = float3x2((0.0f).xx, (0.0f).xx, (0.0f).xx);
+ float3 y = (0.0f).xxx;
+ float2 c = mul(y, x);
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, BinaryMulMatMat) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
func->SetWorkgroupSize(1, 1, 1);