[ir][msl] Emit `var` instructions.

Add emission of `var` instructions to the MSL IR Generator.

Bug: tint:1967
Change-Id: I32e0a22262e1959872a1fc196152528ee7099b95
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144402
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index d60dc59..372bfae 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2328,6 +2328,7 @@
         "lang/msl/writer/printer/let_test.cc",
         "lang/msl/writer/printer/return_test.cc",
         "lang/msl/writer/printer/type_test.cc",
+        "lang/msl/writer/printer/var_test.cc",
       ]
       deps += [
         ":libtint_ir_src",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 890134e..886cac2 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1561,6 +1561,7 @@
         lang/msl/writer/printer/let_test.cc
         lang/msl/writer/printer/return_test.cc
         lang/msl/writer/printer/type_test.cc
+        lang/msl/writer/printer/var_test.cc
       )
     endif()
   endif()
diff --git a/src/tint/lang/msl/writer/printer/let_test.cc b/src/tint/lang/msl/writer/printer/let_test.cc
index 888ac16..519d905 100644
--- a/src/tint/lang/msl/writer/printer/let_test.cc
+++ b/src/tint/lang/msl/writer/printer/let_test.cc
@@ -179,7 +179,7 @@
 )");
 }
 
-TEST_F(MslPrinterTest, LetArrVec2BoolEmit_VariableDeclStatement_Const_arr_vec2_bool) {
+TEST_F(MslPrinterTest, LetArrVec2Bool) {
     auto* func = b.Function("foo", ty.void_());
     b.Append(func->Block(), [&] {
         b.Let("l", b.Composite(ty.array<vec2<bool>, 3>(),  //
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index efa20c1..831efff 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -22,10 +22,12 @@
 #include "src/tint/lang/core/ir/exit_if.h"
 #include "src/tint/lang/core/ir/if.h"
 #include "src/tint/lang/core/ir/let.h"
+#include "src/tint/lang/core/ir/load.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
 #include "src/tint/lang/core/ir/return.h"
 #include "src/tint/lang/core/ir/unreachable.h"
 #include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/core/ir/var.h"
 #include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/atomic.h"
 #include "src/tint/lang/core/type/bool.h"
@@ -84,7 +86,7 @@
 
     // Emit module-scope declarations.
     if (ir_->root_block) {
-        // EmitRootBlock(ir_->root_block);
+        EmitBlockInstructions(ir_->root_block);
     }
 
     // Emit functions.
@@ -170,12 +172,59 @@
             [&](ir::ExitIf* e) { EmitExitIf(e); },         //
             [&](ir::If* if_) { EmitIf(if_); },             //
             [&](ir::Let* l) { EmitLet(l); },               //
+            [&](ir::Load* l) { EmitLoad(l); },             //
             [&](ir::Return* r) { EmitReturn(r); },         //
             [&](ir::Unreachable*) { EmitUnreachable(); },  //
+            [&](ir::Var* v) { EmitVar(v); },               //
             [&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
     }
 }
 
+void Printer::EmitLoad(ir::Load* l) {
+    // Force loads to be bound as inlines
+    bindings_.Add(l->Result(), InlinedValue{Expr(l->From()), PtrKind::kRef});
+}
+
+void Printer::EmitVar(ir::Var* v) {
+    auto out = Line();
+
+    auto* ptr = v->Result()->Type()->As<type::Pointer>();
+    TINT_ASSERT_OR_RETURN(ptr);
+
+    auto space = ptr->AddressSpace();
+    switch (space) {
+        case builtin::AddressSpace::kFunction:
+        case builtin::AddressSpace::kHandle:
+            break;
+        case builtin::AddressSpace::kPrivate:
+            out << "thread ";
+            break;
+        case builtin::AddressSpace::kWorkgroup:
+            out << "threadgroup ";
+            break;
+        default:
+            TINT_ICE() << "unhandled variable address space";
+            return;
+    }
+
+    auto name = ir_->NameOf(v);
+
+    EmitType(out, ptr->UnwrapPtr());
+    out << " " << name.Name();
+
+    if (v->Initializer()) {
+        out << " = " << Expr(v->Initializer());
+    } else if (space == builtin::AddressSpace::kPrivate ||
+               space == builtin::AddressSpace::kFunction ||
+               space == builtin::AddressSpace::kUndefined) {
+        out << " = ";
+        EmitZeroValue(out, ptr->UnwrapPtr());
+    }
+    out << ";";
+
+    Bind(v->Result(), name, PtrKind::kRef);
+}
+
 void Printer::EmitLet(ir::Let* l) {
     Bind(l->Result(), Expr(l->Value(), PtrKind::kPtr), PtrKind::kPtr);
 }
@@ -225,8 +274,8 @@
 }
 
 void Printer::EmitReturn(ir::Return* r) {
-    // If this return has no arguments and the current block is for the function which is being
-    // returned, skip the return.
+    // If this return has no arguments and the current block is for the function which is
+    // being returned, skip the return.
     if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
         return;
     }
@@ -419,10 +468,10 @@
         return;
     }
 
-    // This does not append directly to the preamble because a struct may require other structs, or
-    // the array template, to get emitted before it. So, the struct emits into a temporary text
-    // buffer, then anything it depends on will emit to the preamble first, and then it copies the
-    // text buffer into the preamble.
+    // This does not append directly to the preamble because a struct may require other
+    // structs, or the array template, to get emitted before it. So, the struct emits into a
+    // temporary text buffer, then anything it depends on will emit to the preamble first,
+    // and then it copies the text buffer into the preamble.
     TextBuffer str_buf;
     Line(&str_buf) << "struct " << StructName(str) << " {";
 
@@ -623,6 +672,25 @@
         [&](Default) { UNHANDLED_CASE(c->Type()); });
 }
 
+void Printer::EmitZeroValue(StringStream& out, const type::Type* ty) {
+    Switch(
+        ty, [&](const type::Bool*) { out << "false"; },                     //
+        [&](const type::F16*) { out << "0.0h"; },                           //
+        [&](const type::F32*) { out << "0.0f"; },                           //
+        [&](const type::I32*) { out << "0"; },                              //
+        [&](const type::U32*) { out << "0u"; },                             //
+        [&](const type::Vector* vec) { EmitZeroValue(out, vec->type()); },  //
+        [&](const type::Matrix* mat) {
+            EmitType(out, mat);
+
+            ScopedParen sp(out);
+            EmitZeroValue(out, mat->type());
+        },
+        [&](const type::Array*) { out << "{}"; },   //
+        [&](const type::Struct*) { out << "{}"; },  //
+        [&](Default) { TINT_ICE() << "Invalid type for zero emission: " << ty->FriendlyName(); });
+}
+
 std::string Printer::StructName(const type::Struct* s) {
     auto name = s->Name().Name();
     if (HasPrefix(name, "__")) {
@@ -665,10 +733,12 @@
                     }
 
                     if constexpr (std::is_same_v<T, InlinedValue>) {
+                        auto result = ExprAndPtrKind{got.expr, got.ptr_kind};
+
                         // Single use (inlined) expression.
                         // Mark the bindings_ map entry as consumed.
                         *lookup = ConsumedValue{};
-                        return {got.expr, got.ptr_kind};
+                        return result;
                     }
 
                     if constexpr (std::is_same_v<T, ConsumedValue>) {
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 2bf3b76..d396965 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -30,8 +30,10 @@
 class ExitIf;
 class If;
 class Let;
+class Load;
 class Return;
 class Unreachable;
+class Var;
 }  // namespace tint::ir
 
 namespace tint::msl::writer {
@@ -71,6 +73,12 @@
     /// Emit a let instruction
     /// @param l the let instruction
     void EmitLet(ir::Let* l);
+    /// Emit a var instruction
+    /// @param v the var instruction
+    void EmitVar(ir::Var* v);
+    /// Emit a load instruction
+    /// @param l the load instruction
+    void EmitLoad(ir::Load* l);
 
     /// Emit a return instruction
     /// @param r the return instruction
@@ -126,6 +134,11 @@
     /// @param c the constant to emit
     void EmitConstant(StringStream& out, const constant::Value* c);
 
+    /// Emits the zero value for the given type
+    /// @param out the stream to emit too
+    /// @param ty the type
+    void EmitZeroValue(StringStream& out, const type::Type* ty);
+
     /// @returns the name of the templated `tint_array` helper type, generating it if needed
     const std::string& ArrayTemplateName();
 
diff --git a/src/tint/lang/msl/writer/printer/var_test.cc b/src/tint/lang/msl/writer/printer/var_test.cc
new file mode 100644
index 0000000..b13529a
--- /dev/null
+++ b/src/tint/lang/msl/writer/printer/var_test.cc
@@ -0,0 +1,288 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/msl/writer/printer/helper_test.h"
+
+namespace tint::msl::writer {
+namespace {
+
+using namespace tint::builtin::fluent_types;  // NOLINT
+using namespace tint::number_suffixes;        // NOLINT
+
+TEST_F(MslPrinterTest, VarF32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, f32>());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float a = 0.0f;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarI32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, i32>());
+        v->SetInitializer(b.Constant(1_i));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  int a = 1;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarU32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, u32>());
+        v->SetInitializer(b.Constant(1_u));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  uint a = 1u;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarArrayF32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, array<f32, 5>>());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + MetalArray() + R"(
+void foo() {
+  tint_array<float, 5> a = {};
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarStruct) {
+    auto* s = ty.Struct(mod.symbols.New("MyStruct"), {{mod.symbols.Register("a"), ty.f32()},  //
+                                                      {mod.symbols.Register("b"), ty.vec4<i32>()}});
+
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr(builtin::AddressSpace::kFunction, s));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(struct MyStruct {
+  float a;
+  int4 b;
+};
+
+void foo() {
+  MyStruct a = {};
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec2<f32>>());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float2 a = 0.0f;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF16) {
+    // Enable f16?
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec2<f16>>());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half2 a = 0.0h;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF32) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat3x2<f32>>());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float3x2 a = float3x2(0.0f);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF16) {
+    // Enable f16?
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat3x2<f16>>());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half3x2 a = half3x2(0.0h);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF32SplatZero) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec3<f32>>());
+        v->SetInitializer(b.Splat(ty.vec3<f32>(), 0_f, 3));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float3 a = float3(0.0f);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarVecF16SplatZero) {
+    // Enable f16
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, vec3<f16>>());
+        v->SetInitializer(b.Splat(ty.vec3<f16>(), 0_h, 3));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half3 a = half3(0.0h);
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF32SplatZero) {
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat2x3<f32>>());
+        v->SetInitializer(b.Composite(ty.mat2x3<f32>(), b.Splat(ty.vec3<f32>(), 0_f, 3),
+                                      b.Splat(ty.vec3<f32>(), 0_f, 3)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  float2x3 a = float2x3(float3(0.0f), float3(0.0f));
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarMatF16SplatZero) {
+    // Enable f16?
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* v = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, mat2x3<f16>>());
+        v->SetInitializer(b.Composite(ty.mat2x3<f16>(), b.Splat(ty.vec3<f16>(), 0_h, 3),
+                                      b.Splat(ty.vec3<f16>(), 0_h, 3)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+void foo() {
+  half2x3 a = half2x3(half3(0.0h), half3(0.0h));
+}
+)");
+}
+
+// TODO(dsinclair): Requires ModuleScopeVarToEntryPointParam transform
+TEST_F(MslPrinterTest, DISABLED_VarGlobalPrivate) {
+    ir::Var* v = nullptr;
+    b.Append(b.RootBlock(),
+             [&] { v = b.Var("v", ty.ptr<builtin::AddressSpace::kPrivate, f32>()); });
+
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* ld = b.Load(v->Result());
+        auto* a = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, f32>());
+        a->SetInitializer(ld->Result());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+struct tint_private_vars_struct {
+  float a;
+};
+
+void foo() {
+    thread tint_private_vars_struct tint_private_vars = {};
+    float const a = tint_private_vars.a;
+    return;
+}
+)");
+}
+
+TEST_F(MslPrinterTest, VarGlobalWorkgroup) {
+    ir::Var* v = nullptr;
+    b.Append(b.RootBlock(),
+             [&] { v = b.Var("v", ty.ptr<builtin::AddressSpace::kWorkgroup, f32>()); });
+
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* ld = b.Load(v->Result());
+        auto* a = b.Var("a", ty.ptr<builtin::AddressSpace::kFunction, f32>());
+        a->SetInitializer(ld->Result());
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
+    EXPECT_EQ(generator_.Result(), MetalHeader() + R"(
+threadgroup float v;
+void foo() {
+  float a = v;
+}
+)");
+}
+
+}  // namespace
+}  // namespace tint::msl::writer