mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-26 05:26:27 +00:00 
			
		
		
		
	Merge #4548
4548: v1.8 hybrid search changes r=dureuill a=dureuill Implements the search changes from the [usage page](https://meilisearch.notion.site/v1-8-AI-search-API-usage-135552d6e85a4a52bc7109be82aeca42#40f24df3da694428a39cc8043c9cfc64) ### ⚠️ Breaking changes in an experimental feature: - Removed the `_semanticScore`. Use the `_rankingScore` instead. - Removed `vector` in the response of the search (output was too big). - Removed all the vectors from the `vectorSort` ranking score details - target vector appearing in the name of the rule - matched vector appearing in the details of the rule ### Other user-facing changes - Added `semanticHitCount`, indicating how many hits were returned from the semantic search. This is especially useful in the hybrid search. - Embed lazily: Meilisearch no longer generates an embedding when the keyword results are "good enough". - Graceful embedding failure in hybrid search: when doing hybrid search (`semanticRatio in ]0.0, 1.0[`), an embedding failure no longer causes the search request to fail. Instead, only the keyword search is performed. When doing a full vector search (`semanticRatio==1.0`), a failure to embed will still result in failing that search. Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
		| @@ -758,9 +758,9 @@ impl SearchAggregator { | |||||||
|         let SearchResult { |         let SearchResult { | ||||||
|             hits: _, |             hits: _, | ||||||
|             query: _, |             query: _, | ||||||
|             vector: _, |  | ||||||
|             processing_time_ms, |             processing_time_ms, | ||||||
|             hits_info: _, |             hits_info: _, | ||||||
|  |             semantic_hit_count: _, | ||||||
|             facet_distribution: _, |             facet_distribution: _, | ||||||
|             facet_stats: _, |             facet_stats: _, | ||||||
|             degraded, |             degraded, | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ use tracing::debug; | |||||||
| use crate::analytics::{Analytics, FacetSearchAggregator}; | use crate::analytics::{Analytics, FacetSearchAggregator}; | ||||||
| use crate::extractors::authentication::policies::*; | use crate::extractors::authentication::policies::*; | ||||||
| use crate::extractors::authentication::GuardedData; | use crate::extractors::authentication::GuardedData; | ||||||
|  | use crate::routes::indexes::search::search_kind; | ||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, |     add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, | ||||||
|     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, |     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, | ||||||
| @@ -73,9 +74,10 @@ pub async fn search( | |||||||
|  |  | ||||||
|     let index = index_scheduler.index(&index_uid)?; |     let index = index_scheduler.index(&index_uid)?; | ||||||
|     let features = index_scheduler.features(); |     let features = index_scheduler.features(); | ||||||
|  |     let search_kind = search_kind(&search_query, &index_scheduler, &index, features)?; | ||||||
|     let _permit = search_queue.try_get_search_permit().await?; |     let _permit = search_queue.try_get_search_permit().await?; | ||||||
|     let search_result = tokio::task::spawn_blocking(move || { |     let search_result = tokio::task::spawn_blocking(move || { | ||||||
|         perform_facet_search(&index, search_query, facet_query, facet_name, features) |         perform_facet_search(&index, search_query, facet_query, facet_name, search_kind) | ||||||
|     }) |     }) | ||||||
|     .await?; |     .await?; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,26 +1,26 @@ | |||||||
| use actix_web::web::Data; | use actix_web::web::Data; | ||||||
| use actix_web::{web, HttpRequest, HttpResponse}; | use actix_web::{web, HttpRequest, HttpResponse}; | ||||||
| use deserr::actix_web::{AwebJson, AwebQueryParameter}; | use deserr::actix_web::{AwebJson, AwebQueryParameter}; | ||||||
| use index_scheduler::IndexScheduler; | use index_scheduler::{IndexScheduler, RoFeatures}; | ||||||
| use meilisearch_types::deserr::query_params::Param; | use meilisearch_types::deserr::query_params::Param; | ||||||
| use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; | use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; | ||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
| use meilisearch_types::error::ResponseError; | use meilisearch_types::error::ResponseError; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
| use meilisearch_types::milli; | use meilisearch_types::milli; | ||||||
| use meilisearch_types::milli::vector::DistributionShift; |  | ||||||
| use meilisearch_types::serde_cs::vec::CS; | use meilisearch_types::serde_cs::vec::CS; | ||||||
| use serde_json::Value; | use serde_json::Value; | ||||||
| use tracing::{debug, warn}; | use tracing::debug; | ||||||
|  |  | ||||||
| use crate::analytics::{Analytics, SearchAggregator}; | use crate::analytics::{Analytics, SearchAggregator}; | ||||||
|  | use crate::error::MeilisearchHttpError; | ||||||
| use crate::extractors::authentication::policies::*; | use crate::extractors::authentication::policies::*; | ||||||
| use crate::extractors::authentication::GuardedData; | use crate::extractors::authentication::GuardedData; | ||||||
| use crate::extractors::sequential_extractor::SeqHandler; | use crate::extractors::sequential_extractor::SeqHandler; | ||||||
| use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; | use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; | ||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio, |     add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchKind, SearchQuery, | ||||||
|     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, |     SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, | ||||||
|     DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, |     DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, | ||||||
| }; | }; | ||||||
| use crate::search_queue::SearchQueue; | use crate::search_queue::SearchQueue; | ||||||
| @@ -204,12 +204,11 @@ pub async fn search_with_url_query( | |||||||
|     let index = index_scheduler.index(&index_uid)?; |     let index = index_scheduler.index(&index_uid)?; | ||||||
|     let features = index_scheduler.features(); |     let features = index_scheduler.features(); | ||||||
|  |  | ||||||
|     let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; |     let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features)?; | ||||||
|  |  | ||||||
|     let _permit = search_queue.try_get_search_permit().await?; |     let _permit = search_queue.try_get_search_permit().await?; | ||||||
|     let search_result = |     let search_result = | ||||||
|         tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) |         tokio::task::spawn_blocking(move || perform_search(&index, query, search_kind)).await?; | ||||||
|             .await?; |  | ||||||
|     if let Ok(ref search_result) = search_result { |     if let Ok(ref search_result) = search_result { | ||||||
|         aggregate.succeed(search_result); |         aggregate.succeed(search_result); | ||||||
|     } |     } | ||||||
| @@ -245,12 +244,11 @@ pub async fn search_with_post( | |||||||
|  |  | ||||||
|     let features = index_scheduler.features(); |     let features = index_scheduler.features(); | ||||||
|  |  | ||||||
|     let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; |     let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features)?; | ||||||
|  |  | ||||||
|     let _permit = search_queue.try_get_search_permit().await?; |     let _permit = search_queue.try_get_search_permit().await?; | ||||||
|     let search_result = |     let search_result = | ||||||
|         tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) |         tokio::task::spawn_blocking(move || perform_search(&index, query, search_kind)).await?; | ||||||
|             .await?; |  | ||||||
|     if let Ok(ref search_result) = search_result { |     if let Ok(ref search_result) = search_result { | ||||||
|         aggregate.succeed(search_result); |         aggregate.succeed(search_result); | ||||||
|         if search_result.degraded { |         if search_result.degraded { | ||||||
| @@ -265,76 +263,58 @@ pub async fn search_with_post( | |||||||
|     Ok(HttpResponse::Ok().json(search_result)) |     Ok(HttpResponse::Ok().json(search_result)) | ||||||
| } | } | ||||||
|  |  | ||||||
| pub fn embed( | pub fn search_kind( | ||||||
|     query: &mut SearchQuery, |     query: &SearchQuery, | ||||||
|     index_scheduler: &IndexScheduler, |     index_scheduler: &IndexScheduler, | ||||||
|     index: &milli::Index, |     index: &milli::Index, | ||||||
| ) -> Result<Option<DistributionShift>, ResponseError> { |     features: RoFeatures, | ||||||
|     match (&query.hybrid, &query.vector, &query.q) { | ) -> Result<SearchKind, ResponseError> { | ||||||
|         (Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q)) |     if query.vector.is_some() { | ||||||
|             if !q.trim().is_empty() => |         features.check_vector("Passing `vector` as a query parameter")?; | ||||||
|         { |  | ||||||
|             let embedder_configs = index.embedding_configs(&index.read_txn()?)?; |  | ||||||
|             let embedders = index_scheduler.embedders(embedder_configs)?; |  | ||||||
|  |  | ||||||
|             let embedder = if let Some(embedder_name) = embedder { |  | ||||||
|                 embedders.get(embedder_name) |  | ||||||
|             } else { |  | ||||||
|                 embedders.get_default() |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             let embedder = embedder |  | ||||||
|                 .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) |  | ||||||
|                 .map_err(milli::Error::from)? |  | ||||||
|                 .0; |  | ||||||
|  |  | ||||||
|             let distribution = embedder.distribution(); |  | ||||||
|  |  | ||||||
|             let embeddings = embedder |  | ||||||
|                 .embed(vec![q.to_owned()]) |  | ||||||
|                 .map_err(milli::vector::Error::from) |  | ||||||
|                 .map_err(milli::Error::from)? |  | ||||||
|                 .pop() |  | ||||||
|                 .expect("No vector returned from embedding"); |  | ||||||
|  |  | ||||||
|             if embeddings.iter().nth(1).is_some() { |  | ||||||
|                 warn!("Ignoring embeddings past the first one in long search query"); |  | ||||||
|                 query.vector = Some(embeddings.iter().next().unwrap().to_vec()); |  | ||||||
|             } else { |  | ||||||
|                 query.vector = Some(embeddings.into_inner()); |  | ||||||
|     } |     } | ||||||
|             Ok(distribution) |  | ||||||
|  |     if query.hybrid.is_some() { | ||||||
|  |         features.check_vector("Passing `hybrid` as a query parameter")?; | ||||||
|     } |     } | ||||||
|         (Some(hybrid), vector, _) => { |  | ||||||
|             let embedder_configs = index.embedding_configs(&index.read_txn()?)?; |  | ||||||
|             let embedders = index_scheduler.embedders(embedder_configs)?; |  | ||||||
|  |  | ||||||
|             let embedder = if let Some(embedder_name) = &hybrid.embedder { |     // regardless of anything, always do a keyword search when we don't have a vector and the query is whitespace or missing | ||||||
|                 embedders.get(embedder_name) |     if query.vector.is_none() { | ||||||
|             } else { |         match &query.q { | ||||||
|                 embedders.get_default() |             Some(q) if q.trim().is_empty() => return Ok(SearchKind::KeywordOnly), | ||||||
|             }; |             None => return Ok(SearchKind::KeywordOnly), | ||||||
|  |             _ => {} | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|             let embedder = embedder |     match &query.hybrid { | ||||||
|                 .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) |         Some(HybridQuery { semantic_ratio, embedder }) if **semantic_ratio == 1.0 => { | ||||||
|                 .map_err(milli::Error::from)? |             Ok(SearchKind::semantic( | ||||||
|                 .0; |                 index_scheduler, | ||||||
|  |                 index, | ||||||
|             if let Some(vector) = vector { |                 embedder.as_deref(), | ||||||
|                 if vector.len() != embedder.dimensions() { |                 query.vector.as_ref().map(Vec::len), | ||||||
|                     return Err(meilisearch_types::milli::Error::UserError( |             )?) | ||||||
|                         meilisearch_types::milli::UserError::InvalidVectorDimensions { |         } | ||||||
|                             expected: embedder.dimensions(), |         Some(HybridQuery { semantic_ratio, embedder: _ }) if **semantic_ratio == 0.0 => { | ||||||
|                             found: vector.len(), |             Ok(SearchKind::KeywordOnly) | ||||||
|  |         } | ||||||
|  |         Some(HybridQuery { semantic_ratio, embedder }) => Ok(SearchKind::hybrid( | ||||||
|  |             index_scheduler, | ||||||
|  |             index, | ||||||
|  |             embedder.as_deref(), | ||||||
|  |             **semantic_ratio, | ||||||
|  |             query.vector.as_ref().map(Vec::len), | ||||||
|  |         )?), | ||||||
|  |         None => match (query.q.as_deref(), query.vector.as_deref()) { | ||||||
|  |             (_query, None) => Ok(SearchKind::KeywordOnly), | ||||||
|  |             (None, Some(_vector)) => Ok(SearchKind::semantic( | ||||||
|  |                 index_scheduler, | ||||||
|  |                 index, | ||||||
|  |                 None, | ||||||
|  |                 query.vector.as_ref().map(Vec::len), | ||||||
|  |             )?), | ||||||
|  |             (Some(_), Some(_)) => Err(MeilisearchHttpError::MissingSearchHybrid.into()), | ||||||
|         }, |         }, | ||||||
|                     ) |  | ||||||
|                     .into()); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             Ok(embedder.distribution()) |  | ||||||
|         } |  | ||||||
|         _ => Ok(None), |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator}; | |||||||
| use crate::extractors::authentication::policies::ActionPolicy; | use crate::extractors::authentication::policies::ActionPolicy; | ||||||
| use crate::extractors::authentication::{AuthenticationError, GuardedData}; | use crate::extractors::authentication::{AuthenticationError, GuardedData}; | ||||||
| use crate::extractors::sequential_extractor::SeqHandler; | use crate::extractors::sequential_extractor::SeqHandler; | ||||||
| use crate::routes::indexes::search::embed; | use crate::routes::indexes::search::search_kind; | ||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, |     add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, | ||||||
| }; | }; | ||||||
| @@ -81,12 +81,11 @@ pub async fn multi_search_with_post( | |||||||
|                 }) |                 }) | ||||||
|                 .with_index(query_index)?; |                 .with_index(query_index)?; | ||||||
|  |  | ||||||
|             let distribution = |             let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features) | ||||||
|                 embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?; |                 .with_index(query_index)?; | ||||||
|  |  | ||||||
|             let search_result = tokio::task::spawn_blocking(move || { |             let search_result = | ||||||
|                 perform_search(&index, query, features, distribution) |                 tokio::task::spawn_blocking(move || perform_search(&index, query, search_kind)) | ||||||
|             }) |  | ||||||
|                     .await |                     .await | ||||||
|                     .with_index(query_index)?; |                     .with_index(query_index)?; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,19 +1,20 @@ | |||||||
| use std::cmp::min; | use std::cmp::min; | ||||||
| use std::collections::{BTreeMap, BTreeSet, HashSet}; | use std::collections::{BTreeMap, BTreeSet, HashSet}; | ||||||
| use std::str::FromStr; | use std::str::FromStr; | ||||||
|  | use std::sync::Arc; | ||||||
| use std::time::{Duration, Instant}; | use std::time::{Duration, Instant}; | ||||||
|  |  | ||||||
| use deserr::Deserr; | use deserr::Deserr; | ||||||
| use either::Either; | use either::Either; | ||||||
| use index_scheduler::RoFeatures; |  | ||||||
| use indexmap::IndexMap; | use indexmap::IndexMap; | ||||||
| use meilisearch_auth::IndexSearchRules; | use meilisearch_auth::IndexSearchRules; | ||||||
| use meilisearch_types::deserr::DeserrJsonError; | use meilisearch_types::deserr::DeserrJsonError; | ||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
|  | use meilisearch_types::error::ResponseError; | ||||||
| use meilisearch_types::heed::RoTxn; | use meilisearch_types::heed::RoTxn; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
| use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; | use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; | ||||||
| use meilisearch_types::milli::vector::DistributionShift; | use meilisearch_types::milli::vector::Embedder; | ||||||
| use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget}; | use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget}; | ||||||
| use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | ||||||
| use meilisearch_types::{milli, Document}; | use meilisearch_types::{milli, Document}; | ||||||
| @@ -90,13 +91,75 @@ pub struct SearchQuery { | |||||||
| #[derive(Debug, Clone, Default, PartialEq, Deserr)] | #[derive(Debug, Clone, Default, PartialEq, Deserr)] | ||||||
| #[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)] | #[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)] | ||||||
| pub struct HybridQuery { | pub struct HybridQuery { | ||||||
|     /// TODO validate that sementic ratio is between 0.0 and 1,0 |  | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)] | ||||||
|     pub semantic_ratio: SemanticRatio, |     pub semantic_ratio: SemanticRatio, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)] |     #[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)] | ||||||
|     pub embedder: Option<String>, |     pub embedder: Option<String>, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | pub enum SearchKind { | ||||||
|  |     KeywordOnly, | ||||||
|  |     SemanticOnly { embedder_name: String, embedder: Arc<Embedder> }, | ||||||
|  |     Hybrid { embedder_name: String, embedder: Arc<Embedder>, semantic_ratio: f32 }, | ||||||
|  | } | ||||||
|  | impl SearchKind { | ||||||
|  |     pub(crate) fn semantic( | ||||||
|  |         index_scheduler: &index_scheduler::IndexScheduler, | ||||||
|  |         index: &Index, | ||||||
|  |         embedder_name: Option<&str>, | ||||||
|  |         vector_len: Option<usize>, | ||||||
|  |     ) -> Result<Self, ResponseError> { | ||||||
|  |         let (embedder_name, embedder) = | ||||||
|  |             Self::embedder(index_scheduler, index, embedder_name, vector_len)?; | ||||||
|  |         Ok(Self::SemanticOnly { embedder_name, embedder }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn hybrid( | ||||||
|  |         index_scheduler: &index_scheduler::IndexScheduler, | ||||||
|  |         index: &Index, | ||||||
|  |         embedder_name: Option<&str>, | ||||||
|  |         semantic_ratio: f32, | ||||||
|  |         vector_len: Option<usize>, | ||||||
|  |     ) -> Result<Self, ResponseError> { | ||||||
|  |         let (embedder_name, embedder) = | ||||||
|  |             Self::embedder(index_scheduler, index, embedder_name, vector_len)?; | ||||||
|  |         Ok(Self::Hybrid { embedder_name, embedder, semantic_ratio }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn embedder( | ||||||
|  |         index_scheduler: &index_scheduler::IndexScheduler, | ||||||
|  |         index: &Index, | ||||||
|  |         embedder_name: Option<&str>, | ||||||
|  |         vector_len: Option<usize>, | ||||||
|  |     ) -> Result<(String, Arc<Embedder>), ResponseError> { | ||||||
|  |         let embedder_configs = index.embedding_configs(&index.read_txn()?)?; | ||||||
|  |         let embedders = index_scheduler.embedders(embedder_configs)?; | ||||||
|  |  | ||||||
|  |         let embedder_name = embedder_name.unwrap_or_else(|| embedders.get_default_embedder_name()); | ||||||
|  |  | ||||||
|  |         let embedder = embedders.get(embedder_name); | ||||||
|  |  | ||||||
|  |         let embedder = embedder | ||||||
|  |             .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) | ||||||
|  |             .map_err(milli::Error::from)? | ||||||
|  |             .0; | ||||||
|  |  | ||||||
|  |         if let Some(vector_len) = vector_len { | ||||||
|  |             if vector_len != embedder.dimensions() { | ||||||
|  |                 return Err(meilisearch_types::milli::Error::UserError( | ||||||
|  |                     meilisearch_types::milli::UserError::InvalidVectorDimensions { | ||||||
|  |                         expected: embedder.dimensions(), | ||||||
|  |                         found: vector_len, | ||||||
|  |                     }, | ||||||
|  |                 ) | ||||||
|  |                 .into()); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         Ok((embedder_name.to_owned(), embedder)) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, Copy, PartialEq, Deserr)] | #[derive(Debug, Clone, Copy, PartialEq, Deserr)] | ||||||
| #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] | #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] | ||||||
| pub struct SemanticRatio(f32); | pub struct SemanticRatio(f32); | ||||||
| @@ -305,8 +368,6 @@ pub struct SearchHit { | |||||||
|     pub ranking_score: Option<f64>, |     pub ranking_score: Option<f64>, | ||||||
|     #[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")] |     #[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")] | ||||||
|     pub ranking_score_details: Option<serde_json::Map<String, serde_json::Value>>, |     pub ranking_score_details: Option<serde_json::Map<String, serde_json::Value>>, | ||||||
|     #[serde(rename = "_semanticScore", skip_serializing_if = "Option::is_none")] |  | ||||||
|     pub semantic_score: Option<f32>, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Serialize, Debug, Clone, PartialEq)] | #[derive(Serialize, Debug, Clone, PartialEq)] | ||||||
| @@ -314,8 +375,6 @@ pub struct SearchHit { | |||||||
| pub struct SearchResult { | pub struct SearchResult { | ||||||
|     pub hits: Vec<SearchHit>, |     pub hits: Vec<SearchHit>, | ||||||
|     pub query: String, |     pub query: String, | ||||||
|     #[serde(skip_serializing_if = "Option::is_none")] |  | ||||||
|     pub vector: Option<Vec<f32>>, |  | ||||||
|     pub processing_time_ms: u128, |     pub processing_time_ms: u128, | ||||||
|     #[serde(flatten)] |     #[serde(flatten)] | ||||||
|     pub hits_info: HitsInfo, |     pub hits_info: HitsInfo, | ||||||
| @@ -324,6 +383,9 @@ pub struct SearchResult { | |||||||
|     #[serde(skip_serializing_if = "Option::is_none")] |     #[serde(skip_serializing_if = "Option::is_none")] | ||||||
|     pub facet_stats: Option<BTreeMap<String, FacetStats>>, |     pub facet_stats: Option<BTreeMap<String, FacetStats>>, | ||||||
|  |  | ||||||
|  |     #[serde(skip_serializing_if = "Option::is_none")] | ||||||
|  |     pub semantic_hit_count: Option<u32>, | ||||||
|  |  | ||||||
|     // These fields are only used for analytics purposes |     // These fields are only used for analytics purposes | ||||||
|     #[serde(skip)] |     #[serde(skip)] | ||||||
|     pub degraded: bool, |     pub degraded: bool, | ||||||
| @@ -386,47 +448,36 @@ fn prepare_search<'t>( | |||||||
|     index: &'t Index, |     index: &'t Index, | ||||||
|     rtxn: &'t RoTxn, |     rtxn: &'t RoTxn, | ||||||
|     query: &'t SearchQuery, |     query: &'t SearchQuery, | ||||||
|     features: RoFeatures, |     search_kind: &SearchKind, | ||||||
|     distribution: Option<DistributionShift>, |  | ||||||
|     time_budget: TimeBudget, |     time_budget: TimeBudget, | ||||||
| ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { | ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { | ||||||
|     let mut search = index.search(rtxn); |     let mut search = index.search(rtxn); | ||||||
|     search.time_budget(time_budget); |     search.time_budget(time_budget); | ||||||
|  |  | ||||||
|     if query.vector.is_some() { |     match search_kind { | ||||||
|         features.check_vector("Passing `vector` as a query parameter")?; |         SearchKind::KeywordOnly => { | ||||||
|     } |             if let Some(q) = &query.q { | ||||||
|  |  | ||||||
|     if query.hybrid.is_some() { |  | ||||||
|         features.check_vector("Passing `hybrid` as a query parameter")?; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() { |  | ||||||
|         return Err(MeilisearchHttpError::MissingSearchHybrid); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     search.distribution_shift(distribution); |  | ||||||
|  |  | ||||||
|     if let Some(ref vector) = query.vector { |  | ||||||
|         match &query.hybrid { |  | ||||||
|             // If semantic ratio is 0.0, only the query search will impact the search results, |  | ||||||
|             // skip the vector |  | ||||||
|             Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), |  | ||||||
|             _otherwise => { |  | ||||||
|                 search.vector(vector.clone()); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if let Some(ref q) = query.q { |  | ||||||
|         match &query.hybrid { |  | ||||||
|             // If semantic ratio is 1.0, only the vector search will impact the search results, |  | ||||||
|             // skip the query |  | ||||||
|             Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (), |  | ||||||
|             _otherwise => { |  | ||||||
|                 search.query(q); |                 search.query(q); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |         SearchKind::SemanticOnly { embedder_name, embedder } => { | ||||||
|  |             let vector = match query.vector.clone() { | ||||||
|  |                 Some(vector) => vector, | ||||||
|  |                 None => embedder | ||||||
|  |                     .embed_one(query.q.clone().unwrap()) | ||||||
|  |                     .map_err(milli::vector::Error::from) | ||||||
|  |                     .map_err(milli::Error::from)?, | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             search.semantic(embedder_name.clone(), embedder.clone(), Some(vector)); | ||||||
|  |         } | ||||||
|  |         SearchKind::Hybrid { embedder_name, embedder, semantic_ratio: _ } => { | ||||||
|  |             if let Some(q) = &query.q { | ||||||
|  |                 search.query(q); | ||||||
|  |             } | ||||||
|  |             // will be embedded in hybrid search if necessary | ||||||
|  |             search.semantic(embedder_name.clone(), embedder.clone(), query.vector.clone()); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if let Some(ref searchable) = query.attributes_to_search_on { |     if let Some(ref searchable) = query.attributes_to_search_on { | ||||||
| @@ -449,10 +500,6 @@ fn prepare_search<'t>( | |||||||
|         ScoringStrategy::Skip |         ScoringStrategy::Skip | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|     if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { |  | ||||||
|         search.embedder_name(embedder); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // compute the offset on the limit depending on the pagination mode. |     // compute the offset on the limit depending on the pagination mode. | ||||||
|     let (offset, limit) = if is_finite_pagination { |     let (offset, limit) = if is_finite_pagination { | ||||||
|         let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); |         let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); | ||||||
| @@ -495,8 +542,7 @@ fn prepare_search<'t>( | |||||||
| pub fn perform_search( | pub fn perform_search( | ||||||
|     index: &Index, |     index: &Index, | ||||||
|     query: SearchQuery, |     query: SearchQuery, | ||||||
|     features: RoFeatures, |     search_kind: SearchKind, | ||||||
|     distribution: Option<DistributionShift>, |  | ||||||
| ) -> Result<SearchResult, MeilisearchHttpError> { | ) -> Result<SearchResult, MeilisearchHttpError> { | ||||||
|     let before_search = Instant::now(); |     let before_search = Instant::now(); | ||||||
|     let rtxn = index.read_txn()?; |     let rtxn = index.read_txn()?; | ||||||
| @@ -506,22 +552,26 @@ pub fn perform_search( | |||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let (search, is_finite_pagination, max_total_hits, offset) = |     let (search, is_finite_pagination, max_total_hits, offset) = | ||||||
|         prepare_search(index, &rtxn, &query, features, distribution, time_budget)?; |         prepare_search(index, &rtxn, &query, &search_kind, time_budget)?; | ||||||
|  |  | ||||||
|     let milli::SearchResult { |     let ( | ||||||
|  |         milli::SearchResult { | ||||||
|             documents_ids, |             documents_ids, | ||||||
|             matching_words, |             matching_words, | ||||||
|             candidates, |             candidates, | ||||||
|             document_scores, |             document_scores, | ||||||
|             degraded, |             degraded, | ||||||
|             used_negative_operator, |             used_negative_operator, | ||||||
|         .. |  | ||||||
|     } = match &query.hybrid { |  | ||||||
|         Some(hybrid) => match *hybrid.semantic_ratio { |  | ||||||
|             ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?, |  | ||||||
|             ratio => search.execute_hybrid(ratio)?, |  | ||||||
|         }, |         }, | ||||||
|         None => search.execute()?, |         semantic_hit_count, | ||||||
|  |     ) = match &search_kind { | ||||||
|  |         SearchKind::KeywordOnly => (search.execute()?, None), | ||||||
|  |         SearchKind::SemanticOnly { .. } => { | ||||||
|  |             let results = search.execute()?; | ||||||
|  |             let semantic_hit_count = results.document_scores.len() as u32; | ||||||
|  |             (results, Some(semantic_hit_count)) | ||||||
|  |         } | ||||||
|  |         SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); |     let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); | ||||||
| @@ -631,18 +681,6 @@ pub fn perform_search( | |||||||
|             insert_geo_distance(sort, &mut document); |             insert_geo_distance(sort, &mut document); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         let mut semantic_score = None; |  | ||||||
|         for details in &score { |  | ||||||
|             if let ScoreDetails::Vector(score_details::Vector { |  | ||||||
|                 target_vector: _, |  | ||||||
|                 value_similarity: Some((_matching_vector, similarity)), |  | ||||||
|             }) = details |  | ||||||
|             { |  | ||||||
|                 semantic_score = Some(*similarity); |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let ranking_score = |         let ranking_score = | ||||||
|             query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); |             query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); | ||||||
|         let ranking_score_details = |         let ranking_score_details = | ||||||
| @@ -654,7 +692,6 @@ pub fn perform_search( | |||||||
|             matches_position, |             matches_position, | ||||||
|             ranking_score_details, |             ranking_score_details, | ||||||
|             ranking_score, |             ranking_score, | ||||||
|             semantic_score, |  | ||||||
|         }; |         }; | ||||||
|         documents.push(hit); |         documents.push(hit); | ||||||
|     } |     } | ||||||
| @@ -715,12 +752,12 @@ pub fn perform_search( | |||||||
|         hits: documents, |         hits: documents, | ||||||
|         hits_info, |         hits_info, | ||||||
|         query: query.q.unwrap_or_default(), |         query: query.q.unwrap_or_default(), | ||||||
|         vector: query.vector, |  | ||||||
|         processing_time_ms: before_search.elapsed().as_millis(), |         processing_time_ms: before_search.elapsed().as_millis(), | ||||||
|         facet_distribution, |         facet_distribution, | ||||||
|         facet_stats, |         facet_stats, | ||||||
|         degraded, |         degraded, | ||||||
|         used_negative_operator, |         used_negative_operator, | ||||||
|  |         semantic_hit_count, | ||||||
|     }; |     }; | ||||||
|     Ok(result) |     Ok(result) | ||||||
| } | } | ||||||
| @@ -730,7 +767,7 @@ pub fn perform_facet_search( | |||||||
|     search_query: SearchQuery, |     search_query: SearchQuery, | ||||||
|     facet_query: Option<String>, |     facet_query: Option<String>, | ||||||
|     facet_name: String, |     facet_name: String, | ||||||
|     features: RoFeatures, |     search_kind: SearchKind, | ||||||
| ) -> Result<FacetSearchResult, MeilisearchHttpError> { | ) -> Result<FacetSearchResult, MeilisearchHttpError> { | ||||||
|     let before_search = Instant::now(); |     let before_search = Instant::now(); | ||||||
|     let rtxn = index.read_txn()?; |     let rtxn = index.read_txn()?; | ||||||
| @@ -739,10 +776,12 @@ pub fn perform_facet_search( | |||||||
|         None => TimeBudget::default(), |         None => TimeBudget::default(), | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let (search, _, _, _) = |     let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, &search_kind, time_budget)?; | ||||||
|         prepare_search(index, &rtxn, &search_query, features, None, time_budget)?; |     let mut facet_search = SearchForFacetValues::new( | ||||||
|     let mut facet_search = |         facet_name, | ||||||
|         SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some()); |         search, | ||||||
|  |         matches!(search_kind, SearchKind::Hybrid { .. }), | ||||||
|  |     ); | ||||||
|     if let Some(facet_query) = &facet_query { |     if let Some(facet_query) = &facet_query { | ||||||
|         facet_search.query(facet_query); |         facet_search.query(facet_query); | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -77,14 +77,25 @@ async fn simple_search() { | |||||||
|         .await; |         .await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"0"); | ||||||
|  |  | ||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .search_post( |         .search_post( | ||||||
|             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}}), |             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.5}, "showRankingScore": true}), | ||||||
|         ) |         ) | ||||||
|         .await; |         .await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.996969696969697},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.996969696969697},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.9472135901451112}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"1"); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}, "showRankingScore": true}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.974341630935669},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.9472135901451112}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"3"); | ||||||
| } | } | ||||||
|  |  | ||||||
| #[actix_rt::test] | #[actix_rt::test] | ||||||
| @@ -95,7 +106,7 @@ async fn distribution_shift() { | |||||||
|     let search = json!({"q": "Captain", "vector": [1.0, 1.0], "showRankingScore": true, "hybrid": {"semanticRatio": 1.0}}); |     let search = json!({"q": "Captain", "vector": [1.0, 1.0], "showRankingScore": true, "hybrid": {"semanticRatio": 1.0}}); | ||||||
|     let (response, code) = index.search_post(search.clone()).await; |     let (response, code) = index.search_post(search.clone()).await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.990290343761444,"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.974341630935669,"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.9472135901451112,"_semanticScore":0.9472136}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.974341630935669},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.9472135901451112}]"###); | ||||||
|  |  | ||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .update_settings(json!({ |         .update_settings(json!({ | ||||||
| @@ -116,7 +127,7 @@ async fn distribution_shift() { | |||||||
|  |  | ||||||
|     let (response, code) = index.search_post(search).await; |     let (response, code) = index.search_post(search).await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.19161224365234375,"_semanticScore":0.19161224},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.1920928955078125e-7,"_semanticScore":1.1920929e-7},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.1920928955078125e-7,"_semanticScore":1.1920929e-7}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.19161224365234375},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.1920928955078125e-7},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.1920928955078125e-7}]"###); | ||||||
| } | } | ||||||
|  |  | ||||||
| #[actix_rt::test] | #[actix_rt::test] | ||||||
| @@ -136,10 +147,12 @@ async fn highlighter() { | |||||||
|         .await; |         .await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}}}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}}}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"0"); | ||||||
|  |  | ||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], |         .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], | ||||||
|             "hybrid": {"semanticRatio": 0.8}, |             "hybrid": {"semanticRatio": 0.8}, | ||||||
|  |             "showRankingScore": true, | ||||||
|             "attributesToHighlight": [ |             "attributesToHighlight": [ | ||||||
|                      "desc" |                      "desc" | ||||||
|                    ], |                    ], | ||||||
| @@ -148,12 +161,14 @@ async fn highlighter() { | |||||||
|         })) |         })) | ||||||
|         .await; |         .await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_semanticScore":0.9472136}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_rankingScore":0.974341630935669},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_rankingScore":0.9472135901451112}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"3"); | ||||||
|  |  | ||||||
|     // no highlighting on full semantic |     // no highlighting on full semantic | ||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], |         .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], | ||||||
|             "hybrid": {"semanticRatio": 1.0}, |             "hybrid": {"semanticRatio": 1.0}, | ||||||
|  |             "showRankingScore": true, | ||||||
|             "attributesToHighlight": [ |             "attributesToHighlight": [ | ||||||
|                      "desc" |                      "desc" | ||||||
|                    ], |                    ], | ||||||
| @@ -162,7 +177,8 @@ async fn highlighter() { | |||||||
|         })) |         })) | ||||||
|         .await; |         .await; | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}}]"###); |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_rankingScore":0.974341630935669},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_rankingScore":0.9472135901451112}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"3"); | ||||||
| } | } | ||||||
|  |  | ||||||
| #[actix_rt::test] | #[actix_rt::test] | ||||||
| @@ -249,5 +265,115 @@ async fn single_document() { | |||||||
|     .await; |     .await; | ||||||
|  |  | ||||||
|     snapshot!(code, @"200 OK"); |     snapshot!(code, @"200 OK"); | ||||||
|     snapshot!(response["hits"][0], @r###"{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0,"_semanticScore":1.0}"###); |     snapshot!(response["hits"][0], @r###"{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0}"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"1"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[actix_rt::test] | ||||||
|  | async fn query_combination() { | ||||||
|  |     let server = Server::new().await; | ||||||
|  |     let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; | ||||||
|  |  | ||||||
|  |     // search without query and vector, but with hybrid => still placeholder | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post(json!({"hybrid": {"semanticRatio": 1.0}, "showRankingScore": true})) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.0},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":1.0}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"null"); | ||||||
|  |  | ||||||
|  |     // same with a different semantic ratio | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post(json!({"hybrid": {"semanticRatio": 0.76}, "showRankingScore": true})) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.0},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":1.0}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"null"); | ||||||
|  |  | ||||||
|  |     // wrong vector dimensions | ||||||
|  |     let (response, code) = index | ||||||
|  |     .search_post(json!({"vector": [1.0, 0.0, 1.0], "hybrid": {"semanticRatio": 1.0}, "showRankingScore": true})) | ||||||
|  |     .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Invalid vector dimensions: expected: `2`, found: `3`.", | ||||||
|  |       "code": "invalid_vector_dimensions", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#invalid_vector_dimensions" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     // full vector | ||||||
|  |     let (response, code) = index | ||||||
|  |     .search_post(json!({"vector": [1.0, 0.0], "hybrid": {"semanticRatio": 1.0}, "showRankingScore": true})) | ||||||
|  |     .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.7773500680923462},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.7236068248748779},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.6581138968467712}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"3"); | ||||||
|  |  | ||||||
|  |     // full keyword, without a query | ||||||
|  |     let (response, code) = index | ||||||
|  |     .search_post(json!({"vector": [1.0, 0.0], "hybrid": {"semanticRatio": 0.0}, "showRankingScore": true})) | ||||||
|  |     .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.0},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":1.0}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"null"); | ||||||
|  |  | ||||||
|  |     // query + vector, full keyword => keyword | ||||||
|  |     let (response, code) = index | ||||||
|  |     .search_post(json!({"q": "Captain", "vector": [1.0, 0.0], "hybrid": {"semanticRatio": 0.0}, "showRankingScore": true})) | ||||||
|  |     .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.996969696969697},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.996969696969697},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.8848484848484849}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"null"); | ||||||
|  |  | ||||||
|  |     // query + vector, no hybrid keyword => | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post(json!({"q": "Captain", "vector": [1.0, 0.0], "showRankingScore": true})) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.", | ||||||
|  |       "code": "missing_search_hybrid", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#missing_search_hybrid" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     // full vector, without a vector => error | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Captain", "hybrid": {"semanticRatio": 1.0}, "showRankingScore": true}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Error while generating embeddings: user error: attempt to embed the following text in a configuration where embeddings must be user provided: \"Captain\"", | ||||||
|  |       "code": "vector_embedding_error", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#vector_embedding_error" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     // hybrid without a vector => full keyword | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Planet", "hybrid": {"semanticRatio": 0.99}, "showRankingScore": true}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.9848484848484848}]"###); | ||||||
|  |     snapshot!(response["semanticHitCount"], @"0"); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1040,6 +1040,7 @@ async fn experimental_feature_vector_store() { | |||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .search_post(json!({ |         .search_post(json!({ | ||||||
|             "vector": [1.0, 2.0, 3.0], |             "vector": [1.0, 2.0, 3.0], | ||||||
|  |             "showRankingScore": true | ||||||
|         })) |         })) | ||||||
|         .await; |         .await; | ||||||
|     meili_snap::snapshot!(code, @"400 Bad Request"); |     meili_snap::snapshot!(code, @"400 Bad Request"); | ||||||
| @@ -1082,6 +1083,7 @@ async fn experimental_feature_vector_store() { | |||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .search_post(json!({ |         .search_post(json!({ | ||||||
|             "vector": [1.0, 2.0, 3.0], |             "vector": [1.0, 2.0, 3.0], | ||||||
|  |             "showRankingScore": true, | ||||||
|         })) |         })) | ||||||
|         .await; |         .await; | ||||||
|  |  | ||||||
| @@ -1099,7 +1101,7 @@ async fn experimental_feature_vector_store() { | |||||||
|             3 |             3 | ||||||
|           ] |           ] | ||||||
|         }, |         }, | ||||||
|         "_semanticScore": 1.0 |         "_rankingScore": 1.0 | ||||||
|       }, |       }, | ||||||
|       { |       { | ||||||
|         "title": "Captain Marvel", |         "title": "Captain Marvel", | ||||||
| @@ -1111,7 +1113,7 @@ async fn experimental_feature_vector_store() { | |||||||
|             54 |             54 | ||||||
|           ] |           ] | ||||||
|         }, |         }, | ||||||
|         "_semanticScore": 0.9129112 |         "_rankingScore": 0.9129111766815186 | ||||||
|       }, |       }, | ||||||
|       { |       { | ||||||
|         "title": "Gläss", |         "title": "Gläss", | ||||||
| @@ -1123,7 +1125,7 @@ async fn experimental_feature_vector_store() { | |||||||
|             90 |             90 | ||||||
|           ] |           ] | ||||||
|         }, |         }, | ||||||
|         "_semanticScore": 0.8106413 |         "_rankingScore": 0.8106412887573242 | ||||||
|       }, |       }, | ||||||
|       { |       { | ||||||
|         "title": "How to Train Your Dragon: The Hidden World", |         "title": "How to Train Your Dragon: The Hidden World", | ||||||
| @@ -1135,7 +1137,7 @@ async fn experimental_feature_vector_store() { | |||||||
|             32 |             32 | ||||||
|           ] |           ] | ||||||
|         }, |         }, | ||||||
|         "_semanticScore": 0.74120104 |         "_rankingScore": 0.7412010431289673 | ||||||
|       }, |       }, | ||||||
|       { |       { | ||||||
|         "title": "Escape Room", |         "title": "Escape Room", | ||||||
| @@ -1146,7 +1148,8 @@ async fn experimental_feature_vector_store() { | |||||||
|             -23, |             -23, | ||||||
|             32 |             32 | ||||||
|           ] |           ] | ||||||
|         } |         }, | ||||||
|  |         "_rankingScore": 0.6972063183784485 | ||||||
|       } |       } | ||||||
|     ] |     ] | ||||||
|     "###); |     "###); | ||||||
|   | |||||||
| @@ -196,7 +196,7 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco | |||||||
|     InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), |     InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), | ||||||
|     #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] |     #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] | ||||||
|     TooManyEmbedders(usize), |     TooManyEmbedders(usize), | ||||||
|     #[error("Cannot find embedder with name {0}.")] |     #[error("Cannot find embedder with name `{0}`.")] | ||||||
|     InvalidEmbedder(String), |     InvalidEmbedder(String), | ||||||
|     #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] |     #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] | ||||||
|     TooManyVectors(String, usize), |     TooManyVectors(String, usize), | ||||||
|   | |||||||
| @@ -1499,14 +1499,6 @@ impl Index { | |||||||
|             .unwrap_or_default()) |             .unwrap_or_default()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result<String> { |  | ||||||
|         let configs = self.embedding_configs(rtxn)?; |  | ||||||
|         Ok(match configs.as_slice() { |  | ||||||
|             [(ref first_name, _)] => first_name.clone(), |  | ||||||
|             _ => "default".to_owned(), |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> { |     pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> { | ||||||
|         self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff) |         self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff) | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -61,7 +61,7 @@ pub use self::index::Index; | |||||||
| pub use self::search::facet::{FacetValueHit, SearchForFacetValues}; | pub use self::search::facet::{FacetValueHit, SearchForFacetValues}; | ||||||
| pub use self::search::{ | pub use self::search::{ | ||||||
|     FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, |     FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, | ||||||
|     Search, SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, |     Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| pub type Result<T> = std::result::Result<T, error::Error>; | pub type Result<T> = std::result::Result<T, error::Error>; | ||||||
|   | |||||||
| @@ -98,9 +98,9 @@ impl ScoreDetails { | |||||||
|             ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), |             ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), | ||||||
|             ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), |             ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), | ||||||
|             ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), |             ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), | ||||||
|             ScoreDetails::Vector(vector) => RankOrValue::Score( |             ScoreDetails::Vector(vector) => { | ||||||
|                 vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64), |                 RankOrValue::Score(vector.similarity.as_ref().map(|s| *s as f64).unwrap_or(0.0f64)) | ||||||
|             ), |             } | ||||||
|             ScoreDetails::Skipped => RankOrValue::Rank(Rank { rank: 0, max_rank: 1 }), |             ScoreDetails::Skipped => RankOrValue::Rank(Rank { rank: 0, max_rank: 1 }), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -249,16 +249,13 @@ impl ScoreDetails { | |||||||
|                     order += 1; |                     order += 1; | ||||||
|                 } |                 } | ||||||
|                 ScoreDetails::Vector(s) => { |                 ScoreDetails::Vector(s) => { | ||||||
|                     let vector = format!("vectorSort({:?})", s.target_vector); |                     let similarity = s.similarity.as_ref(); | ||||||
|                     let value = s.value_similarity.as_ref().map(|(v, _)| v); |  | ||||||
|                     let similarity = s.value_similarity.as_ref().map(|(_, s)| s); |  | ||||||
|  |  | ||||||
|                     let details = serde_json::json!({ |                     let details = serde_json::json!({ | ||||||
|                         "order": order, |                         "order": order, | ||||||
|                         "value": value, |  | ||||||
|                         "similarity": similarity, |                         "similarity": similarity, | ||||||
|                     }); |                     }); | ||||||
|                     details_map.insert(vector, details); |                     details_map.insert("vectorSort".into(), details); | ||||||
|                     order += 1; |                     order += 1; | ||||||
|                 } |                 } | ||||||
|                 ScoreDetails::Skipped => { |                 ScoreDetails::Skipped => { | ||||||
| @@ -494,8 +491,7 @@ impl PartialOrd for GeoSort { | |||||||
|  |  | ||||||
| #[derive(Debug, Clone, PartialEq, PartialOrd)] | #[derive(Debug, Clone, PartialEq, PartialOrd)] | ||||||
| pub struct Vector { | pub struct Vector { | ||||||
|     pub target_vector: Vec<f32>, |     pub similarity: Option<f32>, | ||||||
|     pub value_similarity: Option<(Vec<f32>, f32)>, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| impl GeoSort { | impl GeoSort { | ||||||
|   | |||||||
| @@ -92,9 +92,15 @@ impl<'a> SearchForFacetValues<'a> { | |||||||
|             None => return Ok(Vec::new()), |             None => return Ok(Vec::new()), | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let search_candidates = self |         let search_candidates = self.search_query.execute_for_candidates( | ||||||
|  |             self.is_hybrid | ||||||
|  |                 || self | ||||||
|                     .search_query |                     .search_query | ||||||
|             .execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?; |                     .semantic | ||||||
|  |                     .as_ref() | ||||||
|  |                     .and_then(|semantic| semantic.vector.as_ref()) | ||||||
|  |                     .is_some(), | ||||||
|  |         )?; | ||||||
|  |  | ||||||
|         let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) { |         let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) { | ||||||
|             OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values), |             OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values), | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ use itertools::Itertools; | |||||||
| use roaring::RoaringBitmap; | use roaring::RoaringBitmap; | ||||||
|  |  | ||||||
| use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; | use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; | ||||||
|  | use crate::search::SemanticSearch; | ||||||
| use crate::{MatchingWords, Result, Search, SearchResult}; | use crate::{MatchingWords, Result, Search, SearchResult}; | ||||||
|  |  | ||||||
| struct ScoreWithRatioResult { | struct ScoreWithRatioResult { | ||||||
| @@ -83,50 +84,77 @@ impl ScoreWithRatioResult { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult { |     fn merge( | ||||||
|         let mut documents_ids = |         vector_results: Self, | ||||||
|             Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); |         keyword_results: Self, | ||||||
|         let mut document_scores = |         from: usize, | ||||||
|             Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); |         length: usize, | ||||||
|  |     ) -> (SearchResult, u32) { | ||||||
|  |         #[derive(Clone, Copy)] | ||||||
|  |         enum ResultSource { | ||||||
|  |             Semantic, | ||||||
|  |             Keyword, | ||||||
|  |         } | ||||||
|  |         let mut semantic_hit_count = 0; | ||||||
|  |  | ||||||
|  |         let mut documents_ids = Vec::with_capacity( | ||||||
|  |             vector_results.document_scores.len() + keyword_results.document_scores.len(), | ||||||
|  |         ); | ||||||
|  |         let mut document_scores = Vec::with_capacity( | ||||||
|  |             vector_results.document_scores.len() + keyword_results.document_scores.len(), | ||||||
|  |         ); | ||||||
|  |  | ||||||
|         let mut documents_seen = RoaringBitmap::new(); |         let mut documents_seen = RoaringBitmap::new(); | ||||||
|         for (docid, (main_score, _sub_score)) in left |         for ((docid, (main_score, _sub_score)), source) in vector_results | ||||||
|             .document_scores |             .document_scores | ||||||
|             .into_iter() |             .into_iter() | ||||||
|             .merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| { |             .zip(std::iter::repeat(ResultSource::Semantic)) | ||||||
|  |             .merge_by( | ||||||
|  |                 keyword_results | ||||||
|  |                     .document_scores | ||||||
|  |                     .into_iter() | ||||||
|  |                     .zip(std::iter::repeat(ResultSource::Keyword)), | ||||||
|  |                 |((_, left), _), ((_, right), _)| { | ||||||
|                     // the first value is the one with the greatest score |                     // the first value is the one with the greatest score | ||||||
|                     compare_scores(left, right).is_ge() |                     compare_scores(left, right).is_ge() | ||||||
|             }) |                 }, | ||||||
|  |             ) | ||||||
|             // remove documents we already saw |             // remove documents we already saw | ||||||
|             .filter(|(docid, _)| documents_seen.insert(*docid)) |             .filter(|((docid, _), _)| documents_seen.insert(*docid)) | ||||||
|             // start skipping **after** the filter |             // start skipping **after** the filter | ||||||
|             .skip(from) |             .skip(from) | ||||||
|             // take **after** skipping |             // take **after** skipping | ||||||
|             .take(length) |             .take(length) | ||||||
|         { |         { | ||||||
|  |             if let ResultSource::Semantic = source { | ||||||
|  |                 semantic_hit_count += 1; | ||||||
|  |             } | ||||||
|             documents_ids.push(docid); |             documents_ids.push(docid); | ||||||
|             // TODO: pass both scores to documents_score in some way? |             // TODO: pass both scores to documents_score in some way? | ||||||
|             document_scores.push(main_score); |             document_scores.push(main_score); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         ( | ||||||
|             SearchResult { |             SearchResult { | ||||||
|             matching_words: right.matching_words, |                 matching_words: keyword_results.matching_words, | ||||||
|             candidates: left.candidates | right.candidates, |                 candidates: vector_results.candidates | keyword_results.candidates, | ||||||
|                 documents_ids, |                 documents_ids, | ||||||
|                 document_scores, |                 document_scores, | ||||||
|             degraded: left.degraded | right.degraded, |                 degraded: vector_results.degraded | keyword_results.degraded, | ||||||
|             used_negative_operator: left.used_negative_operator | right.used_negative_operator, |                 used_negative_operator: vector_results.used_negative_operator | ||||||
|         } |                     | keyword_results.used_negative_operator, | ||||||
|  |             }, | ||||||
|  |             semantic_hit_count, | ||||||
|  |         ) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<'a> Search<'a> { | impl<'a> Search<'a> { | ||||||
|     pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> { |     pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<(SearchResult, Option<u32>)> { | ||||||
|         // TODO: find classier way to achieve that than to reset vector and query params |         // TODO: find classier way to achieve that than to reset vector and query params | ||||||
|         // create separate keyword and semantic searches |         // create separate keyword and semantic searches | ||||||
|         let mut search = Search { |         let mut search = Search { | ||||||
|             query: self.query.clone(), |             query: self.query.clone(), | ||||||
|             vector: self.vector.clone(), |  | ||||||
|             filter: self.filter.clone(), |             filter: self.filter.clone(), | ||||||
|             offset: 0, |             offset: 0, | ||||||
|             limit: self.limit + self.offset, |             limit: self.limit + self.offset, | ||||||
| @@ -139,26 +167,43 @@ impl<'a> Search<'a> { | |||||||
|             exhaustive_number_hits: self.exhaustive_number_hits, |             exhaustive_number_hits: self.exhaustive_number_hits, | ||||||
|             rtxn: self.rtxn, |             rtxn: self.rtxn, | ||||||
|             index: self.index, |             index: self.index, | ||||||
|             distribution_shift: self.distribution_shift, |             semantic: self.semantic.clone(), | ||||||
|             embedder_name: self.embedder_name.clone(), |  | ||||||
|             time_budget: self.time_budget.clone(), |             time_budget: self.time_budget.clone(), | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let vector_query = search.vector.take(); |         let semantic = search.semantic.take(); | ||||||
|         let keyword_results = search.execute()?; |         let keyword_results = search.execute()?; | ||||||
|  |  | ||||||
|         // skip semantic search if we don't have a vector query (placeholder search) |  | ||||||
|         let Some(vector_query) = vector_query else { |  | ||||||
|             return Ok(keyword_results); |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         // completely skip semantic search if the results of the keyword search are good enough |         // completely skip semantic search if the results of the keyword search are good enough | ||||||
|         if self.results_good_enough(&keyword_results, semantic_ratio) { |         if self.results_good_enough(&keyword_results, semantic_ratio) { | ||||||
|             return Ok(keyword_results); |             return Ok((keyword_results, Some(0))); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         search.vector = Some(vector_query); |         // no vector search against placeholder search | ||||||
|         search.query = None; |         let Some(query) = search.query.take() else { | ||||||
|  |             return Ok((keyword_results, Some(0))); | ||||||
|  |         }; | ||||||
|  |         // no embedder, no semantic search | ||||||
|  |         let Some(SemanticSearch { vector, embedder_name, embedder }) = semantic else { | ||||||
|  |             return Ok((keyword_results, Some(0))); | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         let vector_query = match vector { | ||||||
|  |             Some(vector_query) => vector_query, | ||||||
|  |             None => { | ||||||
|  |                 // attempt to embed the vector | ||||||
|  |                 match embedder.embed_one(query) { | ||||||
|  |                     Ok(embedding) => embedding, | ||||||
|  |                     Err(error) => { | ||||||
|  |                         tracing::error!(error=%error, "Embedding failed"); | ||||||
|  |                         return Ok((keyword_results, Some(0))); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         search.semantic = | ||||||
|  |             Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder }); | ||||||
|  |  | ||||||
|         // TODO: would be better to have two distinct functions at this point |         // TODO: would be better to have two distinct functions at this point | ||||||
|         let vector_results = search.execute()?; |         let vector_results = search.execute()?; | ||||||
| @@ -166,10 +211,10 @@ impl<'a> Search<'a> { | |||||||
|         let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); |         let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); | ||||||
|         let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); |         let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); | ||||||
|  |  | ||||||
|         let merge_results = |         let (merge_results, semantic_hit_count) = | ||||||
|             ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); |             ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); | ||||||
|         assert!(merge_results.documents_ids.len() <= self.limit); |         assert!(merge_results.documents_ids.len() <= self.limit); | ||||||
|         Ok(merge_results) |         Ok((merge_results, Some(semantic_hit_count))) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool { |     fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool { | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| use std::fmt; | use std::fmt; | ||||||
|  | use std::sync::Arc; | ||||||
|  |  | ||||||
| use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; | use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; | ||||||
| use once_cell::sync::Lazy; | use once_cell::sync::Lazy; | ||||||
| @@ -8,7 +9,7 @@ pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FAC | |||||||
| pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; | pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; | ||||||
| use self::new::{execute_vector_search, PartialSearchResult}; | use self::new::{execute_vector_search, PartialSearchResult}; | ||||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||||
| use crate::vector::DistributionShift; | use crate::vector::Embedder; | ||||||
| use crate::{ | use crate::{ | ||||||
|     execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, |     execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, | ||||||
|     SearchContext, TimeBudget, |     SearchContext, TimeBudget, | ||||||
| @@ -24,9 +25,15 @@ mod fst_utils; | |||||||
| pub mod hybrid; | pub mod hybrid; | ||||||
| pub mod new; | pub mod new; | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub struct SemanticSearch { | ||||||
|  |     vector: Option<Vec<f32>>, | ||||||
|  |     embedder_name: String, | ||||||
|  |     embedder: Arc<Embedder>, | ||||||
|  | } | ||||||
|  |  | ||||||
| pub struct Search<'a> { | pub struct Search<'a> { | ||||||
|     query: Option<String>, |     query: Option<String>, | ||||||
|     vector: Option<Vec<f32>>, |  | ||||||
|     // this should be linked to the String in the query |     // this should be linked to the String in the query | ||||||
|     filter: Option<Filter<'a>>, |     filter: Option<Filter<'a>>, | ||||||
|     offset: usize, |     offset: usize, | ||||||
| @@ -38,12 +45,9 @@ pub struct Search<'a> { | |||||||
|     scoring_strategy: ScoringStrategy, |     scoring_strategy: ScoringStrategy, | ||||||
|     words_limit: usize, |     words_limit: usize, | ||||||
|     exhaustive_number_hits: bool, |     exhaustive_number_hits: bool, | ||||||
|     /// TODO: Add semantic ratio or pass it directly to execute_hybrid() |  | ||||||
|     rtxn: &'a heed::RoTxn<'a>, |     rtxn: &'a heed::RoTxn<'a>, | ||||||
|     index: &'a Index, |     index: &'a Index, | ||||||
|     distribution_shift: Option<DistributionShift>, |     semantic: Option<SemanticSearch>, | ||||||
|     embedder_name: Option<String>, |  | ||||||
|  |  | ||||||
|     time_budget: TimeBudget, |     time_budget: TimeBudget, | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -51,7 +55,6 @@ impl<'a> Search<'a> { | |||||||
|     pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { |     pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { | ||||||
|         Search { |         Search { | ||||||
|             query: None, |             query: None, | ||||||
|             vector: None, |  | ||||||
|             filter: None, |             filter: None, | ||||||
|             offset: 0, |             offset: 0, | ||||||
|             limit: 20, |             limit: 20, | ||||||
| @@ -64,8 +67,7 @@ impl<'a> Search<'a> { | |||||||
|             words_limit: 10, |             words_limit: 10, | ||||||
|             rtxn, |             rtxn, | ||||||
|             index, |             index, | ||||||
|             distribution_shift: None, |             semantic: None, | ||||||
|             embedder_name: None, |  | ||||||
|             time_budget: TimeBudget::max(), |             time_budget: TimeBudget::max(), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -75,8 +77,13 @@ impl<'a> Search<'a> { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> { |     pub fn semantic( | ||||||
|         self.vector = Some(vector); |         &mut self, | ||||||
|  |         embedder_name: String, | ||||||
|  |         embedder: Arc<Embedder>, | ||||||
|  |         vector: Option<Vec<f32>>, | ||||||
|  |     ) -> &mut Search<'a> { | ||||||
|  |         self.semantic = Some(SemanticSearch { embedder_name, embedder, vector }); | ||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -133,19 +140,6 @@ impl<'a> Search<'a> { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn distribution_shift( |  | ||||||
|         &mut self, |  | ||||||
|         distribution_shift: Option<DistributionShift>, |  | ||||||
|     ) -> &mut Search<'a> { |  | ||||||
|         self.distribution_shift = distribution_shift; |  | ||||||
|         self |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> { |  | ||||||
|         self.embedder_name = Some(embedder_name.into()); |  | ||||||
|         self |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> { |     pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> { | ||||||
|         self.time_budget = time_budget; |         self.time_budget = time_budget; | ||||||
|         self |         self | ||||||
| @@ -161,15 +155,6 @@ impl<'a> Search<'a> { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn execute(&self) -> Result<SearchResult> { |     pub fn execute(&self) -> Result<SearchResult> { | ||||||
|         let embedder_name; |  | ||||||
|         let embedder_name = match &self.embedder_name { |  | ||||||
|             Some(embedder_name) => embedder_name, |  | ||||||
|             None => { |  | ||||||
|                 embedder_name = self.index.default_embedding_name(self.rtxn)?; |  | ||||||
|                 &embedder_name |  | ||||||
|             } |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         let mut ctx = SearchContext::new(self.index, self.rtxn); |         let mut ctx = SearchContext::new(self.index, self.rtxn); | ||||||
|  |  | ||||||
|         if let Some(searchable_attributes) = self.searchable_attributes { |         if let Some(searchable_attributes) = self.searchable_attributes { | ||||||
| @@ -184,8 +169,9 @@ impl<'a> Search<'a> { | |||||||
|             document_scores, |             document_scores, | ||||||
|             degraded, |             degraded, | ||||||
|             used_negative_operator, |             used_negative_operator, | ||||||
|         } = match self.vector.as_ref() { |         } = match self.semantic.as_ref() { | ||||||
|             Some(vector) => execute_vector_search( |             Some(SemanticSearch { vector: Some(vector), embedder_name, embedder }) => { | ||||||
|  |                 execute_vector_search( | ||||||
|                     &mut ctx, |                     &mut ctx, | ||||||
|                     vector, |                     vector, | ||||||
|                     self.scoring_strategy, |                     self.scoring_strategy, | ||||||
| @@ -194,11 +180,12 @@ impl<'a> Search<'a> { | |||||||
|                     self.geo_strategy, |                     self.geo_strategy, | ||||||
|                     self.offset, |                     self.offset, | ||||||
|                     self.limit, |                     self.limit, | ||||||
|                 self.distribution_shift, |  | ||||||
|                     embedder_name, |                     embedder_name, | ||||||
|  |                     embedder, | ||||||
|                     self.time_budget.clone(), |                     self.time_budget.clone(), | ||||||
|             )?, |                 )? | ||||||
|             None => execute_search( |             } | ||||||
|  |             _ => execute_search( | ||||||
|                 &mut ctx, |                 &mut ctx, | ||||||
|                 self.query.as_deref(), |                 self.query.as_deref(), | ||||||
|                 self.terms_matching_strategy, |                 self.terms_matching_strategy, | ||||||
| @@ -237,7 +224,6 @@ impl fmt::Debug for Search<'_> { | |||||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||||
|         let Search { |         let Search { | ||||||
|             query, |             query, | ||||||
|             vector: _, |  | ||||||
|             filter, |             filter, | ||||||
|             offset, |             offset, | ||||||
|             limit, |             limit, | ||||||
| @@ -250,8 +236,7 @@ impl fmt::Debug for Search<'_> { | |||||||
|             exhaustive_number_hits, |             exhaustive_number_hits, | ||||||
|             rtxn: _, |             rtxn: _, | ||||||
|             index: _, |             index: _, | ||||||
|             distribution_shift, |             semantic, | ||||||
|             embedder_name, |  | ||||||
|             time_budget, |             time_budget, | ||||||
|         } = self; |         } = self; | ||||||
|         f.debug_struct("Search") |         f.debug_struct("Search") | ||||||
| @@ -266,8 +251,10 @@ impl fmt::Debug for Search<'_> { | |||||||
|             .field("scoring_strategy", scoring_strategy) |             .field("scoring_strategy", scoring_strategy) | ||||||
|             .field("exhaustive_number_hits", exhaustive_number_hits) |             .field("exhaustive_number_hits", exhaustive_number_hits) | ||||||
|             .field("words_limit", words_limit) |             .field("words_limit", words_limit) | ||||||
|             .field("distribution_shift", distribution_shift) |             .field( | ||||||
|             .field("embedder_name", embedder_name) |                 "semantic.embedder_name", | ||||||
|  |                 &semantic.as_ref().map(|semantic| &semantic.embedder_name), | ||||||
|  |             ) | ||||||
|             .field("time_budget", time_budget) |             .field("time_budget", time_budget) | ||||||
|             .finish() |             .finish() | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -52,7 +52,7 @@ use self::vector_sort::VectorSort; | |||||||
| use crate::error::FieldIdMapMissingEntry; | use crate::error::FieldIdMapMissingEntry; | ||||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||||
| use crate::search::new::distinct::apply_distinct_rule; | use crate::search::new::distinct::apply_distinct_rule; | ||||||
| use crate::vector::DistributionShift; | use crate::vector::Embedder; | ||||||
| use crate::{ | use crate::{ | ||||||
|     AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, TimeBudget, |     AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, TimeBudget, | ||||||
|     UserError, |     UserError, | ||||||
| @@ -298,8 +298,8 @@ fn get_ranking_rules_for_vector<'ctx>( | |||||||
|     geo_strategy: geo_sort::Strategy, |     geo_strategy: geo_sort::Strategy, | ||||||
|     limit_plus_offset: usize, |     limit_plus_offset: usize, | ||||||
|     target: &[f32], |     target: &[f32], | ||||||
|     distribution_shift: Option<DistributionShift>, |  | ||||||
|     embedder_name: &str, |     embedder_name: &str, | ||||||
|  |     embedder: &Embedder, | ||||||
| ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { | ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { | ||||||
|     // query graph search |     // query graph search | ||||||
|  |  | ||||||
| @@ -325,8 +325,8 @@ fn get_ranking_rules_for_vector<'ctx>( | |||||||
|                         target.to_vec(), |                         target.to_vec(), | ||||||
|                         vector_candidates, |                         vector_candidates, | ||||||
|                         limit_plus_offset, |                         limit_plus_offset, | ||||||
|                         distribution_shift, |  | ||||||
|                         embedder_name, |                         embedder_name, | ||||||
|  |                         embedder, | ||||||
|                     )?; |                     )?; | ||||||
|                     ranking_rules.push(Box::new(vector_sort)); |                     ranking_rules.push(Box::new(vector_sort)); | ||||||
|                     vector = true; |                     vector = true; | ||||||
| @@ -548,8 +548,8 @@ pub fn execute_vector_search( | |||||||
|     geo_strategy: geo_sort::Strategy, |     geo_strategy: geo_sort::Strategy, | ||||||
|     from: usize, |     from: usize, | ||||||
|     length: usize, |     length: usize, | ||||||
|     distribution_shift: Option<DistributionShift>, |  | ||||||
|     embedder_name: &str, |     embedder_name: &str, | ||||||
|  |     embedder: &Embedder, | ||||||
|     time_budget: TimeBudget, |     time_budget: TimeBudget, | ||||||
| ) -> Result<PartialSearchResult> { | ) -> Result<PartialSearchResult> { | ||||||
|     check_sort_criteria(ctx, sort_criteria.as_ref())?; |     check_sort_criteria(ctx, sort_criteria.as_ref())?; | ||||||
| @@ -562,8 +562,8 @@ pub fn execute_vector_search( | |||||||
|         geo_strategy, |         geo_strategy, | ||||||
|         from + length, |         from + length, | ||||||
|         vector, |         vector, | ||||||
|         distribution_shift, |  | ||||||
|         embedder_name, |         embedder_name, | ||||||
|  |         embedder, | ||||||
|     )?; |     )?; | ||||||
|  |  | ||||||
|     let mut placeholder_search_logger = logger::DefaultSearchLogger; |     let mut placeholder_search_logger = logger::DefaultSearchLogger; | ||||||
|   | |||||||
| @@ -5,14 +5,14 @@ use roaring::RoaringBitmap; | |||||||
|  |  | ||||||
| use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; | use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; | ||||||
| use crate::score_details::{self, ScoreDetails}; | use crate::score_details::{self, ScoreDetails}; | ||||||
| use crate::vector::DistributionShift; | use crate::vector::{DistributionShift, Embedder}; | ||||||
| use crate::{DocumentId, Result, SearchContext, SearchLogger}; | use crate::{DocumentId, Result, SearchContext, SearchLogger}; | ||||||
|  |  | ||||||
| pub struct VectorSort<Q: RankingRuleQueryTrait> { | pub struct VectorSort<Q: RankingRuleQueryTrait> { | ||||||
|     query: Option<Q>, |     query: Option<Q>, | ||||||
|     target: Vec<f32>, |     target: Vec<f32>, | ||||||
|     vector_candidates: RoaringBitmap, |     vector_candidates: RoaringBitmap, | ||||||
|     cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>, |     cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32)>, | ||||||
|     limit: usize, |     limit: usize, | ||||||
|     distribution_shift: Option<DistributionShift>, |     distribution_shift: Option<DistributionShift>, | ||||||
|     embedder_index: u8, |     embedder_index: u8, | ||||||
| @@ -24,8 +24,8 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> { | |||||||
|         target: Vec<f32>, |         target: Vec<f32>, | ||||||
|         vector_candidates: RoaringBitmap, |         vector_candidates: RoaringBitmap, | ||||||
|         limit: usize, |         limit: usize, | ||||||
|         distribution_shift: Option<DistributionShift>, |  | ||||||
|         embedder_name: &str, |         embedder_name: &str, | ||||||
|  |         embedder: &Embedder, | ||||||
|     ) -> Result<Self> { |     ) -> Result<Self> { | ||||||
|         let embedder_index = ctx |         let embedder_index = ctx | ||||||
|             .index |             .index | ||||||
| @@ -39,7 +39,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> { | |||||||
|             vector_candidates, |             vector_candidates, | ||||||
|             cached_sorted_docids: Default::default(), |             cached_sorted_docids: Default::default(), | ||||||
|             limit, |             limit, | ||||||
|             distribution_shift, |             distribution_shift: embedder.distribution(), | ||||||
|             embedder_index, |             embedder_index, | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
| @@ -70,14 +70,9 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> { | |||||||
|         for reader in readers.iter() { |         for reader in readers.iter() { | ||||||
|             let nns_by_vector = |             let nns_by_vector = | ||||||
|                 reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?; |                 reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?; | ||||||
|             let vectors: std::result::Result<Vec<_>, _> = nns_by_vector |             results.extend(nns_by_vector.into_iter()); | ||||||
|                 .iter() |  | ||||||
|                 .map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap()) |  | ||||||
|                 .collect(); |  | ||||||
|             let vectors = vectors?; |  | ||||||
|             results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z))); |  | ||||||
|         } |         } | ||||||
|         results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance)); |         results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance)); | ||||||
|         self.cached_sorted_docids = results.into_iter(); |         self.cached_sorted_docids = results.into_iter(); | ||||||
|  |  | ||||||
|         Ok(()) |         Ok(()) | ||||||
| @@ -118,14 +113,11 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> { | |||||||
|             return Ok(Some(RankingRuleOutput { |             return Ok(Some(RankingRuleOutput { | ||||||
|                 query, |                 query, | ||||||
|                 candidates: universe.clone(), |                 candidates: universe.clone(), | ||||||
|                 score: ScoreDetails::Vector(score_details::Vector { |                 score: ScoreDetails::Vector(score_details::Vector { similarity: None }), | ||||||
|                     target_vector: self.target.clone(), |  | ||||||
|                     value_similarity: None, |  | ||||||
|                 }), |  | ||||||
|             })); |             })); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         for (docid, distance, vector) in self.cached_sorted_docids.by_ref() { |         for (docid, distance) in self.cached_sorted_docids.by_ref() { | ||||||
|             if vector_candidates.contains(docid) { |             if vector_candidates.contains(docid) { | ||||||
|                 let score = 1.0 - distance; |                 let score = 1.0 - distance; | ||||||
|                 let score = self |                 let score = self | ||||||
| @@ -135,10 +127,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> { | |||||||
|                 return Ok(Some(RankingRuleOutput { |                 return Ok(Some(RankingRuleOutput { | ||||||
|                     query, |                     query, | ||||||
|                     candidates: RoaringBitmap::from_iter([docid]), |                     candidates: RoaringBitmap::from_iter([docid]), | ||||||
|                     score: ScoreDetails::Vector(score_details::Vector { |                     score: ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }), | ||||||
|                         target_vector: self.target.clone(), |  | ||||||
|                         value_similarity: Some((vector, score)), |  | ||||||
|                     }), |  | ||||||
|                 })); |                 })); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -154,10 +143,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> { | |||||||
|             return Ok(Some(RankingRuleOutput { |             return Ok(Some(RankingRuleOutput { | ||||||
|                 query, |                 query, | ||||||
|                 candidates: universe.clone(), |                 candidates: universe.clone(), | ||||||
|                 score: ScoreDetails::Vector(score_details::Vector { |                 score: ScoreDetails::Vector(score_details::Vector { similarity: None }), | ||||||
|                     target_vector: self.target.clone(), |  | ||||||
|                     value_similarity: None, |  | ||||||
|                 }), |  | ||||||
|             })); |             })); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2672,7 +2672,16 @@ mod tests { | |||||||
|                .unwrap(); |                .unwrap(); | ||||||
|  |  | ||||||
|         let rtxn = index.read_txn().unwrap(); |         let rtxn = index.read_txn().unwrap(); | ||||||
|         let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); |         let mut embedding_configs = index.embedding_configs(&rtxn).unwrap(); | ||||||
|  |         let (embedder_name, embedder) = embedding_configs.pop().unwrap(); | ||||||
|  |         let embedder = | ||||||
|  |             std::sync::Arc::new(crate::vector::Embedder::new(embedder.embedder_options).unwrap()); | ||||||
|  |         assert_eq!("manual", embedder_name); | ||||||
|  |         let res = index | ||||||
|  |             .search(&rtxn) | ||||||
|  |             .semantic(embedder_name, embedder, Some([0.0, 1.0, 2.0].to_vec())) | ||||||
|  |             .execute() | ||||||
|  |             .unwrap(); | ||||||
|         assert_eq!(res.documents_ids.len(), 3); |         assert_eq!(res.documents_ids.len(), 3); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -58,7 +58,7 @@ pub enum EmbedErrorKind { | |||||||
|     RestResponseDeserialization(std::io::Error), |     RestResponseDeserialization(std::io::Error), | ||||||
|     #[error("component `{0}` not found in path `{1}` in response: `{2}`")] |     #[error("component `{0}` not found in path `{1}` in response: `{2}`")] | ||||||
|     RestResponseMissingEmbeddings(String, String, String), |     RestResponseMissingEmbeddings(String, String, String), | ||||||
|     #[error("expected a response parseable as a vector or an array of vectors: {0}")] |     #[error("unexpected format of the embedding response: {0}")] | ||||||
|     RestResponseFormat(serde_json::Error), |     RestResponseFormat(serde_json::Error), | ||||||
|     #[error("expected a response containing {0} embeddings, got only {1}")] |     #[error("expected a response containing {0} embeddings, got only {1}")] | ||||||
|     RestResponseEmbeddingCount(usize, usize), |     RestResponseEmbeddingCount(usize, usize), | ||||||
| @@ -78,6 +78,8 @@ pub enum EmbedErrorKind { | |||||||
|     RestNotAnObject(serde_json::Value, Vec<String>), |     RestNotAnObject(serde_json::Value, Vec<String>), | ||||||
|     #[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")] |     #[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")] | ||||||
|     OpenAiUnexpectedDimension(usize, usize), |     OpenAiUnexpectedDimension(usize, usize), | ||||||
|  |     #[error("no embedding was produced")] | ||||||
|  |     MissingEmbedding, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl EmbedError { | impl EmbedError { | ||||||
| @@ -190,6 +192,9 @@ impl EmbedError { | |||||||
|             fault: FaultSource::Runtime, |             fault: FaultSource::Runtime, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |     pub(crate) fn missing_embedding() -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::MissingEmbedding, fault: FaultSource::Undecided } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, thiserror::Error)] | #[derive(Debug, thiserror::Error)] | ||||||
|   | |||||||
| @@ -143,7 +143,7 @@ impl EmbeddingConfigs { | |||||||
|  |  | ||||||
|     /// Get the default embedder configuration, if any. |     /// Get the default embedder configuration, if any. | ||||||
|     pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> { |     pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> { | ||||||
|         self.get_default_embedder_name().and_then(|default| self.get(&default)) |         self.get(self.get_default_embedder_name()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Get the name of the default embedder configuration. |     /// Get the name of the default embedder configuration. | ||||||
| @@ -153,14 +153,14 @@ impl EmbeddingConfigs { | |||||||
|     /// - If there is only one embedder, it is always the default. |     /// - If there is only one embedder, it is always the default. | ||||||
|     /// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder. |     /// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder. | ||||||
|     /// - In all other cases, there is no default embedder. |     /// - In all other cases, there is no default embedder. | ||||||
|     pub fn get_default_embedder_name(&self) -> Option<String> { |     pub fn get_default_embedder_name(&self) -> &str { | ||||||
|         let mut it = self.0.keys(); |         let mut it = self.0.keys(); | ||||||
|         let first_name = it.next(); |         let first_name = it.next(); | ||||||
|         let second_name = it.next(); |         let second_name = it.next(); | ||||||
|         match (first_name, second_name) { |         match (first_name, second_name) { | ||||||
|             (None, _) => None, |             (None, _) => "default", | ||||||
|             (Some(first), None) => Some(first.to_owned()), |             (Some(first), None) => first, | ||||||
|             (Some(_), Some(_)) => Some("default".to_owned()), |             (Some(_), Some(_)) => "default", | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -237,6 +237,17 @@ impl Embedder { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> { | ||||||
|  |         let mut embeddings = self.embed(vec![text])?; | ||||||
|  |         let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?; | ||||||
|  |         Ok(if embeddings.iter().nth(1).is_some() { | ||||||
|  |             tracing::warn!("Ignoring embeddings past the first one in long search query"); | ||||||
|  |             embeddings.iter().next().unwrap().to_vec() | ||||||
|  |         } else { | ||||||
|  |             embeddings.into_inner() | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Embed multiple chunks of texts. |     /// Embed multiple chunks of texts. | ||||||
|     /// |     /// | ||||||
|     /// Each chunk is composed of one or multiple texts. |     /// Each chunk is composed of one or multiple texts. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user