[hlsl] Fixup loading of f16 vectors.
When loading an F16 vector in HLSL we have to make sure not to emit `>>`
which parses incorrect, instead we need to emit `> >`.
When templating the load, the `Load` instruction is always used, not the
`Load{n}` version for the size of the loaded vector.
Bug: 42251045
Change-Id: I3334ba26eb8a66b7e78c7db6fbccff21c091612e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196958
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 015c5f3..997d30b 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -298,13 +298,7 @@
)");
}
-// TODO(dsinclair): Fails DXC validation
-// hlsl.hlsl:4:55: warning: use of right-shift operator ('>>') in template argument will require
-// parentheses in C++11 [-Wc++11-compat]
-// vector<float16_t, 4> a = v.Load4<vector<float16_t, 4>>(0u);
-// ^
-// ( )
-TEST_F(HlslWriterTest, DISABLED_AccessStorageVectorF16) {
+TEST_F(HlslWriterTest, AccessStorageVectorF16) {
auto* var = b.Var<storage, vec4<f16>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -323,7 +317,7 @@
EXPECT_EQ(output_.hlsl, R"(
ByteAddressBuffer v : register(t0);
void foo() {
- vector<float16_t, 4> a = v.Load4<vector<float16_t, 4>>(0u);
+ vector<float16_t, 4> a = v.Load<vector<float16_t, 4> >(0u);
float16_t b = v.Load<float16_t>(0u);
float16_t c = v.Load<float16_t>(2u);
float16_t d = v.Load<float16_t>(4u);
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index f8cdf56..253ef03 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -654,14 +654,17 @@
fn = BuiltinFn::kLoad;
suffix = "<float16_t>";
} else if (fn == BuiltinFn::kLoad2F16) {
- fn = BuiltinFn::kLoad2;
- suffix = "<vector<float16_t, 2>>";
+ fn = BuiltinFn::kLoad;
+ // Note space between '> >' is required for DXC
+ suffix = "<vector<float16_t, 2> >";
} else if (fn == BuiltinFn::kLoad3F16) {
- fn = BuiltinFn::kLoad3;
- suffix = "<vector<float16_t, 3>>";
+ fn = BuiltinFn::kLoad;
+ // Note space between '> >' is required for DXC
+ suffix = "<vector<float16_t, 3> >";
} else if (fn == BuiltinFn::kLoad4F16) {
- fn = BuiltinFn::kLoad4;
- suffix = "<vector<float16_t, 4>>";
+ fn = BuiltinFn::kLoad;
+ // Note space between '> >' is required for DXC
+ suffix = "<vector<float16_t, 4> >";
}
EmitValue(out, c->Object());