spirv-reader: expand OuterProduct to primitive ops
Bug: tint:3
Change-Id: Id6de3554d945bc743a484e80b494690c26552079
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37660
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 696b305..abd6e25 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -57,6 +57,7 @@
#include "src/ast/type/depth_texture_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
#include "src/ast/type/pointer_type.h"
#include "src/ast/type/storage_texture_type.h"
#include "src/ast/type/texture_type.h"
@@ -3013,6 +3014,10 @@
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
+ case SpvOpOuterProduct:
+ // Synthesize an outer product expression in its own statement.
+ return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
+
case SpvOpFunctionCall:
return EmitFunctionCall(inst);
@@ -3707,7 +3712,8 @@
// but only if they are defined in this function as well.
for (auto& id_def_info_pair : def_info_) {
const auto& inst = id_def_info_pair.second->inst;
- if (inst.opcode() == SpvOpVectorShuffle) {
+ const auto opcode = inst.opcode();
+ if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
// We might access the vector operands multiple times. Make sure they
// are evaluated only once.
for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
@@ -4578,6 +4584,52 @@
return {parser_impl_.ConvertType(inst.type_id()), call_expr};
}
+TypedExpression FunctionEmitter::MakeOuterProduct(
+ const spvtools::opt::Instruction& inst) {
+ // Synthesize the result.
+ auto col = MakeOperand(inst, 0);
+ auto row = MakeOperand(inst, 1);
+ auto* col_ty = col.type->As<ast::type::Vector>();
+ auto* row_ty = row.type->As<ast::type::Vector>();
+ auto* result_ty =
+ parser_impl_.ConvertType(inst.type_id())->As<ast::type::Matrix>();
+ if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() ||
+ result_ty->type() != row_ty->type() ||
+ result_ty->columns() != row_ty->size() ||
+ result_ty->rows() != col_ty->size()) {
+ Fail() << "invalid outer product instruction: bad types "
+ << inst.PrettyPrint();
+ return {};
+ }
+
+ // Example:
+ // c : vec3 column vector
+ // r : vec2 row vector
+ // OuterProduct c r : mat2x3 (2 columns, 3 rows)
+ // Result:
+ // | c.x * r.x c.x * r.y |
+ // | c.y * r.x c.y * r.y |
+ // | c.z * r.x c.z * r.y |
+
+ ast::ExpressionList result_columns;
+ for (uint32_t icol = 0; icol < result_ty->columns(); icol++) {
+ ast::ExpressionList result_row;
+ auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
+ Swizzle(icol));
+ for (uint32_t irow = 0; irow < result_ty->rows(); irow++) {
+ auto* column_factor = create<ast::MemberAccessorExpression>(
+ Source{}, col.expr, Swizzle(irow));
+ auto* elem = create<ast::BinaryExpression>(
+ Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
+ result_row.push_back(elem);
+ }
+ result_columns.push_back(
+ create<ast::TypeConstructorExpression>(Source{}, col_ty, result_row));
+ }
+ return {result_ty, create<ast::TypeConstructorExpression>(Source{}, result_ty,
+ result_columns)};
+}
+
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 9002ea8..74dc3ca 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -851,6 +851,11 @@
/// @returns an expression
TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst);
+ /// Generates an expression for a SPIR-V OpOuterProduct instruction.
+ /// @param inst the SPIR-V instruction
+ /// @returns an expression
+ TypedExpression MakeOuterProduct(const spvtools::opt::Instruction& inst);
+
/// Emits a texture builtin function call for a SPIR-V instruction that
/// accesses an image or sampled image.
/// @param inst the SPIR-V instruction
diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc
index 044fca5..b5c5f20 100644
--- a/src/reader/spirv/function_arithmetic_test.cc
+++ b/src/reader/spirv/function_arithmetic_test.cc
@@ -43,6 +43,7 @@
%int_40 = OpConstant %int 40
%float_50 = OpConstant %float 50
%float_60 = OpConstant %float 60
+ %float_70 = OpConstant %float 70
%ptr_uint = OpTypePointer Function %uint
%ptr_int = OpTypePointer Function %int
@@ -51,6 +52,7 @@
%v2uint = OpTypeVector %uint 2
%v2int = OpTypeVector %int 2
%v2float = OpTypeVector %float 2
+ %v3float = OpTypeVector %float 3
%v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
%v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10
@@ -58,10 +60,12 @@
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
+ %v3float_50_60_70 = OpConstantComposite %v2float %float_50 %float_60 %float_70
%m2v2float = OpTypeMatrix %v2float 2
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
+ %m2v3float = OpTypeMatrix %v3float 2
)";
}
@@ -1099,6 +1103,108 @@
<< ToString(p->get_module(), fe.ast_body());
}
+TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
+ // OpOuterProduct is expanded to basic operations.
+ // The operands, even if used once, are given their own const definitions.
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpFAdd %v3float %v3float_50_60_70 %v3float_50_60_70 ; column vector
+ %2 = OpFAdd %v2float %v2float_60_50 %v2float_50_60 ; row vector
+ %3 = OpOuterProduct %m2v3float %1 %2
+ OpReturn
+ OpFunctionEnd
+)";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto got = ToString(p->get_module(), fe.ast_body());
+ EXPECT_THAT(got, HasSubstr(R"(VariableConst{
+ x_3
+ none
+ __mat_3_2__f32
+ {
+ TypeConstructor[not set]{
+ __mat_3_2__f32
+ TypeConstructor[not set]{
+ __vec_3__f32
+ Binary[not set]{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_2}
+ Identifier[not set]{x}
+ }
+ multiply
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{x}
+ }
+ }
+ Binary[not set]{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_2}
+ Identifier[not set]{x}
+ }
+ multiply
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{y}
+ }
+ }
+ Binary[not set]{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_2}
+ Identifier[not set]{x}
+ }
+ multiply
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{z}
+ }
+ }
+ }
+ TypeConstructor[not set]{
+ __vec_3__f32
+ Binary[not set]{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_2}
+ Identifier[not set]{y}
+ }
+ multiply
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{x}
+ }
+ }
+ Binary[not set]{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_2}
+ Identifier[not set]{y}
+ }
+ multiply
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{y}
+ }
+ }
+ Binary[not set]{
+ MemberAccessor[not set]{
+ Identifier[not set]{x_2}
+ Identifier[not set]{y}
+ }
+ multiply
+ MemberAccessor[not set]{
+ Identifier[not set]{x_1}
+ Identifier[not set]{z}
+ }
+ }
+ }
+ }
+ }
+ })"))
+ << got;
+}
+
// TODO(dneto): OpSRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702