diff --git a/crates/meilisearch/src/routes/indexes/search_analytics.rs b/crates/meilisearch/src/routes/indexes/search_analytics.rs index 07f79eba7..2a8e78059 100644 --- a/crates/meilisearch/src/routes/indexes/search_analytics.rs +++ b/crates/meilisearch/src/routes/indexes/search_analytics.rs @@ -224,6 +224,7 @@ impl SearchAggregator { let SearchResult { hits: _, query: _, + query_vector: _, processing_time_ms, hits_info: _, semantic_hit_count: _, diff --git a/crates/meilisearch/src/search/federated/perform.rs b/crates/meilisearch/src/search/federated/perform.rs index c0fec01e8..3c80c22e3 100644 --- a/crates/meilisearch/src/search/federated/perform.rs +++ b/crates/meilisearch/src/search/federated/perform.rs @@ -13,6 +13,7 @@ use meilisearch_types::error::ResponseError; use meilisearch_types::features::{Network, Remote}; use meilisearch_types::milli::order_by_map::OrderByMap; use meilisearch_types::milli::score_details::{ScoreDetails, WeightedScoreValue}; +use meilisearch_types::milli::vector::Embedding; use meilisearch_types::milli::{self, DocumentId, OrderBy, TimeBudget, DEFAULT_VALUES_PER_FACET}; use roaring::RoaringBitmap; use tokio::task::JoinHandle; @@ -46,6 +47,7 @@ pub async fn perform_federated_search( let deadline = before_search + std::time::Duration::from_secs(9); let required_hit_count = federation.limit + federation.offset; + let retrieve_vectors = queries.iter().any(|q| q.retrieve_vectors); let network = index_scheduler.network(); @@ -91,6 +93,7 @@ pub async fn perform_federated_search( federation, mut semantic_hit_count, mut results_by_index, + mut query_vectors, previous_query_data: _, facet_order, } = search_by_index; @@ -122,7 +125,26 @@ pub async fn perform_federated_search( .map(|hit| hit.hit()) .collect(); - // 3.3. merge facets + // 3.3. merge query vectors + let query_vectors = if retrieve_vectors { + for remote_results in remote_results.iter_mut() { + if let Some(remote_vectors) = remote_results.query_vectors.take() { + for (key, value) in remote_vectors.into_iter() { + debug_assert!( + !query_vectors.contains_key(&key), + "Query vector for query {key} already exists" + ); + query_vectors.insert(key, value); + } + } + } + + Some(query_vectors) + } else { + None + }; + + // 3.4. merge facets let (facet_distribution, facet_stats, facets_by_index) = facet_order.merge(federation.merge_facets, remote_results, facets); @@ -140,6 +162,7 @@ pub async fn perform_federated_search( offset: federation.offset, estimated_total_hits, }, + query_vectors, semantic_hit_count, degraded, used_negative_operator, @@ -408,6 +431,7 @@ fn merge_metadata( hits: _, processing_time_ms, hits_info, + query_vectors: _, semantic_hit_count: _, facet_distribution: _, facet_stats: _, @@ -657,6 +681,7 @@ struct SearchByIndex { // Then when merging, we'll update its value if there is any semantic hit semantic_hit_count: Option, results_by_index: Vec, + query_vectors: BTreeMap, previous_query_data: Option<(RankingRules, usize, String)>, // remember the order and name of first index for each facet when merging with index settings // to detect if the order is inconsistent for a facet. @@ -674,6 +699,7 @@ impl SearchByIndex { federation, semantic_hit_count: None, results_by_index: Vec::with_capacity(index_count), + query_vectors: BTreeMap::new(), previous_query_data: None, } } @@ -837,8 +863,19 @@ impl SearchByIndex { document_scores, degraded: query_degraded, used_negative_operator: query_used_negative_operator, + query_vector, } = result; + if query.retrieve_vectors { + if let Some(query_vector) = query_vector { + debug_assert!( + !self.query_vectors.contains_key(&query_index), + "Query vector for query {query_index} already exists" + ); + self.query_vectors.insert(query_index, query_vector); + } + } + candidates |= query_candidates; degraded |= query_degraded; used_negative_operator |= query_used_negative_operator; diff --git a/crates/meilisearch/src/search/federated/types.rs b/crates/meilisearch/src/search/federated/types.rs index 3cf28c815..9c96fe768 100644 --- a/crates/meilisearch/src/search/federated/types.rs +++ b/crates/meilisearch/src/search/federated/types.rs @@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex}; +use crate::milli::vector::Embedding; pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0; @@ -117,6 +118,9 @@ pub struct FederatedSearchResult { #[serde(flatten)] pub hits_info: HitsInfo, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub query_vectors: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] pub semantic_hit_count: Option, @@ -144,6 +148,7 @@ impl fmt::Debug for FederatedSearchResult { hits, processing_time_ms, hits_info, + query_vectors, semantic_hit_count, degraded, used_negative_operator, @@ -158,6 +163,10 @@ impl fmt::Debug for FederatedSearchResult { debug.field("processing_time_ms", &processing_time_ms); debug.field("hits", &format!("[{} hits returned]", hits.len())); debug.field("hits_info", &hits_info); + if let Some(query_vectors) = query_vectors { + let known = query_vectors.len(); + debug.field("query_vectors", &format!("[{known} known vectors]")); + } if *used_negative_operator { debug.field("used_negative_operator", used_negative_operator); } diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 82096e7b4..bb406aed9 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -841,6 +841,8 @@ pub struct SearchHit { pub struct SearchResult { pub hits: Vec, pub query: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub query_vector: Option>, pub processing_time_ms: u128, #[serde(flatten)] pub hits_info: HitsInfo, @@ -865,6 +867,7 @@ impl fmt::Debug for SearchResult { let SearchResult { hits, query, + query_vector, processing_time_ms, hits_info, facet_distribution, @@ -879,6 +882,9 @@ impl fmt::Debug for SearchResult { debug.field("processing_time_ms", &processing_time_ms); debug.field("hits", &format!("[{} hits returned]", hits.len())); debug.field("query", &query); + if query_vector.is_some() { + debug.field("query_vector", &"[...]"); + } debug.field("hits_info", &hits_info); if *used_negative_operator { debug.field("used_negative_operator", used_negative_operator); @@ -1050,6 +1056,7 @@ pub fn prepare_search<'t>( .map(|x| x as usize) .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); + search.retrieve_vectors(query.retrieve_vectors); search.exhaustive_number_hits(is_finite_pagination); search.max_total_hits(Some(max_total_hits)); search.scoring_strategy( @@ -1132,6 +1139,7 @@ pub fn perform_search( document_scores, degraded, used_negative_operator, + query_vector, }, semantic_hit_count, ) = search_from_kind(index_uid, search_kind, search)?; @@ -1222,6 +1230,7 @@ pub fn perform_search( hits: documents, hits_info, query: q.unwrap_or_default(), + query_vector, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, @@ -1734,6 +1743,7 @@ pub fn perform_similar( document_scores, degraded: _, used_negative_operator: _, + query_vector: _, } = similar.execute().map_err(|err| match err { milli::Error::UserError(milli::UserError::InvalidFilter(_)) => { ResponseError::from_msg(err.to_string(), Code::InvalidSimilarFilter) diff --git a/crates/meilisearch/tests/search/hybrid.rs b/crates/meilisearch/tests/search/hybrid.rs index d95e6fb64..b2970f233 100644 --- a/crates/meilisearch/tests/search/hybrid.rs +++ b/crates/meilisearch/tests/search/hybrid.rs @@ -148,7 +148,70 @@ async fn simple_search() { ) .await; snapshot!(code, @"200 OK"); - snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":{"embeddings":[[1.0,2.0]],"regenerate":false}}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":{"embeddings":[[2.0,3.0]],"regenerate":false}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":{"embeddings":[[1.0,3.0]],"regenerate":false}}}]"###); + snapshot!(response, @r#" + { + "hits": [ + { + "title": "Captain Planet", + "desc": "He's not part of the Marvel Cinematic Universe", + "id": "2", + "_vectors": { + "default": { + "embeddings": [ + [ + 1.0, + 2.0 + ] + ], + "regenerate": false + } + } + }, + { + "title": "Captain Marvel", + "desc": "a Shazam ersatz", + "id": "3", + "_vectors": { + "default": { + "embeddings": [ + [ + 2.0, + 3.0 + ] + ], + "regenerate": false + } + } + }, + { + "title": "Shazam!", + "desc": "a Captain Marvel ersatz", + "id": "1", + "_vectors": { + "default": { + "embeddings": [ + [ + 1.0, + 3.0 + ] + ], + "regenerate": false + } + } + } + ], + "query": "Captain", + "queryVector": [ + 1.0, + 1.0 + ], + "processingTimeMs": "[duration]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 3, + "semanticHitCount": 0 + } + "#); snapshot!(response["semanticHitCount"], @"0"); let (response, code) = index @@ -157,7 +220,73 @@ async fn simple_search() { ) .await; snapshot!(code, @"200 OK"); - snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":{"embeddings":[[2.0,3.0]],"regenerate":false}},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":{"embeddings":[[1.0,2.0]],"regenerate":false}},"_rankingScore":0.9848484848484848},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":{"embeddings":[[1.0,3.0]],"regenerate":false}},"_rankingScore":0.9472135901451112}]"###); + snapshot!(response, @r#" + { + "hits": [ + { + "title": "Captain Marvel", + "desc": "a Shazam ersatz", + "id": "3", + "_vectors": { + "default": { + "embeddings": [ + [ + 2.0, + 3.0 + ] + ], + "regenerate": false + } + }, + "_rankingScore": 0.990290343761444 + }, + { + "title": "Captain Planet", + "desc": "He's not part of the Marvel Cinematic Universe", + "id": "2", + "_vectors": { + "default": { + "embeddings": [ + [ + 1.0, + 2.0 + ] + ], + "regenerate": false + } + }, + "_rankingScore": 0.9848484848484848 + }, + { + "title": "Shazam!", + "desc": "a Captain Marvel ersatz", + "id": "1", + "_vectors": { + "default": { + "embeddings": [ + [ + 1.0, + 3.0 + ] + ], + "regenerate": false + } + }, + "_rankingScore": 0.9472135901451112 + } + ], + "query": "Captain", + "queryVector": [ + 1.0, + 1.0 + ], + "processingTimeMs": "[duration]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 3, + "semanticHitCount": 2 + } + "#); snapshot!(response["semanticHitCount"], @"2"); let (response, code) = index @@ -166,7 +295,73 @@ async fn simple_search() { ) .await; snapshot!(code, @"200 OK"); - snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":{"embeddings":[[2.0,3.0]],"regenerate":false}},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":{"embeddings":[[1.0,2.0]],"regenerate":false}},"_rankingScore":0.974341630935669},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":{"embeddings":[[1.0,3.0]],"regenerate":false}},"_rankingScore":0.9472135901451112}]"###); + snapshot!(response, @r#" + { + "hits": [ + { + "title": "Captain Marvel", + "desc": "a Shazam ersatz", + "id": "3", + "_vectors": { + "default": { + "embeddings": [ + [ + 2.0, + 3.0 + ] + ], + "regenerate": false + } + }, + "_rankingScore": 0.990290343761444 + }, + { + "title": "Captain Planet", + "desc": "He's not part of the Marvel Cinematic Universe", + "id": "2", + "_vectors": { + "default": { + "embeddings": [ + [ + 1.0, + 2.0 + ] + ], + "regenerate": false + } + }, + "_rankingScore": 0.974341630935669 + }, + { + "title": "Shazam!", + "desc": "a Captain Marvel ersatz", + "id": "1", + "_vectors": { + "default": { + "embeddings": [ + [ + 1.0, + 3.0 + ] + ], + "regenerate": false + } + }, + "_rankingScore": 0.9472135901451112 + } + ], + "query": "Captain", + "queryVector": [ + 1.0, + 1.0 + ], + "processingTimeMs": "[duration]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 3, + "semanticHitCount": 3 + } + "#); snapshot!(response["semanticHitCount"], @"3"); } diff --git a/crates/meilisearch/tests/search/multi/mod.rs b/crates/meilisearch/tests/search/multi/mod.rs index b9eed56da..16ee3906e 100644 --- a/crates/meilisearch/tests/search/multi/mod.rs +++ b/crates/meilisearch/tests/search/multi/mod.rs @@ -3703,7 +3703,7 @@ async fn federation_vector_two_indexes() { ]})) .await; snapshot!(code, @"200 OK"); - snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r###" + snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r#" { "hits": [ { @@ -3911,9 +3911,20 @@ async fn federation_vector_two_indexes() { "limit": 20, "offset": 0, "estimatedTotalHits": 8, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.5 + ], + "1": [ + 0.8, + 0.6 + ] + }, "semanticHitCount": 6 } - "###); + "#); // hybrid search, distinct embedder let (response, code) = server @@ -3923,7 +3934,7 @@ async fn federation_vector_two_indexes() { ]})) .await; snapshot!(code, @"200 OK"); - snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r###" + snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r#" { "hits": [ { @@ -4139,9 +4150,20 @@ async fn federation_vector_two_indexes() { "limit": 20, "offset": 0, "estimatedTotalHits": 8, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.5 + ], + "1": [ + -1.0, + 0.6 + ] + }, "semanticHitCount": 8 } - "###); + "#); } #[actix_rt::test] diff --git a/crates/meilisearch/tests/search/multi/proxy.rs b/crates/meilisearch/tests/search/multi/proxy.rs index 943295da5..2b1623ff8 100644 --- a/crates/meilisearch/tests/search/multi/proxy.rs +++ b/crates/meilisearch/tests/search/multi/proxy.rs @@ -2,8 +2,9 @@ use std::sync::Arc; use actix_http::StatusCode; use meili_snap::{json_string, snapshot}; -use wiremock::matchers::AnyMatcher; -use wiremock::{Mock, MockServer, ResponseTemplate}; +use wiremock::matchers::method; +use wiremock::matchers::{path, AnyMatcher}; +use wiremock::{Mock, MockServer, Request, ResponseTemplate}; use crate::common::{Server, Value, SCORE_DOCUMENTS}; use crate::json; @@ -415,6 +416,503 @@ async fn remote_sharding() { "###); } +#[actix_rt::test] +async fn remote_sharding_retrieve_vectors() { + let ms0 = Server::new().await; + let ms1 = Server::new().await; + let ms2 = Server::new().await; + let index0 = ms0.index("test"); + let index1 = ms1.index("test"); + let index2 = ms2.index("test"); + + // enable feature + + let (response, code) = ms0.set_features(json!({"network": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["network"]), @"true"); + let (response, code) = ms1.set_features(json!({"network": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["network"]), @"true"); + let (response, code) = ms2.set_features(json!({"network": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["network"]), @"true"); + + // set self + + let (response, code) = ms0.set_network(json!({"self": "ms0"})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response), @r###" + { + "self": "ms0", + "remotes": {} + } + "###); + let (response, code) = ms1.set_network(json!({"self": "ms1"})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response), @r###" + { + "self": "ms1", + "remotes": {} + } + "###); + let (response, code) = ms2.set_network(json!({"self": "ms2"})).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response), @r###" + { + "self": "ms2", + "remotes": {} + } + "###); + + // setup embedders + + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/")) + .respond_with(move |req: &Request| { + println!("Received request: {:?}", req); + let text = req.body_json::().unwrap().to_lowercase(); + let patterns = [ + ("batman", [1.0, 0.0, 0.0]), + ("dark", [0.0, 0.1, 0.0]), + ("knight", [0.1, 0.1, 0.0]), + ("returns", [0.0, 0.0, 0.2]), + ("part", [0.05, 0.1, 0.0]), + ("1", [0.3, 0.05, 0.0]), + ("2", [0.2, 0.05, 0.0]), + ]; + let mut embedding = vec![0.; 3]; + for (pattern, vector) in patterns { + if text.contains(pattern) { + for (i, v) in vector.iter().enumerate() { + embedding[i] += v; + } + } + } + ResponseTemplate::new(200).set_body_json(json!({ "data": embedding })) + }) + .mount(&mock_server) + .await; + let url = mock_server.uri(); + + for (server, index) in [(&ms0, &index0), (&ms1, &index1), (&ms2, &index2)] { + let (response, code) = index + .update_settings(json!({ + "embedders": { + "rest": { + "source": "rest", + "url": url, + "dimensions": 3, + "request": "{{text}}", + "response": { "data": "{{embedding}}" }, + "documentTemplate": "{{doc.name}}", + }, + }, + })) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await.succeeded(); + } + + // wrap servers + let ms0 = Arc::new(ms0); + let ms1 = Arc::new(ms1); + let ms2 = Arc::new(ms2); + + let rms0 = LocalMeili::new(ms0.clone()).await; + let rms1 = LocalMeili::new(ms1.clone()).await; + let rms2 = LocalMeili::new(ms2.clone()).await; + + // set network + let network = json!({"remotes": { + "ms0": { + "url": rms0.url() + }, + "ms1": { + "url": rms1.url() + }, + "ms2": { + "url": rms2.url() + } + }}); + + let (_response, status_code) = ms0.set_network(network.clone()).await; + snapshot!(status_code, @"200 OK"); + let (_response, status_code) = ms1.set_network(network.clone()).await; + snapshot!(status_code, @"200 OK"); + let (_response, status_code) = ms2.set_network(network.clone()).await; + snapshot!(status_code, @"200 OK"); + + // multi vector search: one query per remote + + let request = json!({ + "federation": {}, + "queries": [ + { + "q": "batman", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms0" + } + }, + { + "q": "dark knight", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + { + "q": "returns", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms2" + } + }, + ] + }); + + let (response, _status_code) = ms0.multi_search(request.clone()).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#" + { + "hits": [], + "processingTimeMs": "[time]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 0, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.0 + ], + "1": [ + 0.1, + 0.2, + 0.0 + ], + "2": [ + 0.0, + 0.0, + 0.2 + ] + }, + "semanticHitCount": 0, + "remoteErrors": {} + } + "#); + + // multi vector search: two local queries, one remote + + let request = json!({ + "federation": {}, + "queries": [ + { + "q": "batman", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms0" + } + }, + { + "q": "dark knight", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms0" + } + }, + { + "q": "returns", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms2" + } + }, + ] + }); + + let (response, _status_code) = ms0.multi_search(request.clone()).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#" + { + "hits": [], + "processingTimeMs": "[time]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 0, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.0 + ], + "1": [ + 0.1, + 0.2, + 0.0 + ], + "2": [ + 0.0, + 0.0, + 0.2 + ] + }, + "semanticHitCount": 0, + "remoteErrors": {} + } + "#); + + // multi vector search: two queries on the same remote + + let request = json!({ + "federation": {}, + "queries": [ + { + "q": "batman", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms0" + } + }, + { + "q": "dark knight", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + { + "q": "returns", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + ] + }); + + let (response, _status_code) = ms0.multi_search(request.clone()).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#" + { + "hits": [], + "processingTimeMs": "[time]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 0, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.0 + ], + "1": [ + 0.1, + 0.2, + 0.0 + ], + "2": [ + 0.0, + 0.0, + 0.2 + ] + }, + "semanticHitCount": 0, + "remoteErrors": {} + } + "#); + + // multi search: two vector, one keyword + + let request = json!({ + "federation": {}, + "queries": [ + { + "q": "batman", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms0" + } + }, + { + "q": "dark knight", + "indexUid": "test", + "hybrid": { + "semanticRatio": 0.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + { + "q": "returns", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + ] + }); + + let (response, _status_code) = ms0.multi_search(request.clone()).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#" + { + "hits": [], + "processingTimeMs": "[time]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 0, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.0 + ], + "2": [ + 0.0, + 0.0, + 0.2 + ] + }, + "semanticHitCount": 0, + "remoteErrors": {} + } + "#); + + // multi vector search: no local queries, all remote + + let request = json!({ + "federation": {}, + "queries": [ + { + "q": "batman", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + { + "q": "dark knight", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + { + "q": "returns", + "indexUid": "test", + "hybrid": { + "semanticRatio": 1.0, + "embedder": "rest" + }, + "retrieveVectors": true, + "federationOptions": { + "remote": "ms1" + } + }, + ] + }); + + let (response, _status_code) = ms0.multi_search(request.clone()).await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#" + { + "hits": [], + "processingTimeMs": "[time]", + "limit": 20, + "offset": 0, + "estimatedTotalHits": 0, + "queryVectors": { + "0": [ + 1.0, + 0.0, + 0.0 + ], + "1": [ + 0.1, + 0.2, + 0.0 + ], + "2": [ + 0.0, + 0.0, + 0.2 + ] + }, + "remoteErrors": {} + } + "#); +} + #[actix_rt::test] async fn error_unregistered_remote() { let ms0 = Server::new().await; diff --git a/crates/meilisearch/tests/vector/binary_quantized.rs b/crates/meilisearch/tests/vector/binary_quantized.rs index 6fcfa3563..adb0da441 100644 --- a/crates/meilisearch/tests/vector/binary_quantized.rs +++ b/crates/meilisearch/tests/vector/binary_quantized.rs @@ -323,7 +323,7 @@ async fn binary_quantize_clear_documents() { // Make sure the arroy DB has been cleared let (documents, _code) = index.search_post(json!({ "hybrid": { "embedder": "manual" }, "vector": [1, 1, 1] })).await; - snapshot!(documents, @r###" + snapshot!(documents, @r#" { "hits": [], "query": "", @@ -333,5 +333,5 @@ async fn binary_quantize_clear_documents() { "estimatedTotalHits": 0, "semanticHitCount": 0 } - "###); + "#); } diff --git a/crates/meilisearch/tests/vector/mod.rs b/crates/meilisearch/tests/vector/mod.rs index 8538f5f1e..551b82178 100644 --- a/crates/meilisearch/tests/vector/mod.rs +++ b/crates/meilisearch/tests/vector/mod.rs @@ -686,7 +686,7 @@ async fn clear_documents() { // Make sure the arroy DB has been cleared let (documents, _code) = index.search_post(json!({ "vector": [1, 1, 1], "hybrid": {"embedder": "manual"} })).await; - snapshot!(documents, @r###" + snapshot!(documents, @r#" { "hits": [], "query": "", @@ -696,7 +696,7 @@ async fn clear_documents() { "estimatedTotalHits": 0, "semanticHitCount": 0 } - "###); + "#); } #[actix_rt::test] @@ -740,7 +740,7 @@ async fn add_remove_one_vector_4588() { json!({"vector": [1, 1, 1], "hybrid": {"semanticRatio": 1.0, "embedder": "manual"} }), ) .await; - snapshot!(documents, @r###" + snapshot!(documents, @r#" { "hits": [ { @@ -755,7 +755,7 @@ async fn add_remove_one_vector_4588() { "estimatedTotalHits": 1, "semanticHitCount": 1 } - "###); + "#); let (documents, _code) = index .get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() }) diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index a29b6c4c7..1535c73ba 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -7,7 +7,7 @@ use roaring::RoaringBitmap; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; use crate::search::new::{distinct_fid, distinct_single_docid}; use crate::search::SemanticSearch; -use crate::vector::SearchQuery; +use crate::vector::{Embedding, SearchQuery}; use crate::{Index, MatchingWords, Result, Search, SearchResult}; struct ScoreWithRatioResult { @@ -16,6 +16,7 @@ struct ScoreWithRatioResult { document_scores: Vec<(u32, ScoreWithRatio)>, degraded: bool, used_negative_operator: bool, + query_vector: Option, } type ScoreWithRatio = (Vec, f32); @@ -85,6 +86,7 @@ impl ScoreWithRatioResult { document_scores, degraded: results.degraded, used_negative_operator: results.used_negative_operator, + query_vector: results.query_vector, } } @@ -186,6 +188,7 @@ impl ScoreWithRatioResult { degraded: vector_results.degraded | keyword_results.degraded, used_negative_operator: vector_results.used_negative_operator | keyword_results.used_negative_operator, + query_vector: vector_results.query_vector, }, semantic_hit_count, )) @@ -209,6 +212,7 @@ impl Search<'_> { terms_matching_strategy: self.terms_matching_strategy, scoring_strategy: ScoringStrategy::Detailed, words_limit: self.words_limit, + retrieve_vectors: self.retrieve_vectors, exhaustive_number_hits: self.exhaustive_number_hits, max_total_hits: self.max_total_hits, rtxn: self.rtxn, @@ -265,7 +269,7 @@ impl Search<'_> { }; search.semantic = Some(SemanticSearch { - vector: Some(vector_query), + vector: Some(vector_query.clone()), embedder_name, embedder, quantized, @@ -322,6 +326,7 @@ fn return_keyword_results( mut document_scores, degraded, used_negative_operator, + query_vector, }: SearchResult, ) -> (SearchResult, Option) { let (documents_ids, document_scores) = if offset >= documents_ids.len() || @@ -348,6 +353,7 @@ fn return_keyword_results( document_scores, degraded, used_negative_operator, + query_vector, }, Some(0), ) diff --git a/crates/milli/src/search/mod.rs b/crates/milli/src/search/mod.rs index 8742db24d..2ae931ff5 100644 --- a/crates/milli/src/search/mod.rs +++ b/crates/milli/src/search/mod.rs @@ -52,6 +52,7 @@ pub struct Search<'a> { terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, words_limit: usize, + retrieve_vectors: bool, exhaustive_number_hits: bool, max_total_hits: Option, rtxn: &'a heed::RoTxn<'a>, @@ -75,6 +76,7 @@ impl<'a> Search<'a> { geo_param: GeoSortParameter::default(), terms_matching_strategy: TermsMatchingStrategy::default(), scoring_strategy: Default::default(), + retrieve_vectors: false, exhaustive_number_hits: false, max_total_hits: None, words_limit: 10, @@ -161,6 +163,11 @@ impl<'a> Search<'a> { self } + pub fn retrieve_vectors(&mut self, retrieve_vectors: bool) -> &mut Search<'a> { + self.retrieve_vectors = retrieve_vectors; + self + } + /// Forces the search to exhaustively compute the number of candidates, /// this will increase the search time but allows finite pagination. pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> { @@ -233,6 +240,7 @@ impl<'a> Search<'a> { } let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; + let mut query_vector = None; let PartialSearchResult { located_query_terms, candidates, @@ -247,24 +255,29 @@ impl<'a> Search<'a> { embedder, quantized, media: _, - }) => execute_vector_search( - &mut ctx, - vector, - self.scoring_strategy, - self.exhaustive_number_hits, - self.max_total_hits, - universe, - &self.sort_criteria, - &self.distinct, - self.geo_param, - self.offset, - self.limit, - embedder_name, - embedder, - *quantized, - self.time_budget.clone(), - self.ranking_score_threshold, - )?, + }) => { + if self.retrieve_vectors { + query_vector = Some(vector.clone()); + } + execute_vector_search( + &mut ctx, + vector, + self.scoring_strategy, + self.exhaustive_number_hits, + self.max_total_hits, + universe, + &self.sort_criteria, + &self.distinct, + self.geo_param, + self.offset, + self.limit, + embedder_name, + embedder, + *quantized, + self.time_budget.clone(), + self.ranking_score_threshold, + )? + } _ => execute_search( &mut ctx, self.query.as_deref(), @@ -306,6 +319,7 @@ impl<'a> Search<'a> { documents_ids, degraded, used_negative_operator, + query_vector, }) } } @@ -324,6 +338,7 @@ impl fmt::Debug for Search<'_> { terms_matching_strategy, scoring_strategy, words_limit, + retrieve_vectors, exhaustive_number_hits, max_total_hits, rtxn: _, @@ -344,6 +359,7 @@ impl fmt::Debug for Search<'_> { .field("searchable_attributes", searchable_attributes) .field("terms_matching_strategy", terms_matching_strategy) .field("scoring_strategy", scoring_strategy) + .field("retrieve_vectors", retrieve_vectors) .field("exhaustive_number_hits", exhaustive_number_hits) .field("max_total_hits", max_total_hits) .field("words_limit", words_limit) @@ -366,6 +382,7 @@ pub struct SearchResult { pub document_scores: Vec>, pub degraded: bool, pub used_negative_operator: bool, + pub query_vector: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/crates/milli/src/search/similar.rs b/crates/milli/src/search/similar.rs index 903b5fcf9..2235f6436 100644 --- a/crates/milli/src/search/similar.rs +++ b/crates/milli/src/search/similar.rs @@ -130,6 +130,7 @@ impl<'a> Similar<'a> { document_scores, degraded: false, used_negative_operator: false, + query_vector: None, }) } } diff --git a/crates/milli/src/test_index.rs b/crates/milli/src/test_index.rs index 6bb6b1345..d174319d0 100644 --- a/crates/milli/src/test_index.rs +++ b/crates/milli/src/test_index.rs @@ -1097,6 +1097,7 @@ fn bug_3021_fourth() { mut documents_ids, degraded: _, used_negative_operator: _, + query_vector: _, } = search.execute().unwrap(); let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap(); documents_ids.sort_unstable();