spirv-reader: Support texture and sampler args to user-defined functions

Fixed: tint:1039
Change-Id: If0cb28679cc73f54025c2c142bdc32852bf4ae28
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109820
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index e53238f..07b5980 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -1514,19 +1514,12 @@
 
     ParameterList ast_params;
     function_.ForEachParam([this, &ast_params](const spvtools::opt::Instruction* param) {
-        const Type* type = nullptr;
-        auto* spirv_type = type_mgr_->GetType(param->type_id());
-        TINT_ASSERT(Reader, spirv_type);
-        if (spirv_type->AsImage() || spirv_type->AsSampler() ||
-            (spirv_type->AsPointer() &&
-             (static_cast<spv::StorageClass>(spirv_type->AsPointer()->storage_class()) ==
-              spv::StorageClass::UniformConstant))) {
-            // When we see image, sampler, pointer-to-image, or pointer-to-sampler, use the
-            // handle type deduced according to usage.
-            type = parser_impl_.GetHandleTypeForSpirvHandle(*param);
-        } else {
-            type = parser_impl_.ConvertType(param->type_id());
-        }
+        // Valid SPIR-V requires function call parameters to be non-null
+        // instructions.
+        TINT_ASSERT(Reader, param != nullptr);
+        const Type* const type = IsHandleObj(*param)
+                                     ? parser_impl_.GetHandleTypeForSpirvHandle(*param)
+                                     : parser_impl_.ConvertType(param->type_id());
 
         if (type != nullptr) {
             auto* ast_param = parser_impl_.MakeParameter(param->result_id(), type, AttributeList{});
@@ -1550,6 +1543,20 @@
     return success();
 }
 
+bool FunctionEmitter::IsHandleObj(const spvtools::opt::Instruction& obj) {
+    TINT_ASSERT(Reader, obj.type_id() != 0u);
+    auto* spirv_type = type_mgr_->GetType(obj.type_id());
+    TINT_ASSERT(Reader, spirv_type);
+    return spirv_type->AsImage() || spirv_type->AsSampler() ||
+           (spirv_type->AsPointer() &&
+            (static_cast<spv::StorageClass>(spirv_type->AsPointer()->storage_class()) ==
+             spv::StorageClass::UniformConstant));
+}
+
+bool FunctionEmitter::IsHandleObj(const spvtools::opt::Instruction* obj) {
+    return (obj != nullptr) && IsHandleObj(*obj);
+}
+
 const Type* FunctionEmitter::GetVariableStoreType(const spvtools::opt::Instruction& var_decl_inst) {
     const auto type_id = var_decl_inst.type_id();
     // Normally we use the SPIRV-Tools optimizer to manage types.
@@ -5278,7 +5285,21 @@
 
     ExpressionList args;
     for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
-        auto expr = MakeOperand(inst, iarg);
+        uint32_t arg_id = inst.GetSingleWordInOperand(iarg);
+        TypedExpression expr;
+
+        if (IsHandleObj(def_use_mgr_->GetDef(arg_id))) {
+            // For textures and samplers, use the memory object declaration
+            // instead.
+            const auto usage = parser_impl_.GetHandleUsage(arg_id);
+            const auto* mem_obj_decl =
+                parser_impl_.GetMemoryObjectDeclarationForHandle(arg_id, usage.IsTexture());
+            expr = MakeExpression(mem_obj_decl->result_id());
+            // Pass the handle through instead of a pointer to the handle.
+            expr.type = parser_impl_.GetHandleTypeForSpirvHandle(*mem_obj_decl);
+        } else {
+            expr = MakeOperand(inst, iarg);
+        }
         if (!expr) {
             return false;
         }
diff --git a/src/tint/reader/spirv/function.h b/src/tint/reader/spirv/function.h
index 755ecc9..8d06735 100644
--- a/src/tint/reader/spirv/function.h
+++ b/src/tint/reader/spirv/function.h
@@ -1002,6 +1002,16 @@
     /// @returns true if emission has not yet failed.
     bool ParseFunctionDeclaration(FunctionDeclaration* decl);
 
+    /// @param obj a SPIR-V instruction with a result ID and a type ID
+    /// @returns true if the object is an image, a sampler, or a pointer to
+    /// an image or a sampler
+    bool IsHandleObj(const spvtools::opt::Instruction& obj);
+
+    /// @param obj a SPIR-V instruction with a result ID and a type ID
+    /// @returns true if the object is an image, a sampler, or a pointer to
+    /// an image or a sampler
+    bool IsHandleObj(const spvtools::opt::Instruction* obj);
+
     /// @returns the store type for the OpVariable instruction, or
     /// null on failure.
     const Type* GetVariableStoreType(const spvtools::opt::Instruction& var_decl_inst);
diff --git a/src/tint/reader/spirv/function_call_test.cc b/src/tint/reader/spirv/function_call_test.cc
index 6b12ece..145a6fb 100644
--- a/src/tint/reader/spirv/function_call_test.cc
+++ b/src/tint/reader/spirv/function_call_test.cc
@@ -32,6 +32,42 @@
 )";
 }
 
+std::string CommonTypes() {
+    return R"(
+    %void = OpTypeVoid
+    %voidfn = OpTypeFunction %void
+    %float = OpTypeFloat 32
+    %uint = OpTypeInt 32 0
+    %int = OpTypeInt 32 1
+    %float_0 = OpConstant %float 0.0
+  )";
+}
+
+std::string CommonHandleTypes() {
+    return R"(
+    OpName %t "t"
+    OpName %s "s"
+    OpDecorate %t DescriptorSet 0
+    OpDecorate %t Binding 0
+    OpDecorate %s DescriptorSet 0
+    OpDecorate %s Binding 1
+    )" + CommonTypes() +
+           R"(
+
+    %v2float = OpTypeVector %float 2
+    %v4float = OpTypeVector %float 4
+    %v2_0 = OpConstantNull %v2float
+    %sampler = OpTypeSampler
+    %tex2d_f32 = OpTypeImage %float 2D 0 0 0 1 Unknown
+    %sampled_image_2d_f32 = OpTypeSampledImage %tex2d_f32
+    %ptr_sampler = OpTypePointer UniformConstant %sampler
+    %ptr_tex2d_f32 = OpTypePointer UniformConstant %tex2d_f32
+
+    %t = OpVariable %ptr_tex2d_f32 UniformConstant
+    %s = OpVariable %ptr_sampler UniformConstant
+  )";
+}
+
 TEST_F(SpvParserTest, EmitStatement_VoidCallNoParams) {
     auto p = parser(test::Assemble(Preamble() + R"(
      %void = OpTypeVoid
@@ -193,5 +229,142 @@
     EXPECT_EQ(program_ast_str, expected);
 }
 
+std::string HelperFunctionPtrHandle() {
+    return R"(
+     ; This is how Glslang generates functions that take texture and sampler arguments.
+     ; It passes them by pointer.
+     %fn_ty = OpTypeFunction %void %ptr_tex2d_f32 %ptr_sampler
+
+     %200 = OpFunction %void None %fn_ty
+     %14 = OpFunctionParameter %ptr_tex2d_f32
+     %15 = OpFunctionParameter %ptr_sampler
+     %helper_entry = OpLabel
+     ; access the texture, to give the handles usages.
+     %helper_im = OpLoad %tex2d_f32 %14
+     %helper_sam = OpLoad %sampler %15
+     %helper_imsam = OpSampledImage %sampled_image_2d_f32 %helper_im %helper_sam
+     %20 = OpImageSampleImplicitLod %v4float %helper_imsam %v2_0
+     OpReturn
+     OpFunctionEnd
+     )";
+}
+
+std::string HelperFunctionHandle() {
+    return R"(
+     ; It is valid in SPIR-V to pass textures and samplers by value.
+     %fn_ty = OpTypeFunction %void %tex2d_f32 %sampler
+
+     %200 = OpFunction %void None %fn_ty
+     %14 = OpFunctionParameter %tex2d_f32
+     %15 = OpFunctionParameter %sampler
+     %helper_entry = OpLabel
+     ; access the texture, to give the handles usages.
+     %helper_imsam = OpSampledImage %sampled_image_2d_f32 %14 %15
+     %20 = OpImageSampleImplicitLod %v4float %helper_imsam %v2_0
+     OpReturn
+     OpFunctionEnd
+     )";
+}
+
+TEST_F(SpvParserTest, Emit_FunctionCall_HandlePtrParams_Direct) {
+    auto assembly = Preamble() + CommonHandleTypes() + HelperFunctionPtrHandle() + R"(
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpFunctionCall %void %200 %t %s
+     OpReturn
+     OpFunctionEnd
+  )";
+
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    const auto got = test::ToString(p->program(), fe.ast_body());
+
+    const std::string expect = R"(x_200(t, s);
+return;
+)";
+    EXPECT_EQ(got, expect);
+}
+
+TEST_F(SpvParserTest, Emit_FunctionCall_HandlePtrParams_CopyObject) {
+    auto assembly = Preamble() + CommonHandleTypes() + HelperFunctionPtrHandle() + R"(
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+
+     %copy_t = OpCopyObject %ptr_tex2d_f32 %t
+     %copy_s = OpCopyObject %ptr_sampler %s
+     %1 = OpFunctionCall %void %200 %copy_t %copy_s
+     OpReturn
+     OpFunctionEnd
+  )";
+
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    const auto got = test::ToString(p->program(), fe.ast_body());
+
+    const std::string expect = R"(x_200(t, s);
+return;
+)";
+    EXPECT_EQ(got, expect);
+}
+
+TEST_F(SpvParserTest, Emit_FunctionCall_HandleParams_Load) {
+    auto assembly = Preamble() + CommonHandleTypes() + HelperFunctionHandle() + R"(
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %im = OpLoad %tex2d_f32 %t
+     %sam = OpLoad %sampler %s
+     %1 = OpFunctionCall %void %200 %im %sam
+     OpReturn
+     OpFunctionEnd
+  )";
+
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    const auto got = test::ToString(p->program(), fe.ast_body());
+
+    const std::string expect = R"(x_200(t, s);
+return;
+)";
+    EXPECT_EQ(got, expect);
+}
+
+TEST_F(SpvParserTest, Emit_FunctionCall_HandleParams_LoadsAndCopyObject) {
+    auto assembly = Preamble() + CommonHandleTypes() + HelperFunctionHandle() + R"(
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+
+     %copy_t = OpCopyObject %ptr_tex2d_f32 %t
+     %copy_s = OpCopyObject %ptr_sampler %s
+     %im = OpLoad %tex2d_f32 %copy_t
+     %sam = OpLoad %sampler %copy_s
+     %copy_im = OpCopyObject %tex2d_f32 %im
+     %copy_sam = OpCopyObject %sampler %sam
+     %1 = OpFunctionCall %void %200 %copy_im %copy_sam
+     OpReturn
+     OpFunctionEnd
+  )";
+
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    const auto got = test::ToString(p->program(), fe.ast_body());
+
+    const std::string expect = R"(x_200(t, s);
+return;
+)";
+    EXPECT_EQ(got, expect);
+}
+
 }  // namespace
 }  // namespace tint::reader::spirv
diff --git a/src/tint/reader/spirv/function_decl_test.cc b/src/tint/reader/spirv/function_decl_test.cc
index 9413d53..0c4834b 100644
--- a/src/tint/reader/spirv/function_decl_test.cc
+++ b/src/tint/reader/spirv/function_decl_test.cc
@@ -160,9 +160,6 @@
     EXPECT_THAT(got, HasSubstr(expect));
 }
 
-//     ;%s = OpVariable %ptr_sampler UniformConstant
-//     ;%t = OpVariable %ptr_tex2d_f32 UniformConstant
-
 TEST_F(SpvParserTest, Emit_FunctionDecl_ParamPtrTexture_ParamPtrSampler) {
     auto p = parser(test::Assemble(Preamble() + CommonHandleTypes() + R"(