Add ranking_score_threshold to milli

This commit is contained in:
Louis Dureuil
2024-04-11 19:04:06 +02:00
parent 75d5c0ae1f
commit aac1d769a7
5 changed files with 46 additions and 0 deletions

View File

@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
scoring_strategy: ScoringStrategy,
logger: &mut dyn SearchLogger<Q>,
time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<BucketSortOutput> {
logger.initial_query(query);
logger.ranking_rules(&ranking_rules);
@ -144,6 +145,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
ctx,
from,
length,
ranking_score_threshold,
logger,
&mut valid_docids,
&mut valid_scores,
@ -164,7 +166,9 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
loop {
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
ranking_rule_scores.push(ScoreDetails::Skipped);
maybe_add_to_results!(bucket);
ranking_rule_scores.pop();
if cur_ranking_rule_index == 0 {
@ -220,6 +224,17 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
debug_assert!(
ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates)
);
if let Some(ranking_score_threshold) = ranking_score_threshold {
let current_score = ScoreDetails::global_score(ranking_rule_scores.iter());
if current_score < ranking_score_threshold {
all_candidates -=
next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index];
back!();
continue;
}
}
ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates;
if cur_ranking_rule_index == ranking_rules_len - 1
@ -262,6 +277,7 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
ctx: &mut SearchContext<'ctx>,
from: usize,
length: usize,
ranking_score_threshold: Option<f64>,
logger: &mut dyn SearchLogger<Q>,
valid_docids: &mut Vec<u32>,
@ -279,6 +295,15 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
ranking_rule_scores: &[ScoreDetails],
candidates: RoaringBitmap,
) -> Result<()> {
// remove candidates from the universe without adding them to result if their score is below the threshold
if let Some(ranking_score_threshold) = ranking_score_threshold {
let score = ScoreDetails::global_score(ranking_rule_scores.iter());
if score < ranking_score_threshold {
*all_candidates -= candidates | &ranking_rule_universes[cur_ranking_rule_index];
return Ok(());
}
}
// First apply the distinct rule on the candidates, reducing the universes if necessary
let candidates = if let Some(distinct_fid) = distinct_fid {
let DistinctOutput { remaining, excluded } =

View File

@ -523,6 +523,7 @@ mod tests {
&mut crate::DefaultSearchLogger,
&mut crate::DefaultSearchLogger,
TimeBudget::max(),
None,
)
.unwrap();

View File

@ -568,6 +568,7 @@ pub fn execute_vector_search(
embedder_name: &str,
embedder: &Embedder,
time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -597,6 +598,7 @@ pub fn execute_vector_search(
scoring_strategy,
placeholder_search_logger,
time_budget,
ranking_score_threshold,
)?;
Ok(PartialSearchResult {
@ -626,6 +628,7 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -714,6 +717,7 @@ pub fn execute_search(
scoring_strategy,
query_graph_logger,
time_budget,
ranking_score_threshold,
)?
} else {
let ranking_rules =
@ -728,6 +732,7 @@ pub fn execute_search(
scoring_strategy,
placeholder_search_logger,
time_budget,
ranking_score_threshold,
)?
};