[hlsl] Enable uniform vector load
Add support for `load_vector_element` to the uniform decompose
transform.
Bug: 349867642
Change-Id: Ie00cd71d82e27d3446b3f63c79986d3fd0b74f69
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/197074
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 3637485..8587d5d 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -789,17 +789,18 @@
)");
}
-TEST_F(HlslWriterTest, DISABLED_AccessUniformVectorLoad) {
+TEST_F(HlslWriterTest, AccessUniformVectorLoad) {
auto* var = b.Var<uniform, vec4<f32>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
b.ir.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
- b.Let("a", b.LoadVectorElement(var, 0_u));
- b.Let("b", b.LoadVectorElement(var, 1_u));
- b.Let("c", b.LoadVectorElement(var, 2_u));
- b.Let("d", b.LoadVectorElement(var, 3_u));
+ b.Let("a", b.Load(var));
+ b.Let("b", b.LoadVectorElement(var, 0_u));
+ b.Let("c", b.LoadVectorElement(var, 1_u));
+ b.Let("d", b.LoadVectorElement(var, 2_u));
+ b.Let("e", b.LoadVectorElement(var, 3_u));
b.Return(func);
});
@@ -808,13 +809,14 @@
cbuffer cbuffer_v : register(b0) {
uint4 v[1];
};
-
-void m() {
- float a = asfloat(v[0].x);
- float b = asfloat(v[0].y);
- float c = asfloat(v[0].z);
- float d = asfloat(v[0].w);
+void foo() {
+ float4 a = asfloat(v[0u]);
+ float b = asfloat(v[0u][0u]);
+ float c = asfloat(v[0u][1u]);
+ float d = asfloat(v[0u][2u]);
+ float e = asfloat(v[0u][3u]);
}
+
)");
}
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
index 4087e53..607f3f6 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
@@ -86,40 +86,31 @@
for (auto* var : var_worklist) {
auto* result = var->Result(0);
- // Find all the usages of the `var` which is loading.
- Vector<core::ir::Instruction*, 4> usage_worklist;
- for (auto& usage : result->Usages()) {
- Switch(
- usage->instruction,
- [&](core::ir::LoadVectorElement* lve) { usage_worklist.Push(lve); },
- [&](core::ir::Load* ld) { usage_worklist.Push(ld); },
- [&](core::ir::Access* a) { usage_worklist.Push(a); },
- [&](core::ir::Let* l) { usage_worklist.Push(l); }, //
- TINT_ICE_ON_NO_MATCH);
- }
-
+ auto usage_worklist = result->Usages().Vector();
auto* var_ty = result->Type()->As<core::type::Pointer>();
while (!usage_worklist.IsEmpty()) {
- auto* inst = usage_worklist.Pop();
+ auto usage = usage_worklist.Pop();
+ auto* inst = usage.instruction;
+
// Load instructions can be destroyed by the replacing access function
if (!inst->Alive()) {
continue;
}
Switch(
- inst,
- [&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
- [&](core::ir::Load* l) { Load(l, var); }, //
- [&](core::ir::Access* a) {
- OffsetData offset;
- Access(a, var, a->Object()->Type(), &offset);
- },
+ inst, //
+ [&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var); },
+ [&](core::ir::Load* l) { Load(l, var); },
+ // [&](core::ir::Access* a) {
+ // OffsetData offset;
+ // Access(a, var, a->Object()->Type(), &offset);
+ // },
[&](core::ir::Let* let) {
// The `let` is, essentially, an alias for the `var` as it's assigned
// directly. Gather all the `let` usages into our worklist, and then replace
// the `let` with the `var` itself.
- for (auto& usage : let->Result(0)->Usages()) {
- usage_worklist.Push(usage->instruction);
+ for (auto& use : let->Result(0)->Usages()) {
+ usage_worklist.Push(use);
}
let->Result(0)->ReplaceAllUsesWith(result);
let->Destroy();
@@ -134,369 +125,26 @@
}
}
- struct OffsetData {
- uint32_t byte_offset = 0;
- Vector<core::ir::Value*, 4> expr{};
- };
-
- // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
- void CalculateVectorIndex(core::ir::Value* v,
- const core::type::Type* store_ty,
- OffsetData* offset) {
- auto elm_size = store_ty->DeepestElement()->Size();
- tint::Switch(
- v, //
- [&](core::ir::Constant* idx_value) {
- offset->byte_offset += idx_value->Value()->ValueAs<uint32_t>() * elm_size;
- },
- [&](core::ir::Value* val) {
- offset->expr.Push(
- b.Multiply(ty.u32(), b.Convert(ty.u32(), val), u32(elm_size))->Result(0));
- },
- TINT_ICE_ON_NO_MATCH);
- }
-
- // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
- core::ir::Value* OffsetToValue(OffsetData offset) {
- core::ir::Value* val = b.Value(u32(offset.byte_offset));
- for (core::ir::Value* expr : offset.expr) {
- val = b.Add(ty.u32(), val, expr)->Result(0);
- }
- return val;
- }
-
- // Creates the appropriate load instructions for the given result type.
- core::ir::Call* MakeLoad(core::ir::Instruction* inst,
- core::ir::Var* var,
- const core::type::Type* result_ty,
- core::ir::Value* offset) {
- if (result_ty->is_numeric_scalar_or_vector()) {
- return MakeScalarOrVectorLoad(var, result_ty, offset);
- }
-
- return tint::Switch(
- result_ty, //
- [&](const core::type::Struct* s) {
- auto* fn = GetLoadFunctionFor(inst, var, s);
- return b.Call(fn, offset);
- },
- [&](const core::type::Matrix* m) {
- auto* fn = GetLoadFunctionFor(inst, var, m);
- return b.Call(fn, offset);
- }, //
- [&](const core::type::Array* a) {
- auto* fn = GetLoadFunctionFor(inst, var, a);
- return b.Call(fn, offset);
- }, //
- TINT_ICE_ON_NO_MATCH);
- }
-
- // Creates a `v.Load{2,3,4} offset` call based on the provided type. The load returns a
- // `u32` or vector of `u32` and then a `bitcast` is done to get back to the desired type.
- //
- // This only works for `u32`, `i32`, `f16`, `f32` and the vector sizes of those types.
- //
- // The `f16` type is special in that `f16` uses a templated load in HLSL `Load<float16_t>`
- // and returns the correct type, so there is no bitcast.
- core::ir::Call* MakeScalarOrVectorLoad(core::ir::Var* var,
- const core::type::Type* result_ty,
- core::ir::Value* offset) {
- bool is_f16 = result_ty->DeepestElement()->Is<core::type::F16>();
-
- const core::type::Type* load_ty = ty.u32();
- // An `f16` load returns an `f16` instead of a `u32`
- if (is_f16) {
- load_ty = ty.f16();
- }
-
- auto fn = is_f16 ? BuiltinFn::kLoadF16 : BuiltinFn::kLoad;
- if (auto* v = result_ty->As<core::type::Vector>()) {
- load_ty = ty.vec(load_ty, v->Width());
- switch (v->Width()) {
- case 2:
- fn = is_f16 ? BuiltinFn::kLoad2F16 : BuiltinFn::kLoad2;
- break;
- case 3:
- fn = is_f16 ? BuiltinFn::kLoad3F16 : BuiltinFn::kLoad3;
- break;
- case 4:
- fn = is_f16 ? BuiltinFn::kLoad4F16 : BuiltinFn::kLoad4;
- break;
- default:
- TINT_UNREACHABLE();
- }
- }
-
- auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(load_ty, fn, var, offset);
- core::ir::Call* res = nullptr;
-
- // Do not bitcast the `f16` conversions as they need to be a templated Load instruction
- if (is_f16) {
- res = builtin;
- } else {
- res = b.Bitcast(result_ty, builtin->Result(0));
- }
- return res;
- }
-
- // Creates a load function for the given `var` and `struct` combination. Essentially creates
- // a function similar to:
- //
- // fn custom_load_S(offset: u32) {
- // let a = <load S member 0>(offset + member 0 offset);
- // let b = <load S member 1>(offset + member 1 offset);
- // ...
- // let z = <load S member last>(offset + member last offset);
- // return S(a, b, ..., z);
- // }
- core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
- core::ir::Var* var,
- const core::type::Struct* s) {
- return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, s}, [&] {
- auto* p = b.FunctionParam("offset", ty.u32());
- auto* fn = b.Function(s);
- fn->SetParams({p});
-
- b.Append(fn->Block(), [&] {
- Vector<core::ir::Value*, 4> values;
- for (const auto* mem : s->Members()) {
- values.Push(MakeLoad(inst, var, mem->Type(),
- b.Add<u32>(p, u32(mem->Offset()))->Result(0))
- ->Result(0));
- }
-
- b.Return(fn, b.Construct(s, values));
- });
-
- return fn;
- });
- }
-
- // Creates a load function for the given `var` and `matrix` combination. Essentially creates
- // a function similar to:
- //
- // fn custom_load_M(offset: u32) {
- // let a = <load M column 1>(offset + (1 * ColumnStride));
- // let b = <load M column 2>(offset + (2 * ColumnStride));
- // ...
- // let z = <load M column last>(offset + (last * ColumnStride));
- // return M(a, b, ... z);
- // }
- core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
- core::ir::Var* var,
- const core::type::Matrix* mat) {
- return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, mat}, [&] {
- auto* p = b.FunctionParam("offset", ty.u32());
- auto* fn = b.Function(mat);
- fn->SetParams({p});
-
- b.Append(fn->Block(), [&] {
- Vector<core::ir::Value*, 4> values;
- for (size_t i = 0; i < mat->columns(); ++i) {
- auto* add = b.Add<u32>(p, u32(i * mat->ColumnStride()));
- auto* load = MakeLoad(inst, var, mat->ColumnType(), add->Result(0));
- values.Push(load->Result(0));
- }
-
- b.Return(fn, b.Construct(mat, values));
- });
-
- return fn;
- });
- }
-
- // Creates a load function for the given `var` and `array` combination. Essentially creates
- // a function similar to:
- //
- // fn custom_load_A(offset: u32) {
- // A a = A();
- // u32 i = 0;
- // loop {
- // if (i >= A length) {
- // break;
- // }
- // a[i] = <load array type>(offset + (i * A->Stride()));
- // i = i + 1;
- // }
- // return a;
- // }
- core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
- core::ir::Var* var,
- const core::type::Array* arr) {
- return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, arr}, [&] {
- auto* p = b.FunctionParam("offset", ty.u32());
- auto* fn = b.Function(arr);
- fn->SetParams({p});
-
- b.Append(fn->Block(), [&] {
- auto* result_arr = b.Var<function>("a", b.Zero(arr));
-
- auto* count = arr->Count()->As<core::type::ConstantArrayCount>();
- TINT_ASSERT(count);
-
- b.LoopRange(ty, 0_u, u32(count->value), 1_u, [&](core::ir::Value* idx) {
- auto* access = b.Access(ty.ptr<function>(arr->ElemType()), result_arr, idx);
- auto* stride = b.Multiply<u32>(idx, u32(arr->Stride()));
- auto* byte_offset = b.Add<u32>(p, stride);
- b.Store(access, MakeLoad(inst, var, arr->ElemType(), byte_offset->Result(0)));
- });
-
- b.Return(fn, b.Load(result_arr));
- });
-
- return fn;
- });
- }
-
- void InsertLoad(core::ir::Var* var, core::ir::Instruction* inst, OffsetData offset) {
- b.InsertBefore(inst, [&] {
- auto* call =
- MakeLoad(inst, var, inst->Result(0)->Type()->UnwrapPtr(), OffsetToValue(offset));
- inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
- });
- inst->Destroy();
- }
-
- void Access(core::ir::Access* a,
- core::ir::Var* var,
- const core::type::Type* obj,
- OffsetData* offset) {
- TINT_ASSERT(offset);
-
- // Note, because we recurse through the `access` helper, the object passed in isn't
- // necessarily the originating `var` object, but maybe a partially resolved access chain
- // object.
- if (auto* view = obj->As<core::type::MemoryView>()) {
- obj = view->StoreType();
- }
-
- const core::type::StructMember* last_member = nullptr;
- for (auto* idx_value : a->Indices()) {
- uint32_t size = 0;
- auto* cur_obj = obj;
- tint::Switch(
- obj, //
- [&](const core::type::Vector* v) {
- size = v->type()->Size();
- obj = v->type();
- },
- [&](const core::type::Matrix* m) {
- size = m->type()->Size() * m->rows();
- obj = m->ColumnType();
- },
- [&](const core::type::Array* ary) {
- size = ary->Stride();
- obj = ary->ElemType();
- },
- [&](const core::type::Struct* s) {
- auto* cnst = idx_value->As<core::ir::Constant>();
-
- // A struct index must be a constant
- TINT_ASSERT(cnst);
-
- uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
- last_member = s->Members()[idx];
- obj = last_member->Type();
- },
- TINT_ICE_ON_NO_MATCH);
-
- tint::Switch(
- idx_value, //
- [&](core::ir::Constant* cnst) {
- uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
- tint::Switch(
- cur_obj, //
- [&](const core::type::Vector*) { offset->byte_offset += size * idx; },
- [&](const core::type::Matrix*) { offset->byte_offset += size * idx; },
- [&](const core::type::Array*) { offset->byte_offset += size * idx; },
- [&](const core::type::Struct* s) {
- auto* mem = s->Members()[idx];
- offset->byte_offset += mem->Offset();
- },
- TINT_ICE_ON_NO_MATCH);
- },
- [&](core::ir::Value* val) {
- b.InsertBefore(a, [&] {
- offset->expr.Push(
- b.Multiply(ty.u32(), u32(size), b.Convert(ty.u32(), val))->Result(0));
- });
- }, //
- TINT_ICE_ON_NO_MATCH);
- }
-
- // Copy the usages into a vector so we can remove items from the hashset.
- auto usages = a->Result(0)->Usages().Vector();
- while (!usages.IsEmpty()) {
- auto usage = usages.Pop();
- tint::Switch(
- usage.instruction,
- [&](core::ir::Let* let) {
- // The `let` is essentially an alias to the `access`. So, add the `let`
- // usages into the usage worklist, and replace the let with the access chain
- // directly.
- for (auto& u : let->Result(0)->Usages()) {
- usages.Push(u);
- }
- let->Result(0)->ReplaceAllUsesWith(a->Result(0));
- let->Destroy();
- },
- [&](core::ir::Access* sub_access) {
- // Treat an access chain of the access chain as a continuation of the outer
- // chain. Pass through the object we stopped at and the current byte_offset
- // and then restart the access chain replacement for the new access chain.
- Access(sub_access, var, obj, offset);
- },
-
- [&](core::ir::LoadVectorElement* lve) {
- a->Result(0)->RemoveUsage(usage);
-
- b.InsertBefore(lve, [&] { CalculateVectorIndex(lve->Index(), obj, offset); });
- InsertLoad(var, lve, *offset);
- },
- [&](core::ir::Load* ld) {
- a->Result(0)->RemoveUsage(usage);
- InsertLoad(var, ld, *offset);
- },
- TINT_ICE_ON_NO_MATCH);
- }
-
- a->Destroy();
- }
-
- // This should _only_ be handling a `var` parameter as any `access` parameters would have
- // been replaced by the `access` being converted.
void Load(core::ir::Load* ld, core::ir::Var* var) {
- auto* result = ld->From()->As<core::ir::InstructionResult>();
- TINT_ASSERT(result);
-
- auto* inst = result->Instruction()->As<core::ir::Var>();
- TINT_ASSERT(inst);
-
- const core::type::Type* result_ty = inst->Result(0)->Type()->UnwrapPtr();
-
b.InsertBefore(ld, [&] {
- auto* call = MakeLoad(ld, var, result_ty, b.Value(0_u));
- ld->Result(0)->ReplaceAllUsesWith(call->Result(0));
+ auto* access = b.Access(ty.ptr(uniform, ty.vec4<u32>()), var, 0_u);
+ auto* load = b.Load(access);
+ auto* bitcast = b.Bitcast(ld->Result(0)->Type(), load);
+ ld->Result(0)->ReplaceAllUsesWith(bitcast->Result(0));
});
ld->Destroy();
}
- // Converts to:
- //
- // %1:u32 = v.Load 0u
- // %b:f32 = bitcast %1
- void LoadVectorElement(core::ir::LoadVectorElement* lve,
- core::ir::Var* var,
- const core::type::Pointer* var_ty) {
+ // A direct vector load on the `var` means the `var` is a vector. Replace the `var` usage` with
+ // an access chain into the `var` array `0` element a `load_vector_element` to retrieve the item
+ // and a `bitcast` to the correct type.
+ void LoadVectorElement(core::ir::LoadVectorElement* lve, core::ir::Var* var) {
b.InsertBefore(lve, [&] {
- OffsetData offset{};
- CalculateVectorIndex(lve->Index(), var_ty->StoreType(), &offset);
-
- auto* result =
- MakeScalarOrVectorLoad(var, lve->Result(0)->Type(), OffsetToValue(offset));
- lve->Result(0)->ReplaceAllUsesWith(result->Result(0));
+ auto* access = b.Access(ty.ptr(uniform, ty.vec4<u32>()), var, 0_u);
+ auto* load = b.LoadVectorElement(access, lve->Index());
+ auto* bitcast = b.Bitcast(lve->Result(0)->Type(), load);
+ lve->Result(0)->ReplaceAllUsesWith(bitcast->Result(0));
});
-
lve->Destroy();
}
};
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
index ac7d72d..270c194 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
@@ -228,7 +228,7 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeUniformAccessTest, DISABLED_UniformAccessVectorLoad) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessVectorLoad) {
auto* var = b.Var<uniform, vec4<f32>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -268,26 +268,31 @@
auto* expect = R"(
$B1: { # root
- %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+ %v:ptr<uniform, array<vec4<u32>, 1>, read> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
$B2: {
- %3:vec4<u32> = %v.Load4 0u
- %4:vec4<f32> = bitcast %3
- %a:vec4<f32> = let %4
- %6:u32 = %v.Load 0u
- %7:f32 = bitcast %6
- %b:f32 = let %7
- %9:u32 = %v.Load 4u
- %10:f32 = bitcast %9
- %c:f32 = let %10
- %12:u32 = %v.Load 8u
+ %3:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %4:vec4<u32> = load %3
+ %5:vec4<f32> = bitcast %4
+ %a:vec4<f32> = let %5
+ %7:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %8:u32 = load_vector_element %7, 0u
+ %9:f32 = bitcast %8
+ %b:f32 = let %9
+ %11:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %12:u32 = load_vector_element %11, 1u
%13:f32 = bitcast %12
- %d:f32 = let %13
- %15:u32 = %v.Load 12u
- %16:f32 = bitcast %15
- %e:f32 = let %16
+ %c:f32 = let %13
+ %15:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %16:u32 = load_vector_element %15, 2u
+ %17:f32 = bitcast %16
+ %d:f32 = let %17
+ %19:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %20:u32 = load_vector_element %19, 3u
+ %21:f32 = bitcast %20
+ %e:f32 = let %21
ret
}
}