[ir] Move dot over to spirv builtins.
Move the dot SPIR-V intrinsics over to a builtin.
Bug: tint:1718
Change-Id: I26af2068e43afa764e0a574927c47b96fbe3082d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/150087
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/intrinsic/data/data.cc b/src/tint/lang/spirv/intrinsic/data/data.cc
index 24c1125..c7b4f7a 100644
--- a/src/tint/lang/spirv/intrinsic/data/data.cc
+++ b/src/tint/lang/spirv/intrinsic/data/data.cc
@@ -642,9 +642,9 @@
/* [4] */ TypeMatcherIndex(22),
/* [5] */ TypeMatcherIndex(23),
/* [6] */ TypeMatcherIndex(0),
- /* [7] */ TypeMatcherIndex(19),
+ /* [7] */ TypeMatcherIndex(18),
/* [8] */ TypeMatcherIndex(0),
- /* [9] */ TypeMatcherIndex(18),
+ /* [9] */ TypeMatcherIndex(19),
/* [10] */ TypeMatcherIndex(0),
/* [11] */ TypeMatcherIndex(5),
/* [12] */ TypeMatcherIndex(1),
@@ -746,53 +746,59 @@
/* [12] */
/* usage */ core::ParameterUsage::kNone,
/* type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(5),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(1),
},
{
/* [13] */
/* usage */ core::ParameterUsage::kNone,
/* type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(7),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(1),
},
{
/* [14] */
/* usage */ core::ParameterUsage::kNone,
- /* type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(8),
+ /* type_matcher_indices */ TypeMatcherIndicesIndex(9),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(7),
},
{
/* [15] */
/* usage */ core::ParameterUsage::kNone,
- /* type_matcher_indices */ TypeMatcherIndicesIndex(2),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(/* invalid */),
+ /* type_matcher_indices */ TypeMatcherIndicesIndex(9),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(5),
},
{
/* [16] */
/* usage */ core::ParameterUsage::kNone,
- /* type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(8),
+ /* type_matcher_indices */ TypeMatcherIndicesIndex(9),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(7),
},
{
/* [17] */
/* usage */ core::ParameterUsage::kNone,
/* type_matcher_indices */ TypeMatcherIndicesIndex(9),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(1),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(8),
},
{
/* [18] */
/* usage */ core::ParameterUsage::kNone,
- /* type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(7),
+ /* type_matcher_indices */ TypeMatcherIndicesIndex(2),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(/* invalid */),
},
{
/* [19] */
/* usage */ core::ParameterUsage::kNone,
/* type_matcher_indices */ TypeMatcherIndicesIndex(9),
- /* number_matcher_indices */ NumberMatcherIndicesIndex(1),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(8),
},
{
/* [20] */
/* usage */ core::ParameterUsage::kNone,
+ /* type_matcher_indices */ TypeMatcherIndicesIndex(7),
+ /* number_matcher_indices */ NumberMatcherIndicesIndex(1),
+ },
+ {
+ /* [21] */
+ /* usage */ core::ParameterUsage::kNone,
/* type_matcher_indices */ TypeMatcherIndicesIndex(2),
/* number_matcher_indices */ NumberMatcherIndicesIndex(/* invalid */),
},
@@ -926,12 +932,12 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 2,
/* num_template_types */ 1,
- /* num_template_numbers */ 3,
+ /* num_template_numbers */ 1,
/* template_types */ TemplateTypeIndex(3),
- /* template_numbers */ TemplateNumberIndex(0),
+ /* template_numbers */ TemplateNumberIndex(3),
/* parameters */ ParameterIndex(12),
- /* return_type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* return_number_matcher_indices */ NumberMatcherIndicesIndex(3),
+ /* return_type_matcher_indices */ TypeMatcherIndicesIndex(2),
+ /* return_number_matcher_indices */ NumberMatcherIndicesIndex(/* invalid */),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -939,12 +945,12 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 2,
/* num_template_types */ 1,
- /* num_template_numbers */ 2,
+ /* num_template_numbers */ 3,
/* template_types */ TemplateTypeIndex(3),
- /* template_numbers */ TemplateNumberIndex(3),
- /* parameters */ ParameterIndex(14),
- /* return_type_matcher_indices */ TypeMatcherIndicesIndex(7),
- /* return_number_matcher_indices */ NumberMatcherIndicesIndex(8),
+ /* template_numbers */ TemplateNumberIndex(0),
+ /* parameters */ ParameterIndex(15),
+ /* return_type_matcher_indices */ TypeMatcherIndicesIndex(9),
+ /* return_number_matcher_indices */ NumberMatcherIndicesIndex(3),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -955,9 +961,9 @@
/* num_template_numbers */ 2,
/* template_types */ TemplateTypeIndex(3),
/* template_numbers */ TemplateNumberIndex(3),
- /* parameters */ ParameterIndex(16),
+ /* parameters */ ParameterIndex(17),
/* return_type_matcher_indices */ TypeMatcherIndicesIndex(9),
- /* return_number_matcher_indices */ NumberMatcherIndicesIndex(3),
+ /* return_number_matcher_indices */ NumberMatcherIndicesIndex(8),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
{
@@ -968,8 +974,8 @@
/* num_template_numbers */ 2,
/* template_types */ TemplateTypeIndex(3),
/* template_numbers */ TemplateNumberIndex(3),
- /* parameters */ ParameterIndex(17),
- /* return_type_matcher_indices */ TypeMatcherIndicesIndex(9),
+ /* parameters */ ParameterIndex(19),
+ /* return_type_matcher_indices */ TypeMatcherIndicesIndex(7),
/* return_number_matcher_indices */ NumberMatcherIndicesIndex(3),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -978,11 +984,24 @@
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 2,
/* num_template_types */ 1,
+ /* num_template_numbers */ 2,
+ /* template_types */ TemplateTypeIndex(3),
+ /* template_numbers */ TemplateNumberIndex(3),
+ /* parameters */ ParameterIndex(13),
+ /* return_type_matcher_indices */ TypeMatcherIndicesIndex(7),
+ /* return_number_matcher_indices */ NumberMatcherIndicesIndex(3),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [9] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
+ /* num_parameters */ 2,
+ /* num_template_types */ 1,
/* num_template_numbers */ 1,
/* template_types */ TemplateTypeIndex(3),
/* template_numbers */ TemplateNumberIndex(3),
- /* parameters */ ParameterIndex(19),
- /* return_type_matcher_indices */ TypeMatcherIndicesIndex(9),
+ /* parameters */ ParameterIndex(20),
+ /* return_type_matcher_indices */ TypeMatcherIndicesIndex(7),
/* return_number_matcher_indices */ NumberMatcherIndicesIndex(1),
/* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
},
@@ -1078,34 +1097,40 @@
},
{
/* [14] */
- /* fn matrix_times_matrix<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
+ /* fn dot<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> T */
/* num overloads */ 1,
/* overloads */ OverloadIndex(4),
},
{
/* [15] */
- /* fn matrix_times_scalar<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
+ /* fn matrix_times_matrix<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(5),
},
{
/* [16] */
- /* fn matrix_times_vector<T : f32_f16, N : num, M : num>(mat<N, M, T>, vec<N, T>) -> vec<M, T> */
+ /* fn matrix_times_scalar<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(6),
},
{
/* [17] */
- /* fn vector_times_matrix<T : f32_f16, N : num, M : num>(vec<N, T>, mat<M, N, T>) -> vec<M, T> */
+ /* fn matrix_times_vector<T : f32_f16, N : num, M : num>(mat<N, M, T>, vec<N, T>) -> vec<M, T> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(7),
},
{
/* [18] */
- /* fn vector_times_scalar<T : f32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
+ /* fn vector_times_matrix<T : f32_f16, N : num, M : num>(vec<N, T>, mat<M, N, T>) -> vec<M, T> */
/* num overloads */ 1,
/* overloads */ OverloadIndex(8),
},
+ {
+ /* [19] */
+ /* fn vector_times_scalar<T : f32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(9),
+ },
};
// clang-format on
diff --git a/src/tint/lang/spirv/ir/function.cc b/src/tint/lang/spirv/ir/function.cc
index 0e47acc..b557bfe 100644
--- a/src/tint/lang/spirv/ir/function.cc
+++ b/src/tint/lang/spirv/ir/function.cc
@@ -57,6 +57,8 @@
return "spirv.atomic_umin";
case Function::kAtomicXor:
return "spirv.atomic_xor";
+ case Function::kDot:
+ return "spirv.dot";
case Function::kMatrixTimesMatrix:
return "spirv.matrix_times_matrix";
case Function::kMatrixTimesScalar:
diff --git a/src/tint/lang/spirv/ir/function.h b/src/tint/lang/spirv/ir/function.h
index 0c6ac9e..23fc191 100644
--- a/src/tint/lang/spirv/ir/function.h
+++ b/src/tint/lang/spirv/ir/function.h
@@ -48,6 +48,7 @@
kAtomicUmax,
kAtomicUmin,
kAtomicXor,
+ kDot,
kMatrixTimesMatrix,
kMatrixTimesScalar,
kMatrixTimesVector,
diff --git a/src/tint/lang/spirv/ir/intrinsic.cc b/src/tint/lang/spirv/ir/intrinsic.cc
index 69b3694..63bf520 100644
--- a/src/tint/lang/spirv/ir/intrinsic.cc
+++ b/src/tint/lang/spirv/ir/intrinsic.cc
@@ -29,9 +29,6 @@
/// @param str the string to parse
/// @returns the parsed enum, or Intrinsic::kUndefined if the string could not be parsed.
Intrinsic ParseIntrinsic(std::string_view str) {
- if (str == "dot") {
- return Intrinsic::kDot;
- }
if (str == "image_dref_gather") {
return Intrinsic::kImageDrefGather;
}
@@ -78,8 +75,6 @@
switch (value) {
case Intrinsic::kUndefined:
return "undefined";
- case Intrinsic::kDot:
- return "dot";
case Intrinsic::kImageDrefGather:
return "image_dref_gather";
case Intrinsic::kImageFetch:
diff --git a/src/tint/lang/spirv/ir/intrinsic.h b/src/tint/lang/spirv/ir/intrinsic.h
index c34bd40..c88f427 100644
--- a/src/tint/lang/spirv/ir/intrinsic.h
+++ b/src/tint/lang/spirv/ir/intrinsic.h
@@ -34,7 +34,6 @@
/// Intrinsic
enum class Intrinsic : uint8_t {
kUndefined,
- kDot,
kImageDrefGather,
kImageFetch,
kImageGather,
@@ -68,7 +67,6 @@
Intrinsic ParseIntrinsic(std::string_view str);
constexpr const char* kIntrinsicStrings[] = {
- "dot",
"image_dref_gather",
"image_fetch",
"image_gather",
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def
index 40756a5..cf35b33 100644
--- a/src/tint/lang/spirv/spirv.def
+++ b/src/tint/lang/spirv/spirv.def
@@ -75,7 +75,6 @@
////////////////////////////////////////////////////////////////////////////////
enum intrinsic {
- dot
image_fetch
image_gather
image_dref_gather
@@ -108,6 +107,7 @@
@stage("fragment", "compute") fn atomic_umax<T: iu32, U: u32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, U, U, T) -> T
@stage("fragment", "compute") fn atomic_umin<T: iu32, U: u32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, U, U, T) -> T
@stage("fragment", "compute") fn atomic_xor<T: iu32, U: u32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, U, U, T) -> T
+fn dot<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> T
fn matrix_times_matrix<T: f32_f16, K: num, C: num, R: num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
fn matrix_times_scalar<T: f32_f16, N: num, M: num>(mat<N, M, T>, T) -> mat<N, M, T>
fn matrix_times_vector<T: f32_f16, N: num, M: num>(mat<N, M, T>, vec<N, T>) -> vec<M, T>
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index f4b0f65..7761590 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1134,6 +1134,9 @@
case spirv::ir::Function::kAtomicXor:
op = spv::Op::OpAtomicXor;
break;
+ case spirv::ir::Function::kDot:
+ op = spv::Op::OpDot;
+ break;
case spirv::ir::Function::kMatrixTimesMatrix:
op = spv::Op::OpMatrixTimesMatrix;
break;
@@ -1613,9 +1616,6 @@
spv::Op op = spv::Op::Max;
switch (call->Kind()) {
- case spirv::ir::Intrinsic::kDot:
- op = spv::Op::OpDot;
- break;
case spirv::ir::Intrinsic::kImageFetch:
op = spv::Op::OpImageFetch;
break;
diff --git a/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc b/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc
index 01b27a1..df86a78 100644
--- a/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc
@@ -329,8 +329,8 @@
// Replace the builtin call with a call to the spirv.dot intrinsic.
auto args = Vector<core::ir::Value*, 4>(builtin->Args());
- auto* call = b.Call<spirv::ir::IntrinsicCall>(builtin->Result()->Type(),
- spirv::ir::Intrinsic::kDot, std::move(args));
+ auto* call = b.Call<spirv::ir::BuiltinCall>(builtin->Result()->Type(),
+ spirv::ir::Function::kDot, std::move(args));
call->InsertBefore(builtin);
return call->Result();
}