spirv-writer: support isNormal
Fixed: tint:158
Change-Id: Iabe7c1afe7dea87e62277bacb2086ee6d2964e78
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52460
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 26cacbe..fc6cbb5 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -32,6 +32,7 @@
#include "src/sem/sampled_texture_type.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
+#include "src/sem/vector_type.h"
#include "src/utils/get_or_create.h"
#include "src/writer/append_vector.h"
@@ -2251,6 +2252,83 @@
}
return 0;
}
+ case IntrinsicType::kIsNormal: {
+ // A normal number is finite, non-zero, and not subnormal.
+ // Its exponent is neither of the extreme possible values.
+ // Implemented as:
+ // exponent_bits = bitcast<u32>(f);
+ // clamped = uclamp(1,254,exponent_bits);
+ // result = (clamped == exponent_bits);
+ //
+ auto val_id = get_param_as_value_id(0);
+ if (!val_id) {
+ return 0;
+ }
+
+ // These parameters are valid for IEEE 754 binary32
+ const uint32_t kExponentMask = 0x7f80000;
+ const uint32_t kMinNormalExponent = 0x0080000;
+ const uint32_t kMaxNormalExponent = 0x7f00000;
+
+ auto set_id = GetGLSLstd450Import();
+ sem::U32 u32;
+ auto unsigned_id = GenerateTypeIfNeeded(&u32);
+ auto exponent_mask_id =
+ GenerateConstantIfNeeded(ScalarConstant::U32(kExponentMask));
+ auto min_exponent_id =
+ GenerateConstantIfNeeded(ScalarConstant::U32(kMinNormalExponent));
+ auto max_exponent_id =
+ GenerateConstantIfNeeded(ScalarConstant::U32(kMaxNormalExponent));
+ if (auto* fvec_ty = intrinsic->ReturnType()->As<sem::Vector>()) {
+ // In the vector case, update the unsigned type to a vector type of the
+ // same size, and create vector constants by replicating the scalars.
+ // I expect backend compilers to fold these into unique constants, so
+ // there is no loss of efficiency.
+ sem::Vector uvec_ty(&u32, fvec_ty->size());
+ unsigned_id = GenerateTypeIfNeeded(&uvec_ty);
+ auto splat = [&](uint32_t scalar_id) -> uint32_t {
+ auto splat_result = result_op();
+ OperandList splat_params{Operand::Int(unsigned_id), splat_result};
+ for (size_t i = 0; i < fvec_ty->size(); i++) {
+ splat_params.emplace_back(Operand::Int(scalar_id));
+ }
+ if (!push_function_inst(spv::Op::OpCompositeConstruct,
+ std::move(splat_params))) {
+ return 0;
+ }
+ return splat_result.to_i();
+ };
+ exponent_mask_id = splat(exponent_mask_id);
+ min_exponent_id = splat(min_exponent_id);
+ max_exponent_id = splat(max_exponent_id);
+ }
+ auto cast_result = result_op();
+ auto exponent_bits_result = result_op();
+ auto clamp_result = result_op();
+
+ if (set_id && unsigned_id && exponent_mask_id && min_exponent_id &&
+ max_exponent_id &&
+ push_function_inst(
+ spv::Op::OpBitcast,
+ {Operand::Int(unsigned_id), cast_result, Operand::Int(val_id)}) &&
+ push_function_inst(spv::Op::OpBitwiseAnd,
+ {Operand::Int(unsigned_id), exponent_bits_result,
+ Operand::Int(cast_result.to_i()),
+ Operand::Int(exponent_mask_id)}) &&
+ push_function_inst(
+ spv::Op::OpExtInst,
+ {Operand::Int(unsigned_id), clamp_result, Operand::Int(set_id),
+ Operand::Int(GLSLstd450UClamp),
+ Operand::Int(exponent_bits_result.to_i()),
+ Operand::Int(min_exponent_id), Operand::Int(max_exponent_id)}) &&
+ push_function_inst(spv::Op::OpIEqual,
+ {Operand::Int(result_type_id), result,
+ Operand::Int(exponent_bits_result.to_i()),
+ Operand::Int(clamp_result.to_i())})) {
+ return result_id;
+ }
+ return 0;
+ }
case IntrinsicType::kReverseBits:
op = spv::Op::OpBitReverse;
break;
diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc
index 002ee07..0435128 100644
--- a/src/writer/spirv/builder_intrinsic_test.cc
+++ b/src/writer/spirv/builder_intrinsic_test.cc
@@ -184,6 +184,97 @@
)");
}
+TEST_F(IntrinsicBuilderTest, IsNormal_Scalar) {
+ auto* var = Global("v", ty.f32(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call("isNormal", "v");
+ WrapInFunction(expr);
+
+ auto* func = Func("a_func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{}, ast::DecorationList{});
+
+ spirv::Builder& b = Build();
+
+ ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
+
+ EXPECT_EQ(b.GenerateCallExpression(expr), 9u) << b.error();
+ auto got = DumpBuilder(b);
+ EXPECT_EQ(got, R"(%12 = OpExtInstImport "GLSL.std.450"
+OpName %1 "v"
+OpName %7 "a_func"
+%3 = OpTypeFloat 32
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+%6 = OpTypeVoid
+%5 = OpTypeFunction %6
+%10 = OpTypeBool
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 133693440
+%15 = OpConstant %13 524288
+%16 = OpConstant %13 133169152
+%7 = OpFunction %6 None %5
+%8 = OpLabel
+%11 = OpLoad %3 %1
+%17 = OpBitcast %13 %11
+%18 = OpBitwiseAnd %13 %17 %14
+%19 = OpExtInst %13 %12 UClamp %18 %15 %16
+%9 = OpIEqual %10 %18 %19
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(IntrinsicBuilderTest, IsNormal_Vector) {
+ auto* var = Global("v", ty.vec2<f32>(), ast::StorageClass::kPrivate);
+
+ auto* expr = Call("isNormal", "v");
+ WrapInFunction(expr);
+
+ auto* func = Func("a_func", ast::VariableList{}, ty.void_(),
+ ast::StatementList{}, ast::DecorationList{});
+
+ spirv::Builder& b = Build();
+
+ ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
+
+ EXPECT_EQ(b.GenerateCallExpression(expr), 10u) << b.error();
+ auto got = DumpBuilder(b);
+ std::cout << got << std::endl;
+ EXPECT_EQ(got, R"(%14 = OpExtInstImport "GLSL.std.450"
+OpName %1 "v"
+OpName %8 "a_func"
+%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 2
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%7 = OpTypeVoid
+%6 = OpTypeFunction %7
+%12 = OpTypeBool
+%11 = OpTypeVector %12 2
+%15 = OpTypeInt 32 0
+%16 = OpConstant %15 133693440
+%17 = OpConstant %15 524288
+%18 = OpConstant %15 133169152
+%19 = OpTypeVector %15 2
+%8 = OpFunction %7 None %6
+%9 = OpLabel
+%13 = OpLoad %3 %1
+%20 = OpCompositeConstruct %19 %16 %16
+%21 = OpCompositeConstruct %19 %17 %17
+%22 = OpCompositeConstruct %19 %18 %18
+%23 = OpBitcast %19 %13
+%24 = OpBitwiseAnd %19 %23 %20
+%25 = OpExtInst %19 %14 UClamp %24 %21 %22
+%10 = OpIEqual %11 %24 %25
+OpReturn
+OpFunctionEnd
+)");
+}
+
using IntrinsicIntTest = IntrinsicBuilderTestWithParam<IntrinsicData>;
TEST_P(IntrinsicIntTest, Call_SInt_Scalar) {
auto param = GetParam();