Implement addition and subtraction of float matrices
Bug: tint:316
Change-Id: I3a1082c41c47daacf0220d029cb2a5f118684959
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52580
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index baceffa..99dc700 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2124,13 +2124,20 @@
}
// Matrix arithmetic
- // TODO(amaiorano): matrix-matrix addition and subtraction
+ auto* lhs_mat = lhs_type->As<Matrix>();
+ auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
+ auto* rhs_mat = rhs_type->As<Matrix>();
+ auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
+ // Addition and subtraction of float matrices
+ if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
+ lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
+ rhs_mat_elem_type->Is<F32>() &&
+ (lhs_mat->columns() == rhs_mat->columns()) &&
+ (lhs_mat->rows() == rhs_mat->rows())) {
+ SetType(expr, rhs_type);
+ return true;
+ }
if (expr->IsMultiply()) {
- auto* lhs_mat = lhs_type->As<Matrix>();
- auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
- auto* rhs_mat = rhs_type->As<Matrix>();
- auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
-
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 06b6ce1..7406204 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1238,9 +1238,12 @@
};
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
- ast_bool, ast_u32, ast_i32, ast_f32,
- ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
- ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>};
+ ast_bool, ast_u32, ast_i32, ast_f32,
+ ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
+ ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>, //
+ ast_mat2x3<i32>, ast_mat2x3<u32>, ast_mat2x3<f32>, //
+ ast_mat3x2<i32>, ast_mat3x2<u32>, ast_mat3x2<f32> //
+};
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
// matNxN, we only test N=3.
@@ -1338,14 +1341,43 @@
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
// Matrix arithmetic
+ Params{Op::kMultiply, ast_mat2x3<f32>, ast_f32, sem_mat2x3<sem_f32>},
+ Params{Op::kMultiply, ast_mat3x2<f32>, ast_f32, sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
+
+ Params{Op::kMultiply, ast_f32, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
+ Params{Op::kMultiply, ast_f32, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
+ Params{Op::kMultiply, ast_vec3<f32>, ast_mat2x3<f32>, sem_vec2<sem_f32>},
+ Params{Op::kMultiply, ast_vec2<f32>, ast_mat3x2<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
+
+ Params{Op::kMultiply, ast_mat3x2<f32>, ast_vec3<f32>, sem_vec2<sem_f32>},
+ Params{Op::kMultiply, ast_mat2x3<f32>, ast_vec2<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
- // TODO(amaiorano): add mat+mat and mat-mat
+
+ Params{Op::kMultiply, ast_mat2x3<f32>, ast_mat3x2<f32>,
+ sem_mat3x3<sem_f32>},
+ Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat2x3<f32>,
+ sem_mat2x2<sem_f32>},
+ Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat3x3<f32>,
+ sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},
+ Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat2x3<f32>,
+ sem_mat2x3<sem_f32>},
+
+ Params{Op::kAdd, ast_mat2x3<f32>, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
+ Params{Op::kAdd, ast_mat3x2<f32>, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
+ Params{Op::kAdd, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
+
+ Params{Op::kSubtract, ast_mat2x3<f32>, ast_mat2x3<f32>,
+ sem_mat2x3<sem_f32>},
+ Params{Op::kSubtract, ast_mat3x2<f32>, ast_mat3x2<f32>,
+ sem_mat3x2<sem_f32>},
+ Params{Op::kSubtract, ast_mat3x3<f32>, ast_mat3x3<f32>,
+ sem_mat3x3<sem_f32>},
// Comparison expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index b82aec0..4a10110 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -177,6 +177,26 @@
}
template <typename T>
+ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
+ return ty.mat2x3<T>();
+}
+
+template <create_ast_type_func_ptr create_type>
+ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
+ return ty.mat2x3(create_type(ty));
+}
+
+template <typename T>
+ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
+ return ty.mat3x2<T>();
+}
+
+template <create_ast_type_func_ptr create_type>
+ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
+ return ty.mat3x2(create_type(ty));
+}
+
+template <typename T>
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
@@ -250,6 +270,18 @@
}
template <create_sem_type_func_ptr create_type>
+sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
+ auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
+ return ty.builder->create<sem::Matrix>(column_type, 2u);
+}
+
+template <create_sem_type_func_ptr create_type>
+sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
+ auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
+ return ty.builder->create<sem::Matrix>(column_type, 3u);
+}
+
+template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
return ty.builder->create<sem::Matrix>(column_type, 3u);
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index dd8de52..79e551c 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1748,6 +1748,67 @@
return splat_result.to_i();
}
+uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
+ uint32_t rhs_id,
+ const sem::Matrix* type,
+ spv::Op op) {
+ // Example addition of two matrices:
+ // %31 = OpLoad %mat3v4float %m34
+ // %32 = OpLoad %mat3v4float %m34
+ // %33 = OpCompositeExtract %v4float %31 0
+ // %34 = OpCompositeExtract %v4float %32 0
+ // %35 = OpFAdd %v4float %33 %34
+ // %36 = OpCompositeExtract %v4float %31 1
+ // %37 = OpCompositeExtract %v4float %32 1
+ // %38 = OpFAdd %v4float %36 %37
+ // %39 = OpCompositeExtract %v4float %31 2
+ // %40 = OpCompositeExtract %v4float %32 2
+ // %41 = OpFAdd %v4float %39 %40
+ // %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
+
+ auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
+ auto column_type_id = GenerateTypeIfNeeded(column_type);
+
+ OperandList ops;
+
+ for (uint32_t i = 0; i < type->columns(); ++i) {
+ // Extract column `i` from lhs mat
+ auto lhs_column_id = result_op();
+ if (!push_function_inst(spv::Op::OpCompositeExtract,
+ {Operand::Int(column_type_id), lhs_column_id,
+ Operand::Int(lhs_id), Operand::Int(i)})) {
+ return 0;
+ }
+
+ // Extract column `i` from rhs mat
+ auto rhs_column_id = result_op();
+ if (!push_function_inst(spv::Op::OpCompositeExtract,
+ {Operand::Int(column_type_id), rhs_column_id,
+ Operand::Int(rhs_id), Operand::Int(i)})) {
+ return 0;
+ }
+
+ // Add or subtract the two columns
+ auto result = result_op();
+ if (!push_function_inst(op, {Operand::Int(column_type_id), result,
+ lhs_column_id, rhs_column_id})) {
+ return 0;
+ }
+
+ ops.push_back(result);
+ }
+
+ // Create the result matrix from the added/subtracted column vectors
+ auto result_mat_id = result_op();
+ ops.insert(ops.begin(), result_mat_id);
+ ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
+ if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
+ return 0;
+ }
+
+ return result_mat_id.to_i();
+}
+
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// There is special logic for short circuiting operators.
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
@@ -1779,6 +1840,24 @@
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
+ // Handle matrix-matrix addition and subtraction
+ if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
+ rhs_type->is_float_matrix()) {
+ auto* lhs_mat = lhs_type->As<sem::Matrix>();
+ auto* rhs_mat = rhs_type->As<sem::Matrix>();
+
+ // This should already have been validated by resolver
+ if (lhs_mat->rows() != rhs_mat->rows() ||
+ lhs_mat->columns() != rhs_mat->columns()) {
+ error_ = "matrices must have same dimensionality for add or subtract";
+ return 0;
+ }
+
+ return GenerateMatrixAddOrSub(
+ lhs_id, rhs_id, lhs_mat,
+ expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
+ }
+
// For vector-scalar arithmetic operations, splat scalar into a vector. We
// skip this for multiply as we can use OpVectorTimesScalar.
const bool is_float_scalar_vector_multiply =
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 3dac6d2..1a7271d 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -482,6 +482,17 @@
/// @returns id of the new vector
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
+ /// Generates instructions to add or subtract two matrices
+ /// @param lhs_id id of multiplicand
+ /// @param rhs_id id of multiplier
+ /// @param type type of both matrices and of result
+ /// @param op one of `spv::Op::OpFAdd` or `spv::Op::OpFSub`
+ /// @returns id of the result matrix
+ uint32_t GenerateMatrixAddOrSub(uint32_t lhs_id,
+ uint32_t rhs_id,
+ const sem::Matrix* type,
+ spv::Op op);
+
/// Converts AST image format to SPIR-V and pushes an appropriate capability.
/// @param format AST image format type
/// @returns SPIR-V image format type
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index e15d4ec..3b8391b 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -863,8 +863,10 @@
)");
}
+namespace BinaryArithVectorScalar {
+
enum class Type { f32, i32, u32 };
-ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
+static ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
@@ -875,7 +877,7 @@
}
return nullptr;
}
-ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
+static ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->Expr(1.f);
@@ -886,7 +888,7 @@
}
return nullptr;
}
-std::string OpTypeDecl(Type type) {
+static std::string OpTypeDecl(Type type) {
switch (type) {
case Type::f32:
return "OpTypeFloat 32";
@@ -904,8 +906,8 @@
std::string name;
};
-using MixedBinaryArithTest = TestParamHelper<Param>;
-TEST_P(MixedBinaryArithTest, VectorScalar) {
+using BinaryArithVectorScalarTest = TestParamHelper<Param>;
+TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
auto& param = GetParam();
ast::Expression* lhs = MakeVectorExpr(this, param.type);
@@ -943,7 +945,7 @@
Validate(b);
}
-TEST_P(MixedBinaryArithTest, ScalarVector) {
+TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
auto& param = GetParam();
ast::Expression* lhs = MakeScalarExpr(this, param.type);
@@ -983,7 +985,7 @@
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
- MixedBinaryArithTest,
+ BinaryArithVectorScalarTest,
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
// NOTE: Modulo not allowed on mixed float scalar-vector
@@ -1005,8 +1007,8 @@
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
-using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
-TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
+using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
+TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
auto& param = GetParam();
ast::Expression* lhs = MakeVectorExpr(this, param.type);
@@ -1040,7 +1042,7 @@
Validate(b);
}
-TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
+TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
auto& param = GetParam();
ast::Expression* lhs = MakeScalarExpr(this, param.type);
@@ -1075,10 +1077,113 @@
Validate(b);
}
INSTANTIATE_TEST_SUITE_P(BuilderTest,
- MixedBinaryArithMultiplyTest,
+ BinaryArithVectorScalarMultiplyTest,
testing::Values(Param{
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
+} // namespace BinaryArithVectorScalar
+
+namespace BinaryArithMatrixMatrix {
+
+struct Param {
+ ast::BinaryOp op;
+ std::string name;
+};
+
+using BinaryArithMatrixMatrix = TestParamHelper<Param>;
+TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
+ auto& param = GetParam();
+
+ ast::Expression* lhs = mat3x4<f32>();
+ ast::Expression* rhs = mat3x4<f32>();
+
+ auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+ WrapInFunction(expr);
+
+ spirv::Builder& b = Build();
+ ASSERT_TRUE(b.Build()) << b.error();
+
+ EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeFloat 32
+%6 = OpTypeVector %7 4
+%5 = OpTypeMatrix %6 3
+%8 = OpConstantNull %5
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%10 = OpCompositeExtract %6 %8 0
+%11 = OpCompositeExtract %6 %8 0
+%12 = )" + param.name + R"( %6 %10 %11
+%13 = OpCompositeExtract %6 %8 1
+%14 = OpCompositeExtract %6 %8 1
+%15 = )" + param.name + R"( %6 %13 %14
+%16 = OpCompositeExtract %6 %8 2
+%17 = OpCompositeExtract %6 %8 2
+%18 = )" + param.name + R"( %6 %16 %17
+%19 = OpCompositeConstruct %5 %12 %15 %18
+OpReturn
+OpFunctionEnd
+)");
+
+ Validate(b);
+}
+INSTANTIATE_TEST_SUITE_P( //
+ BuilderTest,
+ BinaryArithMatrixMatrix,
+ testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
+ Param{ast::BinaryOp::kSubtract, "OpFSub"}));
+
+using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
+TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
+ auto& param = GetParam();
+
+ ast::Expression* lhs = mat3x4<f32>();
+ ast::Expression* rhs = mat4x3<f32>();
+
+ auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+ WrapInFunction(expr);
+
+ spirv::Builder& b = Build();
+ ASSERT_TRUE(b.Build()) << b.error();
+
+ EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeFloat 32
+%6 = OpTypeVector %7 4
+%5 = OpTypeMatrix %6 3
+%8 = OpConstantNull %5
+%10 = OpTypeVector %7 3
+%9 = OpTypeMatrix %10 4
+%11 = OpConstantNull %9
+%13 = OpTypeMatrix %6 4
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%12 = OpMatrixTimesMatrix %13 %8 %11
+OpReturn
+OpFunctionEnd
+)");
+
+ Validate(b);
+}
+INSTANTIATE_TEST_SUITE_P( //
+ BuilderTest,
+ BinaryArithMatrixMatrixMultiply,
+ testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
+
+} // namespace BinaryArithMatrixMatrix
+
} // namespace
} // namespace spirv
} // namespace writer
diff --git a/test/expressions/binary_expressions.wgsl b/test/expressions/binary_expressions.wgsl
index a58eece..b7444a0 100644
--- a/test/expressions/binary_expressions.wgsl
+++ b/test/expressions/binary_expressions.wgsl
@@ -64,6 +64,19 @@
r = s % v;
}
+fn matrix_matrix_f32() {
+ var m34 : mat3x4<f32>;
+ var m43 : mat4x3<f32>;
+ var m33 : mat3x3<f32>;
+ var m44 : mat4x4<f32>;
+
+ m34 = m34 + m34;
+ m34 = m34 - m34;
+
+ m33 = m43 * m34;
+ m44 = m34 * m43;
+}
+
[[stage(fragment)]]
fn main() -> [[location(0)]] vec4<f32> {
return vec4<f32>(0.0,0.0,0.0,0.0);
diff --git a/test/expressions/binary_expressions.wgsl.expected.hlsl b/test/expressions/binary_expressions.wgsl.expected.hlsl
index 01f2d2b..c234983 100644
--- a/test/expressions/binary_expressions.wgsl.expected.hlsl
+++ b/test/expressions/binary_expressions.wgsl.expected.hlsl
@@ -66,6 +66,17 @@
r = (s % v);
}
+void matrix_matrix_f32() {
+ float3x4 m34 = float3x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+ float4x3 m43 = float4x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+ float3x3 m33 = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+ float4x4 m44 = float4x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+ m34 = (m34 + m34);
+ m34 = (m34 - m34);
+ m33 = mul(m34, m43);
+ m44 = mul(m43, m34);
+}
+
tint_symbol main() {
const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
return tint_symbol_1;
diff --git a/test/expressions/binary_expressions.wgsl.expected.msl b/test/expressions/binary_expressions.wgsl.expected.msl
index 5630940..e46aa45 100644
--- a/test/expressions/binary_expressions.wgsl.expected.msl
+++ b/test/expressions/binary_expressions.wgsl.expected.msl
@@ -69,6 +69,17 @@
r = (s % v);
}
+void matrix_matrix_f32() {
+ float3x4 m34 = float3x4(0.0f);
+ float4x3 m43 = float4x3(0.0f);
+ float3x3 m33 = float3x3(0.0f);
+ float4x4 m44 = float4x4(0.0f);
+ m34 = (m34 + m34);
+ m34 = (m34 - m34);
+ m33 = (m43 * m34);
+ m44 = (m34 * m43);
+}
+
fragment tint_symbol_1 tint_symbol() {
return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
}
diff --git a/test/expressions/binary_expressions.wgsl.expected.spvasm b/test/expressions/binary_expressions.wgsl.expected.spvasm
index 445bb49..8052168 100644
--- a/test/expressions/binary_expressions.wgsl.expected.spvasm
+++ b/test/expressions/binary_expressions.wgsl.expected.spvasm
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
-; Bound: 200
+; Bound: 250
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
@@ -32,6 +32,11 @@
OpName %v_4 "v"
OpName %s_4 "s"
OpName %r_4 "r"
+ OpName %matrix_matrix_f32 "matrix_matrix_f32"
+ OpName %m34 "m34"
+ OpName %m43 "m43"
+ OpName %m33 "m33"
+ OpName %m44 "m44"
OpName %tint_symbol_2 "tint_symbol_2"
OpName %tint_symbol "tint_symbol"
OpName %main "main"
@@ -60,9 +65,21 @@
%78 = OpConstantNull %v3uint
%_ptr_Function_uint = OpTypePointer Function %uint
%81 = OpConstantNull %uint
- %191 = OpTypeFunction %void %v4float
+%mat3v4float = OpTypeMatrix %v4float 3
+%_ptr_Function_mat3v4float = OpTypePointer Function %mat3v4float
+ %196 = OpConstantNull %mat3v4float
+%mat4v3float = OpTypeMatrix %v3float 4
+%_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
+ %200 = OpConstantNull %mat4v3float
+%mat3v3float = OpTypeMatrix %v3float 3
+%_ptr_Function_mat3v3float = OpTypePointer Function %mat3v3float
+ %204 = OpConstantNull %mat3v3float
+%mat4v4float = OpTypeMatrix %v4float 4
+%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float
+ %208 = OpConstantNull %mat4v4float
+ %241 = OpTypeFunction %void %v4float
%float_0 = OpConstant %float 0
- %199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+ %249 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%vector_scalar_f32 = OpFunction %void None %6
%9 = OpLabel
%v = OpVariable %_ptr_Function_v3float Function %13
@@ -269,14 +286,56 @@
OpStore %r_4 %188
OpReturn
OpFunctionEnd
-%tint_symbol_2 = OpFunction %void None %191
+%matrix_matrix_f32 = OpFunction %void None %6
+ %192 = OpLabel
+ %m34 = OpVariable %_ptr_Function_mat3v4float Function %196
+ %m43 = OpVariable %_ptr_Function_mat4v3float Function %200
+ %m33 = OpVariable %_ptr_Function_mat3v3float Function %204
+ %m44 = OpVariable %_ptr_Function_mat4v4float Function %208
+ %209 = OpLoad %mat3v4float %m34
+ %210 = OpLoad %mat3v4float %m34
+ %212 = OpCompositeExtract %v4float %209 0
+ %213 = OpCompositeExtract %v4float %210 0
+ %214 = OpFAdd %v4float %212 %213
+ %215 = OpCompositeExtract %v4float %209 1
+ %216 = OpCompositeExtract %v4float %210 1
+ %217 = OpFAdd %v4float %215 %216
+ %218 = OpCompositeExtract %v4float %209 2
+ %219 = OpCompositeExtract %v4float %210 2
+ %220 = OpFAdd %v4float %218 %219
+ %221 = OpCompositeConstruct %mat3v4float %214 %217 %220
+ OpStore %m34 %221
+ %222 = OpLoad %mat3v4float %m34
+ %223 = OpLoad %mat3v4float %m34
+ %225 = OpCompositeExtract %v4float %222 0
+ %226 = OpCompositeExtract %v4float %223 0
+ %227 = OpFSub %v4float %225 %226
+ %228 = OpCompositeExtract %v4float %222 1
+ %229 = OpCompositeExtract %v4float %223 1
+ %230 = OpFSub %v4float %228 %229
+ %231 = OpCompositeExtract %v4float %222 2
+ %232 = OpCompositeExtract %v4float %223 2
+ %233 = OpFSub %v4float %231 %232
+ %234 = OpCompositeConstruct %mat3v4float %227 %230 %233
+ OpStore %m34 %234
+ %235 = OpLoad %mat4v3float %m43
+ %236 = OpLoad %mat3v4float %m34
+ %237 = OpMatrixTimesMatrix %mat3v3float %235 %236
+ OpStore %m33 %237
+ %238 = OpLoad %mat3v4float %m34
+ %239 = OpLoad %mat4v3float %m43
+ %240 = OpMatrixTimesMatrix %mat4v4float %238 %239
+ OpStore %m44 %240
+ OpReturn
+ OpFunctionEnd
+%tint_symbol_2 = OpFunction %void None %241
%tint_symbol = OpFunctionParameter %v4float
- %194 = OpLabel
+ %244 = OpLabel
OpStore %tint_symbol_1 %tint_symbol
OpReturn
OpFunctionEnd
%main = OpFunction %void None %6
- %196 = OpLabel
- %197 = OpFunctionCall %void %tint_symbol_2 %199
+ %246 = OpLabel
+ %247 = OpFunctionCall %void %tint_symbol_2 %249
OpReturn
OpFunctionEnd
diff --git a/test/expressions/binary_expressions.wgsl.expected.wgsl b/test/expressions/binary_expressions.wgsl.expected.wgsl
index ec7e7d5..cd0bf55 100644
--- a/test/expressions/binary_expressions.wgsl.expected.wgsl
+++ b/test/expressions/binary_expressions.wgsl.expected.wgsl
@@ -62,6 +62,17 @@
r = (s % v);
}
+fn matrix_matrix_f32() {
+ var m34 : mat3x4<f32>;
+ var m43 : mat4x3<f32>;
+ var m33 : mat3x3<f32>;
+ var m44 : mat4x4<f32>;
+ m34 = (m34 + m34);
+ m34 = (m34 - m34);
+ m33 = (m43 * m34);
+ m44 = (m34 * m43);
+}
+
[[stage(fragment)]]
fn main() -> [[location(0)]] vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);