[hlsl] Add f16 matrix support to uniform loading
This CL adds the ability to load f16 matrix values to the decompose
uniform access transform.
Bug: 349867642
Change-Id: Ia963c49af1e00146b573e31b46a3ef20bf41fc29
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/198455
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 7de4034..f2d4120 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -865,6 +865,52 @@
)");
}
+TEST_F(HlslWriterTest, AccessUniformScalar) {
+ auto* var = b.Var<uniform, f32, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+ uint4 v[1];
+};
+void foo() {
+ float a = asfloat(v[0u].x);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessUniformScalarF16) {
+ auto* var = b.Var<uniform, f16, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+ uint4 v[1];
+};
+void foo() {
+ float16_t a = float16_t(f16tof32(v[0u].x));
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, AccessUniformVector) {
auto* var = b.Var<uniform, vec4<f32>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -896,39 +942,17 @@
)");
}
-TEST_F(HlslWriterTest, AccessUniformStorageScalarF16) {
- auto* var = b.Var<uniform, f16, core::Access::kRead>("v");
- var->SetBindingPoint(0, 0);
-
- b.ir.root_block->Append(var);
- auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
- b.Append(func->Block(), [&] {
- b.Let("a", b.Load(var));
- b.Return(func);
- });
-
- ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(
-cbuffer cbuffer_v : register(b0) {
- uint4 v[1];
-};
-void foo() {
- float16_t a = float16_t(f16tof32(v[0u].x));
-}
-
-)");
-}
-
-TEST_F(HlslWriterTest, AccessUniformStorageVectorF16) {
+TEST_F(HlslWriterTest, AccessUniformVectorF16) {
auto* var = b.Var<uniform, vec4<f16>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
b.ir.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
+ auto* x = b.Var("x", 1_u);
b.Let("a", b.Load(var));
b.Let("b", b.LoadVectorElement(var, 0_u));
- b.Let("c", b.LoadVectorElement(var, 1_u));
+ b.Let("c", b.LoadVectorElement(var, b.Load(x)));
b.Let("d", b.LoadVectorElement(var, 2_u));
b.Let("e", b.LoadVectorElement(var, 3_u));
b.Return(func);
@@ -945,13 +969,19 @@
uint4 shift = (16u).xxxx;
float4 t_low = f16tof32((v & mask));
float4 t_high = f16tof32(((v >> shift) & mask));
- return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+ float16_t v_1 = float16_t(t_low.x);
+ float16_t v_2 = float16_t(t_high.x);
+ float16_t v_3 = float16_t(t_low.y);
+ return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
}
void foo() {
+ uint x = 1u;
vector<float16_t, 4> a = tint_bitcast_to_f16(v[0u]);
float16_t b = float16_t(f16tof32(v[0u].x));
- float16_t c = float16_t(f16tof32((v[0u].x >> 16u)));
+ uint v_4 = (min(x, 3u) * 2u);
+ uint v_5 = v[(v_4 / 16u)][((v_4 % 16u) / 4u)];
+ float16_t c = float16_t(f16tof32((v_5 >> ((((v_4 % 4u) == 0u)) ? (0u) : (16u)))));
float16_t d = float16_t(f16tof32(v[0u].y));
float16_t e = float16_t(f16tof32((v[0u].y >> 16u)));
}
@@ -959,50 +989,6 @@
)");
}
-TEST_F(HlslWriterTest, DISABLED_AccessUniformStorageMat2x3F16) {
- auto* var = b.Var<uniform, mat2x3<f16>, core::Access::kRead>("v");
- var->SetBindingPoint(0, 0);
-
- auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
- b.Append(func->Block(), [&] {
- b.Let("a", b.Load(var));
- b.Let("b", b.Load(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u)));
- b.Let("c", b.LoadVectorElement(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u), 2_u));
- b.Return(func);
- });
-
- ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(
-cbuffer cbuffer_v : register(b0) {
- uint4 v[1];
-};
-
-matrix<float16_t, 2, 3> v_load(uint offset) {
- const uint scalar_offset = ((offset + 0u)) / 4;
- uint4 ubo_load_1 = v[scalar_offset / 4];
- uint2 ubo_load = ((scalar_offset & 2) ? ubo_load_1.zw : ubo_load_1.xy);
- vector<float16_t, 2> ubo_load_xz = vector<float16_t, 2>(f16tof32(ubo_load & 0xFFFF));
- float16_t ubo_load_y = f16tof32(ubo_load[0] >> 16);
- const uint scalar_offset_1 = ((offset + 8u)) / 4;
- uint4 ubo_load_3 = v[scalar_offset_1 / 4];
- uint2 ubo_load_2 = ((scalar_offset_1 & 2) ? ubo_load_3.zw : ubo_load_3.xy);
- vector<float16_t, 2> ubo_load_2_xz = vector<float16_t, 2>(f16tof32(ubo_load_2 & 0xFFFF));
- float16_t ubo_load_2_y = f16tof32(ubo_load_2[0] >> 16);
- return matrix<float16_t, 2, 3>(vector<float16_t, 3>(ubo_load_xz[0], ubo_load_y, ubo_load_xz[1]), vector<float16_t, 3>(ubo_load_2_xz[0], ubo_load_2_y, ubo_load_2_xz[1]));
-}
-
-void foo() {
- matrix<float16_t, 2, 3> a = v_load(0u);
- uint2 ubo_load_4 = v[0].zw;
- vector<float16_t, 2> ubo_load_4_xz = vector<float16_t, 2>(f16tof32(ubo_load_4 & 0xFFFF));
- float16_t ubo_load_4_y = f16tof32(ubo_load_4[0] >> 16);
- vector<float16_t, 3> b = vector<float16_t, 3>(ubo_load_4_xz[0], ubo_load_4_y, ubo_load_4_xz[1]);
- float16_t c = float16_t(f16tof32(((v[0].w) & 0xFFFF)));
-}
-
-)");
-}
-
TEST_F(HlslWriterTest, AccessUniformMatrix) {
auto* var = b.Var<uniform, mat4x4<f32>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -1071,6 +1057,49 @@
)");
}
+TEST_F(HlslWriterTest, AccessUniformMat2x3F16) {
+ auto* var = b.Var<uniform, mat2x3<f16>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+ b.ir.root_block->Append(var);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u)));
+ b.Let("c", b.LoadVectorElement(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u), 2_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+ uint4 v[1];
+};
+vector<float16_t, 4> tint_bitcast_to_f16(uint4 src) {
+ uint4 v = src;
+ uint4 mask = (65535u).xxxx;
+ uint4 shift = (16u).xxxx;
+ float4 t_low = f16tof32((v & mask));
+ float4 t_high = f16tof32(((v >> shift) & mask));
+ float16_t v_1 = float16_t(t_low.x);
+ float16_t v_2 = float16_t(t_high.x);
+ float16_t v_3 = float16_t(t_low.y);
+ return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
+}
+
+matrix<float16_t, 2, 3> v_4(uint start_byte_offset) {
+ vector<float16_t, 3> v_5 = tint_bitcast_to_f16(v[(start_byte_offset / 16u)]).xyz;
+ return matrix<float16_t, 2, 3>(v_5, tint_bitcast_to_f16(v[((8u + start_byte_offset) / 16u)]).xyz);
+}
+
+void foo() {
+ matrix<float16_t, 2, 3> a = v_4(0u);
+ vector<float16_t, 3> b = tint_bitcast_to_f16(v[0u]).xyz;
+ float16_t c = float16_t(f16tof32(v[0u].w));
+}
+
+)");
+}
TEST_F(HlslWriterTest, AccessUniformMatrix3x2) {
auto* var = b.Var<uniform, mat3x2<f32>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -1143,6 +1172,49 @@
)");
}
+TEST_F(HlslWriterTest, AccessUniformMatrix2x2F16) {
+ auto* var = b.Var<uniform, mat2x2<f16>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<uniform, vec2<f16>, core::Access::kRead>(), var, 1_u)));
+ b.Let("c", b.LoadVectorElement(
+ b.Access(ty.ptr<uniform, vec2<f16>, core::Access::kRead>(), var, 1_u), 1_u));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+ uint4 v[1];
+};
+vector<float16_t, 2> tint_bitcast_to_f16(uint src) {
+ uint v = src;
+ float t_low = f16tof32((v & 65535u));
+ float t_high = f16tof32(((v >> 16u) & 65535u));
+ float16_t v_1 = float16_t(t_low);
+ return vector<float16_t, 2>(v_1, float16_t(t_high));
+}
+
+matrix<float16_t, 2, 2> v_2(uint start_byte_offset) {
+ uint4 v_3 = v[(start_byte_offset / 16u)];
+ vector<float16_t, 2> v_4 = tint_bitcast_to_f16((((((start_byte_offset % 16u) / 4u) == 2u)) ? (v_3.z) : (v_3.x)));
+ uint4 v_5 = v[((4u + start_byte_offset) / 16u)];
+ return matrix<float16_t, 2, 2>(v_4, tint_bitcast_to_f16(((((((4u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_5.z) : (v_5.x))));
+}
+
+void foo() {
+ matrix<float16_t, 2, 2> a = v_2(0u);
+ vector<float16_t, 2> b = tint_bitcast_to_f16(v[0u].x);
+ float16_t c = float16_t(f16tof32((v[0u].y >> 16u)));
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, AccessUniformArray) {
auto* var = b.Var<uniform, array<vec3<f32>, 5>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -1190,6 +1262,65 @@
)");
}
+TEST_F(HlslWriterTest, AccessUniformArrayF16) {
+ auto* var = b.Var<uniform, array<vec3<f16>, 5>, core::Access::kRead>("v");
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<uniform, vec3<f16>, core::Access::kRead>(), var, 3_u)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+ uint4 v[3];
+};
+vector<float16_t, 4> tint_bitcast_to_f16(uint4 src) {
+ uint4 v = src;
+ uint4 mask = (65535u).xxxx;
+ uint4 shift = (16u).xxxx;
+ float4 t_low = f16tof32((v & mask));
+ float4 t_high = f16tof32(((v >> shift) & mask));
+ float16_t v_1 = float16_t(t_low.x);
+ float16_t v_2 = float16_t(t_high.x);
+ float16_t v_3 = float16_t(t_low.y);
+ return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
+}
+
+typedef vector<float16_t, 3> ary_ret[5];
+ary_ret v_4(uint start_byte_offset) {
+ vector<float16_t, 3> a[5] = (vector<float16_t, 3>[5])0;
+ {
+ uint v_5 = 0u;
+ v_5 = 0u;
+ while(true) {
+ uint v_6 = v_5;
+ if ((v_6 >= 5u)) {
+ break;
+ }
+ a[v_6] = tint_bitcast_to_f16(v[((start_byte_offset + (v_6 * 8u)) / 16u)]).xyz;
+ {
+ v_5 = (v_6 + 1u);
+ }
+ continue;
+ }
+ }
+ vector<float16_t, 3> v_7[5] = a;
+ return v_7;
+}
+
+void foo() {
+ vector<float16_t, 3> a[5] = v_4(0u);
+ vector<float16_t, 3> b = tint_bitcast_to_f16(v[1u]).xyz;
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, AccessUniformArrayWhichCanHaveSizesOtherThenFive) {
auto* var = b.Var<uniform, array<vec3<f32>, 42>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -1278,6 +1409,48 @@
)");
}
+TEST_F(HlslWriterTest, AccessUniformStructF16) {
+ auto* SB = ty.Struct(mod.symbols.New("SB"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.f16()},
+ });
+
+ auto* var = b.Var("v", uniform, SB, core::Access::kRead);
+ var->SetBindingPoint(0, 0);
+
+ b.ir.root_block->Append(var);
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Let("a", b.Load(var));
+ b.Let("b", b.Load(b.Access(ty.ptr<uniform, f16, core::Access::kRead>(), var, 1_u)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(struct SB {
+ int a;
+ float16_t b;
+};
+
+
+cbuffer cbuffer_v : register(b0) {
+ uint4 v[1];
+};
+SB v_1(uint start_byte_offset) {
+ int v_2 = asint(v[(start_byte_offset / 16u)][((start_byte_offset % 16u) / 4u)]);
+ uint v_3 = v[((4u + start_byte_offset) / 16u)][(((4u + start_byte_offset) % 16u) / 4u)];
+ SB v_4 = {v_2, float16_t(f16tof32((v_3 >> (((((4u + start_byte_offset) % 4u) == 0u)) ? (0u) : (16u)))))};
+ return v_4;
+}
+
+void foo() {
+ SB a = v_1(0u);
+ float16_t b = float16_t(f16tof32(v[0u].y));
+}
+
+)");
+}
+
TEST_F(HlslWriterTest, AccessUniformStructNested) {
auto* Inner =
ty.Struct(mod.symbols.New("Inner"), {
diff --git a/src/tint/lang/hlsl/writer/bitcast_test.cc b/src/tint/lang/hlsl/writer/bitcast_test.cc
index d873ada..ab69c25 100644
--- a/src/tint/lang/hlsl/writer/bitcast_test.cc
+++ b/src/tint/lang/hlsl/writer/bitcast_test.cc
@@ -187,21 +187,24 @@
uint v = src;
float t_low = f16tof32((v & 65535u));
float t_high = f16tof32(((v >> 16u) & 65535u));
- return vector<float16_t, 2>(t_low.x, t_high.x);
+ float16_t v_1 = float16_t(t_low);
+ return vector<float16_t, 2>(v_1, float16_t(t_high));
}
vector<float16_t, 2> tint_bitcast_to_f16_1(float src) {
uint v = asuint(src);
float t_low = f16tof32((v & 65535u));
float t_high = f16tof32(((v >> 16u) & 65535u));
- return vector<float16_t, 2>(t_low.x, t_high.x);
+ float16_t v_2 = float16_t(t_low);
+ return vector<float16_t, 2>(v_2, float16_t(t_high));
}
vector<float16_t, 2> tint_bitcast_to_f16(int src) {
uint v = asuint(src);
float t_low = f16tof32((v & 65535u));
float t_high = f16tof32(((v >> 16u) & 65535u));
- return vector<float16_t, 2>(t_low.x, t_high.x);
+ float16_t v_3 = float16_t(t_low);
+ return vector<float16_t, 2>(v_3, float16_t(t_high));
}
void foo() {
@@ -277,7 +280,10 @@
uint2 shift = (16u).xx;
float2 t_low = f16tof32((v & mask));
float2 t_high = f16tof32(((v >> shift) & mask));
- return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+ float16_t v_1 = float16_t(t_low.x);
+ float16_t v_2 = float16_t(t_high.x);
+ float16_t v_3 = float16_t(t_low.y);
+ return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
}
vector<float16_t, 4> tint_bitcast_to_f16_1(float2 src) {
@@ -286,7 +292,10 @@
uint2 shift = (16u).xx;
float2 t_low = f16tof32((v & mask));
float2 t_high = f16tof32(((v >> shift) & mask));
- return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+ float16_t v_4 = float16_t(t_low.x);
+ float16_t v_5 = float16_t(t_high.x);
+ float16_t v_6 = float16_t(t_low.y);
+ return vector<float16_t, 4>(v_4, v_5, v_6, float16_t(t_high.y));
}
vector<float16_t, 4> tint_bitcast_to_f16(int2 src) {
@@ -295,7 +304,10 @@
uint2 shift = (16u).xx;
float2 t_low = f16tof32((v & mask));
float2 t_high = f16tof32(((v >> shift) & mask));
- return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+ float16_t v_7 = float16_t(t_low.x);
+ float16_t v_8 = float16_t(t_high.x);
+ float16_t v_9 = float16_t(t_low.y);
+ return vector<float16_t, 4>(v_7, v_8, v_9, float16_t(t_high.y));
}
void foo() {
diff --git a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
index 29bc9dd..93e09aa 100644
--- a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
@@ -526,9 +526,11 @@
},
TINT_ICE_ON_NO_MATCH);
+ bool src_vec = src_type->Is<core::type::Vector>();
+
core::ir::Value* mask = nullptr;
core::ir::Value* shift = nullptr;
- if (src_type->Is<core::type::Vector>()) {
+ if (src_vec) {
mask = b.Let("mask", b.Splat(uint_ty, 0xffff_u))->Result(0);
shift = b.Let("shift", b.Splat(uint_ty, 16_u))->Result(0);
} else {
@@ -544,13 +546,26 @@
auto* t_high = b.Let(
"t_high", b.Call<hlsl::ir::BuiltinCall>(float_ty, BuiltinFn::kF16Tof32, h));
- auto* x = b.Swizzle(ty.f16(), t_low, {0_u});
- auto* y = b.Swizzle(ty.f16(), t_high, {0_u});
- if (dst_type->As<core::type::Vector>()->Width() == 2) {
+ core::ir::Instruction* x = nullptr;
+ core::ir::Instruction* y = nullptr;
+ if (src_vec) {
+ x = b.Swizzle(ty.f32(), t_low, {0_u});
+ y = b.Swizzle(ty.f32(), t_high, {0_u});
+ } else {
+ x = t_low;
+ y = t_high;
+ }
+ x = b.Convert(ty.f16(), x);
+ y = b.Convert(ty.f16(), y);
+
+ auto dst_width = dst_type->As<core::type::Vector>()->Width();
+ TINT_ASSERT(dst_width == 2 || dst_width == 4);
+
+ if (dst_width == 2) {
b.Return(f, b.Construct(dst_type, x, y));
} else {
- auto* z = b.Swizzle(ty.f16(), t_low, {1_u});
- auto* w = b.Swizzle(ty.f16(), t_high, {1_u});
+ auto* z = b.Convert(ty.f16(), b.Swizzle(ty.f32(), t_low, {1_u}));
+ auto* w = b.Convert(ty.f16(), b.Swizzle(ty.f32(), t_high, {1_u}));
b.Return(f, b.Construct(dst_type, x, y, z, w));
}
});
diff --git a/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc b/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc
index 3867b6f..65034f5 100644
--- a/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc
@@ -270,8 +270,8 @@
%12:u32 = and %11, 65535u
%13:f32 = hlsl.f16tof32 %12
%t_high:f32 = let %13
- %15:f16 = swizzle %t_low, x
- %16:f16 = swizzle %t_high, x
+ %15:f16 = convert %t_low
+ %16:f16 = convert %t_high
%17:vec2<f16> = construct %15, %16
ret %17
}
@@ -383,12 +383,16 @@
%17:vec2<u32> = and %16, %mask
%18:vec2<f32> = hlsl.f16tof32 %17
%t_high:vec2<f32> = let %18
- %20:f16 = swizzle %t_low, x
- %21:f16 = swizzle %t_high, x
- %22:f16 = swizzle %t_low, y
- %23:f16 = swizzle %t_high, y
- %24:vec4<f16> = construct %20, %21, %22, %23
- ret %24
+ %20:f32 = swizzle %t_low, x
+ %21:f32 = swizzle %t_high, x
+ %22:f16 = convert %20
+ %23:f16 = convert %21
+ %24:f32 = swizzle %t_low, y
+ %25:f16 = convert %24
+ %26:f32 = swizzle %t_high, y
+ %27:f16 = convert %26
+ %28:vec4<f16> = construct %22, %23, %25, %27
+ ret %28
}
}
)";
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
index 732ce18..9fb8ff5 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
@@ -290,10 +290,10 @@
}
// Creates the appropriate load instructions for the given result type.
- core::ir::Call* MakeLoad(core::ir::Instruction* inst,
- core::ir::Var* var,
- const core::type::Type* result_ty,
- core::ir::Value* byte_idx) {
+ core::ir::Instruction* MakeLoad(core::ir::Instruction* inst,
+ core::ir::Var* var,
+ const core::type::Type* result_ty,
+ core::ir::Value* byte_idx) {
if (result_ty->is_float_scalar() || result_ty->is_integer_scalar()) {
return MakeScalarLoad(var, result_ty, byte_idx);
}
@@ -327,28 +327,35 @@
auto* vec_idx = CalculateVectorOffset(byte_idx);
core::ir::Instruction* load = b.LoadVectorElement(access, vec_idx);
if (result_ty->Is<core::type::F16>()) {
- if (auto* cnst = byte_idx->As<core::ir::Constant>()) {
- if (cnst->Value()->ValueAs<uint32_t>() % 4 != 0) {
- load = b.ShiftRight(ty.u32(), load, 16_u);
- }
- } else {
- auto* false_ = b.Value(16_u);
- auto* true_ = b.Value(0_u);
- auto* cond = b.Equal(ty.bool_(), b.Modulo(ty.u32(), byte_idx, 4_u), 0_u);
-
- Vector<core::ir::Value*, 3> args{false_, true_, cond->Result(0)};
- auto* shift_amt = b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(
- b.InstructionResult(ty.u32()), args);
- b.Append(shift_amt);
-
- load = b.ShiftRight(ty.u32(), load, shift_amt);
- }
- load = b.Call<hlsl::ir::BuiltinCall>(ty.f32(), hlsl::BuiltinFn::kF16Tof32, load);
- return b.Convert(result_ty, load);
+ return MakeScalarLoadF16(load, result_ty, byte_idx);
}
return b.Bitcast(result_ty, load);
}
+ core::ir::Call* MakeScalarLoadF16(core::ir::Instruction* load,
+ const core::type::Type* result_ty,
+ core::ir::Value* byte_idx) {
+ // Handle F16
+ if (auto* cnst = byte_idx->As<core::ir::Constant>()) {
+ if (cnst->Value()->ValueAs<uint32_t>() % 4 != 0) {
+ load = b.ShiftRight(ty.u32(), load, 16_u);
+ }
+ } else {
+ auto* false_ = b.Value(16_u);
+ auto* true_ = b.Value(0_u);
+ auto* cond = b.Equal(ty.bool_(), b.Modulo(ty.u32(), byte_idx, 4_u), 0_u);
+
+ Vector<core::ir::Value*, 3> args{false_, true_, cond->Result(0)};
+ auto* shift_amt = b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(
+ b.InstructionResult(ty.u32()), args);
+ b.Append(shift_amt);
+
+ load = b.ShiftRight(ty.u32(), load, shift_amt);
+ }
+ load = b.Call<hlsl::ir::BuiltinCall>(ty.f32(), hlsl::BuiltinFn::kF16Tof32, load);
+ return b.Convert(result_ty, load);
+ }
+
// When loading a vector we have to take the alignment into account to determine which part of
// the `uint4` to load. A `vec` of `u32`, `f32` or `i32` has an alignment requirement of
// a multiple of 8-bytes (`f16` is 4-bytes). So, this means we'll have memory like:
@@ -366,12 +373,16 @@
// * A 2-element row, we have to decide if we want the `xy` or `zw` element. We have a minimum
// alignment of 8-bytes as per the WGSL spec. So if the `vector_idx != 2` is `0` then we
// access the `.xy` component, otherwise it is in the `.zw` component.
- core::ir::Call* MakeVectorLoad(core::ir::Var* var,
- const core::type::Vector* result_ty,
- core::ir::Value* byte_idx) {
+ core::ir::Instruction* MakeVectorLoad(core::ir::Var* var,
+ const core::type::Vector* result_ty,
+ core::ir::Value* byte_idx) {
auto* array_idx = OffsetValueToArrayIndex(byte_idx);
auto* access = b.Access(ty.ptr(uniform, ty.vec4<u32>()), var, array_idx);
+ if (result_ty->DeepestElement()->Is<core::type::F16>()) {
+ return MakeVectorLoadF16(access, result_ty, byte_idx);
+ }
+
core::ir::Instruction* load = nullptr;
if (result_ty->Width() == 4) {
load = b.Load(access);
@@ -406,6 +417,52 @@
return b.Bitcast(result_ty, load);
}
+ core::ir::Instruction* MakeVectorLoadF16(core::ir::Access* access,
+ const core::type::Vector* result_ty,
+ core::ir::Value* byte_idx) {
+ core::ir::Instruction* load = nullptr;
+ // Vec4 ends up being the same as a bitcast of vec2<u32> to a vec4<f16>
+ if (result_ty->Width() == 4) {
+ return b.Bitcast(result_ty, b.Load(access));
+ }
+
+ // A vec3 will be stored as a vec4, so we can bitcast as if we're a vec4 and swizzle out the
+ // last element
+ if (result_ty->Width() == 3) {
+ auto* bc = b.Bitcast(ty.vec4(result_ty->type()), b.Load(access));
+ return b.Swizzle(result_ty, bc, {0, 1, 2});
+ }
+
+ // Vec2 ends up being the same as a bitcast u32 to vec2<f16>
+ if (result_ty->Width() == 2) {
+ auto* vec_idx = CalculateVectorOffset(byte_idx);
+ if (auto* cnst = vec_idx->As<core::ir::Constant>()) {
+ if (cnst->Value()->ValueAs<uint32_t>() == 2u) {
+ load = b.Swizzle(ty.u32(), b.Load(access), {2});
+ } else {
+ load = b.Swizzle(ty.u32(), b.Load(access), {0});
+ }
+ } else {
+ auto* ubo = b.Load(access);
+ // if vec_idx == 2 -> zw
+ auto* sw_lhs = b.Swizzle(ty.u32(), ubo, {2});
+ // else -> xy
+ auto* sw_rhs = b.Swizzle(ty.u32(), ubo, {0});
+ auto* cond = b.Equal(ty.bool_(), vec_idx, 2_u);
+
+ Vector<core::ir::Value*, 3> args{sw_rhs->Result(0), sw_lhs->Result(0),
+ cond->Result(0)};
+
+ load = b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(
+ b.InstructionResult(ty.u32()), args);
+ b.Append(load);
+ }
+ return b.Bitcast(result_ty, load);
+ }
+
+ TINT_UNREACHABLE();
+ }
+
// Creates a load function for the given `var` and `matrix` combination. Essentially creates
// a function similar to:
//
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
index 997a9d9..86b61e8 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
@@ -438,7 +438,7 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeUniformAccessTest, DISABLED_UniformAccessMat2x3F16) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessMat2x3F16) {
auto* var = b.Var<uniform, mat2x3<f16>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -446,10 +446,8 @@
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
b.Let("a", b.Load(var));
- b.Let("b", b.LoadVectorElement(var, 0_u));
- b.Let("c", b.LoadVectorElement(var, 1_u));
- b.Let("d", b.LoadVectorElement(var, 2_u));
- b.Let("e", b.LoadVectorElement(var, 3_u));
+ b.Let("b", b.Load(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u)));
+ b.Let("c", b.LoadVectorElement(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u), 2_u));
b.Return(func);
});
@@ -460,16 +458,14 @@
%foo = @fragment func():void {
$B2: {
- %3:vec4<f16> = load %v
- %a:vec4<f16> = let %3
- %5:f16 = load_vector_element %v, 0u
- %b:f16 = let %5
- %7:f16 = load_vector_element %v, 1u
- %c:f16 = let %7
- %9:f16 = load_vector_element %v, 2u
- %d:f16 = let %9
- %11:f16 = load_vector_element %v, 3u
- %e:f16 = let %11
+ %3:mat2x3<f16> = load %v
+ %a:mat2x3<f16> = let %3
+ %5:ptr<uniform, vec3<f16>, read> = access %v, 1u
+ %6:vec3<f16> = load %5
+ %b:vec3<f16> = let %6
+ %8:ptr<uniform, vec3<f16>, read> = access %v, 1u
+ %9:f16 = load_vector_element %8, 2u
+ %c:f16 = let %9
ret
}
}
@@ -478,24 +474,43 @@
auto* expect = R"(
$B1: { # root
- %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+ %v:ptr<uniform, array<vec4<u32>, 1>, read> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
$B2: {
- %3:vec4<f16> = %v.Load4F16 0u
- %a:vec4<f16> = let %3
- %5:f16 = %v.LoadF16 0u
- %b:f16 = let %5
- %7:f16 = %v.LoadF16 2u
- %c:f16 = let %7
- %9:f16 = %v.LoadF16 4u
- %d:f16 = let %9
- %11:f16 = %v.LoadF16 6u
- %e:f16 = let %11
+ %3:mat2x3<f16> = call %4, 0u
+ %a:mat2x3<f16> = let %3
+ %6:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %7:vec4<u32> = load %6
+ %8:vec4<f16> = bitcast %7
+ %9:vec3<f16> = swizzle %8, xyz
+ %b:vec3<f16> = let %9
+ %11:ptr<uniform, vec4<u32>, read> = access %v, 0u
+ %12:u32 = load_vector_element %11, 3u
+ %13:f32 = hlsl.f16tof32 %12
+ %14:f16 = convert %13
+ %c:f16 = let %14
ret
}
}
+%4 = func(%start_byte_offset:u32):mat2x3<f16> {
+ $B3: {
+ %17:u32 = div %start_byte_offset, 16u
+ %18:ptr<uniform, vec4<u32>, read> = access %v, %17
+ %19:vec4<u32> = load %18
+ %20:vec4<f16> = bitcast %19
+ %21:vec3<f16> = swizzle %20, xyz
+ %22:u32 = add 8u, %start_byte_offset
+ %23:u32 = div %22, 16u
+ %24:ptr<uniform, vec4<u32>, read> = access %v, %23
+ %25:vec4<u32> = load %24
+ %26:vec4<f16> = bitcast %25
+ %27:vec3<f16> = swizzle %26, xyz
+ %28:mat2x3<f16> = construct %21, %27
+ ret %28
+ }
+}
)";
Run(DecomposeUniformAccess);
EXPECT_EQ(expect, str());