diff --git a/crates/meilisearch/src/search/federated/perform.rs b/crates/meilisearch/src/search/federated/perform.rs index b747d4da4..4c66e5f68 100644 --- a/crates/meilisearch/src/search/federated/perform.rs +++ b/crates/meilisearch/src/search/federated/perform.rs @@ -47,6 +47,7 @@ pub async fn perform_federated_search( let deadline = before_search + std::time::Duration::from_secs(9); let required_hit_count = federation.limit + federation.offset; + let retrieve_vectors = queries.iter().any(|q| q.retrieve_vectors); let network = index_scheduler.network(); @@ -92,6 +93,7 @@ pub async fn perform_federated_search( federation, mut semantic_hit_count, mut results_by_index, + mut query_vectors, previous_query_data: _, facet_order, } = search_by_index; @@ -123,7 +125,26 @@ pub async fn perform_federated_search( .map(|hit| hit.hit()) .collect(); - // 3.3. merge facets + // 3.3. merge query vectors + let query_vectors = if retrieve_vectors { + for remote_results in remote_results.iter_mut() { + if let Some(remote_vectors) = remote_results.query_vectors.take() { + for (key, value) in remote_vectors.into_iter() { + debug_assert!( + !query_vectors.contains_key(&key), + "Query vector for query {key} already exists" + ); + query_vectors.insert(key, value); + } + } + } + + Some(query_vectors) + } else { + None + }; + + // 3.4. merge facets let (facet_distribution, facet_stats, facets_by_index) = facet_order.merge(federation.merge_facets, remote_results, facets); @@ -141,6 +162,7 @@ pub async fn perform_federated_search( offset: federation.offset, estimated_total_hits, }, + query_vectors, semantic_hit_count, degraded, used_negative_operator, @@ -409,6 +431,7 @@ fn merge_metadata( hits: _, processing_time_ms, hits_info, + query_vectors: _, semantic_hit_count: _, facet_distribution: _, facet_stats: _, @@ -658,6 +681,7 @@ struct SearchByIndex { // Then when merging, we'll update its value if there is any semantic hit semantic_hit_count: Option, results_by_index: Vec, + query_vectors: BTreeMap, previous_query_data: Option<(RankingRules, usize, String)>, // remember the order and name of first index for each facet when merging with index settings // to detect if the order is inconsistent for a facet. @@ -675,6 +699,7 @@ impl SearchByIndex { federation, semantic_hit_count: None, results_by_index: Vec::with_capacity(index_count), + query_vectors: BTreeMap::new(), previous_query_data: None, } } @@ -842,6 +867,16 @@ impl SearchByIndex { query_vector, } = result; + if query.retrieve_vectors { + if let Some(query_vector) = query_vector { + debug_assert!( + !self.query_vectors.contains_key(&query_index), + "Query vector for query {query_index} already exists" + ); + self.query_vectors.insert(query_index, query_vector); + } + } + candidates |= query_candidates; degraded |= query_degraded; used_negative_operator |= query_used_negative_operator; diff --git a/crates/meilisearch/src/search/federated/types.rs b/crates/meilisearch/src/search/federated/types.rs index 3cf28c815..9c96fe768 100644 --- a/crates/meilisearch/src/search/federated/types.rs +++ b/crates/meilisearch/src/search/federated/types.rs @@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex}; +use crate::milli::vector::Embedding; pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0; @@ -117,6 +118,9 @@ pub struct FederatedSearchResult { #[serde(flatten)] pub hits_info: HitsInfo, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub query_vectors: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] pub semantic_hit_count: Option, @@ -144,6 +148,7 @@ impl fmt::Debug for FederatedSearchResult { hits, processing_time_ms, hits_info, + query_vectors, semantic_hit_count, degraded, used_negative_operator, @@ -158,6 +163,10 @@ impl fmt::Debug for FederatedSearchResult { debug.field("processing_time_ms", &processing_time_ms); debug.field("hits", &format!("[{} hits returned]", hits.len())); debug.field("hits_info", &hits_info); + if let Some(query_vectors) = query_vectors { + let known = query_vectors.len(); + debug.field("query_vectors", &format!("[{known} known vectors]")); + } if *used_negative_operator { debug.field("used_negative_operator", used_negative_operator); } diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 1329a6f72..e08a13a7b 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -1020,7 +1020,7 @@ pub fn prepare_search<'t>( .map_err(milli::Error::from)? } }; - search.semantic( + search.semantic_auto_embedded( embedder_name.clone(), embedder.clone(), *quantized, diff --git a/crates/meilisearch/tests/search/multi/proxy.rs b/crates/meilisearch/tests/search/multi/proxy.rs index 311f69d9e..8f2741202 100644 --- a/crates/meilisearch/tests/search/multi/proxy.rs +++ b/crates/meilisearch/tests/search/multi/proxy.rs @@ -2,8 +2,9 @@ use std::sync::Arc; use actix_http::StatusCode; use meili_snap::{json_string, snapshot}; -use wiremock::matchers::AnyMatcher; -use wiremock::{Mock, MockServer, ResponseTemplate}; +use wiremock::matchers::method; +use wiremock::matchers::{path, AnyMatcher}; +use wiremock::{Mock, MockServer, Request, ResponseTemplate}; use crate::common::{Server, Value, SCORE_DOCUMENTS}; use crate::json; @@ -415,6 +416,209 @@ async fn remote_sharding() { "###); } +#[actix_rt::test] +async fn remote_sharding_retrieve_vectors() { + let ms0 = Server::new().await; + let ms1 = Server::new().await; + let ms2 = Server::new().await; + let index0 = ms0.index("test"); + let index1 = ms1.index("test"); + let index2 = ms2.index("test"); + + // enable feature + + let (response, code) = ms0.set_features(json!({"network": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["network"]), @"true"); + let (response, code) = ms1.set_features(json!({"network": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["network"]), @"true"); + let (response, code) = ms2.set_features(json!({"network": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["network"]), @"true"); + + // set self + + let (response, code) = ms0.set_network(json!({"self": "ms0"})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response), @r###" + { + "self": "ms0", + "remotes": {} + } + "###); + let (response, code) = ms1.set_network(json!({"self": "ms1"})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response), @r###" + { + "self": "ms1", + "remotes": {} + } + "###); + let (response, code) = ms2.set_network(json!({"self": "ms2"})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response), @r###" + { + "self": "ms2", + "remotes": {} + } + "###); + + // setup embedders + + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/")) + .respond_with(move |req: &Request| { + println!("Received request: {:?}", req); + let text = req.body_json::().unwrap().to_lowercase(); + let patterns = [ + ("batman", [1.0, 0.0, 0.0]), + ("dark", [0.0, 0.1, 0.0]), + ("knight", [0.1, 0.1, 0.0]), + ("returns", [0.0, 0.0, 0.2]), + ("part", [0.05, 0.1, 0.0]), + ("1", [0.3, 0.05, 0.0]), + ("2", [0.2, 0.05, 0.0]), + ]; + let mut embedding = vec![0.; 3]; + for (pattern, vector) in patterns { + if text.contains(pattern) { + for (i, v) in vector.iter().enumerate() { + embedding[i] += v; + } + } + } + ResponseTemplate::new(200).set_body_json(json!({ "data": embedding })) + }) + .mount(&mock_server) + .await; + let url = mock_server.uri(); + + for (server, index) in [(&ms0, &index0), (&ms1, &index1), (&ms2, &index2)] { + let (response, code) = index + .update_settings(json!({ + "embedders": { + "rest": { + "source": "rest", + "url": url, + "dimensions": 3, + "request": "{{text}}", + "response": { "data": "{{embedding}}" }, + "documentTemplate": "{{doc.name}}", + }, + }, + })) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await.succeeded(); + } + + // wrap servers + let ms0 = Arc::new(ms0); + let ms1 = Arc::new(ms1); + let ms2 = Arc::new(ms2); + + let rms0 = LocalMeili::new(ms0.clone()).await; + let rms1 = LocalMeili::new(ms1.clone()).await; + let rms2 = LocalMeili::new(ms2.clone()).await; + + // set network + let network = json!({"remotes": { + "ms0": { + "url": rms0.url() + }, + "ms1": { + "url": rms1.url() + }, + "ms2": { + "url": rms2.url() + } + }}); + + let (_response, status_code) = ms0.set_network(network.clone()).await; + snapshot!(status_code, @"200 OK"); + let (_response, status_code) = ms1.set_network(network.clone()).await; + snapshot!(status_code, @"200 OK"); + let (_response, status_code) = ms2.set_network(network.clone()).await; + snapshot!(status_code, @"200 OK"); + + // perform multi-search + let query = "badman returns"; + let request = json!({ + "federation": {}, + "queries": [ + { + "q": query, + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms0" + } + }, + { + "q": query, + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + { + "q": query, + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms2" + } + }, + ] + }); + + let (response, _status_code) = ms0.multi_search(request.clone()).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#" + { + "hits": [], + "processingTimeMs": "[time]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 0, + "queryVectors": { + "0": [ + 0.0, + 0.0, + 0.2 + ], + "1": [ + 0.0, + 0.0, + 0.2 + ], + "2": [ + 0.0, + 0.0, + 0.2 + ] + }, + "semanticHitCount": 0, + "remoteErrors": {} + } + "#); +} + #[actix_rt::test] async fn error_unregistered_remote() { let ms0 = Server::new().await; diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index e5b4e6787..547d66409 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -230,7 +230,14 @@ impl Search<'_> { } // no embedder, no semantic search - let Some(SemanticSearch { vector, embedder_name, embedder, quantized, media }) = semantic + let Some(SemanticSearch { + vector, + mut auto_embedded, + embedder_name, + embedder, + quantized, + media, + }) = semantic else { return Ok(return_keyword_results(self.limit, self.offset, keyword_results)); }; @@ -253,7 +260,10 @@ impl Search<'_> { let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); match embedder.embed_search(query, Some(deadline)) { - Ok(embedding) => embedding, + Ok(embedding) => { + auto_embedded = true; + embedding + } Err(error) => { tracing::error!(error=%error, "Embedding failed"); return Ok(return_keyword_results( @@ -268,6 +278,7 @@ impl Search<'_> { search.semantic = Some(SemanticSearch { vector: Some(vector_query.clone()), + auto_embedded, embedder_name, embedder, quantized, @@ -280,7 +291,7 @@ impl Search<'_> { let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); - let (mut merge_results, semantic_hit_count) = ScoreWithRatioResult::merge( + let (merge_results, semantic_hit_count) = ScoreWithRatioResult::merge( vector_results, keyword_results, self.offset, @@ -289,7 +300,6 @@ impl Search<'_> { search.index, search.rtxn, )?; - merge_results.query_vector = Some(vector_query); assert!(merge_results.documents_ids.len() <= self.limit); Ok((merge_results, Some(semantic_hit_count))) } diff --git a/crates/milli/src/search/mod.rs b/crates/milli/src/search/mod.rs index cd0d5bc9b..de0514b0b 100644 --- a/crates/milli/src/search/mod.rs +++ b/crates/milli/src/search/mod.rs @@ -32,6 +32,7 @@ pub mod similar; #[derive(Debug, Clone)] pub struct SemanticSearch { vector: Option>, + auto_embedded: bool, media: Option, embedder_name: String, embedder: Arc, @@ -97,7 +98,33 @@ impl<'a> Search<'a> { vector: Option, media: Option, ) -> &mut Search<'a> { - self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector, media }); + self.semantic = Some(SemanticSearch { + embedder_name, + auto_embedded: false, + embedder, + quantized, + vector, + media, + }); + self + } + + pub fn semantic_auto_embedded( + &mut self, + embedder_name: String, + embedder: Arc, + quantized: bool, + vector: Option, + media: Option, + ) -> &mut Search<'a> { + self.semantic = Some(SemanticSearch { + embedder_name, + auto_embedded: true, + embedder, + quantized, + vector, + media, + }); self } @@ -225,6 +252,7 @@ impl<'a> Search<'a> { } let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; + let mut query_vector = None; let PartialSearchResult { located_query_terms, candidates, @@ -235,26 +263,32 @@ impl<'a> Search<'a> { } = match self.semantic.as_ref() { Some(SemanticSearch { vector: Some(vector), + auto_embedded, embedder_name, embedder, quantized, media: _, - }) => execute_vector_search( - &mut ctx, - vector, - self.scoring_strategy, - universe, - &self.sort_criteria, - &self.distinct, - self.geo_param, - self.offset, - self.limit, - embedder_name, - embedder, - *quantized, - self.time_budget.clone(), - self.ranking_score_threshold, - )?, + }) => { + if *auto_embedded { + query_vector = Some(vector.clone()); + } + execute_vector_search( + &mut ctx, + vector, + self.scoring_strategy, + universe, + &self.sort_criteria, + &self.distinct, + self.geo_param, + self.offset, + self.limit, + embedder_name, + embedder, + *quantized, + self.time_budget.clone(), + self.ranking_score_threshold, + )? + } _ => execute_search( &mut ctx, self.query.as_deref(), @@ -295,7 +329,7 @@ impl<'a> Search<'a> { documents_ids, degraded, used_negative_operator, - query_vector: None, + query_vector, }) } }