spirv-writer: fix bool equality, inequality

Fixed: tint:743
Change-Id: I03b5d50d2bf3cd17b672401f1922bde35cbf2640
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52740
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 7f51152..bb45f2d 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1885,6 +1885,8 @@
   }
 
   bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
+  bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector();
+  bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector();
   bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
 
   spv::Op op = spv::Op::OpNop;
@@ -1901,7 +1903,16 @@
       op = spv::Op::OpSDiv;
     }
   } else if (expr->IsEqual()) {
-    op = lhs_is_float_or_vec ? spv::Op::OpFOrdEqual : spv::Op::OpIEqual;
+    if (lhs_is_float_or_vec) {
+      op = spv::Op::OpFOrdEqual;
+    } else if (lhs_is_bool_or_vec) {
+      op = spv::Op::OpLogicalEqual;
+    } else if (lhs_is_integer_or_vec) {
+      op = spv::Op::OpIEqual;
+    } else {
+      error_ = "invalid equal expression";
+      return 0;
+    }
   } else if (expr->IsGreaterThan()) {
     if (lhs_is_float_or_vec) {
       op = spv::Op::OpFOrdGreaterThan;
@@ -1983,7 +1994,16 @@
       return 0;
     }
   } else if (expr->IsNotEqual()) {
-    op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual;
+    if (lhs_is_float_or_vec) {
+      op = spv::Op::OpFOrdNotEqual;
+    } else if (lhs_is_bool_or_vec) {
+      op = spv::Op::OpLogicalNotEqual;
+    } else if (lhs_is_integer_or_vec) {
+      op = spv::Op::OpINotEqual;
+    } else {
+      error_ = "invalid not-equal expression";
+      return 0;
+    }
   } else if (expr->IsOr()) {
     op = spv::Op::OpBitwiseOr;
   } else if (expr->IsShiftLeft()) {
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index 3b8391b..f18f460 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -250,6 +250,61 @@
                     BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
                     BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
 
+using BinaryCompareBoolTest = TestParamHelper<BinaryData>;
+TEST_P(BinaryCompareBoolTest, Scalar) {
+  auto param = GetParam();
+
+  auto* lhs = Expr(true);
+  auto* rhs = Expr(false);
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+
+  b.push_function(Function{});
+
+  EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
+%2 = OpConstantTrue %1
+%3 = OpConstantFalse %1
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            "%4 = " + param.name + " %1 %2 %3\n");
+}
+
+TEST_P(BinaryCompareBoolTest, Vector) {
+  auto param = GetParam();
+
+  auto* lhs = vec3<bool>(false, true, false);
+  auto* rhs = vec3<bool>(true, false, true);
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+
+  b.push_function(Function{});
+
+  EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
+%1 = OpTypeVector %2 3
+%3 = OpConstantFalse %2
+%4 = OpConstantTrue %2
+%5 = OpConstantComposite %1 %3 %4 %3
+%6 = OpConstantComposite %1 %4 %3 %4
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            "%7 = " + param.name + " %1 %5 %6\n");
+}
+INSTANTIATE_TEST_SUITE_P(
+    BuilderTest,
+    BinaryCompareBoolTest,
+    testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpLogicalEqual"},
+                    BinaryData{ast::BinaryOp::kNotEqual, "OpLogicalNotEqual"}));
+
 using BinaryCompareUnsignedIntegerTest = TestParamHelper<BinaryData>;
 TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
   auto param = GetParam();