[ir][spirv-writer] Emit comparison instructions

Bug: tint:1906
Change-Id: I82a9e70c3b20999e991865e2ef00113c63db84c4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134321
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index f1e3410..b9ce817 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -412,6 +412,7 @@
 
 uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
     auto id = module_.NextId();
+    auto* lhs_ty = binary->LHS()->Type();
 
     // Determine the opcode.
     spv::Op op = spv::Op::Max;
@@ -424,6 +425,68 @@
             op = binary->Type()->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
             break;
         }
+
+        case ir::Binary::Kind::kEqual: {
+            if (lhs_ty->is_bool_scalar_or_vector()) {
+                op = spv::Op::OpLogicalEqual;
+            } else if (lhs_ty->is_float_scalar_or_vector()) {
+                op = spv::Op::OpFOrdEqual;
+            } else if (lhs_ty->is_integer_scalar_or_vector()) {
+                op = spv::Op::OpIEqual;
+            }
+            break;
+        }
+        case ir::Binary::Kind::kNotEqual: {
+            if (lhs_ty->is_bool_scalar_or_vector()) {
+                op = spv::Op::OpLogicalNotEqual;
+            } else if (lhs_ty->is_float_scalar_or_vector()) {
+                op = spv::Op::OpFOrdNotEqual;
+            } else if (lhs_ty->is_integer_scalar_or_vector()) {
+                op = spv::Op::OpINotEqual;
+            }
+            break;
+        }
+        case ir::Binary::Kind::kGreaterThan: {
+            if (lhs_ty->is_float_scalar_or_vector()) {
+                op = spv::Op::OpFOrdGreaterThan;
+            } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+                op = spv::Op::OpSGreaterThan;
+            } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+                op = spv::Op::OpUGreaterThan;
+            }
+            break;
+        }
+        case ir::Binary::Kind::kGreaterThanEqual: {
+            if (lhs_ty->is_float_scalar_or_vector()) {
+                op = spv::Op::OpFOrdGreaterThanEqual;
+            } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+                op = spv::Op::OpSGreaterThanEqual;
+            } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+                op = spv::Op::OpUGreaterThanEqual;
+            }
+            break;
+        }
+        case ir::Binary::Kind::kLessThan: {
+            if (lhs_ty->is_float_scalar_or_vector()) {
+                op = spv::Op::OpFOrdLessThan;
+            } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+                op = spv::Op::OpSLessThan;
+            } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+                op = spv::Op::OpULessThan;
+            }
+            break;
+        }
+        case ir::Binary::Kind::kLessThanEqual: {
+            if (lhs_ty->is_float_scalar_or_vector()) {
+                op = spv::Op::OpFOrdLessThanEqual;
+            } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+                op = spv::Op::OpSLessThanEqual;
+            } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+                op = spv::Op::OpULessThanEqual;
+            }
+            break;
+        }
+
         default: {
             TINT_ICE(Writer, diagnostics_)
                 << "unimplemented binary instruction: " << static_cast<uint32_t>(binary->Kind());
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 264f619..a139ad1 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -14,11 +14,109 @@
 
 #include "src/tint/writer/spirv/ir/test_helper_ir.h"
 
+#include "gmock/gmock.h"
+#include "src/tint/ir/binary.h"
+
 using namespace tint::number_suffixes;  // NOLINT
 
 namespace tint::writer::spirv {
 namespace {
 
+/// The element type of a test.
+enum Type {
+    kBool,
+    kI32,
+    kU32,
+    kF32,
+    kF16,
+};
+
+/// A parameterized test case.
+struct BinaryTestCase {
+    /// The element type to test.
+    Type type;
+    /// The binary operation.
+    enum ir::Binary::Kind kind;
+    /// The expected SPIR-V instruction.
+    std::string spirv_inst;
+};
+
+/// A helper class for parameterized binary instruction tests.
+class BinaryInstructionTest : public SpvGeneratorImplTestWithParam<BinaryTestCase> {
+  protected:
+    /// Helper to make a scalar type corresponding to the element type `ty`.
+    /// @param ty the element type
+    /// @returns the scalar type
+    const type::Type* MakeScalarType(Type ty) {
+        switch (ty) {
+            case kBool:
+                return mod.Types().bool_();
+            case kI32:
+                return mod.Types().i32();
+            case kU32:
+                return mod.Types().u32();
+            case kF32:
+                return mod.Types().f32();
+            case kF16:
+                return mod.Types().f16();
+        }
+        return nullptr;
+    }
+
+    /// Helper to make a vector type corresponding to the element type `ty`.
+    /// @param ty the element type
+    /// @returns the vector type
+    const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); }
+
+    /// Helper to make a scalar value with the scalar type `ty`.
+    /// @param ty the element type
+    /// @returns the scalar value
+    ir::Value* MakeScalarValue(Type ty) {
+        switch (ty) {
+            case kBool:
+                return b.Constant(true);
+            case kI32:
+                return b.Constant(1_i);
+            case kU32:
+                return b.Constant(1_u);
+            case kF32:
+                return b.Constant(1_f);
+            case kF16:
+                return b.Constant(1_h);
+        }
+        return nullptr;
+    }
+
+    /// Helper to make a vector value with an element type of `ty`.
+    /// @param ty the element type
+    /// @returns the vector value
+    ir::Value* MakeVectorValue(Type ty) {
+        switch (ty) {
+            case kBool:
+                return b.Constant(b.ir.constant_values.Composite(
+                    MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true),
+                                                      b.ir.constant_values.Get(false)}));
+            case kI32:
+                return b.Constant(b.ir.constant_values.Composite(
+                    MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i),
+                                                      b.ir.constant_values.Get(-10_i)}));
+            case kU32:
+                return b.Constant(b.ir.constant_values.Composite(
+                    MakeVectorType(ty),
+                    utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)}));
+            case kF32:
+                return b.Constant(b.ir.constant_values.Composite(
+                    MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f),
+                                                      b.ir.constant_values.Get(-0.5_f)}));
+            case kF16:
+                return b.Constant(b.ir.constant_values.Composite(
+                    MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h),
+                                                      b.ir.constant_values.Get(-0.5_h)}));
+        }
+        return nullptr;
+    }
+};
+
 TEST_F(SpvGeneratorImplTest, Binary_Add_I32) {
     auto* func = b.CreateFunction("foo", mod.Types().void_());
     func->StartTarget()->SetInstructions(utils::Vector{
@@ -210,6 +308,78 @@
 )");
 }
 
+using Comparison = BinaryInstructionTest;
+TEST_P(Comparison, Scalar) {
+    auto params = GetParam();
+
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+    func->StartTarget()->SetInstructions(
+        utils::Vector{b.CreateBinary(params.kind, mod.Types().bool_(), MakeScalarValue(params.type),
+                                     MakeScalarValue(params.type)),
+                      b.Branch(func->EndTarget())});
+
+    generator_.EmitFunction(func);
+    EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+}
+TEST_P(Comparison, Vector) {
+    auto params = GetParam();
+
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+    func->StartTarget()->SetInstructions(
+        utils::Vector{b.CreateBinary(params.kind, mod.Types().vec2(mod.Types().bool_()),
+                                     MakeVectorValue(params.type), MakeVectorValue(params.type)),
+
+                      b.Branch(func->EndTarget())});
+
+    generator_.EmitFunction(func);
+    EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+}
+INSTANTIATE_TEST_SUITE_P(
+    SpvGeneratorImplTest_Binary_I32,
+    Comparison,
+    testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kEqual, "OpIEqual"},
+                    BinaryTestCase{kI32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
+                    BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThan, "OpSGreaterThan"},
+                    BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThanEqual,
+                                   "OpSGreaterThanEqual"},
+                    BinaryTestCase{kI32, ir::Binary::Kind::kLessThan, "OpSLessThan"},
+                    BinaryTestCase{kI32, ir::Binary::Kind::kLessThanEqual, "OpSLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+    SpvGeneratorImplTest_Binary_U32,
+    Comparison,
+    testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kEqual, "OpIEqual"},
+                    BinaryTestCase{kU32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
+                    BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThan, "OpUGreaterThan"},
+                    BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThanEqual,
+                                   "OpUGreaterThanEqual"},
+                    BinaryTestCase{kU32, ir::Binary::Kind::kLessThan, "OpULessThan"},
+                    BinaryTestCase{kU32, ir::Binary::Kind::kLessThanEqual, "OpULessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+    SpvGeneratorImplTest_Binary_F32,
+    Comparison,
+    testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
+                    BinaryTestCase{kF32, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
+                    BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
+                    BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThanEqual,
+                                   "OpFOrdGreaterThanEqual"},
+                    BinaryTestCase{kF32, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
+                    BinaryTestCase{kF32, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+    SpvGeneratorImplTest_Binary_F16,
+    Comparison,
+    testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
+                    BinaryTestCase{kF16, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
+                    BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
+                    BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThanEqual,
+                                   "OpFOrdGreaterThanEqual"},
+                    BinaryTestCase{kF16, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
+                    BinaryTestCase{kF16, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+INSTANTIATE_TEST_SUITE_P(
+    SpvGeneratorImplTest_Binary_Bool,
+    Comparison,
+    testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kEqual, "OpLogicalEqual"},
+                    BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"}));
+
 TEST_F(SpvGeneratorImplTest, Binary_Chain) {
     auto* func = b.CreateFunction("foo", mod.Types().void_());
     auto* a = b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));