tint: Cleanup of IntrinsicTable
Remove the ProgramBuilder from ClosedState and use a pointer for the
'overload' field instead of a reference. Let's the Candidate be
copy-assignable, which in turn, allows the Candidates vector to be
sorted directly, instead of jumping through hoops to use moves.
Replace random mix of 'int', 'uint8_t' with 'size_t' (externally to the
constant table data). Reduces fragile weak binding between distant code.
Swap the overload scoring order (high-best -> low-best). Remove the
'matched' field - we can now just check whether the 'score' is 0.
Further simplifies sorting.
Change-Id: I4a4b7934be337306202647d096c546eab5c8498f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90641
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 5ca8ee2..58ba84f 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -104,12 +104,10 @@
/// Used by the MatchState.
class ClosedState {
public:
- explicit ClosedState(ProgramBuilder& b) : builder(b) {}
-
/// If the type with index `idx` is open, then it is closed with type `ty` and
/// Type() returns true. If the type is closed, then `Type()` returns true iff
/// it is equal to `ty`.
- bool Type(uint32_t idx, const sem::Type* ty) {
+ bool Type(size_t idx, const sem::Type* ty) {
auto res = types_.emplace(idx, ty);
return res.second || res.first->second == ty;
}
@@ -117,33 +115,27 @@
/// If the number with index `idx` is open, then it is closed with number
/// `number` and Num() returns true. If the number is closed, then `Num()`
/// returns true iff it is equal to `ty`.
- bool Num(uint32_t idx, Number number) {
+ bool Num(size_t idx, Number number) {
auto res = numbers_.emplace(idx, number.Value());
return res.second || res.first->second == number.Value();
}
/// Type returns the closed type with index `idx`, or nullptr if the type was not closed.
- const sem::Type* Type(uint32_t idx) const {
+ const sem::Type* Type(size_t idx) const {
auto it = types_.find(idx);
return (it != types_.end()) ? it->second : nullptr;
}
/// Type returns the number type with index `idx`.
/// An ICE is raised if the number is not closed.
- Number Num(uint32_t idx) const {
+ Number Num(size_t idx) const {
auto it = numbers_.find(idx);
- if (it == numbers_.end()) {
- TINT_ICE(Resolver, builder.Diagnostics())
- << "number with index " << idx << " is not closed";
- return Number::invalid;
- }
- return Number(it->second);
+ return (it != numbers_.end()) ? Number(it->second) : Number::invalid;
}
private:
- ProgramBuilder& builder;
- std::unordered_map<uint32_t, const sem::Type*> types_;
- std::unordered_map<uint32_t, uint32_t> numbers_;
+ std::unordered_map<size_t, const sem::Type*> types_;
+ std::unordered_map<size_t, uint32_t> numbers_;
};
/// Index type used for matcher indices
@@ -158,7 +150,7 @@
MatchState(ProgramBuilder& b,
ClosedState& c,
const Matchers& m,
- const OverloadInfo& o,
+ const OverloadInfo* o,
MatcherIndex const* matcher_indices)
: builder(b), closed(c), matchers(m), overload(o), matcher_indices_(matcher_indices) {}
@@ -169,7 +161,7 @@
/// The type and number matchers
Matchers const& matchers;
/// The current overload being evaluated
- OverloadInfo const& overload;
+ OverloadInfo const* overload;
/// Type uses the next TypeMatcher from the matcher indices to match the type
/// `ty`. If the type matches, the canonical expected type is returned. If the
@@ -240,7 +232,7 @@
class OpenTypeMatcher : public TypeMatcher {
public:
/// Constructor
- explicit OpenTypeMatcher(uint32_t index) : index_(index) {}
+ explicit OpenTypeMatcher(size_t index) : index_(index) {}
const sem::Type* Match(MatchState& state, const sem::Type* type) const override {
if (type->Is<Any>()) {
@@ -252,7 +244,7 @@
std::string String(MatchState& state) const override;
private:
- uint32_t index_;
+ size_t index_;
};
/// OpenNumberMatcher is a Matcher for an open number.
@@ -260,7 +252,7 @@
/// consistent for the overload)
class OpenNumberMatcher : public NumberMatcher {
public:
- explicit OpenNumberMatcher(uint32_t index) : index_(index) {}
+ explicit OpenNumberMatcher(size_t index) : index_(index) {}
Number Match(MatchState& state, Number number) const override {
if (number.IsAny()) {
@@ -272,7 +264,7 @@
std::string String(MatchState& state) const override;
private:
- uint32_t index_;
+ size_t index_;
};
////////////////////////////////////////////////////////////////////////////////
@@ -879,16 +871,16 @@
/// Candidate holds information about an overload evaluated for resolution.
struct Candidate {
/// The candidate overload
- const OverloadInfo& overload;
+ const OverloadInfo* overload;
/// The closed types and numbers
ClosedState closed;
/// The parameter types for the candidate overload
std::vector<IntrinsicPrototype::Parameter> parameters;
- /// True if the candidate is a viable match for the call
- bool matched;
- /// The match-score of the candidate overload. Used for diagnostics when no overload
- /// matches. Higher scores are displayed first (top-most).
- int score;
+ /// The match-score of the candidate overload.
+ /// A score of zero indicates an exact match.
+ /// Non-zero scores are used for diagnostics when no overload matches.
+ /// Lower scores are displayed first (top-most).
+ size_t score;
};
/// A list of candidates
@@ -922,7 +914,7 @@
/// arguments. For example `vec3<f32>()` would have the first template-type closed
/// as `f32`.
/// @returns the evaluated Candidate information.
- Candidate ScoreOverload(const OverloadInfo& overload,
+ Candidate ScoreOverload(const OverloadInfo* overload,
const std::vector<const sem::Type*>& args,
ClosedState closed) const;
@@ -931,12 +923,12 @@
/// @param overload the overload being evaluated
/// @param matcher_indices pointer to a list of matcher indices
MatchState Match(ClosedState& closed,
- const OverloadInfo& overload,
+ const OverloadInfo* overload,
MatcherIndex const* matcher_indices) const;
// Prints the overload for emitting diagnostics
void PrintOverload(std::ostream& ss,
- const OverloadInfo& overload,
+ const OverloadInfo* overload,
const char* intrinsic_name) const;
// Prints the list of candidates for emitting diagnostics
@@ -945,7 +937,7 @@
const char* intrinsic_name) const;
/// Raises an ICE when multiple overload candidates match, as this should never happen.
- void ErrMultipleOverloadsMatched(uint32_t num_matched,
+ void ErrMultipleOverloadsMatched(size_t num_matched,
const char* intrinsic_name,
const std::vector<const sem::Type*>& args,
ClosedState closed,
@@ -988,11 +980,11 @@
}
std::string OpenTypeMatcher::String(MatchState& state) const {
- return state.overload.open_types[index_].name;
+ return state.overload->open_types[index_].name;
}
std::string OpenNumberMatcher::String(MatchState& state) const {
- return state.overload.open_numbers[index_].name;
+ return state.overload->open_numbers[index_].name;
}
Impl::Impl(ProgramBuilder& b) : builder(b) {}
@@ -1016,8 +1008,8 @@
};
// Resolve the intrinsic overload
- auto match = MatchIntrinsic(kBuiltins[static_cast<uint32_t>(builtin_type)], intrinsic_name,
- args, ClosedState(builder), on_no_match);
+ auto match = MatchIntrinsic(kBuiltins[static_cast<size_t>(builtin_type)], intrinsic_name, args,
+ ClosedState{}, on_no_match);
if (!match.overload) {
return {};
}
@@ -1050,7 +1042,7 @@
IntrinsicTable::UnaryOperator Impl::Lookup(ast::UnaryOp op,
const sem::Type* arg,
const Source& source) {
- auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<uint32_t, const char*> {
+ auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<size_t, const char*> {
switch (op) {
case ast::UnaryOp::kComplement:
return {kUnaryOperatorComplement, "operator ~ "};
@@ -1078,7 +1070,7 @@
// Resolve the intrinsic overload
auto match = MatchIntrinsic(kUnaryOperators[intrinsic_index], intrinsic_name, {arg},
- ClosedState(builder), on_no_match);
+ ClosedState{}, on_no_match);
if (!match.overload) {
return {};
}
@@ -1091,7 +1083,7 @@
const sem::Type* rhs,
const Source& source,
bool is_compound) {
- auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<uint32_t, const char*> {
+ auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<size_t, const char*> {
switch (op) {
case ast::BinaryOp::kAnd:
return {kBinaryOperatorAnd, is_compound ? "operator &= " : "operator & "};
@@ -1149,7 +1141,7 @@
// Resolve the intrinsic overload
auto match = MatchIntrinsic(kBinaryOperators[intrinsic_index], intrinsic_name, {lhs, rhs},
- ClosedState(builder), on_no_match);
+ ClosedState{}, on_no_match);
if (!match.overload) {
return {};
}
@@ -1170,7 +1162,7 @@
<< std::endl;
Candidates ctor, conv;
for (auto candidate : candidates) {
- if (candidate.overload.flags.Contains(OverloadFlag::kIsConstructor)) {
+ if (candidate.overload->flags.Contains(OverloadFlag::kIsConstructor)) {
ctor.emplace_back(candidate);
} else {
conv.emplace_back(candidate);
@@ -1192,13 +1184,13 @@
};
// If a template type was provided, then close the 0'th type with this.
- ClosedState closed(builder);
+ ClosedState closed;
if (template_arg) {
closed.Type(0, template_arg);
}
// Resolve the intrinsic overload
- auto match = MatchIntrinsic(kConstructorsAndConverters[static_cast<uint32_t>(type)], name, args,
+ auto match = MatchIntrinsic(kConstructorsAndConverters[static_cast<size_t>(type)], name, args,
closed, on_no_match);
if (!match.overload) {
return {};
@@ -1232,40 +1224,21 @@
const std::vector<const sem::Type*>& args,
ClosedState closed,
OnNoMatch on_no_match) const {
- uint32_t num_matched = 0;
+ size_t num_matched = 0;
Candidates candidates;
candidates.reserve(intrinsic.num_overloads);
- for (uint8_t overload_idx = 0; overload_idx < intrinsic.num_overloads; overload_idx++) {
- auto candidate = ScoreOverload(intrinsic.overloads[overload_idx], args, closed);
- if (candidate.matched) {
+ for (size_t overload_idx = 0; overload_idx < static_cast<size_t>(intrinsic.num_overloads);
+ overload_idx++) {
+ auto candidate = ScoreOverload(&intrinsic.overloads[overload_idx], args, closed);
+ if (candidate.score == 0) {
num_matched++;
}
candidates.emplace_back(std::move(candidate));
}
// Sort the candidates with the most promising first
- {
- std::vector<size_t> candidate_indices(candidates.size());
- for (size_t i = 0; i < candidate_indices.size(); i++) {
- candidate_indices[i] = i;
- }
- std::stable_sort(candidate_indices.begin(), candidate_indices.end(),
- [&](size_t a, size_t b) {
- if (candidates[a].matched && !candidates[b].matched) {
- return true;
- }
- if (candidates[b].matched && !candidates[a].matched) {
- return false;
- }
- return candidates[a].score > candidates[b].score;
- });
- Candidates candidates_sorted;
- candidates_sorted.reserve(candidate_indices.size());
- for (size_t idx : candidate_indices) {
- candidates_sorted.emplace_back(std::move(candidates[idx]));
- }
- std::swap(candidates, candidates_sorted);
- }
+ std::stable_sort(candidates.begin(), candidates.end(),
+ [&](const Candidate& a, const Candidate& b) { return a.score < b.score; });
// How many candidates matched?
switch (num_matched) {
@@ -1282,7 +1255,7 @@
// Build the return type
const sem::Type* return_type = nullptr;
- if (auto* indices = match.overload.return_matcher_indices) {
+ if (auto* indices = match.overload->return_matcher_indices) {
Any any;
return_type = Match(match.closed, match.overload, indices).Type(&any);
if (!return_type) {
@@ -1293,101 +1266,91 @@
return_type = builder.create<sem::Void>();
}
- return IntrinsicPrototype{&match.overload, return_type, std::move(match.parameters)};
+ return IntrinsicPrototype{match.overload, return_type, std::move(match.parameters)};
}
-Impl::Candidate Impl::ScoreOverload(const OverloadInfo& overload,
+Impl::Candidate Impl::ScoreOverload(const OverloadInfo* overload,
const std::vector<const sem::Type*>& args,
ClosedState closed) const {
- // Score weight for argument <-> parameter count matches / mismatches
+ // Penalty weights for overload mismatching.
// This scoring is used to order the suggested overloads in diagnostic on overload mismatch, and
// has no impact for a correct program.
- // The overloads with the highest score will be displayed first (top-most).
- constexpr int kScorePerParamArgMismatch = -1;
- constexpr int kScorePerMatchedParam = 2;
- constexpr int kScorePerMatchedOpenType = 1;
- constexpr int kScorePerMatchedOpenNumber = 1;
+ // The overloads with the lowest score will be displayed first (top-most).
+ constexpr int kMismatchedParamCountPenalty = 3;
+ constexpr int kMismatchedParamTypePenalty = 2;
+ constexpr int kMismatchedOpenTypePenalty = 1;
+ constexpr int kMismatchedOpenNumberPenalty = 1;
- uint32_t num_parameters = static_cast<uint32_t>(overload.num_parameters);
- uint32_t num_arguments = static_cast<uint32_t>(args.size());
+ size_t num_parameters = static_cast<size_t>(overload->num_parameters);
+ size_t num_arguments = static_cast<size_t>(args.size());
- bool overload_matched = true;
- int overload_score = 0;
-
- if (static_cast<uint64_t>(args.size()) >
- static_cast<uint64_t>(std::numeric_limits<uint32_t>::max())) {
- overload_matched = false; // No overload has this number of arguments.
- }
+ size_t score = 0;
if (num_parameters != num_arguments) {
- overload_score += kScorePerParamArgMismatch * (std::max(num_parameters, num_arguments) -
- std::min(num_parameters, num_arguments));
- overload_matched = false;
+ score += kMismatchedParamCountPenalty * (std::max(num_parameters, num_arguments) -
+ std::min(num_parameters, num_arguments));
}
std::vector<IntrinsicPrototype::Parameter> parameters;
auto num_params = std::min(num_parameters, num_arguments);
- for (uint32_t p = 0; p < num_params; p++) {
- auto& parameter = overload.parameters[p];
+ for (size_t p = 0; p < num_params; p++) {
+ auto& parameter = overload->parameters[p];
auto* indices = parameter.matcher_indices;
auto* type = Match(closed, overload, indices).Type(args[p]->UnwrapRef());
if (type) {
parameters.emplace_back(IntrinsicPrototype::Parameter{type, parameter.usage});
- overload_score += kScorePerMatchedParam;
} else {
- overload_matched = false;
+ score += kMismatchedParamTypePenalty;
}
}
- if (overload_matched) {
+ if (score == 0) {
// Check all constrained open types matched
- for (uint32_t ot = 0; ot < overload.num_open_types; ot++) {
- auto& open_type = overload.open_types[ot];
+ for (size_t ot = 0; ot < overload->num_open_types; ot++) {
+ auto& open_type = overload->open_types[ot];
if (open_type.matcher_index != kNoMatcher) {
auto* closed_type = closed.Type(ot);
auto* matcher_index = &open_type.matcher_index;
- if (closed_type && Match(closed, overload, matcher_index).Type(closed_type)) {
- overload_score += kScorePerMatchedOpenType;
- } else {
- overload_matched = false;
+ if (!closed_type || !Match(closed, overload, matcher_index).Type(closed_type)) {
+ score += kMismatchedOpenTypePenalty;
}
}
}
}
- if (overload_matched) {
+ if (score == 0) {
// Check all constrained open numbers matched
- for (uint32_t on = 0; on < overload.num_open_numbers; on++) {
- auto& open_number = overload.open_numbers[on];
+ for (size_t on = 0; on < overload->num_open_numbers; on++) {
+ auto& open_number = overload->open_numbers[on];
if (open_number.matcher_index != kNoMatcher) {
+ auto closed_num = closed.Num(on);
auto* index = &open_number.matcher_index;
- if (Match(closed, overload, index).Num(closed.Num(on)).IsValid()) {
- overload_score += kScorePerMatchedOpenNumber;
- } else {
- overload_matched = false;
+ if (!closed_num.IsValid() ||
+ !Match(closed, overload, index).Num(closed_num).IsValid()) {
+ score += kMismatchedOpenNumberPenalty;
}
}
}
}
- return Candidate{overload, closed, parameters, overload_matched, overload_score};
+ return Candidate{overload, closed, parameters, score};
}
MatchState Impl::Match(ClosedState& closed,
- const OverloadInfo& overload,
+ const OverloadInfo* overload,
MatcherIndex const* matcher_indices) const {
return MatchState(builder, closed, matchers, overload, matcher_indices);
}
void Impl::PrintOverload(std::ostream& ss,
- const OverloadInfo& overload,
+ const OverloadInfo* overload,
const char* intrinsic_name) const {
- ClosedState closed(builder);
+ ClosedState closed;
ss << intrinsic_name << "(";
- for (uint32_t p = 0; p < overload.num_parameters; p++) {
- auto& parameter = overload.parameters[p];
+ for (size_t p = 0; p < overload->num_parameters; p++) {
+ auto& parameter = overload->parameters[p];
if (p > 0) {
ss << ", ";
}
@@ -1398,9 +1361,9 @@
ss << Match(closed, overload, indices).TypeName();
}
ss << ")";
- if (overload.return_matcher_indices) {
+ if (overload->return_matcher_indices) {
ss << " -> ";
- auto* indices = overload.return_matcher_indices;
+ auto* indices = overload->return_matcher_indices;
ss << Match(closed, overload, indices).TypeName();
}
@@ -1409,8 +1372,8 @@
ss << (first ? " where: " : ", ");
first = false;
};
- for (uint32_t i = 0; i < overload.num_open_types; i++) {
- auto& open_type = overload.open_types[i];
+ for (size_t i = 0; i < overload->num_open_types; i++) {
+ auto& open_type = overload->open_types[i];
if (open_type.matcher_index != kNoMatcher) {
separator();
ss << open_type.name;
@@ -1418,8 +1381,8 @@
ss << " is " << Match(closed, overload, index).TypeName();
}
}
- for (uint32_t i = 0; i < overload.num_open_numbers; i++) {
- auto& open_number = overload.open_numbers[i];
+ for (size_t i = 0; i < overload->num_open_numbers; i++) {
+ auto& open_number = overload->open_numbers[i];
if (open_number.matcher_index != kNoMatcher) {
separator();
ss << open_number.name;
@@ -1463,14 +1426,14 @@
return matcher->String(*this);
}
-void Impl::ErrMultipleOverloadsMatched(uint32_t num_matched,
+void Impl::ErrMultipleOverloadsMatched(size_t num_matched,
const char* intrinsic_name,
const std::vector<const sem::Type*>& args,
ClosedState closed,
Candidates candidates) const {
std::stringstream ss;
ss << num_matched << " overloads matched " << intrinsic_name;
- for (uint32_t i = 0; i < 0xffffffffu; i++) {
+ for (size_t i = 0; i < std::numeric_limits<size_t>::max(); i++) {
if (auto* ty = closed.Type(i)) {
ss << ((i == 0) ? "<" : ", ") << ty->FriendlyName(builder.Symbols());
} else if (i > 0) {
@@ -1489,7 +1452,7 @@
}
ss << "):\n";
for (auto& candidate : candidates) {
- if (candidate.matched) {
+ if (candidate.score == 0) {
ss << " ";
PrintOverload(ss, candidate.overload, intrinsic_name);
ss << std::endl;