[hlsl] Add f16 matrix support to uniform loading

This CL adds the ability to load f16 matrix values to the decompose
uniform access transform.

Bug: 349867642
Change-Id: Ia963c49af1e00146b573e31b46a3ef20bf41fc29
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/198455
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 7de4034..f2d4120 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -865,6 +865,52 @@
 )");
 }
 
+TEST_F(HlslWriterTest, AccessUniformScalar) {
+    auto* var = b.Var<uniform, f32, core::Access::kRead>("v");
+    var->SetBindingPoint(0, 0);
+
+    b.ir.root_block->Append(var);
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Let("a", b.Load(var));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+  uint4 v[1];
+};
+void foo() {
+  float a = asfloat(v[0u].x);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, AccessUniformScalarF16) {
+    auto* var = b.Var<uniform, f16, core::Access::kRead>("v");
+    var->SetBindingPoint(0, 0);
+
+    b.ir.root_block->Append(var);
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Let("a", b.Load(var));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+  uint4 v[1];
+};
+void foo() {
+  float16_t a = float16_t(f16tof32(v[0u].x));
+}
+
+)");
+}
+
 TEST_F(HlslWriterTest, AccessUniformVector) {
     auto* var = b.Var<uniform, vec4<f32>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
@@ -896,39 +942,17 @@
 )");
 }
 
-TEST_F(HlslWriterTest, AccessUniformStorageScalarF16) {
-    auto* var = b.Var<uniform, f16, core::Access::kRead>("v");
-    var->SetBindingPoint(0, 0);
-
-    b.ir.root_block->Append(var);
-    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
-    b.Append(func->Block(), [&] {
-        b.Let("a", b.Load(var));
-        b.Return(func);
-    });
-
-    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
-    EXPECT_EQ(output_.hlsl, R"(
-cbuffer cbuffer_v : register(b0) {
-  uint4 v[1];
-};
-void foo() {
-  float16_t a = float16_t(f16tof32(v[0u].x));
-}
-
-)");
-}
-
-TEST_F(HlslWriterTest, AccessUniformStorageVectorF16) {
+TEST_F(HlslWriterTest, AccessUniformVectorF16) {
     auto* var = b.Var<uniform, vec4<f16>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
     b.ir.root_block->Append(var);
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
+        auto* x = b.Var("x", 1_u);
         b.Let("a", b.Load(var));
         b.Let("b", b.LoadVectorElement(var, 0_u));
-        b.Let("c", b.LoadVectorElement(var, 1_u));
+        b.Let("c", b.LoadVectorElement(var, b.Load(x)));
         b.Let("d", b.LoadVectorElement(var, 2_u));
         b.Let("e", b.LoadVectorElement(var, 3_u));
         b.Return(func);
@@ -945,13 +969,19 @@
   uint4 shift = (16u).xxxx;
   float4 t_low = f16tof32((v & mask));
   float4 t_high = f16tof32(((v >> shift) & mask));
-  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+  float16_t v_1 = float16_t(t_low.x);
+  float16_t v_2 = float16_t(t_high.x);
+  float16_t v_3 = float16_t(t_low.y);
+  return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
 }
 
 void foo() {
+  uint x = 1u;
   vector<float16_t, 4> a = tint_bitcast_to_f16(v[0u]);
   float16_t b = float16_t(f16tof32(v[0u].x));
-  float16_t c = float16_t(f16tof32((v[0u].x >> 16u)));
+  uint v_4 = (min(x, 3u) * 2u);
+  uint v_5 = v[(v_4 / 16u)][((v_4 % 16u) / 4u)];
+  float16_t c = float16_t(f16tof32((v_5 >> ((((v_4 % 4u) == 0u)) ? (0u) : (16u)))));
   float16_t d = float16_t(f16tof32(v[0u].y));
   float16_t e = float16_t(f16tof32((v[0u].y >> 16u)));
 }
@@ -959,50 +989,6 @@
 )");
 }
 
-TEST_F(HlslWriterTest, DISABLED_AccessUniformStorageMat2x3F16) {
-    auto* var = b.Var<uniform, mat2x3<f16>, core::Access::kRead>("v");
-    var->SetBindingPoint(0, 0);
-
-    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
-    b.Append(func->Block(), [&] {
-        b.Let("a", b.Load(var));
-        b.Let("b", b.Load(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u)));
-        b.Let("c", b.LoadVectorElement(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u), 2_u));
-        b.Return(func);
-    });
-
-    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
-    EXPECT_EQ(output_.hlsl, R"(
-cbuffer cbuffer_v : register(b0) {
-  uint4 v[1];
-};
-
-matrix<float16_t, 2, 3> v_load(uint offset) {
-  const uint scalar_offset = ((offset + 0u)) / 4;
-  uint4 ubo_load_1 = v[scalar_offset / 4];
-  uint2 ubo_load = ((scalar_offset & 2) ? ubo_load_1.zw : ubo_load_1.xy);
-  vector<float16_t, 2> ubo_load_xz = vector<float16_t, 2>(f16tof32(ubo_load & 0xFFFF));
-  float16_t ubo_load_y = f16tof32(ubo_load[0] >> 16);
-  const uint scalar_offset_1 = ((offset + 8u)) / 4;
-  uint4 ubo_load_3 = v[scalar_offset_1 / 4];
-  uint2 ubo_load_2 = ((scalar_offset_1 & 2) ? ubo_load_3.zw : ubo_load_3.xy);
-  vector<float16_t, 2> ubo_load_2_xz = vector<float16_t, 2>(f16tof32(ubo_load_2 & 0xFFFF));
-  float16_t ubo_load_2_y = f16tof32(ubo_load_2[0] >> 16);
-  return matrix<float16_t, 2, 3>(vector<float16_t, 3>(ubo_load_xz[0], ubo_load_y, ubo_load_xz[1]), vector<float16_t, 3>(ubo_load_2_xz[0], ubo_load_2_y, ubo_load_2_xz[1]));
-}
-
-void foo() {
-  matrix<float16_t, 2, 3> a = v_load(0u);
-  uint2 ubo_load_4 = v[0].zw;
-  vector<float16_t, 2> ubo_load_4_xz = vector<float16_t, 2>(f16tof32(ubo_load_4 & 0xFFFF));
-  float16_t ubo_load_4_y = f16tof32(ubo_load_4[0] >> 16);
-  vector<float16_t, 3> b = vector<float16_t, 3>(ubo_load_4_xz[0], ubo_load_4_y, ubo_load_4_xz[1]);
-  float16_t c = float16_t(f16tof32(((v[0].w) & 0xFFFF)));
-}
-
-)");
-}
-
 TEST_F(HlslWriterTest, AccessUniformMatrix) {
     auto* var = b.Var<uniform, mat4x4<f32>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
@@ -1071,6 +1057,49 @@
 )");
 }
 
+TEST_F(HlslWriterTest, AccessUniformMat2x3F16) {
+    auto* var = b.Var<uniform, mat2x3<f16>, core::Access::kRead>("v");
+    var->SetBindingPoint(0, 0);
+    b.ir.root_block->Append(var);
+
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Let("a", b.Load(var));
+        b.Let("b", b.Load(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u)));
+        b.Let("c", b.LoadVectorElement(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u), 2_u));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+  uint4 v[1];
+};
+vector<float16_t, 4> tint_bitcast_to_f16(uint4 src) {
+  uint4 v = src;
+  uint4 mask = (65535u).xxxx;
+  uint4 shift = (16u).xxxx;
+  float4 t_low = f16tof32((v & mask));
+  float4 t_high = f16tof32(((v >> shift) & mask));
+  float16_t v_1 = float16_t(t_low.x);
+  float16_t v_2 = float16_t(t_high.x);
+  float16_t v_3 = float16_t(t_low.y);
+  return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
+}
+
+matrix<float16_t, 2, 3> v_4(uint start_byte_offset) {
+  vector<float16_t, 3> v_5 = tint_bitcast_to_f16(v[(start_byte_offset / 16u)]).xyz;
+  return matrix<float16_t, 2, 3>(v_5, tint_bitcast_to_f16(v[((8u + start_byte_offset) / 16u)]).xyz);
+}
+
+void foo() {
+  matrix<float16_t, 2, 3> a = v_4(0u);
+  vector<float16_t, 3> b = tint_bitcast_to_f16(v[0u]).xyz;
+  float16_t c = float16_t(f16tof32(v[0u].w));
+}
+
+)");
+}
 TEST_F(HlslWriterTest, AccessUniformMatrix3x2) {
     auto* var = b.Var<uniform, mat3x2<f32>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
@@ -1143,6 +1172,49 @@
 )");
 }
 
+TEST_F(HlslWriterTest, AccessUniformMatrix2x2F16) {
+    auto* var = b.Var<uniform, mat2x2<f16>, core::Access::kRead>("v");
+    var->SetBindingPoint(0, 0);
+
+    b.ir.root_block->Append(var);
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Let("a", b.Load(var));
+        b.Let("b", b.Load(b.Access(ty.ptr<uniform, vec2<f16>, core::Access::kRead>(), var, 1_u)));
+        b.Let("c", b.LoadVectorElement(
+                       b.Access(ty.ptr<uniform, vec2<f16>, core::Access::kRead>(), var, 1_u), 1_u));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+  uint4 v[1];
+};
+vector<float16_t, 2> tint_bitcast_to_f16(uint src) {
+  uint v = src;
+  float t_low = f16tof32((v & 65535u));
+  float t_high = f16tof32(((v >> 16u) & 65535u));
+  float16_t v_1 = float16_t(t_low);
+  return vector<float16_t, 2>(v_1, float16_t(t_high));
+}
+
+matrix<float16_t, 2, 2> v_2(uint start_byte_offset) {
+  uint4 v_3 = v[(start_byte_offset / 16u)];
+  vector<float16_t, 2> v_4 = tint_bitcast_to_f16((((((start_byte_offset % 16u) / 4u) == 2u)) ? (v_3.z) : (v_3.x)));
+  uint4 v_5 = v[((4u + start_byte_offset) / 16u)];
+  return matrix<float16_t, 2, 2>(v_4, tint_bitcast_to_f16(((((((4u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_5.z) : (v_5.x))));
+}
+
+void foo() {
+  matrix<float16_t, 2, 2> a = v_2(0u);
+  vector<float16_t, 2> b = tint_bitcast_to_f16(v[0u].x);
+  float16_t c = float16_t(f16tof32((v[0u].y >> 16u)));
+}
+
+)");
+}
+
 TEST_F(HlslWriterTest, AccessUniformArray) {
     auto* var = b.Var<uniform, array<vec3<f32>, 5>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
@@ -1190,6 +1262,65 @@
 )");
 }
 
+TEST_F(HlslWriterTest, AccessUniformArrayF16) {
+    auto* var = b.Var<uniform, array<vec3<f16>, 5>, core::Access::kRead>("v");
+    var->SetBindingPoint(0, 0);
+
+    b.ir.root_block->Append(var);
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Let("a", b.Load(var));
+        b.Let("b", b.Load(b.Access(ty.ptr<uniform, vec3<f16>, core::Access::kRead>(), var, 3_u)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_v : register(b0) {
+  uint4 v[3];
+};
+vector<float16_t, 4> tint_bitcast_to_f16(uint4 src) {
+  uint4 v = src;
+  uint4 mask = (65535u).xxxx;
+  uint4 shift = (16u).xxxx;
+  float4 t_low = f16tof32((v & mask));
+  float4 t_high = f16tof32(((v >> shift) & mask));
+  float16_t v_1 = float16_t(t_low.x);
+  float16_t v_2 = float16_t(t_high.x);
+  float16_t v_3 = float16_t(t_low.y);
+  return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
+}
+
+typedef vector<float16_t, 3> ary_ret[5];
+ary_ret v_4(uint start_byte_offset) {
+  vector<float16_t, 3> a[5] = (vector<float16_t, 3>[5])0;
+  {
+    uint v_5 = 0u;
+    v_5 = 0u;
+    while(true) {
+      uint v_6 = v_5;
+      if ((v_6 >= 5u)) {
+        break;
+      }
+      a[v_6] = tint_bitcast_to_f16(v[((start_byte_offset + (v_6 * 8u)) / 16u)]).xyz;
+      {
+        v_5 = (v_6 + 1u);
+      }
+      continue;
+    }
+  }
+  vector<float16_t, 3> v_7[5] = a;
+  return v_7;
+}
+
+void foo() {
+  vector<float16_t, 3> a[5] = v_4(0u);
+  vector<float16_t, 3> b = tint_bitcast_to_f16(v[1u]).xyz;
+}
+
+)");
+}
+
 TEST_F(HlslWriterTest, AccessUniformArrayWhichCanHaveSizesOtherThenFive) {
     auto* var = b.Var<uniform, array<vec3<f32>, 42>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
@@ -1278,6 +1409,48 @@
 )");
 }
 
+TEST_F(HlslWriterTest, AccessUniformStructF16) {
+    auto* SB = ty.Struct(mod.symbols.New("SB"), {
+                                                    {mod.symbols.New("a"), ty.i32()},
+                                                    {mod.symbols.New("b"), ty.f16()},
+                                                });
+
+    auto* var = b.Var("v", uniform, SB, core::Access::kRead);
+    var->SetBindingPoint(0, 0);
+
+    b.ir.root_block->Append(var);
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Let("a", b.Load(var));
+        b.Let("b", b.Load(b.Access(ty.ptr<uniform, f16, core::Access::kRead>(), var, 1_u)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(struct SB {
+  int a;
+  float16_t b;
+};
+
+
+cbuffer cbuffer_v : register(b0) {
+  uint4 v[1];
+};
+SB v_1(uint start_byte_offset) {
+  int v_2 = asint(v[(start_byte_offset / 16u)][((start_byte_offset % 16u) / 4u)]);
+  uint v_3 = v[((4u + start_byte_offset) / 16u)][(((4u + start_byte_offset) % 16u) / 4u)];
+  SB v_4 = {v_2, float16_t(f16tof32((v_3 >> (((((4u + start_byte_offset) % 4u) == 0u)) ? (0u) : (16u)))))};
+  return v_4;
+}
+
+void foo() {
+  SB a = v_1(0u);
+  float16_t b = float16_t(f16tof32(v[0u].y));
+}
+
+)");
+}
+
 TEST_F(HlslWriterTest, AccessUniformStructNested) {
     auto* Inner =
         ty.Struct(mod.symbols.New("Inner"), {
diff --git a/src/tint/lang/hlsl/writer/bitcast_test.cc b/src/tint/lang/hlsl/writer/bitcast_test.cc
index d873ada..ab69c25 100644
--- a/src/tint/lang/hlsl/writer/bitcast_test.cc
+++ b/src/tint/lang/hlsl/writer/bitcast_test.cc
@@ -187,21 +187,24 @@
   uint v = src;
   float t_low = f16tof32((v & 65535u));
   float t_high = f16tof32(((v >> 16u) & 65535u));
-  return vector<float16_t, 2>(t_low.x, t_high.x);
+  float16_t v_1 = float16_t(t_low);
+  return vector<float16_t, 2>(v_1, float16_t(t_high));
 }
 
 vector<float16_t, 2> tint_bitcast_to_f16_1(float src) {
   uint v = asuint(src);
   float t_low = f16tof32((v & 65535u));
   float t_high = f16tof32(((v >> 16u) & 65535u));
-  return vector<float16_t, 2>(t_low.x, t_high.x);
+  float16_t v_2 = float16_t(t_low);
+  return vector<float16_t, 2>(v_2, float16_t(t_high));
 }
 
 vector<float16_t, 2> tint_bitcast_to_f16(int src) {
   uint v = asuint(src);
   float t_low = f16tof32((v & 65535u));
   float t_high = f16tof32(((v >> 16u) & 65535u));
-  return vector<float16_t, 2>(t_low.x, t_high.x);
+  float16_t v_3 = float16_t(t_low);
+  return vector<float16_t, 2>(v_3, float16_t(t_high));
 }
 
 void foo() {
@@ -277,7 +280,10 @@
   uint2 shift = (16u).xx;
   float2 t_low = f16tof32((v & mask));
   float2 t_high = f16tof32(((v >> shift) & mask));
-  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+  float16_t v_1 = float16_t(t_low.x);
+  float16_t v_2 = float16_t(t_high.x);
+  float16_t v_3 = float16_t(t_low.y);
+  return vector<float16_t, 4>(v_1, v_2, v_3, float16_t(t_high.y));
 }
 
 vector<float16_t, 4> tint_bitcast_to_f16_1(float2 src) {
@@ -286,7 +292,10 @@
   uint2 shift = (16u).xx;
   float2 t_low = f16tof32((v & mask));
   float2 t_high = f16tof32(((v >> shift) & mask));
-  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+  float16_t v_4 = float16_t(t_low.x);
+  float16_t v_5 = float16_t(t_high.x);
+  float16_t v_6 = float16_t(t_low.y);
+  return vector<float16_t, 4>(v_4, v_5, v_6, float16_t(t_high.y));
 }
 
 vector<float16_t, 4> tint_bitcast_to_f16(int2 src) {
@@ -295,7 +304,10 @@
   uint2 shift = (16u).xx;
   float2 t_low = f16tof32((v & mask));
   float2 t_high = f16tof32(((v >> shift) & mask));
-  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+  float16_t v_7 = float16_t(t_low.x);
+  float16_t v_8 = float16_t(t_high.x);
+  float16_t v_9 = float16_t(t_low.y);
+  return vector<float16_t, 4>(v_7, v_8, v_9, float16_t(t_high.y));
 }
 
 void foo() {
diff --git a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
index 29bc9dd..93e09aa 100644
--- a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
@@ -526,9 +526,11 @@
                         },
                         TINT_ICE_ON_NO_MATCH);
 
+                    bool src_vec = src_type->Is<core::type::Vector>();
+
                     core::ir::Value* mask = nullptr;
                     core::ir::Value* shift = nullptr;
-                    if (src_type->Is<core::type::Vector>()) {
+                    if (src_vec) {
                         mask = b.Let("mask", b.Splat(uint_ty, 0xffff_u))->Result(0);
                         shift = b.Let("shift", b.Splat(uint_ty, 16_u))->Result(0);
                     } else {
@@ -544,13 +546,26 @@
                     auto* t_high = b.Let(
                         "t_high", b.Call<hlsl::ir::BuiltinCall>(float_ty, BuiltinFn::kF16Tof32, h));
 
-                    auto* x = b.Swizzle(ty.f16(), t_low, {0_u});
-                    auto* y = b.Swizzle(ty.f16(), t_high, {0_u});
-                    if (dst_type->As<core::type::Vector>()->Width() == 2) {
+                    core::ir::Instruction* x = nullptr;
+                    core::ir::Instruction* y = nullptr;
+                    if (src_vec) {
+                        x = b.Swizzle(ty.f32(), t_low, {0_u});
+                        y = b.Swizzle(ty.f32(), t_high, {0_u});
+                    } else {
+                        x = t_low;
+                        y = t_high;
+                    }
+                    x = b.Convert(ty.f16(), x);
+                    y = b.Convert(ty.f16(), y);
+
+                    auto dst_width = dst_type->As<core::type::Vector>()->Width();
+                    TINT_ASSERT(dst_width == 2 || dst_width == 4);
+
+                    if (dst_width == 2) {
                         b.Return(f, b.Construct(dst_type, x, y));
                     } else {
-                        auto* z = b.Swizzle(ty.f16(), t_low, {1_u});
-                        auto* w = b.Swizzle(ty.f16(), t_high, {1_u});
+                        auto* z = b.Convert(ty.f16(), b.Swizzle(ty.f32(), t_low, {1_u}));
+                        auto* w = b.Convert(ty.f16(), b.Swizzle(ty.f32(), t_high, {1_u}));
                         b.Return(f, b.Construct(dst_type, x, y, z, w));
                     }
                 });
diff --git a/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc b/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc
index 3867b6f..65034f5 100644
--- a/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/builtin_polyfill_test.cc
@@ -270,8 +270,8 @@
     %12:u32 = and %11, 65535u
     %13:f32 = hlsl.f16tof32 %12
     %t_high:f32 = let %13
-    %15:f16 = swizzle %t_low, x
-    %16:f16 = swizzle %t_high, x
+    %15:f16 = convert %t_low
+    %16:f16 = convert %t_high
     %17:vec2<f16> = construct %15, %16
     ret %17
   }
@@ -383,12 +383,16 @@
     %17:vec2<u32> = and %16, %mask
     %18:vec2<f32> = hlsl.f16tof32 %17
     %t_high:vec2<f32> = let %18
-    %20:f16 = swizzle %t_low, x
-    %21:f16 = swizzle %t_high, x
-    %22:f16 = swizzle %t_low, y
-    %23:f16 = swizzle %t_high, y
-    %24:vec4<f16> = construct %20, %21, %22, %23
-    ret %24
+    %20:f32 = swizzle %t_low, x
+    %21:f32 = swizzle %t_high, x
+    %22:f16 = convert %20
+    %23:f16 = convert %21
+    %24:f32 = swizzle %t_low, y
+    %25:f16 = convert %24
+    %26:f32 = swizzle %t_high, y
+    %27:f16 = convert %26
+    %28:vec4<f16> = construct %22, %23, %25, %27
+    ret %28
   }
 }
 )";
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
index 732ce18..9fb8ff5 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
@@ -290,10 +290,10 @@
     }
 
     // Creates the appropriate load instructions for the given result type.
-    core::ir::Call* MakeLoad(core::ir::Instruction* inst,
-                             core::ir::Var* var,
-                             const core::type::Type* result_ty,
-                             core::ir::Value* byte_idx) {
+    core::ir::Instruction* MakeLoad(core::ir::Instruction* inst,
+                                    core::ir::Var* var,
+                                    const core::type::Type* result_ty,
+                                    core::ir::Value* byte_idx) {
         if (result_ty->is_float_scalar() || result_ty->is_integer_scalar()) {
             return MakeScalarLoad(var, result_ty, byte_idx);
         }
@@ -327,28 +327,35 @@
         auto* vec_idx = CalculateVectorOffset(byte_idx);
         core::ir::Instruction* load = b.LoadVectorElement(access, vec_idx);
         if (result_ty->Is<core::type::F16>()) {
-            if (auto* cnst = byte_idx->As<core::ir::Constant>()) {
-                if (cnst->Value()->ValueAs<uint32_t>() % 4 != 0) {
-                    load = b.ShiftRight(ty.u32(), load, 16_u);
-                }
-            } else {
-                auto* false_ = b.Value(16_u);
-                auto* true_ = b.Value(0_u);
-                auto* cond = b.Equal(ty.bool_(), b.Modulo(ty.u32(), byte_idx, 4_u), 0_u);
-
-                Vector<core::ir::Value*, 3> args{false_, true_, cond->Result(0)};
-                auto* shift_amt = b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(
-                    b.InstructionResult(ty.u32()), args);
-                b.Append(shift_amt);
-
-                load = b.ShiftRight(ty.u32(), load, shift_amt);
-            }
-            load = b.Call<hlsl::ir::BuiltinCall>(ty.f32(), hlsl::BuiltinFn::kF16Tof32, load);
-            return b.Convert(result_ty, load);
+            return MakeScalarLoadF16(load, result_ty, byte_idx);
         }
         return b.Bitcast(result_ty, load);
     }
 
+    core::ir::Call* MakeScalarLoadF16(core::ir::Instruction* load,
+                                      const core::type::Type* result_ty,
+                                      core::ir::Value* byte_idx) {
+        // Handle F16
+        if (auto* cnst = byte_idx->As<core::ir::Constant>()) {
+            if (cnst->Value()->ValueAs<uint32_t>() % 4 != 0) {
+                load = b.ShiftRight(ty.u32(), load, 16_u);
+            }
+        } else {
+            auto* false_ = b.Value(16_u);
+            auto* true_ = b.Value(0_u);
+            auto* cond = b.Equal(ty.bool_(), b.Modulo(ty.u32(), byte_idx, 4_u), 0_u);
+
+            Vector<core::ir::Value*, 3> args{false_, true_, cond->Result(0)};
+            auto* shift_amt = b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(
+                b.InstructionResult(ty.u32()), args);
+            b.Append(shift_amt);
+
+            load = b.ShiftRight(ty.u32(), load, shift_amt);
+        }
+        load = b.Call<hlsl::ir::BuiltinCall>(ty.f32(), hlsl::BuiltinFn::kF16Tof32, load);
+        return b.Convert(result_ty, load);
+    }
+
     // When loading a vector we have to take the alignment into account to determine which part of
     // the `uint4` to load. A `vec`  of `u32`, `f32` or `i32` has an alignment requirement of
     // a multiple of 8-bytes (`f16` is 4-bytes). So, this means we'll have memory like:
@@ -366,12 +373,16 @@
     // * A 2-element row, we have to decide if we want the `xy` or `zw` element. We have a minimum
     //   alignment of 8-bytes as per the WGSL spec. So if the `vector_idx != 2` is `0` then we
     //   access the `.xy` component, otherwise it is in the `.zw` component.
-    core::ir::Call* MakeVectorLoad(core::ir::Var* var,
-                                   const core::type::Vector* result_ty,
-                                   core::ir::Value* byte_idx) {
+    core::ir::Instruction* MakeVectorLoad(core::ir::Var* var,
+                                          const core::type::Vector* result_ty,
+                                          core::ir::Value* byte_idx) {
         auto* array_idx = OffsetValueToArrayIndex(byte_idx);
         auto* access = b.Access(ty.ptr(uniform, ty.vec4<u32>()), var, array_idx);
 
+        if (result_ty->DeepestElement()->Is<core::type::F16>()) {
+            return MakeVectorLoadF16(access, result_ty, byte_idx);
+        }
+
         core::ir::Instruction* load = nullptr;
         if (result_ty->Width() == 4) {
             load = b.Load(access);
@@ -406,6 +417,52 @@
         return b.Bitcast(result_ty, load);
     }
 
+    core::ir::Instruction* MakeVectorLoadF16(core::ir::Access* access,
+                                             const core::type::Vector* result_ty,
+                                             core::ir::Value* byte_idx) {
+        core::ir::Instruction* load = nullptr;
+        // Vec4 ends up being the same as a bitcast of vec2<u32> to a vec4<f16>
+        if (result_ty->Width() == 4) {
+            return b.Bitcast(result_ty, b.Load(access));
+        }
+
+        // A vec3 will be stored as a vec4, so we can bitcast as if we're a vec4 and swizzle out the
+        // last element
+        if (result_ty->Width() == 3) {
+            auto* bc = b.Bitcast(ty.vec4(result_ty->type()), b.Load(access));
+            return b.Swizzle(result_ty, bc, {0, 1, 2});
+        }
+
+        // Vec2 ends up being the same as a bitcast u32 to vec2<f16>
+        if (result_ty->Width() == 2) {
+            auto* vec_idx = CalculateVectorOffset(byte_idx);
+            if (auto* cnst = vec_idx->As<core::ir::Constant>()) {
+                if (cnst->Value()->ValueAs<uint32_t>() == 2u) {
+                    load = b.Swizzle(ty.u32(), b.Load(access), {2});
+                } else {
+                    load = b.Swizzle(ty.u32(), b.Load(access), {0});
+                }
+            } else {
+                auto* ubo = b.Load(access);
+                // if vec_idx == 2 -> zw
+                auto* sw_lhs = b.Swizzle(ty.u32(), ubo, {2});
+                // else -> xy
+                auto* sw_rhs = b.Swizzle(ty.u32(), ubo, {0});
+                auto* cond = b.Equal(ty.bool_(), vec_idx, 2_u);
+
+                Vector<core::ir::Value*, 3> args{sw_rhs->Result(0), sw_lhs->Result(0),
+                                                 cond->Result(0)};
+
+                load = b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(
+                    b.InstructionResult(ty.u32()), args);
+                b.Append(load);
+            }
+            return b.Bitcast(result_ty, load);
+        }
+
+        TINT_UNREACHABLE();
+    }
+
     // Creates a load function for the given `var` and `matrix` combination. Essentially creates
     // a function similar to:
     //
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
index 997a9d9..86b61e8 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
@@ -438,7 +438,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_F(HlslWriterDecomposeUniformAccessTest, DISABLED_UniformAccessMat2x3F16) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessMat2x3F16) {
     auto* var = b.Var<uniform, mat2x3<f16>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
@@ -446,10 +446,8 @@
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
         b.Let("a", b.Load(var));
-        b.Let("b", b.LoadVectorElement(var, 0_u));
-        b.Let("c", b.LoadVectorElement(var, 1_u));
-        b.Let("d", b.LoadVectorElement(var, 2_u));
-        b.Let("e", b.LoadVectorElement(var, 3_u));
+        b.Let("b", b.Load(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u)));
+        b.Let("c", b.LoadVectorElement(b.Access(ty.ptr(uniform, ty.vec3<f16>()), var, 1_u), 2_u));
         b.Return(func);
     });
 
@@ -460,16 +458,14 @@
 
 %foo = @fragment func():void {
   $B2: {
-    %3:vec4<f16> = load %v
-    %a:vec4<f16> = let %3
-    %5:f16 = load_vector_element %v, 0u
-    %b:f16 = let %5
-    %7:f16 = load_vector_element %v, 1u
-    %c:f16 = let %7
-    %9:f16 = load_vector_element %v, 2u
-    %d:f16 = let %9
-    %11:f16 = load_vector_element %v, 3u
-    %e:f16 = let %11
+    %3:mat2x3<f16> = load %v
+    %a:mat2x3<f16> = let %3
+    %5:ptr<uniform, vec3<f16>, read> = access %v, 1u
+    %6:vec3<f16> = load %5
+    %b:vec3<f16> = let %6
+    %8:ptr<uniform, vec3<f16>, read> = access %v, 1u
+    %9:f16 = load_vector_element %8, 2u
+    %c:f16 = let %9
     ret
   }
 }
@@ -478,24 +474,43 @@
 
     auto* expect = R"(
 $B1: {  # root
-  %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+  %v:ptr<uniform, array<vec4<u32>, 1>, read> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
   $B2: {
-    %3:vec4<f16> = %v.Load4F16 0u
-    %a:vec4<f16> = let %3
-    %5:f16 = %v.LoadF16 0u
-    %b:f16 = let %5
-    %7:f16 = %v.LoadF16 2u
-    %c:f16 = let %7
-    %9:f16 = %v.LoadF16 4u
-    %d:f16 = let %9
-    %11:f16 = %v.LoadF16 6u
-    %e:f16 = let %11
+    %3:mat2x3<f16> = call %4, 0u
+    %a:mat2x3<f16> = let %3
+    %6:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %7:vec4<u32> = load %6
+    %8:vec4<f16> = bitcast %7
+    %9:vec3<f16> = swizzle %8, xyz
+    %b:vec3<f16> = let %9
+    %11:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %12:u32 = load_vector_element %11, 3u
+    %13:f32 = hlsl.f16tof32 %12
+    %14:f16 = convert %13
+    %c:f16 = let %14
     ret
   }
 }
+%4 = func(%start_byte_offset:u32):mat2x3<f16> {
+  $B3: {
+    %17:u32 = div %start_byte_offset, 16u
+    %18:ptr<uniform, vec4<u32>, read> = access %v, %17
+    %19:vec4<u32> = load %18
+    %20:vec4<f16> = bitcast %19
+    %21:vec3<f16> = swizzle %20, xyz
+    %22:u32 = add 8u, %start_byte_offset
+    %23:u32 = div %22, 16u
+    %24:ptr<uniform, vec4<u32>, read> = access %v, %23
+    %25:vec4<u32> = load %24
+    %26:vec4<f16> = bitcast %25
+    %27:vec3<f16> = swizzle %26, xyz
+    %28:mat2x3<f16> = construct %21, %27
+    ret %28
+  }
+}
 )";
     Run(DecomposeUniformAccess);
     EXPECT_EQ(expect, str());