Add ranking score threshold to similar

This commit is contained in:
Louis Dureuil
2024-05-30 10:34:09 +02:00
parent c26db7878c
commit 4f03b0cf5b

View File

@@ -17,6 +17,7 @@ pub struct Similar<'a> {
index: &'a Index,
embedder_name: String,
embedder: Arc<Embedder>,
ranking_score_threshold: Option<f64>,
}
impl<'a> Similar<'a> {
@@ -29,7 +30,17 @@ impl<'a> Similar<'a> {
embedder_name: String,
embedder: Arc<Embedder>,
) -> Self {
Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder }
Self {
id,
filter: None,
offset,
limit,
rtxn,
index,
embedder_name,
embedder,
ranking_score_threshold: None,
}
}
pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self {
@@ -37,8 +48,18 @@ impl<'a> Similar<'a> {
self
}
pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Self {
self.ranking_score_threshold = Some(ranking_score_threshold);
self
}
pub fn execute(&self) -> Result<SearchResult> {
let universe = filtered_universe(self.index, self.rtxn, &self.filter)?;
let mut universe = filtered_universe(self.index, self.rtxn, &self.filter)?;
// we never want to receive the docid
universe.remove(self.id);
let universe = universe;
let embedder_index =
self.index
@@ -77,6 +98,8 @@ impl<'a> Similar<'a> {
let mut documents_seen = RoaringBitmap::new();
documents_seen.insert(self.id);
let mut candidates = universe;
for (docid, distance) in results
.into_iter()
// skip documents we've already seen & mark that we saw the current document
@@ -85,8 +108,6 @@ impl<'a> Similar<'a> {
// take **after** filter and skip so that we get exactly limit elements if available
.take(self.limit)
{
documents_ids.push(docid);
let score = 1.0 - distance;
let score = self
.embedder
@@ -94,14 +115,28 @@ impl<'a> Similar<'a> {
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) });
let score_details =
vec![ScoreDetails::Vector(score_details::Vector { similarity: Some(score) })];
document_scores.push(vec![score]);
let score = ScoreDetails::global_score(score_details.iter());
if let Some(ranking_score_threshold) = &self.ranking_score_threshold {
if score < *ranking_score_threshold {
// this document is no longer a candidate
candidates.remove(docid);
// any document after this one is no longer a candidate either, so restrict the set to documents already seen.
candidates &= documents_seen;
break;
}
}
documents_ids.push(docid);
document_scores.push(score_details);
}
Ok(SearchResult {
matching_words: Default::default(),
candidates: universe,
candidates,
documents_ids,
document_scores,
degraded: false,