blob: c71c963b64ab85d2e67fdc04c78ddf03dbb2f89f [file] [log] [blame]
// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/hlsl/writer/raise/decompose_memory_access.h"
#include <utility>
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/hlsl/builtin_fn.h"
#include "src/tint/lang/hlsl/ir/member_builtin_call.h"
#include "src/tint/lang/hlsl/type/byte_address_buffer.h"
#include "src/tint/utils/result/result.h"
namespace tint::hlsl::writer::raise {
namespace {
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
/// PIMPL state for the transform.
struct State {
/// The IR module.
core::ir::Module& ir;
/// The IR builder.
core::ir::Builder b{ir};
/// The type manager.
core::type::Manager& ty{ir.Types()};
using VarTypePair = std::pair<core::ir::Var*, const core::type::Type*>;
/// Maps a struct to the load function
Hashmap<VarTypePair, core::ir::Function*, 2> var_and_type_to_load_fn_{};
/// Process the module.
void Process() {
Vector<core::ir::Var*, 4> var_worklist;
for (auto* inst : *ir.root_block) {
// Allow this to run before or after PromoteInitializers by handling non-var root_block
// entries
auto* var = inst->As<core::ir::Var>();
if (!var) {
continue;
}
// Var must be a pointer
auto* var_ty = var->Result(0)->Type()->As<core::type::Pointer>();
TINT_ASSERT(var_ty);
// Only care about storage address space variables.
if (var_ty->AddressSpace() != core::AddressSpace::kStorage) {
continue;
}
var_worklist.Push(var);
}
for (auto* var : var_worklist) {
auto* result = var->Result(0);
// Find all the usages of the `var` which is loading or storing.
Vector<core::ir::Instruction*, 4> usage_worklist;
for (auto& usage : result->Usages()) {
Switch(
usage->instruction,
[&](core::ir::LoadVectorElement* lve) { usage_worklist.Push(lve); },
[&](core::ir::StoreVectorElement* sve) { usage_worklist.Push(sve); },
[&](core::ir::Store* st) { usage_worklist.Push(st); },
[&](core::ir::Load* ld) { usage_worklist.Push(ld); },
[&](core::ir::Access* a) { usage_worklist.Push(a); });
}
auto* var_ty = result->Type()->As<core::type::Pointer>();
for (auto* inst : usage_worklist) {
// Load instructions can be destroyed by the replacing access function
if (!inst->Alive()) {
continue;
}
Switch(
inst,
[&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
[&](core::ir::StoreVectorElement* s) { StoreVectorElement(s, var, var_ty); },
[&](core::ir::Store* s) { Store(s); }, //
[&](core::ir::Load* l) { Load(l, var); }, //
[&](core::ir::Access* a) { Access(a, var, a->Object()->Type(), 0u); }, //
TINT_ICE_ON_NO_MATCH);
}
// Swap the result type of the `var` to the new HLSL result type
result->SetType(ty.Get<hlsl::type::ByteAddressBuffer>(var_ty->Access()));
}
}
uint32_t CalculateVectorIndex(core::ir::Value* v, const core::type::Type* store_ty) {
auto* idx_value = v->As<core::ir::Constant>();
// TODO(dsinclair): Handle non-constant vector indices.
TINT_ASSERT(idx_value);
return idx_value->Value()->ValueAs<uint32_t>() * store_ty->DeepestElement()->Size();
}
// Creates the appropriate load instructions for the given result type.
core::ir::Call* MakeLoad(core::ir::Instruction* inst,
core::ir::Var* var,
const core::type::Type* result_ty,
core::ir::Value* offset) {
if (result_ty->is_numeric_scalar_or_vector()) {
return MakeScalarOrVectorLoad(var, result_ty, offset);
}
return tint::Switch(
result_ty, //
[&](const core::type::Struct* s) {
auto* fn = GetLoadFunctionFor(inst, var, s);
return b.Call(fn, offset);
},
[&](const core::type::Matrix* m) {
auto* fn = GetLoadFunctionFor(inst, var, m);
return b.Call(fn, offset);
}, //
[&](const core::type::Array* a) {
auto* fn = GetLoadFunctionFor(inst, var, a);
return b.Call(fn, offset);
}, //
TINT_ICE_ON_NO_MATCH);
}
// Creates a `v.Load{2,3,4} offset` call based on the provided type. The load returns a `u32` or
// vector of `u32` and then a `bitcast` is done to get back to the desired type.
//
// This only works for `u32`, `i32`, `f16`, `f32` and the vector sizes of those types.
//
// The `f16` type is special in that `f16` uses a templated load in HLSL `Load<float16_t>` and
// returns the correct type, so there is no bitcast.
core::ir::Call* MakeScalarOrVectorLoad(core::ir::Var* var,
const core::type::Type* result_ty,
core::ir::Value* offset) {
bool is_f16 = result_ty->DeepestElement()->Is<core::type::F16>();
const core::type::Type* load_ty = ty.u32();
// An `f16` load returns an `f16` instead of a `u32`
if (is_f16) {
load_ty = ty.f16();
}
auto fn = is_f16 ? BuiltinFn::kLoadF16 : BuiltinFn::kLoad;
if (auto* v = result_ty->As<core::type::Vector>()) {
load_ty = ty.vec(load_ty, v->Width());
switch (v->Width()) {
case 2:
fn = is_f16 ? BuiltinFn::kLoad2F16 : BuiltinFn::kLoad2;
break;
case 3:
fn = is_f16 ? BuiltinFn::kLoad3F16 : BuiltinFn::kLoad3;
break;
case 4:
fn = is_f16 ? BuiltinFn::kLoad4F16 : BuiltinFn::kLoad4;
break;
default:
TINT_UNREACHABLE();
}
}
auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(load_ty, fn, var, offset);
core::ir::Call* res = nullptr;
// Do not bitcast the `f16` conversions as they need to be a templated Load instruction
if (is_f16) {
res = builtin;
} else {
res = b.Bitcast(result_ty, builtin->Result(0));
}
return res;
}
// Creates a load function for the given `var` and `struct` combination. Essentially creates a
// function similar to:
//
// fn custom_load_S(offset: u32) {
// let a = <load S member 0>(offset + member 0 offset);
// let b = <load S member 1>(offset + member 1 offset);
// ...
// let z = <load S member last>(offset + member last offset);
// return S(a, b, ..., z);
// }
core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
core::ir::Var* var,
const core::type::Struct* s) {
return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, s}, [&] {
auto* p = b.FunctionParam("offset", ty.u32());
auto* fn = b.Function(s);
fn->SetParams({p});
b.Append(fn->Block(), [&] {
Vector<core::ir::Value*, 4> values;
for (const auto* mem : s->Members()) {
values.Push(MakeLoad(inst, var, mem->Type(),
b.Add<u32>(p, u32(mem->Offset()))->Result(0))
->Result(0));
}
b.Return(fn, b.Construct(s, values));
});
return fn;
});
}
// Creates a load function for the given `var` and `matrix` combination. Essentially creates a
// function similar to:
//
// fn custom_load_M(offset: u32) {
// let a = <load M column 1>(offset + (1 * ColumnStride));
// let b = <load M column 2>(offset + (2 * ColumnStride));
// ...
// let z = <load M column last>(offset + (last * ColumnStride));
// return M(a, b, ... z);
// }
core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
core::ir::Var* var,
const core::type::Matrix* mat) {
return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, mat}, [&] {
auto* p = b.FunctionParam("offset", ty.u32());
auto* fn = b.Function(mat);
fn->SetParams({p});
b.Append(fn->Block(), [&] {
Vector<core::ir::Value*, 4> values;
for (size_t i = 0; i < mat->columns(); ++i) {
auto* add = b.Add<u32>(p, u32(i * mat->ColumnStride()));
auto* load = MakeLoad(inst, var, mat->ColumnType(), add->Result(0));
values.Push(load->Result(0));
}
b.Return(fn, b.Construct(mat, values));
});
return fn;
});
}
// Creates a load function for the given `var` and `array` combination. Essentially creates a
// function similar to:
//
// fn custom_load_A(offset: u32) {
// A a = A();
// u32 i = 0;
// loop {
// if (i >= A length) {
// break;
// }
// a[i] = <load array type>(offset + (i * A->Stride()));
// i = i + 1;
// }
// return a;
// }
core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
core::ir::Var* var,
const core::type::Array* arr) {
return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, arr}, [&] {
auto* p = b.FunctionParam("offset", ty.u32());
auto* fn = b.Function(arr);
fn->SetParams({p});
b.Append(fn->Block(), [&] {
auto* result_arr = b.Var<function>("a", b.Zero(arr));
auto* count = arr->Count()->As<core::type::ConstantArrayCount>();
TINT_ASSERT(count);
b.LoopRange(ty, 0_u, u32(count->value), 1_u, [&](core::ir::Value* idx) {
auto* access = b.Access(ty.ptr<function>(arr->ElemType()), result_arr, idx);
auto* stride = b.Multiply<u32>(idx, u32(arr->Stride()));
auto* byte_offset = b.Add<u32>(p, stride);
b.Store(access, MakeLoad(inst, var, arr->ElemType(), byte_offset->Result(0)));
});
b.Return(fn, b.Load(result_arr));
});
return fn;
});
}
void Access(core::ir::Access* a,
core::ir::Var* var,
const core::type::Type* obj,
uint32_t byte_offset) {
// Note, because we recurse through the `access` helper, the object passed in isn't
// necessarily the originating `var` object, but maybe a partially resolved access chain
// object.
if (auto* view = obj->As<core::type::MemoryView>()) {
obj = view->StoreType();
}
for (auto* idx_value : a->Indices()) {
auto* cnst = idx_value->As<core::ir::Constant>();
// TODO(dsinclair): Handle non-constant accessors where the indices are dynamic
TINT_ASSERT(cnst);
uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
tint::Switch(
obj, //
[&](const core::type::Vector* v) {
byte_offset += v->type()->Size() * idx;
obj = v->type();
},
[&](const core::type::Matrix* m) {
byte_offset += m->type()->Size() * m->rows() * idx;
obj = m->ColumnType();
},
[&](const core::type::Array* ary) {
byte_offset += ary->Stride() * idx;
obj = ary->ElemType();
},
[&](const core::type::Struct* s) {
auto* mem = s->Members()[idx];
byte_offset += mem->Offset();
obj = mem->Type();
},
TINT_ICE_ON_NO_MATCH);
}
auto insert_load = [&](core::ir::Instruction* inst, uint32_t offset) {
b.InsertBefore(inst, [&] {
auto* call = MakeLoad(inst, var, inst->Result(0)->Type(), b.Value(u32(offset)));
inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
});
inst->Destroy();
};
// Copy the usages into a vector so we can remove items from the hashset.
auto usages = a->Result(0)->Usages().Vector();
for (auto& usage : usages) {
tint::Switch(
usage.instruction, //
[&](core::ir::Let*) {
// TODO(dsinclair): handle let
},
[&](core::ir::Access* sub_access) {
// Treat an access chain of the access chain as a continuation of the outer
// chain. Pass through the object we stopped at and the current byte_offset and
// then restart the access chain replacement for the new access chain.
Access(sub_access, var, obj, byte_offset);
},
[&](core::ir::LoadVectorElement* lve) {
a->Result(0)->RemoveUsage(usage);
byte_offset += CalculateVectorIndex(lve->Index(), obj);
insert_load(lve, byte_offset);
}, //
[&](core::ir::Load* ld) {
a->Result(0)->RemoveUsage(usage);
insert_load(ld, byte_offset);
},
[&](core::ir::StoreVectorElement*) {
// TODO(dsinclair): Handle stor vector elements
}, //
[&](core::ir::Store*) {
// TODO(dsinclair): Handle store
}, //
TINT_ICE_ON_NO_MATCH);
}
a->Destroy();
}
void Store(core::ir::Store*) {
// TODO(dsinclair): Handle store
}
// This should _only_ be handling a `var` parameter as any `access` parameters would have been
// replaced by the `access` being converted.
void Load(core::ir::Load* ld, core::ir::Var* var) {
auto* result = ld->From()->As<core::ir::InstructionResult>();
TINT_ASSERT(result);
auto* inst = result->Instruction()->As<core::ir::Var>();
TINT_ASSERT(inst);
const core::type::Type* result_ty = inst->Result(0)->Type()->UnwrapPtr();
b.InsertBefore(ld, [&] {
auto* call = MakeLoad(ld, var, result_ty, b.Value(0_u));
ld->Result(0)->ReplaceAllUsesWith(call->Result(0));
});
ld->Destroy();
}
// Converts to:
//
// %1:u32 = v.Load 0u
// %b:f32 = bitcast %1
void LoadVectorElement(core::ir::LoadVectorElement* lve,
core::ir::Var* var,
const core::type::Pointer* var_ty) {
uint32_t pos = CalculateVectorIndex(lve->Index(), var_ty->StoreType());
b.InsertBefore(lve, [&] {
auto* result = MakeScalarOrVectorLoad(var, lve->Result(0)->Type(), b.Value(u32(pos)));
lve->Result(0)->ReplaceAllUsesWith(result->Result(0));
});
lve->Destroy();
}
// Converts to:
//
// %1 = <sve->Value()>
// %2:u32 = bitcast %1
// %3:void = v.Store 0u, %2
void StoreVectorElement(core::ir::StoreVectorElement* sve,
core::ir::Var* var,
const core::type::Pointer* var_ty) {
uint32_t pos = CalculateVectorIndex(sve->Index(), var_ty->StoreType());
auto* cast = b.Bitcast(ty.u32(), sve->Value());
auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(ty.void_(), BuiltinFn::kStore,
var, u32(pos), cast);
cast->InsertBefore(sve);
builtin->InsertBefore(sve);
sve->Destroy();
}
};
} // namespace
Result<SuccessType> DecomposeMemoryAccess(core::ir::Module& ir) {
auto result = ValidateAndDumpIfNeeded(ir, "DecomposeMemoryAccess transform");
if (result != Success) {
return result.Failure();
}
State{ir}.Process();
return Success;
}
} // namespace tint::hlsl::writer::raise