Small commit to add hybrid search and autoembedding

This commit is contained in:
Louis Dureuil
2023-11-15 15:46:37 +01:00
parent 21bcf32109
commit 13c2c6c16b
42 changed files with 4045 additions and 246 deletions

336
milli/src/search/hybrid.rs Normal file
View File

@ -0,0 +1,336 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use itertools::Itertools;
use roaring::RoaringBitmap;
use super::new::{execute_vector_search, PartialSearchResult};
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::{
execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult,
};
struct CombinedSearchResult {
matching_words: MatchingWords,
candidates: RoaringBitmap,
document_scores: Vec<(u32, CombinedScore)>,
}
type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>);
fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
let mut left_main_it = ScoreDetails::score_values(left.0.iter());
let mut left_sub_it =
ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten());
let mut right_main_it = ScoreDetails::score_values(right.0.iter());
let mut right_sub_it =
ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten());
let mut left_main = left_main_it.next();
let mut left_sub = left_sub_it.next();
let mut right_main = right_main_it.next();
let mut right_sub = right_sub_it.next();
loop {
let left =
take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it);
let right =
take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it);
match (left, right) {
(None, None) => return Ordering::Equal,
(None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater,
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
if (left - right).abs() <= f64::EPSILON {
continue;
}
return left.partial_cmp(&right).unwrap();
}
(Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
// if we have this, we're bad
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
unreachable!("Unexpected geo and sort comparison")
}
}
}
}
fn take_best_score<'a>(
main_score: &mut Option<ScoreValue<'a>>,
sub_score: &mut Option<ScoreValue<'a>>,
main_it: &mut impl Iterator<Item = ScoreValue<'a>>,
sub_it: &mut impl Iterator<Item = ScoreValue<'a>>,
) -> Option<ScoreValue<'a>> {
match (*main_score, *sub_score) {
(Some(main), None) => {
*main_score = main_it.next();
Some(main)
}
(None, Some(sub)) => {
*sub_score = sub_it.next();
Some(sub)
}
(main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => {
// take max, both advance
*main_score = main_it.next();
*sub_score = sub_it.next();
if main_f >= sub_v {
main
} else {
sub
}
}
(main @ Some(ScoreValue::Score(_)), _) => {
*main_score = main_it.next();
main
}
(_, sub @ Some(ScoreValue::Score(_))) => {
*sub_score = sub_it.next();
sub
}
(main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => {
// take best advance both
*main_score = main_it.next();
*sub_score = sub_it.next();
if main_geo >= sub_geo {
main
} else {
sub
}
}
(main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => {
// take best advance both
*main_score = main_it.next();
*sub_score = sub_it.next();
if main_sort >= sub_sort {
main
} else {
sub
}
}
(
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
) => None,
(None, None) => None,
}
}
impl CombinedSearchResult {
fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self {
let mut docid_scores = HashMap::new();
for (docid, score) in
main_results.documents_ids.iter().zip(main_results.document_scores.into_iter())
{
docid_scores.insert(*docid, (score, None));
}
for (docid, score) in ancillary_results
.documents_ids
.iter()
.zip(ancillary_results.document_scores.into_iter())
{
docid_scores
.entry(*docid)
.and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score));
}
let mut document_scores: Vec<_> = docid_scores.into_iter().collect();
document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse());
Self {
matching_words: main_results.matching_words,
candidates: main_results.candidates,
document_scores,
}
}
fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult {
let mut documents_ids =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut document_scores =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut documents_seen = RoaringBitmap::new();
for (docid, (main_score, _sub_score)) in left
.document_scores
.into_iter()
.merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| {
// the first value is the one with the greatest score
compare_scores(left, right).is_ge()
})
// remove documents we already saw
.filter(|(docid, _)| documents_seen.insert(*docid))
// start skipping **after** the filter
.skip(from)
// take **after** skipping
.take(length)
{
documents_ids.push(docid);
// TODO: pass both scores to documents_score in some way?
document_scores.push(main_score);
}
SearchResult {
matching_words: left.matching_words,
candidates: left.candidates | right.candidates,
documents_ids,
document_scores,
}
}
}
impl<'a> Search<'a> {
pub fn execute_hybrid(&self) -> Result<SearchResult> {
// TODO: find classier way to achieve that than to reset vector and query params
// create separate keyword and semantic searches
let mut search = Search {
query: self.query.clone(),
vector: self.vector.clone(),
filter: self.filter.clone(),
offset: 0,
limit: self.limit + self.offset,
sort_criteria: self.sort_criteria.clone(),
searchable_attributes: self.searchable_attributes,
geo_strategy: self.geo_strategy,
terms_matching_strategy: self.terms_matching_strategy,
scoring_strategy: ScoringStrategy::Detailed,
words_limit: self.words_limit,
exhaustive_number_hits: self.exhaustive_number_hits,
rtxn: self.rtxn,
index: self.index,
};
let vector_query = search.vector.take();
let keyword_query = self.query.as_deref();
let keyword_results = search.execute()?;
// skip semantic search if we don't have a vector query (placeholder search)
let Some(vector_query) = vector_query else {
return Ok(keyword_results);
};
// completely skip semantic search if the results of the keyword search are good enough
if self.results_good_enough(&keyword_results) {
return Ok(keyword_results);
}
search.vector = Some(vector_query);
search.query = None;
// TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?;
// Compute keyword scores for vector_results
let keyword_results_for_vector =
self.keyword_results_for_vector(keyword_query, &vector_results)?;
// compute vector scores for keyword_results
let vector_results_for_keyword =
// can unwrap because we returned already if there was no vector query
self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?;
let keyword_results =
CombinedSearchResult::new(keyword_results, vector_results_for_keyword);
let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector);
let merge_results =
CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit);
assert!(merge_results.documents_ids.len() <= self.limit);
Ok(merge_results)
}
fn vector_results_for_keyword(
&self,
vector: &[f32],
keyword_results: &SearchResult,
) -> Result<PartialSearchResult> {
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = keyword_results.documents_ids.iter().collect();
execute_vector_search(
&mut ctx,
vector,
ScoringStrategy::Detailed,
universe,
&self.sort_criteria,
self.geo_strategy,
0,
self.limit + self.offset,
)
}
fn keyword_results_for_vector(
&self,
query: Option<&str>,
vector_results: &SearchResult,
) -> Result<PartialSearchResult> {
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = vector_results.documents_ids.iter().collect();
execute_search(
&mut ctx,
query,
self.terms_matching_strategy,
ScoringStrategy::Detailed,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
0,
self.limit + self.offset,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)
}
fn results_good_enough(&self, keyword_results: &SearchResult) -> bool {
const GOOD_ENOUGH_SCORE: f64 = 0.9;
// 1. we check that we got a sufficient number of results
if keyword_results.document_scores.len() < self.limit + self.offset {
return false;
}
// 2. and that all results have a good enough score.
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
for score in &keyword_results.document_scores {
let score = ScoreDetails::global_score(score.iter());
if score < GOOD_ENOUGH_SCORE {
return false;
}
}
true
}
}

View File

@ -3,6 +3,7 @@ use std::ops::ControlFlow;
use charabia::normalizer::NormalizerOption;
use charabia::Normalize;
use deserr::{DeserializeError, Deserr, Sequence};
use fst::automaton::{Automaton, Str};
use fst::{IntoStreamer, Streamer};
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
@ -12,12 +13,13 @@ use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::PartialSearchResult;
use self::new::{execute_vector_search, PartialSearchResult};
use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::{
execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext,
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
Result, SearchContext,
};
// Building these factories is not free.
@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100;
pub mod facet;
mod fst_utils;
pub mod hybrid;
pub mod new;
pub struct Search<'a> {
@ -50,6 +53,53 @@ pub struct Search<'a> {
index: &'a Index,
}
#[derive(Debug, Clone, PartialEq)]
pub enum VectorQuery {
Vector(Vec<f32>),
String(String),
}
impl<E> Deserr<E> for VectorQuery
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> std::result::Result<Self, E> {
match value {
deserr::Value::String(s) => Ok(VectorQuery::String(s)),
deserr::Value::Sequence(seq) => {
let v: std::result::Result<Vec<f32>, _> = seq
.into_iter()
.enumerate()
.map(|(index, v)| match v.into_value() {
deserr::Value::Float(f) => Ok(f as f32),
deserr::Value::Integer(i) => Ok(i as f32),
v => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: v,
accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer],
},
location.push_index(index),
))),
})
.collect();
Ok(VectorQuery::Vector(v?))
}
_ => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence],
},
location,
))),
}
}
}
impl<'a> Search<'a> {
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
Search {
@ -75,8 +125,8 @@ impl<'a> Search<'a> {
self
}
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> {
self.vector = Some(vector.into());
pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> {
self.vector = Some(vector);
self
}
@ -140,23 +190,35 @@ impl<'a> Search<'a> {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = filtered_universe(&ctx, &self.filter)?;
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
execute_search(
&mut ctx,
&self.query,
&self.vector,
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
&self.filter,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?;
match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
)?,
None => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?,
};
// consume context and located_query_terms to build MatchingWords.
let matching_words = match located_query_terms {

View File

@ -498,19 +498,19 @@ mod tests {
use super::*;
use crate::index::tests::TempIndex;
use crate::{execute_search, SearchContext};
use crate::{execute_search, filtered_universe, SearchContext};
impl<'a> MatcherBuilder<'a> {
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
let mut ctx = SearchContext::new(index, rtxn);
let universe = filtered_universe(&ctx, &None).unwrap();
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
&mut ctx,
&Some(query.to_string()),
&None,
Some(query),
crate::TermsMatchingStrategy::default(),
crate::score_details::ScoringStrategy::Skip,
false,
&None,
universe,
&None,
crate::search::new::GeoSortStrategy::default(),
0,

View File

@ -16,6 +16,7 @@ mod small_bitmap;
mod exact_attribute;
mod sort;
mod vector_sort;
#[cfg(test)]
mod tests;
@ -28,7 +29,6 @@ use db_cache::DatabaseCache;
use exact_attribute::ExactAttribute;
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
use heed::RoTxn;
use instant_distance::Search;
use interner::{DedupInterner, Interner};
pub use logger::visual::VisualSearchLogger;
pub use logger::{DefaultSearchLogger, SearchLogger};
@ -46,7 +46,7 @@ use self::geo_sort::GeoSort;
pub use self::geo_sort::Strategy as GeoSortStrategy;
use self::graph_based_ranking_rule::Words;
use self::interner::Interned;
use crate::distance::NDotProductPoint;
use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
@ -258,6 +258,70 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
Ok(ranking_rules)
}
fn get_ranking_rules_for_vector<'ctx>(
ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
target: &[f32],
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
let mut sort = false;
let mut sorted_fields = HashSet::new();
let mut geo_sorted = false;
let mut vector = false;
let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![];
let settings_ranking_rules = ctx.index.criteria(ctx.txn)?;
for rr in settings_ranking_rules {
match rr {
crate::Criterion::Words
| crate::Criterion::Typo
| crate::Criterion::Proximity
| crate::Criterion::Attribute
| crate::Criterion::Exactness => {
if !vector {
let vector_candidates = ctx.index.documents_ids(ctx.txn)?;
let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
}
}
crate::Criterion::Sort => {
if sort {
continue;
}
resolve_sort_criteria(
sort_criteria,
ctx,
&mut ranking_rules,
&mut sorted_fields,
&mut geo_sorted,
geo_strategy,
)?;
sort = true;
}
crate::Criterion::Asc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?));
}
crate::Criterion::Desc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?));
}
}
}
Ok(ranking_rules)
}
/// Return the list of initialised ranking rules to be used for a query graph search.
fn get_ranking_rules_for_query_graph_search<'ctx>(
ctx: &SearchContext<'ctx>,
@ -422,15 +486,62 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
Ok(())
}
pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> {
Ok(if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
})
}
#[allow(clippy::too_many_arguments)]
pub fn execute_vector_search(
ctx: &mut SearchContext,
vector: &[f32],
scoring_strategy: ScoringStrategy,
universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
length: usize,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
/// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
&mut placeholder_search_logger;
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
ctx,
ranking_rules,
&PlaceholderQuery,
&universe,
from,
length,
scoring_strategy,
placeholder_search_logger,
)?;
Ok(PartialSearchResult {
candidates: all_candidates,
document_scores: scores,
documents_ids: docids,
located_query_terms: None,
})
}
#[allow(clippy::too_many_arguments)]
pub fn execute_search(
ctx: &mut SearchContext,
query: &Option<String>,
vector: &Option<Vec<f32>>,
query: Option<&str>,
terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy,
exhaustive_number_hits: bool,
filters: &Option<Filter>,
mut universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
@ -439,60 +550,8 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
) -> Result<PartialSearchResult> {
let mut universe = if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
};
check_sort_criteria(ctx, sort_criteria.as_ref())?;
if let Some(vector) = vector {
let mut search = Search::default();
let docids = match ctx.index.vector_hnsw(ctx.txn)? {
Some(hnsw) => {
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
if vector.len() != expected_size {
return Err(UserError::InvalidVectorDimensions {
expected: expected_size,
found: vector.len(),
}
.into());
}
}
let vector = NDotProductPoint::new(vector.clone());
let neighbors = hnsw.search(&vector, &mut search);
let mut docids = Vec::new();
let mut uniq_docids = RoaringBitmap::new();
for instant_distance::Item { distance: _, pid, point: _ } in neighbors {
let index = pid.into_inner();
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
if universe.contains(docid) && uniq_docids.insert(docid) {
docids.push(docid);
if docids.len() == (from + length) {
break;
}
}
}
// return the nearest documents that are also part of the candidates
// along with a dummy list of scores that are useless in this context.
docids.into_iter().skip(from).take(length).collect()
}
None => Vec::new(),
};
return Ok(PartialSearchResult {
candidates: universe,
document_scores: vec![Vec::new(); docids.len()],
documents_ids: docids,
located_query_terms: None,
});
}
let mut located_query_terms = None;
let query_terms = if let Some(query) = query {
// We make sure that the analyzer is aware of the stop words
@ -546,7 +605,7 @@ pub fn execute_search(
terms_matching_strategy,
)?;
universe =
universe &=
resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?;
bucket_sort(

View File

@ -0,0 +1,150 @@
use std::future::Future;
use std::iter::FromIterator;
use std::pin::Pin;
use nolife::DynBoxScope;
use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::distance::NDotProductPoint;
use crate::index::Hnsw;
use crate::score_details::{self, ScoreDetails};
use crate::{Result, SearchContext, SearchLogger, UserError};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
query: Option<Q>,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
scope: nolife::DynBoxScope<SearchFamily>,
}
type Item<'a> = instant_distance::Item<'a, NDotProductPoint>;
type SearchFut = Pin<Box<dyn Future<Output = nolife::Never>>>;
struct SearchFamily;
impl<'a> nolife::Family<'a> for SearchFamily {
type Family = Box<dyn Iterator<Item = Item<'a>> + 'a>;
}
async fn search_scope(
mut time_capsule: nolife::TimeCapsule<SearchFamily>,
hnsw: Hnsw,
target: Vec<f32>,
) -> nolife::Never {
let mut search = instant_distance::Search::default();
let it = Box::new(hnsw.search(&NDotProductPoint::new(target), &mut search));
let mut it: Box<dyn Iterator<Item = Item>> = it;
loop {
time_capsule.freeze(&mut it).await;
}
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
pub fn new(
ctx: &SearchContext,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
) -> Result<Self> {
let hnsw =
ctx.index.vector_hnsw(ctx.txn)?.unwrap_or(Hnsw::builder().build_hnsw(Vec::default()).0);
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
if target.len() != expected_size {
return Err(UserError::InvalidVectorDimensions {
expected: expected_size,
found: target.len(),
}
.into());
}
}
let target_clone = target.clone();
let producer = move |time_capsule| -> SearchFut {
Box::pin(search_scope(time_capsule, hnsw, target_clone))
};
let scope = DynBoxScope::new(producer);
Ok(Self { query: None, target, vector_candidates, scope })
}
}
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
fn id(&self) -> String {
"vector_sort".to_owned()
}
fn start_iteration(
&mut self,
_ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
query: &Q,
) -> Result<()> {
assert!(self.query.is_none());
self.query = Some(query.clone());
self.vector_candidates &= universe;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn next_bucket(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
self.vector_candidates &= universe;
if self.vector_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
let scope = &mut self.scope;
let target = &self.target;
let vector_candidates = &self.vector_candidates;
scope.enter(|it| {
for item in it.by_ref() {
let item: Item = item;
let index = item.pid.into_inner();
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
if vector_candidates.contains(docid) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: target.clone(),
value_similarity: Some((
item.point.clone().into_inner(),
1.0 - item.distance,
)),
}),
}));
}
}
Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: target.clone(),
value_similarity: None,
}),
}))
})
}
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {
self.query = None;
}
}