[spirv-reader][ir] Split ShaderIO GetParameter apart.

Split the entry point specific code for `GetParameter` out of the SPIR-V
ShaderIO lowering code into its own method.

Bug: 42250952
Change-Id: Id4f1840d775dc85e27711714f2d803603b673ea6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/247154
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/lower/shader_io.cc b/src/tint/lang/spirv/reader/lower/shader_io.cc
index dbe23a5..a675fb6 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io.cc
@@ -409,60 +409,148 @@
         });
     }
 
+    core::ir::Value* GetEntryPointParameter(core::ir::Function* func, core::ir::Var* var) {
+        auto* var_type = var->Result()->Type()->UnwrapPtr();
+
+        // The SPIR_V type may not match the required WGSL entry point type, swap them as
+        // needed.
+        if (var->Attributes().builtin.has_value()) {
+            switch (var->Attributes().builtin.value()) {
+                case core::BuiltinValue::kSampleMask: {
+                    TINT_ASSERT(var_type->Is<core::type::Array>());
+                    TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u);
+                    var_type = ty.u32();
+                    break;
+                }
+                case core::BuiltinValue::kInstanceIndex:
+                case core::BuiltinValue::kVertexIndex:
+                case core::BuiltinValue::kLocalInvocationIndex:
+                case core::BuiltinValue::kSubgroupInvocationId:
+                case core::BuiltinValue::kSubgroupSize:
+                case core::BuiltinValue::kSampleIndex: {
+                    var_type = ty.u32();
+                    break;
+                }
+                case core::BuiltinValue::kLocalInvocationId:
+                case core::BuiltinValue::kGlobalInvocationId:
+                case core::BuiltinValue::kWorkgroupId:
+                case core::BuiltinValue::kNumWorkgroups: {
+                    var_type = ty.vec3<u32>();
+                    break;
+                }
+                default: {
+                    break;
+                }
+            }
+        }
+
+        // Create a new function parameter for the input.
+        auto* param = b.FunctionParam(var_type);
+        func->AppendParam(param);
+        if (auto name = ir.NameOf(var)) {
+            ir.SetName(param, name);
+        }
+
+        // Add attributes to the parameter
+        if (auto* str = param->Type()->UnwrapPtr()->As<core::type::Struct>()) {
+            for (const auto* const_member : str->Members()) {
+                // TODO(crbug.com/tint/745): Remove the const_cast.
+                auto* member = const_cast<core::type::StructMember*>(const_member);
+
+                // Use the base variable attributes if not specified directly on the member.
+                auto member_attributes = member->Attributes();
+                if (auto base_loc = var->Attributes().location) {
+                    // Location values increment from the base location value on the variable.
+                    member->SetLocation(base_loc.value() + member->Index());
+                }
+                if (!member_attributes.interpolation) {
+                    member->SetInterpolation(var->Attributes().interpolation);
+                }
+            }
+        } else {
+            // Set attributes directly on the function parameter.
+            param->SetAttributes(var->Attributes());
+        }
+
+        core::ir::Value* result = param;
+        if (var->Attributes().builtin.has_value()) {
+            switch (var->Attributes().builtin.value()) {
+                case core::BuiltinValue::kSampleMask: {
+                    // Construct an array from the scalar sample_mask builtin value for entry
+                    // points.
+
+                    auto* mask_ty = var->Result()->Type()->UnwrapPtr()->As<core::type::Array>();
+                    TINT_ASSERT(mask_ty);
+
+                    // If the SPIR-V mask was an i32, need to convert from the u32 provided by
+                    // WGSL.
+                    if (mask_ty->ElemType()->IsSignedIntegerScalar()) {
+                        auto* conv = b.Convert(ty.i32(), result);
+                        func->Block()->Prepend(conv);
+
+                        auto* construct = b.Construct(mask_ty, conv);
+                        construct->InsertAfter(conv);
+                        result = construct->Result();
+                    } else {
+                        auto* construct = b.Construct(mask_ty, result);
+                        func->Block()->Prepend(construct);
+                        result = construct->Result();
+                    }
+                    break;
+                }
+                case core::BuiltinValue::kInstanceIndex:
+                case core::BuiltinValue::kVertexIndex:
+                case core::BuiltinValue::kLocalInvocationIndex:
+                case core::BuiltinValue::kSubgroupInvocationId:
+                case core::BuiltinValue::kSubgroupSize:
+                case core::BuiltinValue::kSampleIndex: {
+                    auto* idx_ty = var->Result()->Type()->UnwrapPtr();
+                    if (idx_ty->IsSignedIntegerScalar()) {
+                        auto* conv = b.Convert(ty.i32(), result);
+                        func->Block()->Prepend(conv);
+                        result = conv->Result();
+                    }
+                    break;
+                }
+                case core::BuiltinValue::kLocalInvocationId:
+                case core::BuiltinValue::kGlobalInvocationId:
+                case core::BuiltinValue::kWorkgroupId:
+                case core::BuiltinValue::kNumWorkgroups: {
+                    auto* idx_ty = var->Result()->Type()->UnwrapPtr();
+                    auto* elem_ty = idx_ty->DeepestElement();
+                    if (elem_ty->IsSignedIntegerScalar()) {
+                        auto* conv = b.Convert(ty.MatchWidth(ty.i32(), idx_ty), result);
+                        func->Block()->Prepend(conv);
+                        result = conv->Result();
+                    }
+                    break;
+                }
+                default: {
+                    break;
+                }
+            }
+        }
+        return result;
+    }
+
     /// Get or create a function parameter to replace a module-scope variable.
     /// @param func the function
     /// @param var the module-scope variable
     /// @returns the function parameter
     core::ir::Value* GetParameter(core::ir::Function* func, core::ir::Var* var) {
-        return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&] {
-            const bool entry_point = func->IsEntryPoint();
-            auto* var_type = var->Result()->Type()->UnwrapPtr();
-
-            // The SPIR_V type may not match the required WGSL entry point type, swap them as
-            // needed.
-            if (entry_point && var->Attributes().builtin.has_value()) {
-                switch (var->Attributes().builtin.value()) {
-                    case core::BuiltinValue::kSampleMask: {
-                        TINT_ASSERT(var_type->Is<core::type::Array>());
-                        TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u);
-                        var_type = ty.u32();
-                        break;
-                    }
-                    case core::BuiltinValue::kInstanceIndex:
-                    case core::BuiltinValue::kVertexIndex:
-                    case core::BuiltinValue::kLocalInvocationIndex:
-                    case core::BuiltinValue::kSubgroupInvocationId:
-                    case core::BuiltinValue::kSubgroupSize:
-                    case core::BuiltinValue::kSampleIndex: {
-                        var_type = ty.u32();
-                        break;
-                    }
-                    case core::BuiltinValue::kLocalInvocationId:
-                    case core::BuiltinValue::kGlobalInvocationId:
-                    case core::BuiltinValue::kWorkgroupId:
-                    case core::BuiltinValue::kNumWorkgroups: {
-                        var_type = ty.vec3<u32>();
-                        break;
-                    }
-                    default: {
-                        break;
-                    }
-                }
+        return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&]() -> core::ir::Value* {
+            if (func->IsEntryPoint()) {
+                return GetEntryPointParameter(func, var);
             }
 
             // Create a new function parameter for the input.
-            auto* param = b.FunctionParam(var_type);
+            auto* param = b.FunctionParam(var->Result()->Type()->UnwrapPtr());
             func->AppendParam(param);
             if (auto name = ir.NameOf(var)) {
                 ir.SetName(param, name);
             }
 
-            // Add attributes to the parameter if this is an entry point function.
-            if (entry_point) {
-                AddEntryPointParameterAttributes(param, var->Attributes());
-            }
-
-            // Update the callsites of this function.
+            // Update the call sites of this function.
             func->ForEachUseUnsorted([&](core::ir::Usage use) {
                 if (auto* call = use.instruction->As<core::ir::UserCall>()) {
                     // Recurse into the calling function.
@@ -474,93 +562,9 @@
                 }
             });
 
-            core::ir::Value* result = param;
-            if (entry_point && var->Attributes().builtin.has_value()) {
-                switch (var->Attributes().builtin.value()) {
-                    case core::BuiltinValue::kSampleMask: {
-                        // Construct an array from the scalar sample_mask builtin value for entry
-                        // points.
-
-                        auto* mask_ty = var->Result()->Type()->UnwrapPtr()->As<core::type::Array>();
-                        TINT_ASSERT(mask_ty);
-
-                        // If the SPIR-V mask was an i32, need to convert from the u32 provided by
-                        // WGSL.
-                        if (mask_ty->ElemType()->IsSignedIntegerScalar()) {
-                            auto* conv = b.Convert(ty.i32(), result);
-                            func->Block()->Prepend(conv);
-
-                            auto* construct = b.Construct(mask_ty, conv);
-                            construct->InsertAfter(conv);
-                            result = construct->Result();
-                        } else {
-                            auto* construct = b.Construct(mask_ty, result);
-                            func->Block()->Prepend(construct);
-                            result = construct->Result();
-                        }
-                        break;
-                    }
-                    case core::BuiltinValue::kInstanceIndex:
-                    case core::BuiltinValue::kVertexIndex:
-                    case core::BuiltinValue::kLocalInvocationIndex:
-                    case core::BuiltinValue::kSubgroupInvocationId:
-                    case core::BuiltinValue::kSubgroupSize:
-                    case core::BuiltinValue::kSampleIndex: {
-                        auto* idx_ty = var->Result()->Type()->UnwrapPtr();
-                        if (idx_ty->IsSignedIntegerScalar()) {
-                            auto* conv = b.Convert(ty.i32(), result);
-                            func->Block()->Prepend(conv);
-                            result = conv->Result();
-                        }
-                        break;
-                    }
-                    case core::BuiltinValue::kLocalInvocationId:
-                    case core::BuiltinValue::kGlobalInvocationId:
-                    case core::BuiltinValue::kWorkgroupId:
-                    case core::BuiltinValue::kNumWorkgroups: {
-                        auto* idx_ty = var->Result()->Type()->UnwrapPtr();
-                        auto* elem_ty = idx_ty->DeepestElement();
-                        if (elem_ty->IsSignedIntegerScalar()) {
-                            auto* conv = b.Convert(ty.MatchWidth(ty.i32(), idx_ty), result);
-                            func->Block()->Prepend(conv);
-                            result = conv->Result();
-                        }
-                        break;
-                    }
-                    default: {
-                        break;
-                    }
-                }
-            }
-            return result;
+            return param;
         });
     }
-
-    /// Add attributes to an entry point function parameter.
-    /// @param param the parameter
-    /// @param attributes the attributes
-    void AddEntryPointParameterAttributes(core::ir::FunctionParam* param,
-                                          const core::IOAttributes& attributes) {
-        if (auto* str = param->Type()->UnwrapPtr()->As<core::type::Struct>()) {
-            for (const auto* const_member : str->Members()) {
-                // TODO(crbug.com/tint/745): Remove the const_cast.
-                auto* member = const_cast<core::type::StructMember*>(const_member);
-
-                // Use the base variable attributes if not specified directly on the member.
-                auto member_attributes = member->Attributes();
-                if (auto base_loc = attributes.location) {
-                    // Location values increment from the base location value on the variable.
-                    member->SetLocation(base_loc.value() + member->Index());
-                }
-                if (!member_attributes.interpolation) {
-                    member->SetInterpolation(attributes.interpolation);
-                }
-            }
-        } else {
-            // Set attributes directly on the function parameter.
-            param->SetAttributes(attributes);
-        }
-    }
 };
 
 }  // namespace