[tint][ir] Serialize function parameter attributes

Change-Id: I9dde39b1a14b37c4021d2e84c8f78b18205b991f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/165042
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index 7141a28..54a5544 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -171,13 +171,7 @@
             params_out.Push(ValueAs<ir::FunctionParam>(param_in));
         }
         if (fn_in.has_return_location()) {
-            auto& ret_loc_in = fn_in.return_location();
-            core::ir::Location ret_loc_out{};
-            ret_loc_out.value = ret_loc_in.value();
-            if (ret_loc_in.has_interpolation()) {
-                ret_loc_out.interpolation = Interpolation(ret_loc_in.interpolation());
-            }
-            fn_out->SetReturnLocation(ret_loc_out.value, std::move(ret_loc_out.interpolation));
+            fn_out->SetReturnLocation(Location(fn_in.return_location()));
         }
         fn_out->SetParams(std::move(params_out));
         fn_out->SetBlock(Block(fn_in.block()));
@@ -657,41 +651,21 @@
     ir::Value* CreateValue(const pb::Value& value_in) {
         ir::Value* value_out = nullptr;
         switch (value_in.kind_case()) {
-            case pb::Value::KindCase::kFunction: {
+            case pb::Value::KindCase::kFunction:
                 value_out = Function(value_in.function());
                 break;
-            }
-            case pb::Value::KindCase::kInstructionResult: {
-                auto& res_in = value_in.instruction_result();
-                auto* type = Type(res_in.type());
-                value_out = b.InstructionResult(type);
-                if (res_in.has_name()) {
-                    mod_out_.SetName(value_out, res_in.name());
-                }
+            case pb::Value::KindCase::kInstructionResult:
+                value_out = InstructionResult(value_in.instruction_result());
                 break;
-            }
-            case pb::Value::KindCase::kFunctionParameter: {
-                auto& param_in = value_in.function_parameter();
-                auto* type = Type(param_in.type());
-                value_out = b.FunctionParam(type);
-                if (param_in.has_name()) {
-                    mod_out_.SetName(value_out, param_in.name());
-                }
+            case pb::Value::KindCase::kFunctionParameter:
+                value_out = FunctionParameter(value_in.function_parameter());
                 break;
-            }
-            case pb::Value::KindCase::kBlockParameter: {
-                auto& param_in = value_in.block_parameter();
-                auto* type = Type(param_in.type());
-                value_out = b.BlockParam(type);
-                if (param_in.has_name()) {
-                    mod_out_.SetName(value_out, param_in.name());
-                }
+            case pb::Value::KindCase::kBlockParameter:
+                value_out = BlockParameter(value_in.block_parameter());
                 break;
-            }
-            case pb::Value::KindCase::kConstant: {
+            case pb::Value::KindCase::kConstant:
                 value_out = b.Constant(ConstantValue(value_in.constant()));
                 break;
-            }
             default:
                 TINT_ICE() << "invalid TypeDecl.kind: " << value_in.kind_case();
                 return nullptr;
@@ -699,6 +673,51 @@
         return value_out;
     }
 
+    ir::InstructionResult* InstructionResult(const pb::InstructionResult& res_in) {
+        auto* type = Type(res_in.type());
+        auto* res_out = b.InstructionResult(type);
+        if (res_in.has_name()) {
+            mod_out_.SetName(res_out, res_in.name());
+        }
+        return res_out;
+    }
+
+    ir::FunctionParam* FunctionParameter(const pb::FunctionParameter& param_in) {
+        auto* type = Type(param_in.type());
+        auto* param_out = b.FunctionParam(type);
+        if (param_in.has_name()) {
+            mod_out_.SetName(param_out, param_in.name());
+        }
+
+        if (param_in.has_attributes()) {
+            auto& attrs_in = param_in.attributes();
+            if (attrs_in.has_binding_point()) {
+                auto& bp_in = attrs_in.binding_point();
+                param_out->SetBindingPoint(bp_in.group(), bp_in.binding());
+            }
+            if (attrs_in.has_location()) {
+                param_out->SetLocation(Location(attrs_in.location()));
+            }
+            if (attrs_in.has_builtin()) {
+                param_out->SetBuiltin(BuiltinValue(attrs_in.builtin()));
+            }
+            if (attrs_in.invariant()) {
+                param_out->SetInvariant(true);
+            }
+        }
+
+        return param_out;
+    }
+
+    ir::BlockParam* BlockParameter(const pb::BlockParameter& param_in) {
+        auto* type = Type(param_in.type());
+        auto* param_out = b.BlockParam(type);
+        if (param_in.has_name()) {
+            mod_out_.SetName(param_out, param_in.name());
+        }
+        return param_out;
+    }
+
     ir::Value* Value(uint32_t id) { return id > 0 ? values_[id - 1] : nullptr; }
 
     template <typename T>
@@ -770,6 +789,15 @@
     ////////////////////////////////////////////////////////////////////////////
     // Attributes
     ////////////////////////////////////////////////////////////////////////////
+    ir::Location Location(const pb::Location& location_in) {
+        core::ir::Location location_out{};
+        location_out.value = location_in.value();
+        if (location_in.has_interpolation()) {
+            location_out.interpolation = Interpolation(location_in.interpolation());
+        }
+        return location_out;
+    }
+
     core::Interpolation Interpolation(const pb::Interpolation& interpolation_in) {
         core::Interpolation interpolation_out{};
         interpolation_out.type = InterpolationType(interpolation_in.type());
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index 3677967..9618179 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -136,11 +136,7 @@
         }
         if (auto ret_loc_in = fn_in->ReturnLocation()) {
             auto& ret_loc_out = *fn_out->mutable_return_location();
-            if (auto interpolation_in = ret_loc_in->interpolation) {
-                auto& interpolation_out = *ret_loc_out.mutable_interpolation();
-                Interpolation(interpolation_out, *interpolation_in);
-            }
-            ret_loc_out.set_value(ret_loc_in->value);
+            Location(ret_loc_out, *ret_loc_in);
         }
         fn_out->set_block(Block(fn_in->Block()));
     }
@@ -325,8 +321,7 @@
     void InstructionVar(pb::InstructionVar& var_out, const ir::Var* var_in) {
         if (auto bp_in = var_in->BindingPoint()) {
             auto& bp_out = *var_out.mutable_binding_point();
-            bp_out.set_group(bp_in->group);
-            bp_out.set_binding(bp_in->binding);
+            BindingPoint(bp_out, *bp_in);
         }
     }
 
@@ -497,6 +492,20 @@
         if (auto name = mod_in_.NameOf(param_in); name.IsValid()) {
             param_out.set_name(name.Name());
         }
+        if (auto bp_in = param_in->BindingPoint()) {
+            auto& bp_out = *param_out.mutable_attributes()->mutable_binding_point();
+            BindingPoint(bp_out, *bp_in);
+        }
+        if (auto location_in = param_in->Location()) {
+            auto& location_out = *param_out.mutable_attributes()->mutable_location();
+            Location(location_out, *location_in);
+        }
+        if (auto builtin_in = param_in->Builtin()) {
+            param_out.mutable_attributes()->set_builtin(BuiltinValue(*builtin_in));
+        }
+        if (param_in->Invariant()) {
+            param_out.mutable_attributes()->set_invariant(true);
+        }
     }
 
     void BlockParameter(pb::BlockParameter& param_out, const ir::BlockParam* param_in) {
@@ -563,6 +572,14 @@
     ////////////////////////////////////////////////////////////////////////////
     // Attributes
     ////////////////////////////////////////////////////////////////////////////
+    void Location(pb::Location& location_out, const ir::Location& location_in) {
+        if (auto interpolation_in = location_in.interpolation) {
+            auto& interpolation_out = *location_out.mutable_interpolation();
+            Interpolation(interpolation_out, *interpolation_in);
+        }
+        location_out.set_value(location_in.value);
+    }
+
     void Interpolation(pb::Interpolation& interpolation_out,
                        const core::Interpolation& interpolation_in) {
         interpolation_out.set_type(InterpolationType(interpolation_in.type));
@@ -571,6 +588,11 @@
         }
     }
 
+    void BindingPoint(pb::BindingPoint& binding_point_out, const BindingPoint& binding_point_in) {
+        binding_point_out.set_group(binding_point_in.group);
+        binding_point_out.set_binding(binding_point_in.binding);
+    }
+
     ////////////////////////////////////////////////////////////////////////////
     // Enums
     ////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 6fe9364..e1d600d 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -147,6 +147,7 @@
 message FunctionParameter {
     uint32 type = 1;  // Module.types
     optional string name = 2;
+    optional AttributesFunctionParameter attributes = 3;
 }
 
 message BlockParameter {
@@ -196,7 +197,7 @@
     optional PipelineStage pipeline_stage = 4;
     optional WorkgroupSize workgroup_size = 5;
     repeated uint32 parameters = 6;  // Module.values
-    optional ReturnLocation return_location = 7;
+    optional Location return_location = 7;
 }
 
 enum PipelineStage {
@@ -211,11 +212,6 @@
     uint32 z = 3;
 }
 
-message ReturnLocation {
-    uint32 value = 1;
-    optional Interpolation interpolation = 2;
-}
-
 ////////////////////////////////////////////////////////////////////////////////
 // Blocks
 ////////////////////////////////////////////////////////////////////////////////
@@ -357,11 +353,23 @@
     bool invariant = 6;
 }
 
+message AttributesFunctionParameter {
+    optional BuiltinValue builtin = 1;
+    optional Location location = 2;
+    optional BindingPoint binding_point = 3;
+    bool invariant = 4;
+}
+
 message Interpolation {
     InterpolationType type = 1;
     optional InterpolationSampling sampling = 2;
 }
 
+message Location {
+    uint32 value = 1;
+    optional Interpolation interpolation = 2;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Enums
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 4e14b26..095827e 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -126,6 +126,20 @@
     RUN_TEST();
 }
 
+TEST_F(IRBinaryRoundtripTest, Fn_ParameterAttributes) {
+    auto* fn = b.Function("Function", ty.void_());
+    auto* p0 = b.FunctionParam(ty.i32());
+    auto* p1 = b.FunctionParam(ty.u32());
+    auto* p2 = b.FunctionParam(ty.f32());
+    auto* p3 = b.FunctionParam(ty.bool_());
+    p0->SetBuiltin(BuiltinValue::kGlobalInvocationId);
+    p1->SetInvariant(true);
+    p2->SetLocation(10, Interpolation{InterpolationType::kFlat, InterpolationSampling::kCenter});
+    p3->SetBindingPoint(20, 30);
+    fn->SetParams({p0, p1, p2, p3});
+    RUN_TEST();
+}
+
 TEST_F(IRBinaryRoundtripTest, Fn_ReturnLocation) {
     auto* fn = b.Function("Function", ty.void_());
     fn->SetReturnLocation(42, std::nullopt);
diff --git a/src/tint/lang/core/ir/function.h b/src/tint/lang/core/ir/function.h
index 0193884..a03e510 100644
--- a/src/tint/lang/core/ir/function.h
+++ b/src/tint/lang/core/ir/function.h
@@ -120,23 +120,31 @@
     }
     /// @returns the return builtin attribute
     std::optional<enum ReturnBuiltin> ReturnBuiltin() const { return return_.builtin; }
+
     /// Clears the return builtin attribute.
     void ClearReturnBuiltin() { return_.builtin = {}; }
 
     /// Sets the return location
+    /// @param location the location to set
+    void SetReturnLocation(Location location) { return_.location = std::move(location); }
+
+    /// Sets the return location
     /// @param loc the location to set
     /// @param interp the interpolation
     void SetReturnLocation(uint32_t loc, std::optional<core::Interpolation> interp) {
         return_.location = {loc, interp};
     }
+
     /// @returns the return location
     std::optional<Location> ReturnLocation() const { return return_.location; }
+
     /// Clears the return location attribute.
     void ClearReturnLocation() { return_.location = {}; }
 
     /// Sets the return as invariant
     /// @param val the invariant value to set
     void SetReturnInvariant(bool val) { return_.invariant = val; }
+
     /// @returns the return invariant value
     bool ReturnInvariant() const { return return_.invariant; }
 
diff --git a/src/tint/lang/core/ir/function_param.h b/src/tint/lang/core/ir/function_param.h
index ea989b3..6cefb10 100644
--- a/src/tint/lang/core/ir/function_param.h
+++ b/src/tint/lang/core/ir/function_param.h
@@ -62,23 +62,31 @@
     }
     /// @returns the builtin set for the parameter
     std::optional<core::BuiltinValue> Builtin() const { return builtin_; }
+
     /// Clears the builtin attribute.
     void ClearBuiltin() { builtin_ = {}; }
 
     /// Sets the parameter as invariant
     /// @param val the value to set for invariant
     void SetInvariant(bool val) { invariant_ = val; }
+
     /// @returns true if parameter is invariant
     bool Invariant() const { return invariant_; }
 
     /// Sets the location
+    /// @param location the location
+    void SetLocation(ir::Location location) { location_ = std::move(location); }
+
+    /// Sets the location
     /// @param loc the location value
     /// @param interpolation if the location interpolation settings
     void SetLocation(uint32_t loc, std::optional<core::Interpolation> interpolation) {
         location_ = {loc, interpolation};
     }
+
     /// @returns the location if `Attributes` contains `kLocation`
-    std::optional<struct Location> Location() const { return location_; }
+    std::optional<ir::Location> Location() const { return location_; }
+
     /// Clears the location attribute.
     void ClearLocation() { location_ = {}; }
 
@@ -86,6 +94,7 @@
     /// @param group the group
     /// @param binding the binding
     void SetBindingPoint(uint32_t group, uint32_t binding) { binding_point_ = {group, binding}; }
+
     /// @returns the binding points if `Attributes` contains `kBindingPoint`
     std::optional<struct BindingPoint> BindingPoint() const { return binding_point_; }