spirv-reader: expand OuterProduct to primitive ops

Bug: tint:3
Change-Id: Id6de3554d945bc743a484e80b494690c26552079
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37660
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 696b305..abd6e25 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -57,6 +57,7 @@
 #include "src/ast/type/depth_texture_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
 #include "src/ast/type/pointer_type.h"
 #include "src/ast/type/storage_texture_type.h"
 #include "src/ast/type/texture_type.h"
@@ -3013,6 +3014,10 @@
       return EmitConstDefOrWriteToHoistedVar(inst, expr);
     }
 
+    case SpvOpOuterProduct:
+      // Synthesize an outer product expression in its own statement.
+      return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
+
     case SpvOpFunctionCall:
       return EmitFunctionCall(inst);
 
@@ -3707,7 +3712,8 @@
   // but only if they are defined in this function as well.
   for (auto& id_def_info_pair : def_info_) {
     const auto& inst = id_def_info_pair.second->inst;
-    if (inst.opcode() == SpvOpVectorShuffle) {
+    const auto opcode = inst.opcode();
+    if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
       // We might access the vector operands multiple times. Make sure they
       // are evaluated only once.
       for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
@@ -4578,6 +4584,52 @@
   return {parser_impl_.ConvertType(inst.type_id()), call_expr};
 }
 
+TypedExpression FunctionEmitter::MakeOuterProduct(
+    const spvtools::opt::Instruction& inst) {
+  // Synthesize the result.
+  auto col = MakeOperand(inst, 0);
+  auto row = MakeOperand(inst, 1);
+  auto* col_ty = col.type->As<ast::type::Vector>();
+  auto* row_ty = row.type->As<ast::type::Vector>();
+  auto* result_ty =
+      parser_impl_.ConvertType(inst.type_id())->As<ast::type::Matrix>();
+  if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() ||
+      result_ty->type() != row_ty->type() ||
+      result_ty->columns() != row_ty->size() ||
+      result_ty->rows() != col_ty->size()) {
+    Fail() << "invalid outer product instruction: bad types "
+           << inst.PrettyPrint();
+    return {};
+  }
+
+  // Example:
+  //    c : vec3 column vector
+  //    r : vec2 row vector
+  //    OuterProduct c r : mat2x3 (2 columns, 3 rows)
+  //    Result:
+  //      | c.x * r.x   c.x * r.y |
+  //      | c.y * r.x   c.y * r.y |
+  //      | c.z * r.x   c.z * r.y |
+
+  ast::ExpressionList result_columns;
+  for (uint32_t icol = 0; icol < result_ty->columns(); icol++) {
+    ast::ExpressionList result_row;
+    auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
+                                                             Swizzle(icol));
+    for (uint32_t irow = 0; irow < result_ty->rows(); irow++) {
+      auto* column_factor = create<ast::MemberAccessorExpression>(
+          Source{}, col.expr, Swizzle(irow));
+      auto* elem = create<ast::BinaryExpression>(
+          Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
+      result_row.push_back(elem);
+    }
+    result_columns.push_back(
+        create<ast::TypeConstructorExpression>(Source{}, col_ty, result_row));
+  }
+  return {result_ty, create<ast::TypeConstructorExpression>(Source{}, result_ty,
+                                                            result_columns)};
+}
+
 FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
 FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
 
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 9002ea8..74dc3ca 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -851,6 +851,11 @@
   /// @returns an expression
   TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst);
 
+  /// Generates an expression for a SPIR-V OpOuterProduct instruction.
+  /// @param inst the SPIR-V instruction
+  /// @returns an expression
+  TypedExpression MakeOuterProduct(const spvtools::opt::Instruction& inst);
+
   /// Emits a texture builtin function call for a SPIR-V instruction that
   /// accesses an image or sampled image.
   /// @param inst the SPIR-V instruction
diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc
index 044fca5..b5c5f20 100644
--- a/src/reader/spirv/function_arithmetic_test.cc
+++ b/src/reader/spirv/function_arithmetic_test.cc
@@ -43,6 +43,7 @@
   %int_40 = OpConstant %int 40
   %float_50 = OpConstant %float 50
   %float_60 = OpConstant %float 60
+  %float_70 = OpConstant %float 70
 
   %ptr_uint = OpTypePointer Function %uint
   %ptr_int = OpTypePointer Function %int
@@ -51,6 +52,7 @@
   %v2uint = OpTypeVector %uint 2
   %v2int = OpTypeVector %int 2
   %v2float = OpTypeVector %float 2
+  %v3float = OpTypeVector %float 3
 
   %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
   %v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10
@@ -58,10 +60,12 @@
   %v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
   %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
   %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
+  %v3float_50_60_70 = OpConstantComposite %v2float %float_50 %float_60 %float_70
 
   %m2v2float = OpTypeMatrix %v2float 2
   %m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
   %m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
+  %m2v3float = OpTypeMatrix %v3float 2
 )";
 }
 
@@ -1099,6 +1103,108 @@
       << ToString(p->get_module(), fe.ast_body());
 }
 
+TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
+  // OpOuterProduct is expanded to basic operations.
+  // The operands, even if used once, are given their own const definitions.
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpFAdd %v3float %v3float_50_60_70 %v3float_50_60_70 ; column vector
+     %2 = OpFAdd %v2float %v2float_60_50 %v2float_50_60 ; row vector
+     %3 = OpOuterProduct %m2v3float %1 %2
+     OpReturn
+     OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->get_module(), fe.ast_body());
+  EXPECT_THAT(got, HasSubstr(R"(VariableConst{
+    x_3
+    none
+    __mat_3_2__f32
+    {
+      TypeConstructor[not set]{
+        __mat_3_2__f32
+        TypeConstructor[not set]{
+          __vec_3__f32
+          Binary[not set]{
+            MemberAccessor[not set]{
+              Identifier[not set]{x_2}
+              Identifier[not set]{x}
+            }
+            multiply
+            MemberAccessor[not set]{
+              Identifier[not set]{x_1}
+              Identifier[not set]{x}
+            }
+          }
+          Binary[not set]{
+            MemberAccessor[not set]{
+              Identifier[not set]{x_2}
+              Identifier[not set]{x}
+            }
+            multiply
+            MemberAccessor[not set]{
+              Identifier[not set]{x_1}
+              Identifier[not set]{y}
+            }
+          }
+          Binary[not set]{
+            MemberAccessor[not set]{
+              Identifier[not set]{x_2}
+              Identifier[not set]{x}
+            }
+            multiply
+            MemberAccessor[not set]{
+              Identifier[not set]{x_1}
+              Identifier[not set]{z}
+            }
+          }
+        }
+        TypeConstructor[not set]{
+          __vec_3__f32
+          Binary[not set]{
+            MemberAccessor[not set]{
+              Identifier[not set]{x_2}
+              Identifier[not set]{y}
+            }
+            multiply
+            MemberAccessor[not set]{
+              Identifier[not set]{x_1}
+              Identifier[not set]{x}
+            }
+          }
+          Binary[not set]{
+            MemberAccessor[not set]{
+              Identifier[not set]{x_2}
+              Identifier[not set]{y}
+            }
+            multiply
+            MemberAccessor[not set]{
+              Identifier[not set]{x_1}
+              Identifier[not set]{y}
+            }
+          }
+          Binary[not set]{
+            MemberAccessor[not set]{
+              Identifier[not set]{x_2}
+              Identifier[not set]{y}
+            }
+            multiply
+            MemberAccessor[not set]{
+              Identifier[not set]{x_1}
+              Identifier[not set]{z}
+            }
+          }
+        }
+      }
+    }
+  })"))
+      << got;
+}
+
 // TODO(dneto): OpSRem. Missing from WGSL
 // https://github.com/gpuweb/gpuweb/issues/702