[hlsl] Add polyfill for `trunc`

This CL adds a polyfill for the `trunc` builtin with turns it into a
ternary with floor and ceil.

Bug: 42251045

Change-Id: I8a51654726ca38ae2bad66cb84d9aab1e0100917
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196814
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/ir/ternary.h b/src/tint/lang/hlsl/ir/ternary.h
index 136dccd..46c42aa 100644
--- a/src/tint/lang/hlsl/ir/ternary.h
+++ b/src/tint/lang/hlsl/ir/ternary.h
@@ -42,6 +42,9 @@
 class Ternary final : public Castable<Ternary, core::ir::Call> {
   public:
     /// Constructor
+    ///
+    /// Note, the args are in the order of (`false`, `true`, `compare`) to match select.
+    /// Note, the ternary evaluates all branches, not just the selected branch.
     Ternary(core::ir::InstructionResult* result, VectorRef<core::ir::Value*> args);
     ~Ternary() override;
 
diff --git a/src/tint/lang/hlsl/writer/builtin_test.cc b/src/tint/lang/hlsl/writer/builtin_test.cc
index 843951e..d5edeee 100644
--- a/src/tint/lang/hlsl/writer/builtin_test.cc
+++ b/src/tint/lang/hlsl/writer/builtin_test.cc
@@ -38,7 +38,7 @@
 namespace tint::hlsl::writer {
 namespace {
 
-TEST_F(HlslWriterTest, SelectScalar) {
+TEST_F(HlslWriterTest, BuiltinSelectScalar) {
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
         auto* x = b.Let("x", 1_i);
@@ -60,7 +60,7 @@
 )");
 }
 
-TEST_F(HlslWriterTest, SelectVector) {
+TEST_F(HlslWriterTest, BuiltinSelectVector) {
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
         auto* x = b.Let("x", b.Construct<vec2<i32>>(1_i, 2_i));
@@ -83,5 +83,77 @@
 )");
 }
 
+TEST_F(HlslWriterTest, BuiltinTrunc) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* val = b.Var("v", b.Zero(ty.f32()));
+
+        auto* v = b.Load(val);
+        auto* t = b.Call(ty.f32(), core::BuiltinFn::kTrunc, v);
+
+        b.Let("val", t);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  float v = 0.0f;
+  float v_1 = v;
+  float v_2 = floor(v_1);
+  float val = (((v_1 < 0.0f)) ? (ceil(v_1)) : (v_2));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BuiltinTruncVec) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* val = b.Var("v", b.Splat(ty.vec3<f32>(), 2_f));
+
+        auto* v = b.Load(val);
+        auto* t = b.Call(ty.vec3<f32>(), core::BuiltinFn::kTrunc, v);
+
+        b.Let("val", t);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  float3 v = (2.0f).xxx;
+  float3 v_1 = v;
+  float3 v_2 = floor(v_1);
+  float3 val = (((v_1 < (0.0f).xxx)) ? (ceil(v_1)) : (v_2));
+}
+
+)");
+}
+
+TEST_F(HlslWriterTest, BuiltinTruncF16) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* val = b.Var("v", b.Zero(ty.f16()));
+
+        auto* v = b.Load(val);
+        auto* t = b.Call(ty.f16(), core::BuiltinFn::kTrunc, v);
+
+        b.Let("val", t);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+void foo() {
+  float16_t v = float16_t(0.0h);
+  float16_t v_1 = v;
+  float16_t v_2 = floor(v_1);
+  float16_t val = (((v_1 < float16_t(0.0h))) ? (ceil(v_1)) : (v_2));
+}
+
+)");
+}
+
 }  // namespace
 }  // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
index b04094d..efcca68 100644
--- a/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/hlsl/writer/raise/builtin_polyfill.cc
@@ -58,12 +58,11 @@
     /// The type manager.
     core::type::Manager& ty{ir.Types()};
 
-    using BinaryType =
-        tint::UnorderedKeyWrapper<std::tuple<const core::type::Type*, const core::type::Type*>>;
-
-    // Polyfill functions for bitcast expression, BinaryType indicates the source type and the
+    // Polyfill functions for bitcast expression, BitcastType indicates the source type and the
     // destination type.
-    Hashmap<BinaryType, core::ir::Function*, 4> bitcast_funcs_{};
+    using BitcastType =
+        tint::UnorderedKeyWrapper<std::tuple<const core::type::Type*, const core::type::Type*>>;
+    Hashmap<BitcastType, core::ir::Function*, 4> bitcast_funcs_{};
 
     /// Process the module.
     void Process() {
@@ -80,6 +79,9 @@
                     case core::BuiltinFn::kSelect:
                         call_worklist.Push(call);
                         break;
+                    case core::BuiltinFn::kTrunc:
+                        call_worklist.Push(call);
+                        break;
                     default:
                         break;
                 }
@@ -110,6 +112,9 @@
                 case core::BuiltinFn::kSelect:
                     Select(call);
                     break;
+                case core::BuiltinFn::kTrunc:
+                    Trunc(call);
+                    break;
                 default:
                     TINT_UNREACHABLE();
             }
@@ -124,6 +129,33 @@
         call->Destroy();
     }
 
+    // HLSL's trunc is broken for very large/small float values.
+    // See crbug.com/tint/1883
+    //
+    // Replace with:
+    //   value < 0 ? ceil(value) : floor(value)
+    void Trunc(core::ir::CoreBuiltinCall* call) {
+        auto* val = call->Args()[0];
+
+        auto* type = call->Result(0)->Type();
+        Vector<core::ir::Value*, 4> args;
+        b.InsertBefore(call, [&] {
+            args.Push(b.Call(type, core::BuiltinFn::kFloor, val)->Result(0));
+            args.Push(b.Call(type, core::BuiltinFn::kCeil, val)->Result(0));
+
+            const core::type::Type* comp_ty = ty.bool_();
+            if (auto* vec = type->As<core::type::Vector>()) {
+                comp_ty = ty.vec(comp_ty, vec->Width());
+            }
+            args.Push(b.LessThan(comp_ty, val, b.Zero(type))->Result(0));
+        });
+        auto* trunc =
+            b.ir.allocators.instructions.Create<hlsl::ir::Ternary>(call->DetachResult(), args);
+        trunc->InsertBefore(call);
+
+        call->Destroy();
+    }
+
     /// Replaces an identity bitcast result with the value.
     void ReplaceBitcastWithValue(core::ir::Bitcast* bitcast) {
         bitcast->Result(0)->ReplaceAllUsesWith(bitcast->Val());
@@ -154,7 +186,7 @@
     core::ir::Function* CreateBitcastFromF16(const core::type::Type* src_type,
                                              const core::type::Type* dst_type) {
         return bitcast_funcs_.GetOrAdd(
-            BinaryType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
+            BitcastType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
                 TINT_ASSERT(src_type->Is<core::type::Vector>());
 
                 // Generate a helper function that performs the following (in HLSL):
@@ -240,7 +272,7 @@
     core::ir::Function* CreateBitcastToF16(const core::type::Type* src_type,
                                            const core::type::Type* dst_type) {
         return bitcast_funcs_.GetOrAdd(
-            BinaryType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
+            BitcastType{{src_type, dst_type}}, [&]() -> core::ir::Function* {
                 TINT_ASSERT(dst_type->Is<core::type::Vector>());
 
                 // Generate a helper function that performs the following (in HLSL):