[ir] Add parameter attributes.
This Cl adds attributes to function parameters. This includes all the
ones that can be written in WGSL along with BindingPoint which is added
by some transforms.
Bug: tint:1915
Change-Id: Id2e506318c255562bb1810281cd81575fde1236f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134820
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index bc74d5d..5d9eb86 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1253,6 +1253,7 @@
"ir/instruction.h",
"ir/load.cc",
"ir/load.h",
+ "ir/location.h",
"ir/loop.cc",
"ir/loop.h",
"ir/module.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 277d819..d444066 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -762,6 +762,7 @@
ir/instruction.h
ir/load.cc
ir/load.h
+ ir/location.h
ir/loop.cc
ir/loop.h
ir/module.cc
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 75cc2aa..2ff623e 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -139,6 +139,54 @@
Indent() << "}" << std::endl;
}
+void Disassembler::EmitParamAttributes(FunctionParam* p) {
+ if (!p->Invariant() && !p->Location().has_value() && !p->BindingPoint().has_value() &&
+ !p->Builtin().has_value()) {
+ return;
+ }
+
+ out_ << " [";
+
+ bool need_comma = false;
+ auto comma = [&]() {
+ if (need_comma) {
+ out_ << ", ";
+ }
+ };
+
+ if (p->Invariant()) {
+ comma();
+ out_ << "@invariant";
+ need_comma = true;
+ }
+ if (p->Location().has_value()) {
+ out_ << "@location(" << p->Location()->value << ")";
+ if (p->Location()->interpolation.has_value()) {
+ out_ << ", @interpolate(";
+ out_ << p->Location()->interpolation->type;
+ if (p->Location()->interpolation->sampling !=
+ builtin::InterpolationSampling::kUndefined) {
+ out_ << ", ";
+ out_ << p->Location()->interpolation->sampling;
+ }
+ out_ << ")";
+ }
+ need_comma = true;
+ }
+ if (p->BindingPoint().has_value()) {
+ comma();
+ out_ << "@binding_point(" << p->BindingPoint()->group << ", " << p->BindingPoint()->binding
+ << ")";
+ need_comma = true;
+ }
+ if (p->Builtin().has_value()) {
+ comma();
+ out_ << "@" << p->Builtin().value();
+ need_comma = true;
+ }
+ out_ << "]";
+}
+
void Disassembler::EmitFunction(const Function* func) {
in_function_ = true;
@@ -148,6 +196,8 @@
out_ << ", ";
}
out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName();
+
+ EmitParamAttributes(p);
}
out_ << "):" << func->ReturnType()->FriendlyName();
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 06f5b08..04e258c 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -59,6 +59,7 @@
void Walk(const Block* blk);
void WalkInternal(const Block* blk);
void EmitFunction(const Function* func);
+ void EmitParamAttributes(FunctionParam* p);
void EmitInstruction(const Instruction* inst);
void EmitValueWithType(const Value* val);
void EmitValue(const Value* val);
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 20555d0..ade2aa6 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -43,6 +43,7 @@
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/ast/int_literal_expression.h"
+#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/let.h"
#include "src/tint/ast/literal_expression.h"
@@ -311,10 +312,98 @@
utils::Vector<FunctionParam*, 1> params;
for (auto* p : ast_func->params) {
- const auto* param_sem = program_->Sem().Get(p);
+ const auto* param_sem = program_->Sem().Get(p)->As<sem::Parameter>();
auto* ty = param_sem->Type()->Clone(clone_ctx_.type_ctx);
auto* param = builder_.FunctionParam(ty);
+ // Note, interpolated is only valid when paired with Location, so it will only be set
+ // when the location is set.
+ std::optional<builtin::Interpolation> interpolation;
+ for (auto* attr : p->attributes) {
+ tint::Switch(
+ attr, //
+ [&](const ast::InterpolateAttribute* interp) {
+ auto type =
+ program_->Sem()
+ .Get(interp->type)
+ ->As<sem::BuiltinEnumExpression<builtin::InterpolationType>>();
+ builtin::InterpolationType interpolation_type = type->Value();
+
+ builtin::InterpolationSampling interpolation_sampling =
+ builtin::InterpolationSampling::kUndefined;
+ if (interp->sampling) {
+ auto sampling = program_->Sem()
+ .Get(interp->sampling)
+ ->As<sem::BuiltinEnumExpression<
+ builtin::InterpolationSampling>>();
+ interpolation_sampling = sampling->Value();
+ }
+
+ interpolation =
+ builtin::Interpolation{interpolation_type, interpolation_sampling};
+ },
+ [&](const ast::InvariantAttribute*) { param->SetInvariant(true); },
+ [&](const ast::BuiltinAttribute* b) {
+ if (auto* ident_sem =
+ program_->Sem()
+ .Get(b)
+ ->As<sem::BuiltinEnumExpression<builtin::BuiltinValue>>()) {
+ switch (ident_sem->Value()) {
+ case builtin::BuiltinValue::kVertexIndex:
+ param->SetBuiltin(FunctionParam::Builtin::kVertexIndex);
+ break;
+ case builtin::BuiltinValue::kInstanceIndex:
+ param->SetBuiltin(FunctionParam::Builtin::kInstanceIndex);
+ break;
+ case builtin::BuiltinValue::kPosition:
+ param->SetBuiltin(FunctionParam::Builtin::kPosition);
+ break;
+ case builtin::BuiltinValue::kFrontFacing:
+ param->SetBuiltin(FunctionParam::Builtin::kFrontFacing);
+ break;
+ case builtin::BuiltinValue::kLocalInvocationId:
+ param->SetBuiltin(FunctionParam::Builtin::kLocalInvocationId);
+ break;
+ case builtin::BuiltinValue::kLocalInvocationIndex:
+ param->SetBuiltin(
+ FunctionParam::Builtin::kLocalInvocationIndex);
+ break;
+ case builtin::BuiltinValue::kGlobalInvocationId:
+ param->SetBuiltin(FunctionParam::Builtin::kGlobalInvocationId);
+ break;
+ case builtin::BuiltinValue::kWorkgroupId:
+ param->SetBuiltin(FunctionParam::Builtin::kWorkgroupId);
+ break;
+ case builtin::BuiltinValue::kNumWorkgroups:
+ param->SetBuiltin(FunctionParam::Builtin::kNumWorkgroups);
+ break;
+ case builtin::BuiltinValue::kSampleIndex:
+ param->SetBuiltin(FunctionParam::Builtin::kSampleIndex);
+ break;
+ case builtin::BuiltinValue::kSampleMask:
+ param->SetBuiltin(FunctionParam::Builtin::kSampleMask);
+ break;
+ default:
+ TINT_ICE(IR, diagnostics_)
+ << "Unknown builtin value in parameter attributes "
+ << ident_sem->Value();
+ return;
+ }
+ } else {
+ TINT_ICE(IR, diagnostics_) << "Builtin attribute sem invalid";
+ return;
+ }
+ });
+
+ if (param_sem->Location().has_value()) {
+ param->SetLocation(param_sem->Location().value(), interpolation);
+ }
+ if (param_sem->BindingPoint().has_value()) {
+ param->SetBindingPoint(param_sem->BindingPoint()->group,
+ param_sem->BindingPoint()->binding);
+ }
+ }
+
scopes_.Set(p->name->symbol, param);
builder_.ir.SetName(param, p->name->symbol.NameView());
params.Push(param);
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index f9be668..c8dbd39 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -1264,5 +1264,85 @@
)");
}
+TEST_F(IR_BuilderImplTest, Func_WithParam_WithAttribute_Invariant) {
+ Func(
+ "f",
+ utils::Vector{Param("a", ty.vec4<f32>(),
+ utils::Vector{Invariant(), Builtin(builtin::BuiltinValue::kPosition)})},
+ ty.vec4<f32>(), utils::Vector{Return("a")},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)}, utils::Vector{Location(1_i)});
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(
+ Disassemble(m.Get()),
+ R"(%f = func(%a:vec4<f32> [@invariant, @position]):vec4<f32> [@fragment ra: @location(1)] -> %b1 {
+ %b1 = block {
+ ret %a
+ }
+}
+)");
+}
+
+TEST_F(IR_BuilderImplTest, Func_WithParam_WithAttribute_Location) {
+ Func("f", utils::Vector{Param("a", ty.f32(), utils::Vector{Location(2_i)})}, ty.f32(),
+ utils::Vector{Return("a")}, utils::Vector{Stage(ast::PipelineStage::kFragment)},
+ utils::Vector{Location(1_i)});
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%f = func(%a:f32 [@location(2)]):f32 [@fragment ra: @location(1)] -> %b1 {
+ %b1 = block {
+ ret %a
+ }
+}
+)");
+}
+
+TEST_F(IR_BuilderImplTest, Func_WithParam_WithAttribute_Location_WithInterpolation_LinearCentroid) {
+ Func("f",
+ utils::Vector{Param(
+ "a", ty.f32(),
+ utils::Vector{Location(2_i), Interpolate(builtin::InterpolationType::kLinear,
+ builtin::InterpolationSampling::kCentroid)})},
+ ty.f32(), utils::Vector{Return("a")}, utils::Vector{Stage(ast::PipelineStage::kFragment)},
+ utils::Vector{Location(1_i)});
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(
+ Disassemble(m.Get()),
+ R"(%f = func(%a:f32 [@location(2), @interpolate(linear, centroid)]):f32 [@fragment ra: @location(1)] -> %b1 {
+ %b1 = block {
+ ret %a
+ }
+}
+)");
+}
+
+TEST_F(IR_BuilderImplTest, Func_WithParam_WithAttribute_Location_WithInterpolation_Flat) {
+ Func("f",
+ utils::Vector{
+ Param("a", ty.f32(),
+ utils::Vector{Location(2_i), Interpolate(builtin::InterpolationType::kFlat)})},
+ ty.f32(), utils::Vector{Return("a")}, utils::Vector{Stage(ast::PipelineStage::kFragment)},
+ utils::Vector{Location(1_i)});
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(
+ Disassemble(m.Get()),
+ R"(%f = func(%a:f32 [@location(2), @interpolate(flat)]):f32 [@fragment ra: @location(1)] -> %b1 {
+ %b1 = block {
+ ret %a
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/function.cc b/src/tint/ir/function.cc
index 4e8ea4c..ef09065 100644
--- a/src/tint/ir/function.cc
+++ b/src/tint/ir/function.cc
@@ -51,8 +51,6 @@
return out << "position";
case Function::ReturnAttribute::kInvariant:
return out << "invariant";
- default:
- break;
}
return out << "<unknown>";
}
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index 172afdc..dbee570 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -48,8 +48,6 @@
/// Attributes attached to return types
enum class ReturnAttribute {
- /// No return attribute
- kNone,
/// Location attribute
kLocation,
/// Builtin Position attribute
diff --git a/src/tint/ir/function_param.cc b/src/tint/ir/function_param.cc
index a9570fa..ae94f24 100644
--- a/src/tint/ir/function_param.cc
+++ b/src/tint/ir/function_param.cc
@@ -22,4 +22,43 @@
FunctionParam::~FunctionParam() = default;
+utils::StringStream& operator<<(utils::StringStream& out, enum FunctionParam::Builtin value) {
+ switch (value) {
+ case FunctionParam::Builtin::kVertexIndex:
+ out << "vertex_index";
+ break;
+ case FunctionParam::Builtin::kInstanceIndex:
+ out << "instance_index";
+ break;
+ case FunctionParam::Builtin::kPosition:
+ out << "position";
+ break;
+ case FunctionParam::Builtin::kFrontFacing:
+ out << "front_facing";
+ break;
+ case FunctionParam::Builtin::kLocalInvocationId:
+ out << "local_invocation_id";
+ break;
+ case FunctionParam::Builtin::kLocalInvocationIndex:
+ out << "local_invocation_index";
+ break;
+ case FunctionParam::Builtin::kGlobalInvocationId:
+ out << "global_invocation_id";
+ break;
+ case FunctionParam::Builtin::kWorkgroupId:
+ out << "workgroup_id";
+ break;
+ case FunctionParam::Builtin::kNumWorkgroups:
+ out << "num_workgroups";
+ break;
+ case FunctionParam::Builtin::kSampleIndex:
+ out << "sample_index";
+ break;
+ case FunctionParam::Builtin::kSampleMask:
+ out << "sample_mask";
+ break;
+ }
+ return out;
+}
+
} // namespace tint::ir
diff --git a/src/tint/ir/function_param.h b/src/tint/ir/function_param.h
index 2da0584..198d314 100644
--- a/src/tint/ir/function_param.h
+++ b/src/tint/ir/function_param.h
@@ -15,6 +15,9 @@
#ifndef SRC_TINT_IR_FUNCTION_PARAM_H_
#define SRC_TINT_IR_FUNCTION_PARAM_H_
+#include <utility>
+
+#include "src/tint/ir/location.h"
#include "src/tint/ir/value.h"
#include "src/tint/utils/castable.h"
@@ -23,6 +26,40 @@
/// A function parameter in the IR.
class FunctionParam : public utils::Castable<FunctionParam, Value> {
public:
+ /// Builtin attribute
+ enum class Builtin {
+ /// Builtin Vertex index
+ kVertexIndex,
+ /// Builtin Instance index
+ kInstanceIndex,
+ /// Builtin Position
+ kPosition,
+ /// Builtin FrontFacing
+ kFrontFacing,
+ /// Builtin Local invocation id
+ kLocalInvocationId,
+ /// Builtin Local invocation index
+ kLocalInvocationIndex,
+ /// Builtin Global invocation id
+ kGlobalInvocationId,
+ /// Builtin Workgroup id
+ kWorkgroupId,
+ /// Builtin Num workgroups
+ kNumWorkgroups,
+ /// Builtin Sample index
+ kSampleIndex,
+ /// Builtin Sample mask
+ kSampleMask,
+ };
+
+ /// Binding information
+ struct BindingPoint {
+ /// The `@group` part of the binding point
+ uint32_t group = 0;
+ /// The `@binding` part of the binding point
+ uint32_t binding = 0;
+ };
+
/// Constructor
/// @param type the type of the var
explicit FunctionParam(const type::Type* type);
@@ -31,11 +68,47 @@
/// @returns the type of the var
const type::Type* Type() const override { return type_; }
+ /// Sets the builtin information. Note, it is currently an error if the builtin is already set.
+ /// @param val the builtin to set
+ void SetBuiltin(FunctionParam::Builtin val) {
+ TINT_ASSERT(IR, !builtin_.has_value());
+ builtin_ = val;
+ }
+ /// @returns the builtin set for the parameter
+ std::optional<FunctionParam::Builtin> Builtin() const { return builtin_; }
+
+ /// Sets the parameter as invariant
+ /// @param val the value to set for invariant
+ void SetInvariant(bool val) { invariant_ = val; }
+ /// @returns true if parameter is invariant
+ bool Invariant() const { return invariant_; }
+
+ /// Sets the location
+ /// @param loc the location value
+ /// @param interpolation if the location interpolation settings
+ void SetLocation(uint32_t loc, std::optional<builtin::Interpolation> interpolation) {
+ location_ = {loc, interpolation};
+ }
+ /// @returns the location if `Attributes` contains `kLocation`
+ std::optional<struct Location> Location() const { return location_; }
+
+ /// Sets the binding point
+ /// @param group the group
+ /// @param binding the binding
+ void SetBindingPoint(uint32_t group, uint32_t binding) { binding_point_ = {group, binding}; }
+ /// @returns the binding points if `Attributes` contains `kBindingPoint`
+ std::optional<struct BindingPoint> BindingPoint() const { return binding_point_; }
+
private:
- /// The type of the parameter
- const type::Type* type_;
+ const type::Type* type_ = nullptr;
+ std::optional<enum FunctionParam::Builtin> builtin_;
+ std::optional<struct Location> location_;
+ std::optional<struct BindingPoint> binding_point_;
+ bool invariant_ = false;
};
+utils::StringStream& operator<<(utils::StringStream& out, enum FunctionParam::Builtin value);
+
} // namespace tint::ir
#endif // SRC_TINT_IR_FUNCTION_PARAM_H_
diff --git a/src/tint/ir/location.h b/src/tint/ir/location.h
new file mode 100644
index 0000000..5edb882
--- /dev/null
+++ b/src/tint/ir/location.h
@@ -0,0 +1,34 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_LOCATION_H_
+#define SRC_TINT_IR_LOCATION_H_
+
+#include <optional>
+
+#include "src/tint/builtin/interpolation.h"
+
+namespace tint::ir {
+
+/// A function parameter in the IR.
+struct Location {
+ /// The location value
+ uint32_t value = 0;
+ /// The interpolation settings
+ std::optional<builtin::Interpolation> interpolation;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_LOCATION_H_