This commit is contained in:
Louis Dureuil
2023-12-07 17:03:10 +01:00
parent dde3a04679
commit cb4ebe163e
8 changed files with 185 additions and 157 deletions

View File

@ -262,6 +262,7 @@ fn get_ranking_rules_for_vector<'ctx>(
ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
@ -283,7 +284,12 @@ fn get_ranking_rules_for_vector<'ctx>(
| crate::Criterion::Exactness => {
if !vector {
let vector_candidates = ctx.index.documents_ids(ctx.txn)?;
let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?;
let vector_sort = VectorSort::new(
ctx,
target.to_vec(),
vector_candidates,
limit_plus_offset,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
}
@ -509,7 +515,8 @@ pub fn execute_vector_search(
/// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?;
let ranking_rules =
get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =

View File

@ -1,48 +1,83 @@
use std::future::Future;
use std::iter::FromIterator;
use std::pin::Pin;
use nolife::DynBoxScope;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::distance::NDotProductPoint;
use crate::index::Hnsw;
use crate::score_details::{self, ScoreDetails};
use crate::{Result, SearchContext, SearchLogger, UserError};
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<'ctx, Q: RankingRuleQueryTrait> {
pub struct VectorSort<Q: RankingRuleQueryTrait> {
query: Option<Q>,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
reader: arroy::Reader<'ctx, arroy::distances::DotProduct>,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
}
impl<'ctx, Q: RankingRuleQueryTrait> VectorSort<'ctx, Q> {
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
pub fn new(
ctx: &'ctx SearchContext,
_ctx: &SearchContext,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
) -> Result<Self> {
/// FIXME? what to do in case of missing metadata
let reader = arroy::Reader::open(ctx.txn, 0, ctx.index.vector_arroy)?;
Ok(Self {
query: None,
target,
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
})
}
let target_clone = target.clone();
fn fill_buffer(&mut self, ctx: &mut SearchContext<'_>) -> Result<()> {
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map_while(|k| {
arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy)
.map(Some)
.or_else(|e| match e {
arroy::Error::MissingMetadata => Ok(None),
e => Err(e),
})
.transpose()
})
.collect();
Ok(Self { query: None, target, vector_candidates, reader, limit })
let readers = readers?;
let target = &self.target;
let mut results = Vec::new();
for reader in readers.iter() {
let nns_by_vector = reader.nns_by_vector(
ctx.txn,
&target,
self.limit,
None,
Some(&self.vector_candidates),
)?;
let vectors: std::result::Result<Vec<_>, _> = nns_by_vector
.iter()
.map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap())
.collect();
let vectors = vectors?;
results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z)));
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q> {
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
fn id(&self) -> String {
"vector_sort".to_owned()
}
fn start_iteration(
&mut self,
_ctx: &mut SearchContext<'ctx>,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
query: &Q,
@ -51,7 +86,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q
self.query = Some(query.clone());
self.vector_candidates &= universe;
self.fill_buffer(ctx)?;
Ok(())
}
@ -75,40 +110,24 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q
}),
}));
}
let target = &self.target;
let vector_candidates = &self.vector_candidates;
let result = self.reader.nns_by_vector(ctx.txn, &target, count, search_k, candidates)
scope.enter(|it| {
for item in it.by_ref() {
let item: Item = item;
let index = item.pid.into_inner();
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
if vector_candidates.contains(docid) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: target.clone(),
value_similarity: Some((
item.point.clone().into_inner(),
1.0 - item.distance,
)),
}),
}));
}
while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() {
if self.vector_candidates.contains(docid) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, 1.0 - distance)),
}),
}));
}
Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: target.clone(),
value_similarity: None,
}),
}))
})
}
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx)?;
self.next_bucket(ctx, _logger, universe)
}
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {