tint: const eval of binary XOR
Bug: tint:1581
Change-Id: I5605426f0c4b9447ce770092de4ab2f639d0218d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/102580
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index e32a0f4..23e0dc8 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -920,8 +920,8 @@
op % <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
op % <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
-op ^ <T: iu32>(T, T) -> T
-op ^ <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
+@const op ^ <T: ia_iu32>(T, T) -> T
+@const op ^ <T: ia_iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@const op & (bool, bool) -> bool
@const op & <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 8e46f9a..7e33244 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -1445,6 +1445,23 @@
return r;
}
+ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j});
+ };
+ return Dispatch_ia_iu32(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 6b57556..04e2282 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -356,6 +356,15 @@
utils::VectorRef<const sem::Constant*> args,
const Source& source);
+ /// Bitwise xor operator '^'
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpXor(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index f06fa32..3b1d7ee 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -3675,6 +3675,43 @@
});
}
+template <typename T>
+std::vector<Case> XorCases() {
+ using B = BitValues<T>;
+ return {
+ C(T{0b1010}, T{0b1111}, T{0b0101}),
+ C(T{0b1010}, T{0b0000}, T{0b1010}),
+ C(T{0b1010}, T{0b0011}, T{0b1001}),
+ C(T{0b1010}, T{0b1100}, T{0b0110}),
+ C(T{0b1010}, T{0b0101}, T{0b1111}),
+ C(B::All, B::All, T{0}),
+ C(B::LeftMost, B::LeftMost, T{0}),
+ C(B::RightMost, B::RightMost, T{0}),
+ C(B::All, T{0}, B::All),
+ C(T{0}, B::All, B::All),
+ C(B::LeftMost, B::AllButLeftMost, B::All),
+ C(B::AllButLeftMost, B::LeftMost, B::All),
+ C(B::RightMost, B::AllButRightMost, B::All),
+ C(B::AllButRightMost, B::RightMost, B::All),
+ C(Vec(B::All, B::LeftMost, B::RightMost), //
+ Vec(B::All, B::All, B::All), //
+ Vec(T{0}, B::AllButLeftMost, B::AllButRightMost)), //
+ C(Vec(B::All, B::LeftMost, B::RightMost), //
+ Vec(T{0}, T{0}, T{0}), //
+ Vec(B::All, B::LeftMost, B::RightMost)), //
+ C(Vec(B::LeftMost, B::RightMost), //
+ Vec(B::AllButLeftMost, B::AllButRightMost), //
+ Vec(B::All, B::All)),
+ };
+}
+INSTANTIATE_TEST_SUITE_P(Xor,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kXor),
+ testing::ValuesIn(Concat(XorCases<AInt>(), //
+ XorCases<i32>(), //
+ XorCases<u32>()))));
+
// Tests for errors on overflow/underflow of binary operations with abstract numbers
struct OverflowCase {
ast::BinaryOp op;
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index 559466d..f74cb25 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -13122,24 +13122,24 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[14],
+ /* template types */ &kTemplateTypes[10],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[689],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpXor,
},
{
/* [413] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[14],
+ /* template types */ &kTemplateTypes[10],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[687],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpXor,
},
{
/* [414] */
@@ -14703,8 +14703,8 @@
},
{
/* [5] */
- /* op ^<T : iu32>(T, T) -> T */
- /* op ^<T : iu32, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
+ /* op ^<T : ia_iu32>(T, T) -> T */
+ /* op ^<T : ia_iu32, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[412],
},
diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl
index f3a33ad..6202d8b 100644
--- a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl
+++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl
@@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
- const int r = (1 ^ 2);
+ const int r = 3;
return;
}
diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl
index f3a33ad..6202d8b 100644
--- a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl
+++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl
@@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
- const int r = (1 ^ 2);
+ const int r = 3;
return;
}
diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl
index 235697c..aa5e335 100644
--- a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl
+++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl
@@ -1,7 +1,7 @@
#version 310 es
void f() {
- int r = (1 ^ 2);
+ int r = 3;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl
index 5e4b909..8a5e655 100644
--- a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl
+++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl
@@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
- const uint r = (1u ^ 2u);
+ const uint r = 3u;
return;
}
diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl
index 5e4b909..8a5e655 100644
--- a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl
+++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl
@@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
- const uint r = (1u ^ 2u);
+ const uint r = 3u;
return;
}
diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl
index c6cce08..936c36d 100644
--- a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl
+++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl
@@ -1,7 +1,7 @@
#version 310 es
void f() {
- uint r = (1u ^ 2u);
+ uint r = 3u;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;