Lazily embed, don't fail hybrid search on embedding failure

This commit is contained in:
Louis Dureuil
2024-03-28 11:50:53 +01:00
parent fabc9cf14a
commit 6ebb6b55a6
11 changed files with 237 additions and 203 deletions

View File

@ -1,4 +1,5 @@
use std::fmt;
use std::sync::Arc;
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
use once_cell::sync::Lazy;
@ -8,7 +9,7 @@ pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FAC
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::{execute_vector_search, PartialSearchResult};
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift;
use crate::vector::Embedder;
use crate::{
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result,
SearchContext, TimeBudget,
@ -24,9 +25,15 @@ mod fst_utils;
pub mod hybrid;
pub mod new;
#[derive(Debug, Clone)]
pub struct SemanticSearch {
vector: Option<Vec<f32>>,
embedder_name: String,
embedder: Arc<Embedder>,
}
pub struct Search<'a> {
query: Option<String>,
vector: Option<Vec<f32>>,
// this should be linked to the String in the query
filter: Option<Filter<'a>>,
offset: usize,
@ -38,12 +45,9 @@ pub struct Search<'a> {
scoring_strategy: ScoringStrategy,
words_limit: usize,
exhaustive_number_hits: bool,
/// TODO: Add semantic ratio or pass it directly to execute_hybrid()
rtxn: &'a heed::RoTxn<'a>,
index: &'a Index,
distribution_shift: Option<DistributionShift>,
embedder_name: Option<String>,
semantic: Option<SemanticSearch>,
time_budget: TimeBudget,
}
@ -51,7 +55,6 @@ impl<'a> Search<'a> {
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
Search {
query: None,
vector: None,
filter: None,
offset: 0,
limit: 20,
@ -64,8 +67,7 @@ impl<'a> Search<'a> {
words_limit: 10,
rtxn,
index,
distribution_shift: None,
embedder_name: None,
semantic: None,
time_budget: TimeBudget::max(),
}
}
@ -75,8 +77,13 @@ impl<'a> Search<'a> {
self
}
pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> {
self.vector = Some(vector);
pub fn semantic(
&mut self,
embedder_name: String,
embedder: Arc<Embedder>,
vector: Option<Vec<f32>>,
) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch { embedder_name, embedder, vector });
self
}
@ -133,19 +140,6 @@ impl<'a> Search<'a> {
self
}
pub fn distribution_shift(
&mut self,
distribution_shift: Option<DistributionShift>,
) -> &mut Search<'a> {
self.distribution_shift = distribution_shift;
self
}
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
self.embedder_name = Some(embedder_name.into());
self
}
pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> {
self.time_budget = time_budget;
self
@ -161,15 +155,6 @@ impl<'a> Search<'a> {
}
pub fn execute(&self) -> Result<SearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
@ -184,21 +169,23 @@ impl<'a> Search<'a> {
document_scores,
degraded,
used_negative_operator,
} = match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
self.time_budget.clone(),
)?,
None => execute_search(
} = match self.semantic.as_ref() {
Some(SemanticSearch { vector: Some(vector), embedder_name, embedder }) => {
execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
embedder_name,
embedder,
self.time_budget.clone(),
)?
}
_ => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy,
@ -237,7 +224,6 @@ impl fmt::Debug for Search<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Search {
query,
vector: _,
filter,
offset,
limit,
@ -250,8 +236,7 @@ impl fmt::Debug for Search<'_> {
exhaustive_number_hits,
rtxn: _,
index: _,
distribution_shift,
embedder_name,
semantic,
time_budget,
} = self;
f.debug_struct("Search")
@ -266,8 +251,10 @@ impl fmt::Debug for Search<'_> {
.field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit)
.field("distribution_shift", distribution_shift)
.field("embedder_name", embedder_name)
.field(
"semantic.embedder_name",
&semantic.as_ref().map(|semantic| &semantic.embedder_name),
)
.field("time_budget", time_budget)
.finish()
}