[hlsl] Add Bitcast support to the HLSL IR printer.

This CL adds the `bitcast` instruction into the HLSL IR printer.

Bug: 42251045
Change-Id: Ica594d0da33e0d4b65b25aa9b4bf3d6f66ca5dfb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195015
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/BUILD.bazel b/src/tint/lang/hlsl/writer/BUILD.bazel
index 8bb3b26..38b2173 100644
--- a/src/tint/lang/hlsl/writer/BUILD.bazel
+++ b/src/tint/lang/hlsl/writer/BUILD.bazel
@@ -91,6 +91,7 @@
   srcs = [
     "access_test.cc",
     "binary_test.cc",
+    "bitcast_test.cc",
     "constant_test.cc",
     "construct_test.cc",
     "convert_test.cc",
diff --git a/src/tint/lang/hlsl/writer/BUILD.cmake b/src/tint/lang/hlsl/writer/BUILD.cmake
index d36ee9d..a0b853c 100644
--- a/src/tint/lang/hlsl/writer/BUILD.cmake
+++ b/src/tint/lang/hlsl/writer/BUILD.cmake
@@ -102,6 +102,7 @@
 tint_add_target(tint_lang_hlsl_writer_test test
   lang/hlsl/writer/access_test.cc
   lang/hlsl/writer/binary_test.cc
+  lang/hlsl/writer/bitcast_test.cc
   lang/hlsl/writer/constant_test.cc
   lang/hlsl/writer/construct_test.cc
   lang/hlsl/writer/convert_test.cc
diff --git a/src/tint/lang/hlsl/writer/BUILD.gn b/src/tint/lang/hlsl/writer/BUILD.gn
index 4057fc3..3e6fd99 100644
--- a/src/tint/lang/hlsl/writer/BUILD.gn
+++ b/src/tint/lang/hlsl/writer/BUILD.gn
@@ -94,6 +94,7 @@
       sources = [
         "access_test.cc",
         "binary_test.cc",
+        "bitcast_test.cc",
         "constant_test.cc",
         "construct_test.cc",
         "convert_test.cc",
diff --git a/src/tint/lang/hlsl/writer/bitcast_test.cc b/src/tint/lang/hlsl/writer/bitcast_test.cc
new file mode 100644
index 0000000..14b4cee
--- /dev/null
+++ b/src/tint/lang/hlsl/writer/bitcast_test.cc
@@ -0,0 +1,306 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/fluent_types.h"
+#include "src/tint/lang/core/ir/function.h"
+#include "src/tint/lang/core/number.h"
+#include "src/tint/lang/hlsl/writer/helper_test.h"
+
+#include "gtest/gtest.h"
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+namespace tint::hlsl::writer {
+namespace {
+
+TEST_F(HlslWriterTest, BitcastIdentityNumeric) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", 1_i);
+        b.Let("bc", b.Bitcast<i32>(b.Load(a)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  int a = 1;
+  int bc = a;
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastIdentityVec) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", b.Construct<vec2<f32>>(1_f, 2_f));
+        b.Let("bc", b.Bitcast<vec2<f32>>(b.Load(a)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  float2 a = float2(1.0f, 2.0f);
+  float2 bc = a;
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToFloat) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", 1_i);
+        b.Let("bc", b.Bitcast<f32>(b.Load(a)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  int a = 1;
+  float bc = asfloat(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToInt) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", 1_u);
+        b.Let("bc", b.Bitcast<i32>(b.Load(a)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  uint a = 1u;
+  int bc = asint(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToUint) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", 1_i);
+        b.Let("bc", b.Bitcast<u32>(b.Load(a)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  int a = 1;
+  uint bc = asuint(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastFromVec2F16) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", b.Construct<vec2<f16>>(1_h, 2_h));
+        auto* z = b.Load(a);
+        b.Let("b", b.Bitcast<i32>(z));
+        b.Let("c", b.Bitcast<f32>(z));
+        b.Let("d", b.Bitcast<u32>(z));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(int tint_bitcast_from_f16(vector<float16_t, 2> src) {
+  uint2 r = f32tof16(float2(src));
+  return asint(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+float tint_bitcast_from_f16_1(vector<float16_t, 2> src) {
+  uint2 r = f32tof16(float2(src));
+  return asfloat(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+uint tint_bitcast_from_f16_2(vector<float16_t, 2> src) {
+  uint2 r = f32tof16(float2(src));
+  return asuint(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+
+void foo() {
+  vector<float16_t, 2> a = vector<float16_t, 2>(float16_t(1.0h), float16_t(2.0h));
+  int b = tint_bitcast_from_f16(a);
+  float c = tint_bitcast_from_f16_1(a);
+  uint d = tint_bitcast_from_f16_2(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToVec2F16) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", 1_i);
+        b.Let("b", b.Bitcast<vec2<f16>>(b.Load(a)));
+
+        auto* c = b.Var("c", 1_f);
+        b.Let("d", b.Bitcast<vec2<f16>>(b.Load(c)));
+
+        auto* e = b.Var("e", 1_u);
+        b.Let("f", b.Bitcast<vec2<f16>>(b.Load(e)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(vector<float16_t, 2> tint_bitcast_to_f16(int src) {
+  uint v = asuint(src);
+  float t_low = f16tof32(v & 0xffff);
+  float t_high = f16tof32((v >> 16) & 0xffff);
+  return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16_1(float src) {
+  uint v = asuint(src);
+  float t_low = f16tof32(v & 0xffff);
+  float t_high = f16tof32((v >> 16) & 0xffff);
+  return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16_2(uint src) {
+  uint v = asuint(src);
+  float t_low = f16tof32(v & 0xffff);
+  float t_high = f16tof32((v >> 16) & 0xffff);
+  return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+
+void foo() {
+  int a = 1;
+  vector<float16_t, 2> b = tint_bitcast_to_f16(a);
+  float c = 1.0f;
+  vector<float16_t, 2> d = tint_bitcast_to_f16_1(c);
+  uint e = 1u;
+  vector<float16_t, 2> f = tint_bitcast_to_f16_2(e);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastFromVec4F16) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", b.Construct<vec4<f16>>(1_h, 2_h, 3_h, 4_h));
+        auto* z = b.Load(a);
+        b.Let("b", b.Bitcast<vec2<i32>>(z));
+        b.Let("c", b.Bitcast<vec2<f32>>(z));
+        b.Let("d", b.Bitcast<vec2<u32>>(z));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(int2 tint_bitcast_from_f16(vector<float16_t, 4> src) {
+  uint4 r = f32tof16(float4(src));
+  return asint(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+float2 tint_bitcast_from_f16_1(vector<float16_t, 4> src) {
+  uint4 r = f32tof16(float4(src));
+  return asfloat(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+uint2 tint_bitcast_from_f16_2(vector<float16_t, 4> src) {
+  uint4 r = f32tof16(float4(src));
+  return asuint(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+
+void foo() {
+  vector<float16_t, 4> a = vector<float16_t, 4>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h), float16_t(4.0h));
+  int2 b = tint_bitcast_from_f16(a);
+  float2 c = tint_bitcast_from_f16_1(a);
+  uint2 d = tint_bitcast_from_f16_2(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToVec4F16) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* a = b.Var("a", b.Construct<vec2<i32>>(1_i, 2_i));
+        b.Let("b", b.Bitcast<vec4<f16>>(b.Load(a)));
+
+        auto* c = b.Var("c", b.Construct<vec2<f32>>(1_f, 2_f));
+        b.Let("d", b.Bitcast<vec4<f16>>(b.Load(c)));
+
+        auto* e = b.Var("e", b.Construct<vec2<u32>>(1_u, 2_u));
+        b.Let("f", b.Bitcast<vec4<f16>>(b.Load(e)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(vector<float16_t, 4> tint_bitcast_to_f16(int2 src) {
+  uint2 v = asuint(src);
+  float2 t_low = f16tof32(v & 0xffff);
+  float2 t_high = f16tof32((v >> 16) & 0xffff);
+  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16_1(float2 src) {
+  uint2 v = asuint(src);
+  float2 t_low = f16tof32(v & 0xffff);
+  float2 t_high = f16tof32((v >> 16) & 0xffff);
+  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16_2(uint2 src) {
+  uint2 v = asuint(src);
+  float2 t_low = f16tof32(v & 0xffff);
+  float2 t_high = f16tof32((v >> 16) & 0xffff);
+  return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+
+void foo() {
+  int2 a = int2(1, 2);
+  vector<float16_t, 4> b = tint_bitcast_to_f16(a);
+  float2 c = float2(1.0f, 2.0f);
+  vector<float16_t, 4> d = tint_bitcast_to_f16_1(c);
+  uint2 e = uint2(1u, 2u);
+  vector<float16_t, 4> f = tint_bitcast_to_f16_2(e);
+}
+
+)");
+}
+
+}  // namespace
+}  // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 5a66655..e4be041 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -31,6 +31,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <string>
+#include <tuple>
 #include <unordered_map>
 #include <unordered_set>
 #include <utility>
@@ -107,6 +108,7 @@
 #include "src/tint/utils/ice/ice.h"
 #include "src/tint/utils/macros/compiler.h"
 #include "src/tint/utils/macros/scoped_assignment.h"
+#include "src/tint/utils/math/hash.h"
 #include "src/tint/utils/rtti/switch.h"
 #include "src/tint/utils/strconv/float_to_string.h"
 #include "src/tint/utils/text/string.h"
@@ -192,6 +194,13 @@
     /// Block to emit for a continuing
     std::function<void()> emit_continuing_;
 
+    using BinaryType =
+        tint::UnorderedKeyWrapper<std::tuple<const core::type::Type*, const core::type::Type*>>;
+
+    // Polyfill functions for bitcast expression, BinaryType indicates the source type and the
+    // destination type.
+    std::unordered_map<BinaryType, std::string> bitcast_funcs_;
+
     /// Emit the root block.
     /// @param root_block the root block to emit
     void EmitRootBlock(core::ir::Block* root_block) {
@@ -608,6 +617,7 @@
                 Switch(
                     r->Instruction(),                                                          //
                     [&](const core::ir::Access* a) { EmitAccess(out, a); },                    //
+                    [&](const core::ir::Bitcast* b) { EmitBitcast(out, b); },                  //
                     [&](const core::ir::Construct* c) { EmitConstruct(out, c); },              //
                     [&](const core::ir::Convert* c) { EmitConvert(out, c); },                  //
                     [&](const core::ir::CoreBinary* b) { EmitBinary(out, b); },                //
@@ -627,6 +637,158 @@
             TINT_ICE_ON_NO_MATCH);
     }
 
+    /// Emit a bitcast instruction
+    void EmitBitcast(StringStream& out, const core::ir::Bitcast* b) {
+        auto* src_type = b->Val()->Type();
+        auto* dst_type = b->Result(0)->Type();
+
+        // Identity transform
+        if (src_type == dst_type) {
+            EmitValue(out, b->Val());
+            return;
+        }
+
+        if (src_type->DeepestElement()->As<core::type::F16>()) {
+            out << EmitBitcastFromF16(src_type, dst_type);
+        } else if (dst_type->DeepestElement()->As<core::type::F16>()) {
+            out << EmitBitcastToF16(src_type, dst_type);
+        } else {
+            out << "as";
+            EmitType(out, dst_type);
+        }
+        out << "(";
+        EmitValue(out, b->Val());
+        out << ")";
+    }
+
+    // Bitcast f16 types to others by converting the given f16 value to f32 and call
+    // f32tof16 to get the bits. This should be safe, because the conversion is precise
+    // for finite and infinite f16 value as they are exactly representable by f32.
+    std::string EmitBitcastFromF16(const core::type::Type* src_type,
+                                   const core::type::Type* dst_type) {
+        return tint::GetOrAdd(
+            bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
+                TextBuffer b;
+                auto fn_name = UniqueIdentifier(std::string("tint_bitcast_from_f16"));
+                {
+                    auto decl = Line(&b);
+                    EmitTypeAndName(decl, dst_type, core::AddressSpace::kUndefined,
+                                    core::Access::kUndefined, fn_name);
+                    {
+                        const ScopedParen sp(decl);
+                        EmitTypeAndName(decl, src_type, core::AddressSpace::kUndefined,
+                                        core::Access::kUndefined, "src");
+                    }
+                    decl << " {";
+                }
+                {
+                    auto* src_vec = src_type->As<core::type::Vector>();
+
+                    const ScopedIndent si(&b);
+                    {
+                        Line(&b) << "uint" << src_vec->Width() << " r = f32tof16(float"
+                                 << src_vec->Width() << "(src));";
+
+                        {
+                            auto* dst_el_type = dst_type->DeepestElement();
+
+                            auto s = Line(&b);
+                            s << "return as";
+                            EmitType(s, dst_el_type, core::AddressSpace::kUndefined,
+                                     core::Access::kReadWrite, "");
+                            s << "(";
+                            switch (src_vec->Width()) {
+                                case 2: {
+                                    s << "uint((r.x & 0xffff) | ((r.y & 0xffff) << 16))";
+                                    break;
+                                }
+                                case 4: {
+                                    s << "uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), "
+                                         "(r.z & 0xffff) | ((r.w & 0xffff) << 16))";
+                                    break;
+                                }
+                                default: {
+                                    TINT_UNREACHABLE();
+                                }
+                            }
+                            s << ");";
+                        }
+                    }
+                }
+                Line(&b) << "}";
+                Line(&b);
+
+                preamble_buffer_.Append(b);
+                return fn_name;
+            });
+    }
+
+    // Bitcast other types to f16 types by reinterpreting their bits as f16 using
+    // f16tof32, and convert the result f32 to f16. This should be safe, because the
+    // conversion is precise for finite and infinite f16 result value as they are
+    // exactly representable by f32.
+    std::string EmitBitcastToF16(const core::type::Type* src_type,
+                                 const core::type::Type* dst_type) {
+        return tint::GetOrAdd(
+            bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
+                TextBuffer b;
+                auto fn_name = UniqueIdentifier(std::string("tint_bitcast_to_f16"));
+                {
+                    auto decl = Line(&b);
+                    EmitTypeAndName(decl, dst_type, core::AddressSpace::kUndefined,
+                                    core::Access::kUndefined, fn_name);
+                    {
+                        const ScopedParen sp(decl);
+                        EmitTypeAndName(decl, src_type, core::AddressSpace::kUndefined,
+                                        core::Access::kUndefined, "src");
+                    }
+                    decl << " {";
+                }
+                {
+                    const ScopedIndent si(&b);
+                    {
+                        auto* dst_vec = dst_type->As<core::type::Vector>();
+                        auto* src_vec = src_type->As<core::type::Vector>();
+                        const std::string src_type_suffix = (src_vec ? "2" : "");
+
+                        // Convert the source to uint for f16tof32.
+                        Line(&b) << "uint" << src_type_suffix << " v = asuint(src);";
+                        // Reinterpret the low 16 bits and high 16 bits
+                        Line(&b) << "float" << src_type_suffix << " t_low = f16tof32(v & 0xffff);";
+                        Line(&b) << "float" << src_type_suffix
+                                 << " t_high = f16tof32((v >> 16) & 0xffff);";
+                        // Construct the result f16 vector
+                        {
+                            auto s = Line(&b);
+                            s << "return ";
+                            EmitType(s, dst_type, core::AddressSpace::kUndefined,
+                                     core::Access::kReadWrite, "");
+                            s << "(";
+                            switch (dst_vec->Width()) {
+                                case 2: {
+                                    s << "t_low.x, t_high.x";
+                                    break;
+                                }
+                                case 4: {
+                                    s << "t_low.x, t_high.x, t_low.y, t_high.y";
+                                    break;
+                                }
+                                default: {
+                                    TINT_UNREACHABLE();
+                                }
+                            }
+                            s << ");";
+                        }
+                    }
+                }
+                Line(&b) << "}";
+                Line(&b);
+
+                preamble_buffer_.Append(b);
+                return fn_name;
+            });
+    }
+
     /// Emit a convert instruction
     void EmitConvert(StringStream& out, const core::ir::Convert* c) {
         EmitType(out, c->Result(0)->Type());