mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 21:16:28 +00:00 
			
		
		
		
	hybrid search uses semantic ratio, error handling
This commit is contained in:
		| @@ -299,6 +299,7 @@ MissingFacetSearchFacetName           , InvalidRequest       , BAD_REQUEST ; | |||||||
| MissingIndexUid                       , InvalidRequest       , BAD_REQUEST ; | MissingIndexUid                       , InvalidRequest       , BAD_REQUEST ; | ||||||
| MissingMasterKey                      , Auth                 , UNAUTHORIZED ; | MissingMasterKey                      , Auth                 , UNAUTHORIZED ; | ||||||
| MissingPayload                        , InvalidRequest       , BAD_REQUEST ; | MissingPayload                        , InvalidRequest       , BAD_REQUEST ; | ||||||
|  | MissingSearchHybrid                   , InvalidRequest       , BAD_REQUEST ; | ||||||
| MissingSwapIndexes                    , InvalidRequest       , BAD_REQUEST ; | MissingSwapIndexes                    , InvalidRequest       , BAD_REQUEST ; | ||||||
| MissingTaskFilters                    , InvalidRequest       , BAD_REQUEST ; | MissingTaskFilters                    , InvalidRequest       , BAD_REQUEST ; | ||||||
| NoSpaceLeftOnDevice                   , System               , UNPROCESSABLE_ENTITY; | NoSpaceLeftOnDevice                   , System               , UNPROCESSABLE_ENTITY; | ||||||
|   | |||||||
| @@ -692,7 +692,7 @@ impl SearchAggregator { | |||||||
|             ret.max_terms_number = q.split_whitespace().count(); |             ret.max_terms_number = q.split_whitespace().count(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector { |         if let Some(ref vector) = vector { | ||||||
|             ret.max_vector_size = vector.len(); |             ret.max_vector_size = vector.len(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -51,6 +51,8 @@ pub enum MeilisearchHttpError { | |||||||
|     DocumentFormat(#[from] DocumentFormatError), |     DocumentFormat(#[from] DocumentFormatError), | ||||||
|     #[error(transparent)] |     #[error(transparent)] | ||||||
|     Join(#[from] JoinError), |     Join(#[from] JoinError), | ||||||
|  |     #[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")] | ||||||
|  |     MissingSearchHybrid, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl ErrorCode for MeilisearchHttpError { | impl ErrorCode for MeilisearchHttpError { | ||||||
| @@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError { | |||||||
|             MeilisearchHttpError::FileStore(_) => Code::Internal, |             MeilisearchHttpError::FileStore(_) => Code::Internal, | ||||||
|             MeilisearchHttpError::DocumentFormat(e) => e.error_code(), |             MeilisearchHttpError::DocumentFormat(e) => e.error_code(), | ||||||
|             MeilisearchHttpError::Join(_) => Code::Internal, |             MeilisearchHttpError::Join(_) => Code::Internal, | ||||||
|  |             MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,7 +7,6 @@ use meilisearch_types::deserr::DeserrJsonError; | |||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
| use meilisearch_types::error::ResponseError; | use meilisearch_types::error::ResponseError; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
| use meilisearch_types::milli::VectorQuery; |  | ||||||
| use serde_json::Value; | use serde_json::Value; | ||||||
|  |  | ||||||
| use crate::analytics::{Analytics, FacetSearchAggregator}; | use crate::analytics::{Analytics, FacetSearchAggregator}; | ||||||
| @@ -121,7 +120,7 @@ impl From<FacetSearchQuery> for SearchQuery { | |||||||
|             highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), |             highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), | ||||||
|             crop_marker: DEFAULT_CROP_MARKER(), |             crop_marker: DEFAULT_CROP_MARKER(), | ||||||
|             matching_strategy, |             matching_strategy, | ||||||
|             vector: vector.map(VectorQuery::Vector), |             vector, | ||||||
|             attributes_to_search_on, |             attributes_to_search_on, | ||||||
|             hybrid, |             hybrid, | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; | |||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
| use meilisearch_types::error::ResponseError; | use meilisearch_types::error::ResponseError; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
| use meilisearch_types::milli::{self, VectorQuery}; | use meilisearch_types::milli; | ||||||
| use meilisearch_types::serde_cs::vec::CS; | use meilisearch_types::serde_cs::vec::CS; | ||||||
| use serde_json::Value; | use serde_json::Value; | ||||||
|  |  | ||||||
| @@ -128,7 +128,7 @@ impl From<SearchQueryGet> for SearchQuery { | |||||||
|  |  | ||||||
|         Self { |         Self { | ||||||
|             q: other.q, |             q: other.q, | ||||||
|             vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), |             vector: other.vector.map(CS::into_inner), | ||||||
|             offset: other.offset.0, |             offset: other.offset.0, | ||||||
|             limit: other.limit.0, |             limit: other.limit.0, | ||||||
|             page: other.page.as_deref().copied(), |             page: other.page.as_deref().copied(), | ||||||
| @@ -258,21 +258,13 @@ pub async fn embed( | |||||||
|     index_scheduler: &IndexScheduler, |     index_scheduler: &IndexScheduler, | ||||||
|     index: &milli::Index, |     index: &milli::Index, | ||||||
| ) -> Result<(), ResponseError> { | ) -> Result<(), ResponseError> { | ||||||
|     match query.vector.take() { |     if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) = | ||||||
|         Some(VectorQuery::String(prompt)) => { |         (&query.vector, &query.q, &query.hybrid) | ||||||
|  |     { | ||||||
|         let embedder_configs = index.embedding_configs(&index.read_txn()?)?; |         let embedder_configs = index.embedding_configs(&index.read_txn()?)?; | ||||||
|         let embedders = index_scheduler.embedders(embedder_configs)?; |         let embedders = index_scheduler.embedders(embedder_configs)?; | ||||||
|  |  | ||||||
|             let embedder_name = |         let embedder = if let Some(embedder_name) = embedder { | ||||||
|                 if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) = |  | ||||||
|                     &query.hybrid |  | ||||||
|                 { |  | ||||||
|                     Some(embedder) |  | ||||||
|                 } else { |  | ||||||
|                     None |  | ||||||
|                 }; |  | ||||||
|  |  | ||||||
|             let embedder = if let Some(embedder_name) = embedder_name { |  | ||||||
|             embedders.get(embedder_name) |             embedders.get(embedder_name) | ||||||
|         } else { |         } else { | ||||||
|             embedders.get_default() |             embedders.get_default() | ||||||
| @@ -283,7 +275,7 @@ pub async fn embed( | |||||||
|             .map_err(milli::Error::from)? |             .map_err(milli::Error::from)? | ||||||
|             .0; |             .0; | ||||||
|         let embeddings = embedder |         let embeddings = embedder | ||||||
|                 .embed(vec![prompt]) |             .embed(vec![q.to_owned()]) | ||||||
|             .await |             .await | ||||||
|             .map_err(milli::vector::Error::from) |             .map_err(milli::vector::Error::from) | ||||||
|             .map_err(milli::Error::from)? |             .map_err(milli::Error::from)? | ||||||
| @@ -292,15 +284,11 @@ pub async fn embed( | |||||||
|  |  | ||||||
|         if embeddings.iter().nth(1).is_some() { |         if embeddings.iter().nth(1).is_some() { | ||||||
|             warn!("Ignoring embeddings past the first one in long search query"); |             warn!("Ignoring embeddings past the first one in long search query"); | ||||||
|                 query.vector = |             query.vector = Some(embeddings.iter().next().unwrap().to_vec()); | ||||||
|                     Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); |  | ||||||
|         } else { |         } else { | ||||||
|                 query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); |             query.vector = Some(embeddings.into_inner()); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|         Some(vector) => query.vector = Some(vector), |  | ||||||
|         None => {} |  | ||||||
|     }; |  | ||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,14 +7,13 @@ use deserr::Deserr; | |||||||
| use either::Either; | use either::Either; | ||||||
| use index_scheduler::RoFeatures; | use index_scheduler::RoFeatures; | ||||||
| use indexmap::IndexMap; | use indexmap::IndexMap; | ||||||
| use log::warn; |  | ||||||
| use meilisearch_auth::IndexSearchRules; | use meilisearch_auth::IndexSearchRules; | ||||||
| use meilisearch_types::deserr::DeserrJsonError; | use meilisearch_types::deserr::DeserrJsonError; | ||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
| use meilisearch_types::heed::RoTxn; | use meilisearch_types::heed::RoTxn; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
| use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; | use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; | ||||||
| use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; | use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; | ||||||
| use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | ||||||
| use meilisearch_types::{milli, Document}; | use meilisearch_types::{milli, Document}; | ||||||
| use milli::tokenizer::TokenizerBuilder; | use milli::tokenizer::TokenizerBuilder; | ||||||
| @@ -44,7 +43,7 @@ pub struct SearchQuery { | |||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||||
|     pub q: Option<String>, |     pub q: Option<String>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] | ||||||
|     pub vector: Option<milli::VectorQuery>, |     pub vector: Option<Vec<f32>>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] |     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] | ||||||
|     pub hybrid: Option<HybridQuery>, |     pub hybrid: Option<HybridQuery>, | ||||||
|     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] |     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] | ||||||
| @@ -105,6 +104,8 @@ impl std::convert::TryFrom<f32> for SemanticRatio { | |||||||
|     type Error = InvalidSearchSemanticRatio; |     type Error = InvalidSearchSemanticRatio; | ||||||
|  |  | ||||||
|     fn try_from(f: f32) -> Result<Self, Self::Error> { |     fn try_from(f: f32) -> Result<Self, Self::Error> { | ||||||
|  |         // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable | ||||||
|  |         #[allow(clippy::manual_range_contains)] | ||||||
|         if f > 1.0 || f < 0.0 { |         if f > 1.0 || f < 0.0 { | ||||||
|             Err(InvalidSearchSemanticRatio) |             Err(InvalidSearchSemanticRatio) | ||||||
|         } else { |         } else { | ||||||
| @@ -139,7 +140,7 @@ pub struct SearchQueryWithIndex { | |||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||||
|     pub q: Option<String>, |     pub q: Option<String>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||||
|     pub vector: Option<VectorQuery>, |     pub vector: Option<Vec<f32>>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] |     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] | ||||||
|     pub hybrid: Option<HybridQuery>, |     pub hybrid: Option<HybridQuery>, | ||||||
|     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] |     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] | ||||||
| @@ -376,8 +377,16 @@ fn prepare_search<'t>( | |||||||
| ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { | ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { | ||||||
|     let mut search = index.search(rtxn); |     let mut search = index.search(rtxn); | ||||||
|  |  | ||||||
|     if query.vector.is_some() && query.q.is_some() { |     if query.vector.is_some() { | ||||||
|         warn!("Attempting hybrid search"); |         features.check_vector("Passing `vector` as a query parameter")?; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if query.hybrid.is_some() { | ||||||
|  |         features.check_vector("Passing `hybrid` as a query parameter")?; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() { | ||||||
|  |         return Err(MeilisearchHttpError::MissingSearchHybrid); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if let Some(ref vector) = query.vector { |     if let Some(ref vector) = query.vector { | ||||||
| @@ -385,14 +394,9 @@ fn prepare_search<'t>( | |||||||
|             // If semantic ratio is 0.0, only the query search will impact the search results, |             // If semantic ratio is 0.0, only the query search will impact the search results, | ||||||
|             // skip the vector |             // skip the vector | ||||||
|             Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), |             Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), | ||||||
|             _otherwise => match vector { |             _otherwise => { | ||||||
|                 VectorQuery::Vector(vector) => { |  | ||||||
|                 search.vector(vector.clone()); |                 search.vector(vector.clone()); | ||||||
|             } |             } | ||||||
|                 VectorQuery::String(_) => { |  | ||||||
|                     panic!("Failed while preparing search; caller did not generate embedding for query") |  | ||||||
|                 } |  | ||||||
|             }, |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -431,10 +435,6 @@ fn prepare_search<'t>( | |||||||
|         features.check_score_details()?; |         features.check_score_details()?; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if query.vector.is_some() { |  | ||||||
|         features.check_vector("Passing `vector` as a query parameter")?; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { |     if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { | ||||||
|         search.embedder_name(embedder); |         search.embedder_name(embedder); | ||||||
|     } |     } | ||||||
| @@ -492,7 +492,7 @@ pub fn perform_search( | |||||||
|     let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = |     let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = | ||||||
|         match &query.hybrid { |         match &query.hybrid { | ||||||
|             Some(hybrid) => match *hybrid.semantic_ratio { |             Some(hybrid) => match *hybrid.semantic_ratio { | ||||||
|                 0.0 | 1.0 => search.execute()?, |                 ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?, | ||||||
|                 ratio => search.execute_hybrid(ratio)?, |                 ratio => search.execute_hybrid(ratio)?, | ||||||
|             }, |             }, | ||||||
|             None => search.execute()?, |             None => search.execute()?, | ||||||
| @@ -700,10 +700,7 @@ pub fn perform_search( | |||||||
|         hits: documents, |         hits: documents, | ||||||
|         hits_info, |         hits_info, | ||||||
|         query: query.q.unwrap_or_default(), |         query: query.q.unwrap_or_default(), | ||||||
|         vector: match query.vector { |         vector: query.vector, | ||||||
|             Some(VectorQuery::Vector(vector)) => Some(vector), |  | ||||||
|             _ => None, |  | ||||||
|         }, |  | ||||||
|         processing_time_ms: before_search.elapsed().as_millis(), |         processing_time_ms: before_search.elapsed().as_millis(), | ||||||
|         facet_distribution, |         facet_distribution, | ||||||
|         facet_stats, |         facet_stats, | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| use meili_snap::{json_string, snapshot}; | use meili_snap::snapshot; | ||||||
| use once_cell::sync::Lazy; | use once_cell::sync::Lazy; | ||||||
|  |  | ||||||
| use crate::common::index::Index; | use crate::common::index::Index; | ||||||
|   | |||||||
| @@ -59,7 +59,7 @@ pub use self::index::Index; | |||||||
| pub use self::search::{ | pub use self::search::{ | ||||||
|     FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, |     FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, | ||||||
|     MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, |     MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, | ||||||
|     VectorQuery, DEFAULT_VALUES_PER_FACET, |     DEFAULT_VALUES_PER_FACET, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| pub type Result<T> = std::result::Result<T, error::Error>; | pub type Result<T> = std::result::Result<T, error::Error>; | ||||||
|   | |||||||
| @@ -1,49 +1,37 @@ | |||||||
| use std::cmp::Ordering; | use std::cmp::Ordering; | ||||||
| use std::collections::HashMap; |  | ||||||
|  |  | ||||||
| use itertools::Itertools; | use itertools::Itertools; | ||||||
| use roaring::RoaringBitmap; | use roaring::RoaringBitmap; | ||||||
|  |  | ||||||
| use super::new::{execute_vector_search, PartialSearchResult}; |  | ||||||
| use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; | use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; | ||||||
| use crate::{ | use crate::{MatchingWords, Result, Search, SearchResult}; | ||||||
|     execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| struct CombinedSearchResult { | struct ScoreWithRatioResult { | ||||||
|     matching_words: MatchingWords, |     matching_words: MatchingWords, | ||||||
|     candidates: RoaringBitmap, |     candidates: RoaringBitmap, | ||||||
|     document_scores: Vec<(u32, CombinedScore)>, |     document_scores: Vec<(u32, ScoreWithRatio)>, | ||||||
| } | } | ||||||
|  |  | ||||||
| type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>); | type ScoreWithRatio = (Vec<ScoreDetails>, f32); | ||||||
|  |  | ||||||
| fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { | fn compare_scores( | ||||||
|     let mut left_main_it = ScoreDetails::score_values(left.0.iter()); |     &(ref left_scores, left_ratio): &ScoreWithRatio, | ||||||
|     let mut left_sub_it = |     &(ref right_scores, right_ratio): &ScoreWithRatio, | ||||||
|         ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten()); | ) -> Ordering { | ||||||
|  |     let mut left_it = ScoreDetails::score_values(left_scores.iter()); | ||||||
|     let mut right_main_it = ScoreDetails::score_values(right.0.iter()); |     let mut right_it = ScoreDetails::score_values(right_scores.iter()); | ||||||
|     let mut right_sub_it = |  | ||||||
|         ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten()); |  | ||||||
|  |  | ||||||
|     let mut left_main = left_main_it.next(); |  | ||||||
|     let mut left_sub = left_sub_it.next(); |  | ||||||
|     let mut right_main = right_main_it.next(); |  | ||||||
|     let mut right_sub = right_sub_it.next(); |  | ||||||
|  |  | ||||||
|     loop { |     loop { | ||||||
|         let left = |         let left = left_it.next(); | ||||||
|             take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it); |         let right = right_it.next(); | ||||||
|  |  | ||||||
|         let right = |  | ||||||
|             take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it); |  | ||||||
|  |  | ||||||
|         match (left, right) { |         match (left, right) { | ||||||
|             (None, None) => return Ordering::Equal, |             (None, None) => return Ordering::Equal, | ||||||
|             (None, Some(_)) => return Ordering::Less, |             (None, Some(_)) => return Ordering::Less, | ||||||
|             (Some(_), None) => return Ordering::Greater, |             (Some(_), None) => return Ordering::Greater, | ||||||
|             (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { |             (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { | ||||||
|  |                 let left = left * left_ratio as f64; | ||||||
|  |                 let right = right * right_ratio as f64; | ||||||
|                 if (left - right).abs() <= f64::EPSILON { |                 if (left - right).abs() <= f64::EPSILON { | ||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
| @@ -72,94 +60,17 @@ fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| fn take_best_score<'a>( | impl ScoreWithRatioResult { | ||||||
|     main_score: &mut Option<ScoreValue<'a>>, |     fn new(results: SearchResult, ratio: f32) -> Self { | ||||||
|     sub_score: &mut Option<ScoreValue<'a>>, |         let document_scores = results | ||||||
|     main_it: &mut impl Iterator<Item = ScoreValue<'a>>, |  | ||||||
|     sub_it: &mut impl Iterator<Item = ScoreValue<'a>>, |  | ||||||
| ) -> Option<ScoreValue<'a>> { |  | ||||||
|     match (*main_score, *sub_score) { |  | ||||||
|         (Some(main), None) => { |  | ||||||
|             *main_score = main_it.next(); |  | ||||||
|             Some(main) |  | ||||||
|         } |  | ||||||
|         (None, Some(sub)) => { |  | ||||||
|             *sub_score = sub_it.next(); |  | ||||||
|             Some(sub) |  | ||||||
|         } |  | ||||||
|         (main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => { |  | ||||||
|             // take max, both advance |  | ||||||
|             *main_score = main_it.next(); |  | ||||||
|             *sub_score = sub_it.next(); |  | ||||||
|             if main_f >= sub_v { |  | ||||||
|                 main |  | ||||||
|             } else { |  | ||||||
|                 sub |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         (main @ Some(ScoreValue::Score(_)), _) => { |  | ||||||
|             *main_score = main_it.next(); |  | ||||||
|             main |  | ||||||
|         } |  | ||||||
|         (_, sub @ Some(ScoreValue::Score(_))) => { |  | ||||||
|             *sub_score = sub_it.next(); |  | ||||||
|             sub |  | ||||||
|         } |  | ||||||
|         (main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => { |  | ||||||
|             // take best advance both |  | ||||||
|             *main_score = main_it.next(); |  | ||||||
|             *sub_score = sub_it.next(); |  | ||||||
|             if main_geo >= sub_geo { |  | ||||||
|                 main |  | ||||||
|             } else { |  | ||||||
|                 sub |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         (main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => { |  | ||||||
|             // take best advance both |  | ||||||
|             *main_score = main_it.next(); |  | ||||||
|             *sub_score = sub_it.next(); |  | ||||||
|             if main_sort >= sub_sort { |  | ||||||
|                 main |  | ||||||
|             } else { |  | ||||||
|                 sub |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         ( |  | ||||||
|             Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), |  | ||||||
|             Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), |  | ||||||
|         ) => None, |  | ||||||
|  |  | ||||||
|         (None, None) => None, |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl CombinedSearchResult { |  | ||||||
|     fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self { |  | ||||||
|         let mut docid_scores = HashMap::new(); |  | ||||||
|         for (docid, score) in |  | ||||||
|             main_results.documents_ids.iter().zip(main_results.document_scores.into_iter()) |  | ||||||
|         { |  | ||||||
|             docid_scores.insert(*docid, (score, None)); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         for (docid, score) in ancillary_results |  | ||||||
|             .documents_ids |             .documents_ids | ||||||
|             .iter() |             .into_iter() | ||||||
|             .zip(ancillary_results.document_scores.into_iter()) |             .zip(results.document_scores.into_iter().map(|scores| (scores, ratio))) | ||||||
|         { |             .collect(); | ||||||
|             docid_scores |  | ||||||
|                 .entry(*docid) |  | ||||||
|                 .and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score)); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let mut document_scores: Vec<_> = docid_scores.into_iter().collect(); |  | ||||||
|  |  | ||||||
|         document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse()); |  | ||||||
|  |  | ||||||
|         Self { |         Self { | ||||||
|             matching_words: main_results.matching_words, |             matching_words: results.matching_words, | ||||||
|             candidates: main_results.candidates, |             candidates: results.candidates, | ||||||
|             document_scores, |             document_scores, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -200,7 +111,7 @@ impl CombinedSearchResult { | |||||||
| } | } | ||||||
|  |  | ||||||
| impl<'a> Search<'a> { | impl<'a> Search<'a> { | ||||||
|     pub fn execute_hybrid(&self) -> Result<SearchResult> { |     pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> { | ||||||
|         // TODO: find classier way to achieve that than to reset vector and query params |         // TODO: find classier way to achieve that than to reset vector and query params | ||||||
|         // create separate keyword and semantic searches |         // create separate keyword and semantic searches | ||||||
|         let mut search = Search { |         let mut search = Search { | ||||||
| @@ -223,8 +134,6 @@ impl<'a> Search<'a> { | |||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let vector_query = search.vector.take(); |         let vector_query = search.vector.take(); | ||||||
|         let keyword_query = self.query.as_deref(); |  | ||||||
|  |  | ||||||
|         let keyword_results = search.execute()?; |         let keyword_results = search.execute()?; | ||||||
|  |  | ||||||
|         // skip semantic search if we don't have a vector query (placeholder search) |         // skip semantic search if we don't have a vector query (placeholder search) | ||||||
| @@ -233,7 +142,7 @@ impl<'a> Search<'a> { | |||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         // completely skip semantic search if the results of the keyword search are good enough |         // completely skip semantic search if the results of the keyword search are good enough | ||||||
|         if self.results_good_enough(&keyword_results) { |         if self.results_good_enough(&keyword_results, semantic_ratio) { | ||||||
|             return Ok(keyword_results); |             return Ok(keyword_results); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -243,94 +152,18 @@ impl<'a> Search<'a> { | |||||||
|         // TODO: would be better to have two distinct functions at this point |         // TODO: would be better to have two distinct functions at this point | ||||||
|         let vector_results = search.execute()?; |         let vector_results = search.execute()?; | ||||||
|  |  | ||||||
|         // Compute keyword scores for vector_results |         let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); | ||||||
|         let keyword_results_for_vector = |         let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); | ||||||
|             self.keyword_results_for_vector(keyword_query, &vector_results)?; |  | ||||||
|  |  | ||||||
|         // compute vector scores for keyword_results |  | ||||||
|         let vector_results_for_keyword = |  | ||||||
|             // can unwrap because we returned already if there was no vector query |  | ||||||
|             self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?; |  | ||||||
|  |  | ||||||
|         /// TODO apply sementic ratio |  | ||||||
|         let keyword_results = |  | ||||||
|             CombinedSearchResult::new(keyword_results, vector_results_for_keyword); |  | ||||||
|         let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector); |  | ||||||
|  |  | ||||||
|         let merge_results = |         let merge_results = | ||||||
|             CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit); |             ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); | ||||||
|         assert!(merge_results.documents_ids.len() <= self.limit); |         assert!(merge_results.documents_ids.len() <= self.limit); | ||||||
|         Ok(merge_results) |         Ok(merge_results) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn vector_results_for_keyword( |     fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool { | ||||||
|         &self, |         // A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5 | ||||||
|         vector: &[f32], |         const GOOD_ENOUGH_SCORE: f64 = 0.45; | ||||||
|         keyword_results: &SearchResult, |  | ||||||
|     ) -> Result<PartialSearchResult> { |  | ||||||
|         let embedder_name; |  | ||||||
|         let embedder_name = match &self.embedder_name { |  | ||||||
|             Some(embedder_name) => embedder_name, |  | ||||||
|             None => { |  | ||||||
|                 embedder_name = self.index.default_embedding_name(self.rtxn)?; |  | ||||||
|                 &embedder_name |  | ||||||
|             } |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         let mut ctx = SearchContext::new(self.index, self.rtxn); |  | ||||||
|  |  | ||||||
|         if let Some(searchable_attributes) = self.searchable_attributes { |  | ||||||
|             ctx.searchable_attributes(searchable_attributes)?; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let universe = keyword_results.documents_ids.iter().collect(); |  | ||||||
|  |  | ||||||
|         execute_vector_search( |  | ||||||
|             &mut ctx, |  | ||||||
|             vector, |  | ||||||
|             ScoringStrategy::Detailed, |  | ||||||
|             universe, |  | ||||||
|             &self.sort_criteria, |  | ||||||
|             self.geo_strategy, |  | ||||||
|             0, |  | ||||||
|             self.limit + self.offset, |  | ||||||
|             self.distribution_shift, |  | ||||||
|             embedder_name, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn keyword_results_for_vector( |  | ||||||
|         &self, |  | ||||||
|         query: Option<&str>, |  | ||||||
|         vector_results: &SearchResult, |  | ||||||
|     ) -> Result<PartialSearchResult> { |  | ||||||
|         let mut ctx = SearchContext::new(self.index, self.rtxn); |  | ||||||
|  |  | ||||||
|         if let Some(searchable_attributes) = self.searchable_attributes { |  | ||||||
|             ctx.searchable_attributes(searchable_attributes)?; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let universe = vector_results.documents_ids.iter().collect(); |  | ||||||
|  |  | ||||||
|         execute_search( |  | ||||||
|             &mut ctx, |  | ||||||
|             query, |  | ||||||
|             self.terms_matching_strategy, |  | ||||||
|             ScoringStrategy::Detailed, |  | ||||||
|             self.exhaustive_number_hits, |  | ||||||
|             universe, |  | ||||||
|             &self.sort_criteria, |  | ||||||
|             self.geo_strategy, |  | ||||||
|             0, |  | ||||||
|             self.limit + self.offset, |  | ||||||
|             Some(self.words_limit), |  | ||||||
|             &mut DefaultSearchLogger, |  | ||||||
|             &mut DefaultSearchLogger, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn results_good_enough(&self, keyword_results: &SearchResult) -> bool { |  | ||||||
|         const GOOD_ENOUGH_SCORE: f64 = 0.9; |  | ||||||
|  |  | ||||||
|         // 1. we check that we got a sufficient number of results |         // 1. we check that we got a sufficient number of results | ||||||
|         if keyword_results.document_scores.len() < self.limit + self.offset { |         if keyword_results.document_scores.len() < self.limit + self.offset { | ||||||
| @@ -341,7 +174,7 @@ impl<'a> Search<'a> { | |||||||
|         // we need to check all results because due to sort like rules, they're not necessarily in relevancy order |         // we need to check all results because due to sort like rules, they're not necessarily in relevancy order | ||||||
|         for score in &keyword_results.document_scores { |         for score in &keyword_results.document_scores { | ||||||
|             let score = ScoreDetails::global_score(score.iter()); |             let score = ScoreDetails::global_score(score.iter()); | ||||||
|             if score < GOOD_ENOUGH_SCORE { |             if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE { | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -3,7 +3,6 @@ use std::ops::ControlFlow; | |||||||
|  |  | ||||||
| use charabia::normalizer::NormalizerOption; | use charabia::normalizer::NormalizerOption; | ||||||
| use charabia::Normalize; | use charabia::Normalize; | ||||||
| use deserr::{DeserializeError, Deserr, Sequence}; |  | ||||||
| use fst::automaton::{Automaton, Str}; | use fst::automaton::{Automaton, Str}; | ||||||
| use fst::{IntoStreamer, Streamer}; | use fst::{IntoStreamer, Streamer}; | ||||||
| use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; | use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; | ||||||
| @@ -57,53 +56,6 @@ pub struct Search<'a> { | |||||||
|     embedder_name: Option<String>, |     embedder_name: Option<String>, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, PartialEq)] |  | ||||||
| pub enum VectorQuery { |  | ||||||
|     Vector(Vec<f32>), |  | ||||||
|     String(String), |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<E> Deserr<E> for VectorQuery |  | ||||||
| where |  | ||||||
|     E: DeserializeError, |  | ||||||
| { |  | ||||||
|     fn deserialize_from_value<V: deserr::IntoValue>( |  | ||||||
|         value: deserr::Value<V>, |  | ||||||
|         location: deserr::ValuePointerRef, |  | ||||||
|     ) -> std::result::Result<Self, E> { |  | ||||||
|         match value { |  | ||||||
|             deserr::Value::String(s) => Ok(VectorQuery::String(s)), |  | ||||||
|             deserr::Value::Sequence(seq) => { |  | ||||||
|                 let v: std::result::Result<Vec<f32>, _> = seq |  | ||||||
|                     .into_iter() |  | ||||||
|                     .enumerate() |  | ||||||
|                     .map(|(index, v)| match v.into_value() { |  | ||||||
|                         deserr::Value::Float(f) => Ok(f as f32), |  | ||||||
|                         deserr::Value::Integer(i) => Ok(i as f32), |  | ||||||
|                         v => Err(deserr::take_cf_content(E::error::<V>( |  | ||||||
|                             None, |  | ||||||
|                             deserr::ErrorKind::IncorrectValueKind { |  | ||||||
|                                 actual: v, |  | ||||||
|                                 accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer], |  | ||||||
|                             }, |  | ||||||
|                             location.push_index(index), |  | ||||||
|                         ))), |  | ||||||
|                     }) |  | ||||||
|                     .collect(); |  | ||||||
|                 Ok(VectorQuery::Vector(v?)) |  | ||||||
|             } |  | ||||||
|             _ => Err(deserr::take_cf_content(E::error::<V>( |  | ||||||
|                 None, |  | ||||||
|                 deserr::ErrorKind::IncorrectValueKind { |  | ||||||
|                     actual: value, |  | ||||||
|                     accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence], |  | ||||||
|                 }, |  | ||||||
|                 location, |  | ||||||
|             ))), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<'a> Search<'a> { | impl<'a> Search<'a> { | ||||||
|     pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { |     pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { | ||||||
|         Search { |         Search { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user