[msl] Support handle types in ModuleScopeVars

These are passed around by value instead of as pointers, so add some
extra logic to drop the pointers and remove load instructions.

Bug: 42251016
Change-Id: Ie395f12fce678f62942610d2a6b50b5e6986cbb1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/188842
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/printer/function_test.cc b/src/tint/lang/msl/writer/printer/function_test.cc
index 0513a48..8ba60c2 100644
--- a/src/tint/lang/msl/writer/printer/function_test.cc
+++ b/src/tint/lang/msl/writer/printer/function_test.cc
@@ -25,6 +25,7 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
+#include "src/tint/lang/core/type/sampled_texture.h"
 #include "src/tint/lang/msl/writer/printer/helper_test.h"
 
 namespace tint::msl::writer {
@@ -41,7 +42,7 @@
 )");
 }
 
-TEST_F(MslPrinterTest, EntryPointParameterBindingPoint) {
+TEST_F(MslPrinterTest, EntryPointParameterBufferBindingPoint) {
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     auto* storage = b.FunctionParam("storage", ty.ptr(core::AddressSpace::kStorage, ty.i32()));
     auto* uniform = b.FunctionParam("uniform", ty.ptr(core::AddressSpace::kUniform, ty.i32()));
@@ -57,5 +58,22 @@
 )");
 }
 
+TEST_F(MslPrinterTest, EntryPointParameterHandleBindingPoint) {
+    auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    auto* texture = b.FunctionParam("texture", t);
+    auto* sampler = b.FunctionParam("sampler", ty.sampler());
+    texture->SetBindingPoint(0, 1);
+    sampler->SetBindingPoint(0, 2);
+    func->SetParams({texture, sampler});
+    func->Block()->Append(b.Return(func));
+
+    ASSERT_TRUE(Generate()) << err_ << output_;
+    EXPECT_EQ(output_, MetalHeader() + R"(
+fragment void foo(texture2d<float, access::sample> texture [[texture(1)]], sampler sampler [[sampler(2)]]) {
+}
+)");
+}
+
 }  // namespace
 }  // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 7c58cd4..0ae4046 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -329,16 +329,28 @@
                 }
 
                 if (auto binding_point = param->BindingPoint()) {
-                    auto ptr = param->Type()->As<core::type::Pointer>();
                     TINT_ASSERT(binding_point->group == 0);
-                    switch (ptr->AddressSpace()) {
-                        case core::AddressSpace::kStorage:
-                        case core::AddressSpace::kUniform:
-                            out << " [[buffer(" << binding_point->binding << ")]]";
-                            break;
-                        default:
-                            TINT_UNREACHABLE() << "invalid address space with binding point: "
-                                               << ptr->AddressSpace();
+                    if (auto ptr = param->Type()->As<core::type::Pointer>()) {
+                        switch (ptr->AddressSpace()) {
+                            case core::AddressSpace::kStorage:
+                            case core::AddressSpace::kUniform:
+                                out << " [[buffer(" << binding_point->binding << ")]]";
+                                break;
+                            default:
+                                TINT_UNREACHABLE() << "invalid address space with binding point: "
+                                                   << ptr->AddressSpace();
+                        }
+                    } else {
+                        // Handle types are declared by value instead of by pointer.
+                        Switch(
+                            param->Type(),
+                            [&](const core::type::Texture*) {
+                                out << " [[texture(" << binding_point->binding << ")]]";
+                            },
+                            [&](const core::type::Sampler*) {
+                                out << " [[sampler(" << binding_point->binding << ")]]";
+                            },
+                            TINT_ICE_ON_NO_MATCH);
                     }
                 }
             }
@@ -1040,7 +1052,6 @@
         switch (sc) {
             case core::AddressSpace::kFunction:
             case core::AddressSpace::kPrivate:
-            case core::AddressSpace::kHandle:
                 out << "thread";
                 break;
             case core::AddressSpace::kWorkgroup:
diff --git a/src/tint/lang/msl/writer/raise/module_scope_vars.cc b/src/tint/lang/msl/writer/raise/module_scope_vars.cc
index 0cb07c1..00a569f 100644
--- a/src/tint/lang/msl/writer/raise/module_scope_vars.cc
+++ b/src/tint/lang/msl/writer/raise/module_scope_vars.cc
@@ -88,14 +88,32 @@
             ProcessFunction(*func);
         }
 
-        // Replace uses of each module-scope variable with pointers extracted from the structure.
+        // Replace uses of each module-scope variable with values extracted from the structure.
         uint32_t index = 0;
         for (auto& var : module_vars) {
-            var->Result(0)->ReplaceAllUsesWith([&](core::ir::Usage use) {  //
-                return GetPointerFromStruct(var, use.instruction, index);
+            Vector<core::ir::Instruction*, 16> to_destroy;
+            auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
+            var->Result(0)->ForEachUse([&](core::ir::Usage use) {  //
+                auto* extracted_variable = GetVariableFromStruct(var, use.instruction, index);
+
+                // We drop the pointer from handle variables and store them in the struct by value
+                // instead, so remove any load instructions for the handle address space.
+                if (use.instruction->Is<core::ir::Load>() &&
+                    ptr->AddressSpace() == core::AddressSpace::kHandle) {
+                    use.instruction->Result(0)->ReplaceAllUsesWith(extracted_variable);
+                    to_destroy.Push(use.instruction);
+                    return;
+                }
+
+                use.instruction->SetOperand(use.operand_index, extracted_variable);
             });
             var->Destroy();
             index++;
+
+            // Clean up instructions that need to be removed.
+            for (auto* inst : to_destroy) {
+                inst->Destroy();
+            }
         }
     }
 
@@ -106,6 +124,13 @@
         for (auto* global : *ir.root_block) {
             if (auto* var = global->As<core::ir::Var>()) {
                 auto* type = var->Result(0)->Type();
+
+                // Handle types drop the pointer and are passed around by value.
+                auto* ptr = type->As<core::type::Pointer>();
+                if (ptr->AddressSpace() == core::AddressSpace::kHandle) {
+                    type = ptr->StoreType();
+                }
+
                 auto name = ir.NameOf(var);
                 if (!name) {
                     name = ir.symbols.New();
@@ -181,6 +206,14 @@
                         decl = param;
                         break;
                     }
+                    case core::AddressSpace::kHandle: {
+                        // Handle types become function parameters and drop the pointer.
+                        auto* param = b.FunctionParam(ptr->UnwrapPtr());
+                        param->SetBindingPoint(var->BindingPoint());
+                        func->AppendParam(param);
+                        decl = param;
+                        break;
+                    }
                     default:
                         TINT_UNREACHABLE() << "unhandled address space: " << ptr->AddressSpace();
                 }
@@ -218,18 +251,26 @@
         return param;
     }
 
-    /// Get a pointer from the module-scope variable replacement structure, inserting new access
+    /// Get a variable from the module-scope variable replacement structure, inserting new access
     /// instructions before @p inst.
     /// @param var the variable to get the replacement for
     /// @param inst the instruction that uses the variable
     /// @param index the index of the variable in the structure member list
-    /// @returns the pointer extracted from the structure
-    core::ir::Value* GetPointerFromStruct(core::ir::Var* var,
-                                          core::ir::Instruction* inst,
-                                          uint32_t index) {
+    /// @returns the variable extracted from the structure
+    core::ir::Value* GetVariableFromStruct(core::ir::Var* var,
+                                           core::ir::Instruction* inst,
+                                           uint32_t index) {
         auto* func = ContainingFunction(inst);
         auto* struct_value = function_to_struct_value.GetOr(func, nullptr);
-        auto* access = b.Access(var->Result(0)->Type(), struct_value, u32(index));
+        auto* type = var->Result(0)->Type();
+
+        // Handle types drop the pointer and are passed around by value.
+        auto* ptr = type->As<core::type::Pointer>();
+        if (ptr->AddressSpace() == core::AddressSpace::kHandle) {
+            type = ptr->StoreType();
+        }
+
+        auto* access = b.Access(type, struct_value, u32(index));
         access->InsertBefore(inst);
         return access->Result(0);
     }
diff --git a/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc b/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
index 17ffdc1..3bef228 100644
--- a/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
+++ b/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
@@ -30,6 +30,7 @@
 #include <utility>
 
 #include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/type/sampled_texture.h"
 
 using namespace tint::core::fluent_types;     // NOLINT
 using namespace tint::core::number_suffixes;  // NOLINT
@@ -306,6 +307,63 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(MslWriter_ModuleScopeVarsTest, HandleTypes) {
+    auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+    auto* var_t = b.Var("t", ty.ptr<handle>(t));
+    auto* var_s = b.Var("s", ty.ptr<handle>(ty.sampler()));
+    var_t->SetBindingPoint(1, 2);
+    var_s->SetBindingPoint(3, 4);
+    mod.root_block->Append(var_t);
+    mod.root_block->Append(var_s);
+
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* load_t = b.Load(var_t);
+        auto* load_s = b.Load(var_s);
+        b.Call<vec4<f32>>(core::BuiltinFn::kTextureSample, load_t, load_s, b.Splat<vec2<f32>>(0_f));
+        b.Return(func);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %t:ptr<handle, texture_2d<f32>, read> = var @binding_point(1, 2)
+  %s:ptr<handle, sampler, read> = var @binding_point(3, 4)
+}
+
+%foo = @fragment func():void {
+  $B2: {
+    %4:texture_2d<f32> = load %t
+    %5:sampler = load %s
+    %6:vec4<f32> = textureSample %4, %5, vec2<f32>(0.0f)
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+tint_module_vars_struct = struct @align(1) {
+  t:texture_2d<f32> @offset(0)
+  s:sampler @offset(0)
+}
+
+%foo = @fragment func(%t:texture_2d<f32> [@binding_point(1, 2)], %s:sampler [@binding_point(3, 4)]):void {
+  $B1: {
+    %4:tint_module_vars_struct = construct %t, %s
+    %tint_module_vars:tint_module_vars_struct = let %4
+    %6:texture_2d<f32> = access %tint_module_vars, 0u
+    %7:sampler = access %tint_module_vars, 1u
+    %8:vec4<f32> = textureSample %6, %7, vec2<f32>(0.0f)
+    ret
+  }
+}
+)";
+
+    Run(ModuleScopeVars);
+
+    EXPECT_EQ(expect, str());
+}
+
 TEST_F(MslWriter_ModuleScopeVarsTest, MultipleAddressSpaces) {
     auto* var_a = b.Var("a", ty.ptr<uniform, i32, core::Access::kRead>());
     auto* var_b = b.Var("b", ty.ptr<storage, i32, core::Access::kReadWrite>());
@@ -607,6 +665,84 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(MslWriter_ModuleScopeVarsTest, CallFunctionThatUsesVars_HandleTypes) {
+    auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+    auto* var_t = b.Var("t", ty.ptr<handle>(t));
+    auto* var_s = b.Var("s", ty.ptr<handle>(ty.sampler()));
+    var_t->SetBindingPoint(1, 2);
+    var_s->SetBindingPoint(3, 4);
+    mod.root_block->Append(var_t);
+    mod.root_block->Append(var_s);
+
+    auto* foo = b.Function("foo", ty.vec4<f32>());
+    auto* param = b.FunctionParam<i32>("param");
+    foo->SetParams({param});
+    b.Append(foo->Block(), [&] {
+        auto* load_t = b.Load(var_t);
+        auto* load_s = b.Load(var_s);
+        auto* result = b.Call<vec4<f32>>(core::BuiltinFn::kTextureSample, load_t, load_s,
+                                         b.Splat<vec2<f32>>(0_f));
+        b.Return(foo, result);
+    });
+
+    auto* func = b.Function("main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Call(foo, 42_i);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %t:ptr<handle, texture_2d<f32>, read> = var @binding_point(1, 2)
+  %s:ptr<handle, sampler, read> = var @binding_point(3, 4)
+}
+
+%foo = func(%param:i32):vec4<f32> {
+  $B2: {
+    %5:texture_2d<f32> = load %t
+    %6:sampler = load %s
+    %7:vec4<f32> = textureSample %5, %6, vec2<f32>(0.0f)
+    ret %7
+  }
+}
+%main = @fragment func():void {
+  $B3: {
+    %9:vec4<f32> = call %foo, 42i
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+tint_module_vars_struct = struct @align(1) {
+  t:texture_2d<f32> @offset(0)
+  s:sampler @offset(0)
+}
+
+%foo = func(%param:i32, %tint_module_vars:tint_module_vars_struct):vec4<f32> {
+  $B1: {
+    %4:texture_2d<f32> = access %tint_module_vars, 0u
+    %5:sampler = access %tint_module_vars, 1u
+    %6:vec4<f32> = textureSample %4, %5, vec2<f32>(0.0f)
+    ret %6
+  }
+}
+%main = @fragment func(%t:texture_2d<f32> [@binding_point(1, 2)], %s:sampler [@binding_point(3, 4)]):void {
+  $B2: {
+    %10:tint_module_vars_struct = construct %t, %s
+    %tint_module_vars_1:tint_module_vars_struct = let %10  # %tint_module_vars_1: 'tint_module_vars'
+    %12:vec4<f32> = call %foo, 42i, %tint_module_vars_1
+    ret
+  }
+}
+)";
+
+    Run(ModuleScopeVars);
+
+    EXPECT_EQ(expect, str());
+}
+
 TEST_F(MslWriter_ModuleScopeVarsTest, CallFunctionThatUsesVars_OutOfOrder) {
     auto* var_a = b.Var("a", ty.ptr<storage, i32, core::Access::kRead>());
     auto* var_b = b.Var("b", ty.ptr<storage, i32, core::Access::kReadWrite>());