[ir][ms] Emit Lets

Add emission of `let` instructions to the MSL IR generator.

Bug: tint:1967
Change-Id: I5b6af981bce7ee4e497b101c95a1fc662f70cc3e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144042
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 16920a1..0c554e2 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2323,6 +2323,7 @@
         "lang/msl/writer/printer/function_test.cc",
         "lang/msl/writer/printer/helper_test.h",
         "lang/msl/writer/printer/if_test.cc",
+        "lang/msl/writer/printer/let_test.cc",
         "lang/msl/writer/printer/return_test.cc",
         "lang/msl/writer/printer/type_test.cc",
       ]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 7c52c4b..89e7d6c 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1556,6 +1556,7 @@
         lang/msl/writer/printer/function_test.cc
         lang/msl/writer/printer/helper_test.h
         lang/msl/writer/printer/if_test.cc
+        lang/msl/writer/printer/let_test.cc
         lang/msl/writer/printer/return_test.cc
         lang/msl/writer/printer/type_test.cc
       )
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 2048955..a71257c 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -57,10 +57,12 @@
 #include "src/tint/lang/core/ir/user_call.h"
 #include "src/tint/lang/core/ir/value.h"
 #include "src/tint/lang/core/ir/var.h"
+#include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/bool.h"
 #include "src/tint/lang/core/type/f16.h"
 #include "src/tint/lang/core/type/f32.h"
 #include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/matrix.h"
 #include "src/tint/lang/core/type/pointer.h"
 #include "src/tint/lang/core/type/u32.h"
 #include "src/tint/lang/core/type/vector.h"
@@ -219,29 +221,72 @@
     /// Creates a ir::Constant for an i32 Scalar
     /// @param v the value
     /// @returns the new constant
-    ir::Constant* Constant(i32 v) { return Constant(ir.constant_values.Get(v)); }
+    ir::Constant* Constant(i32 v) { return Constant(ConstantValue(v)); }
 
     /// Creates a ir::Constant for a u32 Scalar
     /// @param v the value
     /// @returns the new constant
-    ir::Constant* Constant(u32 v) { return Constant(ir.constant_values.Get(v)); }
+    ir::Constant* Constant(u32 v) { return Constant(ConstantValue(v)); }
 
     /// Creates a ir::Constant for a f32 Scalar
     /// @param v the value
     /// @returns the new constant
-    ir::Constant* Constant(f32 v) { return Constant(ir.constant_values.Get(v)); }
+    ir::Constant* Constant(f32 v) { return Constant(ConstantValue(v)); }
 
     /// Creates a ir::Constant for a f16 Scalar
     /// @param v the value
     /// @returns the new constant
-    ir::Constant* Constant(f16 v) { return Constant(ir.constant_values.Get(v)); }
+    ir::Constant* Constant(f16 v) { return Constant(ConstantValue(v)); }
 
     /// Creates a ir::Constant for a bool Scalar
     /// @param v the value
     /// @returns the new constant
     template <typename BOOL, typename = std::enable_if_t<std::is_same_v<BOOL, bool>>>
     ir::Constant* Constant(BOOL v) {
-        return Constant(ir.constant_values.Get(v));
+        return Constant(ConstantValue(v));
+    }
+
+    /// Retrieves the inner constant from an ir::Constant
+    /// @param constant the ir constant
+    /// @returns the constant::Value inside the constant
+    const constant::Value* ConstantValue(ir::Constant* constant) { return constant->Value(); }
+
+    /// Creates a constant::Value for an i32 Scalar
+    /// @param v the value
+    /// @returns the new constant
+    const constant::Value* ConstantValue(i32 v) { return ir.constant_values.Get(v); }
+
+    /// Creates a constant::Value for a u32 Scalar
+    /// @param v the value
+    /// @returns the new constant
+    const constant::Value* ConstantValue(u32 v) { return ir.constant_values.Get(v); }
+
+    /// Creates a constant::Value for a f32 Scalar
+    /// @param v the value
+    /// @returns the new constant
+    const constant::Value* ConstantValue(f32 v) { return ir.constant_values.Get(v); }
+
+    /// Creates a constant::Value for a f16 Scalar
+    /// @param v the value
+    /// @returns the new constant
+    const constant::Value* ConstantValue(f16 v) { return ir.constant_values.Get(v); }
+
+    /// Creates a constant::Value for a bool Scalar
+    /// @param v the value
+    /// @returns the new constant
+    template <typename BOOL, typename = std::enable_if_t<std::is_same_v<BOOL, bool>>>
+    const constant::Value* ConstantValue(BOOL v) {
+        return ir.constant_values.Get(v);
+    }
+
+    /// Creates a new ir::Constant
+    /// @param ty the constant type
+    /// @param values the composite values
+    /// @returns the new constant
+    template <typename... ARGS, typename = DisableIfVectorLike<ARGS...>>
+    ir::Constant* Composite(const type::Type* ty, ARGS&&... values) {
+        return Constant(
+            ir.constant_values.Composite(ty, Vector{ConstantValue(std::forward<ARGS>(values))...}));
     }
 
     /// @param in the input value. One of: nullptr, ir::Value*, ir::Instruction* or a numeric value.
diff --git a/src/tint/lang/msl/writer/printer/helper_test.h b/src/tint/lang/msl/writer/printer/helper_test.h
index 152cf3b..9569f75 100644
--- a/src/tint/lang/msl/writer/printer/helper_test.h
+++ b/src/tint/lang/msl/writer/printer/helper_test.h
@@ -29,6 +29,20 @@
 using namespace metal;
 )";
 
+constexpr auto kMetalArray = R"(template<typename T, size_t N>
+struct tint_array {
+  const constant T& operator[](size_t i) const constant { return elements[i]; }
+  device T& operator[](size_t i) device { return elements[i]; }
+  const device T& operator[](size_t i) const device { return elements[i]; }
+  thread T& operator[](size_t i) thread { return elements[i]; }
+  const thread T& operator[](size_t i) const thread { return elements[i]; }
+  threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+  const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+  T elements[N];
+};
+
+)";
+
 /// Base helper class for testing the MSL generator implementation.
 template <typename BASE>
 class MslPrinterTestHelperBase : public BASE {
@@ -44,6 +58,8 @@
 
     /// @returns the metal header string
     std::string MetalHeader() const { return kMetalHeader; }
+    /// @return the metal array string
+    std::string MetalArray() const { return kMetalArray; }
 
   protected:
     /// The MSL generator.
diff --git a/src/tint/lang/msl/writer/printer/let_test.cc b/src/tint/lang/msl/writer/printer/let_test.cc
new file mode 100644
index 0000000..888ac16
--- /dev/null
+++ b/src/tint/lang/msl/writer/printer/let_test.cc
@@ -0,0 +1,201 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/msl/writer/printer/helper_test.h"
+
+namespace tint::msl::writer {
+namespace {
+
+using namespace tint::builtin::fluent_types;  // NOLINT
+using namespace tint::number_suffixes;        // NOLINT
+
+TEST_F(MslPrinterTest, LetU32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", 42_u);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  uint const l = 42u;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetDuplicate) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l1", 42_u);
+        b.Let("l2", 42_u);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  uint const l1 = 42u;
+  uint const l2 = 42u;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetF32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", 42.0_f);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float const l = 42.0f;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetI32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", 42_i);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  int const l = 42;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetF16) {
+    // Enable F16?
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", 42_h);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half const l = 42.0h;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetVec3F32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", b.Composite(ty.vec3<f32>(), 1_f, 2_f, 3_f));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float3 const l = float3(1.0f, 2.0f, 3.0f);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetVec3F16) {
+    // Enable f16?
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", b.Composite(ty.vec3<f16>(), 1_h, 2_h, 3_h));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half3 const l = half3(1.0h, 2.0h, 3.0h);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetMat2x3F32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", b.Composite(ty.mat2x3<f32>(),  //
+                               b.Composite(ty.vec3<f32>(), 1_f, 2_f, 3_f),
+                               b.Composite(ty.vec3<f32>(), 4_f, 5_f, 6_f)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float2x3 const l = float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f));
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetMat2x3F16) {
+    // Enable f16?
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", b.Composite(ty.mat2x3<f16>(),  //
+                               b.Composite(ty.vec3<f16>(), 1_h, 2_h, 3_h),
+                               b.Composite(ty.vec3<f16>(), 4_h, 5_h, 6_h)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h));
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetArrF32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", b.Composite(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + MetalArray() + R"(
+void foo() {
+  tint_array<float, 3> const l = tint_array<float, 3>{1.0f, 2.0f, 3.0f};
+}
+)");
+}
+
+TEST_F(MslPrinterTest, LetArrVec2BoolEmit_VariableDeclStatement_Const_arr_vec2_bool) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("l", b.Composite(ty.array<vec2<bool>, 3>(),  //
+                               b.Composite(ty.vec2<bool>(), true, false),
+                               b.Composite(ty.vec2<bool>(), false, true),
+                               b.Composite(ty.vec2<bool>(), true, false)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + MetalArray() + R"(
+void foo() {
+  tint_array<bool2, 3> const l = tint_array<bool2, 3>{bool2(true, false), bool2(false, true), bool2(true, false)};
+}
+)");
+}
+
+}  // namespace
+}  // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 1f0a5cc..efa20c1 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -21,6 +21,7 @@
 #include "src/tint/lang/core/ir/constant.h"
 #include "src/tint/lang/core/ir/exit_if.h"
 #include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/let.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
 #include "src/tint/lang/core/ir/return.h"
 #include "src/tint/lang/core/ir/unreachable.h"
@@ -168,12 +169,17 @@
             inst,                                          //
             [&](ir::ExitIf* e) { EmitExitIf(e); },         //
             [&](ir::If* if_) { EmitIf(if_); },             //
+            [&](ir::Let* l) { EmitLet(l); },               //
             [&](ir::Return* r) { EmitReturn(r); },         //
             [&](ir::Unreachable*) { EmitUnreachable(); },  //
             [&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
     }
 }
 
+void Printer::EmitLet(ir::Let* l) {
+    Bind(l->Result(), Expr(l->Value(), PtrKind::kPtr), PtrKind::kPtr);
+}
+
 void Printer::EmitIf(ir::If* if_) {
     // Emit any nodes that need to be used as PHI nodes
     for (auto* phi : if_->Results()) {
@@ -723,6 +729,7 @@
             } else {
                 out << expr;
             }
+            out << ";";
 
             Bind(value, mod_name, PtrKind::kPtr);
         }
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 3144147..2bf3b76 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -29,6 +29,7 @@
 namespace tint::ir {
 class ExitIf;
 class If;
+class Let;
 class Return;
 class Unreachable;
 }  // namespace tint::ir
@@ -67,6 +68,10 @@
     /// @param e the exit-if instruction
     void EmitExitIf(ir::ExitIf* e);
 
+    /// Emit a let instruction
+    /// @param l the let instruction
+    void EmitLet(ir::Let* l);
+
     /// Emit a return instruction
     /// @param r the return instruction
     void EmitReturn(ir::Return* r);