Merge pull request #5778 from meilisearch/retrieve-query-vectors

Return query vector
This commit is contained in:
Mubelotix
2025-08-13 07:39:15 +00:00
committed by GitHub
13 changed files with 833 additions and 36 deletions

View File

@ -224,6 +224,7 @@ impl<Method: AggregateMethod> SearchAggregator<Method> {
let SearchResult { let SearchResult {
hits: _, hits: _,
query: _, query: _,
query_vector: _,
processing_time_ms, processing_time_ms,
hits_info: _, hits_info: _,
semantic_hit_count: _, semantic_hit_count: _,

View File

@ -13,6 +13,7 @@ use meilisearch_types::error::ResponseError;
use meilisearch_types::features::{Network, Remote}; use meilisearch_types::features::{Network, Remote};
use meilisearch_types::milli::order_by_map::OrderByMap; use meilisearch_types::milli::order_by_map::OrderByMap;
use meilisearch_types::milli::score_details::{ScoreDetails, WeightedScoreValue}; use meilisearch_types::milli::score_details::{ScoreDetails, WeightedScoreValue};
use meilisearch_types::milli::vector::Embedding;
use meilisearch_types::milli::{self, DocumentId, OrderBy, TimeBudget, DEFAULT_VALUES_PER_FACET}; use meilisearch_types::milli::{self, DocumentId, OrderBy, TimeBudget, DEFAULT_VALUES_PER_FACET};
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
@ -46,6 +47,7 @@ pub async fn perform_federated_search(
let deadline = before_search + std::time::Duration::from_secs(9); let deadline = before_search + std::time::Duration::from_secs(9);
let required_hit_count = federation.limit + federation.offset; let required_hit_count = federation.limit + federation.offset;
let retrieve_vectors = queries.iter().any(|q| q.retrieve_vectors);
let network = index_scheduler.network(); let network = index_scheduler.network();
@ -91,6 +93,7 @@ pub async fn perform_federated_search(
federation, federation,
mut semantic_hit_count, mut semantic_hit_count,
mut results_by_index, mut results_by_index,
mut query_vectors,
previous_query_data: _, previous_query_data: _,
facet_order, facet_order,
} = search_by_index; } = search_by_index;
@ -122,7 +125,26 @@ pub async fn perform_federated_search(
.map(|hit| hit.hit()) .map(|hit| hit.hit())
.collect(); .collect();
// 3.3. merge facets // 3.3. merge query vectors
let query_vectors = if retrieve_vectors {
for remote_results in remote_results.iter_mut() {
if let Some(remote_vectors) = remote_results.query_vectors.take() {
for (key, value) in remote_vectors.into_iter() {
debug_assert!(
!query_vectors.contains_key(&key),
"Query vector for query {key} already exists"
);
query_vectors.insert(key, value);
}
}
}
Some(query_vectors)
} else {
None
};
// 3.4. merge facets
let (facet_distribution, facet_stats, facets_by_index) = let (facet_distribution, facet_stats, facets_by_index) =
facet_order.merge(federation.merge_facets, remote_results, facets); facet_order.merge(federation.merge_facets, remote_results, facets);
@ -140,6 +162,7 @@ pub async fn perform_federated_search(
offset: federation.offset, offset: federation.offset,
estimated_total_hits, estimated_total_hits,
}, },
query_vectors,
semantic_hit_count, semantic_hit_count,
degraded, degraded,
used_negative_operator, used_negative_operator,
@ -408,6 +431,7 @@ fn merge_metadata(
hits: _, hits: _,
processing_time_ms, processing_time_ms,
hits_info, hits_info,
query_vectors: _,
semantic_hit_count: _, semantic_hit_count: _,
facet_distribution: _, facet_distribution: _,
facet_stats: _, facet_stats: _,
@ -657,6 +681,7 @@ struct SearchByIndex {
// Then when merging, we'll update its value if there is any semantic hit // Then when merging, we'll update its value if there is any semantic hit
semantic_hit_count: Option<u32>, semantic_hit_count: Option<u32>,
results_by_index: Vec<SearchResultByIndex>, results_by_index: Vec<SearchResultByIndex>,
query_vectors: BTreeMap<usize, Embedding>,
previous_query_data: Option<(RankingRules, usize, String)>, previous_query_data: Option<(RankingRules, usize, String)>,
// remember the order and name of first index for each facet when merging with index settings // remember the order and name of first index for each facet when merging with index settings
// to detect if the order is inconsistent for a facet. // to detect if the order is inconsistent for a facet.
@ -674,6 +699,7 @@ impl SearchByIndex {
federation, federation,
semantic_hit_count: None, semantic_hit_count: None,
results_by_index: Vec::with_capacity(index_count), results_by_index: Vec::with_capacity(index_count),
query_vectors: BTreeMap::new(),
previous_query_data: None, previous_query_data: None,
} }
} }
@ -837,8 +863,19 @@ impl SearchByIndex {
document_scores, document_scores,
degraded: query_degraded, degraded: query_degraded,
used_negative_operator: query_used_negative_operator, used_negative_operator: query_used_negative_operator,
query_vector,
} = result; } = result;
if query.retrieve_vectors {
if let Some(query_vector) = query_vector {
debug_assert!(
!self.query_vectors.contains_key(&query_index),
"Query vector for query {query_index} already exists"
);
self.query_vectors.insert(query_index, query_vector);
}
}
candidates |= query_candidates; candidates |= query_candidates;
degraded |= query_degraded; degraded |= query_degraded;
used_negative_operator |= query_used_negative_operator; used_negative_operator |= query_used_negative_operator;

View File

@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex}; use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex};
use crate::milli::vector::Embedding;
pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0; pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0;
@ -117,6 +118,9 @@ pub struct FederatedSearchResult {
#[serde(flatten)] #[serde(flatten)]
pub hits_info: HitsInfo, pub hits_info: HitsInfo,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub query_vectors: Option<BTreeMap<usize, Embedding>>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub semantic_hit_count: Option<u32>, pub semantic_hit_count: Option<u32>,
@ -144,6 +148,7 @@ impl fmt::Debug for FederatedSearchResult {
hits, hits,
processing_time_ms, processing_time_ms,
hits_info, hits_info,
query_vectors,
semantic_hit_count, semantic_hit_count,
degraded, degraded,
used_negative_operator, used_negative_operator,
@ -158,6 +163,10 @@ impl fmt::Debug for FederatedSearchResult {
debug.field("processing_time_ms", &processing_time_ms); debug.field("processing_time_ms", &processing_time_ms);
debug.field("hits", &format!("[{} hits returned]", hits.len())); debug.field("hits", &format!("[{} hits returned]", hits.len()));
debug.field("hits_info", &hits_info); debug.field("hits_info", &hits_info);
if let Some(query_vectors) = query_vectors {
let known = query_vectors.len();
debug.field("query_vectors", &format!("[{known} known vectors]"));
}
if *used_negative_operator { if *used_negative_operator {
debug.field("used_negative_operator", used_negative_operator); debug.field("used_negative_operator", used_negative_operator);
} }

View File

@ -841,6 +841,8 @@ pub struct SearchHit {
pub struct SearchResult { pub struct SearchResult {
pub hits: Vec<SearchHit>, pub hits: Vec<SearchHit>,
pub query: String, pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub query_vector: Option<Vec<f32>>,
pub processing_time_ms: u128, pub processing_time_ms: u128,
#[serde(flatten)] #[serde(flatten)]
pub hits_info: HitsInfo, pub hits_info: HitsInfo,
@ -865,6 +867,7 @@ impl fmt::Debug for SearchResult {
let SearchResult { let SearchResult {
hits, hits,
query, query,
query_vector,
processing_time_ms, processing_time_ms,
hits_info, hits_info,
facet_distribution, facet_distribution,
@ -879,6 +882,9 @@ impl fmt::Debug for SearchResult {
debug.field("processing_time_ms", &processing_time_ms); debug.field("processing_time_ms", &processing_time_ms);
debug.field("hits", &format!("[{} hits returned]", hits.len())); debug.field("hits", &format!("[{} hits returned]", hits.len()));
debug.field("query", &query); debug.field("query", &query);
if query_vector.is_some() {
debug.field("query_vector", &"[...]");
}
debug.field("hits_info", &hits_info); debug.field("hits_info", &hits_info);
if *used_negative_operator { if *used_negative_operator {
debug.field("used_negative_operator", used_negative_operator); debug.field("used_negative_operator", used_negative_operator);
@ -1050,6 +1056,7 @@ pub fn prepare_search<'t>(
.map(|x| x as usize) .map(|x| x as usize)
.unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS);
search.retrieve_vectors(query.retrieve_vectors);
search.exhaustive_number_hits(is_finite_pagination); search.exhaustive_number_hits(is_finite_pagination);
search.max_total_hits(Some(max_total_hits)); search.max_total_hits(Some(max_total_hits));
search.scoring_strategy( search.scoring_strategy(
@ -1132,6 +1139,7 @@ pub fn perform_search(
document_scores, document_scores,
degraded, degraded,
used_negative_operator, used_negative_operator,
query_vector,
}, },
semantic_hit_count, semantic_hit_count,
) = search_from_kind(index_uid, search_kind, search)?; ) = search_from_kind(index_uid, search_kind, search)?;
@ -1222,6 +1230,7 @@ pub fn perform_search(
hits: documents, hits: documents,
hits_info, hits_info,
query: q.unwrap_or_default(), query: q.unwrap_or_default(),
query_vector,
processing_time_ms: before_search.elapsed().as_millis(), processing_time_ms: before_search.elapsed().as_millis(),
facet_distribution, facet_distribution,
facet_stats, facet_stats,
@ -1734,6 +1743,7 @@ pub fn perform_similar(
document_scores, document_scores,
degraded: _, degraded: _,
used_negative_operator: _, used_negative_operator: _,
query_vector: _,
} = similar.execute().map_err(|err| match err { } = similar.execute().map_err(|err| match err {
milli::Error::UserError(milli::UserError::InvalidFilter(_)) => { milli::Error::UserError(milli::UserError::InvalidFilter(_)) => {
ResponseError::from_msg(err.to_string(), Code::InvalidSimilarFilter) ResponseError::from_msg(err.to_string(), Code::InvalidSimilarFilter)

View File

@ -148,7 +148,70 @@ async fn simple_search() {
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":{"embeddings":[[1.0,2.0]],"regenerate":false}}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":{"embeddings":[[2.0,3.0]],"regenerate":false}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":{"embeddings":[[1.0,3.0]],"regenerate":false}}}]"###); snapshot!(response, @r#"
{
"hits": [
{
"title": "Captain Planet",
"desc": "He's not part of the Marvel Cinematic Universe",
"id": "2",
"_vectors": {
"default": {
"embeddings": [
[
1.0,
2.0
]
],
"regenerate": false
}
}
},
{
"title": "Captain Marvel",
"desc": "a Shazam ersatz",
"id": "3",
"_vectors": {
"default": {
"embeddings": [
[
2.0,
3.0
]
],
"regenerate": false
}
}
},
{
"title": "Shazam!",
"desc": "a Captain Marvel ersatz",
"id": "1",
"_vectors": {
"default": {
"embeddings": [
[
1.0,
3.0
]
],
"regenerate": false
}
}
}
],
"query": "Captain",
"queryVector": [
1.0,
1.0
],
"processingTimeMs": "[duration]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 3,
"semanticHitCount": 0
}
"#);
snapshot!(response["semanticHitCount"], @"0"); snapshot!(response["semanticHitCount"], @"0");
let (response, code) = index let (response, code) = index
@ -157,7 +220,73 @@ async fn simple_search() {
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":{"embeddings":[[2.0,3.0]],"regenerate":false}},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":{"embeddings":[[1.0,2.0]],"regenerate":false}},"_rankingScore":0.9848484848484848},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":{"embeddings":[[1.0,3.0]],"regenerate":false}},"_rankingScore":0.9472135901451112}]"###); snapshot!(response, @r#"
{
"hits": [
{
"title": "Captain Marvel",
"desc": "a Shazam ersatz",
"id": "3",
"_vectors": {
"default": {
"embeddings": [
[
2.0,
3.0
]
],
"regenerate": false
}
},
"_rankingScore": 0.990290343761444
},
{
"title": "Captain Planet",
"desc": "He's not part of the Marvel Cinematic Universe",
"id": "2",
"_vectors": {
"default": {
"embeddings": [
[
1.0,
2.0
]
],
"regenerate": false
}
},
"_rankingScore": 0.9848484848484848
},
{
"title": "Shazam!",
"desc": "a Captain Marvel ersatz",
"id": "1",
"_vectors": {
"default": {
"embeddings": [
[
1.0,
3.0
]
],
"regenerate": false
}
},
"_rankingScore": 0.9472135901451112
}
],
"query": "Captain",
"queryVector": [
1.0,
1.0
],
"processingTimeMs": "[duration]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 3,
"semanticHitCount": 2
}
"#);
snapshot!(response["semanticHitCount"], @"2"); snapshot!(response["semanticHitCount"], @"2");
let (response, code) = index let (response, code) = index
@ -166,7 +295,73 @@ async fn simple_search() {
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":{"embeddings":[[2.0,3.0]],"regenerate":false}},"_rankingScore":0.990290343761444},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":{"embeddings":[[1.0,2.0]],"regenerate":false}},"_rankingScore":0.974341630935669},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":{"embeddings":[[1.0,3.0]],"regenerate":false}},"_rankingScore":0.9472135901451112}]"###); snapshot!(response, @r#"
{
"hits": [
{
"title": "Captain Marvel",
"desc": "a Shazam ersatz",
"id": "3",
"_vectors": {
"default": {
"embeddings": [
[
2.0,
3.0
]
],
"regenerate": false
}
},
"_rankingScore": 0.990290343761444
},
{
"title": "Captain Planet",
"desc": "He's not part of the Marvel Cinematic Universe",
"id": "2",
"_vectors": {
"default": {
"embeddings": [
[
1.0,
2.0
]
],
"regenerate": false
}
},
"_rankingScore": 0.974341630935669
},
{
"title": "Shazam!",
"desc": "a Captain Marvel ersatz",
"id": "1",
"_vectors": {
"default": {
"embeddings": [
[
1.0,
3.0
]
],
"regenerate": false
}
},
"_rankingScore": 0.9472135901451112
}
],
"query": "Captain",
"queryVector": [
1.0,
1.0
],
"processingTimeMs": "[duration]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 3,
"semanticHitCount": 3
}
"#);
snapshot!(response["semanticHitCount"], @"3"); snapshot!(response["semanticHitCount"], @"3");
} }

View File

@ -3703,7 +3703,7 @@ async fn federation_vector_two_indexes() {
]})) ]}))
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r###" snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r#"
{ {
"hits": [ "hits": [
{ {
@ -3911,9 +3911,20 @@ async fn federation_vector_two_indexes() {
"limit": 20, "limit": 20,
"offset": 0, "offset": 0,
"estimatedTotalHits": 8, "estimatedTotalHits": 8,
"queryVectors": {
"0": [
1.0,
0.0,
0.5
],
"1": [
0.8,
0.6
]
},
"semanticHitCount": 6 "semanticHitCount": 6
} }
"###); "#);
// hybrid search, distinct embedder // hybrid search, distinct embedder
let (response, code) = server let (response, code) = server
@ -3923,7 +3934,7 @@ async fn federation_vector_two_indexes() {
]})) ]}))
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r###" snapshot!(json_string!(response, { ".processingTimeMs" => "[duration]", ".**._rankingScore" => "[score]" }), @r#"
{ {
"hits": [ "hits": [
{ {
@ -4139,9 +4150,20 @@ async fn federation_vector_two_indexes() {
"limit": 20, "limit": 20,
"offset": 0, "offset": 0,
"estimatedTotalHits": 8, "estimatedTotalHits": 8,
"queryVectors": {
"0": [
1.0,
0.0,
0.5
],
"1": [
-1.0,
0.6
]
},
"semanticHitCount": 8 "semanticHitCount": 8
} }
"###); "#);
} }
#[actix_rt::test] #[actix_rt::test]

View File

@ -2,8 +2,9 @@ use std::sync::Arc;
use actix_http::StatusCode; use actix_http::StatusCode;
use meili_snap::{json_string, snapshot}; use meili_snap::{json_string, snapshot};
use wiremock::matchers::AnyMatcher; use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate}; use wiremock::matchers::{path, AnyMatcher};
use wiremock::{Mock, MockServer, Request, ResponseTemplate};
use crate::common::{Server, Value, SCORE_DOCUMENTS}; use crate::common::{Server, Value, SCORE_DOCUMENTS};
use crate::json; use crate::json;
@ -415,6 +416,503 @@ async fn remote_sharding() {
"###); "###);
} }
#[actix_rt::test]
async fn remote_sharding_retrieve_vectors() {
let ms0 = Server::new().await;
let ms1 = Server::new().await;
let ms2 = Server::new().await;
let index0 = ms0.index("test");
let index1 = ms1.index("test");
let index2 = ms2.index("test");
// enable feature
let (response, code) = ms0.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
let (response, code) = ms1.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
let (response, code) = ms2.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
// set self
let (response, code) = ms0.set_network(json!({"self": "ms0"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms0",
"remotes": {}
}
"###);
let (response, code) = ms1.set_network(json!({"self": "ms1"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms1",
"remotes": {}
}
"###);
let (response, code) = ms2.set_network(json!({"self": "ms2"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms2",
"remotes": {}
}
"###);
// setup embedders
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(move |req: &Request| {
println!("Received request: {:?}", req);
let text = req.body_json::<String>().unwrap().to_lowercase();
let patterns = [
("batman", [1.0, 0.0, 0.0]),
("dark", [0.0, 0.1, 0.0]),
("knight", [0.1, 0.1, 0.0]),
("returns", [0.0, 0.0, 0.2]),
("part", [0.05, 0.1, 0.0]),
("1", [0.3, 0.05, 0.0]),
("2", [0.2, 0.05, 0.0]),
];
let mut embedding = vec![0.; 3];
for (pattern, vector) in patterns {
if text.contains(pattern) {
for (i, v) in vector.iter().enumerate() {
embedding[i] += v;
}
}
}
ResponseTemplate::new(200).set_body_json(json!({ "data": embedding }))
})
.mount(&mock_server)
.await;
let url = mock_server.uri();
for (server, index) in [(&ms0, &index0), (&ms1, &index1), (&ms2, &index2)] {
let (response, code) = index
.update_settings(json!({
"embedders": {
"rest": {
"source": "rest",
"url": url,
"dimensions": 3,
"request": "{{text}}",
"response": { "data": "{{embedding}}" },
"documentTemplate": "{{doc.name}}",
},
},
}))
.await;
snapshot!(code, @"202 Accepted");
server.wait_task(response.uid()).await.succeeded();
}
// wrap servers
let ms0 = Arc::new(ms0);
let ms1 = Arc::new(ms1);
let ms2 = Arc::new(ms2);
let rms0 = LocalMeili::new(ms0.clone()).await;
let rms1 = LocalMeili::new(ms1.clone()).await;
let rms2 = LocalMeili::new(ms2.clone()).await;
// set network
let network = json!({"remotes": {
"ms0": {
"url": rms0.url()
},
"ms1": {
"url": rms1.url()
},
"ms2": {
"url": rms2.url()
}
}});
let (_response, status_code) = ms0.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
let (_response, status_code) = ms1.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
let (_response, status_code) = ms2.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
// multi vector search: one query per remote
let request = json!({
"federation": {},
"queries": [
{
"q": "batman",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": "dark knight",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": "returns",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms2"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
1.0,
0.0,
0.0
],
"1": [
0.1,
0.2,
0.0
],
"2": [
0.0,
0.0,
0.2
]
},
"semanticHitCount": 0,
"remoteErrors": {}
}
"#);
// multi vector search: two local queries, one remote
let request = json!({
"federation": {},
"queries": [
{
"q": "batman",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": "dark knight",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": "returns",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms2"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
1.0,
0.0,
0.0
],
"1": [
0.1,
0.2,
0.0
],
"2": [
0.0,
0.0,
0.2
]
},
"semanticHitCount": 0,
"remoteErrors": {}
}
"#);
// multi vector search: two queries on the same remote
let request = json!({
"federation": {},
"queries": [
{
"q": "batman",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": "dark knight",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": "returns",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
1.0,
0.0,
0.0
],
"1": [
0.1,
0.2,
0.0
],
"2": [
0.0,
0.0,
0.2
]
},
"semanticHitCount": 0,
"remoteErrors": {}
}
"#);
// multi search: two vector, one keyword
let request = json!({
"federation": {},
"queries": [
{
"q": "batman",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": "dark knight",
"indexUid": "test",
"hybrid": {
"semanticRatio": 0.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": "returns",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
1.0,
0.0,
0.0
],
"2": [
0.0,
0.0,
0.2
]
},
"semanticHitCount": 0,
"remoteErrors": {}
}
"#);
// multi vector search: no local queries, all remote
let request = json!({
"federation": {},
"queries": [
{
"q": "batman",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": "dark knight",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": "returns",
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
1.0,
0.0,
0.0
],
"1": [
0.1,
0.2,
0.0
],
"2": [
0.0,
0.0,
0.2
]
},
"remoteErrors": {}
}
"#);
}
#[actix_rt::test] #[actix_rt::test]
async fn error_unregistered_remote() { async fn error_unregistered_remote() {
let ms0 = Server::new().await; let ms0 = Server::new().await;

View File

@ -323,7 +323,7 @@ async fn binary_quantize_clear_documents() {
// Make sure the arroy DB has been cleared // Make sure the arroy DB has been cleared
let (documents, _code) = let (documents, _code) =
index.search_post(json!({ "hybrid": { "embedder": "manual" }, "vector": [1, 1, 1] })).await; index.search_post(json!({ "hybrid": { "embedder": "manual" }, "vector": [1, 1, 1] })).await;
snapshot!(documents, @r###" snapshot!(documents, @r#"
{ {
"hits": [], "hits": [],
"query": "", "query": "",
@ -333,5 +333,5 @@ async fn binary_quantize_clear_documents() {
"estimatedTotalHits": 0, "estimatedTotalHits": 0,
"semanticHitCount": 0 "semanticHitCount": 0
} }
"###); "#);
} }

View File

@ -686,7 +686,7 @@ async fn clear_documents() {
// Make sure the arroy DB has been cleared // Make sure the arroy DB has been cleared
let (documents, _code) = let (documents, _code) =
index.search_post(json!({ "vector": [1, 1, 1], "hybrid": {"embedder": "manual"} })).await; index.search_post(json!({ "vector": [1, 1, 1], "hybrid": {"embedder": "manual"} })).await;
snapshot!(documents, @r###" snapshot!(documents, @r#"
{ {
"hits": [], "hits": [],
"query": "", "query": "",
@ -696,7 +696,7 @@ async fn clear_documents() {
"estimatedTotalHits": 0, "estimatedTotalHits": 0,
"semanticHitCount": 0 "semanticHitCount": 0
} }
"###); "#);
} }
#[actix_rt::test] #[actix_rt::test]
@ -740,7 +740,7 @@ async fn add_remove_one_vector_4588() {
json!({"vector": [1, 1, 1], "hybrid": {"semanticRatio": 1.0, "embedder": "manual"} }), json!({"vector": [1, 1, 1], "hybrid": {"semanticRatio": 1.0, "embedder": "manual"} }),
) )
.await; .await;
snapshot!(documents, @r###" snapshot!(documents, @r#"
{ {
"hits": [ "hits": [
{ {
@ -755,7 +755,7 @@ async fn add_remove_one_vector_4588() {
"estimatedTotalHits": 1, "estimatedTotalHits": 1,
"semanticHitCount": 1 "semanticHitCount": 1
} }
"###); "#);
let (documents, _code) = index let (documents, _code) = index
.get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() }) .get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() })

View File

@ -7,7 +7,7 @@ use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::search::new::{distinct_fid, distinct_single_docid}; use crate::search::new::{distinct_fid, distinct_single_docid};
use crate::search::SemanticSearch; use crate::search::SemanticSearch;
use crate::vector::SearchQuery; use crate::vector::{Embedding, SearchQuery};
use crate::{Index, MatchingWords, Result, Search, SearchResult}; use crate::{Index, MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult { struct ScoreWithRatioResult {
@ -16,6 +16,7 @@ struct ScoreWithRatioResult {
document_scores: Vec<(u32, ScoreWithRatio)>, document_scores: Vec<(u32, ScoreWithRatio)>,
degraded: bool, degraded: bool,
used_negative_operator: bool, used_negative_operator: bool,
query_vector: Option<Embedding>,
} }
type ScoreWithRatio = (Vec<ScoreDetails>, f32); type ScoreWithRatio = (Vec<ScoreDetails>, f32);
@ -85,6 +86,7 @@ impl ScoreWithRatioResult {
document_scores, document_scores,
degraded: results.degraded, degraded: results.degraded,
used_negative_operator: results.used_negative_operator, used_negative_operator: results.used_negative_operator,
query_vector: results.query_vector,
} }
} }
@ -186,6 +188,7 @@ impl ScoreWithRatioResult {
degraded: vector_results.degraded | keyword_results.degraded, degraded: vector_results.degraded | keyword_results.degraded,
used_negative_operator: vector_results.used_negative_operator used_negative_operator: vector_results.used_negative_operator
| keyword_results.used_negative_operator, | keyword_results.used_negative_operator,
query_vector: vector_results.query_vector,
}, },
semantic_hit_count, semantic_hit_count,
)) ))
@ -209,6 +212,7 @@ impl Search<'_> {
terms_matching_strategy: self.terms_matching_strategy, terms_matching_strategy: self.terms_matching_strategy,
scoring_strategy: ScoringStrategy::Detailed, scoring_strategy: ScoringStrategy::Detailed,
words_limit: self.words_limit, words_limit: self.words_limit,
retrieve_vectors: self.retrieve_vectors,
exhaustive_number_hits: self.exhaustive_number_hits, exhaustive_number_hits: self.exhaustive_number_hits,
max_total_hits: self.max_total_hits, max_total_hits: self.max_total_hits,
rtxn: self.rtxn, rtxn: self.rtxn,
@ -265,7 +269,7 @@ impl Search<'_> {
}; };
search.semantic = Some(SemanticSearch { search.semantic = Some(SemanticSearch {
vector: Some(vector_query), vector: Some(vector_query.clone()),
embedder_name, embedder_name,
embedder, embedder,
quantized, quantized,
@ -322,6 +326,7 @@ fn return_keyword_results(
mut document_scores, mut document_scores,
degraded, degraded,
used_negative_operator, used_negative_operator,
query_vector,
}: SearchResult, }: SearchResult,
) -> (SearchResult, Option<u32>) { ) -> (SearchResult, Option<u32>) {
let (documents_ids, document_scores) = if offset >= documents_ids.len() || let (documents_ids, document_scores) = if offset >= documents_ids.len() ||
@ -348,6 +353,7 @@ fn return_keyword_results(
document_scores, document_scores,
degraded, degraded,
used_negative_operator, used_negative_operator,
query_vector,
}, },
Some(0), Some(0),
) )

View File

@ -52,6 +52,7 @@ pub struct Search<'a> {
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
words_limit: usize, words_limit: usize,
retrieve_vectors: bool,
exhaustive_number_hits: bool, exhaustive_number_hits: bool,
max_total_hits: Option<usize>, max_total_hits: Option<usize>,
rtxn: &'a heed::RoTxn<'a>, rtxn: &'a heed::RoTxn<'a>,
@ -75,6 +76,7 @@ impl<'a> Search<'a> {
geo_param: GeoSortParameter::default(), geo_param: GeoSortParameter::default(),
terms_matching_strategy: TermsMatchingStrategy::default(), terms_matching_strategy: TermsMatchingStrategy::default(),
scoring_strategy: Default::default(), scoring_strategy: Default::default(),
retrieve_vectors: false,
exhaustive_number_hits: false, exhaustive_number_hits: false,
max_total_hits: None, max_total_hits: None,
words_limit: 10, words_limit: 10,
@ -161,6 +163,11 @@ impl<'a> Search<'a> {
self self
} }
pub fn retrieve_vectors(&mut self, retrieve_vectors: bool) -> &mut Search<'a> {
self.retrieve_vectors = retrieve_vectors;
self
}
/// Forces the search to exhaustively compute the number of candidates, /// Forces the search to exhaustively compute the number of candidates,
/// this will increase the search time but allows finite pagination. /// this will increase the search time but allows finite pagination.
pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> { pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> {
@ -233,6 +240,7 @@ impl<'a> Search<'a> {
} }
let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?;
let mut query_vector = None;
let PartialSearchResult { let PartialSearchResult {
located_query_terms, located_query_terms,
candidates, candidates,
@ -247,24 +255,29 @@ impl<'a> Search<'a> {
embedder, embedder,
quantized, quantized,
media: _, media: _,
}) => execute_vector_search( }) => {
&mut ctx, if self.retrieve_vectors {
vector, query_vector = Some(vector.clone());
self.scoring_strategy, }
self.exhaustive_number_hits, execute_vector_search(
self.max_total_hits, &mut ctx,
universe, vector,
&self.sort_criteria, self.scoring_strategy,
&self.distinct, self.exhaustive_number_hits,
self.geo_param, self.max_total_hits,
self.offset, universe,
self.limit, &self.sort_criteria,
embedder_name, &self.distinct,
embedder, self.geo_param,
*quantized, self.offset,
self.time_budget.clone(), self.limit,
self.ranking_score_threshold, embedder_name,
)?, embedder,
*quantized,
self.time_budget.clone(),
self.ranking_score_threshold,
)?
}
_ => execute_search( _ => execute_search(
&mut ctx, &mut ctx,
self.query.as_deref(), self.query.as_deref(),
@ -306,6 +319,7 @@ impl<'a> Search<'a> {
documents_ids, documents_ids,
degraded, degraded,
used_negative_operator, used_negative_operator,
query_vector,
}) })
} }
} }
@ -324,6 +338,7 @@ impl fmt::Debug for Search<'_> {
terms_matching_strategy, terms_matching_strategy,
scoring_strategy, scoring_strategy,
words_limit, words_limit,
retrieve_vectors,
exhaustive_number_hits, exhaustive_number_hits,
max_total_hits, max_total_hits,
rtxn: _, rtxn: _,
@ -344,6 +359,7 @@ impl fmt::Debug for Search<'_> {
.field("searchable_attributes", searchable_attributes) .field("searchable_attributes", searchable_attributes)
.field("terms_matching_strategy", terms_matching_strategy) .field("terms_matching_strategy", terms_matching_strategy)
.field("scoring_strategy", scoring_strategy) .field("scoring_strategy", scoring_strategy)
.field("retrieve_vectors", retrieve_vectors)
.field("exhaustive_number_hits", exhaustive_number_hits) .field("exhaustive_number_hits", exhaustive_number_hits)
.field("max_total_hits", max_total_hits) .field("max_total_hits", max_total_hits)
.field("words_limit", words_limit) .field("words_limit", words_limit)
@ -366,6 +382,7 @@ pub struct SearchResult {
pub document_scores: Vec<Vec<ScoreDetails>>, pub document_scores: Vec<Vec<ScoreDetails>>,
pub degraded: bool, pub degraded: bool,
pub used_negative_operator: bool, pub used_negative_operator: bool,
pub query_vector: Option<Embedding>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]

View File

@ -130,6 +130,7 @@ impl<'a> Similar<'a> {
document_scores, document_scores,
degraded: false, degraded: false,
used_negative_operator: false, used_negative_operator: false,
query_vector: None,
}) })
} }
} }

View File

@ -1097,6 +1097,7 @@ fn bug_3021_fourth() {
mut documents_ids, mut documents_ids,
degraded: _, degraded: _,
used_negative_operator: _, used_negative_operator: _,
query_vector: _,
} = search.execute().unwrap(); } = search.execute().unwrap();
let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap(); let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap();
documents_ids.sort_unstable(); documents_ids.sort_unstable();