Set search_k to max_hits * n_trees with finite pagination

This commit is contained in:
Mubelotix
2025-07-07 12:53:48 +02:00
parent 70a860a0f0
commit 68362cf5dd
8 changed files with 52 additions and 15 deletions

View File

@ -1050,7 +1050,7 @@ pub fn prepare_search<'t>(
.map(|x| x as usize) .map(|x| x as usize)
.unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS);
search.exhaustive_number_hits(is_finite_pagination); search.is_exhaustive_pagination(is_finite_pagination);
search.max_total_hits(Some(max_total_hits)); search.max_total_hits(Some(max_total_hits));
search.scoring_strategy( search.scoring_strategy(
if query.show_ranking_score if query.show_ranking_score

View File

@ -209,7 +209,7 @@ impl Search<'_> {
terms_matching_strategy: self.terms_matching_strategy, terms_matching_strategy: self.terms_matching_strategy,
scoring_strategy: ScoringStrategy::Detailed, scoring_strategy: ScoringStrategy::Detailed,
words_limit: self.words_limit, words_limit: self.words_limit,
exhaustive_number_hits: self.exhaustive_number_hits, is_exhaustive_pagination: self.is_exhaustive_pagination,
max_total_hits: self.max_total_hits, max_total_hits: self.max_total_hits,
rtxn: self.rtxn, rtxn: self.rtxn,
index: self.index, index: self.index,

View File

@ -51,7 +51,7 @@ pub struct Search<'a> {
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
words_limit: usize, words_limit: usize,
exhaustive_number_hits: bool, is_exhaustive_pagination: bool,
max_total_hits: Option<usize>, max_total_hits: Option<usize>,
rtxn: &'a heed::RoTxn<'a>, rtxn: &'a heed::RoTxn<'a>,
index: &'a Index, index: &'a Index,
@ -74,7 +74,7 @@ impl<'a> Search<'a> {
geo_param: new::GeoSortParameter::default(), geo_param: new::GeoSortParameter::default(),
terms_matching_strategy: TermsMatchingStrategy::default(), terms_matching_strategy: TermsMatchingStrategy::default(),
scoring_strategy: Default::default(), scoring_strategy: Default::default(),
exhaustive_number_hits: false, is_exhaustive_pagination: false,
max_total_hits: None, max_total_hits: None,
words_limit: 10, words_limit: 10,
rtxn, rtxn,
@ -162,8 +162,8 @@ impl<'a> Search<'a> {
/// Forces the search to exhaustively compute the number of candidates, /// Forces the search to exhaustively compute the number of candidates,
/// this will increase the search time but allows finite pagination. /// this will increase the search time but allows finite pagination.
pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> { pub fn is_exhaustive_pagination(&mut self, is_exhaustive_pagination: bool) -> &mut Search<'a> {
self.exhaustive_number_hits = exhaustive_number_hits; self.is_exhaustive_pagination = is_exhaustive_pagination;
self self
} }
@ -231,6 +231,13 @@ impl<'a> Search<'a> {
} }
} }
let mut search_k_div_trees = None;
if self.is_exhaustive_pagination {
if let Some(max_total_hits) = self.max_total_hits {
search_k_div_trees = Some(max_total_hits);
}
}
let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?;
let PartialSearchResult { let PartialSearchResult {
located_query_terms, located_query_terms,
@ -250,7 +257,7 @@ impl<'a> Search<'a> {
&mut ctx, &mut ctx,
vector, vector,
self.scoring_strategy, self.scoring_strategy,
self.exhaustive_number_hits, self.is_exhaustive_pagination,
self.max_total_hits, self.max_total_hits,
universe, universe,
&self.sort_criteria, &self.sort_criteria,
@ -261,6 +268,7 @@ impl<'a> Search<'a> {
embedder_name, embedder_name,
embedder, embedder,
*quantized, *quantized,
search_k_div_trees,
self.time_budget.clone(), self.time_budget.clone(),
self.ranking_score_threshold, self.ranking_score_threshold,
)?, )?,
@ -269,7 +277,7 @@ impl<'a> Search<'a> {
self.query.as_deref(), self.query.as_deref(),
self.terms_matching_strategy, self.terms_matching_strategy,
self.scoring_strategy, self.scoring_strategy,
self.exhaustive_number_hits, self.is_exhaustive_pagination,
self.max_total_hits, self.max_total_hits,
universe, universe,
&self.sort_criteria, &self.sort_criteria,
@ -323,7 +331,7 @@ impl fmt::Debug for Search<'_> {
terms_matching_strategy, terms_matching_strategy,
scoring_strategy, scoring_strategy,
words_limit, words_limit,
exhaustive_number_hits, is_exhaustive_pagination,
max_total_hits, max_total_hits,
rtxn: _, rtxn: _,
index: _, index: _,
@ -343,7 +351,7 @@ impl fmt::Debug for Search<'_> {
.field("searchable_attributes", searchable_attributes) .field("searchable_attributes", searchable_attributes)
.field("terms_matching_strategy", terms_matching_strategy) .field("terms_matching_strategy", terms_matching_strategy)
.field("scoring_strategy", scoring_strategy) .field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits) .field("is_exhaustive_pagination", is_exhaustive_pagination)
.field("max_total_hits", max_total_hits) .field("max_total_hits", max_total_hits)
.field("words_limit", words_limit) .field("words_limit", words_limit)
.field( .field(

View File

@ -377,6 +377,7 @@ fn get_ranking_rules_for_vector<'ctx>(
embedder_name: &str, embedder_name: &str,
embedder: &Embedder, embedder: &Embedder,
quantized: bool, quantized: bool,
search_k_div_trees: Option<usize>,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search // query graph search
@ -405,6 +406,7 @@ fn get_ranking_rules_for_vector<'ctx>(
embedder_name, embedder_name,
embedder, embedder,
quantized, quantized,
search_k_div_trees,
)?; )?;
ranking_rules.push(Box::new(vector_sort)); ranking_rules.push(Box::new(vector_sort));
vector = true; vector = true;
@ -637,6 +639,7 @@ pub fn execute_vector_search(
embedder_name: &str, embedder_name: &str,
embedder: &Embedder, embedder: &Embedder,
quantized: bool, quantized: bool,
search_k_div_trees: Option<usize>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>, ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
@ -653,6 +656,7 @@ pub fn execute_vector_search(
embedder_name, embedder_name,
embedder, embedder,
quantized, quantized,
search_k_div_trees,
)?; )?;
let mut placeholder_search_logger = logger::DefaultSearchLogger; let mut placeholder_search_logger = logger::DefaultSearchLogger;

View File

@ -572,7 +572,7 @@ fn test_distinct_all_candidates() {
let mut s = Search::new(&txn, &index); let mut s = Search::new(&txn, &index);
s.terms_matching_strategy(TermsMatchingStrategy::Last); s.terms_matching_strategy(TermsMatchingStrategy::Last);
s.sort_criteria(vec![AscDesc::Desc(Member::Field(S("rank1")))]); s.sort_criteria(vec![AscDesc::Desc(Member::Field(S("rank1")))]);
s.exhaustive_number_hits(true); s.is_exhaustive_pagination(true);
let SearchResult { documents_ids, candidates, .. } = s.execute().unwrap(); let SearchResult { documents_ids, candidates, .. } = s.execute().unwrap();
let candidates = candidates.iter().collect::<Vec<_>>(); let candidates = candidates.iter().collect::<Vec<_>>();

View File

@ -18,9 +18,11 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
distribution_shift: Option<DistributionShift>, distribution_shift: Option<DistributionShift>,
embedder_index: u8, embedder_index: u8,
quantized: bool, quantized: bool,
search_k_div_trees: Option<usize>,
} }
impl<Q: RankingRuleQueryTrait> VectorSort<Q> { impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
ctx: &SearchContext<'_>, ctx: &SearchContext<'_>,
target: Vec<f32>, target: Vec<f32>,
@ -29,6 +31,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
embedder_name: &str, embedder_name: &str,
embedder: &Embedder, embedder: &Embedder,
quantized: bool, quantized: bool,
search_k_div_trees: Option<usize>,
) -> Result<Self> { ) -> Result<Self> {
let embedder_index = ctx let embedder_index = ctx
.index .index
@ -42,6 +45,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
vector_candidates, vector_candidates,
cached_sorted_docids: Default::default(), cached_sorted_docids: Default::default(),
limit, limit,
search_k_div_trees,
distribution_shift: embedder.distribution(), distribution_shift: embedder.distribution(),
embedder_index, embedder_index,
quantized, quantized,
@ -57,7 +61,13 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
let before = Instant::now(); let before = Instant::now();
let reader = ArroyWrapper::new(ctx.index.vector_arroy, self.embedder_index, self.quantized); let reader = ArroyWrapper::new(ctx.index.vector_arroy, self.embedder_index, self.quantized);
let results = reader.nns_by_vector(ctx.txn, target, self.limit, Some(vector_candidates))?; let results = reader.nns_by_vector(
ctx.txn,
target,
self.limit,
self.search_k_div_trees,
Some(vector_candidates),
)?;
self.cached_sorted_docids = results.into_iter(); self.cached_sorted_docids = results.into_iter();
*ctx.vector_store_stats.get_or_insert_default() += VectorStoreStats { *ctx.vector_store_stats.get_or_insert_default() += VectorStoreStats {
total_time: before.elapsed(), total_time: before.elapsed(),

View File

@ -483,12 +483,20 @@ impl ArroyWrapper {
rtxn: &RoTxn, rtxn: &RoTxn,
vector: &[f32], vector: &[f32],
limit: usize, limit: usize,
search_k_div_trees: Option<usize>,
filter: Option<&RoaringBitmap>, filter: Option<&RoaringBitmap>,
) -> Result<Vec<(ItemId, f32)>, arroy::Error> { ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
if self.quantized { if self.quantized {
self._nns_by_vector(rtxn, self.quantized_db(), vector, limit, filter) self._nns_by_vector(
rtxn,
self.quantized_db(),
vector,
limit,
search_k_div_trees,
filter,
)
} else { } else {
self._nns_by_vector(rtxn, self.angular_db(), vector, limit, filter) self._nns_by_vector(rtxn, self.angular_db(), vector, limit, search_k_div_trees, filter)
} }
} }
@ -498,6 +506,7 @@ impl ArroyWrapper {
db: arroy::Database<D>, db: arroy::Database<D>,
vector: &[f32], vector: &[f32],
limit: usize, limit: usize,
search_k_div_trees: Option<usize>,
filter: Option<&RoaringBitmap>, filter: Option<&RoaringBitmap>,
) -> Result<Vec<(ItemId, f32)>, arroy::Error> { ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
let mut results = Vec::new(); let mut results = Vec::new();
@ -509,6 +518,12 @@ impl ArroyWrapper {
if reader.item_ids().is_disjoint(filter) { if reader.item_ids().is_disjoint(filter) {
continue; continue;
} }
if let Some(mut search_k) = search_k_div_trees {
search_k *= reader.n_trees();
if let Ok(search_k) = search_k.try_into() {
searcher.search_k(search_k);
}
}
searcher.candidates(filter); searcher.candidates(filter);
} }

View File

@ -29,7 +29,7 @@ macro_rules! test_distinct {
search.query(search::TEST_QUERY); search.query(search::TEST_QUERY);
search.limit($limit); search.limit($limit);
search.offset($offset); search.offset($offset);
search.exhaustive_number_hits($exhaustive); search.is_exhaustive_pagination($exhaustive);
search.terms_matching_strategy(TermsMatchingStrategy::default()); search.terms_matching_strategy(TermsMatchingStrategy::default());