[ir] Create a real logical not operator.
Currently logical not is turned into a `a == false` comparison. This Cl
creates a real unary not operator and updates the backends to use it as
needed.
Change-Id: If91c5cfbece7d48c18d4b462c83afd863b0b379a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/194162
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 9508335..defcab7 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -927,12 +927,8 @@
/// @param val the value
/// @returns the operation
template <typename VAL>
- ir::CoreBinary* Not(const core::type::Type* type, VAL&& val) {
- if (auto* vec = type->As<core::type::Vector>()) {
- return Equal(type, std::forward<VAL>(val), Splat(vec, false));
- } else {
- return Equal(type, std::forward<VAL>(val), Constant(false));
- }
+ ir::CoreUnary* Not(const core::type::Type* type, VAL&& val) {
+ return Unary(UnaryOp::kNot, type, std::forward<VAL>(val));
}
/// Creates a Not operation
@@ -940,7 +936,7 @@
/// @param val the value
/// @returns the operation
template <typename TYPE, typename VAL>
- ir::CoreBinary* Not(VAL&& val) {
+ ir::CoreUnary* Not(VAL&& val) {
auto* type = ir.Types().Get<TYPE>();
return Not(type, std::forward<VAL>(val));
}
diff --git a/src/tint/lang/core/ir/core_binary_test.cc b/src/tint/lang/core/ir/core_binary_test.cc
index 6ce32d7..4eac99c 100644
--- a/src/tint/lang/core/ir/core_binary_test.cc
+++ b/src/tint/lang/core/ir/core_binary_test.cc
@@ -211,23 +211,6 @@
EXPECT_EQ(2_i, rhs->As<core::constant::Scalar<i32>>()->ValueAs<i32>());
}
-TEST_F(IR_BinaryTest, CreateNot) {
- auto* inst = b.Not(mod.Types().bool_(), true);
-
- ASSERT_TRUE(inst->Is<Binary>());
- EXPECT_EQ(inst->Op(), BinaryOp::kEqual);
-
- ASSERT_TRUE(inst->LHS()->Is<Constant>());
- auto lhs = inst->LHS()->As<Constant>()->Value();
- ASSERT_TRUE(lhs->Is<core::constant::Scalar<bool>>());
- EXPECT_TRUE(lhs->As<core::constant::Scalar<bool>>()->ValueAs<bool>());
-
- ASSERT_TRUE(inst->RHS()->Is<Constant>());
- auto rhs = inst->RHS()->As<Constant>()->Value();
- ASSERT_TRUE(rhs->Is<core::constant::Scalar<bool>>());
- EXPECT_FALSE(rhs->As<core::constant::Scalar<bool>>()->ValueAs<bool>());
-}
-
TEST_F(IR_BinaryTest, CreateShiftLeft) {
auto* inst = b.ShiftLeft(mod.Types().i32(), 4_i, 2_i);
diff --git a/src/tint/lang/core/ir/core_unary_test.cc b/src/tint/lang/core/ir/core_unary_test.cc
index 26054a0..67a68f0 100644
--- a/src/tint/lang/core/ir/core_unary_test.cc
+++ b/src/tint/lang/core/ir/core_unary_test.cc
@@ -63,6 +63,18 @@
EXPECT_EQ(4_i, lhs->As<core::constant::Scalar<i32>>()->ValueAs<i32>());
}
+TEST_F(IR_UnaryTest, CreateNot) {
+ auto* inst = b.Not(mod.Types().bool_(), true);
+
+ ASSERT_TRUE(inst->Is<Unary>());
+ EXPECT_EQ(inst->Op(), UnaryOp::kNot);
+
+ ASSERT_TRUE(inst->Val()->Is<Constant>());
+ auto lhs = inst->Val()->As<Constant>()->Value();
+ ASSERT_TRUE(lhs->Is<core::constant::Scalar<bool>>());
+ EXPECT_EQ(true, lhs->As<core::constant::Scalar<bool>>()->ValueAs<bool>());
+}
+
TEST_F(IR_UnaryTest, Usage) {
auto* inst = b.Negation(mod.Types().i32(), 4_i);
diff --git a/src/tint/lang/core/ir/transform/demote_to_helper.cc b/src/tint/lang/core/ir/transform/demote_to_helper.cc
index a460c41..eb7835f 100644
--- a/src/tint/lang/core/ir/transform/demote_to_helper.cc
+++ b/src/tint/lang/core/ir/transform/demote_to_helper.cc
@@ -194,7 +194,7 @@
if (ret->Func()->Stage() == Function::PipelineStage::kFragment) {
b.InsertBefore(ret, [&] {
auto* cond = b.Load(continue_execution);
- auto* ifelse = b.If(b.Equal(ty.bool_(), cond, false));
+ auto* ifelse = b.If(b.Not<bool>(cond));
b.Append(ifelse->True(), [&] { //
b.TerminateInvocation();
});
diff --git a/src/tint/lang/core/ir/transform/demote_to_helper_test.cc b/src/tint/lang/core/ir/transform/demote_to_helper_test.cc
index db15893..057c9d0 100644
--- a/src/tint/lang/core/ir/transform/demote_to_helper_test.cc
+++ b/src/tint/lang/core/ir/transform/demote_to_helper_test.cc
@@ -139,7 +139,7 @@
}
}
%6:bool = load %continue_execution
- %7:bool = eq %6, false
+ %7:bool = not %6
if %7 [t: $B5] { # if_3
$B5: { # true
terminate_invocation
@@ -236,7 +236,7 @@
}
%7:void = call %foo
%8:bool = load %continue_execution
- %9:bool = eq %8, false
+ %9:bool = not %8
if %9 [t: $B6] { # if_3
$B6: { # true
terminate_invocation
@@ -335,7 +335,7 @@
}
}
%9:bool = load %continue_execution
- %10:bool = eq %9, false
+ %10:bool = not %9
if %10 [t: $B6] { # if_3
$B6: { # true
terminate_invocation
@@ -434,7 +434,7 @@
$B5: {
%8:void = call %foo, %front_facing
%9:bool = load %continue_execution
- %10:bool = eq %9, false
+ %10:bool = not %9
if %10 [t: $B6] { # if_3
$B6: { # true
terminate_invocation
@@ -510,7 +510,7 @@
store %priv, 42i
store %func, 42i
%6:bool = load %continue_execution
- %7:bool = eq %6, false
+ %7:bool = not %6
if %7 [t: $B4] { # if_2
$B4: { # true
terminate_invocation
@@ -598,7 +598,7 @@
}
}
%9:bool = load %continue_execution
- %10:bool = eq %9, false
+ %10:bool = not %9
if %10 [t: $B5] { # if_3
$B5: { # true
terminate_invocation
@@ -677,7 +677,7 @@
}
}
%7:bool = load %continue_execution
- %8:bool = eq %7, false
+ %8:bool = not %7
if %8 [t: $B5] { # if_3
$B5: { # true
terminate_invocation
@@ -760,7 +760,7 @@
}
%8:i32 = add %6, 1i
%9:bool = load %continue_execution
- %10:bool = eq %9, false
+ %10:bool = not %9
if %10 [t: $B5] { # if_3
$B5: { # true
terminate_invocation
@@ -857,7 +857,7 @@
%8:i32 = access %6, 0i
%9:i32 = add %8, 1i
%10:bool = load %continue_execution
- %11:bool = eq %10, false
+ %11:bool = not %10
if %11 [t: $B5] { # if_3
$B5: { # true
terminate_invocation
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 43f956f..89adda9 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -447,6 +447,9 @@
case core::UnaryOp::kComplement:
out << "~";
break;
+ case core::UnaryOp::kNot:
+ out << "!";
+ break;
default:
TINT_UNIMPLEMENTED() << u->Op();
}
@@ -458,18 +461,6 @@
/// Emit a binary instruction
/// @param b the binary instruction
void EmitBinary(StringStream& out, const core::ir::CoreBinary* b) {
- if (b->Op() == core::BinaryOp::kEqual) {
- auto* rhs = b->RHS()->As<core::ir::Constant>();
- if (rhs && rhs->Type()->Is<core::type::Bool>() &&
- rhs->Value()->ValueAs<bool>() == false) {
- // expr == false
- out << "!(";
- EmitValue(out, b->LHS());
- out << ")";
- return;
- }
- }
-
auto kind = [&] {
switch (b->Op()) {
case core::BinaryOp::kAdd:
diff --git a/src/tint/lang/spirv/writer/discard_test.cc b/src/tint/lang/spirv/writer/discard_test.cc
index 843ffa6..b9755e3 100644
--- a/src/tint/lang/spirv/writer/discard_test.cc
+++ b/src/tint/lang/spirv/writer/discard_test.cc
@@ -74,7 +74,7 @@
OpBranch %26
%26 = OpLabel
%29 = OpLoad %bool %continue_execution
- %30 = OpLogicalEqual %bool %29 %false
+ %30 = OpLogicalNot %bool %29
OpSelectionMerge %31 None
OpBranchConditional %30 %32 %31
%32 = OpLabel
@@ -130,7 +130,7 @@
%26 = OpLabel
%32 = OpPhi %int %29 %27 %33 %28
%34 = OpLoad %bool %continue_execution
- %35 = OpLogicalEqual %bool %34 %false
+ %35 = OpLogicalNot %bool %34
OpSelectionMerge %36 None
OpBranchConditional %35 %37 %36
%37 = OpLabel
diff --git a/src/tint/lang/spirv/writer/loop_test.cc b/src/tint/lang/spirv/writer/loop_test.cc
index 1631549..a02873e 100644
--- a/src/tint/lang/spirv/writer/loop_test.cc
+++ b/src/tint/lang/spirv/writer/loop_test.cc
@@ -432,7 +432,7 @@
%13 = OpPhi %int %18 %6
%19 = OpPhi %bool %15 %6
%20 = OpSGreaterThan %bool %13 %int_5
- %17 = OpLogicalEqual %bool %19 %false
+ %17 = OpLogicalNot %bool %19
OpBranchConditional %20 %9 %8
%9 = OpLabel
OpReturn
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 8a7ccc1..2a3df4d 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1983,6 +1983,9 @@
op = spv::Op::OpSNegate;
}
break;
+ case core::UnaryOp::kNot:
+ op = spv::Op::OpLogicalNot;
+ break;
default:
TINT_UNIMPLEMENTED() << unary->Op();
}
diff --git a/src/tint/lang/wgsl/reader/program_to_ir/unary_test.cc b/src/tint/lang/wgsl/reader/program_to_ir/unary_test.cc
index 4c8fe31..87ecb03 100644
--- a/src/tint/lang/wgsl/reader/program_to_ir/unary_test.cc
+++ b/src/tint/lang/wgsl/reader/program_to_ir/unary_test.cc
@@ -51,7 +51,7 @@
%test_function = @compute @workgroup_size(1, 1, 1) func():void {
$B2: {
%3:bool = call %my_func
- %4:bool = eq %3, false
+ %4:bool = not %3
%tint_symbol:bool = let %4
ret
}
@@ -75,7 +75,7 @@
%test_function = @compute @workgroup_size(1, 1, 1) func():void {
$B2: {
%3:vec4<bool> = call %my_func
- %4:vec4<bool> = eq %3, vec4<bool>(false)
+ %4:vec4<bool> = not %3
%tint_symbol:vec4<bool> = let %4
ret
}
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index e856dd4..6f0b757 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -680,14 +680,15 @@
case core::UnaryOp::kNegation:
expr = b.Negation(Expr(u->Val()));
break;
+ case core::UnaryOp::kNot:
+ expr = b.Not(Expr(u->Val()));
+ break;
case core::UnaryOp::kAddressOf:
expr = b.AddressOf(Expr(u->Val()));
break;
case core::UnaryOp::kIndirection:
expr = b.Deref(Expr(u->Val()));
break;
- default:
- TINT_UNIMPLEMENTED() << u->Op();
}
Bind(u->Result(0), expr);
}