[ir] Add BinaryPolyfill transform
This is for generic binary instruction polyfills that use core IR
constructs. Add a polyfill for the RHS of a bitwise shift.
Use this in the SPIR-V writer.
Bug: tint:1718, tint:1906
Change-Id: I2165af4ffb78a9d284f5762e696a3b2c09a5d2f1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/147301
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/transform/BUILD.cmake b/src/tint/lang/core/ir/transform/BUILD.cmake
index 01ee55d..664b66c 100644
--- a/src/tint/lang/core/ir/transform/BUILD.cmake
+++ b/src/tint/lang/core/ir/transform/BUILD.cmake
@@ -31,6 +31,8 @@
lang/core/ir/transform/add_empty_entry_point.h
lang/core/ir/transform/bgra8unorm_polyfill.cc
lang/core/ir/transform/bgra8unorm_polyfill.h
+ lang/core/ir/transform/binary_polyfill.cc
+ lang/core/ir/transform/binary_polyfill.h
lang/core/ir/transform/binding_remapper.cc
lang/core/ir/transform/binding_remapper.h
lang/core/ir/transform/block_decorated_structs.cc
@@ -84,6 +86,7 @@
tint_add_target(tint_lang_core_ir_transform_test test
lang/core/ir/transform/add_empty_entry_point_test.cc
lang/core/ir/transform/bgra8unorm_polyfill_test.cc
+ lang/core/ir/transform/binary_polyfill_test.cc
lang/core/ir/transform/binding_remapper_test.cc
lang/core/ir/transform/block_decorated_structs_test.cc
lang/core/ir/transform/builtin_polyfill_test.cc
diff --git a/src/tint/lang/core/ir/transform/BUILD.gn b/src/tint/lang/core/ir/transform/BUILD.gn
index d7950c0..ad8e700 100644
--- a/src/tint/lang/core/ir/transform/BUILD.gn
+++ b/src/tint/lang/core/ir/transform/BUILD.gn
@@ -30,6 +30,8 @@
"add_empty_entry_point.h",
"bgra8unorm_polyfill.cc",
"bgra8unorm_polyfill.h",
+ "binary_polyfill.cc",
+ "binary_polyfill.h",
"binding_remapper.cc",
"binding_remapper.h",
"block_decorated_structs.cc",
@@ -77,6 +79,7 @@
sources = [
"add_empty_entry_point_test.cc",
"bgra8unorm_polyfill_test.cc",
+ "binary_polyfill_test.cc",
"binding_remapper_test.cc",
"block_decorated_structs_test.cc",
"builtin_polyfill_test.cc",
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.cc b/src/tint/lang/core/ir/transform/binary_polyfill.cc
new file mode 100644
index 0000000..b566d68
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.cc
@@ -0,0 +1,145 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::core::ir::transform {
+
+namespace {
+
+/// PIMPL state for the transform.
+struct State {
+ /// The polyfill config.
+ const BinaryPolyfillConfig& config;
+
+ /// The IR module.
+ Module* ir = nullptr;
+
+ /// The IR builder.
+ Builder b{*ir};
+
+ /// The type manager.
+ core::type::Manager& ty{ir->Types()};
+
+ /// The symbol table.
+ SymbolTable& sym{ir->symbols};
+
+ /// Process the module.
+ void Process() {
+ // Find the binary instructions that need to be polyfilled.
+ Vector<ir::Binary*, 64> worklist;
+ for (auto* inst : ir->instructions.Objects()) {
+ if (!inst->Alive()) {
+ continue;
+ }
+ if (auto* binary = inst->As<ir::Binary>()) {
+ switch (binary->Kind()) {
+ case ir::Binary::Kind::kShiftLeft:
+ case ir::Binary::Kind::kShiftRight:
+ if (config.bitshift_modulo) {
+ worklist.Push(binary);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ // Polyfill the binary instructions that we found.
+ for (auto* binary : worklist) {
+ ir::Value* replacement = nullptr;
+ switch (binary->Kind()) {
+ case Binary::Kind::kShiftLeft:
+ case Binary::Kind::kShiftRight:
+ replacement = MaskShiftAmount(binary);
+ break;
+ default:
+ break;
+ }
+ TINT_ASSERT_OR_RETURN(replacement);
+
+ if (replacement != binary->Result()) {
+ // Replace the old binary instruction result with the new value.
+ if (auto name = ir->NameOf(binary->Result())) {
+ ir->SetName(replacement, name);
+ }
+ binary->Result()->ReplaceAllUsesWith(replacement);
+ binary->Destroy();
+ }
+ }
+ }
+
+ /// Return a type with element type @p type that has the same number of vector components as
+ /// @p match. If @p match is scalar just return @p type.
+ /// @param el_ty the type to extend
+ /// @param match the type to match the component count of
+ /// @returns a type with the same number of vector components as @p match
+ const core::type::Type* MatchWidth(const core::type::Type* el_ty,
+ const core::type::Type* match) {
+ if (auto* vec = match->As<core::type::Vector>()) {
+ return ty.vec(el_ty, vec->Width());
+ }
+ return el_ty;
+ }
+
+ /// Return a constant that has the same number of vector components as @p match, each with the
+ /// value @p element. If @p match is scalar just return @p element.
+ /// @param element the value to extend
+ /// @param match the type to match the component count of
+ /// @returns a value with the same number of vector components as @p match
+ ir::Constant* MatchWidth(ir::Constant* element, const core::type::Type* match) {
+ if (auto* vec = match->As<core::type::Vector>()) {
+ return b.Splat(MatchWidth(element->Type(), match), element, vec->Width());
+ }
+ return element;
+ }
+
+ /// Mask the RHS of a shift instruction to ensure it is modulo the bitwidth of the LHS.
+ /// @param binary the binary instruction
+ /// @returns the replacement value
+ ir::Value* MaskShiftAmount(ir::Binary* binary) {
+ auto* lhs = binary->LHS();
+ auto* rhs = binary->RHS();
+ auto* mask = b.Constant(u32(lhs->Type()->DeepestElement()->Size() * 8 - 1));
+ auto* masked = b.And(rhs->Type(), rhs, MatchWidth(mask, rhs->Type()));
+ masked->InsertBefore(binary);
+ binary->SetOperand(ir::Binary::kRhsOperandOffset, masked->Result());
+ return binary->Result();
+ }
+};
+
+} // namespace
+
+Result<SuccessType, std::string> BinaryPolyfill(Module* ir, const BinaryPolyfillConfig& config) {
+ auto result = ValidateAndDumpIfNeeded(*ir, "BinaryPolyfill transform");
+ if (!result) {
+ return result;
+ }
+
+ State{config, ir}.Process();
+
+ return Success;
+}
+
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.h b/src/tint/lang/core/ir/transform/binary_polyfill.h
new file mode 100644
index 0000000..7c84a88
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.h
@@ -0,0 +1,44 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_BINARY_POLYFILL_H_
+#define SRC_TINT_LANG_CORE_IR_TRANSFORM_BINARY_POLYFILL_H_
+
+#include <string>
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::core::ir::transform {
+
+/// The set of polyfills that should be applied.
+struct BinaryPolyfillConfig {
+ /// Should the RHS of a shift be masked to make it modulo the bit-width of the LHS?
+ bool bitshift_modulo = false;
+};
+
+/// BinaryPolyfill is a transform that modifies binary instructions to prepare them for raising to
+/// backend dialects that may have different semantics.
+/// @param module the module to transform
+/// @param config the polyfill configuration
+/// @returns an error string on failure
+Result<SuccessType, std::string> BinaryPolyfill(Module* module, const BinaryPolyfillConfig& config);
+
+} // namespace tint::core::ir::transform
+
+#endif // SRC_TINT_LANG_CORE_IR_TRANSFORM_BINARY_POLYFILL_H_
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill_test.cc b/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
new file mode 100644
index 0000000..a937f1f
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
@@ -0,0 +1,317 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
+#include "src/tint/lang/core/ir/binary.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+
+namespace tint::core::ir::transform {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+class IR_BinaryPolyfillTest : public TransformTest {
+ protected:
+ /// Helper to build a function that executes a binary instruction.
+ /// @param kind the binary operation
+ /// @param result_ty the result type of the builtin call
+ /// @param lhs_ty the type of the LHS
+ /// @param rhs_ty the type of the RHS
+ void Build(enum ir::Binary::Kind kind,
+ const core::type::Type* result_ty,
+ const core::type::Type* lhs_ty,
+ const core::type::Type* rhs_ty) {
+ Vector<FunctionParam*, 4> args;
+ args.Push(b.FunctionParam("lhs", lhs_ty));
+ args.Push(b.FunctionParam("rhs", rhs_ty));
+ auto* func = b.Function("foo", result_ty);
+ func->SetParams(args);
+ b.Append(func->Block(), [&] {
+ auto* result = b.Binary(kind, result_ty, args[0], args[1]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+ }
+};
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_NoPolyfill) {
+ Build(Binary::Kind::kShiftLeft, ty.i32(), ty.i32(), ty.i32());
+ auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+ %b1 = block {
+ %result:i32 = shiftl %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = src;
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = false;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_NoPolyfill) {
+ Build(Binary::Kind::kShiftRight, ty.i32(), ty.i32(), ty.i32());
+ auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+ %b1 = block {
+ %result:i32 = shiftr %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = src;
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = false;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_I32) {
+ Build(Binary::Kind::kShiftLeft, ty.i32(), ty.i32(), ty.i32());
+ auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+ %b1 = block {
+ %result:i32 = shiftl %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:i32 = and %rhs, 31u
+ %result:i32 = shiftl %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_U32) {
+ Build(Binary::Kind::kShiftLeft, ty.u32(), ty.u32(), ty.u32());
+ auto* src = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+ %b1 = block {
+ %result:u32 = shiftl %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+ %b1 = block {
+ %4:u32 = and %rhs, 31u
+ %result:u32 = shiftl %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_Vec2I32) {
+ Build(Binary::Kind::kShiftLeft, ty.vec2<i32>(), ty.vec2<i32>(), ty.vec2<i32>());
+ auto* src = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+ %b1 = block {
+ %result:vec2<i32> = shiftl %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+ %b1 = block {
+ %4:vec2<i32> = and %rhs, vec2<u32>(31u)
+ %result:vec2<i32> = shiftl %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_Vec3U32) {
+ Build(Binary::Kind::kShiftLeft, ty.vec3<u32>(), ty.vec3<u32>(), ty.vec3<u32>());
+ auto* src = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+ %b1 = block {
+ %result:vec3<u32> = shiftl %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+ %b1 = block {
+ %4:vec3<u32> = and %rhs, vec3<u32>(31u)
+ %result:vec3<u32> = shiftl %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_I32) {
+ Build(Binary::Kind::kShiftRight, ty.i32(), ty.i32(), ty.i32());
+ auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+ %b1 = block {
+ %result:i32 = shiftr %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+ %b1 = block {
+ %4:i32 = and %rhs, 31u
+ %result:i32 = shiftr %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_U32) {
+ Build(Binary::Kind::kShiftRight, ty.u32(), ty.u32(), ty.u32());
+ auto* src = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+ %b1 = block {
+ %result:u32 = shiftr %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+ %b1 = block {
+ %4:u32 = and %rhs, 31u
+ %result:u32 = shiftr %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_Vec2I32) {
+ Build(Binary::Kind::kShiftRight, ty.vec2<i32>(), ty.vec2<i32>(), ty.vec2<i32>());
+ auto* src = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+ %b1 = block {
+ %result:vec2<i32> = shiftr %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+ %b1 = block {
+ %4:vec2<i32> = and %rhs, vec2<u32>(31u)
+ %result:vec2<i32> = shiftr %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_Vec3U32) {
+ Build(Binary::Kind::kShiftRight, ty.vec3<u32>(), ty.vec3<u32>(), ty.vec3<u32>());
+ auto* src = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+ %b1 = block {
+ %result:vec3<u32> = shiftr %lhs, %rhs
+ ret %result
+ }
+}
+)";
+ auto* expect = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+ %b1 = block {
+ %4:vec3<u32> = and %rhs, vec3<u32>(31u)
+ %result:vec3<u32> = shiftr %lhs, %4
+ ret %result
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+
+ BinaryPolyfillConfig config;
+ config.bitshift_modulo = true;
+ Run(BinaryPolyfill, config);
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 2d510db..2c519a3 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -18,6 +18,7 @@
#include "src/tint/lang/core/ir/transform/add_empty_entry_point.h"
#include "src/tint/lang/core/ir/transform/bgra8unorm_polyfill.h"
+#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
#include "src/tint/lang/core/ir/transform/block_decorated_structs.h"
#include "src/tint/lang/core/ir/transform/builtin_polyfill.h"
#include "src/tint/lang/core/ir/transform/demote_to_helper.h"
@@ -41,6 +42,10 @@
} \
} while (false)
+ core::ir::transform::BinaryPolyfillConfig binary_polyfills;
+ binary_polyfills.bitshift_modulo = true;
+ RUN_TRANSFORM(core::ir::transform::BinaryPolyfill, module, binary_polyfills);
+
core::ir::transform::BuiltinPolyfillConfig core_polyfills;
core_polyfills.count_leading_zeros = true;
core_polyfills.count_trailing_zeros = true;