blob: ec04027a9a4be8c05463125a6da9700b24ea6976 [file] [log] [blame]
// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/spirv/reader/lower/builtins.h"
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/builtin_structs.h"
#include "src/tint/lang/spirv/ir/builtin_call.h"
namespace tint::spirv::reader::lower {
namespace {
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
/// PIMPL state for the transform.
struct State {
/// The IR module.
core::ir::Module& ir;
/// The IR builder.
core::ir::Builder b{ir};
/// The type manager.
core::type::Manager& ty{ir.Types()};
/// Process the module.
void Process() {
Vector<spirv::ir::BuiltinCall*, 4> builtin_worklist;
for (auto* inst : ir.Instructions()) {
if (auto* builtin = inst->As<spirv::ir::BuiltinCall>()) {
builtin_worklist.Push(builtin);
}
}
// Replace the builtins that we found.
for (auto* builtin : builtin_worklist) {
switch (builtin->Func()) {
case spirv::BuiltinFn::kNormalize:
Normalize(builtin);
break;
case spirv::BuiltinFn::kInverse:
Inverse(builtin);
break;
case spirv::BuiltinFn::kSign:
Sign(builtin);
break;
case spirv::BuiltinFn::kAbs:
Abs(builtin);
break;
case spirv::BuiltinFn::kSMax:
SMax(builtin);
break;
case spirv::BuiltinFn::kSMin:
SMin(builtin);
break;
case spirv::BuiltinFn::kSClamp:
SClamp(builtin);
break;
case spirv::BuiltinFn::kUMax:
UMax(builtin);
break;
case spirv::BuiltinFn::kUMin:
UMin(builtin);
break;
case spirv::BuiltinFn::kUClamp:
UClamp(builtin);
break;
case spirv::BuiltinFn::kFindILsb:
FindILsb(builtin);
break;
case spirv::BuiltinFn::kFindSMsb:
FindSMsb(builtin);
break;
case spirv::BuiltinFn::kFindUMsb:
FindUMsb(builtin);
break;
case spirv::BuiltinFn::kRefract:
Refract(builtin);
break;
case spirv::BuiltinFn::kReflect:
Reflect(builtin);
break;
case spirv::BuiltinFn::kFaceForward:
FaceForward(builtin);
break;
case spirv::BuiltinFn::kLdexp:
Ldexp(builtin);
break;
case spirv::BuiltinFn::kModf:
Modf(builtin);
break;
case spirv::BuiltinFn::kFrexp:
Frexp(builtin);
break;
case spirv::BuiltinFn::kBitCount:
BitCount(builtin);
break;
case spirv::BuiltinFn::kBitFieldInsert:
BitFieldInsert(builtin);
break;
case spirv::BuiltinFn::kBitFieldSExtract:
BitFieldSExtract(builtin);
break;
case spirv::BuiltinFn::kBitFieldUExtract:
BitFieldUExtract(builtin);
break;
case spirv::BuiltinFn::kAdd:
Add(builtin);
break;
case spirv::BuiltinFn::kSub:
Sub(builtin);
break;
case spirv::BuiltinFn::kMul:
Mul(builtin);
break;
case spirv::BuiltinFn::kSDiv:
SDiv(builtin);
break;
case spirv::BuiltinFn::kSMod:
SMod(builtin);
break;
case spirv::BuiltinFn::kConvertFToS:
ConvertFToS(builtin);
break;
case spirv::BuiltinFn::kConvertSToF:
ConvertSToF(builtin);
break;
case spirv::BuiltinFn::kConvertUToF:
ConvertUToF(builtin);
break;
case spirv::BuiltinFn::kBitwiseAnd:
BitwiseAnd(builtin);
break;
case spirv::BuiltinFn::kBitwiseOr:
BitwiseOr(builtin);
break;
case spirv::BuiltinFn::kBitwiseXor:
BitwiseXor(builtin);
break;
case spirv::BuiltinFn::kEqual:
Equal(builtin);
break;
case spirv::BuiltinFn::kNotEqual:
NotEqual(builtin);
break;
case spirv::BuiltinFn::kSGreaterThan:
SGreaterThan(builtin);
break;
case spirv::BuiltinFn::kSGreaterThanEqual:
SGreaterThanEqual(builtin);
break;
case spirv::BuiltinFn::kSLessThan:
SLessThan(builtin);
break;
case spirv::BuiltinFn::kSLessThanEqual:
SLessThanEqual(builtin);
break;
case spirv::BuiltinFn::kUGreaterThan:
UGreaterThan(builtin);
break;
case spirv::BuiltinFn::kUGreaterThanEqual:
UGreaterThanEqual(builtin);
break;
case spirv::BuiltinFn::kULessThan:
ULessThan(builtin);
break;
case spirv::BuiltinFn::kULessThanEqual:
ULessThanEqual(builtin);
break;
case spirv::BuiltinFn::kShiftLeftLogical:
ShiftLeftLogical(builtin);
break;
case spirv::BuiltinFn::kShiftRightLogical:
ShiftRightLogical(builtin);
break;
case spirv::BuiltinFn::kShiftRightArithmetic:
ShiftRightArithmetic(builtin);
break;
case spirv::BuiltinFn::kNot:
Not(builtin);
break;
case spirv::BuiltinFn::kSNegate:
SNegate(builtin);
break;
case spirv::BuiltinFn::kFMod:
FMod(builtin);
break;
case spirv::BuiltinFn::kSelect:
Select(builtin);
break;
case spirv::BuiltinFn::kOuterProduct:
OuterProduct(builtin);
break;
case spirv::BuiltinFn::kAtomicLoad:
case spirv::BuiltinFn::kAtomicStore:
case spirv::BuiltinFn::kAtomicExchange:
case spirv::BuiltinFn::kAtomicCompareExchange:
case spirv::BuiltinFn::kAtomicIAdd:
case spirv::BuiltinFn::kAtomicISub:
case spirv::BuiltinFn::kAtomicSMax:
case spirv::BuiltinFn::kAtomicSMin:
case spirv::BuiltinFn::kAtomicUMax:
case spirv::BuiltinFn::kAtomicUMin:
case spirv::BuiltinFn::kAtomicAnd:
case spirv::BuiltinFn::kAtomicOr:
case spirv::BuiltinFn::kAtomicXor:
case spirv::BuiltinFn::kAtomicIIncrement:
case spirv::BuiltinFn::kAtomicIDecrement:
// Ignore Atomics, they'll be handled by the `Atomics` transform.
break;
case spirv::BuiltinFn::kSampledImage:
case spirv::BuiltinFn::kImageRead:
case spirv::BuiltinFn::kImageFetch:
case spirv::BuiltinFn::kImageGather:
case spirv::BuiltinFn::kImageQueryLevels:
case spirv::BuiltinFn::kImageQuerySamples:
case spirv::BuiltinFn::kImageQuerySize:
case spirv::BuiltinFn::kImageQuerySizeLod:
case spirv::BuiltinFn::kImageSampleExplicitLod:
case spirv::BuiltinFn::kImageSampleImplicitLod:
case spirv::BuiltinFn::kImageSampleProjImplicitLod:
case spirv::BuiltinFn::kImageSampleProjExplicitLod:
case spirv::BuiltinFn::kImageSampleDrefImplicitLod:
case spirv::BuiltinFn::kImageWrite:
// Ignore image methods, they'll be handled by the `Texture` transform.
break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
}
}
}
void OuterProduct(spirv::ir::BuiltinCall* call) {
auto* vector1 = call->Args()[0];
auto* vector2 = call->Args()[1];
uint32_t rows = vector1->Type()->As<core::type::Vector>()->Width();
uint32_t cols = vector2->Type()->As<core::type::Vector>()->Width();
auto* elem_ty = vector1->Type()->DeepestElement();
b.InsertBefore(call, [&] {
Vector<core::ir::Value*, 4> col_vectors;
for (uint32_t col = 0; col < cols; ++col) {
Vector<core::ir::Value*, 4> col_elements;
auto* v2_element = b.Access(elem_ty, vector2, u32(col));
for (uint32_t row = 0; row < rows; ++row) {
auto* v1_element = b.Access(elem_ty, vector1, u32(row));
auto* result = b.Multiply(elem_ty, v1_element, v2_element)->Result();
col_elements.Push(result);
}
auto* row_vector = b.Construct(ty.vec(elem_ty, rows), col_elements)->Result();
col_vectors.Push(row_vector);
}
b.ConstructWithResult(call->DetachResult(), col_vectors);
});
call->Destroy();
}
void Select(spirv::ir::BuiltinCall* call) {
auto* cond = call->Args()[0];
auto* true_ = call->Args()[1];
auto* false_ = call->Args()[2];
b.InsertBefore(call, [&] {
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kSelect, false_, true_, cond);
});
call->Destroy();
}
// FMod(x, y) emulated with: x - y * floor(x / y)
void FMod(spirv::ir::BuiltinCall* call) {
auto* x = call->Args()[0];
auto* y = call->Args()[1];
auto* res_ty = call->Result()->Type();
b.InsertBefore(call, [&] {
auto* div = b.Divide(res_ty, x, y);
auto* floor = b.Call(res_ty, core::BuiltinFn::kFloor, div);
auto* mul = b.Multiply(res_ty, y, floor);
auto* sub = b.Subtract(res_ty, x, mul);
call->Result()->ReplaceAllUsesWith(sub->Result());
});
call->Destroy();
}
void SNegate(spirv::ir::BuiltinCall* call) {
auto* val = call->Args()[0];
auto* res_ty = call->Result()->Type();
auto* neg_ty = ty.MatchWidth(ty.i32(), val->Type());
b.InsertBefore(call, [&] {
if (val->Type() != neg_ty) {
val = b.Bitcast(neg_ty, val)->Result();
}
val = b.Negation(neg_ty, val)->Result();
if (neg_ty != res_ty) {
val = b.Bitcast(res_ty, val)->Result();
}
call->Result()->ReplaceAllUsesWith(val);
});
call->Destroy();
}
void Not(spirv::ir::BuiltinCall* call) {
auto* val = call->Args()[0];
auto* result_ty = call->Result()->Type();
b.InsertBefore(call, [&] {
auto* complement = b.Complement(val->Type(), val)->Result();
if (val->Type() != result_ty) {
complement = b.Bitcast(result_ty, complement)->Result();
}
call->Result()->ReplaceAllUsesWith(complement);
});
call->Destroy();
}
void ConvertSToF(spirv::ir::BuiltinCall* call) {
b.InsertBefore(call, [&] {
auto* result_ty = call->Result()->Type();
auto* arg = call->Args()[0];
if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result();
}
b.ConvertWithResult(call->DetachResult(), arg);
});
call->Destroy();
}
void ConvertUToF(spirv::ir::BuiltinCall* call) {
b.InsertBefore(call, [&] {
auto* result_ty = call->Result()->Type();
auto* arg = call->Args()[0];
if (arg->Type()->IsSignedIntegerScalarOrVector()) {
arg = b.Bitcast(ty.MatchWidth(ty.u32(), result_ty), arg)->Result();
}
b.ConvertWithResult(call->DetachResult(), arg);
});
call->Destroy();
}
void ConvertFToS(spirv::ir::BuiltinCall* call) {
b.InsertBefore(call, [&] {
auto* res_ty = call->Result()->Type();
auto deepest = res_ty->DeepestElement();
auto* res = b.Convert(ty.MatchWidth(ty.i32(), res_ty), call->Args()[0])->Result();
if (deepest->IsUnsignedIntegerScalar()) {
res = b.Bitcast(res_ty, res)->Result();
}
call->Result()->ReplaceAllUsesWith(res);
});
call->Destroy();
}
void EmitBinaryWrappedAsFirstArg(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
const auto& args = call->Args();
auto* lhs = args[0];
auto* rhs = args[1];
auto* op_ty = lhs->Type();
auto* res_ty = call->Result()->Type();
b.InsertBefore(call, [&] {
if (rhs->Type() != op_ty) {
rhs = b.Bitcast(op_ty, rhs)->Result();
}
auto* c = b.Binary(op, op_ty, lhs, rhs)->Result();
if (res_ty != op_ty) {
c = b.Bitcast(res_ty, c)->Result();
}
call->Result()->ReplaceAllUsesWith(c);
});
call->Destroy();
}
void BitwiseAnd(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kAnd);
}
void BitwiseOr(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kOr);
}
void BitwiseXor(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kXor);
}
void Add(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kAdd);
}
void Sub(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kSubtract);
}
void Mul(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kMultiply);
}
void EmitBinaryWrappedSignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
const auto& args = call->Args();
auto* lhs = args[0];
auto* rhs = args[1];
auto* res_ty = call->Result()->Type();
auto* op_ty = ty.MatchWidth(ty.i32(), res_ty);
b.InsertBefore(call, [&] {
if (lhs->Type() != op_ty) {
lhs = b.Bitcast(op_ty, lhs)->Result();
}
if (rhs->Type() != op_ty) {
rhs = b.Bitcast(op_ty, rhs)->Result();
}
auto* c = b.Binary(op, op_ty, lhs, rhs)->Result();
if (res_ty != op_ty) {
c = b.Bitcast(res_ty, c)->Result();
}
call->Result()->ReplaceAllUsesWith(c);
});
call->Destroy();
}
void SDiv(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedSignedSpirvMethods(call, core::BinaryOp::kDivide);
}
void SMod(spirv::ir::BuiltinCall* call) {
EmitBinaryWrappedSignedSpirvMethods(call, core::BinaryOp::kModulo);
}
void EmitBinaryMatchedArgs(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
const auto& args = call->Args();
auto* lhs = args[0];
auto* rhs = args[1];
b.InsertBefore(call, [&] {
if (rhs->Type() != lhs->Type()) {
rhs = b.Bitcast(lhs->Type(), rhs)->Result();
}
b.BinaryWithResult(call->DetachResult(), op, lhs, rhs)->Result();
});
call->Destroy();
}
void Equal(spirv::ir::BuiltinCall* call) {
EmitBinaryMatchedArgs(call, core::BinaryOp::kEqual);
}
void NotEqual(spirv::ir::BuiltinCall* call) {
EmitBinaryMatchedArgs(call, core::BinaryOp::kNotEqual);
}
void EmitBinaryWithSignedArgs(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
const auto& args = call->Args();
auto* lhs = args[0];
auto* rhs = args[1];
auto* arg_ty = ty.MatchWidth(ty.i32(), call->Result()->Type());
b.InsertBefore(call, [&] {
if (lhs->Type() != arg_ty) {
lhs = b.Bitcast(arg_ty, lhs)->Result();
}
if (rhs->Type() != arg_ty) {
rhs = b.Bitcast(arg_ty, rhs)->Result();
}
b.BinaryWithResult(call->DetachResult(), op, lhs, rhs)->Result();
});
call->Destroy();
}
void SGreaterThan(spirv::ir::BuiltinCall* call) {
EmitBinaryWithSignedArgs(call, core::BinaryOp::kGreaterThan);
}
void SGreaterThanEqual(spirv::ir::BuiltinCall* call) {
EmitBinaryWithSignedArgs(call, core::BinaryOp::kGreaterThanEqual);
}
void SLessThan(spirv::ir::BuiltinCall* call) {
EmitBinaryWithSignedArgs(call, core::BinaryOp::kLessThan);
}
void SLessThanEqual(spirv::ir::BuiltinCall* call) {
EmitBinaryWithSignedArgs(call, core::BinaryOp::kLessThanEqual);
}
void EmitBinaryWithUnsignedArgs(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
const auto& args = call->Args();
auto* lhs = args[0];
auto* rhs = args[1];
auto* arg_ty = ty.MatchWidth(ty.u32(), call->Result()->Type());
b.InsertBefore(call, [&] {
if (lhs->Type() != arg_ty) {
lhs = b.Bitcast(arg_ty, lhs)->Result();
}
if (rhs->Type() != arg_ty) {
rhs = b.Bitcast(arg_ty, rhs)->Result();
}
b.BinaryWithResult(call->DetachResult(), op, lhs, rhs)->Result();
});
call->Destroy();
}
void UGreaterThan(spirv::ir::BuiltinCall* call) {
EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kGreaterThan);
}
void UGreaterThanEqual(spirv::ir::BuiltinCall* call) {
EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kGreaterThanEqual);
}
void ULessThan(spirv::ir::BuiltinCall* call) {
EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kLessThan);
}
void ULessThanEqual(spirv::ir::BuiltinCall* call) {
EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kLessThanEqual);
}
// The SPIR-V Signed methods all interpret their arguments as signed (regardless of the type of
// the argument). In order to satisfy this, we must bitcast any unsigned argument to a signed
// type before calling the WGSL equivalent method.
//
// The result of the WGSL method will match the arguments, or in this case a signed value. If
// the SPIR-V instruction expected an unsigned result we must bitcast the WGSL result to the
// corrrect unsigned type.
void WrapSignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BuiltinFn func) {
auto args = call->Args();
b.InsertBefore(call, [&] {
auto* result_ty = call->Result()->Type();
Vector<core::ir::Value*, 2> new_args;
for (auto* arg : args) {
if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result();
}
new_args.Push(arg);
}
auto* new_call = b.Call(result_ty, func, new_args);
core::ir::Value* replacement = new_call->Result();
if (result_ty->DeepestElement() == ty.u32()) {
new_call->Result()->SetType(ty.MatchWidth(ty.i32(), result_ty));
replacement = b.Bitcast(result_ty, replacement)->Result();
}
call->Result()->ReplaceAllUsesWith(replacement);
});
call->Destroy();
}
void Sign(spirv::ir::BuiltinCall* call) {
WrapSignedSpirvMethods(call, core::BuiltinFn::kSign);
}
void Abs(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kAbs); }
void FindSMsb(spirv::ir::BuiltinCall* call) {
WrapSignedSpirvMethods(call, core::BuiltinFn::kFirstLeadingBit);
}
void SMax(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMax); }
void SMin(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMin); }
void SClamp(spirv::ir::BuiltinCall* call) {
WrapSignedSpirvMethods(call, core::BuiltinFn::kClamp);
}
void Ldexp(spirv::ir::BuiltinCall* call) {
WrapSignedSpirvMethods(call, core::BuiltinFn::kLdexp);
}
// The SPIR-V Unsigned methods all interpret their arguments as unsigned (regardless of the type
// of the argument). In order to satisfy this, we must bitcast any signed argument to an
// unsigned type before calling the WGSL equivalent method.
//
// The result of the WGSL method will match the arguments, or in this case an unsigned value. If
// the SPIR-V instruction expected a signed result we must bitcast the WGSL result to the
// correct signed type.
void WrapUnsignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BuiltinFn func) {
auto args = call->Args();
b.InsertBefore(call, [&] {
auto* result_ty = call->Result()->Type();
Vector<core::ir::Value*, 2> new_args;
for (auto* arg : args) {
if (arg->Type()->IsSignedIntegerScalarOrVector()) {
arg = b.Bitcast(ty.MatchWidth(ty.u32(), result_ty), arg)->Result();
}
new_args.Push(arg);
}
auto* new_call = b.Call(result_ty, func, new_args);
core::ir::Value* replacement = new_call->Result();
if (result_ty->DeepestElement() == ty.i32()) {
new_call->Result()->SetType(ty.MatchWidth(ty.u32(), result_ty));
replacement = b.Bitcast(result_ty, replacement)->Result();
}
call->Result()->ReplaceAllUsesWith(replacement);
});
call->Destroy();
}
void UMax(spirv::ir::BuiltinCall* call) {
WrapUnsignedSpirvMethods(call, core::BuiltinFn::kMax);
}
void UMin(spirv::ir::BuiltinCall* call) {
WrapUnsignedSpirvMethods(call, core::BuiltinFn::kMin);
}
void UClamp(spirv::ir::BuiltinCall* call) {
WrapUnsignedSpirvMethods(call, core::BuiltinFn::kClamp);
}
void FindUMsb(spirv::ir::BuiltinCall* call) {
WrapUnsignedSpirvMethods(call, core::BuiltinFn::kFirstLeadingBit);
}
void Normalize(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
b.InsertBefore(call, [&] {
core::BuiltinFn fn = core::BuiltinFn::kNormalize;
if (arg->Type()->IsScalar()) {
fn = core::BuiltinFn::kSign;
}
b.CallWithResult(call->DetachResult(), fn, Vector<core::ir::Value*, 1>{arg});
});
call->Destroy();
}
void FindILsb(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
b.InsertBefore(call, [&] {
auto* arg_ty = arg->Type();
auto* ret_ty = call->Result()->Type();
auto* v =
b.Call(arg_ty, core::BuiltinFn::kFirstTrailingBit, Vector<core::ir::Value*, 1>{arg})
->Result();
if (arg_ty != ret_ty) {
v = b.Bitcast(ret_ty, v)->Result();
}
call->Result()->ReplaceAllUsesWith(v);
});
call->Destroy();
}
void Refract(spirv::ir::BuiltinCall* call) {
auto args = call->Args();
auto* I = args[0];
auto* N = args[1];
auto* eta = args[2];
b.InsertBefore(call, [&] {
if (I->Type()->IsFloatScalar()) {
auto* src_ty = I->Type();
auto* vec_ty = ty.vec(src_ty, 2);
auto* zero = b.Zero(src_ty);
I = b.Construct(vec_ty, I, zero)->Result();
N = b.Construct(vec_ty, N, zero)->Result();
auto* c = b.Call(vec_ty, core::BuiltinFn::kRefract,
Vector<core::ir::Value*, 3>{I, N, eta});
auto* s = b.Swizzle(src_ty, c, {0});
call->Result()->ReplaceAllUsesWith(s->Result());
} else {
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kRefract,
Vector<core::ir::Value*, 3>{I, N, eta});
}
});
call->Destroy();
}
void Reflect(spirv::ir::BuiltinCall* call) {
auto args = call->Args();
auto* I = args[0];
auto* N = args[1];
b.InsertBefore(call, [&] {
if (I->Type()->IsFloatScalar()) {
auto* v = b.Multiply(I->Type(), I, N)->Result();
v = b.Multiply(I->Type(), v, N)->Result();
v = b.Multiply(I->Type(), v, 2.0_f)->Result();
v = b.Subtract(I->Type(), I, v)->Result();
call->Result()->ReplaceAllUsesWith(v);
} else {
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kReflect,
Vector<core::ir::Value*, 2>{I, N});
}
});
call->Destroy();
}
void FaceForward(spirv::ir::BuiltinCall* call) {
auto args = call->Args();
auto* N = args[0];
auto* I = args[1];
auto* Nref = args[2];
b.InsertBefore(call, [&] {
if (I->Type()->IsFloatScalar()) {
auto* neg = b.Negation(N->Type(), N);
auto* sel = b.Multiply(I->Type(), I, Nref)->Result();
sel = b.LessThan(ty.bool_(), sel, b.Zero(sel->Type()))->Result();
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kSelect, neg, N, sel);
} else {
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kFaceForward, N, I, Nref);
}
});
call->Destroy();
}
void Modf(spirv::ir::BuiltinCall* call) {
auto* x = call->Args()[0];
auto* i = call->Args()[1];
auto* result_ty = call->Result()->Type();
auto* modf_result_ty = core::type::CreateModfResult(ty, ir.symbols, result_ty);
b.InsertBefore(call, [&] {
auto* c = b.Call(modf_result_ty, core::BuiltinFn::kModf, x);
auto* whole = b.Access(result_ty, c, 1_u);
b.Store(i, whole);
b.AccessWithResult(call->DetachResult(), c, 0_u);
});
call->Destroy();
}
void Frexp(spirv::ir::BuiltinCall* call) {
auto* x = call->Args()[0];
auto* i = call->Args()[1];
auto* result_ty = call->Result()->Type();
auto* frexp_result_ty = core::type::CreateFrexpResult(ty, ir.symbols, result_ty);
b.InsertBefore(call, [&] {
auto* c = b.Call(frexp_result_ty, core::BuiltinFn::kFrexp, x);
auto* exp = b.Access(ty.MatchWidth(ty.i32(), result_ty), c, 1_u)->Result();
if (i->Type()->UnwrapPtr()->DeepestElement()->IsUnsignedIntegerScalar()) {
exp = b.Bitcast(i->Type()->UnwrapPtr(), exp)->Result();
}
b.Store(i, exp);
b.AccessWithResult(call->DetachResult(), c, 0_u);
});
call->Destroy();
}
void BitFieldInsert(spirv::ir::BuiltinCall* call) {
const auto& args = call->Args();
auto* e = args[0];
auto* newbits = args[1];
auto* offset = args[2];
auto* count = args[3];
b.InsertBefore(call, [&] {
if (offset->Type()->IsSignedIntegerScalar()) {
offset = b.Bitcast(ty.u32(), offset)->Result();
}
if (count->Type()->IsSignedIntegerScalar()) {
count = b.Bitcast(ty.u32(), count)->Result();
}
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kInsertBits, e, newbits, offset,
count);
});
call->Destroy();
}
void BitFieldUExtract(spirv::ir::BuiltinCall* call) {
const auto& args = call->Args();
auto* e = args[0];
auto* offset = args[1];
auto* count = args[2];
b.InsertBefore(call, [&] {
bool cast_result = false;
auto* call_ty = ty.MatchWidth(ty.u32(), e->Type());
if (e->Type()->DeepestElement()->IsSignedIntegerScalar()) {
e = b.Bitcast(call_ty, e)->Result();
cast_result = true;
}
if (offset->Type()->IsSignedIntegerScalar()) {
offset = b.Bitcast(ty.u32(), offset)->Result();
}
if (count->Type()->IsSignedIntegerScalar()) {
count = b.Bitcast(ty.u32(), count)->Result();
}
auto* res = b.Call(call_ty, core::BuiltinFn::kExtractBits, e, offset, count)->Result();
if (cast_result) {
res = b.Bitcast(call->Result()->Type(), res)->Result();
}
call->Result()->ReplaceAllUsesWith(res);
});
call->Destroy();
}
void BitFieldSExtract(spirv::ir::BuiltinCall* call) {
const auto& args = call->Args();
auto* e = args[0];
auto* offset = args[1];
auto* count = args[2];
b.InsertBefore(call, [&] {
bool cast_result = false;
auto* call_ty = ty.MatchWidth(ty.i32(), e->Type());
if (e->Type()->DeepestElement()->IsUnsignedIntegerScalar()) {
e = b.Bitcast(call_ty, e)->Result();
cast_result = true;
}
if (offset->Type()->IsSignedIntegerScalar()) {
offset = b.Bitcast(ty.u32(), offset)->Result();
}
if (count->Type()->IsSignedIntegerScalar()) {
count = b.Bitcast(ty.u32(), count)->Result();
}
auto* res = b.Call(call_ty, core::BuiltinFn::kExtractBits, e, offset, count)->Result();
if (cast_result) {
res = b.Bitcast(call->Result()->Type(), res)->Result();
}
call->Result()->ReplaceAllUsesWith(res);
});
call->Destroy();
}
void BitCount(spirv::ir::BuiltinCall* call) {
auto arg = call->Args()[0];
b.InsertBefore(call, [&] {
auto* res_ty = call->Result()->Type();
auto* arg_ty = arg->Type();
auto* bc = b.Call(arg_ty, core::BuiltinFn::kCountOneBits, arg)->Result();
if (res_ty != arg_ty) {
bc = b.Bitcast(res_ty, bc)->Result();
}
call->Result()->ReplaceAllUsesWith(bc);
});
call->Destroy();
}
void ShiftLeftLogical(spirv::ir::BuiltinCall* call) {
const auto& args = call->Args();
b.InsertBefore(call, [&] {
auto* base = args[0];
auto* shift = args[1];
if (!shift->Type()->IsUnsignedIntegerScalarOrVector()) {
shift = b.Bitcast(ty.MatchWidth(ty.u32(), shift->Type()), shift)->Result();
}
auto* bin = b.Binary(core::BinaryOp::kShiftLeft, base->Type(), base, shift)->Result();
if (base->Type() != call->Result()->Type()) {
bin = b.Bitcast(call->Result()->Type(), bin)->Result();
}
call->Result()->ReplaceAllUsesWith(bin);
});
call->Destroy();
}
void ShiftRightLogical(spirv::ir::BuiltinCall* call) {
const auto& args = call->Args();
b.InsertBefore(call, [&] {
auto* base = args[0];
auto* shift = args[1];
auto* u_ty = ty.MatchWidth(ty.u32(), base->Type());
if (!base->Type()->IsUnsignedIntegerScalarOrVector()) {
base = b.Bitcast(u_ty, base)->Result();
}
if (!shift->Type()->IsUnsignedIntegerScalarOrVector()) {
shift = b.Bitcast(u_ty, shift)->Result();
}
auto* bin = b.Binary(core::BinaryOp::kShiftRight, u_ty, base, shift)->Result();
if (u_ty != call->Result()->Type()) {
bin = b.Bitcast(call->Result()->Type(), bin)->Result();
}
call->Result()->ReplaceAllUsesWith(bin);
});
call->Destroy();
}
void ShiftRightArithmetic(spirv::ir::BuiltinCall* call) {
const auto& args = call->Args();
b.InsertBefore(call, [&] {
auto* base = args[0];
auto* shift = args[1];
auto* s_ty = ty.MatchWidth(ty.i32(), base->Type());
if (!base->Type()->IsSignedIntegerScalarOrVector()) {
base = b.Bitcast(s_ty, base)->Result();
}
if (!shift->Type()->IsUnsignedIntegerScalarOrVector()) {
shift = b.Bitcast(ty.MatchWidth(ty.u32(), shift->Type()), shift)->Result();
}
auto* bin = b.Binary(core::BinaryOp::kShiftRight, s_ty, base, shift)->Result();
if (s_ty != call->Result()->Type()) {
bin = b.Bitcast(call->Result()->Type(), bin)->Result();
}
call->Result()->ReplaceAllUsesWith(bin);
});
call->Destroy();
}
void Inverse(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
auto* mat_ty = arg->Type()->As<core::type::Matrix>();
TINT_ASSERT(mat_ty);
TINT_ASSERT(mat_ty->Columns() == mat_ty->Rows());
auto* elem_ty = mat_ty->Type();
b.InsertBefore(call, [&] {
auto* det =
b.Call(elem_ty, core::BuiltinFn::kDeterminant, Vector<core::ir::Value*, 1>{arg});
core::ir::Value* one = nullptr;
if (elem_ty->Is<core::type::F32>()) {
one = b.Constant(1.0_f);
} else if (elem_ty->Is<core::type::F16>()) {
one = b.Constant(1.0_h);
} else {
TINT_UNREACHABLE();
}
auto* inv_det = b.Divide(elem_ty, one, det);
// Returns (m * n) - (o * p)
auto sub_mul2 = [&](auto* m, auto* n, auto* o, auto* p) {
auto* x = b.Multiply(elem_ty, m, n);
auto* y = b.Multiply(elem_ty, o, p);
return b.Subtract(elem_ty, x, y);
};
// Returns (m * n) - (o * p) + (q * r)
auto sub_add_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) {
auto* w = b.Multiply(elem_ty, m, n);
auto* x = b.Multiply(elem_ty, o, p);
auto* y = b.Multiply(elem_ty, q, r);
auto* z = b.Subtract(elem_ty, w, x);
return b.Add(elem_ty, z, y);
};
// Returns (m * n) + (o * p) - (q * r)
auto add_sub_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) {
auto* w = b.Multiply(elem_ty, m, n);
auto* x = b.Multiply(elem_ty, o, p);
auto* y = b.Multiply(elem_ty, q, r);
auto* z = b.Add(elem_ty, w, x);
return b.Subtract(elem_ty, z, y);
};
switch (mat_ty->Columns()) {
case 2: {
auto* neg_inv_det = b.Negation(elem_ty, inv_det);
auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
auto* mc = b.Access(elem_ty, arg, 1_u, 0_u);
auto* md = b.Access(elem_ty, arg, 1_u, 1_u);
auto* r_00 = b.Multiply(elem_ty, inv_det, md);
auto* r_01 = b.Multiply(elem_ty, neg_inv_det, mb);
auto* r_10 = b.Multiply(elem_ty, neg_inv_det, mc);
auto* r_11 = b.Multiply(elem_ty, inv_det, ma);
auto* r1 = b.Construct(ty.vec2(elem_ty), r_00, r_01);
auto* r2 = b.Construct(ty.vec2(elem_ty), r_10, r_11);
b.ConstructWithResult(call->DetachResult(), r1, r2);
break;
}
case 3: {
auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
auto* mc = b.Access(elem_ty, arg, 0_u, 2_u);
auto* md = b.Access(elem_ty, arg, 1_u, 0_u);
auto* me = b.Access(elem_ty, arg, 1_u, 1_u);
auto* mf = b.Access(elem_ty, arg, 1_u, 2_u);
auto* mg = b.Access(elem_ty, arg, 2_u, 0_u);
auto* mh = b.Access(elem_ty, arg, 2_u, 1_u);
auto* mi = b.Access(elem_ty, arg, 2_u, 2_u);
// e * i - f * h
auto* r_00 = sub_mul2(me, mi, mf, mh);
// c * h - b * i
auto* r_01 = sub_mul2(mc, mh, mb, mi);
// b * f - c * e
auto* r_02 = sub_mul2(mb, mf, mc, me);
// f * g - d * i
auto* r_10 = sub_mul2(mf, mg, md, mi);
// a * i - c * g
auto* r_11 = sub_mul2(ma, mi, mc, mg);
// c * d - a * f
auto* r_12 = sub_mul2(mc, md, ma, mf);
// d * h - e * g
auto* r_20 = sub_mul2(md, mh, me, mg);
// b * g - a * h
auto* r_21 = sub_mul2(mb, mg, ma, mh);
// a * e - b * d
auto* r_22 = sub_mul2(ma, me, mb, md);
auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02);
auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12);
auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22);
auto* m = b.Construct(mat_ty, r1, r2, r3);
auto* inv = b.Multiply(mat_ty, inv_det, m);
call->Result()->ReplaceAllUsesWith(inv->Result());
break;
}
case 4: {
auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
auto* mc = b.Access(elem_ty, arg, 0_u, 2_u);
auto* md = b.Access(elem_ty, arg, 0_u, 3_u);
auto* me = b.Access(elem_ty, arg, 1_u, 0_u);
auto* mf = b.Access(elem_ty, arg, 1_u, 1_u);
auto* mg = b.Access(elem_ty, arg, 1_u, 2_u);
auto* mh = b.Access(elem_ty, arg, 1_u, 3_u);
auto* mi = b.Access(elem_ty, arg, 2_u, 0_u);
auto* mj = b.Access(elem_ty, arg, 2_u, 1_u);
auto* mk = b.Access(elem_ty, arg, 2_u, 2_u);
auto* ml = b.Access(elem_ty, arg, 2_u, 3_u);
auto* mm = b.Access(elem_ty, arg, 3_u, 0_u);
auto* mn = b.Access(elem_ty, arg, 3_u, 1_u);
auto* mo = b.Access(elem_ty, arg, 3_u, 2_u);
auto* mp = b.Access(elem_ty, arg, 3_u, 3_u);
// kplo = k * p - l * o
auto* kplo = sub_mul2(mk, mp, ml, mo);
// jpln = j * p - l * n
auto* jpln = sub_mul2(mj, mp, ml, mn);
// jokn = j * o - k * n;
auto* jokn = sub_mul2(mj, mo, mk, mn);
// gpho = g * p - h * o
auto* gpho = sub_mul2(mg, mp, mh, mo);
// fphn = f * p - h * n
auto* fphn = sub_mul2(mf, mp, mh, mn);
// fogn = f * o - g * n;
auto* fogn = sub_mul2(mf, mo, mg, mn);
// glhk = g * l - h * k
auto* glhk = sub_mul2(mg, ml, mh, mk);
// flhj = f * l - h * j
auto* flhj = sub_mul2(mf, ml, mh, mj);
// fkgj = f * k - g * j;
auto* fkgj = sub_mul2(mf, mk, mg, mj);
// iplm = i * p - l * m
auto* iplm = sub_mul2(mi, mp, ml, mm);
// iokm = i * o - k * m
auto* iokm = sub_mul2(mi, mo, mk, mm);
// ephm = e * p - h * m;
auto* ephm = sub_mul2(me, mp, mh, mm);
// eogm = e * o - g * m
auto* eogm = sub_mul2(me, mo, mg, mm);
// elhi = e * l - h * i
auto* elhi = sub_mul2(me, ml, mh, mi);
// ekgi = e * k - g * i;
auto* ekgi = sub_mul2(me, mk, mg, mi);
// injm = i * n - j * m
auto* injm = sub_mul2(mi, mn, mj, mm);
// enfm = e * n - f * m
auto* enfm = sub_mul2(me, mn, mf, mm);
// ejfi = e * j - f * i;
auto* ejfi = sub_mul2(me, mj, mf, mi);
auto* neg_b = b.Negation(elem_ty, mb);
// f * kplo - g * jpln + h * jokn
auto* r_00 = sub_add_mul3(mf, kplo, mg, jpln, mh, jokn);
// -b * kplo + c * jpln - d * jokn
auto* r_01 = add_sub_mul3(neg_b, kplo, mc, jpln, md, jokn);
// b * gpho - c * fphn + d * fogn
auto* r_02 = sub_add_mul3(mb, gpho, mc, fphn, md, fogn);
// -b * glhk + c * flhj - d * fkgj
auto* r_03 = add_sub_mul3(neg_b, glhk, mc, flhj, md, fkgj);
auto* neg_e = b.Negation(elem_ty, me);
auto* neg_a = b.Negation(elem_ty, ma);
// -e * kplo + g * iplm - h * iokm
auto* r_10 = add_sub_mul3(neg_e, kplo, mg, iplm, mh, iokm);
// a * kplo - c * iplm + d * iokm
auto* r_11 = sub_add_mul3(ma, kplo, mc, iplm, md, iokm);
// -a * gpho + c * ephm - d * eogm
auto* r_12 = add_sub_mul3(neg_a, gpho, mc, ephm, md, eogm);
// a * glhk - c * elhi + d * ekgi
auto* r_13 = sub_add_mul3(ma, glhk, mc, elhi, md, ekgi);
// e * jpln - f * iplm + h * injm
auto* r_20 = sub_add_mul3(me, jpln, mf, iplm, mh, injm);
// -a * jpln + b * iplm - d * injm
auto* r_21 = add_sub_mul3(neg_a, jpln, mb, iplm, md, injm);
// a * fphn - b * ephm + d * enfm
auto* r_22 = sub_add_mul3(ma, fphn, mb, ephm, md, enfm);
// -a * flhj + b * elhi - d * ejfi
auto* r_23 = add_sub_mul3(neg_a, flhj, mb, elhi, md, ejfi);
// -e * jokn + f * iokm - g * injm
auto* r_30 = add_sub_mul3(neg_e, jokn, mf, iokm, mg, injm);
// a * jokn - b * iokm + c * injm
auto* r_31 = sub_add_mul3(ma, jokn, mb, iokm, mc, injm);
// -a * fogn + b * eogm - c * enfm
auto* r_32 = add_sub_mul3(neg_a, fogn, mb, eogm, mc, enfm);
// a * fkgj - b * ekgi + c * ejfi
auto* r_33 = sub_add_mul3(ma, fkgj, mb, ekgi, mc, ejfi);
auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02, r_03);
auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12, r_13);
auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22, r_23);
auto* r4 = b.Construct(ty.vec3(elem_ty), r_30, r_31, r_32, r_33);
auto* m = b.Construct(mat_ty, r1, r2, r3, r4);
auto* inv = b.Multiply(mat_ty, inv_det, m);
call->Result()->ReplaceAllUsesWith(inv->Result());
break;
}
default: {
TINT_UNREACHABLE();
}
}
});
call->Destroy();
}
};
} // namespace
Result<SuccessType> Builtins(core::ir::Module& ir) {
auto result = ValidateAndDumpIfNeeded(ir, "spirv.Builtins",
core::ir::Capabilities{
core::ir::Capability::kAllowOverrides,
core::ir::Capability::kAllowNonCoreTypes,
});
if (result != Success) {
return result.Failure();
}
State{ir}.Process();
return Success;
}
} // namespace tint::spirv::reader::lower