blob: 4746914218cd3b28c808ff6e6680f277a5655751 [file] [log] [blame]
// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/spirv/reader/lower/builtins.h"
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/spirv/ir/builtin_call.h"
namespace tint::spirv::reader::lower {
namespace {
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
/// PIMPL state for the transform.
struct State {
/// The IR module.
core::ir::Module& ir;
/// The IR builder.
core::ir::Builder b{ir};
/// The type manager.
core::type::Manager& ty{ir.Types()};
/// Process the module.
void Process() {
Vector<spirv::ir::BuiltinCall*, 4> builtin_worklist;
for (auto* inst : ir.Instructions()) {
if (auto* builtin = inst->As<spirv::ir::BuiltinCall>()) {
builtin_worklist.Push(builtin);
}
}
// Replace the builtins that we found.
for (auto* builtin : builtin_worklist) {
switch (builtin->Func()) {
case spirv::BuiltinFn::kNormalize:
Normalize(builtin);
break;
case spirv::BuiltinFn::kInverse:
Inverse(builtin);
break;
case spirv::BuiltinFn::kSign:
Sign(builtin);
break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
}
}
}
// The `spirv.sign` method takes a `u32` operand which is not accepted in WGSL. If the operand
// is a `u32` or a `vec<N, u32>` then we need to bitcast the operand to `i32` before doing the
// comparison.
void Sign(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
b.InsertBefore(call, [&] {
auto* result_ty = call->Result(0)->Type();
if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result(0);
}
auto* new_call =
b.Call(result_ty, core::BuiltinFn::kSign, Vector<core::ir::Value*, 1>{arg});
core::ir::Value* replacement = new_call->Result(0);
// If the call is a `u32` result type, we need to cast it to `i32`.
if (result_ty->DeepestElement() == ty.u32()) {
new_call->Result(0)->SetType(ty.MatchWidth(ty.i32(), result_ty));
replacement = b.Bitcast(result_ty, replacement)->Result(0);
}
call->Result(0)->ReplaceAllUsesWith(replacement);
});
call->Destroy();
}
void Normalize(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
b.InsertBefore(call, [&] {
core::BuiltinFn fn = core::BuiltinFn::kNormalize;
if (arg->Type()->IsScalar()) {
fn = core::BuiltinFn::kSign;
}
b.CallWithResult(call->DetachResult(), fn, Vector<core::ir::Value*, 1>{arg});
});
call->Destroy();
}
void Inverse(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
auto* mat_ty = arg->Type()->As<core::type::Matrix>();
TINT_ASSERT(mat_ty);
TINT_ASSERT(mat_ty->Columns() == mat_ty->Rows());
auto* elem_ty = mat_ty->Type();
b.InsertBefore(call, [&] {
auto* det =
b.Call(elem_ty, core::BuiltinFn::kDeterminant, Vector<core::ir::Value*, 1>{arg});
core::ir::Value* one = nullptr;
if (elem_ty->Is<core::type::F32>()) {
one = b.Constant(1.0_f);
} else if (elem_ty->Is<core::type::F16>()) {
one = b.Constant(1.0_h);
} else {
TINT_UNREACHABLE();
}
auto* inv_det = b.Divide(elem_ty, one, det);
// Returns (m * n) - (o * p)
auto sub_mul2 = [&](auto* m, auto* n, auto* o, auto* p) {
auto* x = b.Multiply(elem_ty, m, n);
auto* y = b.Multiply(elem_ty, o, p);
return b.Subtract(elem_ty, x, y);
};
// Returns (m * n) - (o * p) + (q * r)
auto sub_add_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) {
auto* w = b.Multiply(elem_ty, m, n);
auto* x = b.Multiply(elem_ty, o, p);
auto* y = b.Multiply(elem_ty, q, r);
auto* z = b.Subtract(elem_ty, w, x);
return b.Add(elem_ty, z, y);
};
// Returns (m * n) + (o * p) - (q * r)
auto add_sub_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) {
auto* w = b.Multiply(elem_ty, m, n);
auto* x = b.Multiply(elem_ty, o, p);
auto* y = b.Multiply(elem_ty, q, r);
auto* z = b.Add(elem_ty, w, x);
return b.Subtract(elem_ty, z, y);
};
switch (mat_ty->Columns()) {
case 2: {
auto* neg_inv_det = b.Negation(elem_ty, inv_det);
auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
auto* mc = b.Access(elem_ty, arg, 1_u, 0_u);
auto* md = b.Access(elem_ty, arg, 1_u, 1_u);
auto* r_00 = b.Multiply(elem_ty, inv_det, md);
auto* r_01 = b.Multiply(elem_ty, neg_inv_det, mb);
auto* r_10 = b.Multiply(elem_ty, neg_inv_det, mc);
auto* r_11 = b.Multiply(elem_ty, inv_det, ma);
auto* r1 = b.Construct(ty.vec2(elem_ty), r_00, r_01);
auto* r2 = b.Construct(ty.vec2(elem_ty), r_10, r_11);
b.ConstructWithResult(call->DetachResult(), r1, r2);
break;
}
case 3: {
auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
auto* mc = b.Access(elem_ty, arg, 0_u, 2_u);
auto* md = b.Access(elem_ty, arg, 1_u, 0_u);
auto* me = b.Access(elem_ty, arg, 1_u, 1_u);
auto* mf = b.Access(elem_ty, arg, 1_u, 2_u);
auto* mg = b.Access(elem_ty, arg, 2_u, 0_u);
auto* mh = b.Access(elem_ty, arg, 2_u, 1_u);
auto* mi = b.Access(elem_ty, arg, 2_u, 2_u);
// e * i - f * h
auto* r_00 = sub_mul2(me, mi, mf, mh);
// c * h - b * i
auto* r_01 = sub_mul2(mc, mh, mb, mi);
// b * f - c * e
auto* r_02 = sub_mul2(mb, mf, mc, me);
// f * g - d * i
auto* r_10 = sub_mul2(mf, mg, md, mi);
// a * i - c * g
auto* r_11 = sub_mul2(ma, mi, mc, mg);
// c * d - a * f
auto* r_12 = sub_mul2(mc, md, ma, mf);
// d * h - e * g
auto* r_20 = sub_mul2(md, mh, me, mg);
// b * g - a * h
auto* r_21 = sub_mul2(mb, mg, ma, mh);
// a * e - b * d
auto* r_22 = sub_mul2(ma, me, mb, md);
auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02);
auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12);
auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22);
auto* m = b.Construct(mat_ty, r1, r2, r3);
auto* inv = b.Multiply(mat_ty, inv_det, m);
call->Result(0)->ReplaceAllUsesWith(inv->Result(0));
break;
}
case 4: {
auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
auto* mc = b.Access(elem_ty, arg, 0_u, 2_u);
auto* md = b.Access(elem_ty, arg, 0_u, 3_u);
auto* me = b.Access(elem_ty, arg, 1_u, 0_u);
auto* mf = b.Access(elem_ty, arg, 1_u, 1_u);
auto* mg = b.Access(elem_ty, arg, 1_u, 2_u);
auto* mh = b.Access(elem_ty, arg, 1_u, 3_u);
auto* mi = b.Access(elem_ty, arg, 2_u, 0_u);
auto* mj = b.Access(elem_ty, arg, 2_u, 1_u);
auto* mk = b.Access(elem_ty, arg, 2_u, 2_u);
auto* ml = b.Access(elem_ty, arg, 2_u, 3_u);
auto* mm = b.Access(elem_ty, arg, 3_u, 0_u);
auto* mn = b.Access(elem_ty, arg, 3_u, 1_u);
auto* mo = b.Access(elem_ty, arg, 3_u, 2_u);
auto* mp = b.Access(elem_ty, arg, 3_u, 3_u);
// kplo = k * p - l * o
auto* kplo = sub_mul2(mk, mp, ml, mo);
// jpln = j * p - l * n
auto* jpln = sub_mul2(mj, mp, ml, mn);
// jokn = j * o - k * n;
auto* jokn = sub_mul2(mj, mo, mk, mn);
// gpho = g * p - h * o
auto* gpho = sub_mul2(mg, mp, mh, mo);
// fphn = f * p - h * n
auto* fphn = sub_mul2(mf, mp, mh, mn);
// fogn = f * o - g * n;
auto* fogn = sub_mul2(mf, mo, mg, mn);
// glhk = g * l - h * k
auto* glhk = sub_mul2(mg, ml, mh, mk);
// flhj = f * l - h * j
auto* flhj = sub_mul2(mf, ml, mh, mj);
// fkgj = f * k - g * j;
auto* fkgj = sub_mul2(mf, mk, mg, mj);
// iplm = i * p - l * m
auto* iplm = sub_mul2(mi, mp, ml, mm);
// iokm = i * o - k * m
auto* iokm = sub_mul2(mi, mo, mk, mm);
// ephm = e * p - h * m;
auto* ephm = sub_mul2(me, mp, mh, mm);
// eogm = e * o - g * m
auto* eogm = sub_mul2(me, mo, mg, mm);
// elhi = e * l - h * i
auto* elhi = sub_mul2(me, ml, mh, mi);
// ekgi = e * k - g * i;
auto* ekgi = sub_mul2(me, mk, mg, mi);
// injm = i * n - j * m
auto* injm = sub_mul2(mi, mn, mj, mm);
// enfm = e * n - f * m
auto* enfm = sub_mul2(me, mn, mf, mm);
// ejfi = e * j - f * i;
auto* ejfi = sub_mul2(me, mj, mf, mi);
auto* neg_b = b.Negation(elem_ty, mb);
// f * kplo - g * jpln + h * jokn
auto* r_00 = sub_add_mul3(mf, kplo, mg, jpln, mh, jokn);
// -b * kplo + c * jpln - d * jokn
auto* r_01 = add_sub_mul3(neg_b, kplo, mc, jpln, md, jokn);
// b * gpho - c * fphn + d * fogn
auto* r_02 = sub_add_mul3(mb, gpho, mc, fphn, md, fogn);
// -b * glhk + c * flhj - d * fkgj
auto* r_03 = add_sub_mul3(neg_b, glhk, mc, flhj, md, fkgj);
auto* neg_e = b.Negation(elem_ty, me);
auto* neg_a = b.Negation(elem_ty, ma);
// -e * kplo + g * iplm - h * iokm
auto* r_10 = add_sub_mul3(neg_e, kplo, mg, iplm, mh, iokm);
// a * kplo - c * iplm + d * iokm
auto* r_11 = sub_add_mul3(ma, kplo, mc, iplm, md, iokm);
// -a * gpho + c * ephm - d * eogm
auto* r_12 = add_sub_mul3(neg_a, gpho, mc, ephm, md, eogm);
// a * glhk - c * elhi + d * ekgi
auto* r_13 = sub_add_mul3(ma, glhk, mc, elhi, md, ekgi);
// e * jpln - f * iplm + h * injm
auto* r_20 = sub_add_mul3(me, jpln, mf, iplm, mh, injm);
// -a * jpln + b * iplm - d * injm
auto* r_21 = add_sub_mul3(neg_a, jpln, mb, iplm, md, injm);
// a * fphn - b * ephm + d * enfm
auto* r_22 = sub_add_mul3(ma, fphn, mb, ephm, md, enfm);
// -a * flhj + b * elhi - d * ejfi
auto* r_23 = add_sub_mul3(neg_a, flhj, mb, elhi, md, ejfi);
// -e * jokn + f * iokm - g * injm
auto* r_30 = add_sub_mul3(neg_e, jokn, mf, iokm, mg, injm);
// a * jokn - b * iokm + c * injm
auto* r_31 = sub_add_mul3(ma, jokn, mb, iokm, mc, injm);
// -a * fogn + b * eogm - c * enfm
auto* r_32 = add_sub_mul3(neg_a, fogn, mb, eogm, mc, enfm);
// a * fkgj - b * ekgi + c * ejfi
auto* r_33 = sub_add_mul3(ma, fkgj, mb, ekgi, mc, ejfi);
auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02, r_03);
auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12, r_13);
auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22, r_23);
auto* r4 = b.Construct(ty.vec3(elem_ty), r_30, r_31, r_32, r_33);
auto* m = b.Construct(mat_ty, r1, r2, r3, r4);
auto* inv = b.Multiply(mat_ty, inv_det, m);
call->Result(0)->ReplaceAllUsesWith(inv->Result(0));
break;
}
default: {
TINT_UNREACHABLE();
}
}
});
call->Destroy();
}
};
} // namespace
Result<SuccessType> Builtins(core::ir::Module& ir) {
auto result = ValidateAndDumpIfNeeded(ir, "spirv.Builtins");
if (result != Success) {
return result.Failure();
}
State{ir}.Process();
return Success;
}
} // namespace tint::spirv::reader::lower