[tint][intrinsics] Optimize overload matching.
Fail fast as a first pass. Do the expensive 'full' scoring when
displaying ranked overloads.
Change-Id: I2b4a2ae1612143ebcc2af726fafcf7ef44b3f688
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/175246
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/intrinsic/table.cc b/src/tint/lang/core/intrinsic/table.cc
index 5d749ed..d7c080e 100644
--- a/src/tint/lang/core/intrinsic/table.cc
+++ b/src/tint/lang/core/intrinsic/table.cc
@@ -69,17 +69,17 @@
/// Candidate holds information about an overload evaluated for resolution.
struct Candidate {
- /// The candidate overload
- const OverloadInfo* overload;
- /// The template types and numbers
- TemplateState templates;
- /// The parameter types for the candidate overload
- Vector<Overload::Parameter, kNumFixedParams> parameters;
/// The match-score of the candidate overload.
/// A score of zero indicates an exact match.
/// Non-zero scores are used for diagnostics when no overload matches.
/// Lower scores are displayed first (top-most).
- size_t score;
+ size_t score = 0;
+ /// The candidate overload
+ const OverloadInfo* overload = nullptr;
+ /// The template types and numbers
+ TemplateState templates{};
+ /// The parameter types for the candidate overload
+ Vector<Overload::Parameter, kNumFixedParams> parameters{};
};
/// A list of candidates
@@ -122,12 +122,25 @@
EvaluationStage earliest_eval_stage,
const OnNoMatch& on_no_match);
+/// The scoring mode for ScoreOverload()
+enum class ScoreMode {
+ /// If the overload doesn't match, then the returned Candidate will simply have a score of 1.
+ /// No other fields will be populated.
+ kEarlyReject,
+ /// A more expensive score calculations will be made for the overload, which can be used
+ /// to rank potential overloads
+ kFull
+};
+
/// Evaluates the single overload for the provided argument types.
/// @param context the intrinsic context
/// @param overload the overload being considered
/// @param template_args the template argument types
/// @param args the argument types
+/// @tparam MODE the scoring mode to use. Passed as a template argument to ensure that the
+/// extremely-hot function is specialized without scoring logic for the common code path.
/// @returns the evaluated Candidate information.
+template <ScoreMode MODE>
Candidate ScoreOverload(Context& context,
const OverloadInfo& overload,
VectorRef<const core::type::Type*> template_args,
@@ -198,14 +211,15 @@
VectorRef<const core::type::Type*> args,
EvaluationStage earliest_eval_stage,
const OnNoMatch& on_no_match) {
+ const size_t num_overloads = static_cast<size_t>(intrinsic.num_overloads);
size_t num_matched = 0;
size_t match_idx = 0;
Vector<Candidate, kNumFixedCandidates> candidates;
candidates.Reserve(intrinsic.num_overloads);
- for (size_t overload_idx = 0; overload_idx < static_cast<size_t>(intrinsic.num_overloads);
- overload_idx++) {
+ for (size_t overload_idx = 0; overload_idx < num_overloads; overload_idx++) {
auto& overload = context.data[intrinsic.overloads + overload_idx];
- auto candidate = ScoreOverload(context, overload, template_args, args, earliest_eval_stage);
+ auto candidate = ScoreOverload<ScoreMode::kEarlyReject>(context, overload, template_args,
+ args, earliest_eval_stage);
if (candidate.score == 0) {
match_idx = overload_idx;
num_matched++;
@@ -214,7 +228,13 @@
}
// How many candidates matched?
- if (num_matched == 0) {
+ if (TINT_UNLIKELY(num_matched == 0)) {
+ // Perform the full scoring of each overload
+ for (size_t overload_idx = 0; overload_idx < num_overloads; overload_idx++) {
+ auto& overload = context.data[intrinsic.overloads + overload_idx];
+ candidates[overload_idx] = ScoreOverload<ScoreMode::kFull>(
+ context, overload, template_args, args, earliest_eval_stage);
+ }
// Sort the candidates with the most promising first
SortCandidates(candidates);
return on_no_match(std::move(candidates));
@@ -253,11 +273,21 @@
context.data[match.overload->const_eval_fn]};
}
+template <ScoreMode MODE>
Candidate ScoreOverload(Context& context,
const OverloadInfo& overload,
VectorRef<const core::type::Type*> template_args,
VectorRef<const core::type::Type*> args,
EvaluationStage earliest_eval_stage) {
+#define MATCH_FAILURE(PENALTY) \
+ do { \
+ if constexpr (MODE == ScoreMode::kEarlyReject) { \
+ return Candidate{1}; \
+ } else { \
+ score += PENALTY; \
+ } \
+ } while (false)
+
// Penalty weights for overload mismatching.
// This scoring is used to order the suggested overloads in diagnostic on overload mismatch, and
// has no impact for a correct program.
@@ -275,8 +305,8 @@
size_t score = 0;
if (num_parameters != num_arguments) {
- score += kMismatchedParamCountPenalty * (std::max(num_parameters, num_arguments) -
- std::min(num_parameters, num_arguments));
+ MATCH_FAILURE(kMismatchedParamCountPenalty * (std::max(num_parameters, num_arguments) -
+ std::min(num_parameters, num_arguments)));
}
if (score == 0) {
@@ -284,9 +314,9 @@
const size_t expected_templates = overload.num_explicit_templates;
const size_t provided_templates = template_args.Length();
if (provided_templates != expected_templates) {
- score += kMismatchedExplicitTemplateCountPenalty *
- (std::max(expected_templates, provided_templates) -
- std::min(expected_templates, provided_templates));
+ MATCH_FAILURE(kMismatchedExplicitTemplateCountPenalty *
+ (std::max(expected_templates, provided_templates) -
+ std::min(expected_templates, provided_templates)));
}
}
@@ -303,7 +333,7 @@
type = Match(context, templates, overload, matcher_indices, earliest_eval_stage)
.Type(type);
if (!type) {
- score += kMismatchedExplicitTemplateTypePenalty;
+ MATCH_FAILURE(kMismatchedExplicitTemplateTypePenalty);
continue;
}
}
@@ -325,7 +355,7 @@
auto* matcher_indices = context.data[parameter.matcher_indices];
if (!Match(context, templates, overload, matcher_indices, earliest_eval_stage)
.Type(args[p])) {
- score += kMismatchedParamTypePenalty;
+ MATCH_FAILURE(kMismatchedParamTypePenalty);
}
}
@@ -359,7 +389,7 @@
continue;
}
}
- score += kMismatchedImplicitTemplateTypePenalty;
+ MATCH_FAILURE(kMismatchedImplicitTemplateTypePenalty);
break;
}
@@ -369,7 +399,7 @@
// constraint matchers.
auto number = templates.Num(i);
if (!number.IsValid() || !matcher.Num(number).IsValid()) {
- score += kMismatchedImplicitTemplateNumberPenalty;
+ MATCH_FAILURE(kMismatchedImplicitTemplateNumberPenalty);
}
}
}
@@ -389,7 +419,8 @@
}
}
- return Candidate{&overload, templates, parameters, score};
+ return Candidate{score, &overload, templates, parameters};
+#undef MATCH_FAILURE
}
Result<Candidate, std::string> ResolveCandidate(Context& context,