[ir][spirv-writer] Expand implicit vector splats

Add a new transform that looks for implicit vector splats in construct
or binary instructions and expands them to be explicitly constructed
vectors instead.

Add an intrinsic for OpVectorTimesScalar and use that for the floating
point vector * scalar case, instead of pattern matching in the writer.

Bug: tint:1906
Change-Id: I0452ec375916d603ac28b19e362ae8d6c2f4f88e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/141234
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 729d601..cb88576 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -540,6 +540,8 @@
       sources += [
         "ir/transform/builtin_polyfill_spirv.cc",
         "ir/transform/builtin_polyfill_spirv.h",
+        "ir/transform/expand_implicit_splats.cc",
+        "ir/transform/expand_implicit_splats.h",
         "ir/transform/merge_return.cc",
         "ir/transform/merge_return.h",
         "ir/transform/shader_io_spirv.cc",
@@ -1908,6 +1910,7 @@
       if (tint_build_spv_writer) {
         sources += [
           "ir/transform/builtin_polyfill_spirv_test.cc",
+          "ir/transform/expand_implicit_splats_test.cc",
           "ir/transform/merge_return_test.cc",
           "ir/transform/shader_io_spirv_test.cc",
           "ir/transform/var_for_dynamic_index_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index e7b6cb2..996574d 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -679,6 +679,8 @@
     list(APPEND TINT_LIB_SRCS
       ir/transform/builtin_polyfill_spirv.cc
       ir/transform/builtin_polyfill_spirv.h
+      ir/transform/expand_implicit_splats.cc
+      ir/transform/expand_implicit_splats.h
       ir/transform/merge_return.cc
       ir/transform/merge_return.h
       ir/transform/shader_io_spirv.cc
@@ -1332,6 +1334,7 @@
     if(${TINT_BUILD_IR})
       list(APPEND TINT_TEST_SRCS
         ir/transform/builtin_polyfill_spirv_test.cc
+        ir/transform/expand_implicit_splats_test.cc
         ir/transform/merge_return_test.cc
         ir/transform/shader_io_spirv_test.cc
         ir/transform/var_for_dynamic_index_test.cc
diff --git a/src/tint/ir/call.h b/src/tint/ir/call.h
index 2d1ff35..ea7e14e 100644
--- a/src/tint/ir/call.h
+++ b/src/tint/ir/call.h
@@ -28,6 +28,10 @@
     /// @returns the call arguments
     virtual utils::Slice<Value*> Args() { return operands_.Slice(); }
 
+    /// Append a new argument to the argument list for this call instruction.
+    /// @param arg the argument value to append
+    void AppendArg(ir::Value* arg) { AddOperand(operands_.Length(), arg); }
+
   protected:
     /// Constructor
     Call();
diff --git a/src/tint/ir/intrinsic_call.cc b/src/tint/ir/intrinsic_call.cc
index 1170418..b75eb9f 100644
--- a/src/tint/ir/intrinsic_call.cc
+++ b/src/tint/ir/intrinsic_call.cc
@@ -37,12 +37,6 @@
         case IntrinsicCall::Kind::kSpirvDot:
             out << "spirv.dot";
             break;
-        case IntrinsicCall::Kind::kSpirvSelect:
-            out << "spirv.select";
-            break;
-        case IntrinsicCall::Kind::kSpirvSampledImage:
-            out << "spirv.sampled_image";
-            break;
         case IntrinsicCall::Kind::kSpirvImageSampleImplicitLod:
             out << "spirv.image_sample_implicit_lod";
             break;
@@ -55,6 +49,15 @@
         case IntrinsicCall::Kind::kSpirvImageSampleDrefExplicitLod:
             out << "spirv.image_sample_dref_implicit_lod";
             break;
+        case IntrinsicCall::Kind::kSpirvSampledImage:
+            out << "spirv.sampled_image";
+            break;
+        case IntrinsicCall::Kind::kSpirvSelect:
+            out << "spirv.select";
+            break;
+        case IntrinsicCall::Kind::kSpirvVectorTimesScalar:
+            out << "spirv.vector_times_scalar";
+            break;
     }
     return out;
 }
diff --git a/src/tint/ir/intrinsic_call.h b/src/tint/ir/intrinsic_call.h
index f535972..3bf93e0 100644
--- a/src/tint/ir/intrinsic_call.h
+++ b/src/tint/ir/intrinsic_call.h
@@ -30,12 +30,13 @@
     enum class Kind {
         // SPIR-V backend intrinsics.
         kSpirvDot,
-        kSpirvSelect,
-        kSpirvSampledImage,
         kSpirvImageSampleImplicitLod,
         kSpirvImageSampleExplicitLod,
         kSpirvImageSampleDrefImplicitLod,
         kSpirvImageSampleDrefExplicitLod,
+        kSpirvSampledImage,
+        kSpirvSelect,
+        kSpirvVectorTimesScalar,
     };
 
     /// Constructor
diff --git a/src/tint/ir/transform/expand_implicit_splats.cc b/src/tint/ir/transform/expand_implicit_splats.cc
new file mode 100644
index 0000000..9e052ce
--- /dev/null
+++ b/src/tint/ir/transform/expand_implicit_splats.cc
@@ -0,0 +1,108 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/expand_implicit_splats.h"
+
+#include <utility>
+
+#include "src/tint/ir/builder.h"
+#include "src/tint/ir/module.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ExpandImplicitSplats);
+
+using namespace tint::number_suffixes;  // NOLINT
+
+namespace tint::ir::transform {
+
+ExpandImplicitSplats::ExpandImplicitSplats() = default;
+
+ExpandImplicitSplats::~ExpandImplicitSplats() = default;
+
+void ExpandImplicitSplats::Run(ir::Module* ir, const DataMap&, DataMap&) const {
+    ir::Builder b(*ir);
+
+    // Find the instructions that use implicit splats and either modify them in place or record them
+    // to be replaced in a second pass.
+    utils::Vector<Binary*, 4> binary_worklist;
+    for (auto* inst : ir->instructions.Objects()) {
+        if (!inst->Alive()) {
+            continue;
+        }
+        if (auto* construct = inst->As<Construct>()) {
+            // A vector constructor with a single scalar argument needs to be modified to replicate
+            // the argument N times.
+            auto* vec = construct->Result()->Type()->As<type::Vector>();
+            if (vec &&  //
+                construct->Args().Length() == 1 &&
+                construct->Args()[0]->Type()->Is<type::Scalar>()) {
+                for (uint32_t i = 1; i < vec->Width(); i++) {
+                    construct->AppendArg(construct->Args()[0]);
+                }
+            }
+        } else if (auto* binary = inst->As<Binary>()) {
+            // A binary instruction that mixes vector and scalar operands needs to have the scalar
+            // operand replaced with an explicit vector constructor.
+            if (binary->Result()->Type()->Is<type::Vector>()) {
+                if (binary->LHS()->Type()->Is<type::Scalar>() ||
+                    binary->RHS()->Type()->Is<type::Scalar>()) {
+                    binary_worklist.Push(binary);
+                }
+            }
+        }
+    }
+
+    // Helper to expand a scalar operand of an instruction by replacing it with an explicitly
+    // constructed vector that matches the result type.
+    auto expand_operand = [&](Instruction* inst, size_t operand_idx) {
+        auto* vec = inst->Result()->Type()->As<type::Vector>();
+
+        utils::Vector<Value*, 4> args;
+        args.Resize(vec->Width(), inst->Operands()[operand_idx]);
+
+        auto* construct = b.Construct(vec, std::move(args));
+        construct->InsertBefore(inst);
+        inst->SetOperand(operand_idx, construct->Result());
+    };
+
+    // Replace scalar operands to binary instructions that produce vectors.
+    for (auto* binary : binary_worklist) {
+        auto* result_ty = binary->Result()->Type();
+        if (result_ty->is_float_vector() && binary->Kind() == Binary::Kind::kMultiply) {
+            // Use OpVectorTimesScalar for floating point multiply.
+            auto* vts = b.Call(result_ty, IntrinsicCall::Kind::kSpirvVectorTimesScalar);
+            if (binary->LHS()->Type()->Is<type::Scalar>()) {
+                vts->AppendArg(binary->RHS());
+                vts->AppendArg(binary->LHS());
+            } else {
+                vts->AppendArg(binary->LHS());
+                vts->AppendArg(binary->RHS());
+            }
+            if (auto name = ir->NameOf(binary)) {
+                ir->SetName(vts->Result(), name);
+            }
+            binary->Result()->ReplaceAllUsesWith(vts->Result());
+            binary->ReplaceWith(vts);
+            binary->Destroy();
+        } else {
+            // Expand the scalar argument into an explicitly constructed vector.
+            if (binary->LHS()->Type()->Is<type::Scalar>()) {
+                expand_operand(binary, Binary::kLhsOperandOffset);
+            } else if (binary->RHS()->Type()->Is<type::Scalar>()) {
+                expand_operand(binary, Binary::kRhsOperandOffset);
+            }
+        }
+    }
+}
+
+}  // namespace tint::ir::transform
diff --git a/src/tint/ir/transform/expand_implicit_splats.h b/src/tint/ir/transform/expand_implicit_splats.h
new file mode 100644
index 0000000..ec1ae18
--- /dev/null
+++ b/src/tint/ir/transform/expand_implicit_splats.h
@@ -0,0 +1,37 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_TRANSFORM_EXPAND_IMPLICIT_SPLATS_H_
+#define SRC_TINT_IR_TRANSFORM_EXPAND_IMPLICIT_SPLATS_H_
+
+#include "src/tint/ir/transform/transform.h"
+
+namespace tint::ir::transform {
+
+/// ExpandImplicitSplats is a transform that expands implicit vector splat operands in construct
+/// instructions and binary instructions where not supported by SPIR-V.
+class ExpandImplicitSplats final : public utils::Castable<ExpandImplicitSplats, Transform> {
+  public:
+    /// Constructor
+    ExpandImplicitSplats();
+    /// Destructor
+    ~ExpandImplicitSplats() override;
+
+    /// @copydoc Transform::Run
+    void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+}  // namespace tint::ir::transform
+
+#endif  // SRC_TINT_IR_TRANSFORM_EXPAND_IMPLICIT_SPLATS_H_
diff --git a/src/tint/ir/transform/expand_implicit_splats_test.cc b/src/tint/ir/transform/expand_implicit_splats_test.cc
new file mode 100644
index 0000000..33e86b2
--- /dev/null
+++ b/src/tint/ir/transform/expand_implicit_splats_test.cc
@@ -0,0 +1,635 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/transform/expand_implicit_splats.h"
+
+#include <utility>
+
+#include "src/tint/ir/transform/test_helper.h"
+
+namespace tint::ir::transform {
+namespace {
+
+using namespace tint::builtin::fluent_types;  // NOLINT
+using namespace tint::number_suffixes;        // NOLINT
+
+using IR_ExpandImplicitSplatsTest = TransformTest;
+
+TEST_F(IR_ExpandImplicitSplatsTest, NoModify_Construct_VectorIdentity) {
+    auto* vector = b.FunctionParam("vector", ty.vec2<i32>());
+    auto* func = b.Function("foo", ty.vec2<i32>());
+    func->SetParams({vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Construct(ty.vec2<i32>(), vector);
+        b.Return(func, result);
+    });
+
+    auto* expect = R"(
+%foo = func(%vector:vec2<i32>):vec2<i32> -> %b1 {
+  %b1 = block {
+    %3:vec2<i32> = construct %vector
+    ret %3
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, NoModify_Construct_MixedScalarVector) {
+    auto* scalar = b.FunctionParam("scalar", ty.i32());
+    auto* vector = b.FunctionParam("vector", ty.vec2<i32>());
+    auto* func = b.Function("foo", ty.vec3<i32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Construct(ty.vec3<i32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* expect = R"(
+%foo = func(%scalar:i32, %vector:vec2<i32>):vec3<i32> -> %b1 {
+  %b1 = block {
+    %4:vec3<i32> = construct %scalar, %vector
+    ret %4
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, NoModify_Construct_AllScalars) {
+    auto* scalar = b.FunctionParam("scalar", ty.i32());
+    auto* func = b.Function("foo", ty.vec3<i32>());
+    func->SetParams({scalar});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Construct(ty.vec3<i32>(), scalar, scalar, scalar);
+        b.Return(func, result);
+    });
+
+    auto* expect = R"(
+%foo = func(%scalar:i32):vec3<i32> -> %b1 {
+  %b1 = block {
+    %3:vec3<i32> = construct %scalar, %scalar, %scalar
+    ret %3
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Construct_Splat_Vec2i) {
+    auto* scalar = b.FunctionParam("scalar", ty.i32());
+    auto* func = b.Function("foo", ty.vec2<i32>());
+    func->SetParams({scalar});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Construct(ty.vec2<i32>(), scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:i32):vec2<i32> -> %b1 {
+  %b1 = block {
+    %3:vec2<i32> = construct %scalar
+    ret %3
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:i32):vec2<i32> -> %b1 {
+  %b1 = block {
+    %3:vec2<i32> = construct %scalar, %scalar
+    ret %3
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Construct_Splat_Vec3u) {
+    auto* scalar = b.FunctionParam("scalar", ty.u32());
+    auto* func = b.Function("foo", ty.vec3<u32>());
+    func->SetParams({scalar});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Construct(ty.vec3<u32>(), scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:u32):vec3<u32> -> %b1 {
+  %b1 = block {
+    %3:vec3<u32> = construct %scalar
+    ret %3
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:u32):vec3<u32> -> %b1 {
+  %b1 = block {
+    %3:vec3<u32> = construct %scalar, %scalar, %scalar
+    ret %3
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, Construct_Splat_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Construct(ty.vec4<f32>(), scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32):vec4<f32> -> %b1 {
+  %b1 = block {
+    %3:vec4<f32> = construct %scalar
+    ret %3
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32):vec4<f32> -> %b1 {
+  %b1 = block {
+    %3:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    ret %3
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryAdd_VectorScalar_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Add(ty.vec4<f32>(), vector, scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = add %vector, %scalar
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = add %vector, %4
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryAdd_ScalarVector_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Add(ty.vec4<f32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = add %scalar, %vector
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = add %4, %vector
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinarySubtract_VectorScalar_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Subtract(ty.vec4<f32>(), vector, scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = sub %vector, %scalar
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = sub %vector, %4
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinarySubtract_ScalarVector_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Subtract(ty.vec4<f32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = sub %scalar, %vector
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = sub %4, %vector
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryDivide_VectorScalar_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Divide(ty.vec4<f32>(), vector, scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = div %vector, %scalar
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = div %vector, %4
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryDivide_ScalarVector_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Divide(ty.vec4<f32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = div %scalar, %vector
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = div %4, %vector
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryModulo_VectorScalar_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Modulo(ty.vec4<f32>(), vector, scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = mod %vector, %scalar
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = mod %vector, %4
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryModulo_ScalarVector_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Modulo(ty.vec4<f32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = mod %scalar, %vector
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<f32> = mod %4, %vector
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_VectorScalar_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Multiply(ty.vec4<f32>(), vector, scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = mul %vector, %scalar
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = spirv.vector_times_scalar %vector, %scalar
+    ret %4
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_ScalarVector_Vec4f) {
+    auto* scalar = b.FunctionParam("scalar", ty.f32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+    auto* func = b.Function("foo", ty.vec4<f32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Multiply(ty.vec4<f32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = mul %scalar, %vector
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:f32, %vector:vec4<f32>):vec4<f32> -> %b1 {
+  %b1 = block {
+    %4:vec4<f32> = spirv.vector_times_scalar %vector, %scalar
+    ret %4
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_VectorScalar_Vec4i) {
+    auto* scalar = b.FunctionParam("scalar", ty.i32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<i32>());
+    auto* func = b.Function("foo", ty.vec4<i32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Multiply(ty.vec4<i32>(), vector, scalar);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+  %b1 = block {
+    %4:vec4<i32> = mul %vector, %scalar
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+  %b1 = block {
+    %4:vec4<i32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<i32> = mul %vector, %4
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ExpandImplicitSplatsTest, BinaryMultiply_ScalarVector_Vec4i) {
+    auto* scalar = b.FunctionParam("scalar", ty.i32());
+    auto* vector = b.FunctionParam("vector", ty.vec4<i32>());
+    auto* func = b.Function("foo", ty.vec4<i32>());
+    func->SetParams({scalar, vector});
+
+    b.With(func->Block(), [&] {
+        auto* result = b.Multiply(ty.vec4<i32>(), scalar, vector);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+  %b1 = block {
+    %4:vec4<i32> = mul %scalar, %vector
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%foo = func(%scalar:i32, %vector:vec4<i32>):vec4<i32> -> %b1 {
+  %b1 = block {
+    %4:vec4<i32> = construct %scalar, %scalar, %scalar, %scalar
+    %5:vec4<i32> = mul %4, %vector
+    ret %5
+  }
+}
+)";
+
+    Run<ExpandImplicitSplats>();
+
+    EXPECT_EQ(expect, str());
+}
+
+}  // namespace
+}  // namespace tint::ir::transform
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index bae2f8e..0ae9796 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -46,6 +46,7 @@
 #include "src/tint/ir/transform/block_decorated_structs.h"
 #include "src/tint/ir/transform/builtin_polyfill_spirv.h"
 #include "src/tint/ir/transform/demote_to_helper.h"
+#include "src/tint/ir/transform/expand_implicit_splats.h"
 #include "src/tint/ir/transform/merge_return.h"
 #include "src/tint/ir/transform/shader_io_spirv.h"
 #include "src/tint/ir/transform/var_for_dynamic_index.h"
@@ -94,6 +95,7 @@
     manager.Add<ir::transform::BlockDecoratedStructs>();
     manager.Add<ir::transform::BuiltinPolyfillSpirv>();
     manager.Add<ir::transform::DemoteToHelper>();
+    manager.Add<ir::transform::ExpandImplicitSplats>();
     manager.Add<ir::transform::MergeReturn>();
     manager.Add<ir::transform::ShaderIOSpirv>();
     manager.Add<ir::transform::VarForDynamicIndex>();
@@ -978,13 +980,6 @@
             } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_vector()) {
                 // Two float vectors multiply with OpFMul.
                 op = spv::Op::OpFMul;
-            } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_vector()) {
-                // Use OpVectorTimesScalar for scalar * vector, and swap the operand order.
-                std::swap(lhs, rhs);
-                op = spv::Op::OpVectorTimesScalar;
-            } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_scalar()) {
-                // Use OpVectorTimesScalar for scalar * vector.
-                op = spv::Op::OpVectorTimesScalar;
             } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_matrix()) {
                 // Use OpMatrixTimesScalar for scalar * matrix, and swap the operand order.
                 std::swap(lhs, rhs);
@@ -1446,12 +1441,6 @@
         case ir::IntrinsicCall::Kind::kSpirvDot:
             op = spv::Op::OpDot;
             break;
-        case ir::IntrinsicCall::Kind::kSpirvSelect:
-            op = spv::Op::OpSelect;
-            break;
-        case ir::IntrinsicCall::Kind::kSpirvSampledImage:
-            op = spv::Op::OpSampledImage;
-            break;
         case ir::IntrinsicCall::Kind::kSpirvImageSampleImplicitLod:
             op = spv::Op::OpImageSampleImplicitLod;
             break;
@@ -1464,6 +1453,15 @@
         case ir::IntrinsicCall::Kind::kSpirvImageSampleDrefExplicitLod:
             op = spv::Op::OpImageSampleDrefExplicitLod;
             break;
+        case ir::IntrinsicCall::Kind::kSpirvSampledImage:
+            op = spv::Op::OpSampledImage;
+            break;
+        case ir::IntrinsicCall::Kind::kSpirvSelect:
+            op = spv::Op::OpSelect;
+            break;
+        case ir::IntrinsicCall::Kind::kSpirvVectorTimesScalar:
+            op = spv::Op::OpVectorTimesScalar;
+            break;
     }
 
     OperandList operands = {Type(call->Result()->Type()), id};