Implement for multi-search

This commit is contained in:
Mubelotix
2025-07-25 11:45:51 +02:00
parent 26da478b5b
commit a7fe2abca4
6 changed files with 318 additions and 26 deletions

View File

@ -47,6 +47,7 @@ pub async fn perform_federated_search(
let deadline = before_search + std::time::Duration::from_secs(9); let deadline = before_search + std::time::Duration::from_secs(9);
let required_hit_count = federation.limit + federation.offset; let required_hit_count = federation.limit + federation.offset;
let retrieve_vectors = queries.iter().any(|q| q.retrieve_vectors);
let network = index_scheduler.network(); let network = index_scheduler.network();
@ -92,6 +93,7 @@ pub async fn perform_federated_search(
federation, federation,
mut semantic_hit_count, mut semantic_hit_count,
mut results_by_index, mut results_by_index,
mut query_vectors,
previous_query_data: _, previous_query_data: _,
facet_order, facet_order,
} = search_by_index; } = search_by_index;
@ -123,7 +125,26 @@ pub async fn perform_federated_search(
.map(|hit| hit.hit()) .map(|hit| hit.hit())
.collect(); .collect();
// 3.3. merge facets // 3.3. merge query vectors
let query_vectors = if retrieve_vectors {
for remote_results in remote_results.iter_mut() {
if let Some(remote_vectors) = remote_results.query_vectors.take() {
for (key, value) in remote_vectors.into_iter() {
debug_assert!(
!query_vectors.contains_key(&key),
"Query vector for query {key} already exists"
);
query_vectors.insert(key, value);
}
}
}
Some(query_vectors)
} else {
None
};
// 3.4. merge facets
let (facet_distribution, facet_stats, facets_by_index) = let (facet_distribution, facet_stats, facets_by_index) =
facet_order.merge(federation.merge_facets, remote_results, facets); facet_order.merge(federation.merge_facets, remote_results, facets);
@ -141,6 +162,7 @@ pub async fn perform_federated_search(
offset: federation.offset, offset: federation.offset,
estimated_total_hits, estimated_total_hits,
}, },
query_vectors,
semantic_hit_count, semantic_hit_count,
degraded, degraded,
used_negative_operator, used_negative_operator,
@ -409,6 +431,7 @@ fn merge_metadata(
hits: _, hits: _,
processing_time_ms, processing_time_ms,
hits_info, hits_info,
query_vectors: _,
semantic_hit_count: _, semantic_hit_count: _,
facet_distribution: _, facet_distribution: _,
facet_stats: _, facet_stats: _,
@ -658,6 +681,7 @@ struct SearchByIndex {
// Then when merging, we'll update its value if there is any semantic hit // Then when merging, we'll update its value if there is any semantic hit
semantic_hit_count: Option<u32>, semantic_hit_count: Option<u32>,
results_by_index: Vec<SearchResultByIndex>, results_by_index: Vec<SearchResultByIndex>,
query_vectors: BTreeMap<usize, Embedding>,
previous_query_data: Option<(RankingRules, usize, String)>, previous_query_data: Option<(RankingRules, usize, String)>,
// remember the order and name of first index for each facet when merging with index settings // remember the order and name of first index for each facet when merging with index settings
// to detect if the order is inconsistent for a facet. // to detect if the order is inconsistent for a facet.
@ -675,6 +699,7 @@ impl SearchByIndex {
federation, federation,
semantic_hit_count: None, semantic_hit_count: None,
results_by_index: Vec::with_capacity(index_count), results_by_index: Vec::with_capacity(index_count),
query_vectors: BTreeMap::new(),
previous_query_data: None, previous_query_data: None,
} }
} }
@ -842,6 +867,16 @@ impl SearchByIndex {
query_vector, query_vector,
} = result; } = result;
if query.retrieve_vectors {
if let Some(query_vector) = query_vector {
debug_assert!(
!self.query_vectors.contains_key(&query_index),
"Query vector for query {query_index} already exists"
);
self.query_vectors.insert(query_index, query_vector);
}
}
candidates |= query_candidates; candidates |= query_candidates;
degraded |= query_degraded; degraded |= query_degraded;
used_negative_operator |= query_used_negative_operator; used_negative_operator |= query_used_negative_operator;

View File

@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex}; use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex};
use crate::milli::vector::Embedding;
pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0; pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0;
@ -117,6 +118,9 @@ pub struct FederatedSearchResult {
#[serde(flatten)] #[serde(flatten)]
pub hits_info: HitsInfo, pub hits_info: HitsInfo,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub query_vectors: Option<BTreeMap<usize, Embedding>>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub semantic_hit_count: Option<u32>, pub semantic_hit_count: Option<u32>,
@ -144,6 +148,7 @@ impl fmt::Debug for FederatedSearchResult {
hits, hits,
processing_time_ms, processing_time_ms,
hits_info, hits_info,
query_vectors,
semantic_hit_count, semantic_hit_count,
degraded, degraded,
used_negative_operator, used_negative_operator,
@ -158,6 +163,10 @@ impl fmt::Debug for FederatedSearchResult {
debug.field("processing_time_ms", &processing_time_ms); debug.field("processing_time_ms", &processing_time_ms);
debug.field("hits", &format!("[{} hits returned]", hits.len())); debug.field("hits", &format!("[{} hits returned]", hits.len()));
debug.field("hits_info", &hits_info); debug.field("hits_info", &hits_info);
if let Some(query_vectors) = query_vectors {
let known = query_vectors.len();
debug.field("query_vectors", &format!("[{known} known vectors]"));
}
if *used_negative_operator { if *used_negative_operator {
debug.field("used_negative_operator", used_negative_operator); debug.field("used_negative_operator", used_negative_operator);
} }

View File

@ -1020,7 +1020,7 @@ pub fn prepare_search<'t>(
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
} }
}; };
search.semantic( search.semantic_auto_embedded(
embedder_name.clone(), embedder_name.clone(),
embedder.clone(), embedder.clone(),
*quantized, *quantized,

View File

@ -2,8 +2,9 @@ use std::sync::Arc;
use actix_http::StatusCode; use actix_http::StatusCode;
use meili_snap::{json_string, snapshot}; use meili_snap::{json_string, snapshot};
use wiremock::matchers::AnyMatcher; use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate}; use wiremock::matchers::{path, AnyMatcher};
use wiremock::{Mock, MockServer, Request, ResponseTemplate};
use crate::common::{Server, Value, SCORE_DOCUMENTS}; use crate::common::{Server, Value, SCORE_DOCUMENTS};
use crate::json; use crate::json;
@ -415,6 +416,209 @@ async fn remote_sharding() {
"###); "###);
} }
#[actix_rt::test]
async fn remote_sharding_retrieve_vectors() {
let ms0 = Server::new().await;
let ms1 = Server::new().await;
let ms2 = Server::new().await;
let index0 = ms0.index("test");
let index1 = ms1.index("test");
let index2 = ms2.index("test");
// enable feature
let (response, code) = ms0.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
let (response, code) = ms1.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
let (response, code) = ms2.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
// set self
let (response, code) = ms0.set_network(json!({"self": "ms0"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms0",
"remotes": {}
}
"###);
let (response, code) = ms1.set_network(json!({"self": "ms1"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms1",
"remotes": {}
}
"###);
let (response, code) = ms2.set_network(json!({"self": "ms2"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms2",
"remotes": {}
}
"###);
// setup embedders
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(move |req: &Request| {
println!("Received request: {:?}", req);
let text = req.body_json::<String>().unwrap().to_lowercase();
let patterns = [
("batman", [1.0, 0.0, 0.0]),
("dark", [0.0, 0.1, 0.0]),
("knight", [0.1, 0.1, 0.0]),
("returns", [0.0, 0.0, 0.2]),
("part", [0.05, 0.1, 0.0]),
("1", [0.3, 0.05, 0.0]),
("2", [0.2, 0.05, 0.0]),
];
let mut embedding = vec![0.; 3];
for (pattern, vector) in patterns {
if text.contains(pattern) {
for (i, v) in vector.iter().enumerate() {
embedding[i] += v;
}
}
}
ResponseTemplate::new(200).set_body_json(json!({ "data": embedding }))
})
.mount(&mock_server)
.await;
let url = mock_server.uri();
for (server, index) in [(&ms0, &index0), (&ms1, &index1), (&ms2, &index2)] {
let (response, code) = index
.update_settings(json!({
"embedders": {
"rest": {
"source": "rest",
"url": url,
"dimensions": 3,
"request": "{{text}}",
"response": { "data": "{{embedding}}" },
"documentTemplate": "{{doc.name}}",
},
},
}))
.await;
snapshot!(code, @"202 Accepted");
server.wait_task(response.uid()).await.succeeded();
}
// wrap servers
let ms0 = Arc::new(ms0);
let ms1 = Arc::new(ms1);
let ms2 = Arc::new(ms2);
let rms0 = LocalMeili::new(ms0.clone()).await;
let rms1 = LocalMeili::new(ms1.clone()).await;
let rms2 = LocalMeili::new(ms2.clone()).await;
// set network
let network = json!({"remotes": {
"ms0": {
"url": rms0.url()
},
"ms1": {
"url": rms1.url()
},
"ms2": {
"url": rms2.url()
}
}});
let (_response, status_code) = ms0.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
let (_response, status_code) = ms1.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
let (_response, status_code) = ms2.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
// perform multi-search
let query = "badman returns";
let request = json!({
"federation": {},
"queries": [
{
"q": query,
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": query,
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": query,
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms2"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
0.0,
0.0,
0.2
],
"1": [
0.0,
0.0,
0.2
],
"2": [
0.0,
0.0,
0.2
]
},
"semanticHitCount": 0,
"remoteErrors": {}
}
"#);
}
#[actix_rt::test] #[actix_rt::test]
async fn error_unregistered_remote() { async fn error_unregistered_remote() {
let ms0 = Server::new().await; let ms0 = Server::new().await;

View File

@ -230,7 +230,14 @@ impl Search<'_> {
} }
// no embedder, no semantic search // no embedder, no semantic search
let Some(SemanticSearch { vector, embedder_name, embedder, quantized, media }) = semantic let Some(SemanticSearch {
vector,
mut auto_embedded,
embedder_name,
embedder,
quantized,
media,
}) = semantic
else { else {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results)); return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
}; };
@ -253,7 +260,10 @@ impl Search<'_> {
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_search(query, Some(deadline)) { match embedder.embed_search(query, Some(deadline)) {
Ok(embedding) => embedding, Ok(embedding) => {
auto_embedded = true;
embedding
}
Err(error) => { Err(error) => {
tracing::error!(error=%error, "Embedding failed"); tracing::error!(error=%error, "Embedding failed");
return Ok(return_keyword_results( return Ok(return_keyword_results(
@ -268,6 +278,7 @@ impl Search<'_> {
search.semantic = Some(SemanticSearch { search.semantic = Some(SemanticSearch {
vector: Some(vector_query.clone()), vector: Some(vector_query.clone()),
auto_embedded,
embedder_name, embedder_name,
embedder, embedder,
quantized, quantized,
@ -280,7 +291,7 @@ impl Search<'_> {
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
let (mut merge_results, semantic_hit_count) = ScoreWithRatioResult::merge( let (merge_results, semantic_hit_count) = ScoreWithRatioResult::merge(
vector_results, vector_results,
keyword_results, keyword_results,
self.offset, self.offset,
@ -289,7 +300,6 @@ impl Search<'_> {
search.index, search.index,
search.rtxn, search.rtxn,
)?; )?;
merge_results.query_vector = Some(vector_query);
assert!(merge_results.documents_ids.len() <= self.limit); assert!(merge_results.documents_ids.len() <= self.limit);
Ok((merge_results, Some(semantic_hit_count))) Ok((merge_results, Some(semantic_hit_count)))
} }

View File

@ -32,6 +32,7 @@ pub mod similar;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SemanticSearch { pub struct SemanticSearch {
vector: Option<Vec<f32>>, vector: Option<Vec<f32>>,
auto_embedded: bool,
media: Option<serde_json::Value>, media: Option<serde_json::Value>,
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
@ -97,7 +98,33 @@ impl<'a> Search<'a> {
vector: Option<Embedding>, vector: Option<Embedding>,
media: Option<serde_json::Value>, media: Option<serde_json::Value>,
) -> &mut Search<'a> { ) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector, media }); self.semantic = Some(SemanticSearch {
embedder_name,
auto_embedded: false,
embedder,
quantized,
vector,
media,
});
self
}
pub fn semantic_auto_embedded(
&mut self,
embedder_name: String,
embedder: Arc<Embedder>,
quantized: bool,
vector: Option<Embedding>,
media: Option<serde_json::Value>,
) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch {
embedder_name,
auto_embedded: true,
embedder,
quantized,
vector,
media,
});
self self
} }
@ -225,6 +252,7 @@ impl<'a> Search<'a> {
} }
let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?;
let mut query_vector = None;
let PartialSearchResult { let PartialSearchResult {
located_query_terms, located_query_terms,
candidates, candidates,
@ -235,26 +263,32 @@ impl<'a> Search<'a> {
} = match self.semantic.as_ref() { } = match self.semantic.as_ref() {
Some(SemanticSearch { Some(SemanticSearch {
vector: Some(vector), vector: Some(vector),
auto_embedded,
embedder_name, embedder_name,
embedder, embedder,
quantized, quantized,
media: _, media: _,
}) => execute_vector_search( }) => {
&mut ctx, if *auto_embedded {
vector, query_vector = Some(vector.clone());
self.scoring_strategy, }
universe, execute_vector_search(
&self.sort_criteria, &mut ctx,
&self.distinct, vector,
self.geo_param, self.scoring_strategy,
self.offset, universe,
self.limit, &self.sort_criteria,
embedder_name, &self.distinct,
embedder, self.geo_param,
*quantized, self.offset,
self.time_budget.clone(), self.limit,
self.ranking_score_threshold, embedder_name,
)?, embedder,
*quantized,
self.time_budget.clone(),
self.ranking_score_threshold,
)?
}
_ => execute_search( _ => execute_search(
&mut ctx, &mut ctx,
self.query.as_deref(), self.query.as_deref(),
@ -295,7 +329,7 @@ impl<'a> Search<'a> {
documents_ids, documents_ids,
degraded, degraded,
used_negative_operator, used_negative_operator,
query_vector: None, query_vector,
}) })
} }
} }