hybrid search uses semantic ratio, error handling

This commit is contained in:
Louis Dureuil
2023-12-14 12:42:37 +01:00
parent 1b7c164a55
commit 217105b7da
10 changed files with 89 additions and 316 deletions

View File

@@ -299,6 +299,7 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
MissingIndexUid , InvalidRequest , BAD_REQUEST ; MissingIndexUid , InvalidRequest , BAD_REQUEST ;
MissingMasterKey , Auth , UNAUTHORIZED ; MissingMasterKey , Auth , UNAUTHORIZED ;
MissingPayload , InvalidRequest , BAD_REQUEST ; MissingPayload , InvalidRequest , BAD_REQUEST ;
MissingSearchHybrid , InvalidRequest , BAD_REQUEST ;
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ; MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
MissingTaskFilters , InvalidRequest , BAD_REQUEST ; MissingTaskFilters , InvalidRequest , BAD_REQUEST ;
NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY; NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY;

View File

@@ -692,7 +692,7 @@ impl SearchAggregator {
ret.max_terms_number = q.split_whitespace().count(); ret.max_terms_number = q.split_whitespace().count();
} }
if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector { if let Some(ref vector) = vector {
ret.max_vector_size = vector.len(); ret.max_vector_size = vector.len();
} }

View File

@@ -51,6 +51,8 @@ pub enum MeilisearchHttpError {
DocumentFormat(#[from] DocumentFormatError), DocumentFormat(#[from] DocumentFormatError),
#[error(transparent)] #[error(transparent)]
Join(#[from] JoinError), Join(#[from] JoinError),
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
MissingSearchHybrid,
} }
impl ErrorCode for MeilisearchHttpError { impl ErrorCode for MeilisearchHttpError {
@@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError {
MeilisearchHttpError::FileStore(_) => Code::Internal, MeilisearchHttpError::FileStore(_) => Code::Internal,
MeilisearchHttpError::DocumentFormat(e) => e.error_code(), MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
MeilisearchHttpError::Join(_) => Code::Internal, MeilisearchHttpError::Join(_) => Code::Internal,
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
} }
} }
} }

View File

@@ -7,7 +7,6 @@ use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::VectorQuery;
use serde_json::Value; use serde_json::Value;
use crate::analytics::{Analytics, FacetSearchAggregator}; use crate::analytics::{Analytics, FacetSearchAggregator};
@@ -121,7 +120,7 @@ impl From<FacetSearchQuery> for SearchQuery {
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
crop_marker: DEFAULT_CROP_MARKER(), crop_marker: DEFAULT_CROP_MARKER(),
matching_strategy, matching_strategy,
vector: vector.map(VectorQuery::Vector), vector,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
} }

View File

@@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::{self, VectorQuery}; use meilisearch_types::milli;
use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value; use serde_json::Value;
@@ -128,7 +128,7 @@ impl From<SearchQueryGet> for SearchQuery {
Self { Self {
q: other.q, q: other.q,
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), vector: other.vector.map(CS::into_inner),
offset: other.offset.0, offset: other.offset.0,
limit: other.limit.0, limit: other.limit.0,
page: other.page.as_deref().copied(), page: other.page.as_deref().copied(),
@@ -258,21 +258,13 @@ pub async fn embed(
index_scheduler: &IndexScheduler, index_scheduler: &IndexScheduler,
index: &milli::Index, index: &milli::Index,
) -> Result<(), ResponseError> { ) -> Result<(), ResponseError> {
match query.vector.take() { if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) =
Some(VectorQuery::String(prompt)) => { (&query.vector, &query.q, &query.hybrid)
{
let embedder_configs = index.embedding_configs(&index.read_txn()?)?; let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?; let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder_name = let embedder = if let Some(embedder_name) = embedder {
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
&query.hybrid
{
Some(embedder)
} else {
None
};
let embedder = if let Some(embedder_name) = embedder_name {
embedders.get(embedder_name) embedders.get(embedder_name)
} else { } else {
embedders.get_default() embedders.get_default()
@@ -283,7 +275,7 @@ pub async fn embed(
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
.0; .0;
let embeddings = embedder let embeddings = embedder
.embed(vec![prompt]) .embed(vec![q.to_owned()])
.await .await
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
@@ -292,15 +284,11 @@ pub async fn embed(
if embeddings.iter().nth(1).is_some() { if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query"); warn!("Ignoring embeddings past the first one in long search query");
query.vector = query.vector = Some(embeddings.iter().next().unwrap().to_vec());
Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
} else { } else {
query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); query.vector = Some(embeddings.into_inner());
} }
} }
Some(vector) => query.vector = Some(vector),
None => {}
};
Ok(()) Ok(())
} }

View File

@@ -7,14 +7,13 @@ use deserr::Deserr;
use either::Either; use either::Either;
use index_scheduler::RoFeatures; use index_scheduler::RoFeatures;
use indexmap::IndexMap; use indexmap::IndexMap;
use log::warn;
use meilisearch_auth::IndexSearchRules; use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn; use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document}; use meilisearch_types::{milli, Document};
use milli::tokenizer::TokenizerBuilder; use milli::tokenizer::TokenizerBuilder;
@@ -44,7 +43,7 @@ pub struct SearchQuery {
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<milli::VectorQuery>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>, pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
@@ -105,6 +104,8 @@ impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio; type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> { fn try_from(f: f32) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 { if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio) Err(InvalidSearchSemanticRatio)
} else { } else {
@@ -139,7 +140,7 @@ pub struct SearchQueryWithIndex {
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub vector: Option<VectorQuery>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>, pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
@@ -376,8 +377,16 @@ fn prepare_search<'t>(
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn); let mut search = index.search(rtxn);
if query.vector.is_some() && query.q.is_some() { if query.vector.is_some() {
warn!("Attempting hybrid search"); features.check_vector("Passing `vector` as a query parameter")?;
}
if query.hybrid.is_some() {
features.check_vector("Passing `hybrid` as a query parameter")?;
}
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
return Err(MeilisearchHttpError::MissingSearchHybrid);
} }
if let Some(ref vector) = query.vector { if let Some(ref vector) = query.vector {
@@ -385,14 +394,9 @@ fn prepare_search<'t>(
// If semantic ratio is 0.0, only the query search will impact the search results, // If semantic ratio is 0.0, only the query search will impact the search results,
// skip the vector // skip the vector
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
_otherwise => match vector { _otherwise => {
VectorQuery::Vector(vector) => {
search.vector(vector.clone()); search.vector(vector.clone());
} }
VectorQuery::String(_) => {
panic!("Failed while preparing search; caller did not generate embedding for query")
}
},
} }
} }
@@ -431,10 +435,6 @@ fn prepare_search<'t>(
features.check_score_details()?; features.check_score_details()?;
} }
if query.vector.is_some() {
features.check_vector("Passing `vector` as a query parameter")?;
}
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
search.embedder_name(embedder); search.embedder_name(embedder);
} }
@@ -492,7 +492,7 @@ pub fn perform_search(
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
match &query.hybrid { match &query.hybrid {
Some(hybrid) => match *hybrid.semantic_ratio { Some(hybrid) => match *hybrid.semantic_ratio {
0.0 | 1.0 => search.execute()?, ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
ratio => search.execute_hybrid(ratio)?, ratio => search.execute_hybrid(ratio)?,
}, },
None => search.execute()?, None => search.execute()?,
@@ -700,10 +700,7 @@ pub fn perform_search(
hits: documents, hits: documents,
hits_info, hits_info,
query: query.q.unwrap_or_default(), query: query.q.unwrap_or_default(),
vector: match query.vector { vector: query.vector,
Some(VectorQuery::Vector(vector)) => Some(vector),
_ => None,
},
processing_time_ms: before_search.elapsed().as_millis(), processing_time_ms: before_search.elapsed().as_millis(),
facet_distribution, facet_distribution,
facet_stats, facet_stats,

View File

@@ -1,4 +1,4 @@
use meili_snap::{json_string, snapshot}; use meili_snap::snapshot;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use crate::common::index::Index; use crate::common::index::Index;

View File

@@ -59,7 +59,7 @@ pub use self::index::Index;
pub use self::search::{ pub use self::search::{
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
VectorQuery, DEFAULT_VALUES_PER_FACET, DEFAULT_VALUES_PER_FACET,
}; };
pub type Result<T> = std::result::Result<T, error::Error>; pub type Result<T> = std::result::Result<T, error::Error>;

View File

@@ -1,49 +1,37 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use std::collections::HashMap;
use itertools::Itertools; use itertools::Itertools;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use super::new::{execute_vector_search, PartialSearchResult};
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::{ use crate::{MatchingWords, Result, Search, SearchResult};
execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult,
};
struct CombinedSearchResult { struct ScoreWithRatioResult {
matching_words: MatchingWords, matching_words: MatchingWords,
candidates: RoaringBitmap, candidates: RoaringBitmap,
document_scores: Vec<(u32, CombinedScore)>, document_scores: Vec<(u32, ScoreWithRatio)>,
} }
type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>); type ScoreWithRatio = (Vec<ScoreDetails>, f32);
fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { fn compare_scores(
let mut left_main_it = ScoreDetails::score_values(left.0.iter()); &(ref left_scores, left_ratio): &ScoreWithRatio,
let mut left_sub_it = &(ref right_scores, right_ratio): &ScoreWithRatio,
ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten()); ) -> Ordering {
let mut left_it = ScoreDetails::score_values(left_scores.iter());
let mut right_main_it = ScoreDetails::score_values(right.0.iter()); let mut right_it = ScoreDetails::score_values(right_scores.iter());
let mut right_sub_it =
ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten());
let mut left_main = left_main_it.next();
let mut left_sub = left_sub_it.next();
let mut right_main = right_main_it.next();
let mut right_sub = right_sub_it.next();
loop { loop {
let left = let left = left_it.next();
take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it); let right = right_it.next();
let right =
take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it);
match (left, right) { match (left, right) {
(None, None) => return Ordering::Equal, (None, None) => return Ordering::Equal,
(None, Some(_)) => return Ordering::Less, (None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater, (Some(_), None) => return Ordering::Greater,
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
let left = left * left_ratio as f64;
let right = right * right_ratio as f64;
if (left - right).abs() <= f64::EPSILON { if (left - right).abs() <= f64::EPSILON {
continue; continue;
} }
@@ -72,94 +60,17 @@ fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
} }
} }
fn take_best_score<'a>( impl ScoreWithRatioResult {
main_score: &mut Option<ScoreValue<'a>>, fn new(results: SearchResult, ratio: f32) -> Self {
sub_score: &mut Option<ScoreValue<'a>>, let document_scores = results
main_it: &mut impl Iterator<Item = ScoreValue<'a>>,
sub_it: &mut impl Iterator<Item = ScoreValue<'a>>,
) -> Option<ScoreValue<'a>> {
match (*main_score, *sub_score) {
(Some(main), None) => {
*main_score = main_it.next();
Some(main)
}
(None, Some(sub)) => {
*sub_score = sub_it.next();
Some(sub)
}
(main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => {
// take max, both advance
*main_score = main_it.next();
*sub_score = sub_it.next();
if main_f >= sub_v {
main
} else {
sub
}
}
(main @ Some(ScoreValue::Score(_)), _) => {
*main_score = main_it.next();
main
}
(_, sub @ Some(ScoreValue::Score(_))) => {
*sub_score = sub_it.next();
sub
}
(main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => {
// take best advance both
*main_score = main_it.next();
*sub_score = sub_it.next();
if main_geo >= sub_geo {
main
} else {
sub
}
}
(main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => {
// take best advance both
*main_score = main_it.next();
*sub_score = sub_it.next();
if main_sort >= sub_sort {
main
} else {
sub
}
}
(
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
) => None,
(None, None) => None,
}
}
impl CombinedSearchResult {
fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self {
let mut docid_scores = HashMap::new();
for (docid, score) in
main_results.documents_ids.iter().zip(main_results.document_scores.into_iter())
{
docid_scores.insert(*docid, (score, None));
}
for (docid, score) in ancillary_results
.documents_ids .documents_ids
.iter() .into_iter()
.zip(ancillary_results.document_scores.into_iter()) .zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
{ .collect();
docid_scores
.entry(*docid)
.and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score));
}
let mut document_scores: Vec<_> = docid_scores.into_iter().collect();
document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse());
Self { Self {
matching_words: main_results.matching_words, matching_words: results.matching_words,
candidates: main_results.candidates, candidates: results.candidates,
document_scores, document_scores,
} }
} }
@@ -200,7 +111,7 @@ impl CombinedSearchResult {
} }
impl<'a> Search<'a> { impl<'a> Search<'a> {
pub fn execute_hybrid(&self) -> Result<SearchResult> { pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> {
// TODO: find classier way to achieve that than to reset vector and query params // TODO: find classier way to achieve that than to reset vector and query params
// create separate keyword and semantic searches // create separate keyword and semantic searches
let mut search = Search { let mut search = Search {
@@ -223,8 +134,6 @@ impl<'a> Search<'a> {
}; };
let vector_query = search.vector.take(); let vector_query = search.vector.take();
let keyword_query = self.query.as_deref();
let keyword_results = search.execute()?; let keyword_results = search.execute()?;
// skip semantic search if we don't have a vector query (placeholder search) // skip semantic search if we don't have a vector query (placeholder search)
@@ -233,7 +142,7 @@ impl<'a> Search<'a> {
}; };
// completely skip semantic search if the results of the keyword search are good enough // completely skip semantic search if the results of the keyword search are good enough
if self.results_good_enough(&keyword_results) { if self.results_good_enough(&keyword_results, semantic_ratio) {
return Ok(keyword_results); return Ok(keyword_results);
} }
@@ -243,94 +152,18 @@ impl<'a> Search<'a> {
// TODO: would be better to have two distinct functions at this point // TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?; let vector_results = search.execute()?;
// Compute keyword scores for vector_results let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let keyword_results_for_vector = let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
self.keyword_results_for_vector(keyword_query, &vector_results)?;
// compute vector scores for keyword_results
let vector_results_for_keyword =
// can unwrap because we returned already if there was no vector query
self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?;
/// TODO apply sementic ratio
let keyword_results =
CombinedSearchResult::new(keyword_results, vector_results_for_keyword);
let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector);
let merge_results = let merge_results =
CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit); ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
assert!(merge_results.documents_ids.len() <= self.limit); assert!(merge_results.documents_ids.len() <= self.limit);
Ok(merge_results) Ok(merge_results)
} }
fn vector_results_for_keyword( fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
&self, // A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
vector: &[f32], const GOOD_ENOUGH_SCORE: f64 = 0.45;
keyword_results: &SearchResult,
) -> Result<PartialSearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = keyword_results.documents_ids.iter().collect();
execute_vector_search(
&mut ctx,
vector,
ScoringStrategy::Detailed,
universe,
&self.sort_criteria,
self.geo_strategy,
0,
self.limit + self.offset,
self.distribution_shift,
embedder_name,
)
}
fn keyword_results_for_vector(
&self,
query: Option<&str>,
vector_results: &SearchResult,
) -> Result<PartialSearchResult> {
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = vector_results.documents_ids.iter().collect();
execute_search(
&mut ctx,
query,
self.terms_matching_strategy,
ScoringStrategy::Detailed,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
0,
self.limit + self.offset,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)
}
fn results_good_enough(&self, keyword_results: &SearchResult) -> bool {
const GOOD_ENOUGH_SCORE: f64 = 0.9;
// 1. we check that we got a sufficient number of results // 1. we check that we got a sufficient number of results
if keyword_results.document_scores.len() < self.limit + self.offset { if keyword_results.document_scores.len() < self.limit + self.offset {
@@ -341,7 +174,7 @@ impl<'a> Search<'a> {
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order // we need to check all results because due to sort like rules, they're not necessarily in relevancy order
for score in &keyword_results.document_scores { for score in &keyword_results.document_scores {
let score = ScoreDetails::global_score(score.iter()); let score = ScoreDetails::global_score(score.iter());
if score < GOOD_ENOUGH_SCORE { if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
return false; return false;
} }
} }

View File

@@ -3,7 +3,6 @@ use std::ops::ControlFlow;
use charabia::normalizer::NormalizerOption; use charabia::normalizer::NormalizerOption;
use charabia::Normalize; use charabia::Normalize;
use deserr::{DeserializeError, Deserr, Sequence};
use fst::automaton::{Automaton, Str}; use fst::automaton::{Automaton, Str};
use fst::{IntoStreamer, Streamer}; use fst::{IntoStreamer, Streamer};
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
@@ -57,53 +56,6 @@ pub struct Search<'a> {
embedder_name: Option<String>, embedder_name: Option<String>,
} }
#[derive(Debug, Clone, PartialEq)]
pub enum VectorQuery {
Vector(Vec<f32>),
String(String),
}
impl<E> Deserr<E> for VectorQuery
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> std::result::Result<Self, E> {
match value {
deserr::Value::String(s) => Ok(VectorQuery::String(s)),
deserr::Value::Sequence(seq) => {
let v: std::result::Result<Vec<f32>, _> = seq
.into_iter()
.enumerate()
.map(|(index, v)| match v.into_value() {
deserr::Value::Float(f) => Ok(f as f32),
deserr::Value::Integer(i) => Ok(i as f32),
v => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: v,
accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer],
},
location.push_index(index),
))),
})
.collect();
Ok(VectorQuery::Vector(v?))
}
_ => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence],
},
location,
))),
}
}
}
impl<'a> Search<'a> { impl<'a> Search<'a> {
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
Search { Search {