Support distinct in hybrid search

This commit is contained in:
Louis Dureuil 2025-05-28 17:58:02 +02:00
parent fd4b192a39
commit 54f5e74744
No known key found for this signature in database

View File

@ -1,11 +1,13 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use heed::RoTxn;
use itertools::Itertools; use itertools::Itertools;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::search::new::{distinct_fid, distinct_single_docid};
use crate::search::SemanticSearch; use crate::search::SemanticSearch;
use crate::{MatchingWords, Result, Search, SearchResult}; use crate::{Index, MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult { struct ScoreWithRatioResult {
matching_words: MatchingWords, matching_words: MatchingWords,
@ -91,7 +93,10 @@ impl ScoreWithRatioResult {
keyword_results: Self, keyword_results: Self,
from: usize, from: usize,
length: usize, length: usize,
) -> (SearchResult, u32) { distinct: Option<&str>,
index: &Index,
rtxn: &RoTxn<'_>,
) -> Result<(SearchResult, u32)> {
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
enum ResultSource { enum ResultSource {
Semantic, Semantic,
@ -106,8 +111,9 @@ impl ScoreWithRatioResult {
vector_results.document_scores.len() + keyword_results.document_scores.len(), vector_results.document_scores.len() + keyword_results.document_scores.len(),
); );
let mut documents_seen = RoaringBitmap::new(); let distinct_fid = distinct_fid(distinct, index, rtxn)?;
for ((docid, (main_score, _sub_score)), source) in vector_results let mut excluded_documents = RoaringBitmap::new();
for res in vector_results
.document_scores .document_scores
.into_iter() .into_iter()
.zip(std::iter::repeat(ResultSource::Semantic)) .zip(std::iter::repeat(ResultSource::Semantic))
@ -121,13 +127,33 @@ impl ScoreWithRatioResult {
compare_scores(left, right).is_ge() compare_scores(left, right).is_ge()
}, },
) )
// remove documents we already saw // remove documents we already saw and apply distinct rule
.filter(|((docid, _), _)| documents_seen.insert(*docid)) .filter_map(|item @ ((docid, _), _)| {
if !excluded_documents.insert(docid) {
// the document was already added, or is indistinct from an already-added document.
return None;
}
if let Some(distinct_fid) = distinct_fid {
if let Err(error) = distinct_single_docid(
index,
rtxn,
distinct_fid,
docid,
&mut excluded_documents,
) {
return Some(Err(error));
}
}
Some(Ok(item))
})
// start skipping **after** the filter // start skipping **after** the filter
.skip(from) .skip(from)
// take **after** skipping // take **after** skipping
.take(length) .take(length)
{ {
let ((docid, (main_score, _sub_score)), source) = res?;
if let ResultSource::Semantic = source { if let ResultSource::Semantic = source {
semantic_hit_count += 1; semantic_hit_count += 1;
} }
@ -136,10 +162,24 @@ impl ScoreWithRatioResult {
document_scores.push(main_score); document_scores.push(main_score);
} }
( // compute the set of candidates from both sets
let candidates = vector_results.candidates | keyword_results.candidates;
let must_remove_redundant_candidates = distinct_fid.is_some();
let candidates = if must_remove_redundant_candidates {
// patch-up the candidates to remove the indistinct documents, then add back the actual hits
let mut candidates = candidates - excluded_documents;
for docid in &documents_ids {
candidates.insert(*docid);
}
candidates
} else {
candidates
};
Ok((
SearchResult { SearchResult {
matching_words: keyword_results.matching_words, matching_words: keyword_results.matching_words,
candidates: vector_results.candidates | keyword_results.candidates, candidates,
documents_ids, documents_ids,
document_scores, document_scores,
degraded: vector_results.degraded | keyword_results.degraded, degraded: vector_results.degraded | keyword_results.degraded,
@ -147,7 +187,7 @@ impl ScoreWithRatioResult {
| keyword_results.used_negative_operator, | keyword_results.used_negative_operator,
}, },
semantic_hit_count, semantic_hit_count,
) ))
} }
} }
@ -226,8 +266,15 @@ impl Search<'_> {
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
let (merge_results, semantic_hit_count) = let (merge_results, semantic_hit_count) = ScoreWithRatioResult::merge(
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); vector_results,
keyword_results,
self.offset,
self.limit,
search.distinct.as_deref(),
search.index,
search.rtxn,
)?;
assert!(merge_results.documents_ids.len() <= self.limit); assert!(merge_results.documents_ids.len() <= self.limit);
Ok((merge_results, Some(semantic_hit_count))) Ok((merge_results, Some(semantic_hit_count)))
} }