[ir][spirv-writer] Emit comparison instructions
Bug: tint:1906
Change-Id: I82a9e70c3b20999e991865e2ef00113c63db84c4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134321
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index f1e3410..b9ce817 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -412,6 +412,7 @@
uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
auto id = module_.NextId();
+ auto* lhs_ty = binary->LHS()->Type();
// Determine the opcode.
spv::Op op = spv::Op::Max;
@@ -424,6 +425,68 @@
op = binary->Type()->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
break;
}
+
+ case ir::Binary::Kind::kEqual: {
+ if (lhs_ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalEqual;
+ } else if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdEqual;
+ } else if (lhs_ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpIEqual;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kNotEqual: {
+ if (lhs_ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalNotEqual;
+ } else if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdNotEqual;
+ } else if (lhs_ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpINotEqual;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kGreaterThan: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdGreaterThan;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSGreaterThan;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUGreaterThan;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kGreaterThanEqual: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdGreaterThanEqual;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSGreaterThanEqual;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUGreaterThanEqual;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kLessThan: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdLessThan;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSLessThan;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpULessThan;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kLessThanEqual: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdLessThanEqual;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSLessThanEqual;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpULessThanEqual;
+ }
+ break;
+ }
+
default: {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented binary instruction: " << static_cast<uint32_t>(binary->Kind());
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 264f619..a139ad1 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -14,11 +14,109 @@
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+#include "gmock/gmock.h"
+#include "src/tint/ir/binary.h"
+
using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::spirv {
namespace {
+/// The element type of a test.
+enum Type {
+ kBool,
+ kI32,
+ kU32,
+ kF32,
+ kF16,
+};
+
+/// A parameterized test case.
+struct BinaryTestCase {
+ /// The element type to test.
+ Type type;
+ /// The binary operation.
+ enum ir::Binary::Kind kind;
+ /// The expected SPIR-V instruction.
+ std::string spirv_inst;
+};
+
+/// A helper class for parameterized binary instruction tests.
+class BinaryInstructionTest : public SpvGeneratorImplTestWithParam<BinaryTestCase> {
+ protected:
+ /// Helper to make a scalar type corresponding to the element type `ty`.
+ /// @param ty the element type
+ /// @returns the scalar type
+ const type::Type* MakeScalarType(Type ty) {
+ switch (ty) {
+ case kBool:
+ return mod.Types().bool_();
+ case kI32:
+ return mod.Types().i32();
+ case kU32:
+ return mod.Types().u32();
+ case kF32:
+ return mod.Types().f32();
+ case kF16:
+ return mod.Types().f16();
+ }
+ return nullptr;
+ }
+
+ /// Helper to make a vector type corresponding to the element type `ty`.
+ /// @param ty the element type
+ /// @returns the vector type
+ const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); }
+
+ /// Helper to make a scalar value with the scalar type `ty`.
+ /// @param ty the element type
+ /// @returns the scalar value
+ ir::Value* MakeScalarValue(Type ty) {
+ switch (ty) {
+ case kBool:
+ return b.Constant(true);
+ case kI32:
+ return b.Constant(1_i);
+ case kU32:
+ return b.Constant(1_u);
+ case kF32:
+ return b.Constant(1_f);
+ case kF16:
+ return b.Constant(1_h);
+ }
+ return nullptr;
+ }
+
+ /// Helper to make a vector value with an element type of `ty`.
+ /// @param ty the element type
+ /// @returns the vector value
+ ir::Value* MakeVectorValue(Type ty) {
+ switch (ty) {
+ case kBool:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true),
+ b.ir.constant_values.Get(false)}));
+ case kI32:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i),
+ b.ir.constant_values.Get(-10_i)}));
+ case kU32:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty),
+ utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)}));
+ case kF32:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f),
+ b.ir.constant_values.Get(-0.5_f)}));
+ case kF16:
+ return b.Constant(b.ir.constant_values.Composite(
+ MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h),
+ b.ir.constant_values.Get(-0.5_h)}));
+ }
+ return nullptr;
+ }
+};
+
TEST_F(SpvGeneratorImplTest, Binary_Add_I32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{
@@ -210,6 +308,78 @@
)");
}
+using Comparison = BinaryInstructionTest;
+TEST_P(Comparison, Scalar) {
+ auto params = GetParam();
+
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{b.CreateBinary(params.kind, mod.Types().bool_(), MakeScalarValue(params.type),
+ MakeScalarValue(params.type)),
+ b.Branch(func->EndTarget())});
+
+ generator_.EmitFunction(func);
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+}
+TEST_P(Comparison, Vector) {
+ auto params = GetParam();
+
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+ func->StartTarget()->SetInstructions(
+ utils::Vector{b.CreateBinary(params.kind, mod.Types().vec2(mod.Types().bool_()),
+ MakeVectorValue(params.type), MakeVectorValue(params.type)),
+
+ b.Branch(func->EndTarget())});
+
+ generator_.EmitFunction(func);
+ EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_I32,
+ Comparison,
+ testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kEqual, "OpIEqual"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThan, "OpSGreaterThan"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThanEqual,
+ "OpSGreaterThanEqual"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kLessThan, "OpSLessThan"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kLessThanEqual, "OpSLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_U32,
+ Comparison,
+ testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kEqual, "OpIEqual"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThan, "OpUGreaterThan"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThanEqual,
+ "OpUGreaterThanEqual"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kLessThan, "OpULessThan"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kLessThanEqual, "OpULessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_F32,
+ Comparison,
+ testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThanEqual,
+ "OpFOrdGreaterThanEqual"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_F16,
+ Comparison,
+ testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThanEqual,
+ "OpFOrdGreaterThanEqual"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_Bool,
+ Comparison,
+ testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kEqual, "OpLogicalEqual"},
+ BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"}));
+
TEST_F(SpvGeneratorImplTest, Binary_Chain) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* a = b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));