[tint][ir] Use intrinsic table for binary ops.
Much like BuiltinCall, make ir::Binary abstract, with a derived class for each dialect.
Add ir::CoreUnary to hold the core-dialect binary operators.
Make the validator consult the dialect's intrinsic table.
Change-Id: I798cd0d2c3d417b95e959ef9551e5c2b1f2fb34d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/168226
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/BUILD.bazel b/src/tint/lang/core/ir/BUILD.bazel
index 07d5897..a02930e 100644
--- a/src/tint/lang/core/ir/BUILD.bazel
+++ b/src/tint/lang/core/ir/BUILD.bazel
@@ -54,6 +54,7 @@
"continue.cc",
"control_instruction.cc",
"convert.cc",
+ "core_binary.cc",
"core_builtin_call.cc",
"core_unary.cc",
"disassembler.cc",
@@ -105,6 +106,7 @@
"continue.h",
"control_instruction.h",
"convert.h",
+ "core_binary.h",
"core_builtin_call.h",
"core_unary.h",
"disassembler.h",
@@ -171,7 +173,6 @@
alwayslink = True,
srcs = [
"access_test.cc",
- "binary_test.cc",
"bitcast_test.cc",
"block_param_test.cc",
"block_test.cc",
@@ -180,6 +181,7 @@
"construct_test.cc",
"continue_test.cc",
"convert_test.cc",
+ "core_binary_test.cc",
"core_builtin_call_test.cc",
"core_unary_test.cc",
"discard_test.cc",
diff --git a/src/tint/lang/core/ir/BUILD.cmake b/src/tint/lang/core/ir/BUILD.cmake
index 33f118e..94365b4 100644
--- a/src/tint/lang/core/ir/BUILD.cmake
+++ b/src/tint/lang/core/ir/BUILD.cmake
@@ -72,6 +72,8 @@
lang/core/ir/control_instruction.h
lang/core/ir/convert.cc
lang/core/ir/convert.h
+ lang/core/ir/core_binary.cc
+ lang/core/ir/core_binary.h
lang/core/ir/core_builtin_call.cc
lang/core/ir/core_builtin_call.h
lang/core/ir/core_unary.cc
@@ -172,7 +174,6 @@
################################################################################
tint_add_target(tint_lang_core_ir_test test
lang/core/ir/access_test.cc
- lang/core/ir/binary_test.cc
lang/core/ir/bitcast_test.cc
lang/core/ir/block_param_test.cc
lang/core/ir/block_test.cc
@@ -181,6 +182,7 @@
lang/core/ir/construct_test.cc
lang/core/ir/continue_test.cc
lang/core/ir/convert_test.cc
+ lang/core/ir/core_binary_test.cc
lang/core/ir/core_builtin_call_test.cc
lang/core/ir/core_unary_test.cc
lang/core/ir/discard_test.cc
diff --git a/src/tint/lang/core/ir/BUILD.gn b/src/tint/lang/core/ir/BUILD.gn
index 2fbb5ab..ac59268 100644
--- a/src/tint/lang/core/ir/BUILD.gn
+++ b/src/tint/lang/core/ir/BUILD.gn
@@ -74,6 +74,8 @@
"control_instruction.h",
"convert.cc",
"convert.h",
+ "core_binary.cc",
+ "core_binary.h",
"core_builtin_call.cc",
"core_builtin_call.h",
"core_unary.cc",
@@ -171,7 +173,6 @@
tint_unittests_source_set("unittests") {
sources = [
"access_test.cc",
- "binary_test.cc",
"bitcast_test.cc",
"block_param_test.cc",
"block_test.cc",
@@ -180,6 +181,7 @@
"construct_test.cc",
"continue_test.cc",
"convert_test.cc",
+ "core_binary_test.cc",
"core_builtin_call_test.cc",
"core_unary_test.cc",
"discard_test.cc",
diff --git a/src/tint/lang/core/ir/binary.cc b/src/tint/lang/core/ir/binary.cc
index 732f029..2738b01 100644
--- a/src/tint/lang/core/ir/binary.cc
+++ b/src/tint/lang/core/ir/binary.cc
@@ -44,49 +44,4 @@
Binary::~Binary() = default;
-Binary* Binary::Clone(CloneContext& ctx) {
- auto* new_result = ctx.Clone(Result(0));
- auto* lhs = ctx.Remap(LHS());
- auto* rhs = ctx.Remap(RHS());
- return ctx.ir.instructions.Create<Binary>(new_result, op_, lhs, rhs);
-}
-
-std::string_view ToString(enum BinaryOp op) {
- switch (op) {
- case BinaryOp::kAdd:
- return "add";
- case BinaryOp::kSubtract:
- return "subtract";
- case BinaryOp::kMultiply:
- return "multiply";
- case BinaryOp::kDivide:
- return "divide";
- case BinaryOp::kModulo:
- return "modulo";
- case BinaryOp::kAnd:
- return "and";
- case BinaryOp::kOr:
- return "or";
- case BinaryOp::kXor:
- return "xor";
- case BinaryOp::kEqual:
- return "equal";
- case BinaryOp::kNotEqual:
- return "not equal";
- case BinaryOp::kLessThan:
- return "less than";
- case BinaryOp::kGreaterThan:
- return "greater than";
- case BinaryOp::kLessThanEqual:
- return "less than equal";
- case BinaryOp::kGreaterThanEqual:
- return "greater than equal";
- case BinaryOp::kShiftLeft:
- return "shift left";
- case BinaryOp::kShiftRight:
- return "shift right";
- }
- return "<unknown>";
-}
-
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/binary.h b/src/tint/lang/core/ir/binary.h
index 5f8c26a..378dae4 100644
--- a/src/tint/lang/core/ir/binary.h
+++ b/src/tint/lang/core/ir/binary.h
@@ -30,36 +30,18 @@
#include <string>
+#include "src/tint/lang/core/binary_op.h"
#include "src/tint/lang/core/ir/operand_instruction.h"
-#include "src/tint/utils/rtti/castable.h"
+
+// Forward declarations
+namespace tint::core::intrinsic {
+struct TableData;
+}
namespace tint::core::ir {
-/// A binary operator.
-enum class BinaryOp {
- kAdd,
- kSubtract,
- kMultiply,
- kDivide,
- kModulo,
-
- kAnd,
- kOr,
- kXor,
-
- kEqual,
- kNotEqual,
- kLessThan,
- kGreaterThan,
- kLessThanEqual,
- kGreaterThanEqual,
-
- kShiftLeft,
- kShiftRight
-};
-
-/// A binary instruction in the IR.
-class Binary final : public Castable<Binary, OperandInstruction<2, 1>> {
+/// The abstract base class for dialect-specific binary-op instructions in the IR.
+class Binary : public Castable<Binary, OperandInstruction<2, 1>> {
public:
/// The offset in Operands() for the LHS
static constexpr size_t kLhsOperandOffset = 0;
@@ -78,9 +60,6 @@
Binary(InstructionResult* result, BinaryOp op, Value* lhs, Value* rhs);
~Binary() override;
- /// @copydoc Instruction::Clone()
- Binary* Clone(CloneContext& ctx) override;
-
/// @returns the binary operator
BinaryOp Op() const { return op_; }
@@ -102,20 +81,13 @@
/// @returns the friendly name for the instruction
std::string FriendlyName() const override { return "binary"; }
+ /// @returns the table data to validate this builtin
+ virtual const core::intrinsic::TableData& TableData() const = 0;
+
private:
BinaryOp op_ = BinaryOp::kAdd;
};
-/// @param kind the enum value
-/// @returns the string for the given enum value
-std::string_view ToString(BinaryOp kind);
-
-/// Emits the name of the intrinsic type.
-template <typename STREAM, typename = traits::EnableIfIsOStream<STREAM>>
-auto& operator<<(STREAM& out, BinaryOp kind) {
- return out << ToString(kind);
-}
-
} // namespace tint::core::ir
#endif // SRC_TINT_LANG_CORE_IR_BINARY_H_
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index d5fd41f..8b3d75df 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -347,8 +347,8 @@
return mod_out_.instructions.Create<ir::Access>();
}
- ir::Binary* CreateInstructionBinary(const pb::InstructionBinary& binary_in) {
- auto* binary_out = mod_out_.instructions.Create<ir::Binary>();
+ ir::CoreBinary* CreateInstructionBinary(const pb::InstructionBinary& binary_in) {
+ auto* binary_out = mod_out_.instructions.Create<ir::CoreBinary>();
binary_out->SetOp(BinaryOp(binary_in.op()));
return binary_out;
}
@@ -913,44 +913,44 @@
}
}
- core::ir::BinaryOp BinaryOp(pb::BinaryOp in) {
+ core::BinaryOp BinaryOp(pb::BinaryOp in) {
switch (in) {
case pb::BinaryOp::add_:
- return core::ir::BinaryOp::kAdd;
+ return core::BinaryOp::kAdd;
case pb::BinaryOp::subtract:
- return core::ir::BinaryOp::kSubtract;
+ return core::BinaryOp::kSubtract;
case pb::BinaryOp::multiply:
- return core::ir::BinaryOp::kMultiply;
+ return core::BinaryOp::kMultiply;
case pb::BinaryOp::divide:
- return core::ir::BinaryOp::kDivide;
+ return core::BinaryOp::kDivide;
case pb::BinaryOp::modulo:
- return core::ir::BinaryOp::kModulo;
+ return core::BinaryOp::kModulo;
case pb::BinaryOp::and_:
- return core::ir::BinaryOp::kAnd;
+ return core::BinaryOp::kAnd;
case pb::BinaryOp::or_:
- return core::ir::BinaryOp::kOr;
+ return core::BinaryOp::kOr;
case pb::BinaryOp::xor_:
- return core::ir::BinaryOp::kXor;
+ return core::BinaryOp::kXor;
case pb::BinaryOp::equal:
- return core::ir::BinaryOp::kEqual;
+ return core::BinaryOp::kEqual;
case pb::BinaryOp::not_equal:
- return core::ir::BinaryOp::kNotEqual;
+ return core::BinaryOp::kNotEqual;
case pb::BinaryOp::less_than:
- return core::ir::BinaryOp::kLessThan;
+ return core::BinaryOp::kLessThan;
case pb::BinaryOp::greater_than:
- return core::ir::BinaryOp::kGreaterThan;
+ return core::BinaryOp::kGreaterThan;
case pb::BinaryOp::less_than_equal:
- return core::ir::BinaryOp::kLessThanEqual;
+ return core::BinaryOp::kLessThanEqual;
case pb::BinaryOp::greater_than_equal:
- return core::ir::BinaryOp::kGreaterThanEqual;
+ return core::BinaryOp::kGreaterThanEqual;
case pb::BinaryOp::shift_left:
- return core::ir::BinaryOp::kShiftLeft;
+ return core::BinaryOp::kShiftLeft;
case pb::BinaryOp::shift_right:
- return core::ir::BinaryOp::kShiftRight;
+ return core::BinaryOp::kShiftRight;
default:
TINT_ICE() << "invalid BinaryOp: " << in;
- return core::ir::BinaryOp::kAdd;
+ return core::BinaryOp::kAdd;
}
}
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index cb45846..6f57223 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -35,12 +35,12 @@
#include "src/tint/lang/core/constant/scalar.h"
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/ir/access.h"
-#include "src/tint/lang/core/ir/binary.h"
#include "src/tint/lang/core/ir/bitcast.h"
#include "src/tint/lang/core/ir/break_if.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
+#include "src/tint/lang/core/ir/core_binary.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/core_unary.h"
#include "src/tint/lang/core/ir/discard.h"
@@ -198,12 +198,13 @@
tint::Switch(
inst_in, //
[&](const ir::Access* i) { InstructionAccess(*inst_out.mutable_access(), i); },
- [&](const ir::Binary* i) { InstructionBinary(*inst_out.mutable_binary(), i); },
[&](const ir::Bitcast* i) { InstructionBitcast(*inst_out.mutable_bitcast(), i); },
[&](const ir::BreakIf* i) { InstructionBreakIf(*inst_out.mutable_break_if(), i); },
+ [&](const ir::CoreBinary* i) { InstructionBinary(*inst_out.mutable_binary(), i); },
[&](const ir::CoreBuiltinCall* i) {
InstructionBuiltinCall(*inst_out.mutable_builtin_call(), i);
},
+ [&](const ir::CoreUnary* i) { InstructionUnary(*inst_out.mutable_unary(), i); },
[&](const ir::Construct* i) { InstructionConstruct(*inst_out.mutable_construct(), i); },
[&](const ir::Continue* i) { InstructionContinue(*inst_out.mutable_continue_(), i); },
[&](const ir::Convert* i) { InstructionConvert(*inst_out.mutable_convert(), i); },
@@ -230,7 +231,6 @@
},
[&](const ir::Switch* i) { InstructionSwitch(*inst_out.mutable_switch_(), i); },
[&](const ir::Swizzle* i) { InstructionSwizzle(*inst_out.mutable_swizzle(), i); },
- [&](const ir::CoreUnary* i) { InstructionUnary(*inst_out.mutable_unary(), i); },
[&](const ir::UserCall* i) { InstructionUserCall(*inst_out.mutable_user_call(), i); },
[&](const ir::Var* i) { InstructionVar(*inst_out.mutable_var(), i); },
[&](const ir::Unreachable* i) {
@@ -247,7 +247,7 @@
void InstructionAccess(pb::InstructionAccess&, const ir::Access*) {}
- void InstructionBinary(pb::InstructionBinary& binary_out, const ir::Binary* binary_in) {
+ void InstructionBinary(pb::InstructionBinary& binary_out, const ir::CoreBinary* binary_in) {
binary_out.set_op(BinaryOp(binary_in->Op()));
}
@@ -692,40 +692,44 @@
return pb::UnaryOp::complement;
}
- pb::BinaryOp BinaryOp(core::ir::BinaryOp in) {
+ pb::BinaryOp BinaryOp(core::BinaryOp in) {
switch (in) {
- case core::ir::BinaryOp::kAdd:
+ case core::BinaryOp::kAdd:
return pb::BinaryOp::add_;
- case core::ir::BinaryOp::kSubtract:
+ case core::BinaryOp::kSubtract:
return pb::BinaryOp::subtract;
- case core::ir::BinaryOp::kMultiply:
+ case core::BinaryOp::kMultiply:
return pb::BinaryOp::multiply;
- case core::ir::BinaryOp::kDivide:
+ case core::BinaryOp::kDivide:
return pb::BinaryOp::divide;
- case core::ir::BinaryOp::kModulo:
+ case core::BinaryOp::kModulo:
return pb::BinaryOp::modulo;
- case core::ir::BinaryOp::kAnd:
+ case core::BinaryOp::kAnd:
return pb::BinaryOp::and_;
- case core::ir::BinaryOp::kOr:
+ case core::BinaryOp::kOr:
return pb::BinaryOp::or_;
- case core::ir::BinaryOp::kXor:
+ case core::BinaryOp::kXor:
return pb::BinaryOp::xor_;
- case core::ir::BinaryOp::kEqual:
+ case core::BinaryOp::kEqual:
return pb::BinaryOp::equal;
- case core::ir::BinaryOp::kNotEqual:
+ case core::BinaryOp::kNotEqual:
return pb::BinaryOp::not_equal;
- case core::ir::BinaryOp::kLessThan:
+ case core::BinaryOp::kLessThan:
return pb::BinaryOp::less_than;
- case core::ir::BinaryOp::kGreaterThan:
+ case core::BinaryOp::kGreaterThan:
return pb::BinaryOp::greater_than;
- case core::ir::BinaryOp::kLessThanEqual:
+ case core::BinaryOp::kLessThanEqual:
return pb::BinaryOp::less_than_equal;
- case core::ir::BinaryOp::kGreaterThanEqual:
+ case core::BinaryOp::kGreaterThanEqual:
return pb::BinaryOp::greater_than_equal;
- case core::ir::BinaryOp::kShiftLeft:
+ case core::BinaryOp::kShiftLeft:
return pb::BinaryOp::shift_left;
- case core::ir::BinaryOp::kShiftRight:
+ case core::BinaryOp::kShiftRight:
return pb::BinaryOp::shift_right;
+ case core::BinaryOp::kLogicalAnd:
+ return pb::BinaryOp::logical_and;
+ case core::BinaryOp::kLogicalOr:
+ return pb::BinaryOp::logical_or;
}
TINT_ICE() << "invalid BinaryOp: " << in;
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index b964c92..20933ce 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -435,6 +435,8 @@
greater_than_equal = 13;
shift_left = 14;
shift_right = 15;
+ logical_and = 16;
+ logical_or = 17;
}
enum TextureDimension {
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 5e8f502..c906bb9 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -34,7 +34,6 @@
#include "src/tint/lang/core/constant/scalar.h"
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/ir/access.h"
-#include "src/tint/lang/core/ir/binary.h"
#include "src/tint/lang/core/ir/bitcast.h"
#include "src/tint/lang/core/ir/block_param.h"
#include "src/tint/lang/core/ir/break_if.h"
@@ -42,6 +41,7 @@
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
+#include "src/tint/lang/core/ir/core_binary.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/core_unary.h"
#include "src/tint/lang/core/ir/discard.h"
@@ -472,12 +472,12 @@
/// @param rhs the right-hand-side of the operation
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Binary(BinaryOp op, const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Binary(BinaryOp op, const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
CheckForNonDeterministicEvaluation<LHS, RHS>();
auto* lhs_val = Value(std::forward<LHS>(lhs));
auto* rhs_val = Value(std::forward<RHS>(rhs));
return Append(
- ir.instructions.Create<ir::Binary>(InstructionResult(type), op, lhs_val, rhs_val));
+ ir.instructions.Create<ir::CoreBinary>(InstructionResult(type), op, lhs_val, rhs_val));
}
/// Creates an And operation
@@ -486,7 +486,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* And(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* And(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kAnd, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -496,7 +496,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* And(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* And(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return And(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -507,7 +507,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Or(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Or(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kOr, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -517,7 +517,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Or(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Or(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Or(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -528,7 +528,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Xor(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Xor(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kXor, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -538,7 +538,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Xor(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Xor(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Xor(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -549,7 +549,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Equal(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Equal(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kEqual, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -559,7 +559,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Equal(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Equal(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Equal(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -570,7 +570,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* NotEqual(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* NotEqual(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kNotEqual, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -580,7 +580,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* NotEqual(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* NotEqual(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return NotEqual(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -591,7 +591,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* LessThan(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* LessThan(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kLessThan, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -601,7 +601,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* LessThan(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* LessThan(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return LessThan(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -612,7 +612,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* GreaterThan(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* GreaterThan(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kGreaterThan, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -622,7 +622,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* GreaterThan(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* GreaterThan(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return GreaterThan(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -633,7 +633,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* LessThanEqual(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* LessThanEqual(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kLessThanEqual, type, std::forward<LHS>(lhs),
std::forward<RHS>(rhs));
}
@@ -644,7 +644,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* LessThanEqual(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* LessThanEqual(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return LessThanEqual(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -655,7 +655,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* GreaterThanEqual(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* GreaterThanEqual(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kGreaterThanEqual, type, std::forward<LHS>(lhs),
std::forward<RHS>(rhs));
}
@@ -666,7 +666,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* GreaterThanEqual(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* GreaterThanEqual(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return GreaterThanEqual(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -677,7 +677,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* ShiftLeft(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* ShiftLeft(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kShiftLeft, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -687,7 +687,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* ShiftLeft(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* ShiftLeft(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return ShiftLeft(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -698,7 +698,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* ShiftRight(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* ShiftRight(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kShiftRight, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -708,7 +708,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* ShiftRight(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* ShiftRight(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return ShiftRight(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -719,7 +719,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Add(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Add(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kAdd, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -729,7 +729,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Add(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Add(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Add(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -740,7 +740,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Subtract(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Subtract(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kSubtract, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -750,7 +750,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Subtract(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Subtract(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Subtract(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -761,7 +761,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Multiply(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Multiply(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kMultiply, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -771,7 +771,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Multiply(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Multiply(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Multiply(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -782,7 +782,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Divide(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Divide(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kDivide, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -792,7 +792,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Divide(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Divide(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Divide(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -803,7 +803,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename LHS, typename RHS>
- ir::Binary* Modulo(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Modulo(const core::type::Type* type, LHS&& lhs, RHS&& rhs) {
return Binary(BinaryOp::kModulo, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -813,7 +813,7 @@
/// @param rhs the rhs of the add
/// @returns the operation
template <typename TYPE, typename LHS, typename RHS>
- ir::Binary* Modulo(LHS&& lhs, RHS&& rhs) {
+ ir::CoreBinary* Modulo(LHS&& lhs, RHS&& rhs) {
auto* type = ir.Types().Get<TYPE>();
return Modulo(type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
}
@@ -883,7 +883,7 @@
/// @param val the value
/// @returns the operation
template <typename VAL>
- ir::Binary* Not(const core::type::Type* type, VAL&& val) {
+ ir::CoreBinary* Not(const core::type::Type* type, VAL&& val) {
if (auto* vec = type->As<core::type::Vector>()) {
return Equal(type, std::forward<VAL>(val), Splat(vec, false, vec->Width()));
} else {
@@ -896,7 +896,7 @@
/// @param val the value
/// @returns the operation
template <typename TYPE, typename VAL>
- ir::Binary* Not(VAL&& val) {
+ ir::CoreBinary* Not(VAL&& val) {
auto* type = ir.Types().Get<TYPE>();
return Not(type, std::forward<VAL>(val));
}
diff --git a/src/tint/lang/core/ir/core_binary.cc b/src/tint/lang/core/ir/core_binary.cc
new file mode 100644
index 0000000..0bab152
--- /dev/null
+++ b/src/tint/lang/core/ir/core_binary.cc
@@ -0,0 +1,56 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/ir/core_binary.h"
+
+#include "src/tint/lang/core/intrinsic/dialect.h"
+#include "src/tint/lang/core/ir/clone_context.h"
+#include "src/tint/lang/core/ir/module.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::core::ir::CoreBinary);
+
+namespace tint::core::ir {
+
+CoreBinary::CoreBinary() = default;
+
+CoreBinary::CoreBinary(InstructionResult* result, BinaryOp op, Value* lhs, Value* rhs)
+ : Base(result, op, lhs, rhs) {}
+
+CoreBinary::~CoreBinary() = default;
+
+CoreBinary* CoreBinary::Clone(CloneContext& ctx) {
+ auto* new_result = ctx.Clone(Result(0));
+ auto* lhs = ctx.Remap(LHS());
+ auto* rhs = ctx.Remap(RHS());
+ return ctx.ir.instructions.Create<CoreBinary>(new_result, Op(), lhs, rhs);
+}
+
+const core::intrinsic::TableData& CoreBinary::TableData() const {
+ return core::intrinsic::Dialect::kData;
+}
+
+} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/core_binary.h b/src/tint/lang/core/ir/core_binary.h
new file mode 100644
index 0000000..7329ce6
--- /dev/null
+++ b/src/tint/lang/core/ir/core_binary.h
@@ -0,0 +1,61 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_CORE_IR_CORE_BINARY_H_
+#define SRC_TINT_LANG_CORE_IR_CORE_BINARY_H_
+
+#include "src/tint/lang/core/ir/binary.h"
+
+namespace tint::core::ir {
+
+/// A core-dialect binary-op instruction in the IR.
+class CoreBinary final : public Castable<CoreBinary, Binary> {
+ public:
+ /// The offset in Operands() for the value
+ static constexpr size_t kValueOperandOffset = 0;
+
+ /// Constructor (no results, no operands)
+ CoreBinary();
+
+ /// Constructor
+ /// @param result the result value
+ /// @param op the Binary operator
+ /// @param lhs the lhs of the instruction
+ /// @param rhs the rhs of the instruction
+ CoreBinary(InstructionResult* result, BinaryOp op, Value* lhs, Value* rhs);
+ ~CoreBinary() override;
+
+ /// @copydoc Instruction::Clone()
+ CoreBinary* Clone(CloneContext& ctx) override;
+
+ /// @returns the table data to validate this builtin
+ const core::intrinsic::TableData& TableData() const override;
+};
+
+} // namespace tint::core::ir
+
+#endif // SRC_TINT_LANG_CORE_IR_CORE_BINARY_H_
diff --git a/src/tint/lang/core/ir/binary_test.cc b/src/tint/lang/core/ir/core_binary_test.cc
similarity index 100%
rename from src/tint/lang/core/ir/binary_test.cc
rename to src/tint/lang/core/ir/core_binary_test.cc
diff --git a/src/tint/lang/core/ir/disassembler.cc b/src/tint/lang/core/ir/disassembler.cc
index ee66e2b..cfa2e30 100644
--- a/src/tint/lang/core/ir/disassembler.cc
+++ b/src/tint/lang/core/ir/disassembler.cc
@@ -852,6 +852,12 @@
case BinaryOp::kShiftRight:
out_ << "shr";
break;
+ case BinaryOp::kLogicalAnd:
+ out_ << "logical-and";
+ break;
+ case BinaryOp::kLogicalOr:
+ out_ << "logical-or";
+ break;
}
out_ << " ";
EmitOperandList(b);
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.cc b/src/tint/lang/core/ir/transform/binary_polyfill.cc
index 14c6756..83fb233 100644
--- a/src/tint/lang/core/ir/transform/binary_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.cc
@@ -66,12 +66,12 @@
/// Process the module.
void Process() {
// Find the binary instructions that need to be polyfilled.
- Vector<ir::Binary*, 64> worklist;
+ Vector<ir::CoreBinary*, 64> worklist;
for (auto* inst : ir.instructions.Objects()) {
if (!inst->Alive()) {
continue;
}
- if (auto* binary = inst->As<ir::Binary>()) {
+ if (auto* binary = inst->As<ir::CoreBinary>()) {
switch (binary->Op()) {
case BinaryOp::kDivide:
case BinaryOp::kModulo:
@@ -149,7 +149,7 @@
/// divide-by-zero and signed integer overflow.
/// @param binary the binary instruction
/// @returns the replacement value
- ir::Value* IntDivMod(ir::Binary* binary) {
+ ir::Value* IntDivMod(ir::CoreBinary* binary) {
auto* result_ty = binary->Result(0)->Type();
bool is_div = binary->Op() == BinaryOp::kDivide;
bool is_signed = result_ty->is_signed_integer_scalar_or_vector();
@@ -232,13 +232,13 @@
/// Mask the RHS of a shift instruction to ensure it is modulo the bitwidth of the LHS.
/// @param binary the binary instruction
/// @returns the replacement value
- ir::Value* MaskShiftAmount(ir::Binary* binary) {
+ ir::Value* MaskShiftAmount(ir::CoreBinary* binary) {
auto* lhs = binary->LHS();
auto* rhs = binary->RHS();
auto* mask = b.Constant(u32(lhs->Type()->DeepestElement()->Size() * 8 - 1));
auto* masked = b.And(rhs->Type(), rhs, MatchWidth(mask, rhs->Type()));
masked->InsertBefore(binary);
- binary->SetOperand(ir::Binary::kRhsOperandOffset, masked->Result(0));
+ binary->SetOperand(ir::CoreBinary::kRhsOperandOffset, masked->Result(0));
return binary->Result(0);
}
};
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill_test.cc b/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
index b582411..314ac0c 100644
--- a/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
+++ b/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
@@ -671,9 +671,9 @@
}
TEST_F(IR_BinaryPolyfillTest, Divide_Scalar_Vector) {
- Build(BinaryOp::kDivide, ty.vec4<i32>(), ty.i32(), ty.vec2<i32>());
+ Build(BinaryOp::kDivide, ty.vec4<i32>(), ty.i32(), ty.vec4<i32>());
auto* src = R"(
-%foo = func(%lhs:i32, %rhs:vec2<i32>):vec4<i32> -> %b1 {
+%foo = func(%lhs:i32, %rhs:vec4<i32>):vec4<i32> -> %b1 {
%b1 = block {
%result:vec4<i32> = div %lhs, %rhs
ret %result
@@ -681,7 +681,7 @@
}
)";
auto* expect = R"(
-%foo = func(%lhs:i32, %rhs:vec2<i32>):vec4<i32> -> %b1 {
+%foo = func(%lhs:i32, %rhs:vec4<i32>):vec4<i32> -> %b1 {
%b1 = block {
%4:vec4<i32> = construct %lhs
%result:vec4<i32> = call %tint_div_v4i32, %4, %rhs
@@ -711,9 +711,9 @@
}
TEST_F(IR_BinaryPolyfillTest, Divide_Vector_Scalar) {
- Build(BinaryOp::kDivide, ty.vec4<i32>(), ty.vec2<i32>(), ty.i32());
+ Build(BinaryOp::kDivide, ty.vec4<i32>(), ty.vec4<i32>(), ty.i32());
auto* src = R"(
-%foo = func(%lhs:vec2<i32>, %rhs:i32):vec4<i32> -> %b1 {
+%foo = func(%lhs:vec4<i32>, %rhs:i32):vec4<i32> -> %b1 {
%b1 = block {
%result:vec4<i32> = div %lhs, %rhs
ret %result
@@ -721,7 +721,7 @@
}
)";
auto* expect = R"(
-%foo = func(%lhs:vec2<i32>, %rhs:i32):vec4<i32> -> %b1 {
+%foo = func(%lhs:vec4<i32>, %rhs:i32):vec4<i32> -> %b1 {
%b1 = block {
%4:vec4<i32> = construct %rhs
%result:vec4<i32> = call %tint_div_v4i32, %lhs, %4
@@ -751,9 +751,9 @@
}
TEST_F(IR_BinaryPolyfillTest, Modulo_Scalar_Vector) {
- Build(BinaryOp::kModulo, ty.vec4<i32>(), ty.i32(), ty.vec2<i32>());
+ Build(BinaryOp::kModulo, ty.vec4<i32>(), ty.i32(), ty.vec4<i32>());
auto* src = R"(
-%foo = func(%lhs:i32, %rhs:vec2<i32>):vec4<i32> -> %b1 {
+%foo = func(%lhs:i32, %rhs:vec4<i32>):vec4<i32> -> %b1 {
%b1 = block {
%result:vec4<i32> = mod %lhs, %rhs
ret %result
@@ -761,7 +761,7 @@
}
)";
auto* expect = R"(
-%foo = func(%lhs:i32, %rhs:vec2<i32>):vec4<i32> -> %b1 {
+%foo = func(%lhs:i32, %rhs:vec4<i32>):vec4<i32> -> %b1 {
%b1 = block {
%4:vec4<i32> = construct %lhs
%result:vec4<i32> = call %tint_mod_v4i32, %4, %rhs
@@ -793,9 +793,9 @@
}
TEST_F(IR_BinaryPolyfillTest, Modulo_Vector_Scalar) {
- Build(BinaryOp::kModulo, ty.vec4<i32>(), ty.vec2<i32>(), ty.i32());
+ Build(BinaryOp::kModulo, ty.vec4<i32>(), ty.vec4<i32>(), ty.i32());
auto* src = R"(
-%foo = func(%lhs:vec2<i32>, %rhs:i32):vec4<i32> -> %b1 {
+%foo = func(%lhs:vec4<i32>, %rhs:i32):vec4<i32> -> %b1 {
%b1 = block {
%result:vec4<i32> = mod %lhs, %rhs
ret %result
@@ -803,7 +803,7 @@
}
)";
auto* expect = R"(
-%foo = func(%lhs:vec2<i32>, %rhs:i32):vec4<i32> -> %b1 {
+%foo = func(%lhs:vec4<i32>, %rhs:i32):vec4<i32> -> %b1 {
%b1 = block {
%4:vec4<i32> = construct %rhs
%result:vec4<i32> = call %tint_mod_v4i32, %lhs, %4
diff --git a/src/tint/lang/core/ir/unary.h b/src/tint/lang/core/ir/unary.h
index 7cd4b37..8ade151 100644
--- a/src/tint/lang/core/ir/unary.h
+++ b/src/tint/lang/core/ir/unary.h
@@ -40,7 +40,7 @@
namespace tint::core::ir {
-/// A unary instruction in the IR.
+/// The abstract base class for dialect-specific unary-op instructions in the IR.
class Unary : public Castable<Unary, OperandInstruction<1, 1>> {
public:
/// The offset in Operands() for the value
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 57180b3..e0c73e1 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -667,6 +667,33 @@
void Validator::CheckBinary(const Binary* b) {
CheckOperandsNotNull(b, Binary::kLhsOperandOffset, Binary::kRhsOperandOffset);
+ if (b->LHS() && b->RHS()) {
+ auto symbols = SymbolTable::Wrap(mod_.symbols);
+ auto type_mgr = type::Manager::Wrap(mod_.Types());
+ intrinsic::Context context{
+ b->TableData(),
+ type_mgr,
+ symbols,
+ };
+
+ auto overload =
+ core::intrinsic::LookupBinary(context, b->Op(), b->LHS()->Type(), b->RHS()->Type(),
+ core::EvaluationStage::kRuntime, /* is_compound */ false);
+ if (overload != Success) {
+ AddError(b, InstError(b, overload.Failure()));
+ return;
+ }
+
+ if (auto* result = b->Result(0)) {
+ if (overload->return_type != result->Type()) {
+ StringStream err;
+ err << "binary instruction result type (" << result->Type()->FriendlyName()
+ << ") does not match overload result type ("
+ << overload->return_type->FriendlyName() << ")";
+ AddError(b, InstError(b, err.str()));
+ }
+ }
+ }
}
void Validator::CheckUnary(const Unary* u) {
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 3f2fd5c..7a084a3 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -1271,8 +1271,8 @@
}
TEST_F(IR_ValidatorTest, Binary_Result_Nullptr) {
- auto* bin = mod.instructions.Create<ir::Binary>(nullptr, BinaryOp::kAdd, b.Constant(3_i),
- b.Constant(2_i));
+ auto* bin = mod.instructions.Create<ir::CoreBinary>(nullptr, BinaryOp::kAdd, b.Constant(3_i),
+ b.Constant(2_i));
auto* f = b.Function("my_func", ty.void_());
diff --git a/src/tint/lang/msl/writer/printer/binary_test.cc b/src/tint/lang/msl/writer/printer/binary_test.cc
index 74808bf..495b156 100644
--- a/src/tint/lang/msl/writer/printer/binary_test.cc
+++ b/src/tint/lang/msl/writer/printer/binary_test.cc
@@ -38,7 +38,7 @@
struct BinaryData {
const char* result;
- core::ir::BinaryOp op;
+ core::BinaryOp op;
};
inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
StringStream str;
@@ -70,22 +70,21 @@
}
)");
}
-INSTANTIATE_TEST_SUITE_P(
- MslPrinterTest,
- MslPrinterBinaryTest,
- testing::Values(BinaryData{"(left + right)", core::ir::BinaryOp::kAdd},
- BinaryData{"(left - right)", core::ir::BinaryOp::kSubtract},
- BinaryData{"(left * right)", core::ir::BinaryOp::kMultiply},
- BinaryData{"(left & right)", core::ir::BinaryOp::kAnd},
- BinaryData{"(left | right)", core::ir::BinaryOp::kOr},
- BinaryData{"(left ^ right)", core::ir::BinaryOp::kXor}));
+INSTANTIATE_TEST_SUITE_P(MslPrinterTest,
+ MslPrinterBinaryTest,
+ testing::Values(BinaryData{"(left + right)", core::BinaryOp::kAdd},
+ BinaryData{"(left - right)", core::BinaryOp::kSubtract},
+ BinaryData{"(left * right)", core::BinaryOp::kMultiply},
+ BinaryData{"(left & right)", core::BinaryOp::kAnd},
+ BinaryData{"(left | right)", core::BinaryOp::kOr},
+ BinaryData{"(left ^ right)", core::BinaryOp::kXor}));
TEST_F(MslPrinterTest, BinaryDivU32) {
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* l = b.Let("left", b.Constant(1_u));
auto* r = b.Let("right", b.Constant(2_u));
- auto* bin = b.Binary(core::ir::BinaryOp::kDivide, ty.u32(), l, r);
+ auto* bin = b.Binary(core::BinaryOp::kDivide, ty.u32(), l, r);
b.Let("val", bin);
b.Return(func);
});
@@ -108,7 +107,7 @@
b.Append(func->Block(), [&] {
auto* l = b.Let("left", b.Constant(1_u));
auto* r = b.Let("right", b.Constant(2_u));
- auto* bin = b.Binary(core::ir::BinaryOp::kModulo, ty.u32(), l, r);
+ auto* bin = b.Binary(core::BinaryOp::kModulo, ty.u32(), l, r);
b.Let("val", bin);
b.Return(func);
});
@@ -132,7 +131,7 @@
b.Append(func->Block(), [&] {
auto* l = b.Let("left", b.Constant(1_u));
auto* r = b.Let("right", b.Constant(2_u));
- auto* bin = b.Binary(core::ir::BinaryOp::kShiftLeft, ty.u32(), l, r);
+ auto* bin = b.Binary(core::BinaryOp::kShiftLeft, ty.u32(), l, r);
b.Let("val", bin);
b.Return(func);
});
@@ -152,7 +151,7 @@
b.Append(func->Block(), [&] {
auto* l = b.Let("left", b.Constant(1_u));
auto* r = b.Let("right", b.Constant(2_u));
- auto* bin = b.Binary(core::ir::BinaryOp::kShiftRight, ty.u32(), l, r);
+ auto* bin = b.Binary(core::BinaryOp::kShiftRight, ty.u32(), l, r);
b.Let("val", bin);
b.Return(func);
});
@@ -193,12 +192,12 @@
INSTANTIATE_TEST_SUITE_P(
MslPrinterTest,
MslPrinterBinaryBoolTest,
- testing::Values(BinaryData{"(left == right)", core::ir::BinaryOp::kEqual},
- BinaryData{"(left != right)", core::ir::BinaryOp::kNotEqual},
- BinaryData{"(left < right)", core::ir::BinaryOp::kLessThan},
- BinaryData{"(left > right)", core::ir::BinaryOp::kGreaterThan},
- BinaryData{"(left <= right)", core::ir::BinaryOp::kLessThanEqual},
- BinaryData{"(left >= right)", core::ir::BinaryOp::kGreaterThanEqual}));
+ testing::Values(BinaryData{"(left == right)", core::BinaryOp::kEqual},
+ BinaryData{"(left != right)", core::BinaryOp::kNotEqual},
+ BinaryData{"(left < right)", core::BinaryOp::kLessThan},
+ BinaryData{"(left > right)", core::BinaryOp::kGreaterThan},
+ BinaryData{"(left <= right)", core::BinaryOp::kLessThanEqual},
+ BinaryData{"(left >= right)", core::BinaryOp::kGreaterThanEqual}));
// TODO(dsinclair): Needs transform
// TODO(dsinclair): Requires `bitcast` support
@@ -228,9 +227,9 @@
}
constexpr BinaryData signed_overflow_defined_behaviour_cases[] = {
- {"as_type<int>((as_type<uint>(left) + as_type<uint>(right)))", core::ir::BinaryOp::kAdd},
- {"as_type<int>((as_type<uint>(left) - as_type<uint>(right)))", core::ir::BinaryOp::kSubtract},
- {"as_type<int>((as_type<uint>(left) * as_type<uint>(right)))", core::ir::BinaryOp::kMultiply}};
+ {"as_type<int>((as_type<uint>(left) + as_type<uint>(right)))", core::BinaryOp::kAdd},
+ {"as_type<int>((as_type<uint>(left) - as_type<uint>(right)))", core::BinaryOp::kSubtract},
+ {"as_type<int>((as_type<uint>(left) * as_type<uint>(right)))", core::BinaryOp::kMultiply}};
INSTANTIATE_TEST_SUITE_P(MslPrinterTest,
MslPrinterBinaryTest_SignedOverflowDefinedBehaviour,
testing::ValuesIn(signed_overflow_defined_behaviour_cases));
@@ -263,8 +262,8 @@
}
constexpr BinaryData shift_signed_overflow_defined_behaviour_cases[] = {
- {"as_type<int>((as_type<uint>(left) << right))", core::ir::BinaryOp::kShiftLeft},
- {"(left >> right)", core::ir::BinaryOp::kShiftRight}};
+ {"as_type<int>((as_type<uint>(left) << right))", core::BinaryOp::kShiftLeft},
+ {"(left >> right)", core::BinaryOp::kShiftRight}};
INSTANTIATE_TEST_SUITE_P(MslPrinterTest,
MslPrinterBinaryTest_ShiftSignedOverflowDefinedBehaviour,
testing::ValuesIn(shift_signed_overflow_defined_behaviour_cases));
@@ -299,13 +298,13 @@
constexpr BinaryData signed_overflow_defined_behaviour_chained_cases[] = {
{R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) + as_type<uint>(right)))) +
as_type<uint>(right))))",
- core::ir::BinaryOp::kAdd},
+ core::BinaryOp::kAdd},
{R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) - as_type<uint>(right)))) -
as_type<uint>(right))))",
- core::ir::BinaryOp::kSubtract},
+ core::BinaryOp::kSubtract},
{R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) * as_type<uint>(right)))) *
as_type<uint>(right))))",
- core::ir::BinaryOp::kMultiply}};
+ core::BinaryOp::kMultiply}};
INSTANTIATE_TEST_SUITE_P(MslPrinterTest,
MslPrinterBinaryTest_SignedOverflowDefinedBehaviour_Chained,
testing::ValuesIn(signed_overflow_defined_behaviour_chained_cases));
@@ -339,8 +338,8 @@
}
constexpr BinaryData shift_signed_overflow_defined_behaviour_chained_cases[] = {
{R"(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(left) << right))) << right)))",
- core::ir::BinaryOp::kShiftLeft},
- {R"(((left >> right) >> right))", core::ir::BinaryOp::kShiftRight},
+ core::BinaryOp::kShiftLeft},
+ {R"(((left >> right) >> right))", core::BinaryOp::kShiftRight},
};
INSTANTIATE_TEST_SUITE_P(MslPrinterTest,
MslPrinterBinaryTest_ShiftSignedOverflowDefinedBehaviour_Chained,
@@ -353,7 +352,7 @@
auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, f32>());
auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, f32>());
- auto* expr1 = b.Binary(core::ir::BinaryOp::kModulo, ty.f32(), left, right);
+ auto* expr1 = b.Binary(core::BinaryOp::kModulo, ty.f32(), left, right);
b.Let("val", expr1);
});
@@ -376,7 +375,7 @@
auto* left = b.Var("left", ty.ptr<core::AddressSpace::kFunction, f16>());
auto* right = b.Var("right", ty.ptr<core::AddressSpace::kFunction, f16>());
- auto* expr1 = b.Binary(core::ir::BinaryOp::kModulo, ty.f16(), left, right);
+ auto* expr1 = b.Binary(core::BinaryOp::kModulo, ty.f16(), left, right);
b.Let("val", expr1);
});
@@ -397,7 +396,7 @@
auto* left = b.Var("left", ty.ptr(core::AddressSpace::kFunction, ty.vec3<f32>()));
auto* right = b.Var("right", ty.ptr(core::AddressSpace::kFunction, ty.vec3<f32>()));
- auto* expr1 = b.Binary(core::ir::BinaryOp::kModulo, ty.vec3<f32>(), left, right);
+ auto* expr1 = b.Binary(core::BinaryOp::kModulo, ty.vec3<f32>(), left, right);
b.Let("val", expr1);
});
@@ -420,7 +419,7 @@
auto* left = b.Var("left", ty.ptr(core::AddressSpace::kFunction, ty.vec3<f16>()));
auto* right = b.Var("right", ty.ptr(core::AddressSpace::kFunction, ty.vec3<f16>()));
- auto* expr1 = b.Binary(core::ir::BinaryOp::kModulo, ty.vec3<f16>(), left, right);
+ auto* expr1 = b.Binary(core::BinaryOp::kModulo, ty.vec3<f16>(), left, right);
b.Let("val", expr1);
});
@@ -441,7 +440,7 @@
auto* left = b.Var("left", ty.ptr(core::AddressSpace::kFunction, ty.bool_()));
auto* right = b.Var("right", ty.ptr(core::AddressSpace::kFunction, ty.bool_()));
- auto* expr1 = b.Binary(core::ir::BinaryOp::kAdd, ty.bool_(), left, right);
+ auto* expr1 = b.Binary(core::BinaryOp::kAdd, ty.bool_(), left, right);
b.Let("val", expr1);
});
@@ -462,7 +461,7 @@
auto* left = b.Var("left", ty.ptr(core::AddressSpace::kFunction, ty.bool_()));
auto* right = b.Var("right", ty.ptr(core::AddressSpace::kFunction, ty.bool_()));
- auto* expr1 = b.Binary(core::ir::BinaryOp::kOr, ty.bool_(), left, right);
+ auto* expr1 = b.Binary(core::BinaryOp::kOr, ty.bool_(), left, right);
b.Let("val", expr1);
});
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index dfc8eda..e2950c5 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -35,13 +35,13 @@
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/ir/access.h"
-#include "src/tint/lang/core/ir/binary.h"
#include "src/tint/lang/core/ir/bitcast.h"
#include "src/tint/lang/core/ir/break_if.h"
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
+#include "src/tint/lang/core/ir/core_binary.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/core_unary.h"
#include "src/tint/lang/core/ir/discard.h"
@@ -350,8 +350,8 @@
[&](core::ir::LoadVectorElement*) { /* inlined */ }, //
[&](core::ir::Swizzle*) { /* inlined */ }, //
[&](core::ir::Bitcast*) { /* inlined */ }, //
+ [&](core::ir::CoreBinary*) { /* inlined */ }, //
[&](core::ir::CoreUnary*) { /* inlined */ }, //
- [&](core::ir::Binary*) { /* inlined */ }, //
[&](core::ir::Load*) { /* inlined */ }, //
[&](core::ir::Construct*) { /* inlined */ }, //
[&](core::ir::Access*) { /* inlined */ }, //
@@ -366,8 +366,8 @@
[&](const core::ir::InstructionResult* r) {
Switch(
r->Instruction(), //
+ [&](const core::ir::CoreBinary* b) { EmitBinary(out, b); }, //
[&](const core::ir::CoreUnary* u) { EmitUnary(out, u); }, //
- [&](const core::ir::Binary* b) { EmitBinary(out, b); }, //
[&](const core::ir::Convert* b) { EmitConvert(out, b); }, //
[&](const core::ir::Let* l) { out << NameOf(l->Result(0)); }, //
[&](const core::ir::Load* l) { EmitValue(out, l->From()); }, //
@@ -407,8 +407,8 @@
/// Emit a binary instruction
/// @param b the binary instruction
- void EmitBinary(StringStream& out, const core::ir::Binary* b) {
- if (b->Op() == core::ir::BinaryOp::kEqual) {
+ void EmitBinary(StringStream& out, const core::ir::CoreBinary* b) {
+ if (b->Op() == core::BinaryOp::kEqual) {
auto* rhs = b->RHS()->As<core::ir::Constant>();
if (rhs && rhs->Type()->Is<core::type::Bool>() &&
rhs->Value()->ValueAs<bool>() == false) {
@@ -422,38 +422,42 @@
auto kind = [&] {
switch (b->Op()) {
- case core::ir::BinaryOp::kAdd:
+ case core::BinaryOp::kAdd:
return "+";
- case core::ir::BinaryOp::kSubtract:
+ case core::BinaryOp::kSubtract:
return "-";
- case core::ir::BinaryOp::kMultiply:
+ case core::BinaryOp::kMultiply:
return "*";
- case core::ir::BinaryOp::kDivide:
+ case core::BinaryOp::kDivide:
return "/";
- case core::ir::BinaryOp::kModulo:
+ case core::BinaryOp::kModulo:
return "%";
- case core::ir::BinaryOp::kAnd:
+ case core::BinaryOp::kAnd:
return "&";
- case core::ir::BinaryOp::kOr:
+ case core::BinaryOp::kOr:
return "|";
- case core::ir::BinaryOp::kXor:
+ case core::BinaryOp::kXor:
return "^";
- case core::ir::BinaryOp::kEqual:
+ case core::BinaryOp::kEqual:
return "==";
- case core::ir::BinaryOp::kNotEqual:
+ case core::BinaryOp::kNotEqual:
return "!=";
- case core::ir::BinaryOp::kLessThan:
+ case core::BinaryOp::kLessThan:
return "<";
- case core::ir::BinaryOp::kGreaterThan:
+ case core::BinaryOp::kGreaterThan:
return ">";
- case core::ir::BinaryOp::kLessThanEqual:
+ case core::BinaryOp::kLessThanEqual:
return "<=";
- case core::ir::BinaryOp::kGreaterThanEqual:
+ case core::BinaryOp::kGreaterThanEqual:
return ">=";
- case core::ir::BinaryOp::kShiftLeft:
+ case core::BinaryOp::kShiftLeft:
return "<<";
- case core::ir::BinaryOp::kShiftRight:
+ case core::BinaryOp::kShiftRight:
return ">>";
+ case core::BinaryOp::kLogicalAnd:
+ return "&&";
+ case core::BinaryOp::kLogicalOr:
+ return "||";
}
return "<error>";
};
diff --git a/src/tint/lang/spirv/writer/binary_test.cc b/src/tint/lang/spirv/writer/binary_test.cc
index 607b04b..a14b39b 100644
--- a/src/tint/lang/spirv/writer/binary_test.cc
+++ b/src/tint/lang/spirv/writer/binary_test.cc
@@ -38,10 +38,38 @@
/// A parameterized test case.
struct BinaryTestCase {
- /// The element type to test.
- TestElementType type;
+ BinaryTestCase(TestElementType type_,
+ core::BinaryOp op_,
+ std::string spirv_inst_,
+ std::string spirv_type_name_)
+ : res_type(type_),
+ lhs_type(type_),
+ rhs_type(type_),
+ op(op_),
+ spirv_inst(spirv_inst_),
+ spirv_type_name(spirv_type_name_) {}
+
+ BinaryTestCase(TestElementType res_type_,
+ TestElementType lhs_type_,
+ TestElementType rhs_type_,
+ core::BinaryOp op_,
+ std::string spirv_inst_,
+ std::string spirv_type_name_)
+ : res_type(res_type_),
+ lhs_type(lhs_type_),
+ rhs_type(rhs_type_),
+ op(op_),
+ spirv_inst(spirv_inst_),
+ spirv_type_name(spirv_type_name_) {}
+
+ /// The result type of the binary op.
+ TestElementType res_type;
+ /// The LHS type of the binary op.
+ TestElementType lhs_type;
+ /// The RHS type of the binary op.
+ TestElementType rhs_type;
/// The binary operation.
- core::ir::BinaryOp op;
+ core::BinaryOp op;
/// The expected SPIR-V instruction.
std::string spirv_inst;
/// The expected SPIR-V result type name.
@@ -54,9 +82,9 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* lhs = MakeScalarValue(params.type);
- auto* rhs = MakeScalarValue(params.type);
- auto* result = b.Binary(params.op, MakeScalarType(params.type), lhs, rhs);
+ auto* lhs = MakeScalarValue(params.lhs_type);
+ auto* rhs = MakeScalarValue(params.rhs_type);
+ auto* result = b.Binary(params.op, MakeScalarType(params.res_type), lhs, rhs);
b.Return(func);
mod.SetName(result, "result");
});
@@ -69,9 +97,9 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* lhs = MakeVectorValue(params.type);
- auto* rhs = MakeVectorValue(params.type);
- auto* result = b.Binary(params.op, MakeVectorType(params.type), lhs, rhs);
+ auto* lhs = MakeVectorValue(params.lhs_type);
+ auto* rhs = MakeVectorValue(params.rhs_type);
+ auto* result = b.Binary(params.op, MakeVectorType(params.res_type), lhs, rhs);
b.Return(func);
mod.SetName(result, "result");
});
@@ -82,48 +110,49 @@
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_I32,
Arithmetic_Bitwise,
- testing::Values(
- BinaryTestCase{kI32, core::ir::BinaryOp::kAdd, "OpIAdd", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kSubtract, "OpISub", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kMultiply, "OpIMul", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kAnd, "OpBitwiseAnd", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kOr, "OpBitwiseOr", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kXor, "OpBitwiseXor", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kShiftLeft, "OpShiftLeftLogical", "int"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kShiftRight, "OpShiftRightArithmetic", "int"}));
+ testing::Values(BinaryTestCase{kI32, core::BinaryOp::kAdd, "OpIAdd", "int"},
+ BinaryTestCase{kI32, core::BinaryOp::kSubtract, "OpISub", "int"},
+ BinaryTestCase{kI32, core::BinaryOp::kMultiply, "OpIMul", "int"},
+ BinaryTestCase{kI32, core::BinaryOp::kAnd, "OpBitwiseAnd", "int"},
+ BinaryTestCase{kI32, core::BinaryOp::kOr, "OpBitwiseOr", "int"},
+ BinaryTestCase{kI32, core::BinaryOp::kXor, "OpBitwiseXor", "int"},
+ BinaryTestCase{kI32, kI32, kU32, core::BinaryOp::kShiftLeft,
+ "OpShiftLeftLogical", "int"},
+ BinaryTestCase{kI32, kI32, kU32, core::BinaryOp::kShiftRight,
+ "OpShiftRightArithmetic", "int"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_U32,
Arithmetic_Bitwise,
- testing::Values(
- BinaryTestCase{kU32, core::ir::BinaryOp::kAdd, "OpIAdd", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kSubtract, "OpISub", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kMultiply, "OpIMul", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kAnd, "OpBitwiseAnd", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kOr, "OpBitwiseOr", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kXor, "OpBitwiseXor", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kShiftLeft, "OpShiftLeftLogical", "uint"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kShiftRight, "OpShiftRightLogical", "uint"}));
+ testing::Values(BinaryTestCase{kU32, core::BinaryOp::kAdd, "OpIAdd", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kSubtract, "OpISub", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kMultiply, "OpIMul", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kAnd, "OpBitwiseAnd", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kOr, "OpBitwiseOr", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kXor, "OpBitwiseXor", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kShiftLeft, "OpShiftLeftLogical", "uint"},
+ BinaryTestCase{kU32, core::BinaryOp::kShiftRight, "OpShiftRightLogical",
+ "uint"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_F32,
Arithmetic_Bitwise,
- testing::Values(BinaryTestCase{kF32, core::ir::BinaryOp::kAdd, "OpFAdd", "float"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kSubtract, "OpFSub", "float"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kMultiply, "OpFMul", "float"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kDivide, "OpFDiv", "float"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kModulo, "OpFRem", "float"}));
+ testing::Values(BinaryTestCase{kF32, core::BinaryOp::kAdd, "OpFAdd", "float"},
+ BinaryTestCase{kF32, core::BinaryOp::kSubtract, "OpFSub", "float"},
+ BinaryTestCase{kF32, core::BinaryOp::kMultiply, "OpFMul", "float"},
+ BinaryTestCase{kF32, core::BinaryOp::kDivide, "OpFDiv", "float"},
+ BinaryTestCase{kF32, core::BinaryOp::kModulo, "OpFRem", "float"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_F16,
Arithmetic_Bitwise,
- testing::Values(BinaryTestCase{kF16, core::ir::BinaryOp::kAdd, "OpFAdd", "half"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kSubtract, "OpFSub", "half"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kMultiply, "OpFMul", "half"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kDivide, "OpFDiv", "half"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kModulo, "OpFRem", "half"}));
+ testing::Values(BinaryTestCase{kF16, core::BinaryOp::kAdd, "OpFAdd", "half"},
+ BinaryTestCase{kF16, core::BinaryOp::kSubtract, "OpFSub", "half"},
+ BinaryTestCase{kF16, core::BinaryOp::kMultiply, "OpFMul", "half"},
+ BinaryTestCase{kF16, core::BinaryOp::kDivide, "OpFDiv", "half"},
+ BinaryTestCase{kF16, core::BinaryOp::kModulo, "OpFRem", "half"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_Bool,
Arithmetic_Bitwise,
- testing::Values(BinaryTestCase{kBool, core::ir::BinaryOp::kAnd, "OpLogicalAnd", "bool"},
- BinaryTestCase{kBool, core::ir::BinaryOp::kOr, "OpLogicalOr", "bool"}));
+ testing::Values(BinaryTestCase{kBool, core::BinaryOp::kAnd, "OpLogicalAnd", "bool"},
+ BinaryTestCase{kBool, core::BinaryOp::kOr, "OpLogicalOr", "bool"}));
TEST_F(SpirvWriterTest, Binary_ScalarTimesVector_F32) {
auto* scalar = b.FunctionParam("scalar", ty.f32());
@@ -236,8 +265,8 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* lhs = MakeScalarValue(params.type);
- auto* rhs = MakeScalarValue(params.type);
+ auto* lhs = MakeScalarValue(params.lhs_type);
+ auto* rhs = MakeScalarValue(params.rhs_type);
auto* result = b.Binary(params.op, ty.bool_(), lhs, rhs);
b.Return(func);
mod.SetName(result, "result");
@@ -252,8 +281,8 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
- auto* lhs = MakeVectorValue(params.type);
- auto* rhs = MakeVectorValue(params.type);
+ auto* lhs = MakeVectorValue(params.lhs_type);
+ auto* rhs = MakeVectorValue(params.rhs_type);
auto* result = b.Binary(params.op, ty.vec2<bool>(), lhs, rhs);
b.Return(func);
mod.SetName(result, "result");
@@ -266,50 +295,47 @@
SpirvWriterTest_Binary_I32,
Comparison,
testing::Values(
- BinaryTestCase{kI32, core::ir::BinaryOp::kEqual, "OpIEqual", "bool"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kNotEqual, "OpINotEqual", "bool"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kGreaterThan, "OpSGreaterThan", "bool"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kGreaterThanEqual, "OpSGreaterThanEqual", "bool"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kLessThan, "OpSLessThan", "bool"},
- BinaryTestCase{kI32, core::ir::BinaryOp::kLessThanEqual, "OpSLessThanEqual", "bool"}));
+ BinaryTestCase{kI32, core::BinaryOp::kEqual, "OpIEqual", "bool"},
+ BinaryTestCase{kI32, core::BinaryOp::kNotEqual, "OpINotEqual", "bool"},
+ BinaryTestCase{kI32, core::BinaryOp::kGreaterThan, "OpSGreaterThan", "bool"},
+ BinaryTestCase{kI32, core::BinaryOp::kGreaterThanEqual, "OpSGreaterThanEqual", "bool"},
+ BinaryTestCase{kI32, core::BinaryOp::kLessThan, "OpSLessThan", "bool"},
+ BinaryTestCase{kI32, core::BinaryOp::kLessThanEqual, "OpSLessThanEqual", "bool"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_U32,
Comparison,
testing::Values(
- BinaryTestCase{kU32, core::ir::BinaryOp::kEqual, "OpIEqual", "bool"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kNotEqual, "OpINotEqual", "bool"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kGreaterThan, "OpUGreaterThan", "bool"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kGreaterThanEqual, "OpUGreaterThanEqual", "bool"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kLessThan, "OpULessThan", "bool"},
- BinaryTestCase{kU32, core::ir::BinaryOp::kLessThanEqual, "OpULessThanEqual", "bool"}));
+ BinaryTestCase{kU32, core::BinaryOp::kEqual, "OpIEqual", "bool"},
+ BinaryTestCase{kU32, core::BinaryOp::kNotEqual, "OpINotEqual", "bool"},
+ BinaryTestCase{kU32, core::BinaryOp::kGreaterThan, "OpUGreaterThan", "bool"},
+ BinaryTestCase{kU32, core::BinaryOp::kGreaterThanEqual, "OpUGreaterThanEqual", "bool"},
+ BinaryTestCase{kU32, core::BinaryOp::kLessThan, "OpULessThan", "bool"},
+ BinaryTestCase{kU32, core::BinaryOp::kLessThanEqual, "OpULessThanEqual", "bool"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_F32,
Comparison,
testing::Values(
- BinaryTestCase{kF32, core::ir::BinaryOp::kEqual, "OpFOrdEqual", "bool"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kNotEqual, "OpFOrdNotEqual", "bool"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kGreaterThan, "OpFOrdGreaterThan", "bool"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual",
- "bool"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kLessThan, "OpFOrdLessThan", "bool"},
- BinaryTestCase{kF32, core::ir::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual", "bool"}));
+ BinaryTestCase{kF32, core::BinaryOp::kEqual, "OpFOrdEqual", "bool"},
+ BinaryTestCase{kF32, core::BinaryOp::kNotEqual, "OpFOrdNotEqual", "bool"},
+ BinaryTestCase{kF32, core::BinaryOp::kGreaterThan, "OpFOrdGreaterThan", "bool"},
+ BinaryTestCase{kF32, core::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual", "bool"},
+ BinaryTestCase{kF32, core::BinaryOp::kLessThan, "OpFOrdLessThan", "bool"},
+ BinaryTestCase{kF32, core::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual", "bool"}));
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest_Binary_F16,
Comparison,
testing::Values(
- BinaryTestCase{kF16, core::ir::BinaryOp::kEqual, "OpFOrdEqual", "bool"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kNotEqual, "OpFOrdNotEqual", "bool"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kGreaterThan, "OpFOrdGreaterThan", "bool"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual",
- "bool"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kLessThan, "OpFOrdLessThan", "bool"},
- BinaryTestCase{kF16, core::ir::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual", "bool"}));
-INSTANTIATE_TEST_SUITE_P(SpirvWriterTest_Binary_Bool,
- Comparison,
- testing::Values(BinaryTestCase{kBool, core::ir::BinaryOp::kEqual,
- "OpLogicalEqual", "bool"},
- BinaryTestCase{kBool, core::ir::BinaryOp::kNotEqual,
- "OpLogicalNotEqual", "bool"}));
+ BinaryTestCase{kF16, core::BinaryOp::kEqual, "OpFOrdEqual", "bool"},
+ BinaryTestCase{kF16, core::BinaryOp::kNotEqual, "OpFOrdNotEqual", "bool"},
+ BinaryTestCase{kF16, core::BinaryOp::kGreaterThan, "OpFOrdGreaterThan", "bool"},
+ BinaryTestCase{kF16, core::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual", "bool"},
+ BinaryTestCase{kF16, core::BinaryOp::kLessThan, "OpFOrdLessThan", "bool"},
+ BinaryTestCase{kF16, core::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual", "bool"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpirvWriterTest_Binary_Bool,
+ Comparison,
+ testing::Values(BinaryTestCase{kBool, core::BinaryOp::kEqual, "OpLogicalEqual", "bool"},
+ BinaryTestCase{kBool, core::BinaryOp::kNotEqual, "OpLogicalNotEqual", "bool"}));
TEST_F(SpirvWriterTest, Binary_Chain) {
auto* func = b.Function("foo", ty.void_());
@@ -334,7 +360,7 @@
auto* func = b.Function("foo", ty.u32());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kDivide, ty.u32(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kDivide, ty.u32(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -370,7 +396,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kDivide, ty.i32(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kDivide, ty.i32(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -410,7 +436,7 @@
auto* func = b.Function("foo", ty.vec4<i32>());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kDivide, ty.vec4<i32>(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kDivide, ty.vec4<i32>(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -452,7 +478,7 @@
auto* func = b.Function("foo", ty.vec4<i32>());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kDivide, ty.vec4<i32>(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kDivide, ty.vec4<i32>(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -494,7 +520,7 @@
auto* func = b.Function("foo", ty.u32());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kModulo, ty.u32(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kModulo, ty.u32(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -532,7 +558,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kModulo, ty.i32(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kModulo, ty.i32(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -574,7 +600,7 @@
auto* func = b.Function("foo", ty.vec4<i32>());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kModulo, ty.vec4<i32>(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kModulo, ty.vec4<i32>(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
@@ -618,7 +644,7 @@
auto* func = b.Function("foo", ty.vec4<i32>());
func->SetParams(args);
b.Append(func->Block(), [&] {
- auto* result = b.Binary(core::ir::BinaryOp::kModulo, ty.vec4<i32>(), args[0], args[1]);
+ auto* result = b.Binary(core::BinaryOp::kModulo, ty.vec4<i32>(), args[0], args[1]);
b.Return(func, result);
mod.SetName(result, "result");
});
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 7863fd8..8eec980 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -879,8 +879,8 @@
Switch(
inst, //
[&](core::ir::Access* a) { EmitAccess(a); }, //
- [&](core::ir::Binary* b) { EmitBinary(b); }, //
[&](core::ir::Bitcast* b) { EmitBitcast(b); }, //
+ [&](core::ir::CoreBinary* b) { EmitBinary(b); }, //
[&](core::ir::CoreBuiltinCall* b) { EmitCoreBuiltinCall(b); }, //
[&](spirv::ir::BuiltinCall* b) { EmitSpirvBuiltinCall(b); }, //
[&](core::ir::Construct* c) { EmitConstruct(c); }, //
@@ -1060,7 +1060,7 @@
/// Emit a binary instruction.
/// @param binary the binary instruction to emit
- void EmitBinary(core::ir::Binary* binary) {
+ void EmitBinary(core::ir::CoreBinary* binary) {
auto id = Value(binary);
auto lhs = Value(binary->LHS());
auto rhs = Value(binary->RHS());
@@ -1070,11 +1070,11 @@
// Determine the opcode.
spv::Op op = spv::Op::Max;
switch (binary->Op()) {
- case core::ir::BinaryOp::kAdd: {
+ case core::BinaryOp::kAdd: {
op = ty->is_integer_scalar_or_vector() ? spv::Op::OpIAdd : spv::Op::OpFAdd;
break;
}
- case core::ir::BinaryOp::kDivide: {
+ case core::BinaryOp::kDivide: {
if (ty->is_signed_integer_scalar_or_vector()) {
op = spv::Op::OpSDiv;
} else if (ty->is_unsigned_integer_scalar_or_vector()) {
@@ -1084,7 +1084,7 @@
}
break;
}
- case core::ir::BinaryOp::kMultiply: {
+ case core::BinaryOp::kMultiply: {
if (ty->is_integer_scalar_or_vector()) {
op = spv::Op::OpIMul;
} else if (ty->is_float_scalar_or_vector()) {
@@ -1092,11 +1092,11 @@
}
break;
}
- case core::ir::BinaryOp::kSubtract: {
+ case core::BinaryOp::kSubtract: {
op = ty->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
break;
}
- case core::ir::BinaryOp::kModulo: {
+ case core::BinaryOp::kModulo: {
if (ty->is_signed_integer_scalar_or_vector()) {
op = spv::Op::OpSRem;
} else if (ty->is_unsigned_integer_scalar_or_vector()) {
@@ -1107,7 +1107,7 @@
break;
}
- case core::ir::BinaryOp::kAnd: {
+ case core::BinaryOp::kAnd: {
if (ty->is_integer_scalar_or_vector()) {
op = spv::Op::OpBitwiseAnd;
} else if (ty->is_bool_scalar_or_vector()) {
@@ -1115,7 +1115,7 @@
}
break;
}
- case core::ir::BinaryOp::kOr: {
+ case core::BinaryOp::kOr: {
if (ty->is_integer_scalar_or_vector()) {
op = spv::Op::OpBitwiseOr;
} else if (ty->is_bool_scalar_or_vector()) {
@@ -1123,16 +1123,16 @@
}
break;
}
- case core::ir::BinaryOp::kXor: {
+ case core::BinaryOp::kXor: {
op = spv::Op::OpBitwiseXor;
break;
}
- case core::ir::BinaryOp::kShiftLeft: {
+ case core::BinaryOp::kShiftLeft: {
op = spv::Op::OpShiftLeftLogical;
break;
}
- case core::ir::BinaryOp::kShiftRight: {
+ case core::BinaryOp::kShiftRight: {
if (ty->is_signed_integer_scalar_or_vector()) {
op = spv::Op::OpShiftRightArithmetic;
} else if (ty->is_unsigned_integer_scalar_or_vector()) {
@@ -1141,7 +1141,7 @@
break;
}
- case core::ir::BinaryOp::kEqual: {
+ case core::BinaryOp::kEqual: {
if (lhs_ty->is_bool_scalar_or_vector()) {
op = spv::Op::OpLogicalEqual;
} else if (lhs_ty->is_float_scalar_or_vector()) {
@@ -1151,7 +1151,7 @@
}
break;
}
- case core::ir::BinaryOp::kNotEqual: {
+ case core::BinaryOp::kNotEqual: {
if (lhs_ty->is_bool_scalar_or_vector()) {
op = spv::Op::OpLogicalNotEqual;
} else if (lhs_ty->is_float_scalar_or_vector()) {
@@ -1161,7 +1161,7 @@
}
break;
}
- case core::ir::BinaryOp::kGreaterThan: {
+ case core::BinaryOp::kGreaterThan: {
if (lhs_ty->is_float_scalar_or_vector()) {
op = spv::Op::OpFOrdGreaterThan;
} else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
@@ -1171,7 +1171,7 @@
}
break;
}
- case core::ir::BinaryOp::kGreaterThanEqual: {
+ case core::BinaryOp::kGreaterThanEqual: {
if (lhs_ty->is_float_scalar_or_vector()) {
op = spv::Op::OpFOrdGreaterThanEqual;
} else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
@@ -1181,7 +1181,7 @@
}
break;
}
- case core::ir::BinaryOp::kLessThan: {
+ case core::BinaryOp::kLessThan: {
if (lhs_ty->is_float_scalar_or_vector()) {
op = spv::Op::OpFOrdLessThan;
} else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
@@ -1191,7 +1191,7 @@
}
break;
}
- case core::ir::BinaryOp::kLessThanEqual: {
+ case core::BinaryOp::kLessThanEqual: {
if (lhs_ty->is_float_scalar_or_vector()) {
op = spv::Op::OpFOrdLessThanEqual;
} else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
@@ -1201,6 +1201,9 @@
}
break;
}
+ default:
+ TINT_UNIMPLEMENTED() << binary->Op();
+ break;
}
// Emit the instruction.
diff --git a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
index 083a98b..3385ad0 100644
--- a/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
+++ b/src/tint/lang/spirv/writer/raise/expand_implicit_splats.cc
@@ -46,7 +46,7 @@
// Find the instructions that use implicit splats and either modify them in place or record them
// to be replaced in a second pass.
- Vector<core::ir::Binary*, 4> binary_worklist;
+ Vector<core::ir::CoreBinary*, 4> binary_worklist;
Vector<core::ir::CoreBuiltinCall*, 4> builtin_worklist;
for (auto* inst : ir.instructions.Objects()) {
if (!inst->Alive()) {
@@ -63,7 +63,7 @@
construct->AppendArg(construct->Args()[0]);
}
}
- } else if (auto* binary = inst->As<core::ir::Binary>()) {
+ } else if (auto* binary = inst->As<core::ir::CoreBinary>()) {
// A binary instruction that mixes vector and scalar operands needs to have the scalar
// operand replaced with an explicit vector constructor.
if (binary->Result(0)->Type()->Is<core::type::Vector>()) {
@@ -101,7 +101,7 @@
// Replace scalar operands to binary instructions that produce vectors.
for (auto* binary : binary_worklist) {
auto* result_ty = binary->Result(0)->Type();
- if (result_ty->is_float_vector() && binary->Op() == core::ir::BinaryOp::kMultiply) {
+ if (result_ty->is_float_vector() && binary->Op() == core::BinaryOp::kMultiply) {
// Use OpVectorTimesScalar for floating point multiply.
auto* vts =
b.Call<spirv::ir::BuiltinCall>(result_ty, spirv::BuiltinFn::kVectorTimesScalar);
@@ -121,9 +121,9 @@
} else {
// Expand the scalar argument into an explicitly constructed vector.
if (binary->LHS()->Type()->Is<core::type::Scalar>()) {
- expand_operand(binary, core::ir::Binary::kLhsOperandOffset);
+ expand_operand(binary, core::ir::CoreBinary::kLhsOperandOffset);
} else if (binary->RHS()->Type()->Is<core::type::Scalar>()) {
- expand_operand(binary, core::ir::Binary::kRhsOperandOffset);
+ expand_operand(binary, core::ir::CoreBinary::kRhsOperandOffset);
}
}
}
diff --git a/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc b/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc
index 6be5daf..3e77d21 100644
--- a/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc
+++ b/src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.cc
@@ -48,13 +48,13 @@
core::ir::Builder b{ir};
// Find the instructions that need to be modified.
- Vector<core::ir::Binary*, 4> binary_worklist;
+ Vector<core::ir::CoreBinary*, 4> binary_worklist;
Vector<core::ir::Convert*, 4> convert_worklist;
for (auto* inst : ir.instructions.Objects()) {
if (!inst->Alive()) {
continue;
}
- if (auto* binary = inst->As<core::ir::Binary>()) {
+ if (auto* binary = inst->As<core::ir::CoreBinary>()) {
TINT_ASSERT(binary->Operands().Length() == 2);
if (binary->LHS()->Type()->Is<core::type::Matrix>() ||
binary->RHS()->Type()->Is<core::type::Matrix>()) {
@@ -101,13 +101,13 @@
};
switch (binary->Op()) {
- case core::ir::BinaryOp::kAdd:
- column_wise(core::ir::BinaryOp::kAdd);
+ case core::BinaryOp::kAdd:
+ column_wise(core::BinaryOp::kAdd);
break;
- case core::ir::BinaryOp::kSubtract:
- column_wise(core::ir::BinaryOp::kSubtract);
+ case core::BinaryOp::kSubtract:
+ column_wise(core::BinaryOp::kSubtract);
break;
- case core::ir::BinaryOp::kMultiply:
+ case core::BinaryOp::kMultiply:
// Select the SPIR-V intrinsic that corresponds to the operation being performed.
if (lhs_ty->Is<core::type::Matrix>()) {
if (rhs_ty->Is<core::type::Scalar>()) {
diff --git a/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc b/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc
index 0ae2eea..740c918 100644
--- a/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc
+++ b/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc
@@ -950,7 +950,7 @@
if (!rhs) {
return;
}
- core::ir::Binary* inst = impl.BinaryOp(ty, lhs, rhs, b->op);
+ auto* inst = impl.BinaryOp(ty, lhs, rhs, b->op);
if (!inst) {
return;
}
@@ -1288,10 +1288,10 @@
TINT_ICE_ON_NO_MATCH);
}
- core::ir::Binary* BinaryOp(const core::type::Type* ty,
- core::ir::Value* lhs,
- core::ir::Value* rhs,
- core::BinaryOp op) {
+ core::ir::CoreBinary* BinaryOp(const core::type::Type* ty,
+ core::ir::Value* lhs,
+ core::ir::Value* rhs,
+ core::BinaryOp op) {
switch (op) {
case core::BinaryOp::kAnd:
return builder_.And(ty, lhs, rhs);
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index 642617c..47b6468 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -812,7 +812,7 @@
}
void Binary(const core::ir::Binary* e) {
- if (e->Op() == core::ir::BinaryOp::kEqual) {
+ if (e->Op() == core::BinaryOp::kEqual) {
auto* rhs = e->RHS()->As<core::ir::Constant>();
if (rhs && rhs->Type()->Is<core::type::Bool>() &&
rhs->Value()->ValueAs<bool>() == false) {
@@ -825,54 +825,60 @@
auto* rhs = Expr(e->RHS());
const ast::Expression* expr = nullptr;
switch (e->Op()) {
- case core::ir::BinaryOp::kAdd:
+ case core::BinaryOp::kAdd:
expr = b.Add(lhs, rhs);
break;
- case core::ir::BinaryOp::kSubtract:
+ case core::BinaryOp::kSubtract:
expr = b.Sub(lhs, rhs);
break;
- case core::ir::BinaryOp::kMultiply:
+ case core::BinaryOp::kMultiply:
expr = b.Mul(lhs, rhs);
break;
- case core::ir::BinaryOp::kDivide:
+ case core::BinaryOp::kDivide:
expr = b.Div(lhs, rhs);
break;
- case core::ir::BinaryOp::kModulo:
+ case core::BinaryOp::kModulo:
expr = b.Mod(lhs, rhs);
break;
- case core::ir::BinaryOp::kAnd:
+ case core::BinaryOp::kAnd:
expr = b.And(lhs, rhs);
break;
- case core::ir::BinaryOp::kOr:
+ case core::BinaryOp::kOr:
expr = b.Or(lhs, rhs);
break;
- case core::ir::BinaryOp::kXor:
+ case core::BinaryOp::kXor:
expr = b.Xor(lhs, rhs);
break;
- case core::ir::BinaryOp::kEqual:
+ case core::BinaryOp::kEqual:
expr = b.Equal(lhs, rhs);
break;
- case core::ir::BinaryOp::kNotEqual:
+ case core::BinaryOp::kNotEqual:
expr = b.NotEqual(lhs, rhs);
break;
- case core::ir::BinaryOp::kLessThan:
+ case core::BinaryOp::kLessThan:
expr = b.LessThan(lhs, rhs);
break;
- case core::ir::BinaryOp::kGreaterThan:
+ case core::BinaryOp::kGreaterThan:
expr = b.GreaterThan(lhs, rhs);
break;
- case core::ir::BinaryOp::kLessThanEqual:
+ case core::BinaryOp::kLessThanEqual:
expr = b.LessThanEqual(lhs, rhs);
break;
- case core::ir::BinaryOp::kGreaterThanEqual:
+ case core::BinaryOp::kGreaterThanEqual:
expr = b.GreaterThanEqual(lhs, rhs);
break;
- case core::ir::BinaryOp::kShiftLeft:
+ case core::BinaryOp::kShiftLeft:
expr = b.Shl(lhs, rhs);
break;
- case core::ir::BinaryOp::kShiftRight:
+ case core::BinaryOp::kShiftRight:
expr = b.Shr(lhs, rhs);
break;
+ case core::BinaryOp::kLogicalAnd:
+ expr = b.LogicalAnd(lhs, rhs);
+ break;
+ case core::BinaryOp::kLogicalOr:
+ expr = b.LogicalOr(lhs, rhs);
+ break;
}
Bind(e->Result(0), expr);
}