Remove stuff, add distribution shift (WIP)

This commit is contained in:
Louis Dureuil
2023-12-12 10:05:06 +01:00
parent e56f160032
commit 65e49b7092
10 changed files with 126 additions and 278 deletions

View File

@ -50,6 +50,7 @@ use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift;
use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
};
@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>(
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>(
target.to_vec(),
vector_candidates,
limit_plus_offset,
distribution_shift,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
@ -515,8 +518,14 @@ pub fn execute_vector_search(
/// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules =
get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?;
let ranking_rules = get_ranking_rules_for_vector(
ctx,
sort_criteria,
geo_strategy,
from + length,
vector,
None,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =

View File

@ -5,6 +5,7 @@ use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift;
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
@ -13,6 +14,7 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
vector_candidates: RoaringBitmap,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
@ -21,6 +23,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
) -> Result<Self> {
Ok(Self {
query: None,
@ -28,6 +31,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
})
}
@ -52,7 +56,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
for reader in readers.iter() {
let nns_by_vector = reader.nns_by_vector(
ctx.txn,
&target,
target,
self.limit,
None,
Some(&self.vector_candidates),
@ -66,6 +70,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
}));
}
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
if self.vector_candidates.contains(docid) {
let score = 1.0 - distance;
let score = self
.distribution_shift
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, 1.0 - distance)),
value_similarity: Some((vector, score)),
}),
}));
}