[hlsl] Add polyfill for `trunc`
This CL adds a polyfill for the `trunc` builtin with turns it into a
ternary with floor and ceil.
Bug: 42251045
Change-Id: I8a51654726ca38ae2bad66cb84d9aab1e0100917
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196814
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/ir/ternary.h b/src/tint/lang/hlsl/ir/ternary.h
index 136dccd..46c42aa 100644
--- a/src/tint/lang/hlsl/ir/ternary.h
+++ b/src/tint/lang/hlsl/ir/ternary.h
@@ -42,6 +42,9 @@
class Ternary final : public Castable<Ternary, core::ir::Call> {
public:
/// Constructor
+ ///
+ /// Note, the args are in the order of (`false`, `true`, `compare`) to match select.
+ /// Note, the ternary evaluates all branches, not just the selected branch.
Ternary(core::ir::InstructionResult* result, VectorRef<core::ir::Value*> args);
~Ternary() override;
diff --git a/src/tint/lang/hlsl/writer/builtin_test.cc b/src/tint/lang/hlsl/writer/builtin_test.cc
index 843951e..d5edeee 100644
--- a/src/tint/lang/hlsl/writer/builtin_test.cc
+++ b/src/tint/lang/hlsl/writer/builtin_test.cc
@@ -38,7 +38,7 @@
namespace tint::hlsl::writer {
namespace {
-TEST_F(HlslWriterTest, SelectScalar) {
+TEST_F(HlslWriterTest, BuiltinSelectScalar) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
auto* x = b.Let("x", 1_i);
@@ -60,7 +60,7 @@
)");
}
-TEST_F(HlslWriterTest, SelectVector) {
+TEST_F(HlslWriterTest, BuiltinSelectVector) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] {
auto* x = b.Let("x", b.Construct<vec2<i32>>(1_i, 2_i));
@@ -83,5 +83,77 @@
)");
}
+TEST_F(HlslWriterTest, BuiltinTrunc) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* val = b.Var("v", b.Zero(ty.f32()));
+
+ auto* v = b.Load(val);
+ auto* t = b.Call(ty.f32(), core::BuiltinFn::kTrunc, v);
+
+ b.Let("val", t);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ float v = 0.0f;
+ float v_1 = v;
+ float v_2 = floor(v_1);
+ float val = (((v_1 < 0.0f)) ? (ceil(v_1)) : (v_2));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BuiltinTruncVec) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* val = b.Var("v", b.Splat(ty.vec3<f32>(), 2_f));
+
+ auto* v = b.Load(val);
+ auto* t = b.Call(ty.vec3<f32>(), core::BuiltinFn::kTrunc, v);
+
+ b.Let("val", t);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ float3 v = (2.0f).xxx;
+ float3 v_1 = v;
+ float3 v_2 = floor(v_1);
+ float3 val = (((v_1 < (0.0f).xxx)) ? (ceil(v_1)) : (v_2));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BuiltinTruncF16) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* val = b.Var("v", b.Zero(ty.f16()));
+
+ auto* v = b.Load(val);
+ auto* t = b.Call(ty.f16(), core::BuiltinFn::kTrunc, v);
+
+ b.Let("val", t);
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+ EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+ float16_t v = float16_t(0.0h);
+ float16_t v_1 = v;
+ float16_t v_2 = floor(v_1);
+ float16_t val = (((v_1 < float16_t(0.0h))) ? (ceil(v_1)) : (v_2));
+}
+
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
index b04094d..efcca68 100644
--- a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
@@ -58,12 +58,11 @@
/// The type manager.
core::type::Manager& ty{ir.Types()};
- using BinaryType =
- tint::UnorderedKeyWrapper<std::tuple<const core::type::Type*, const core::type::Type*>>;
-
- // Polyfill functions for bitcast expression, BinaryType indicates the source type and the
+ // Polyfill functions for bitcast expression, BitcastType indicates the source type and the
// destination type.
- Hashmap<BinaryType, core::ir::Function*, 4> bitcast_funcs_{};
+ using BitcastType =
+ tint::UnorderedKeyWrapper<std::tuple<const core::type::Type*, const core::type::Type*>>;
+ Hashmap<BitcastType, core::ir::Function*, 4> bitcast_funcs_{};
/// Process the module.
void Process() {
@@ -80,6 +79,9 @@
case core::BuiltinFn::kSelect:
call_worklist.Push(call);
break;
+ case core::BuiltinFn::kTrunc:
+ call_worklist.Push(call);
+ break;
default:
break;
}
@@ -110,6 +112,9 @@
case core::BuiltinFn::kSelect:
Select(call);
break;
+ case core::BuiltinFn::kTrunc:
+ Trunc(call);
+ break;
default:
TINT_UNREACHABLE();
}
@@ -124,6 +129,33 @@
call->Destroy();
}
+ // HLSL's trunc is broken for very large/small float values.
+ // See crbug.com/tint/1883
+ //
+ // Replace with:
+ // value < 0 ? ceil(value) : floor(value)
+ void Trunc(core::ir::CoreBuiltinCall* call) {
+ auto* val = call->Args()[0];
+
+ auto* type = call->Result(0)->Type();
+ Vector<core::ir::Value*, 4> args;
+ b.InsertBefore(call, [&] {
+ args.Push(b.Call(type, core::BuiltinFn::kFloor, val)->Result(0));
+ args.Push(b.Call(type, core::BuiltinFn::kCeil, val)->Result(0));
+
+ const core::type::Type* comp_ty = ty.bool_();
+ if (auto* vec = type->As<core::type::Vector>()) {
+ comp_ty = ty.vec(comp_ty, vec->Width());
+ }
+ args.Push(b.LessThan(comp_ty, val, b.Zero(type))->Result(0));
+ });
+ auto* trunc =
+ b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(call->DetachResult(), args);
+ trunc->InsertBefore(call);
+
+ call->Destroy();
+ }
+
/// Replaces an identity bitcast result with the value.
void ReplaceBitcastWithValue(core::ir::Bitcast* bitcast) {
bitcast->Result(0)->ReplaceAllUsesWith(bitcast->Val());
@@ -154,7 +186,7 @@
core::ir::Function* CreateBitcastFromF16(const core::type::Type* src_type,
const core::type::Type* dst_type) {
return bitcast_funcs_.GetOrAdd(
- BinaryType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
+ BitcastType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
TINT_ASSERT(src_type->Is<core::type::Vector>());
// Generate a helper function that performs the following (in HLSL):
@@ -240,7 +272,7 @@
core::ir::Function* CreateBitcastToF16(const core::type::Type* src_type,
const core::type::Type* dst_type) {
return bitcast_funcs_.GetOrAdd(
- BinaryType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
+ BitcastType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
TINT_ASSERT(dst_type->Is<core::type::Vector>());
// Generate a helper function that performs the following (in HLSL):