Merge branch 'main' into change-proximity-precision-settings

This commit is contained in:
Many the fish
2023-12-18 09:08:47 +01:00
committed by GitHub
55 changed files with 5801 additions and 723 deletions

183
milli/src/search/hybrid.rs Normal file
View File

@ -0,0 +1,183 @@
use std::cmp::Ordering;
use itertools::Itertools;
use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::{MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult {
matching_words: MatchingWords,
candidates: RoaringBitmap,
document_scores: Vec<(u32, ScoreWithRatio)>,
}
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
fn compare_scores(
&(ref left_scores, left_ratio): &ScoreWithRatio,
&(ref right_scores, right_ratio): &ScoreWithRatio,
) -> Ordering {
let mut left_it = ScoreDetails::score_values(left_scores.iter());
let mut right_it = ScoreDetails::score_values(right_scores.iter());
loop {
let left = left_it.next();
let right = right_it.next();
match (left, right) {
(None, None) => return Ordering::Equal,
(None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater,
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
let left = left * left_ratio as f64;
let right = right * right_ratio as f64;
if (left - right).abs() <= f64::EPSILON {
continue;
}
return left.partial_cmp(&right).unwrap();
}
(Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
// if we have this, we're bad
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
unreachable!("Unexpected geo and sort comparison")
}
}
}
}
impl ScoreWithRatioResult {
fn new(results: SearchResult, ratio: f32) -> Self {
let document_scores = results
.documents_ids
.into_iter()
.zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
.collect();
Self {
matching_words: results.matching_words,
candidates: results.candidates,
document_scores,
}
}
fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult {
let mut documents_ids =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut document_scores =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut documents_seen = RoaringBitmap::new();
for (docid, (main_score, _sub_score)) in left
.document_scores
.into_iter()
.merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| {
// the first value is the one with the greatest score
compare_scores(left, right).is_ge()
})
// remove documents we already saw
.filter(|(docid, _)| documents_seen.insert(*docid))
// start skipping **after** the filter
.skip(from)
// take **after** skipping
.take(length)
{
documents_ids.push(docid);
// TODO: pass both scores to documents_score in some way?
document_scores.push(main_score);
}
SearchResult {
matching_words: left.matching_words,
candidates: left.candidates | right.candidates,
documents_ids,
document_scores,
}
}
}
impl<'a> Search<'a> {
pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> {
// TODO: find classier way to achieve that than to reset vector and query params
// create separate keyword and semantic searches
let mut search = Search {
query: self.query.clone(),
vector: self.vector.clone(),
filter: self.filter.clone(),
offset: 0,
limit: self.limit + self.offset,
sort_criteria: self.sort_criteria.clone(),
searchable_attributes: self.searchable_attributes,
geo_strategy: self.geo_strategy,
terms_matching_strategy: self.terms_matching_strategy,
scoring_strategy: ScoringStrategy::Detailed,
words_limit: self.words_limit,
exhaustive_number_hits: self.exhaustive_number_hits,
rtxn: self.rtxn,
index: self.index,
distribution_shift: self.distribution_shift,
embedder_name: self.embedder_name.clone(),
};
let vector_query = search.vector.take();
let keyword_results = search.execute()?;
// skip semantic search if we don't have a vector query (placeholder search)
let Some(vector_query) = vector_query else {
return Ok(keyword_results);
};
// completely skip semantic search if the results of the keyword search are good enough
if self.results_good_enough(&keyword_results, semantic_ratio) {
return Ok(keyword_results);
}
search.vector = Some(vector_query);
search.query = None;
// TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?;
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
let merge_results =
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
assert!(merge_results.documents_ids.len() <= self.limit);
Ok(merge_results)
}
fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
// A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
const GOOD_ENOUGH_SCORE: f64 = 0.45;
// 1. we check that we got a sufficient number of results
if keyword_results.document_scores.len() < self.limit + self.offset {
return false;
}
// 2. and that all results have a good enough score.
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
for score in &keyword_results.document_scores {
let score = ScoreDetails::global_score(score.iter());
if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
return false;
}
}
true
}
}

View File

@ -12,12 +12,14 @@ use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::PartialSearchResult;
use self::new::{execute_vector_search, PartialSearchResult};
use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift;
use crate::{
execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext,
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
Result, SearchContext,
};
// Building these factories is not free.
@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100;
pub mod facet;
mod fst_utils;
pub mod hybrid;
pub mod new;
pub struct Search<'a> {
@ -46,8 +49,11 @@ pub struct Search<'a> {
scoring_strategy: ScoringStrategy,
words_limit: usize,
exhaustive_number_hits: bool,
/// TODO: Add semantic ratio or pass it directly to execute_hybrid()
rtxn: &'a heed::RoTxn<'a>,
index: &'a Index,
distribution_shift: Option<DistributionShift>,
embedder_name: Option<String>,
}
impl<'a> Search<'a> {
@ -67,6 +73,8 @@ impl<'a> Search<'a> {
words_limit: 10,
rtxn,
index,
distribution_shift: None,
embedder_name: None,
}
}
@ -75,8 +83,8 @@ impl<'a> Search<'a> {
self
}
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> {
self.vector = Some(vector.into());
pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> {
self.vector = Some(vector);
self
}
@ -133,30 +141,75 @@ impl<'a> Search<'a> {
self
}
pub fn distribution_shift(
&mut self,
distribution_shift: Option<DistributionShift>,
) -> &mut Search<'a> {
self.distribution_shift = distribution_shift;
self
}
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
self.embedder_name = Some(embedder_name.into());
self
}
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn);
filtered_universe(&ctx, &self.filter)
} else {
Ok(self.execute()?.candidates)
}
}
pub fn execute(&self) -> Result<SearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = filtered_universe(&ctx, &self.filter)?;
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
execute_search(
&mut ctx,
&self.query,
&self.vector,
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
&self.filter,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?;
match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
)?,
None => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?,
};
// consume context and located_query_terms to build MatchingWords.
let matching_words = match located_query_terms {
@ -185,6 +238,8 @@ impl fmt::Debug for Search<'_> {
exhaustive_number_hits,
rtxn: _,
index: _,
distribution_shift,
embedder_name,
} = self;
f.debug_struct("Search")
.field("query", query)
@ -198,6 +253,8 @@ impl fmt::Debug for Search<'_> {
.field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit)
.field("distribution_shift", distribution_shift)
.field("embedder_name", embedder_name)
.finish()
}
}
@ -249,11 +306,16 @@ pub struct SearchForFacetValues<'a> {
query: Option<String>,
facet: String,
search_query: Search<'a>,
is_hybrid: bool,
}
impl<'a> SearchForFacetValues<'a> {
pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> {
SearchForFacetValues { query: None, facet, search_query }
pub fn new(
facet: String,
search_query: Search<'a>,
is_hybrid: bool,
) -> SearchForFacetValues<'a> {
SearchForFacetValues { query: None, facet, search_query, is_hybrid }
}
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
@ -303,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> {
None => return Ok(vec![]),
};
let search_candidates = self.search_query.execute()?.candidates;
let search_candidates = self
.search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
match self.query.as_ref() {
Some(query) => {

View File

@ -107,12 +107,16 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
/// Refill the internal buffer of cached docids based on the strategy.
/// Drop the rtree if we don't need it anymore.
fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> {
fn fill_buffer(
&mut self,
ctx: &mut SearchContext,
geo_candidates: &RoaringBitmap,
) -> Result<()> {
debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng");
debug_assert!(self.cached_sorted_docids.is_empty());
// lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree`
let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) {
let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) {
if let Some(rtree) = self.rtree.as_ref() {
// get rtree from cache
Some(rtree)
@ -131,7 +135,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
if self.ascending {
let point = lat_lng_to_xyz(&self.point);
for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) {
if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_back(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
@ -143,7 +147,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// and we insert the points in reverse order they get reversed when emptying the cache later on
let point = lat_lng_to_xyz(&opposite_of(self.point));
for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) {
if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_front(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
@ -155,8 +159,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// the iterative version
let [lat, lng] = self.field_ids.unwrap();
let mut documents = self
.geo_candidates
let mut documents = geo_candidates
.iter()
.map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) })
.collect::<Result<Vec<(u32, [f64; 2])>>>()?;
@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
assert!(self.query.is_none());
self.query = Some(query.clone());
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() {
let geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(());
}
@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat");
let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng");
self.field_ids = Some([lat, lng]);
self.fill_buffer(ctx)?;
self.fill_buffer(ctx, &geo_candidates)?;
Ok(())
}
@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() {
let geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
}
};
while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if self.geo_candidates.contains(id) {
if geo_candidates.contains(id) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([id]),
@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx)?;
self.fill_buffer(ctx, &geo_candidates)?;
self.next_bucket(ctx, logger, universe)
}

View File

@ -498,19 +498,19 @@ mod tests {
use super::*;
use crate::index::tests::TempIndex;
use crate::{execute_search, SearchContext};
use crate::{execute_search, filtered_universe, SearchContext};
impl<'a> MatcherBuilder<'a> {
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
let mut ctx = SearchContext::new(index, rtxn);
let universe = filtered_universe(&ctx, &None).unwrap();
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
&mut ctx,
&Some(query.to_string()),
&None,
Some(query),
crate::TermsMatchingStrategy::default(),
crate::score_details::ScoringStrategy::Skip,
false,
&None,
universe,
&None,
crate::search::new::GeoSortStrategy::default(),
0,

View File

@ -16,6 +16,7 @@ mod small_bitmap;
mod exact_attribute;
mod sort;
mod vector_sort;
#[cfg(test)]
mod tests;
@ -28,7 +29,6 @@ use db_cache::DatabaseCache;
use exact_attribute::ExactAttribute;
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
use heed::RoTxn;
use instant_distance::Search;
use interner::{DedupInterner, Interner};
pub use logger::visual::VisualSearchLogger;
pub use logger::{DefaultSearchLogger, SearchLogger};
@ -46,10 +46,11 @@ use self::geo_sort::GeoSort;
pub use self::geo_sort::Strategy as GeoSortStrategy;
use self::graph_based_ranking_rule::Words;
use self::interner::Interned;
use crate::distance::NDotProductPoint;
use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift;
use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
};
@ -258,6 +259,80 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
Ok(ranking_rules)
}
fn get_ranking_rules_for_vector<'ctx>(
ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
let mut sort = false;
let mut sorted_fields = HashSet::new();
let mut geo_sorted = false;
let mut vector = false;
let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![];
let settings_ranking_rules = ctx.index.criteria(ctx.txn)?;
for rr in settings_ranking_rules {
match rr {
crate::Criterion::Words
| crate::Criterion::Typo
| crate::Criterion::Proximity
| crate::Criterion::Attribute
| crate::Criterion::Exactness => {
if !vector {
let vector_candidates = ctx.index.documents_ids(ctx.txn)?;
let vector_sort = VectorSort::new(
ctx,
target.to_vec(),
vector_candidates,
limit_plus_offset,
distribution_shift,
embedder_name,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
}
}
crate::Criterion::Sort => {
if sort {
continue;
}
resolve_sort_criteria(
sort_criteria,
ctx,
&mut ranking_rules,
&mut sorted_fields,
&mut geo_sorted,
geo_strategy,
)?;
sort = true;
}
crate::Criterion::Asc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?));
}
crate::Criterion::Desc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?));
}
}
}
Ok(ranking_rules)
}
/// Return the list of initialised ranking rules to be used for a query graph search.
fn get_ranking_rules_for_query_graph_search<'ctx>(
ctx: &SearchContext<'ctx>,
@ -422,15 +497,72 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
Ok(())
}
pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> {
Ok(if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
})
}
#[allow(clippy::too_many_arguments)]
pub fn execute_vector_search(
ctx: &mut SearchContext,
vector: &[f32],
scoring_strategy: ScoringStrategy,
universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
length: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules = get_ranking_rules_for_vector(
ctx,
sort_criteria,
geo_strategy,
from + length,
vector,
distribution_shift,
embedder_name,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
&mut placeholder_search_logger;
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
ctx,
ranking_rules,
&PlaceholderQuery,
&universe,
from,
length,
scoring_strategy,
placeholder_search_logger,
)?;
Ok(PartialSearchResult {
candidates: all_candidates,
document_scores: scores,
documents_ids: docids,
located_query_terms: None,
})
}
#[allow(clippy::too_many_arguments)]
pub fn execute_search(
ctx: &mut SearchContext,
query: &Option<String>,
vector: &Option<Vec<f32>>,
query: Option<&str>,
terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy,
exhaustive_number_hits: bool,
filters: &Option<Filter>,
mut universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
@ -439,60 +571,8 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
) -> Result<PartialSearchResult> {
let mut universe = if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
};
check_sort_criteria(ctx, sort_criteria.as_ref())?;
if let Some(vector) = vector {
let mut search = Search::default();
let docids = match ctx.index.vector_hnsw(ctx.txn)? {
Some(hnsw) => {
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
if vector.len() != expected_size {
return Err(UserError::InvalidVectorDimensions {
expected: expected_size,
found: vector.len(),
}
.into());
}
}
let vector = NDotProductPoint::new(vector.clone());
let neighbors = hnsw.search(&vector, &mut search);
let mut docids = Vec::new();
let mut uniq_docids = RoaringBitmap::new();
for instant_distance::Item { distance: _, pid, point: _ } in neighbors {
let index = pid.into_inner();
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
if universe.contains(docid) && uniq_docids.insert(docid) {
docids.push(docid);
if docids.len() == (from + length) {
break;
}
}
}
// return the nearest documents that are also part of the candidates
// along with a dummy list of scores that are useless in this context.
docids.into_iter().skip(from).take(length).collect()
}
None => Vec::new(),
};
return Ok(PartialSearchResult {
candidates: universe,
document_scores: vec![Vec::new(); docids.len()],
documents_ids: docids,
located_query_terms: None,
});
}
let mut located_query_terms = None;
let query_terms = if let Some(query) = query {
// We make sure that the analyzer is aware of the stop words
@ -546,7 +626,7 @@ pub fn execute_search(
terms_matching_strategy,
)?;
universe =
universe &=
resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?;
bucket_sort(

View File

@ -0,0 +1,170 @@
use std::iter::FromIterator;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift;
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
query: Option<Q>,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_index: u8,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
pub fn new(
ctx: &SearchContext,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Self> {
let embedder_index = ctx
.index
.embedder_category_id
.get(ctx.txn, embedder_name)?
.ok_or_else(|| crate::UserError::InvalidEmbedder(embedder_name.to_owned()))?;
Ok(Self {
query: None,
target,
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
embedder_index,
})
}
fn fill_buffer(
&mut self,
ctx: &mut SearchContext<'_>,
vector_candidates: &RoaringBitmap,
) -> Result<()> {
let writer_index = (self.embedder_index as u16) << 8;
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map_while(|k| {
arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy)
.map(Some)
.or_else(|e| match e {
arroy::Error::MissingMetadata => Ok(None),
e => Err(e),
})
.transpose()
})
.collect();
let readers = readers?;
let target = &self.target;
let mut results = Vec::new();
for reader in readers.iter() {
let nns_by_vector =
reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?;
let vectors: std::result::Result<Vec<_>, _> = nns_by_vector
.iter()
.map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap())
.collect();
let vectors = vectors?;
results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z)));
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
fn id(&self) -> String {
"vector_sort".to_owned()
}
fn start_iteration(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
query: &Q,
) -> Result<()> {
assert!(self.query.is_none());
self.query = Some(query.clone());
let vector_candidates = &self.vector_candidates & universe;
self.fill_buffer(ctx, &vector_candidates)?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn next_bucket(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
let vector_candidates = &self.vector_candidates & universe;
if vector_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
if vector_candidates.contains(docid) {
let score = 1.0 - distance;
let score = self
.distribution_shift
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, score)),
}),
}));
}
}
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx, &vector_candidates)?;
// we tried filling the buffer, but it remained empty 😢
// it means we don't actually have any document remaining in the universe with a vector.
// => exit
if self.cached_sorted_docids.len() == 0 {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
self.next_bucket(ctx, _logger, universe)
}
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {
self.query = None;
}
}