Various changes

- DistributionShift in Search object (to be set from model in embed?)
- Fix issue where embedder index wasn't computed at search time
- Accept as default embedder either the "default" one, or the only embedder when there is only one
This commit is contained in:
Louis Dureuil
2023-12-13 15:38:44 +01:00
parent 12940d79a9
commit e0cc775dc4
12 changed files with 141 additions and 33 deletions

View File

@ -218,6 +218,8 @@ impl<'a> Search<'a> {
exhaustive_number_hits: self.exhaustive_number_hits,
rtxn: self.rtxn,
index: self.index,
distribution_shift: self.distribution_shift,
embedder_name: self.embedder_name.clone(),
};
let vector_query = search.vector.take();
@ -265,6 +267,15 @@ impl<'a> Search<'a> {
vector: &[f32],
keyword_results: &SearchResult,
) -> Result<PartialSearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
@ -282,6 +293,8 @@ impl<'a> Search<'a> {
self.geo_strategy,
0,
self.limit + self.offset,
self.distribution_shift,
embedder_name,
)
}

View File

@ -17,6 +17,7 @@ use self::new::{execute_vector_search, PartialSearchResult};
use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift;
use crate::{
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
Result, SearchContext,
@ -51,6 +52,8 @@ pub struct Search<'a> {
exhaustive_number_hits: bool,
rtxn: &'a heed::RoTxn<'a>,
index: &'a Index,
distribution_shift: Option<DistributionShift>,
embedder_name: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
@ -117,6 +120,8 @@ impl<'a> Search<'a> {
words_limit: 10,
rtxn,
index,
distribution_shift: None,
embedder_name: None,
}
}
@ -183,7 +188,29 @@ impl<'a> Search<'a> {
self
}
pub fn distribution_shift(
&mut self,
distribution_shift: Option<DistributionShift>,
) -> &mut Search<'a> {
self.distribution_shift = distribution_shift;
self
}
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
self.embedder_name = Some(embedder_name.into());
self
}
pub fn execute(&self) -> Result<SearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
@ -202,6 +229,8 @@ impl<'a> Search<'a> {
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
)?,
None => execute_search(
&mut ctx,
@ -247,6 +276,8 @@ impl fmt::Debug for Search<'_> {
exhaustive_number_hits,
rtxn: _,
index: _,
distribution_shift,
embedder_name,
} = self;
f.debug_struct("Search")
.field("query", query)
@ -260,6 +291,8 @@ impl fmt::Debug for Search<'_> {
.field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit)
.field("distribution_shift", distribution_shift)
.field("embedder_name", embedder_name)
.finish()
}
}

View File

@ -266,6 +266,7 @@ fn get_ranking_rules_for_vector<'ctx>(
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
@ -292,6 +293,7 @@ fn get_ranking_rules_for_vector<'ctx>(
vector_candidates,
limit_plus_offset,
distribution_shift,
embedder_name,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
@ -513,6 +515,8 @@ pub fn execute_vector_search(
geo_strategy: geo_sort::Strategy,
from: usize,
length: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -524,7 +528,8 @@ pub fn execute_vector_search(
geo_strategy,
from + length,
vector,
None,
distribution_shift,
embedder_name,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;

View File

@ -15,16 +15,21 @@ pub struct VectorSort<Q: RankingRuleQueryTrait> {
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_index: u8,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
pub fn new(
_ctx: &SearchContext,
ctx: &SearchContext,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Self> {
/// FIXME: unwrap
let embedder_index = ctx.index.embedder_category_id.get(ctx.txn, embedder_name)?.unwrap();
Ok(Self {
query: None,
target,
@ -32,6 +37,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
embedder_index,
})
}
@ -40,9 +46,10 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
ctx: &mut SearchContext<'_>,
vector_candidates: &RoaringBitmap,
) -> Result<()> {
let writer_index = (self.embedder_index as u16) << 8;
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map_while(|k| {
arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy)
arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy)
.map(Some)
.or_else(|e| match e {
arroy::Error::MissingMetadata => Ok(None),