Import Tint changes from Dawn
Changes:
- 5cf943e7972c52547ce3b2d581bd208ce2e18886 tint/hlsl: implement trunc in terms of floor/ceil to work... by Antonio Maiorano <amaiorano@google.com>
GitOrigin-RevId: 5cf943e7972c52547ce3b2d581bd208ce2e18886
Change-Id: I29d469dd0d74dd6fab71cb920bf99ad6a0937932
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/125580
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 4910f81..2c4279c 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -997,6 +997,9 @@
if (type == builtin::Function::kQuantizeToF16) {
return EmitQuantizeToF16Call(out, expr, builtin);
}
+ if (type == builtin::Function::kTrunc) {
+ return EmitTruncCall(out, expr, builtin);
+ }
if (builtin->IsDataPacking()) {
return EmitDataPackingCall(out, expr, builtin);
}
@@ -2116,6 +2119,20 @@
return true;
}
+bool GeneratorImpl::EmitTruncCall(utils::StringStream& out,
+ const ast::CallExpression* expr,
+ const sem::Builtin* builtin) {
+ // HLSL's trunc is broken for very large/small float values.
+ // See crbug.com/tint/1883
+ return CallBuiltinHelper( //
+ out, expr, builtin, [&](TextBuffer* b, const std::vector<std::string>& params) {
+ // value < 0 ? ceil(value) : floor(value)
+ line(b) << "return " << params[0] << " < 0 ? ceil(" << params[0] << ") : floor("
+ << params[0] << ");";
+ return true;
+ });
+}
+
bool GeneratorImpl::EmitDataPackingCall(utils::StringStream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin) {
@@ -2704,7 +2721,6 @@
case builtin::Function::kTan:
case builtin::Function::kTanh:
case builtin::Function::kTranspose:
- case builtin::Function::kTrunc:
return builtin->str();
case builtin::Function::kCountOneBits: // uint
return "countbits";
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index f7adc9f..103e355 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -281,6 +281,14 @@
bool EmitQuantizeToF16Call(utils::StringStream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin);
+ /// Handles generating a call to the `trunc()` intrinsic
+ /// @param out the output of the expression stream
+ /// @param expr the call expression
+ /// @param builtin the semantic information for the builtin
+ /// @returns true if the call expression is emitted
+ bool EmitTruncCall(utils::StringStream& out,
+ const ast::CallExpression* expr,
+ const sem::Builtin* builtin);
/// Handles generating a call to DP4a builtins (dot4I8Packed and dot4U8Packed)
/// @param out the output of the expression stream
/// @param expr the call expression
diff --git a/src/tint/writer/hlsl/generator_impl_builtin_test.cc b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
index d552671..75fd463 100644
--- a/src/tint/writer/hlsl/generator_impl_builtin_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
@@ -97,7 +97,6 @@
case builtin::Function::kSqrt:
case builtin::Function::kTan:
case builtin::Function::kTanh:
- case builtin::Function::kTrunc:
if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else {
@@ -309,8 +308,6 @@
BuiltinData{builtin::Function::kTan, CallParamType::kF16, "tan"},
BuiltinData{builtin::Function::kTanh, CallParamType::kF32, "tanh"},
BuiltinData{builtin::Function::kTanh, CallParamType::kF16, "tanh"},
- BuiltinData{builtin::Function::kTrunc, CallParamType::kF32, "trunc"},
- BuiltinData{builtin::Function::kTrunc, CallParamType::kF16, "trunc"},
/* Integer built-in */
BuiltinData{builtin::Function::kAbs, CallParamType::kU32, "abs"},
BuiltinData{builtin::Function::kClamp, CallParamType::kU32, "clamp"},
@@ -1089,6 +1086,94 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Builtin, Trunc_Scalar_f32) {
+ auto* val = Var("val", ty.f32());
+ auto* call = Call("trunc", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"(float tint_trunc(float param_0) {
+ return param_0 < 0 ? ceil(param_0) : floor(param_0);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+ float val = 0.0f;
+ const float tint_symbol = tint_trunc(val);
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Trunc_Vector_f32) {
+ auto* val = Var("val", ty.vec3<f32>());
+ auto* call = Call("trunc", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"(float3 tint_trunc(float3 param_0) {
+ return param_0 < 0 ? ceil(param_0) : floor(param_0);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+ float3 val = float3(0.0f, 0.0f, 0.0f);
+ const float3 tint_symbol = tint_trunc(val);
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Trunc_Scalar_f16) {
+ Enable(builtin::Extension::kF16);
+
+ auto* val = Var("val", ty.f16());
+ auto* call = Call("trunc", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"(float16_t tint_trunc(float16_t param_0) {
+ return param_0 < 0 ? ceil(param_0) : floor(param_0);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+ float16_t val = float16_t(0.0h);
+ const float16_t tint_symbol = tint_trunc(val);
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Trunc_Vector_f16) {
+ Enable(builtin::Extension::kF16);
+
+ auto* val = Var("val", ty.vec3<f16>());
+ auto* call = Call("trunc", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"(vector<float16_t, 3> tint_trunc(vector<float16_t, 3> param_0) {
+ return param_0 < 0 ? ceil(param_0) : floor(param_0);
+}
+
+[numthreads(1, 1, 1)]
+void test_function() {
+ vector<float16_t, 3> val = vector<float16_t, 3>(float16_t(0.0h), float16_t(0.0h), float16_t(0.0h));
+ const vector<float16_t, 3> tint_symbol = tint_trunc(val);
+ return;
+}
+)");
+}
+
TEST_F(HlslGeneratorImplTest_Builtin, Pack4x8Snorm) {
auto* call = Call("pack4x8snorm", "p1");
GlobalVar("p1", ty.vec4<f32>(), builtin::AddressSpace::kPrivate);
diff --git a/src/tint/writer/hlsl/generator_impl_import_test.cc b/src/tint/writer/hlsl/generator_impl_import_test.cc
index 133774a..3fbad50 100644
--- a/src/tint/writer/hlsl/generator_impl_import_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_import_test.cc
@@ -66,8 +66,7 @@
HlslImportData{"sinh", "sinh"},
HlslImportData{"sqrt", "sqrt"},
HlslImportData{"tan", "tan"},
- HlslImportData{"tanh", "tanh"},
- HlslImportData{"trunc", "trunc"}));
+ HlslImportData{"tanh", "tanh"}));
using HlslImportData_SingleIntParamTest = TestParamHelper<HlslImportData>;
TEST_P(HlslImportData_SingleIntParamTest, IntScalar) {
@@ -125,8 +124,7 @@
HlslImportData{"sinh", "sinh"},
HlslImportData{"sqrt", "sqrt"},
HlslImportData{"tan", "tan"},
- HlslImportData{"tanh", "tanh"},
- HlslImportData{"trunc", "trunc"}));
+ HlslImportData{"tanh", "tanh"}));
using HlslImportData_DualParam_ScalarTest = TestParamHelper<HlslImportData>;
TEST_P(HlslImportData_DualParam_ScalarTest, Float) {