[spirv-reader][ir] Add OpOuterProduct support. Bug: 391486476 Change-Id: I899d09b53c0f45bfc619423c5e10603ae439f99d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/231354 Reviewed-by: James Price <jrprice@google.com> Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc index ccf7cb9..fde80f6 100644 --- a/src/tint/lang/spirv/builtin_fn.cc +++ b/src/tint/lang/spirv/builtin_fn.cc
@@ -212,6 +212,8 @@ return "s_negate"; case BuiltinFn::kFMod: return "f_mod"; + case BuiltinFn::kOuterProduct: + return "outer_product"; case BuiltinFn::kSDot: return "s_dot"; case BuiltinFn::kUDot: @@ -326,6 +328,7 @@ case BuiltinFn::kNot: case BuiltinFn::kSNegate: case BuiltinFn::kFMod: + case BuiltinFn::kOuterProduct: break; } return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl index e04e1b3..1f39b83 100644 --- a/src/tint/lang/spirv/builtin_fn.cc.tmpl +++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl
@@ -127,6 +127,7 @@ case BuiltinFn::kNot: case BuiltinFn::kSNegate: case BuiltinFn::kFMod: + case BuiltinFn::kOuterProduct: break; } return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h index 23d96ca..653a833 100644 --- a/src/tint/lang/spirv/builtin_fn.h +++ b/src/tint/lang/spirv/builtin_fn.h
@@ -133,6 +133,7 @@ kNot, kSNegate, kFMod, + kOuterProduct, kSDot, kUDot, kCooperativeMatrixLoad,
diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc index 415ddf0..5bd1f22 100644 --- a/src/tint/lang/spirv/intrinsic/data.cc +++ b/src/tint/lang/spirv/intrinsic/data.cc
@@ -3393,62 +3393,62 @@ { /* [359] */ /* usage */ core::ParameterUsage::kNone, - /* matcher_indices */ MatcherIndicesIndex(168), + /* matcher_indices */ MatcherIndicesIndex(159), }, { /* [360] */ /* usage */ core::ParameterUsage::kNone, - /* matcher_indices */ MatcherIndicesIndex(171), + /* matcher_indices */ MatcherIndicesIndex(9), }, { /* [361] */ - /* usage */ core::ParameterUsage::kX, - /* matcher_indices */ MatcherIndicesIndex(4), + /* usage */ core::ParameterUsage::kNone, + /* matcher_indices */ MatcherIndicesIndex(168), }, { /* [362] */ - /* usage */ core::ParameterUsage::kI, - /* matcher_indices */ MatcherIndicesIndex(73), + /* usage */ core::ParameterUsage::kNone, + /* matcher_indices */ MatcherIndicesIndex(171), }, { /* [363] */ /* usage */ core::ParameterUsage::kX, - /* matcher_indices */ MatcherIndicesIndex(2), + /* matcher_indices */ MatcherIndicesIndex(4), }, { /* [364] */ /* usage */ core::ParameterUsage::kI, - /* matcher_indices */ MatcherIndicesIndex(0), + /* matcher_indices */ MatcherIndicesIndex(73), }, { /* [365] */ /* usage */ core::ParameterUsage::kX, - /* matcher_indices */ MatcherIndicesIndex(4), + /* matcher_indices */ MatcherIndicesIndex(2), }, { /* [366] */ /* usage */ core::ParameterUsage::kI, - /* matcher_indices */ MatcherIndicesIndex(77), + /* matcher_indices */ MatcherIndicesIndex(0), }, { /* [367] */ /* usage */ core::ParameterUsage::kX, - /* matcher_indices */ MatcherIndicesIndex(159), + /* matcher_indices */ MatcherIndicesIndex(4), }, { /* [368] */ /* usage */ core::ParameterUsage::kI, - /* matcher_indices */ MatcherIndicesIndex(7), + /* matcher_indices */ MatcherIndicesIndex(77), }, { /* [369] */ - /* usage */ core::ParameterUsage::kNone, + /* usage */ core::ParameterUsage::kX, /* matcher_indices */ MatcherIndicesIndex(159), }, { /* [370] */ - /* usage */ core::ParameterUsage::kNone, - /* matcher_indices */ MatcherIndicesIndex(9), + /* usage */ core::ParameterUsage::kI, + /* matcher_indices */ MatcherIndicesIndex(7), }, { /* [371] */ @@ -5853,7 +5853,7 @@ /* num_explicit_templates */ 1, /* num_templates */ 3, /* templates */ TemplateIndex(70), - /* parameters */ ParameterIndex(370), + /* parameters */ ParameterIndex(360), /* return_matcher_indices */ MatcherIndicesIndex(159), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -5875,7 +5875,7 @@ /* num_explicit_templates */ 1, /* num_templates */ 4, /* templates */ TemplateIndex(42), - /* parameters */ ParameterIndex(359), + /* parameters */ ParameterIndex(361), /* return_matcher_indices */ MatcherIndicesIndex(165), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -5996,7 +5996,7 @@ /* num_explicit_templates */ 0, /* num_templates */ 2, /* templates */ TemplateIndex(94), - /* parameters */ ParameterIndex(361), + /* parameters */ ParameterIndex(363), /* return_matcher_indices */ MatcherIndicesIndex(4), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6007,7 +6007,7 @@ /* num_explicit_templates */ 0, /* num_templates */ 3, /* templates */ TemplateIndex(73), - /* parameters */ ParameterIndex(363), + /* parameters */ ParameterIndex(365), /* return_matcher_indices */ MatcherIndicesIndex(2), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6018,7 +6018,7 @@ /* num_explicit_templates */ 0, /* num_templates */ 3, /* templates */ TemplateIndex(76), - /* parameters */ ParameterIndex(365), + /* parameters */ ParameterIndex(367), /* return_matcher_indices */ MatcherIndicesIndex(4), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6029,7 +6029,7 @@ /* num_explicit_templates */ 0, /* num_templates */ 4, /* templates */ TemplateIndex(47), - /* parameters */ ParameterIndex(367), + /* parameters */ ParameterIndex(369), /* return_matcher_indices */ MatcherIndicesIndex(159), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6095,7 +6095,7 @@ /* num_explicit_templates */ 1, /* num_templates */ 4, /* templates */ TemplateIndex(51), - /* parameters */ ParameterIndex(359), + /* parameters */ ParameterIndex(361), /* return_matcher_indices */ MatcherIndicesIndex(165), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6117,7 +6117,7 @@ /* num_explicit_templates */ 0, /* num_templates */ 3, /* templates */ TemplateIndex(52), - /* parameters */ ParameterIndex(369), + /* parameters */ ParameterIndex(359), /* return_matcher_indices */ MatcherIndicesIndex(189), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6139,7 +6139,7 @@ /* num_explicit_templates */ 1, /* num_templates */ 3, /* templates */ TemplateIndex(79), - /* parameters */ ParameterIndex(370), + /* parameters */ ParameterIndex(360), /* return_matcher_indices */ MatcherIndicesIndex(159), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6161,7 +6161,7 @@ /* num_explicit_templates */ 1, /* num_templates */ 3, /* templates */ TemplateIndex(82), - /* parameters */ ParameterIndex(370), + /* parameters */ ParameterIndex(360), /* return_matcher_indices */ MatcherIndicesIndex(159), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6183,7 +6183,7 @@ /* num_explicit_templates */ 1, /* num_templates */ 3, /* templates */ TemplateIndex(85), - /* parameters */ ParameterIndex(370), + /* parameters */ ParameterIndex(360), /* return_matcher_indices */ MatcherIndicesIndex(159), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, @@ -6278,12 +6278,12 @@ { /* [202] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* num_parameters */ 3, + /* num_parameters */ 2, /* num_explicit_templates */ 0, - /* num_templates */ 0, - /* templates */ TemplateIndex(/* invalid */), - /* parameters */ ParameterIndex(127), - /* return_matcher_indices */ MatcherIndicesIndex(124), + /* num_templates */ 3, + /* templates */ TemplateIndex(67), + /* parameters */ ParameterIndex(358), + /* return_matcher_indices */ MatcherIndicesIndex(65), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { @@ -6294,11 +6294,22 @@ /* num_templates */ 0, /* templates */ TemplateIndex(/* invalid */), /* parameters */ ParameterIndex(127), - /* return_matcher_indices */ MatcherIndicesIndex(127), + /* return_matcher_indices */ MatcherIndicesIndex(124), /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { /* [204] */ + /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), + /* num_parameters */ 3, + /* num_explicit_templates */ 0, + /* num_templates */ 0, + /* templates */ TemplateIndex(/* invalid */), + /* parameters */ ParameterIndex(127), + /* return_matcher_indices */ MatcherIndicesIndex(127), + /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), + }, + { + /* [205] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 4, /* num_explicit_templates */ 1, @@ -6309,7 +6320,7 @@ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [205] */ + /* [206] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 5, /* num_explicit_templates */ 0, @@ -6320,7 +6331,7 @@ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [206] */ + /* [207] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsComputePipeline, OverloadFlag::kMustUse), /* num_parameters */ 4, /* num_explicit_templates */ 0, @@ -7048,34 +7059,40 @@ }, { /* [85] */ - /* fn s_dot(u32, u32, u32) -> i32 */ + /* fn outer_product[T : f32_f16, N : num, M : num](vec<N, T>, vec<M, T>) -> mat<M, N, T> */ /* num overloads */ 1, /* overloads */ OverloadIndex(202), }, { /* [86] */ - /* fn u_dot(u32, u32, u32) -> u32 */ + /* fn s_dot(u32, u32, u32) -> i32 */ /* num overloads */ 1, /* overloads */ OverloadIndex(203), }, { /* [87] */ - /* fn cooperative_matrix_load<T : subgroup_matrix<K, S, C, R>>[K : subgroup_matrix_kind, S : fiu32_f16, C : num, R : num](ptr<workgroup_or_storage, S, readable>, u32, u32, u32) -> T */ + /* fn u_dot(u32, u32, u32) -> u32 */ /* num overloads */ 1, /* overloads */ OverloadIndex(204), }, { /* [88] */ - /* fn cooperative_matrix_store[K : subgroup_matrix_kind, S : fiu32_f16, C : num, R : num](ptr<workgroup_or_storage, S, writable>, subgroup_matrix<K, S, C, R>, u32, u32, u32) */ + /* fn cooperative_matrix_load<T : subgroup_matrix<K, S, C, R>>[K : subgroup_matrix_kind, S : fiu32_f16, C : num, R : num](ptr<workgroup_or_storage, S, readable>, u32, u32, u32) -> T */ /* num overloads */ 1, /* overloads */ OverloadIndex(205), }, { /* [89] */ - /* fn cooperative_matrix_mul_add[T : subgroup_matrix_elements, TR : subgroup_matrix_elements, C : num, R : num, K : num](subgroup_matrix<subgroup_matrix_kind_left, T, K, R>, subgroup_matrix<subgroup_matrix_kind_right, T, C, K>, subgroup_matrix<subgroup_matrix_kind_result, TR, C, R>, u32) -> subgroup_matrix<subgroup_matrix_kind_result, TR, C, R> */ + /* fn cooperative_matrix_store[K : subgroup_matrix_kind, S : fiu32_f16, C : num, R : num](ptr<workgroup_or_storage, S, writable>, subgroup_matrix<K, S, C, R>, u32, u32, u32) */ /* num overloads */ 1, /* overloads */ OverloadIndex(206), }, + { + /* [90] */ + /* fn cooperative_matrix_mul_add[T : subgroup_matrix_elements, TR : subgroup_matrix_elements, C : num, R : num, K : num](subgroup_matrix<subgroup_matrix_kind_left, T, K, R>, subgroup_matrix<subgroup_matrix_kind_right, T, C, K>, subgroup_matrix<subgroup_matrix_kind_result, TR, C, R>, u32) -> subgroup_matrix<subgroup_matrix_kind_result, TR, C, R> */ + /* num overloads */ 1, + /* overloads */ OverloadIndex(207), + }, }; // clang-format on
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc index 5fe41de..0d194d3 100644 --- a/src/tint/lang/spirv/reader/lower/builtins.cc +++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -215,11 +215,44 @@ case spirv::BuiltinFn::kSelect: Select(builtin); break; + case spirv::BuiltinFn::kOuterProduct: + OuterProduct(builtin); + break; default: TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func(); } } } + void OuterProduct(spirv::ir::BuiltinCall* call) { + auto* vector1 = call->Args()[0]; + auto* vector2 = call->Args()[1]; + + uint32_t rows = vector1->Type()->As<core::type::Vector>()->Width(); + uint32_t cols = vector2->Type()->As<core::type::Vector>()->Width(); + + auto* elem_ty = vector1->Type()->DeepestElement(); + + b.InsertBefore(call, [&] { + Vector<core::ir::Value*, 4> col_vectors; + + for (uint32_t col = 0; col < cols; ++col) { + Vector<core::ir::Value*, 4> col_elements; + auto* v2_element = b.Access(elem_ty, vector2, u32(col)); + + for (uint32_t row = 0; row < rows; ++row) { + auto* v1_element = b.Access(elem_ty, vector1, u32(row)); + auto* result = b.Multiply(elem_ty, v1_element, v2_element)->Result(0); + col_elements.Push(result); + } + + auto* row_vector = b.Construct(ty.vec(elem_ty, rows), col_elements)->Result(0); + col_vectors.Push(row_vector); + } + b.ConstructWithResult(call->DetachResult(), col_vectors); + }); + + call->Destroy(); + } void Select(spirv::ir::BuiltinCall* call) { auto* cond = call->Args()[0];
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc index 82b806b..ac934bd 100644 --- a/src/tint/lang/spirv/reader/lower/builtins_test.cc +++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -9121,5 +9121,61 @@ EXPECT_EQ(expect, str()); } +TEST_F(SpirvReader_BuiltinsTest, OuterProduct_Vector) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + // Call the OuterProduct builtin function + b.Call<spirv::ir::BuiltinCall>(ty.mat2x4<f32>(), spirv::BuiltinFn::kOuterProduct, + b.Splat<vec4<f32>>(1_f), b.Splat<vec2<f32>>(2_f)); + b.Return(ep); + }); + + // Expected SPIR-V source code + auto src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x4<f32> = spirv.outer_product vec4<f32>(1.0f), vec2<f32>(2.0f) + ret + } +} +)"; + EXPECT_EQ(src, str()); + + // Run the test + Run(Builtins); + + // Updated expected expanded SPIR-V code after lowering + auto expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:f32 = access vec2<f32>(2.0f), 0u + %3:f32 = access vec4<f32>(1.0f), 0u + %4:f32 = mul %3, %2 + %5:f32 = access vec4<f32>(1.0f), 1u + %6:f32 = mul %5, %2 + %7:f32 = access vec4<f32>(1.0f), 2u + %8:f32 = mul %7, %2 + %9:f32 = access vec4<f32>(1.0f), 3u + %10:f32 = mul %9, %2 + %11:vec4<f32> = construct %4, %6, %8, %10 + %12:f32 = access vec2<f32>(2.0f), 1u + %13:f32 = access vec4<f32>(1.0f), 0u + %14:f32 = mul %13, %12 + %15:f32 = access vec4<f32>(1.0f), 1u + %16:f32 = mul %15, %12 + %17:f32 = access vec4<f32>(1.0f), 2u + %18:f32 = mul %17, %12 + %19:f32 = access vec4<f32>(1.0f), 3u + %20:f32 = mul %19, %12 + %21:vec4<f32> = construct %14, %16, %18, %20 + %22:mat2x4<f32> = construct %11, %21 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + } // namespace } // namespace tint::spirv::reader::lower
diff --git a/src/tint/lang/spirv/reader/parser/builtin_test.cc b/src/tint/lang/spirv/reader/parser/builtin_test.cc index 2837a95..8d3f4ea 100644 --- a/src/tint/lang/spirv/reader/parser/builtin_test.cc +++ b/src/tint/lang/spirv/reader/parser/builtin_test.cc
@@ -1621,5 +1621,81 @@ )"); } +TEST_F(SpirvParserTest, OuterProductVec2Vec3) { + EXPECT_IR(R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + %void = OpTypeVoid +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v3float = OpTypeVector %float 3 +%mat2x3float = OpTypeMatrix %v3float 2 +%ep_type = OpTypeFunction %void + +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%float_3 = OpConstant %float 3 +%float_4 = OpConstant %float 4 +%float_5 = OpConstant %float 5 + +%vec2 = OpConstantComposite %v2float %float_1 %float_2 +%vec3 = OpConstantComposite %v3float %float_3 %float_4 %float_5 + + %main = OpFunction %void None %ep_type +%entry = OpLabel + %1 = OpOuterProduct %mat2x3float %vec3 %vec2 + OpReturn + OpFunctionEnd)", + R"( +%main = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x3<f32> = spirv.outer_product vec3<f32>(3.0f, 4.0f, 5.0f), vec2<f32>(1.0f, 2.0f) + ret + } +} +)"); +} + +TEST_F(SpirvParserTest, OuterProductVec3Vec2) { + EXPECT_IR(R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + %void = OpTypeVoid +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v3float = OpTypeVector %float 3 +%mat3x2float = OpTypeMatrix %v2float 3 +%ep_type = OpTypeFunction %void + +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%float_3 = OpConstant %float 3 +%float_4 = OpConstant %float 4 +%float_5 = OpConstant %float 5 + +%vec2 = OpConstantComposite %v2float %float_1 %float_2 +%vec3 = OpConstantComposite %v3float %float_3 %float_4 %float_5 + + %main = OpFunction %void None %ep_type +%entry = OpLabel + %1 = OpOuterProduct %mat3x2float %vec2 %vec3 + OpReturn + OpFunctionEnd)", + R"( +%main = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat3x2<f32> = spirv.outer_product vec2<f32>(1.0f, 2.0f), vec3<f32>(3.0f, 4.0f, 5.0f) + ret + } +} +)"); +} + } // namespace } // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc index 32dd6f9..56cf033 100644 --- a/src/tint/lang/spirv/reader/parser/parser.cc +++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -1129,6 +1129,9 @@ case spv::Op::OpVectorExtractDynamic: EmitAccess(inst); break; + case spv::Op::OpOuterProduct: + EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kOuterProduct); + break; default: TINT_UNIMPLEMENTED() << "unhandled SPIR-V instruction: " << static_cast<uint32_t>(inst.opcode());
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def index c4e7773..242bf4f 100644 --- a/src/tint/lang/spirv/spirv.def +++ b/src/tint/lang/spirv/spirv.def
@@ -486,6 +486,8 @@ implicit(T: f32_f16) fn f_mod(T, T) -> T implicit(T: f32_f16, N: num) fn f_mod(vec<N, T>, vec<N, T>) -> vec<N, T> +implicit(T: f32_f16, N: num, M: num) fn outer_product(vec<N, T>, vec<M, T>) -> mat<M, N, T> + //////////////////////////////////////////////////////////////////////////////// // SPV_KHR_integer_dot_product instructions ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc index 94f8208..8e54b06 100644 --- a/src/tint/lang/spirv/writer/printer/printer.cc +++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1609,6 +1609,9 @@ case BuiltinFn::kFMod: op = spv::Op::OpFMod; break; + case BuiltinFn::kOuterProduct: + op = spv::Op::OpOuterProduct; + break; case spirv::BuiltinFn::kNone: TINT_ICE() << "undefined spirv ir function"; }