[tint][ir] Serialize function parameter attributes
Change-Id: I9dde39b1a14b37c4021d2e84c8f78b18205b991f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/165042
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index 7141a28..54a5544 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -171,13 +171,7 @@
params_out.Push(ValueAs<ir::FunctionParam>(param_in));
}
if (fn_in.has_return_location()) {
- auto& ret_loc_in = fn_in.return_location();
- core::ir::Location ret_loc_out{};
- ret_loc_out.value = ret_loc_in.value();
- if (ret_loc_in.has_interpolation()) {
- ret_loc_out.interpolation = Interpolation(ret_loc_in.interpolation());
- }
- fn_out->SetReturnLocation(ret_loc_out.value, std::move(ret_loc_out.interpolation));
+ fn_out->SetReturnLocation(Location(fn_in.return_location()));
}
fn_out->SetParams(std::move(params_out));
fn_out->SetBlock(Block(fn_in.block()));
@@ -657,41 +651,21 @@
ir::Value* CreateValue(const pb::Value& value_in) {
ir::Value* value_out = nullptr;
switch (value_in.kind_case()) {
- case pb::Value::KindCase::kFunction: {
+ case pb::Value::KindCase::kFunction:
value_out = Function(value_in.function());
break;
- }
- case pb::Value::KindCase::kInstructionResult: {
- auto& res_in = value_in.instruction_result();
- auto* type = Type(res_in.type());
- value_out = b.InstructionResult(type);
- if (res_in.has_name()) {
- mod_out_.SetName(value_out, res_in.name());
- }
+ case pb::Value::KindCase::kInstructionResult:
+ value_out = InstructionResult(value_in.instruction_result());
break;
- }
- case pb::Value::KindCase::kFunctionParameter: {
- auto& param_in = value_in.function_parameter();
- auto* type = Type(param_in.type());
- value_out = b.FunctionParam(type);
- if (param_in.has_name()) {
- mod_out_.SetName(value_out, param_in.name());
- }
+ case pb::Value::KindCase::kFunctionParameter:
+ value_out = FunctionParameter(value_in.function_parameter());
break;
- }
- case pb::Value::KindCase::kBlockParameter: {
- auto& param_in = value_in.block_parameter();
- auto* type = Type(param_in.type());
- value_out = b.BlockParam(type);
- if (param_in.has_name()) {
- mod_out_.SetName(value_out, param_in.name());
- }
+ case pb::Value::KindCase::kBlockParameter:
+ value_out = BlockParameter(value_in.block_parameter());
break;
- }
- case pb::Value::KindCase::kConstant: {
+ case pb::Value::KindCase::kConstant:
value_out = b.Constant(ConstantValue(value_in.constant()));
break;
- }
default:
TINT_ICE() << "invalid TypeDecl.kind: " << value_in.kind_case();
return nullptr;
@@ -699,6 +673,51 @@
return value_out;
}
+ ir::InstructionResult* InstructionResult(const pb::InstructionResult& res_in) {
+ auto* type = Type(res_in.type());
+ auto* res_out = b.InstructionResult(type);
+ if (res_in.has_name()) {
+ mod_out_.SetName(res_out, res_in.name());
+ }
+ return res_out;
+ }
+
+ ir::FunctionParam* FunctionParameter(const pb::FunctionParameter& param_in) {
+ auto* type = Type(param_in.type());
+ auto* param_out = b.FunctionParam(type);
+ if (param_in.has_name()) {
+ mod_out_.SetName(param_out, param_in.name());
+ }
+
+ if (param_in.has_attributes()) {
+ auto& attrs_in = param_in.attributes();
+ if (attrs_in.has_binding_point()) {
+ auto& bp_in = attrs_in.binding_point();
+ param_out->SetBindingPoint(bp_in.group(), bp_in.binding());
+ }
+ if (attrs_in.has_location()) {
+ param_out->SetLocation(Location(attrs_in.location()));
+ }
+ if (attrs_in.has_builtin()) {
+ param_out->SetBuiltin(BuiltinValue(attrs_in.builtin()));
+ }
+ if (attrs_in.invariant()) {
+ param_out->SetInvariant(true);
+ }
+ }
+
+ return param_out;
+ }
+
+ ir::BlockParam* BlockParameter(const pb::BlockParameter& param_in) {
+ auto* type = Type(param_in.type());
+ auto* param_out = b.BlockParam(type);
+ if (param_in.has_name()) {
+ mod_out_.SetName(param_out, param_in.name());
+ }
+ return param_out;
+ }
+
ir::Value* Value(uint32_t id) { return id > 0 ? values_[id - 1] : nullptr; }
template <typename T>
@@ -770,6 +789,15 @@
////////////////////////////////////////////////////////////////////////////
// Attributes
////////////////////////////////////////////////////////////////////////////
+ ir::Location Location(const pb::Location& location_in) {
+ core::ir::Location location_out{};
+ location_out.value = location_in.value();
+ if (location_in.has_interpolation()) {
+ location_out.interpolation = Interpolation(location_in.interpolation());
+ }
+ return location_out;
+ }
+
core::Interpolation Interpolation(const pb::Interpolation& interpolation_in) {
core::Interpolation interpolation_out{};
interpolation_out.type = InterpolationType(interpolation_in.type());
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index 3677967..9618179 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -136,11 +136,7 @@
}
if (auto ret_loc_in = fn_in->ReturnLocation()) {
auto& ret_loc_out = *fn_out->mutable_return_location();
- if (auto interpolation_in = ret_loc_in->interpolation) {
- auto& interpolation_out = *ret_loc_out.mutable_interpolation();
- Interpolation(interpolation_out, *interpolation_in);
- }
- ret_loc_out.set_value(ret_loc_in->value);
+ Location(ret_loc_out, *ret_loc_in);
}
fn_out->set_block(Block(fn_in->Block()));
}
@@ -325,8 +321,7 @@
void InstructionVar(pb::InstructionVar& var_out, const ir::Var* var_in) {
if (auto bp_in = var_in->BindingPoint()) {
auto& bp_out = *var_out.mutable_binding_point();
- bp_out.set_group(bp_in->group);
- bp_out.set_binding(bp_in->binding);
+ BindingPoint(bp_out, *bp_in);
}
}
@@ -497,6 +492,20 @@
if (auto name = mod_in_.NameOf(param_in); name.IsValid()) {
param_out.set_name(name.Name());
}
+ if (auto bp_in = param_in->BindingPoint()) {
+ auto& bp_out = *param_out.mutable_attributes()->mutable_binding_point();
+ BindingPoint(bp_out, *bp_in);
+ }
+ if (auto location_in = param_in->Location()) {
+ auto& location_out = *param_out.mutable_attributes()->mutable_location();
+ Location(location_out, *location_in);
+ }
+ if (auto builtin_in = param_in->Builtin()) {
+ param_out.mutable_attributes()->set_builtin(BuiltinValue(*builtin_in));
+ }
+ if (param_in->Invariant()) {
+ param_out.mutable_attributes()->set_invariant(true);
+ }
}
void BlockParameter(pb::BlockParameter& param_out, const ir::BlockParam* param_in) {
@@ -563,6 +572,14 @@
////////////////////////////////////////////////////////////////////////////
// Attributes
////////////////////////////////////////////////////////////////////////////
+ void Location(pb::Location& location_out, const ir::Location& location_in) {
+ if (auto interpolation_in = location_in.interpolation) {
+ auto& interpolation_out = *location_out.mutable_interpolation();
+ Interpolation(interpolation_out, *interpolation_in);
+ }
+ location_out.set_value(location_in.value);
+ }
+
void Interpolation(pb::Interpolation& interpolation_out,
const core::Interpolation& interpolation_in) {
interpolation_out.set_type(InterpolationType(interpolation_in.type));
@@ -571,6 +588,11 @@
}
}
+ void BindingPoint(pb::BindingPoint& binding_point_out, const BindingPoint& binding_point_in) {
+ binding_point_out.set_group(binding_point_in.group);
+ binding_point_out.set_binding(binding_point_in.binding);
+ }
+
////////////////////////////////////////////////////////////////////////////
// Enums
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 6fe9364..e1d600d 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -147,6 +147,7 @@
message FunctionParameter {
uint32 type = 1; // Module.types
optional string name = 2;
+ optional AttributesFunctionParameter attributes = 3;
}
message BlockParameter {
@@ -196,7 +197,7 @@
optional PipelineStage pipeline_stage = 4;
optional WorkgroupSize workgroup_size = 5;
repeated uint32 parameters = 6; // Module.values
- optional ReturnLocation return_location = 7;
+ optional Location return_location = 7;
}
enum PipelineStage {
@@ -211,11 +212,6 @@
uint32 z = 3;
}
-message ReturnLocation {
- uint32 value = 1;
- optional Interpolation interpolation = 2;
-}
-
////////////////////////////////////////////////////////////////////////////////
// Blocks
////////////////////////////////////////////////////////////////////////////////
@@ -357,11 +353,23 @@
bool invariant = 6;
}
+message AttributesFunctionParameter {
+ optional BuiltinValue builtin = 1;
+ optional Location location = 2;
+ optional BindingPoint binding_point = 3;
+ bool invariant = 4;
+}
+
message Interpolation {
InterpolationType type = 1;
optional InterpolationSampling sampling = 2;
}
+message Location {
+ uint32 value = 1;
+ optional Interpolation interpolation = 2;
+}
+
////////////////////////////////////////////////////////////////////////////////
// Enums
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 4e14b26..095827e 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -126,6 +126,20 @@
RUN_TEST();
}
+TEST_F(IRBinaryRoundtripTest, Fn_ParameterAttributes) {
+ auto* fn = b.Function("Function", ty.void_());
+ auto* p0 = b.FunctionParam(ty.i32());
+ auto* p1 = b.FunctionParam(ty.u32());
+ auto* p2 = b.FunctionParam(ty.f32());
+ auto* p3 = b.FunctionParam(ty.bool_());
+ p0->SetBuiltin(BuiltinValue::kGlobalInvocationId);
+ p1->SetInvariant(true);
+ p2->SetLocation(10, Interpolation{InterpolationType::kFlat, InterpolationSampling::kCenter});
+ p3->SetBindingPoint(20, 30);
+ fn->SetParams({p0, p1, p2, p3});
+ RUN_TEST();
+}
+
TEST_F(IRBinaryRoundtripTest, Fn_ReturnLocation) {
auto* fn = b.Function("Function", ty.void_());
fn->SetReturnLocation(42, std::nullopt);
diff --git a/src/tint/lang/core/ir/function.h b/src/tint/lang/core/ir/function.h
index 0193884..a03e510 100644
--- a/src/tint/lang/core/ir/function.h
+++ b/src/tint/lang/core/ir/function.h
@@ -120,23 +120,31 @@
}
/// @returns the return builtin attribute
std::optional<enum ReturnBuiltin> ReturnBuiltin() const { return return_.builtin; }
+
/// Clears the return builtin attribute.
void ClearReturnBuiltin() { return_.builtin = {}; }
/// Sets the return location
+ /// @param location the location to set
+ void SetReturnLocation(Location location) { return_.location = std::move(location); }
+
+ /// Sets the return location
/// @param loc the location to set
/// @param interp the interpolation
void SetReturnLocation(uint32_t loc, std::optional<core::Interpolation> interp) {
return_.location = {loc, interp};
}
+
/// @returns the return location
std::optional<Location> ReturnLocation() const { return return_.location; }
+
/// Clears the return location attribute.
void ClearReturnLocation() { return_.location = {}; }
/// Sets the return as invariant
/// @param val the invariant value to set
void SetReturnInvariant(bool val) { return_.invariant = val; }
+
/// @returns the return invariant value
bool ReturnInvariant() const { return return_.invariant; }
diff --git a/src/tint/lang/core/ir/function_param.h b/src/tint/lang/core/ir/function_param.h
index ea989b3..6cefb10 100644
--- a/src/tint/lang/core/ir/function_param.h
+++ b/src/tint/lang/core/ir/function_param.h
@@ -62,23 +62,31 @@
}
/// @returns the builtin set for the parameter
std::optional<core::BuiltinValue> Builtin() const { return builtin_; }
+
/// Clears the builtin attribute.
void ClearBuiltin() { builtin_ = {}; }
/// Sets the parameter as invariant
/// @param val the value to set for invariant
void SetInvariant(bool val) { invariant_ = val; }
+
/// @returns true if parameter is invariant
bool Invariant() const { return invariant_; }
/// Sets the location
+ /// @param location the location
+ void SetLocation(ir::Location location) { location_ = std::move(location); }
+
+ /// Sets the location
/// @param loc the location value
/// @param interpolation if the location interpolation settings
void SetLocation(uint32_t loc, std::optional<core::Interpolation> interpolation) {
location_ = {loc, interpolation};
}
+
/// @returns the location if `Attributes` contains `kLocation`
- std::optional<struct Location> Location() const { return location_; }
+ std::optional<ir::Location> Location() const { return location_; }
+
/// Clears the location attribute.
void ClearLocation() { location_ = {}; }
@@ -86,6 +94,7 @@
/// @param group the group
/// @param binding the binding
void SetBindingPoint(uint32_t group, uint32_t binding) { binding_point_ = {group, binding}; }
+
/// @returns the binding points if `Attributes` contains `kBindingPoint`
std::optional<struct BindingPoint> BindingPoint() const { return binding_point_; }