blob: 17f179c152002495f316f06bc99fe67efa44a767 [file] [log] [blame]
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/lang/core/ir/transform/builtin_polyfill.h"
#include <utility>
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/sampled_texture.h"
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
namespace tint::ir::transform {
namespace {
/// PIMPL state for the transform.
struct State {
/// The polyfill config.
const BuiltinPolyfillConfig& config;
/// The IR module.
Module* ir = nullptr;
/// The IR builder.
Builder b{*ir};
/// The type manager.
core::type::Manager& ty{ir->Types()};
/// The symbol table.
SymbolTable& sym{ir->symbols};
/// Process the module.
void Process() {
// Find the builtin call instructions that may need to be polyfilled.
Vector<ir::CoreBuiltinCall*, 4> worklist;
for (auto* inst : ir->instructions.Objects()) {
if (!inst->Alive()) {
continue;
}
if (auto* builtin = inst->As<ir::CoreBuiltinCall>()) {
switch (builtin->Func()) {
case core::Function::kCountLeadingZeros:
if (config.count_leading_zeros) {
worklist.Push(builtin);
}
break;
case core::Function::kCountTrailingZeros:
if (config.count_trailing_zeros) {
worklist.Push(builtin);
}
break;
case core::Function::kFirstLeadingBit:
if (config.first_leading_bit) {
worklist.Push(builtin);
}
break;
case core::Function::kFirstTrailingBit:
if (config.first_trailing_bit) {
worklist.Push(builtin);
}
break;
case core::Function::kSaturate:
if (config.saturate) {
worklist.Push(builtin);
}
break;
case core::Function::kTextureSampleBaseClampToEdge:
if (config.texture_sample_base_clamp_to_edge_2d_f32) {
auto* tex =
builtin->Args()[0]->Type()->As<core::type::SampledTexture>();
if (tex && tex->dim() == core::type::TextureDimension::k2d &&
tex->type()->Is<core::type::F32>()) {
worklist.Push(builtin);
}
}
break;
default:
break;
}
}
}
// Polyfill the builtin call instructions that we found.
for (auto* builtin : worklist) {
ir::Value* replacement = nullptr;
switch (builtin->Func()) {
case core::Function::kCountLeadingZeros:
replacement = CountLeadingZeros(builtin);
break;
case core::Function::kCountTrailingZeros:
replacement = CountTrailingZeros(builtin);
break;
case core::Function::kFirstLeadingBit:
replacement = FirstLeadingBit(builtin);
break;
case core::Function::kFirstTrailingBit:
replacement = FirstTrailingBit(builtin);
break;
case core::Function::kSaturate:
replacement = Saturate(builtin);
break;
case core::Function::kTextureSampleBaseClampToEdge:
replacement = TextureSampleBaseClampToEdge_2d_f32(builtin);
break;
default:
break;
}
TINT_ASSERT_OR_RETURN(replacement);
// Replace the old builtin call result with the new value.
if (auto name = ir->NameOf(builtin->Result())) {
ir->SetName(replacement, name);
}
builtin->Result()->ReplaceAllUsesWith(replacement);
builtin->Destroy();
}
}
/// Return a type with element type @p type that has the same number of vector components as
/// @p match. If @p match is scalar just return @p type.
/// @param el_ty the type to extend
/// @param match the type to match the component count of
/// @returns a type with the same number of vector components as @p match
const core::type::Type* MatchWidth(const core::type::Type* el_ty,
const core::type::Type* match) {
if (auto* vec = match->As<core::type::Vector>()) {
return ty.vec(el_ty, vec->Width());
}
return el_ty;
}
/// Return a constant that has the same number of vector components as @p match, each with the
/// value @p element. If @p match is scalar just return @p element.
/// @param element the value to extend
/// @param match the type to match the component count of
/// @returns a value with the same number of vector components as @p match
ir::Constant* MatchWidth(ir::Constant* element, const core::type::Type* match) {
if (auto* vec = match->As<core::type::Vector>()) {
return b.Splat(MatchWidth(element->Type(), match), element, vec->Width());
}
return element;
}
/// Polyfill a `countLeadingZeros()` builtin call.
/// @param call the builtin call instruction
/// @returns the replacement value
ir::Value* CountLeadingZeros(ir::CoreBuiltinCall* call) {
auto* input = call->Args()[0];
auto* result_ty = input->Type();
auto* uint_ty = MatchWidth(ty.u32(), result_ty);
auto* bool_ty = MatchWidth(ty.bool_(), result_ty);
// Make an u32 constant with the same component count as result_ty.
auto V = [&](uint32_t u) { return MatchWidth(b.Constant(u32(u)), result_ty); };
Value* result = nullptr;
b.InsertBefore(call, [&] {
// %x = %input;
// if (%x is signed) {
// %x = bitcast<u32>(%x)
// }
// %b16 = select(0, 16, %x <= 0x0000ffff);
// %x <<= %b16;
// %b8 = select(0, 8, %x <= 0x00ffffff);
// %x <<= %b8;
// %b4 = select(0, 4, %x <= 0x0fffffff);
// %x <<= %b4;
// %b2 = select(0, 2, %x <= 0x3fffffff);
// %x <<= %b2;
// %b1 = select(0, 1, %x <= 0x7fffffff);
// %b0 = select(0, 1, %x == 0);
// %result = (%b16 | %b8 | %b4 | %b2 | %b1) + %b0;
auto* x = input;
if (result_ty->is_signed_integer_scalar_or_vector()) {
x = b.Bitcast(uint_ty, x)->Result();
}
auto* b16 = b.Call(uint_ty, core::Function::kSelect, V(0), V(16),
b.LessThanEqual(bool_ty, x, V(0x0000ffff)));
x = b.ShiftLeft(uint_ty, x, b16)->Result();
auto* b8 = b.Call(uint_ty, core::Function::kSelect, V(0), V(8),
b.LessThanEqual(bool_ty, x, V(0x00ffffff)));
x = b.ShiftLeft(uint_ty, x, b8)->Result();
auto* b4 = b.Call(uint_ty, core::Function::kSelect, V(0), V(4),
b.LessThanEqual(bool_ty, x, V(0x0fffffff)));
x = b.ShiftLeft(uint_ty, x, b4)->Result();
auto* b2 = b.Call(uint_ty, core::Function::kSelect, V(0), V(2),
b.LessThanEqual(bool_ty, x, V(0x3fffffff)));
x = b.ShiftLeft(uint_ty, x, b2)->Result();
auto* b1 = b.Call(uint_ty, core::Function::kSelect, V(0), V(1),
b.LessThanEqual(bool_ty, x, V(0x7fffffff)));
auto* b0 =
b.Call(uint_ty, core::Function::kSelect, V(0), V(1), b.Equal(bool_ty, x, V(0)));
result = b.Add(uint_ty,
b.Or(uint_ty, b16,
b.Or(uint_ty, b8,
b.Or(uint_ty, b4, b.Or(uint_ty, b2, b.Or(uint_ty, b1, b0))))),
b0)
->Result();
if (result_ty->is_signed_integer_scalar_or_vector()) {
result = b.Bitcast(result_ty, result)->Result();
}
});
return result;
}
/// Polyfill a `countTrailingZeros()` builtin call.
/// @param call the builtin call instruction
/// @returns the replacement value
ir::Value* CountTrailingZeros(ir::CoreBuiltinCall* call) {
auto* input = call->Args()[0];
auto* result_ty = input->Type();
auto* uint_ty = MatchWidth(ty.u32(), result_ty);
auto* bool_ty = MatchWidth(ty.bool_(), result_ty);
// Make an u32 constant with the same component count as result_ty.
auto V = [&](uint32_t u) { return MatchWidth(b.Constant(u32(u)), result_ty); };
Value* result = nullptr;
b.InsertBefore(call, [&] {
// %x = %input;
// if (%x is signed) {
// %x = bitcast<u32>(%x)
// }
// %b16 = select(0, 16, (%x & 0x0000ffff) == 0);
// %x >>= %b16;
// %b8 = select(0, 8, (%x & 0x000000ff) == 0);
// %x >>= %b8;
// %b4 = select(0, 4, (%x & 0x0000000f) == 0);
// %x >>= %b4;
// %b2 = select(0, 2, (%x & 0x00000003) == 0);
// %x >>= %b2;
// %b1 = select(0, 1, (%x & 0x00000001) == 0);
// %b0 = select(0, 1, (%x & 0x00000001) == 0);
// %result = (%b16 | %b8 | %b4 | %b2 | %b1) + %b0;
auto* x = input;
if (result_ty->is_signed_integer_scalar_or_vector()) {
x = b.Bitcast(uint_ty, x)->Result();
}
auto* b16 = b.Call(uint_ty, core::Function::kSelect, V(0), V(16),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x0000ffff)), V(0)));
x = b.ShiftRight(uint_ty, x, b16)->Result();
auto* b8 = b.Call(uint_ty, core::Function::kSelect, V(0), V(8),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x000000ff)), V(0)));
x = b.ShiftRight(uint_ty, x, b8)->Result();
auto* b4 = b.Call(uint_ty, core::Function::kSelect, V(0), V(4),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x0000000f)), V(0)));
x = b.ShiftRight(uint_ty, x, b4)->Result();
auto* b2 = b.Call(uint_ty, core::Function::kSelect, V(0), V(2),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x00000003)), V(0)));
x = b.ShiftRight(uint_ty, x, b2)->Result();
auto* b1 = b.Call(uint_ty, core::Function::kSelect, V(0), V(1),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x00000001)), V(0)));
auto* b0 =
b.Call(uint_ty, core::Function::kSelect, V(0), V(1), b.Equal(bool_ty, x, V(0)));
result = b.Add(uint_ty,
b.Or(uint_ty, b16,
b.Or(uint_ty, b8, b.Or(uint_ty, b4, b.Or(uint_ty, b2, b1)))),
b0)
->Result();
if (result_ty->is_signed_integer_scalar_or_vector()) {
result = b.Bitcast(result_ty, result)->Result();
}
});
return result;
}
/// Polyfill a `firstLeadingBit()` builtin call.
/// @param call the builtin call instruction
/// @returns the replacement value
ir::Value* FirstLeadingBit(ir::CoreBuiltinCall* call) {
auto* input = call->Args()[0];
auto* result_ty = input->Type();
auto* uint_ty = MatchWidth(ty.u32(), result_ty);
auto* bool_ty = MatchWidth(ty.bool_(), result_ty);
// Make an u32 constant with the same component count as result_ty.
auto V = [&](uint32_t u) { return MatchWidth(b.Constant(u32(u)), result_ty); };
Value* result = nullptr;
b.InsertBefore(call, [&] {
// %x = %input;
// if (%x is signed) {
// %x = select(u32(%x), ~u32(%x), x > 0x80000000);
// }
// %b16 = select(16, 0, (%x & 0xffff0000) == 0);
// %x >>= %b16;
// %b8 = select(8, 0, (%x & 0x0000ff00) == 0);
// %x >>= %b8;
// %b4 = select(4, 0, (%x & 0x000000f0) == 0);
// %x >>= %b4;
// %b2 = select(2, 0, (%x & 0x0000000c) == 0);
// %x >>= %b2;
// %b1 = select(1, 0, (%x & 0x00000002) == 0);
// %result = %b16 | %b8 | %b4 | %b2 | %b1;
// %result = select(%result, 0xffffffff, %x == 0);
auto* x = input;
if (result_ty->is_signed_integer_scalar_or_vector()) {
x = b.Bitcast(uint_ty, x)->Result();
auto* inverted = b.Complement(uint_ty, x);
x = b.Call(uint_ty, core::Function::kSelect, inverted, x,
b.LessThan(bool_ty, x, V(0x80000000)))
->Result();
}
auto* b16 = b.Call(uint_ty, core::Function::kSelect, V(16), V(0),
b.Equal(bool_ty, b.And(uint_ty, x, V(0xffff0000)), V(0)));
x = b.ShiftRight(uint_ty, x, b16)->Result();
auto* b8 = b.Call(uint_ty, core::Function::kSelect, V(8), V(0),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x0000ff00)), V(0)));
x = b.ShiftRight(uint_ty, x, b8)->Result();
auto* b4 = b.Call(uint_ty, core::Function::kSelect, V(4), V(0),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x000000f0)), V(0)));
x = b.ShiftRight(uint_ty, x, b4)->Result();
auto* b2 = b.Call(uint_ty, core::Function::kSelect, V(2), V(0),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x0000000c)), V(0)));
x = b.ShiftRight(uint_ty, x, b2)->Result();
auto* b1 = b.Call(uint_ty, core::Function::kSelect, V(1), V(0),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x00000002)), V(0)));
result = b.Or(uint_ty, b16, b.Or(uint_ty, b8, b.Or(uint_ty, b4, b.Or(uint_ty, b2, b1))))
->Result();
result = b.Call(uint_ty, core::Function::kSelect, result, V(0xffffffff),
b.Equal(bool_ty, x, V(0)))
->Result();
if (result_ty->is_signed_integer_scalar_or_vector()) {
result = b.Bitcast(result_ty, result)->Result();
}
});
return result;
}
/// Polyfill a `firstTrailingBit()` builtin call.
/// @param call the builtin call instruction
/// @returns the replacement value
ir::Value* FirstTrailingBit(ir::CoreBuiltinCall* call) {
auto* input = call->Args()[0];
auto* result_ty = input->Type();
auto* uint_ty = MatchWidth(ty.u32(), result_ty);
auto* bool_ty = MatchWidth(ty.bool_(), result_ty);
// Make an u32 constant with the same component count as result_ty.
auto V = [&](uint32_t u) { return MatchWidth(b.Constant(u32(u)), result_ty); };
Value* result = nullptr;
b.InsertBefore(call, [&] {
// %x = %input;
// if (%x is signed) {
// %x = bitcast<u32>(%x)
// }
// %b16 = select(0, 16, (%x & 0x0000ffff) == 0);
// %x >>= %b16;
// %b8 = select(0, 8, (%x & 0x000000ff) == 0);
// %x >>= %b8;
// %b4 = select(0, 4, (%x & 0x0000000f) == 0);
// %x >>= %b4;
// %b2 = select(0, 2, (%x & 0x00000003) == 0);
// %x >>= %b2;
// %b1 = select(0, 1, (%x & 0x00000001) == 0);
// %result = %b16 | %b8 | %b4 | %b2 | %b1;
// %result = select(%result, 0xffffffff, %x == 0);
auto* x = input;
if (result_ty->is_signed_integer_scalar_or_vector()) {
x = b.Bitcast(uint_ty, x)->Result();
}
auto* b16 = b.Call(uint_ty, core::Function::kSelect, V(0), V(16),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x0000ffff)), V(0)));
x = b.ShiftRight(uint_ty, x, b16)->Result();
auto* b8 = b.Call(uint_ty, core::Function::kSelect, V(0), V(8),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x000000ff)), V(0)));
x = b.ShiftRight(uint_ty, x, b8)->Result();
auto* b4 = b.Call(uint_ty, core::Function::kSelect, V(0), V(4),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x0000000f)), V(0)));
x = b.ShiftRight(uint_ty, x, b4)->Result();
auto* b2 = b.Call(uint_ty, core::Function::kSelect, V(0), V(2),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x00000003)), V(0)));
x = b.ShiftRight(uint_ty, x, b2)->Result();
auto* b1 = b.Call(uint_ty, core::Function::kSelect, V(0), V(1),
b.Equal(bool_ty, b.And(uint_ty, x, V(0x00000001)), V(0)));
result = b.Or(uint_ty, b16, b.Or(uint_ty, b8, b.Or(uint_ty, b4, b.Or(uint_ty, b2, b1))))
->Result();
result = b.Call(uint_ty, core::Function::kSelect, result, V(0xffffffff),
b.Equal(bool_ty, x, V(0)))
->Result();
if (result_ty->is_signed_integer_scalar_or_vector()) {
result = b.Bitcast(result_ty, result)->Result();
}
});
return result;
}
/// Polyfill a `saturate()` builtin call.
/// @param call the builtin call instruction
/// @returns the replacement value
ir::Value* Saturate(ir::CoreBuiltinCall* call) {
// Replace `saturate(x)` with `clamp(x, 0., 1.)`.
auto* type = call->Result()->Type();
ir::Constant* zero = nullptr;
ir::Constant* one = nullptr;
if (type->DeepestElement()->Is<core::type::F32>()) {
zero = MatchWidth(b.Constant(0_f), type);
one = MatchWidth(b.Constant(1_f), type);
} else if (type->DeepestElement()->Is<core::type::F16>()) {
zero = MatchWidth(b.Constant(0_h), type);
one = MatchWidth(b.Constant(1_h), type);
}
auto* clamp = b.Call(type, core::Function::kClamp, Vector{call->Args()[0], zero, one});
clamp->InsertBefore(call);
return clamp->Result();
}
/// Polyfill a `textureSampleBaseClampToEdge()` builtin call for 2D F32 textures.
/// @param call the builtin call instruction
/// @returns the replacement value
ir::Value* TextureSampleBaseClampToEdge_2d_f32(ir::CoreBuiltinCall* call) {
// Replace `textureSampleBaseClampToEdge(%texture, %sample, %coords)` with:
// %dims = vec2f(textureDimensions(%texture));
// %half_texel = vec2f(0.5) / dims;
// %clamped = clamp(%coord, %half_texel, 1.0 - %half_texel);
// %result = textureSampleLevel(%texture, %sampler, %clamped, 0);
ir::Value* result = nullptr;
auto* texture = call->Args()[0];
auto* sampler = call->Args()[1];
auto* coords = call->Args()[2];
b.InsertBefore(call, [&] {
auto* vec2f = ty.vec2<f32>();
auto* dims = b.Call(ty.vec2<u32>(), core::Function::kTextureDimensions, texture);
auto* fdims = b.Convert(vec2f, dims);
auto* half_texel = b.Divide(vec2f, b.Splat(vec2f, 0.5_f, 2), fdims);
auto* one_minus_half_texel = b.Subtract(vec2f, b.Splat(vec2f, 1_f, 2), half_texel);
auto* clamped =
b.Call(vec2f, core::Function::kClamp, coords, half_texel, one_minus_half_texel);
result = b.Call(ty.vec4<f32>(), core::Function::kTextureSampleLevel, texture, sampler,
clamped, 0_f)
->Result();
});
return result;
}
};
} // namespace
Result<SuccessType, std::string> BuiltinPolyfill(Module* ir, const BuiltinPolyfillConfig& config) {
auto result = ValidateAndDumpIfNeeded(*ir, "BuiltinPolyfill transform");
if (!result) {
return result;
}
State{config, ir}.Process();
return Success;
}
} // namespace tint::ir::transform