// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
//    contributors may be used to endorse or promote products derived from
//    this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "src/tint/lang/spirv/reader/lower/builtins.h"

#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/builtin_structs.h"
#include "src/tint/lang/spirv/ir/builtin_call.h"

namespace tint::spirv::reader::lower {
namespace {

using namespace tint::core::fluent_types;     // NOLINT
using namespace tint::core::number_suffixes;  // NOLINT

/// PIMPL state for the transform.
struct State {
    /// The IR module.
    core::ir::Module& ir;

    /// The IR builder.
    core::ir::Builder b{ir};

    /// The type manager.
    core::type::Manager& ty{ir.Types()};

    /// Process the module.
    void Process() {
        Vector<spirv::ir::BuiltinCall*, 4> builtin_worklist;
        for (auto* inst : ir.Instructions()) {
            if (auto* builtin = inst->As<spirv::ir::BuiltinCall>()) {
                builtin_worklist.Push(builtin);
            }
        }

        // Replace the builtins that we found.
        for (auto* builtin : builtin_worklist) {
            switch (builtin->Func()) {
                case spirv::BuiltinFn::kNormalize:
                    Normalize(builtin);
                    break;
                case spirv::BuiltinFn::kInverse:
                    Inverse(builtin);
                    break;
                case spirv::BuiltinFn::kSign:
                    Sign(builtin);
                    break;
                case spirv::BuiltinFn::kAbs:
                    Abs(builtin);
                    break;
                case spirv::BuiltinFn::kSMax:
                    SMax(builtin);
                    break;
                case spirv::BuiltinFn::kSMin:
                    SMin(builtin);
                    break;
                case spirv::BuiltinFn::kSClamp:
                    SClamp(builtin);
                    break;
                case spirv::BuiltinFn::kUMax:
                    UMax(builtin);
                    break;
                case spirv::BuiltinFn::kUMin:
                    UMin(builtin);
                    break;
                case spirv::BuiltinFn::kUClamp:
                    UClamp(builtin);
                    break;
                case spirv::BuiltinFn::kFindILsb:
                    FindILsb(builtin);
                    break;
                case spirv::BuiltinFn::kFindSMsb:
                    FindSMsb(builtin);
                    break;
                case spirv::BuiltinFn::kFindUMsb:
                    FindUMsb(builtin);
                    break;
                case spirv::BuiltinFn::kRefract:
                    Refract(builtin);
                    break;
                case spirv::BuiltinFn::kReflect:
                    Reflect(builtin);
                    break;
                case spirv::BuiltinFn::kFaceForward:
                    FaceForward(builtin);
                    break;
                case spirv::BuiltinFn::kLdexp:
                    Ldexp(builtin);
                    break;
                case spirv::BuiltinFn::kModf:
                    Modf(builtin);
                    break;
                case spirv::BuiltinFn::kFrexp:
                    Frexp(builtin);
                    break;
                case spirv::BuiltinFn::kBitCount:
                    BitCount(builtin);
                    break;
                case spirv::BuiltinFn::kBitFieldInsert:
                    BitFieldInsert(builtin);
                    break;
                case spirv::BuiltinFn::kBitFieldSExtract:
                    BitFieldSExtract(builtin);
                    break;
                case spirv::BuiltinFn::kBitFieldUExtract:
                    BitFieldUExtract(builtin);
                    break;
                case spirv::BuiltinFn::kAdd:
                    Add(builtin);
                    break;
                case spirv::BuiltinFn::kSub:
                    Sub(builtin);
                    break;
                case spirv::BuiltinFn::kMul:
                    Mul(builtin);
                    break;
                case spirv::BuiltinFn::kSDiv:
                    SDiv(builtin);
                    break;
                case spirv::BuiltinFn::kSMod:
                    SMod(builtin);
                    break;
                case spirv::BuiltinFn::kConvertFToS:
                    ConvertFToS(builtin);
                    break;
                case spirv::BuiltinFn::kConvertSToF:
                    ConvertSToF(builtin);
                    break;
                case spirv::BuiltinFn::kConvertUToF:
                    ConvertUToF(builtin);
                    break;
                case spirv::BuiltinFn::kBitwiseAnd:
                    BitwiseAnd(builtin);
                    break;
                case spirv::BuiltinFn::kBitwiseOr:
                    BitwiseOr(builtin);
                    break;
                case spirv::BuiltinFn::kBitwiseXor:
                    BitwiseXor(builtin);
                    break;
                case spirv::BuiltinFn::kEqual:
                    Equal(builtin);
                    break;
                case spirv::BuiltinFn::kNotEqual:
                    NotEqual(builtin);
                    break;
                case spirv::BuiltinFn::kSGreaterThan:
                    SGreaterThan(builtin);
                    break;
                case spirv::BuiltinFn::kSGreaterThanEqual:
                    SGreaterThanEqual(builtin);
                    break;
                case spirv::BuiltinFn::kSLessThan:
                    SLessThan(builtin);
                    break;
                case spirv::BuiltinFn::kSLessThanEqual:
                    SLessThanEqual(builtin);
                    break;
                case spirv::BuiltinFn::kUGreaterThan:
                    UGreaterThan(builtin);
                    break;
                case spirv::BuiltinFn::kUGreaterThanEqual:
                    UGreaterThanEqual(builtin);
                    break;
                case spirv::BuiltinFn::kULessThan:
                    ULessThan(builtin);
                    break;
                case spirv::BuiltinFn::kULessThanEqual:
                    ULessThanEqual(builtin);
                    break;
                case spirv::BuiltinFn::kShiftLeftLogical:
                    ShiftLeftLogical(builtin);
                    break;
                case spirv::BuiltinFn::kShiftRightLogical:
                    ShiftRightLogical(builtin);
                    break;
                case spirv::BuiltinFn::kShiftRightArithmetic:
                    ShiftRightArithmetic(builtin);
                    break;
                case spirv::BuiltinFn::kNot:
                    Not(builtin);
                    break;
                case spirv::BuiltinFn::kSNegate:
                    SNegate(builtin);
                    break;
                case spirv::BuiltinFn::kFMod:
                    FMod(builtin);
                    break;
                case spirv::BuiltinFn::kSelect:
                    Select(builtin);
                    break;
                case spirv::BuiltinFn::kOuterProduct:
                    OuterProduct(builtin);
                    break;
                case spirv::BuiltinFn::kAtomicLoad:
                case spirv::BuiltinFn::kAtomicStore:
                case spirv::BuiltinFn::kAtomicExchange:
                case spirv::BuiltinFn::kAtomicCompareExchange:
                case spirv::BuiltinFn::kAtomicIAdd:
                case spirv::BuiltinFn::kAtomicISub:
                case spirv::BuiltinFn::kAtomicSMax:
                case spirv::BuiltinFn::kAtomicSMin:
                case spirv::BuiltinFn::kAtomicUMax:
                case spirv::BuiltinFn::kAtomicUMin:
                case spirv::BuiltinFn::kAtomicAnd:
                case spirv::BuiltinFn::kAtomicOr:
                case spirv::BuiltinFn::kAtomicXor:
                case spirv::BuiltinFn::kAtomicIIncrement:
                case spirv::BuiltinFn::kAtomicIDecrement:
                    // Ignore Atomics, they'll be handled by the `Atomics` transform.
                    break;
                case spirv::BuiltinFn::kSampledImage:
                case spirv::BuiltinFn::kImageRead:
                case spirv::BuiltinFn::kImageFetch:
                case spirv::BuiltinFn::kImageGather:
                case spirv::BuiltinFn::kImageQueryLevels:
                case spirv::BuiltinFn::kImageQuerySamples:
                case spirv::BuiltinFn::kImageQuerySize:
                case spirv::BuiltinFn::kImageQuerySizeLod:
                case spirv::BuiltinFn::kImageSampleExplicitLod:
                case spirv::BuiltinFn::kImageSampleImplicitLod:
                case spirv::BuiltinFn::kImageSampleProjImplicitLod:
                case spirv::BuiltinFn::kImageSampleProjExplicitLod:
                case spirv::BuiltinFn::kImageWrite:
                    // Ignore image methods, they'll be handled by the `Texture` transform.
                    break;
                default:
                    TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
            }
        }
    }
    void OuterProduct(spirv::ir::BuiltinCall* call) {
        auto* vector1 = call->Args()[0];
        auto* vector2 = call->Args()[1];

        uint32_t rows = vector1->Type()->As<core::type::Vector>()->Width();
        uint32_t cols = vector2->Type()->As<core::type::Vector>()->Width();

        auto* elem_ty = vector1->Type()->DeepestElement();

        b.InsertBefore(call, [&] {
            Vector<core::ir::Value*, 4> col_vectors;

            for (uint32_t col = 0; col < cols; ++col) {
                Vector<core::ir::Value*, 4> col_elements;
                auto* v2_element = b.Access(elem_ty, vector2, u32(col));

                for (uint32_t row = 0; row < rows; ++row) {
                    auto* v1_element = b.Access(elem_ty, vector1, u32(row));
                    auto* result = b.Multiply(elem_ty, v1_element, v2_element)->Result();
                    col_elements.Push(result);
                }

                auto* row_vector = b.Construct(ty.vec(elem_ty, rows), col_elements)->Result();
                col_vectors.Push(row_vector);
            }
            b.ConstructWithResult(call->DetachResult(), col_vectors);
        });

        call->Destroy();
    }

    void Select(spirv::ir::BuiltinCall* call) {
        auto* cond = call->Args()[0];
        auto* true_ = call->Args()[1];
        auto* false_ = call->Args()[2];
        b.InsertBefore(call, [&] {
            b.CallWithResult(call->DetachResult(), core::BuiltinFn::kSelect, false_, true_, cond);
        });
        call->Destroy();
    }

    // FMod(x, y) emulated with: x - y * floor(x / y)
    void FMod(spirv::ir::BuiltinCall* call) {
        auto* x = call->Args()[0];
        auto* y = call->Args()[1];

        auto* res_ty = call->Result()->Type();
        b.InsertBefore(call, [&] {
            auto* div = b.Divide(res_ty, x, y);
            auto* floor = b.Call(res_ty, core::BuiltinFn::kFloor, div);
            auto* mul = b.Multiply(res_ty, y, floor);
            auto* sub = b.Subtract(res_ty, x, mul);

            call->Result()->ReplaceAllUsesWith(sub->Result());
        });
        call->Destroy();
    }

    void SNegate(spirv::ir::BuiltinCall* call) {
        auto* val = call->Args()[0];

        auto* res_ty = call->Result()->Type();
        auto* neg_ty = ty.MatchWidth(ty.i32(), val->Type());
        b.InsertBefore(call, [&] {
            if (val->Type() != neg_ty) {
                val = b.Bitcast(neg_ty, val)->Result();
            }
            val = b.Negation(neg_ty, val)->Result();

            if (neg_ty != res_ty) {
                val = b.Bitcast(res_ty, val)->Result();
            }
            call->Result()->ReplaceAllUsesWith(val);
        });
        call->Destroy();
    }

    void Not(spirv::ir::BuiltinCall* call) {
        auto* val = call->Args()[0];
        auto* result_ty = call->Result()->Type();
        b.InsertBefore(call, [&] {
            auto* complement = b.Complement(val->Type(), val)->Result();
            if (val->Type() != result_ty) {
                complement = b.Bitcast(result_ty, complement)->Result();
            }
            call->Result()->ReplaceAllUsesWith(complement);
        });
        call->Destroy();
    }

    void ConvertSToF(spirv::ir::BuiltinCall* call) {
        b.InsertBefore(call, [&] {
            auto* result_ty = call->Result()->Type();

            auto* arg = call->Args()[0];
            if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
                arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result();
            }

            b.ConvertWithResult(call->DetachResult(), arg);
        });
        call->Destroy();
    }

    void ConvertUToF(spirv::ir::BuiltinCall* call) {
        b.InsertBefore(call, [&] {
            auto* result_ty = call->Result()->Type();

            auto* arg = call->Args()[0];
            if (arg->Type()->IsSignedIntegerScalarOrVector()) {
                arg = b.Bitcast(ty.MatchWidth(ty.u32(), result_ty), arg)->Result();
            }

            b.ConvertWithResult(call->DetachResult(), arg);
        });
        call->Destroy();
    }

    void ConvertFToS(spirv::ir::BuiltinCall* call) {
        b.InsertBefore(call, [&] {
            auto* res_ty = call->Result()->Type();
            auto deepest = res_ty->DeepestElement();

            auto* res = b.Convert(ty.MatchWidth(ty.i32(), res_ty), call->Args()[0])->Result();
            if (deepest->IsUnsignedIntegerScalar()) {
                res = b.Bitcast(res_ty, res)->Result();
            }
            call->Result()->ReplaceAllUsesWith(res);
        });
        call->Destroy();
    }

    void EmitBinaryWrappedAsFirstArg(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
        const auto& args = call->Args();
        auto* lhs = args[0];
        auto* rhs = args[1];

        auto* op_ty = lhs->Type();
        auto* res_ty = call->Result()->Type();

        b.InsertBefore(call, [&] {
            if (rhs->Type() != op_ty) {
                rhs = b.Bitcast(op_ty, rhs)->Result();
            }

            auto* c = b.Binary(op, op_ty, lhs, rhs)->Result();
            if (res_ty != op_ty) {
                c = b.Bitcast(res_ty, c)->Result();
            }
            call->Result()->ReplaceAllUsesWith(c);
        });
        call->Destroy();
    }

    void BitwiseAnd(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kAnd);
    }
    void BitwiseOr(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kOr);
    }
    void BitwiseXor(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kXor);
    }

    void Add(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kAdd);
    }
    void Sub(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kSubtract);
    }
    void Mul(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedAsFirstArg(call, core::BinaryOp::kMultiply);
    }

    void EmitBinaryWrappedSignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
        const auto& args = call->Args();
        auto* lhs = args[0];
        auto* rhs = args[1];

        auto* res_ty = call->Result()->Type();
        auto* op_ty = ty.MatchWidth(ty.i32(), res_ty);

        b.InsertBefore(call, [&] {
            if (lhs->Type() != op_ty) {
                lhs = b.Bitcast(op_ty, lhs)->Result();
            }
            if (rhs->Type() != op_ty) {
                rhs = b.Bitcast(op_ty, rhs)->Result();
            }

            auto* c = b.Binary(op, op_ty, lhs, rhs)->Result();
            if (res_ty != op_ty) {
                c = b.Bitcast(res_ty, c)->Result();
            }
            call->Result()->ReplaceAllUsesWith(c);
        });
        call->Destroy();
    }

    void SDiv(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedSignedSpirvMethods(call, core::BinaryOp::kDivide);
    }
    void SMod(spirv::ir::BuiltinCall* call) {
        EmitBinaryWrappedSignedSpirvMethods(call, core::BinaryOp::kModulo);
    }

    void EmitBinaryMatchedArgs(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
        const auto& args = call->Args();
        auto* lhs = args[0];
        auto* rhs = args[1];

        b.InsertBefore(call, [&] {
            if (rhs->Type() != lhs->Type()) {
                rhs = b.Bitcast(lhs->Type(), rhs)->Result();
            }

            b.BinaryWithResult(call->DetachResult(), op, lhs, rhs)->Result();
        });
        call->Destroy();
    }
    void Equal(spirv::ir::BuiltinCall* call) {
        EmitBinaryMatchedArgs(call, core::BinaryOp::kEqual);
    }
    void NotEqual(spirv::ir::BuiltinCall* call) {
        EmitBinaryMatchedArgs(call, core::BinaryOp::kNotEqual);
    }

    void EmitBinaryWithSignedArgs(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
        const auto& args = call->Args();
        auto* lhs = args[0];
        auto* rhs = args[1];

        auto* arg_ty = ty.MatchWidth(ty.i32(), call->Result()->Type());
        b.InsertBefore(call, [&] {
            if (lhs->Type() != arg_ty) {
                lhs = b.Bitcast(arg_ty, lhs)->Result();
            }
            if (rhs->Type() != arg_ty) {
                rhs = b.Bitcast(arg_ty, rhs)->Result();
            }

            b.BinaryWithResult(call->DetachResult(), op, lhs, rhs)->Result();
        });
        call->Destroy();
    }
    void SGreaterThan(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithSignedArgs(call, core::BinaryOp::kGreaterThan);
    }
    void SGreaterThanEqual(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithSignedArgs(call, core::BinaryOp::kGreaterThanEqual);
    }
    void SLessThan(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithSignedArgs(call, core::BinaryOp::kLessThan);
    }
    void SLessThanEqual(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithSignedArgs(call, core::BinaryOp::kLessThanEqual);
    }

    void EmitBinaryWithUnsignedArgs(spirv::ir::BuiltinCall* call, core::BinaryOp op) {
        const auto& args = call->Args();
        auto* lhs = args[0];
        auto* rhs = args[1];

        auto* arg_ty = ty.MatchWidth(ty.u32(), call->Result()->Type());
        b.InsertBefore(call, [&] {
            if (lhs->Type() != arg_ty) {
                lhs = b.Bitcast(arg_ty, lhs)->Result();
            }
            if (rhs->Type() != arg_ty) {
                rhs = b.Bitcast(arg_ty, rhs)->Result();
            }

            b.BinaryWithResult(call->DetachResult(), op, lhs, rhs)->Result();
        });
        call->Destroy();
    }
    void UGreaterThan(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kGreaterThan);
    }
    void UGreaterThanEqual(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kGreaterThanEqual);
    }
    void ULessThan(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kLessThan);
    }
    void ULessThanEqual(spirv::ir::BuiltinCall* call) {
        EmitBinaryWithUnsignedArgs(call, core::BinaryOp::kLessThanEqual);
    }

    // The SPIR-V Signed methods all interpret their arguments as signed (regardless of the type of
    // the argument). In order to satisfy this, we must bitcast any unsigned argument to a signed
    // type before calling the WGSL equivalent method.
    //
    // The result of the WGSL method will match the arguments, or in this case a signed value. If
    // the SPIR-V instruction expected an unsigned result we must bitcast the WGSL result to the
    // corrrect unsigned type.
    void WrapSignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BuiltinFn func) {
        auto args = call->Args();

        b.InsertBefore(call, [&] {
            auto* result_ty = call->Result()->Type();
            Vector<core::ir::Value*, 2> new_args;

            for (auto* arg : args) {
                if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
                    arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result();
                }
                new_args.Push(arg);
            }

            auto* new_call = b.Call(result_ty, func, new_args);

            core::ir::Value* replacement = new_call->Result();
            if (result_ty->DeepestElement() == ty.u32()) {
                new_call->Result()->SetType(ty.MatchWidth(ty.i32(), result_ty));
                replacement = b.Bitcast(result_ty, replacement)->Result();
            }
            call->Result()->ReplaceAllUsesWith(replacement);
        });
        call->Destroy();
    }

    void Sign(spirv::ir::BuiltinCall* call) {
        WrapSignedSpirvMethods(call, core::BuiltinFn::kSign);
    }
    void Abs(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kAbs); }
    void FindSMsb(spirv::ir::BuiltinCall* call) {
        WrapSignedSpirvMethods(call, core::BuiltinFn::kFirstLeadingBit);
    }
    void SMax(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMax); }
    void SMin(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMin); }
    void SClamp(spirv::ir::BuiltinCall* call) {
        WrapSignedSpirvMethods(call, core::BuiltinFn::kClamp);
    }

    void Ldexp(spirv::ir::BuiltinCall* call) {
        WrapSignedSpirvMethods(call, core::BuiltinFn::kLdexp);
    }

    // The SPIR-V Unsigned methods all interpret their arguments as unsigned (regardless of the type
    // of the argument). In order to satisfy this, we must bitcast any signed argument to an
    // unsigned type before calling the WGSL equivalent method.
    //
    // The result of the WGSL method will match the arguments, or in this case an unsigned value. If
    // the SPIR-V instruction expected a signed result we must bitcast the WGSL result to the
    // correct signed type.
    void WrapUnsignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BuiltinFn func) {
        auto args = call->Args();

        b.InsertBefore(call, [&] {
            auto* result_ty = call->Result()->Type();
            Vector<core::ir::Value*, 2> new_args;

            for (auto* arg : args) {
                if (arg->Type()->IsSignedIntegerScalarOrVector()) {
                    arg = b.Bitcast(ty.MatchWidth(ty.u32(), result_ty), arg)->Result();
                }
                new_args.Push(arg);
            }

            auto* new_call = b.Call(result_ty, func, new_args);

            core::ir::Value* replacement = new_call->Result();
            if (result_ty->DeepestElement() == ty.i32()) {
                new_call->Result()->SetType(ty.MatchWidth(ty.u32(), result_ty));
                replacement = b.Bitcast(result_ty, replacement)->Result();
            }
            call->Result()->ReplaceAllUsesWith(replacement);
        });
        call->Destroy();
    }

    void UMax(spirv::ir::BuiltinCall* call) {
        WrapUnsignedSpirvMethods(call, core::BuiltinFn::kMax);
    }
    void UMin(spirv::ir::BuiltinCall* call) {
        WrapUnsignedSpirvMethods(call, core::BuiltinFn::kMin);
    }
    void UClamp(spirv::ir::BuiltinCall* call) {
        WrapUnsignedSpirvMethods(call, core::BuiltinFn::kClamp);
    }
    void FindUMsb(spirv::ir::BuiltinCall* call) {
        WrapUnsignedSpirvMethods(call, core::BuiltinFn::kFirstLeadingBit);
    }

    void Normalize(spirv::ir::BuiltinCall* call) {
        auto* arg = call->Args()[0];

        b.InsertBefore(call, [&] {
            core::BuiltinFn fn = core::BuiltinFn::kNormalize;
            if (arg->Type()->IsScalar()) {
                fn = core::BuiltinFn::kSign;
            }
            b.CallWithResult(call->DetachResult(), fn, Vector<core::ir::Value*, 1>{arg});
        });
        call->Destroy();
    }

    void FindILsb(spirv::ir::BuiltinCall* call) {
        auto* arg = call->Args()[0];

        b.InsertBefore(call, [&] {
            auto* arg_ty = arg->Type();
            auto* ret_ty = call->Result()->Type();

            auto* v =
                b.Call(arg_ty, core::BuiltinFn::kFirstTrailingBit, Vector<core::ir::Value*, 1>{arg})
                    ->Result();
            if (arg_ty != ret_ty) {
                v = b.Bitcast(ret_ty, v)->Result();
            }
            call->Result()->ReplaceAllUsesWith(v);
        });
        call->Destroy();
    }

    void Refract(spirv::ir::BuiltinCall* call) {
        auto args = call->Args();

        auto* I = args[0];
        auto* N = args[1];
        auto* eta = args[2];

        b.InsertBefore(call, [&] {
            if (I->Type()->IsFloatScalar()) {
                auto* src_ty = I->Type();
                auto* vec_ty = ty.vec(src_ty, 2);
                auto* zero = b.Zero(src_ty);
                I = b.Construct(vec_ty, I, zero)->Result();
                N = b.Construct(vec_ty, N, zero)->Result();

                auto* c = b.Call(vec_ty, core::BuiltinFn::kRefract,
                                 Vector<core::ir::Value*, 3>{I, N, eta});
                auto* s = b.Swizzle(src_ty, c, {0});
                call->Result()->ReplaceAllUsesWith(s->Result());
            } else {
                b.CallWithResult(call->DetachResult(), core::BuiltinFn::kRefract,
                                 Vector<core::ir::Value*, 3>{I, N, eta});
            }
        });
        call->Destroy();
    }

    void Reflect(spirv::ir::BuiltinCall* call) {
        auto args = call->Args();

        auto* I = args[0];
        auto* N = args[1];

        b.InsertBefore(call, [&] {
            if (I->Type()->IsFloatScalar()) {
                auto* v = b.Multiply(I->Type(), I, N)->Result();
                v = b.Multiply(I->Type(), v, N)->Result();
                v = b.Multiply(I->Type(), v, 2.0_f)->Result();
                v = b.Subtract(I->Type(), I, v)->Result();
                call->Result()->ReplaceAllUsesWith(v);
            } else {
                b.CallWithResult(call->DetachResult(), core::BuiltinFn::kReflect,
                                 Vector<core::ir::Value*, 2>{I, N});
            }
        });
        call->Destroy();
    }

    void FaceForward(spirv::ir::BuiltinCall* call) {
        auto args = call->Args();
        auto* N = args[0];
        auto* I = args[1];
        auto* Nref = args[2];

        b.InsertBefore(call, [&] {
            if (I->Type()->IsFloatScalar()) {
                auto* neg = b.Negation(N->Type(), N);
                auto* sel = b.Multiply(I->Type(), I, Nref)->Result();
                sel = b.LessThan(ty.bool_(), sel, b.Zero(sel->Type()))->Result();
                b.CallWithResult(call->DetachResult(), core::BuiltinFn::kSelect, neg, N, sel);
            } else {
                b.CallWithResult(call->DetachResult(), core::BuiltinFn::kFaceForward, N, I, Nref);
            }
        });
        call->Destroy();
    }

    void Modf(spirv::ir::BuiltinCall* call) {
        auto* x = call->Args()[0];
        auto* i = call->Args()[1];
        auto* result_ty = call->Result()->Type();
        auto* modf_result_ty = core::type::CreateModfResult(ty, ir.symbols, result_ty);

        b.InsertBefore(call, [&] {
            auto* c = b.Call(modf_result_ty, core::BuiltinFn::kModf, x);
            auto* whole = b.Access(result_ty, c, 1_u);
            b.Store(i, whole);

            b.AccessWithResult(call->DetachResult(), c, 0_u);
        });
        call->Destroy();
    }

    void Frexp(spirv::ir::BuiltinCall* call) {
        auto* x = call->Args()[0];
        auto* i = call->Args()[1];
        auto* result_ty = call->Result()->Type();
        auto* frexp_result_ty = core::type::CreateFrexpResult(ty, ir.symbols, result_ty);

        b.InsertBefore(call, [&] {
            auto* c = b.Call(frexp_result_ty, core::BuiltinFn::kFrexp, x);
            auto* exp = b.Access(ty.MatchWidth(ty.i32(), result_ty), c, 1_u)->Result();

            if (i->Type()->UnwrapPtr()->DeepestElement()->IsUnsignedIntegerScalar()) {
                exp = b.Bitcast(i->Type()->UnwrapPtr(), exp)->Result();
            }
            b.Store(i, exp);

            b.AccessWithResult(call->DetachResult(), c, 0_u);
        });
        call->Destroy();
    }

    void BitFieldInsert(spirv::ir::BuiltinCall* call) {
        const auto& args = call->Args();
        auto* e = args[0];
        auto* newbits = args[1];
        auto* offset = args[2];
        auto* count = args[3];

        b.InsertBefore(call, [&] {
            if (offset->Type()->IsSignedIntegerScalar()) {
                offset = b.Bitcast(ty.u32(), offset)->Result();
            }
            if (count->Type()->IsSignedIntegerScalar()) {
                count = b.Bitcast(ty.u32(), count)->Result();
            }
            b.CallWithResult(call->DetachResult(), core::BuiltinFn::kInsertBits, e, newbits, offset,
                             count);
        });
        call->Destroy();
    }

    void BitFieldUExtract(spirv::ir::BuiltinCall* call) {
        const auto& args = call->Args();
        auto* e = args[0];
        auto* offset = args[1];
        auto* count = args[2];

        b.InsertBefore(call, [&] {
            bool cast_result = false;
            auto* call_ty = ty.MatchWidth(ty.u32(), e->Type());

            if (e->Type()->DeepestElement()->IsSignedIntegerScalar()) {
                e = b.Bitcast(call_ty, e)->Result();
                cast_result = true;
            }

            if (offset->Type()->IsSignedIntegerScalar()) {
                offset = b.Bitcast(ty.u32(), offset)->Result();
            }
            if (count->Type()->IsSignedIntegerScalar()) {
                count = b.Bitcast(ty.u32(), count)->Result();
            }

            auto* res = b.Call(call_ty, core::BuiltinFn::kExtractBits, e, offset, count)->Result();
            if (cast_result) {
                res = b.Bitcast(call->Result()->Type(), res)->Result();
            }
            call->Result()->ReplaceAllUsesWith(res);
        });
        call->Destroy();
    }

    void BitFieldSExtract(spirv::ir::BuiltinCall* call) {
        const auto& args = call->Args();
        auto* e = args[0];
        auto* offset = args[1];
        auto* count = args[2];

        b.InsertBefore(call, [&] {
            bool cast_result = false;
            auto* call_ty = ty.MatchWidth(ty.i32(), e->Type());

            if (e->Type()->DeepestElement()->IsUnsignedIntegerScalar()) {
                e = b.Bitcast(call_ty, e)->Result();
                cast_result = true;
            }

            if (offset->Type()->IsSignedIntegerScalar()) {
                offset = b.Bitcast(ty.u32(), offset)->Result();
            }
            if (count->Type()->IsSignedIntegerScalar()) {
                count = b.Bitcast(ty.u32(), count)->Result();
            }

            auto* res = b.Call(call_ty, core::BuiltinFn::kExtractBits, e, offset, count)->Result();
            if (cast_result) {
                res = b.Bitcast(call->Result()->Type(), res)->Result();
            }
            call->Result()->ReplaceAllUsesWith(res);
        });
        call->Destroy();
    }

    void BitCount(spirv::ir::BuiltinCall* call) {
        auto arg = call->Args()[0];

        b.InsertBefore(call, [&] {
            auto* res_ty = call->Result()->Type();
            auto* arg_ty = arg->Type();

            auto* bc = b.Call(arg_ty, core::BuiltinFn::kCountOneBits, arg)->Result();
            if (res_ty != arg_ty) {
                bc = b.Bitcast(res_ty, bc)->Result();
            }
            call->Result()->ReplaceAllUsesWith(bc);
        });
        call->Destroy();
    }

    void ShiftLeftLogical(spirv::ir::BuiltinCall* call) {
        const auto& args = call->Args();

        b.InsertBefore(call, [&] {
            auto* base = args[0];
            auto* shift = args[1];

            if (!shift->Type()->IsUnsignedIntegerScalarOrVector()) {
                shift = b.Bitcast(ty.MatchWidth(ty.u32(), shift->Type()), shift)->Result();
            }

            auto* bin = b.Binary(core::BinaryOp::kShiftLeft, base->Type(), base, shift)->Result();
            if (base->Type() != call->Result()->Type()) {
                bin = b.Bitcast(call->Result()->Type(), bin)->Result();
            }
            call->Result()->ReplaceAllUsesWith(bin);
        });
        call->Destroy();
    }

    void ShiftRightLogical(spirv::ir::BuiltinCall* call) {
        const auto& args = call->Args();

        b.InsertBefore(call, [&] {
            auto* base = args[0];
            auto* shift = args[1];

            auto* u_ty = ty.MatchWidth(ty.u32(), base->Type());
            if (!base->Type()->IsUnsignedIntegerScalarOrVector()) {
                base = b.Bitcast(u_ty, base)->Result();
            }
            if (!shift->Type()->IsUnsignedIntegerScalarOrVector()) {
                shift = b.Bitcast(u_ty, shift)->Result();
            }

            auto* bin = b.Binary(core::BinaryOp::kShiftRight, u_ty, base, shift)->Result();
            if (u_ty != call->Result()->Type()) {
                bin = b.Bitcast(call->Result()->Type(), bin)->Result();
            }
            call->Result()->ReplaceAllUsesWith(bin);
        });
        call->Destroy();
    }

    void ShiftRightArithmetic(spirv::ir::BuiltinCall* call) {
        const auto& args = call->Args();

        b.InsertBefore(call, [&] {
            auto* base = args[0];
            auto* shift = args[1];

            auto* s_ty = ty.MatchWidth(ty.i32(), base->Type());
            if (!base->Type()->IsSignedIntegerScalarOrVector()) {
                base = b.Bitcast(s_ty, base)->Result();
            }
            if (!shift->Type()->IsUnsignedIntegerScalarOrVector()) {
                shift = b.Bitcast(ty.MatchWidth(ty.u32(), shift->Type()), shift)->Result();
            }

            auto* bin = b.Binary(core::BinaryOp::kShiftRight, s_ty, base, shift)->Result();
            if (s_ty != call->Result()->Type()) {
                bin = b.Bitcast(call->Result()->Type(), bin)->Result();
            }
            call->Result()->ReplaceAllUsesWith(bin);
        });
        call->Destroy();
    }

    void Inverse(spirv::ir::BuiltinCall* call) {
        auto* arg = call->Args()[0];
        auto* mat_ty = arg->Type()->As<core::type::Matrix>();
        TINT_ASSERT(mat_ty);
        TINT_ASSERT(mat_ty->Columns() == mat_ty->Rows());

        auto* elem_ty = mat_ty->Type();

        b.InsertBefore(call, [&] {
            auto* det =
                b.Call(elem_ty, core::BuiltinFn::kDeterminant, Vector<core::ir::Value*, 1>{arg});
            core::ir::Value* one = nullptr;
            if (elem_ty->Is<core::type::F32>()) {
                one = b.Constant(1.0_f);
            } else if (elem_ty->Is<core::type::F16>()) {
                one = b.Constant(1.0_h);
            } else {
                TINT_UNREACHABLE();
            }
            auto* inv_det = b.Divide(elem_ty, one, det);

            // Returns (m * n) - (o * p)
            auto sub_mul2 = [&](auto* m, auto* n, auto* o, auto* p) {
                auto* x = b.Multiply(elem_ty, m, n);
                auto* y = b.Multiply(elem_ty, o, p);
                return b.Subtract(elem_ty, x, y);
            };

            // Returns (m * n) - (o * p) + (q * r)
            auto sub_add_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) {
                auto* w = b.Multiply(elem_ty, m, n);
                auto* x = b.Multiply(elem_ty, o, p);
                auto* y = b.Multiply(elem_ty, q, r);

                auto* z = b.Subtract(elem_ty, w, x);
                return b.Add(elem_ty, z, y);
            };

            // Returns (m * n) + (o * p) - (q * r)
            auto add_sub_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) {
                auto* w = b.Multiply(elem_ty, m, n);
                auto* x = b.Multiply(elem_ty, o, p);
                auto* y = b.Multiply(elem_ty, q, r);

                auto* z = b.Add(elem_ty, w, x);
                return b.Subtract(elem_ty, z, y);
            };

            switch (mat_ty->Columns()) {
                case 2: {
                    auto* neg_inv_det = b.Negation(elem_ty, inv_det);

                    auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
                    auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
                    auto* mc = b.Access(elem_ty, arg, 1_u, 0_u);
                    auto* md = b.Access(elem_ty, arg, 1_u, 1_u);

                    auto* r_00 = b.Multiply(elem_ty, inv_det, md);
                    auto* r_01 = b.Multiply(elem_ty, neg_inv_det, mb);
                    auto* r_10 = b.Multiply(elem_ty, neg_inv_det, mc);
                    auto* r_11 = b.Multiply(elem_ty, inv_det, ma);

                    auto* r1 = b.Construct(ty.vec2(elem_ty), r_00, r_01);
                    auto* r2 = b.Construct(ty.vec2(elem_ty), r_10, r_11);
                    b.ConstructWithResult(call->DetachResult(), r1, r2);
                    break;
                }
                case 3: {
                    auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
                    auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
                    auto* mc = b.Access(elem_ty, arg, 0_u, 2_u);
                    auto* md = b.Access(elem_ty, arg, 1_u, 0_u);
                    auto* me = b.Access(elem_ty, arg, 1_u, 1_u);
                    auto* mf = b.Access(elem_ty, arg, 1_u, 2_u);
                    auto* mg = b.Access(elem_ty, arg, 2_u, 0_u);
                    auto* mh = b.Access(elem_ty, arg, 2_u, 1_u);
                    auto* mi = b.Access(elem_ty, arg, 2_u, 2_u);

                    // e * i - f * h
                    auto* r_00 = sub_mul2(me, mi, mf, mh);
                    // c * h - b * i
                    auto* r_01 = sub_mul2(mc, mh, mb, mi);
                    // b * f - c * e
                    auto* r_02 = sub_mul2(mb, mf, mc, me);

                    // f * g - d * i
                    auto* r_10 = sub_mul2(mf, mg, md, mi);
                    // a * i - c * g
                    auto* r_11 = sub_mul2(ma, mi, mc, mg);
                    // c * d - a * f
                    auto* r_12 = sub_mul2(mc, md, ma, mf);

                    // d * h - e * g
                    auto* r_20 = sub_mul2(md, mh, me, mg);
                    // b * g - a * h
                    auto* r_21 = sub_mul2(mb, mg, ma, mh);
                    // a * e - b * d
                    auto* r_22 = sub_mul2(ma, me, mb, md);

                    auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02);
                    auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12);
                    auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22);

                    auto* m = b.Construct(mat_ty, r1, r2, r3);
                    auto* inv = b.Multiply(mat_ty, inv_det, m);
                    call->Result()->ReplaceAllUsesWith(inv->Result());
                    break;
                }
                case 4: {
                    auto* ma = b.Access(elem_ty, arg, 0_u, 0_u);
                    auto* mb = b.Access(elem_ty, arg, 0_u, 1_u);
                    auto* mc = b.Access(elem_ty, arg, 0_u, 2_u);
                    auto* md = b.Access(elem_ty, arg, 0_u, 3_u);
                    auto* me = b.Access(elem_ty, arg, 1_u, 0_u);
                    auto* mf = b.Access(elem_ty, arg, 1_u, 1_u);
                    auto* mg = b.Access(elem_ty, arg, 1_u, 2_u);
                    auto* mh = b.Access(elem_ty, arg, 1_u, 3_u);
                    auto* mi = b.Access(elem_ty, arg, 2_u, 0_u);
                    auto* mj = b.Access(elem_ty, arg, 2_u, 1_u);
                    auto* mk = b.Access(elem_ty, arg, 2_u, 2_u);
                    auto* ml = b.Access(elem_ty, arg, 2_u, 3_u);
                    auto* mm = b.Access(elem_ty, arg, 3_u, 0_u);
                    auto* mn = b.Access(elem_ty, arg, 3_u, 1_u);
                    auto* mo = b.Access(elem_ty, arg, 3_u, 2_u);
                    auto* mp = b.Access(elem_ty, arg, 3_u, 3_u);

                    // kplo = k * p - l * o
                    auto* kplo = sub_mul2(mk, mp, ml, mo);
                    // jpln = j * p - l * n
                    auto* jpln = sub_mul2(mj, mp, ml, mn);
                    // jokn = j * o - k * n;
                    auto* jokn = sub_mul2(mj, mo, mk, mn);
                    // gpho = g * p - h * o
                    auto* gpho = sub_mul2(mg, mp, mh, mo);
                    // fphn = f * p - h * n
                    auto* fphn = sub_mul2(mf, mp, mh, mn);
                    // fogn = f * o - g * n;
                    auto* fogn = sub_mul2(mf, mo, mg, mn);
                    // glhk = g * l - h * k
                    auto* glhk = sub_mul2(mg, ml, mh, mk);
                    // flhj = f * l - h * j
                    auto* flhj = sub_mul2(mf, ml, mh, mj);
                    // fkgj = f * k - g * j;
                    auto* fkgj = sub_mul2(mf, mk, mg, mj);
                    // iplm = i * p - l * m
                    auto* iplm = sub_mul2(mi, mp, ml, mm);
                    // iokm = i * o - k * m
                    auto* iokm = sub_mul2(mi, mo, mk, mm);
                    // ephm = e * p - h * m;
                    auto* ephm = sub_mul2(me, mp, mh, mm);
                    // eogm = e * o - g * m
                    auto* eogm = sub_mul2(me, mo, mg, mm);
                    // elhi = e * l - h * i
                    auto* elhi = sub_mul2(me, ml, mh, mi);
                    // ekgi = e * k - g * i;
                    auto* ekgi = sub_mul2(me, mk, mg, mi);
                    // injm = i * n - j * m
                    auto* injm = sub_mul2(mi, mn, mj, mm);
                    // enfm = e * n - f * m
                    auto* enfm = sub_mul2(me, mn, mf, mm);
                    // ejfi = e * j - f * i;
                    auto* ejfi = sub_mul2(me, mj, mf, mi);

                    auto* neg_b = b.Negation(elem_ty, mb);
                    // f * kplo - g * jpln + h * jokn
                    auto* r_00 = sub_add_mul3(mf, kplo, mg, jpln, mh, jokn);
                    // -b * kplo + c * jpln - d * jokn
                    auto* r_01 = add_sub_mul3(neg_b, kplo, mc, jpln, md, jokn);
                    // b * gpho - c * fphn + d * fogn
                    auto* r_02 = sub_add_mul3(mb, gpho, mc, fphn, md, fogn);
                    // -b * glhk + c * flhj - d * fkgj
                    auto* r_03 = add_sub_mul3(neg_b, glhk, mc, flhj, md, fkgj);

                    auto* neg_e = b.Negation(elem_ty, me);
                    auto* neg_a = b.Negation(elem_ty, ma);
                    // -e * kplo + g * iplm - h * iokm
                    auto* r_10 = add_sub_mul3(neg_e, kplo, mg, iplm, mh, iokm);
                    // a * kplo - c * iplm + d * iokm
                    auto* r_11 = sub_add_mul3(ma, kplo, mc, iplm, md, iokm);
                    // -a * gpho + c * ephm - d * eogm
                    auto* r_12 = add_sub_mul3(neg_a, gpho, mc, ephm, md, eogm);
                    // a * glhk - c * elhi + d * ekgi
                    auto* r_13 = sub_add_mul3(ma, glhk, mc, elhi, md, ekgi);

                    // e * jpln - f * iplm + h * injm
                    auto* r_20 = sub_add_mul3(me, jpln, mf, iplm, mh, injm);
                    // -a * jpln + b * iplm - d * injm
                    auto* r_21 = add_sub_mul3(neg_a, jpln, mb, iplm, md, injm);
                    // a * fphn - b * ephm + d * enfm
                    auto* r_22 = sub_add_mul3(ma, fphn, mb, ephm, md, enfm);
                    // -a * flhj + b * elhi - d * ejfi
                    auto* r_23 = add_sub_mul3(neg_a, flhj, mb, elhi, md, ejfi);

                    // -e * jokn + f * iokm - g * injm
                    auto* r_30 = add_sub_mul3(neg_e, jokn, mf, iokm, mg, injm);
                    // a * jokn - b * iokm + c * injm
                    auto* r_31 = sub_add_mul3(ma, jokn, mb, iokm, mc, injm);
                    // -a * fogn + b * eogm - c * enfm
                    auto* r_32 = add_sub_mul3(neg_a, fogn, mb, eogm, mc, enfm);
                    // a * fkgj - b * ekgi + c * ejfi
                    auto* r_33 = sub_add_mul3(ma, fkgj, mb, ekgi, mc, ejfi);

                    auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02, r_03);
                    auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12, r_13);
                    auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22, r_23);
                    auto* r4 = b.Construct(ty.vec3(elem_ty), r_30, r_31, r_32, r_33);

                    auto* m = b.Construct(mat_ty, r1, r2, r3, r4);
                    auto* inv = b.Multiply(mat_ty, inv_det, m);
                    call->Result()->ReplaceAllUsesWith(inv->Result());
                    break;
                }
                default: {
                    TINT_UNREACHABLE();
                }
            }
        });
        call->Destroy();
    }
};

}  // namespace

Result<SuccessType> Builtins(core::ir::Module& ir) {
    auto result = ValidateAndDumpIfNeeded(ir, "spirv.Builtins",
                                          core::ir::Capabilities{
                                              core::ir::Capability::kAllowOverrides,
                                              core::ir::Capability::kAllowNonCoreTypes,
                                          });
    if (result != Success) {
        return result.Failure();
    }

    State{ir}.Process();

    return Success;
}

}  // namespace tint::spirv::reader::lower
