[ir] Add BinaryPolyfill transform

This is for generic binary instruction polyfills that use core IR
constructs. Add a polyfill for the RHS of a bitwise shift.

Use this in the SPIR-V writer.

Bug: tint:1718, tint:1906
Change-Id: I2165af4ffb78a9d284f5762e696a3b2c09a5d2f1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/147301
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/transform/BUILD.cmake b/src/tint/lang/core/ir/transform/BUILD.cmake
index 01ee55d..664b66c 100644
--- a/src/tint/lang/core/ir/transform/BUILD.cmake
+++ b/src/tint/lang/core/ir/transform/BUILD.cmake
@@ -31,6 +31,8 @@
   lang/core/ir/transform/add_empty_entry_point.h
   lang/core/ir/transform/bgra8unorm_polyfill.cc
   lang/core/ir/transform/bgra8unorm_polyfill.h
+  lang/core/ir/transform/binary_polyfill.cc
+  lang/core/ir/transform/binary_polyfill.h
   lang/core/ir/transform/binding_remapper.cc
   lang/core/ir/transform/binding_remapper.h
   lang/core/ir/transform/block_decorated_structs.cc
@@ -84,6 +86,7 @@
 tint_add_target(tint_lang_core_ir_transform_test test
   lang/core/ir/transform/add_empty_entry_point_test.cc
   lang/core/ir/transform/bgra8unorm_polyfill_test.cc
+  lang/core/ir/transform/binary_polyfill_test.cc
   lang/core/ir/transform/binding_remapper_test.cc
   lang/core/ir/transform/block_decorated_structs_test.cc
   lang/core/ir/transform/builtin_polyfill_test.cc
diff --git a/src/tint/lang/core/ir/transform/BUILD.gn b/src/tint/lang/core/ir/transform/BUILD.gn
index d7950c0..ad8e700 100644
--- a/src/tint/lang/core/ir/transform/BUILD.gn
+++ b/src/tint/lang/core/ir/transform/BUILD.gn
@@ -30,6 +30,8 @@
       "add_empty_entry_point.h",
       "bgra8unorm_polyfill.cc",
       "bgra8unorm_polyfill.h",
+      "binary_polyfill.cc",
+      "binary_polyfill.h",
       "binding_remapper.cc",
       "binding_remapper.h",
       "block_decorated_structs.cc",
@@ -77,6 +79,7 @@
     sources = [
       "add_empty_entry_point_test.cc",
       "bgra8unorm_polyfill_test.cc",
+      "binary_polyfill_test.cc",
       "binding_remapper_test.cc",
       "block_decorated_structs_test.cc",
       "builtin_polyfill_test.cc",
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.cc b/src/tint/lang/core/ir/transform/binary_polyfill.cc
new file mode 100644
index 0000000..b566d68
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.cc
@@ -0,0 +1,145 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+namespace tint::core::ir::transform {
+
+namespace {
+
+/// PIMPL state for the transform.
+struct State {
+    /// The polyfill config.
+    const BinaryPolyfillConfig& config;
+
+    /// The IR module.
+    Module* ir = nullptr;
+
+    /// The IR builder.
+    Builder b{*ir};
+
+    /// The type manager.
+    core::type::Manager& ty{ir->Types()};
+
+    /// The symbol table.
+    SymbolTable& sym{ir->symbols};
+
+    /// Process the module.
+    void Process() {
+        // Find the binary instructions that need to be polyfilled.
+        Vector<ir::Binary*, 64> worklist;
+        for (auto* inst : ir->instructions.Objects()) {
+            if (!inst->Alive()) {
+                continue;
+            }
+            if (auto* binary = inst->As<ir::Binary>()) {
+                switch (binary->Kind()) {
+                    case ir::Binary::Kind::kShiftLeft:
+                    case ir::Binary::Kind::kShiftRight:
+                        if (config.bitshift_modulo) {
+                            worklist.Push(binary);
+                        }
+                        break;
+                    default:
+                        break;
+                }
+            }
+        }
+
+        // Polyfill the binary instructions that we found.
+        for (auto* binary : worklist) {
+            ir::Value* replacement = nullptr;
+            switch (binary->Kind()) {
+                case Binary::Kind::kShiftLeft:
+                case Binary::Kind::kShiftRight:
+                    replacement = MaskShiftAmount(binary);
+                    break;
+                default:
+                    break;
+            }
+            TINT_ASSERT_OR_RETURN(replacement);
+
+            if (replacement != binary->Result()) {
+                // Replace the old binary instruction result with the new value.
+                if (auto name = ir->NameOf(binary->Result())) {
+                    ir->SetName(replacement, name);
+                }
+                binary->Result()->ReplaceAllUsesWith(replacement);
+                binary->Destroy();
+            }
+        }
+    }
+
+    /// Return a type with element type @p type that has the same number of vector components as
+    /// @p match. If @p match is scalar just return @p type.
+    /// @param el_ty the type to extend
+    /// @param match the type to match the component count of
+    /// @returns a type with the same number of vector components as @p match
+    const core::type::Type* MatchWidth(const core::type::Type* el_ty,
+                                       const core::type::Type* match) {
+        if (auto* vec = match->As<core::type::Vector>()) {
+            return ty.vec(el_ty, vec->Width());
+        }
+        return el_ty;
+    }
+
+    /// Return a constant that has the same number of vector components as @p match, each with the
+    /// value @p element. If @p match is scalar just return @p element.
+    /// @param element the value to extend
+    /// @param match the type to match the component count of
+    /// @returns a value with the same number of vector components as @p match
+    ir::Constant* MatchWidth(ir::Constant* element, const core::type::Type* match) {
+        if (auto* vec = match->As<core::type::Vector>()) {
+            return b.Splat(MatchWidth(element->Type(), match), element, vec->Width());
+        }
+        return element;
+    }
+
+    /// Mask the RHS of a shift instruction to ensure it is modulo the bitwidth of the LHS.
+    /// @param binary the binary instruction
+    /// @returns the replacement value
+    ir::Value* MaskShiftAmount(ir::Binary* binary) {
+        auto* lhs = binary->LHS();
+        auto* rhs = binary->RHS();
+        auto* mask = b.Constant(u32(lhs->Type()->DeepestElement()->Size() * 8 - 1));
+        auto* masked = b.And(rhs->Type(), rhs, MatchWidth(mask, rhs->Type()));
+        masked->InsertBefore(binary);
+        binary->SetOperand(ir::Binary::kRhsOperandOffset, masked->Result());
+        return binary->Result();
+    }
+};
+
+}  // namespace
+
+Result<SuccessType, std::string> BinaryPolyfill(Module* ir, const BinaryPolyfillConfig& config) {
+    auto result = ValidateAndDumpIfNeeded(*ir, "BinaryPolyfill transform");
+    if (!result) {
+        return result;
+    }
+
+    State{config, ir}.Process();
+
+    return Success;
+}
+
+}  // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.h b/src/tint/lang/core/ir/transform/binary_polyfill.h
new file mode 100644
index 0000000..7c84a88
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.h
@@ -0,0 +1,44 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_BINARY_POLYFILL_H_
+#define SRC_TINT_LANG_CORE_IR_TRANSFORM_BINARY_POLYFILL_H_
+
+#include <string>
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::core::ir::transform {
+
+/// The set of polyfills that should be applied.
+struct BinaryPolyfillConfig {
+    /// Should the RHS of a shift be masked to make it modulo the bit-width of the LHS?
+    bool bitshift_modulo = false;
+};
+
+/// BinaryPolyfill is a transform that modifies binary instructions to prepare them for raising to
+/// backend dialects that may have different semantics.
+/// @param module the module to transform
+/// @param config the polyfill configuration
+/// @returns an error string on failure
+Result<SuccessType, std::string> BinaryPolyfill(Module* module, const BinaryPolyfillConfig& config);
+
+}  // namespace tint::core::ir::transform
+
+#endif  // SRC_TINT_LANG_CORE_IR_TRANSFORM_BINARY_POLYFILL_H_
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill_test.cc b/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
new file mode 100644
index 0000000..a937f1f
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/binary_polyfill_test.cc
@@ -0,0 +1,317 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
+#include "src/tint/lang/core/ir/binary.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+
+namespace tint::core::ir::transform {
+namespace {
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+class IR_BinaryPolyfillTest : public TransformTest {
+  protected:
+    /// Helper to build a function that executes a binary instruction.
+    /// @param kind the binary operation
+    /// @param result_ty the result type of the builtin call
+    /// @param lhs_ty the type of the LHS
+    /// @param rhs_ty the type of the RHS
+    void Build(enum ir::Binary::Kind kind,
+               const core::type::Type* result_ty,
+               const core::type::Type* lhs_ty,
+               const core::type::Type* rhs_ty) {
+        Vector<FunctionParam*, 4> args;
+        args.Push(b.FunctionParam("lhs", lhs_ty));
+        args.Push(b.FunctionParam("rhs", rhs_ty));
+        auto* func = b.Function("foo", result_ty);
+        func->SetParams(args);
+        b.Append(func->Block(), [&] {
+            auto* result = b.Binary(kind, result_ty, args[0], args[1]);
+            b.Return(func, result);
+            mod.SetName(result, "result");
+        });
+    }
+};
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_NoPolyfill) {
+    Build(Binary::Kind::kShiftLeft, ty.i32(), ty.i32(), ty.i32());
+    auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+  %b1 = block {
+    %result:i32 = shiftl %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = src;
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = false;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_NoPolyfill) {
+    Build(Binary::Kind::kShiftRight, ty.i32(), ty.i32(), ty.i32());
+    auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+  %b1 = block {
+    %result:i32 = shiftr %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = src;
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = false;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_I32) {
+    Build(Binary::Kind::kShiftLeft, ty.i32(), ty.i32(), ty.i32());
+    auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+  %b1 = block {
+    %result:i32 = shiftl %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+  %b1 = block {
+    %4:i32 = and %rhs, 31u
+    %result:i32 = shiftl %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_U32) {
+    Build(Binary::Kind::kShiftLeft, ty.u32(), ty.u32(), ty.u32());
+    auto* src = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+  %b1 = block {
+    %result:u32 = shiftl %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+  %b1 = block {
+    %4:u32 = and %rhs, 31u
+    %result:u32 = shiftl %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_Vec2I32) {
+    Build(Binary::Kind::kShiftLeft, ty.vec2<i32>(), ty.vec2<i32>(), ty.vec2<i32>());
+    auto* src = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+  %b1 = block {
+    %result:vec2<i32> = shiftl %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+  %b1 = block {
+    %4:vec2<i32> = and %rhs, vec2<u32>(31u)
+    %result:vec2<i32> = shiftl %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftLeft_Vec3U32) {
+    Build(Binary::Kind::kShiftLeft, ty.vec3<u32>(), ty.vec3<u32>(), ty.vec3<u32>());
+    auto* src = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+  %b1 = block {
+    %result:vec3<u32> = shiftl %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+  %b1 = block {
+    %4:vec3<u32> = and %rhs, vec3<u32>(31u)
+    %result:vec3<u32> = shiftl %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_I32) {
+    Build(Binary::Kind::kShiftRight, ty.i32(), ty.i32(), ty.i32());
+    auto* src = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+  %b1 = block {
+    %result:i32 = shiftr %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:i32, %rhs:i32):i32 -> %b1 {
+  %b1 = block {
+    %4:i32 = and %rhs, 31u
+    %result:i32 = shiftr %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_U32) {
+    Build(Binary::Kind::kShiftRight, ty.u32(), ty.u32(), ty.u32());
+    auto* src = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+  %b1 = block {
+    %result:u32 = shiftr %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:u32, %rhs:u32):u32 -> %b1 {
+  %b1 = block {
+    %4:u32 = and %rhs, 31u
+    %result:u32 = shiftr %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_Vec2I32) {
+    Build(Binary::Kind::kShiftRight, ty.vec2<i32>(), ty.vec2<i32>(), ty.vec2<i32>());
+    auto* src = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+  %b1 = block {
+    %result:vec2<i32> = shiftr %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:vec2<i32>, %rhs:vec2<i32>):vec2<i32> -> %b1 {
+  %b1 = block {
+    %4:vec2<i32> = and %rhs, vec2<u32>(31u)
+    %result:vec2<i32> = shiftr %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_BinaryPolyfillTest, ShiftRight_Vec3U32) {
+    Build(Binary::Kind::kShiftRight, ty.vec3<u32>(), ty.vec3<u32>(), ty.vec3<u32>());
+    auto* src = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+  %b1 = block {
+    %result:vec3<u32> = shiftr %lhs, %rhs
+    ret %result
+  }
+}
+)";
+    auto* expect = R"(
+%foo = func(%lhs:vec3<u32>, %rhs:vec3<u32>):vec3<u32> -> %b1 {
+  %b1 = block {
+    %4:vec3<u32> = and %rhs, vec3<u32>(31u)
+    %result:vec3<u32> = shiftr %lhs, %4
+    ret %result
+  }
+}
+)";
+
+    EXPECT_EQ(src, str());
+
+    BinaryPolyfillConfig config;
+    config.bitshift_modulo = true;
+    Run(BinaryPolyfill, config);
+    EXPECT_EQ(expect, str());
+}
+
+}  // namespace
+}  // namespace tint::core::ir::transform
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 2d510db..2c519a3 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -18,6 +18,7 @@
 
 #include "src/tint/lang/core/ir/transform/add_empty_entry_point.h"
 #include "src/tint/lang/core/ir/transform/bgra8unorm_polyfill.h"
+#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
 #include "src/tint/lang/core/ir/transform/block_decorated_structs.h"
 #include "src/tint/lang/core/ir/transform/builtin_polyfill.h"
 #include "src/tint/lang/core/ir/transform/demote_to_helper.h"
@@ -41,6 +42,10 @@
         }                                \
     } while (false)
 
+    core::ir::transform::BinaryPolyfillConfig binary_polyfills;
+    binary_polyfills.bitshift_modulo = true;
+    RUN_TRANSFORM(core::ir::transform::BinaryPolyfill, module, binary_polyfills);
+
     core::ir::transform::BuiltinPolyfillConfig core_polyfills;
     core_polyfills.count_leading_zeros = true;
     core_polyfills.count_trailing_zeros = true;