[hlsl] Add Bitcast support to the HLSL IR printer.
This CL adds the `bitcast` instruction into the HLSL IR printer.
Bug: 42251045
Change-Id: Ica594d0da33e0d4b65b25aa9b4bf3d6f66ca5dfb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195015
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/BUILD.bazel b/src/tint/lang/hlsl/writer/BUILD.bazel
index 8bb3b26..38b2173 100644
--- a/src/tint/lang/hlsl/writer/BUILD.bazel
+++ b/src/tint/lang/hlsl/writer/BUILD.bazel
@@ -91,6 +91,7 @@
srcs = [
"access_test.cc",
"binary_test.cc",
+ "bitcast_test.cc",
"constant_test.cc",
"construct_test.cc",
"convert_test.cc",
diff --git a/src/tint/lang/hlsl/writer/BUILD.cmake b/src/tint/lang/hlsl/writer/BUILD.cmake
index d36ee9d..a0b853c 100644
--- a/src/tint/lang/hlsl/writer/BUILD.cmake
+++ b/src/tint/lang/hlsl/writer/BUILD.cmake
@@ -102,6 +102,7 @@
tint_add_target(tint_lang_hlsl_writer_test test
lang/hlsl/writer/access_test.cc
lang/hlsl/writer/binary_test.cc
+ lang/hlsl/writer/bitcast_test.cc
lang/hlsl/writer/constant_test.cc
lang/hlsl/writer/construct_test.cc
lang/hlsl/writer/convert_test.cc
diff --git a/src/tint/lang/hlsl/writer/BUILD.gn b/src/tint/lang/hlsl/writer/BUILD.gn
index 4057fc3..3e6fd99 100644
--- a/src/tint/lang/hlsl/writer/BUILD.gn
+++ b/src/tint/lang/hlsl/writer/BUILD.gn
@@ -94,6 +94,7 @@
sources = [
"access_test.cc",
"binary_test.cc",
+ "bitcast_test.cc",
"constant_test.cc",
"construct_test.cc",
"convert_test.cc",
diff --git a/src/tint/lang/hlsl/writer/bitcast_test.cc b/src/tint/lang/hlsl/writer/bitcast_test.cc
new file mode 100644
index 0000000..14b4cee
--- /dev/null
+++ b/src/tint/lang/hlsl/writer/bitcast_test.cc
@@ -0,0 +1,306 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/fluent_types.h"
+#include "src/tint/lang/core/ir/function.h"
+#include "src/tint/lang/core/number.h"
+#include "src/tint/lang/hlsl/writer/helper_test.h"
+
+#include "gtest/gtest.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::hlsl::writer {
+namespace {
+
+TEST_F(HlslWriterTest, BitcastIdentityNumeric) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", 1_i);
+ b.Let("bc", b.Bitcast<i32>(b.Load(a)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ int a = 1;
+ int bc = a;
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastIdentityVec) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", b.Construct<vec2<f32>>(1_f, 2_f));
+ b.Let("bc", b.Bitcast<vec2<f32>>(b.Load(a)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ float2 a = float2(1.0f, 2.0f);
+ float2 bc = a;
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToFloat) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", 1_i);
+ b.Let("bc", b.Bitcast<f32>(b.Load(a)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ int a = 1;
+ float bc = asfloat(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToInt) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", 1_u);
+ b.Let("bc", b.Bitcast<i32>(b.Load(a)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ uint a = 1u;
+ int bc = asint(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToUint) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", 1_i);
+ b.Let("bc", b.Bitcast<u32>(b.Load(a)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ int a = 1;
+ uint bc = asuint(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastFromVec2F16) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", b.Construct<vec2<f16>>(1_h, 2_h));
+ auto* z = b.Load(a);
+ b.Let("b", b.Bitcast<i32>(z));
+ b.Let("c", b.Bitcast<f32>(z));
+ b.Let("d", b.Bitcast<u32>(z));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(int tint_bitcast_from_f16(vector<float16_t, 2> src) {
+ uint2 r = f32tof16(float2(src));
+ return asint(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+float tint_bitcast_from_f16_1(vector<float16_t, 2> src) {
+ uint2 r = f32tof16(float2(src));
+ return asfloat(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+uint tint_bitcast_from_f16_2(vector<float16_t, 2> src) {
+ uint2 r = f32tof16(float2(src));
+ return asuint(uint((r.x & 0xffff) | ((r.y & 0xffff) << 16)));
+}
+
+
+void foo() {
+ vector<float16_t, 2> a = vector<float16_t, 2>(float16_t(1.0h), float16_t(2.0h));
+ int b = tint_bitcast_from_f16(a);
+ float c = tint_bitcast_from_f16_1(a);
+ uint d = tint_bitcast_from_f16_2(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToVec2F16) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", 1_i);
+ b.Let("b", b.Bitcast<vec2<f16>>(b.Load(a)));
+
+ auto* c = b.Var("c", 1_f);
+ b.Let("d", b.Bitcast<vec2<f16>>(b.Load(c)));
+
+ auto* e = b.Var("e", 1_u);
+ b.Let("f", b.Bitcast<vec2<f16>>(b.Load(e)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(vector<float16_t, 2> tint_bitcast_to_f16(int src) {
+ uint v = asuint(src);
+ float t_low = f16tof32(v & 0xffff);
+ float t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16_1(float src) {
+ uint v = asuint(src);
+ float t_low = f16tof32(v & 0xffff);
+ float t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+vector<float16_t, 2> tint_bitcast_to_f16_2(uint src) {
+ uint v = asuint(src);
+ float t_low = f16tof32(v & 0xffff);
+ float t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 2>(t_low.x, t_high.x);
+}
+
+
+void foo() {
+ int a = 1;
+ vector<float16_t, 2> b = tint_bitcast_to_f16(a);
+ float c = 1.0f;
+ vector<float16_t, 2> d = tint_bitcast_to_f16_1(c);
+ uint e = 1u;
+ vector<float16_t, 2> f = tint_bitcast_to_f16_2(e);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastFromVec4F16) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", b.Construct<vec4<f16>>(1_h, 2_h, 3_h, 4_h));
+ auto* z = b.Load(a);
+ b.Let("b", b.Bitcast<vec2<i32>>(z));
+ b.Let("c", b.Bitcast<vec2<f32>>(z));
+ b.Let("d", b.Bitcast<vec2<u32>>(z));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(int2 tint_bitcast_from_f16(vector<float16_t, 4> src) {
+ uint4 r = f32tof16(float4(src));
+ return asint(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+float2 tint_bitcast_from_f16_1(vector<float16_t, 4> src) {
+ uint4 r = f32tof16(float4(src));
+ return asfloat(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+uint2 tint_bitcast_from_f16_2(vector<float16_t, 4> src) {
+ uint4 r = f32tof16(float4(src));
+ return asuint(uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), (r.z & 0xffff) | ((r.w & 0xffff) << 16)));
+}
+
+
+void foo() {
+ vector<float16_t, 4> a = vector<float16_t, 4>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h), float16_t(4.0h));
+ int2 b = tint_bitcast_from_f16(a);
+ float2 c = tint_bitcast_from_f16_1(a);
+ uint2 d = tint_bitcast_from_f16_2(a);
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BitcastToVec4F16) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* a = b.Var("a", b.Construct<vec2<i32>>(1_i, 2_i));
+ b.Let("b", b.Bitcast<vec4<f16>>(b.Load(a)));
+
+ auto* c = b.Var("c", b.Construct<vec2<f32>>(1_f, 2_f));
+ b.Let("d", b.Bitcast<vec4<f16>>(b.Load(c)));
+
+ auto* e = b.Var("e", b.Construct<vec2<u32>>(1_u, 2_u));
+ b.Let("f", b.Bitcast<vec4<f16>>(b.Load(e)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(vector<float16_t, 4> tint_bitcast_to_f16(int2 src) {
+ uint2 v = asuint(src);
+ float2 t_low = f16tof32(v & 0xffff);
+ float2 t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16_1(float2 src) {
+ uint2 v = asuint(src);
+ float2 t_low = f16tof32(v & 0xffff);
+ float2 t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+vector<float16_t, 4> tint_bitcast_to_f16_2(uint2 src) {
+ uint2 v = asuint(src);
+ float2 t_low = f16tof32(v & 0xffff);
+ float2 t_high = f16tof32((v >> 16) & 0xffff);
+ return vector<float16_t, 4>(t_low.x, t_high.x, t_low.y, t_high.y);
+}
+
+
+void foo() {
+ int2 a = int2(1, 2);
+ vector<float16_t, 4> b = tint_bitcast_to_f16(a);
+ float2 c = float2(1.0f, 2.0f);
+ vector<float16_t, 4> d = tint_bitcast_to_f16_1(c);
+ uint2 e = uint2(1u, 2u);
+ vector<float16_t, 4> f = tint_bitcast_to_f16_2(e);
+}
+
+)");
+}
+
+} // namespace
+} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/printer/printer.cc b/src/tint/lang/hlsl/writer/printer/printer.cc
index 5a66655..e4be041 100644
--- a/src/tint/lang/hlsl/writer/printer/printer.cc
+++ b/src/tint/lang/hlsl/writer/printer/printer.cc
@@ -31,6 +31,7 @@
#include <cstddef>
#include <cstdint>
#include <string>
+#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -107,6 +108,7 @@
#include "src/tint/utils/ice/ice.h"
#include "src/tint/utils/macros/compiler.h"
#include "src/tint/utils/macros/scoped_assignment.h"
+#include "src/tint/utils/math/hash.h"
#include "src/tint/utils/rtti/switch.h"
#include "src/tint/utils/strconv/float_to_string.h"
#include "src/tint/utils/text/string.h"
@@ -192,6 +194,13 @@
/// Block to emit for a continuing
std::function<void()> emit_continuing_;
+ using BinaryType =
+ tint::UnorderedKeyWrapper<std::tuple<const core::type::Type*, const core::type::Type*>>;
+
+ // Polyfill functions for bitcast expression, BinaryType indicates the source type and the
+ // destination type.
+ std::unordered_map<BinaryType, std::string> bitcast_funcs_;
+
/// Emit the root block.
/// @param root_block the root block to emit
void EmitRootBlock(core::ir::Block* root_block) {
@@ -608,6 +617,7 @@
Switch(
r->Instruction(), //
[&](const core::ir::Access* a) { EmitAccess(out, a); }, //
+ [&](const core::ir::Bitcast* b) { EmitBitcast(out, b); }, //
[&](const core::ir::Construct* c) { EmitConstruct(out, c); }, //
[&](const core::ir::Convert* c) { EmitConvert(out, c); }, //
[&](const core::ir::CoreBinary* b) { EmitBinary(out, b); }, //
@@ -627,6 +637,158 @@
TINT_ICE_ON_NO_MATCH);
}
+ /// Emit a bitcast instruction
+ void EmitBitcast(StringStream& out, const core::ir::Bitcast* b) {
+ auto* src_type = b->Val()->Type();
+ auto* dst_type = b->Result(0)->Type();
+
+ // Identity transform
+ if (src_type == dst_type) {
+ EmitValue(out, b->Val());
+ return;
+ }
+
+ if (src_type->DeepestElement()->As<core::type::F16>()) {
+ out << EmitBitcastFromF16(src_type, dst_type);
+ } else if (dst_type->DeepestElement()->As<core::type::F16>()) {
+ out << EmitBitcastToF16(src_type, dst_type);
+ } else {
+ out << "as";
+ EmitType(out, dst_type);
+ }
+ out << "(";
+ EmitValue(out, b->Val());
+ out << ")";
+ }
+
+ // Bitcast f16 types to others by converting the given f16 value to f32 and call
+ // f32tof16 to get the bits. This should be safe, because the conversion is precise
+ // for finite and infinite f16 value as they are exactly representable by f32.
+ std::string EmitBitcastFromF16(const core::type::Type* src_type,
+ const core::type::Type* dst_type) {
+ return tint::GetOrAdd(
+ bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
+ TextBuffer b;
+ auto fn_name = UniqueIdentifier(std::string("tint_bitcast_from_f16"));
+ {
+ auto decl = Line(&b);
+ EmitTypeAndName(decl, dst_type, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, fn_name);
+ {
+ const ScopedParen sp(decl);
+ EmitTypeAndName(decl, src_type, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, "src");
+ }
+ decl << " {";
+ }
+ {
+ auto* src_vec = src_type->As<core::type::Vector>();
+
+ const ScopedIndent si(&b);
+ {
+ Line(&b) << "uint" << src_vec->Width() << " r = f32tof16(float"
+ << src_vec->Width() << "(src));";
+
+ {
+ auto* dst_el_type = dst_type->DeepestElement();
+
+ auto s = Line(&b);
+ s << "return as";
+ EmitType(s, dst_el_type, core::AddressSpace::kUndefined,
+ core::Access::kReadWrite, "");
+ s << "(";
+ switch (src_vec->Width()) {
+ case 2: {
+ s << "uint((r.x & 0xffff) | ((r.y & 0xffff) << 16))";
+ break;
+ }
+ case 4: {
+ s << "uint2((r.x & 0xffff) | ((r.y & 0xffff) << 16), "
+ "(r.z & 0xffff) | ((r.w & 0xffff) << 16))";
+ break;
+ }
+ default: {
+ TINT_UNREACHABLE();
+ }
+ }
+ s << ");";
+ }
+ }
+ }
+ Line(&b) << "}";
+ Line(&b);
+
+ preamble_buffer_.Append(b);
+ return fn_name;
+ });
+ }
+
+ // Bitcast other types to f16 types by reinterpreting their bits as f16 using
+ // f16tof32, and convert the result f32 to f16. This should be safe, because the
+ // conversion is precise for finite and infinite f16 result value as they are
+ // exactly representable by f32.
+ std::string EmitBitcastToF16(const core::type::Type* src_type,
+ const core::type::Type* dst_type) {
+ return tint::GetOrAdd(
+ bitcast_funcs_, BinaryType{{src_type, dst_type}}, [&]() -> std::string {
+ TextBuffer b;
+ auto fn_name = UniqueIdentifier(std::string("tint_bitcast_to_f16"));
+ {
+ auto decl = Line(&b);
+ EmitTypeAndName(decl, dst_type, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, fn_name);
+ {
+ const ScopedParen sp(decl);
+ EmitTypeAndName(decl, src_type, core::AddressSpace::kUndefined,
+ core::Access::kUndefined, "src");
+ }
+ decl << " {";
+ }
+ {
+ const ScopedIndent si(&b);
+ {
+ auto* dst_vec = dst_type->As<core::type::Vector>();
+ auto* src_vec = src_type->As<core::type::Vector>();
+ const std::string src_type_suffix = (src_vec ? "2" : "");
+
+ // Convert the source to uint for f16tof32.
+ Line(&b) << "uint" << src_type_suffix << " v = asuint(src);";
+ // Reinterpret the low 16 bits and high 16 bits
+ Line(&b) << "float" << src_type_suffix << " t_low = f16tof32(v & 0xffff);";
+ Line(&b) << "float" << src_type_suffix
+ << " t_high = f16tof32((v >> 16) & 0xffff);";
+ // Construct the result f16 vector
+ {
+ auto s = Line(&b);
+ s << "return ";
+ EmitType(s, dst_type, core::AddressSpace::kUndefined,
+ core::Access::kReadWrite, "");
+ s << "(";
+ switch (dst_vec->Width()) {
+ case 2: {
+ s << "t_low.x, t_high.x";
+ break;
+ }
+ case 4: {
+ s << "t_low.x, t_high.x, t_low.y, t_high.y";
+ break;
+ }
+ default: {
+ TINT_UNREACHABLE();
+ }
+ }
+ s << ");";
+ }
+ }
+ }
+ Line(&b) << "}";
+ Line(&b);
+
+ preamble_buffer_.Append(b);
+ return fn_name;
+ });
+ }
+
/// Emit a convert instruction
void EmitConvert(StringStream& out, const core::ir::Convert* c) {
EmitType(out, c->Result(0)->Type());