Implement for multi-search

This commit is contained in:
Mubelotix
2025-07-25 11:45:51 +02:00
parent 26da478b5b
commit a7fe2abca4
6 changed files with 318 additions and 26 deletions

View File

@ -47,6 +47,7 @@ pub async fn perform_federated_search(
let deadline = before_search + std::time::Duration::from_secs(9);
let required_hit_count = federation.limit + federation.offset;
let retrieve_vectors = queries.iter().any(|q| q.retrieve_vectors);
let network = index_scheduler.network();
@ -92,6 +93,7 @@ pub async fn perform_federated_search(
federation,
mut semantic_hit_count,
mut results_by_index,
mut query_vectors,
previous_query_data: _,
facet_order,
} = search_by_index;
@ -123,7 +125,26 @@ pub async fn perform_federated_search(
.map(|hit| hit.hit())
.collect();
// 3.3. merge facets
// 3.3. merge query vectors
let query_vectors = if retrieve_vectors {
for remote_results in remote_results.iter_mut() {
if let Some(remote_vectors) = remote_results.query_vectors.take() {
for (key, value) in remote_vectors.into_iter() {
debug_assert!(
!query_vectors.contains_key(&key),
"Query vector for query {key} already exists"
);
query_vectors.insert(key, value);
}
}
}
Some(query_vectors)
} else {
None
};
// 3.4. merge facets
let (facet_distribution, facet_stats, facets_by_index) =
facet_order.merge(federation.merge_facets, remote_results, facets);
@ -141,6 +162,7 @@ pub async fn perform_federated_search(
offset: federation.offset,
estimated_total_hits,
},
query_vectors,
semantic_hit_count,
degraded,
used_negative_operator,
@ -409,6 +431,7 @@ fn merge_metadata(
hits: _,
processing_time_ms,
hits_info,
query_vectors: _,
semantic_hit_count: _,
facet_distribution: _,
facet_stats: _,
@ -658,6 +681,7 @@ struct SearchByIndex {
// Then when merging, we'll update its value if there is any semantic hit
semantic_hit_count: Option<u32>,
results_by_index: Vec<SearchResultByIndex>,
query_vectors: BTreeMap<usize, Embedding>,
previous_query_data: Option<(RankingRules, usize, String)>,
// remember the order and name of first index for each facet when merging with index settings
// to detect if the order is inconsistent for a facet.
@ -675,6 +699,7 @@ impl SearchByIndex {
federation,
semantic_hit_count: None,
results_by_index: Vec::with_capacity(index_count),
query_vectors: BTreeMap::new(),
previous_query_data: None,
}
}
@ -842,6 +867,16 @@ impl SearchByIndex {
query_vector,
} = result;
if query.retrieve_vectors {
if let Some(query_vector) = query_vector {
debug_assert!(
!self.query_vectors.contains_key(&query_index),
"Query vector for query {query_index} already exists"
);
self.query_vectors.insert(query_index, query_vector);
}
}
candidates |= query_candidates;
degraded |= query_degraded;
used_negative_operator |= query_used_negative_operator;

View File

@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex};
use crate::milli::vector::Embedding;
pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0;
@ -117,6 +118,9 @@ pub struct FederatedSearchResult {
#[serde(flatten)]
pub hits_info: HitsInfo,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub query_vectors: Option<BTreeMap<usize, Embedding>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub semantic_hit_count: Option<u32>,
@ -144,6 +148,7 @@ impl fmt::Debug for FederatedSearchResult {
hits,
processing_time_ms,
hits_info,
query_vectors,
semantic_hit_count,
degraded,
used_negative_operator,
@ -158,6 +163,10 @@ impl fmt::Debug for FederatedSearchResult {
debug.field("processing_time_ms", &processing_time_ms);
debug.field("hits", &format!("[{} hits returned]", hits.len()));
debug.field("hits_info", &hits_info);
if let Some(query_vectors) = query_vectors {
let known = query_vectors.len();
debug.field("query_vectors", &format!("[{known} known vectors]"));
}
if *used_negative_operator {
debug.field("used_negative_operator", used_negative_operator);
}

View File

@ -1020,7 +1020,7 @@ pub fn prepare_search<'t>(
.map_err(milli::Error::from)?
}
};
search.semantic(
search.semantic_auto_embedded(
embedder_name.clone(),
embedder.clone(),
*quantized,

View File

@ -2,8 +2,9 @@ use std::sync::Arc;
use actix_http::StatusCode;
use meili_snap::{json_string, snapshot};
use wiremock::matchers::AnyMatcher;
use wiremock::{Mock, MockServer, ResponseTemplate};
use wiremock::matchers::method;
use wiremock::matchers::{path, AnyMatcher};
use wiremock::{Mock, MockServer, Request, ResponseTemplate};
use crate::common::{Server, Value, SCORE_DOCUMENTS};
use crate::json;
@ -415,6 +416,209 @@ async fn remote_sharding() {
"###);
}
#[actix_rt::test]
async fn remote_sharding_retrieve_vectors() {
let ms0 = Server::new().await;
let ms1 = Server::new().await;
let ms2 = Server::new().await;
let index0 = ms0.index("test");
let index1 = ms1.index("test");
let index2 = ms2.index("test");
// enable feature
let (response, code) = ms0.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
let (response, code) = ms1.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
let (response, code) = ms2.set_features(json!({"network": true})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["network"]), @"true");
// set self
let (response, code) = ms0.set_network(json!({"self": "ms0"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms0",
"remotes": {}
}
"###);
let (response, code) = ms1.set_network(json!({"self": "ms1"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms1",
"remotes": {}
}
"###);
let (response, code) = ms2.set_network(json!({"self": "ms2"})).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response), @r###"
{
"self": "ms2",
"remotes": {}
}
"###);
// setup embedders
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(move |req: &Request| {
println!("Received request: {:?}", req);
let text = req.body_json::<String>().unwrap().to_lowercase();
let patterns = [
("batman", [1.0, 0.0, 0.0]),
("dark", [0.0, 0.1, 0.0]),
("knight", [0.1, 0.1, 0.0]),
("returns", [0.0, 0.0, 0.2]),
("part", [0.05, 0.1, 0.0]),
("1", [0.3, 0.05, 0.0]),
("2", [0.2, 0.05, 0.0]),
];
let mut embedding = vec![0.; 3];
for (pattern, vector) in patterns {
if text.contains(pattern) {
for (i, v) in vector.iter().enumerate() {
embedding[i] += v;
}
}
}
ResponseTemplate::new(200).set_body_json(json!({ "data": embedding }))
})
.mount(&mock_server)
.await;
let url = mock_server.uri();
for (server, index) in [(&ms0, &index0), (&ms1, &index1), (&ms2, &index2)] {
let (response, code) = index
.update_settings(json!({
"embedders": {
"rest": {
"source": "rest",
"url": url,
"dimensions": 3,
"request": "{{text}}",
"response": { "data": "{{embedding}}" },
"documentTemplate": "{{doc.name}}",
},
},
}))
.await;
snapshot!(code, @"202 Accepted");
server.wait_task(response.uid()).await.succeeded();
}
// wrap servers
let ms0 = Arc::new(ms0);
let ms1 = Arc::new(ms1);
let ms2 = Arc::new(ms2);
let rms0 = LocalMeili::new(ms0.clone()).await;
let rms1 = LocalMeili::new(ms1.clone()).await;
let rms2 = LocalMeili::new(ms2.clone()).await;
// set network
let network = json!({"remotes": {
"ms0": {
"url": rms0.url()
},
"ms1": {
"url": rms1.url()
},
"ms2": {
"url": rms2.url()
}
}});
let (_response, status_code) = ms0.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
let (_response, status_code) = ms1.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
let (_response, status_code) = ms2.set_network(network.clone()).await;
snapshot!(status_code, @"200 OK");
// perform multi-search
let query = "badman returns";
let request = json!({
"federation": {},
"queries": [
{
"q": query,
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms0"
}
},
{
"q": query,
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms1"
}
},
{
"q": query,
"indexUid": "test",
"hybrid": {
"semanticRatio": 1.0,
"embedder": "rest"
},
"retrieveVectors": true,
"federationOptions": {
"remote": "ms2"
}
},
]
});
let (response, _status_code) = ms0.multi_search(request.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
{
"hits": [],
"processingTimeMs": "[time]",
"limit": 20,
"offset": 0,
"estimatedTotalHits": 0,
"queryVectors": {
"0": [
0.0,
0.0,
0.2
],
"1": [
0.0,
0.0,
0.2
],
"2": [
0.0,
0.0,
0.2
]
},
"semanticHitCount": 0,
"remoteErrors": {}
}
"#);
}
#[actix_rt::test]
async fn error_unregistered_remote() {
let ms0 = Server::new().await;

View File

@ -230,7 +230,14 @@ impl Search<'_> {
}
// no embedder, no semantic search
let Some(SemanticSearch { vector, embedder_name, embedder, quantized, media }) = semantic
let Some(SemanticSearch {
vector,
mut auto_embedded,
embedder_name,
embedder,
quantized,
media,
}) = semantic
else {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
};
@ -253,7 +260,10 @@ impl Search<'_> {
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_search(query, Some(deadline)) {
Ok(embedding) => embedding,
Ok(embedding) => {
auto_embedded = true;
embedding
}
Err(error) => {
tracing::error!(error=%error, "Embedding failed");
return Ok(return_keyword_results(
@ -268,6 +278,7 @@ impl Search<'_> {
search.semantic = Some(SemanticSearch {
vector: Some(vector_query.clone()),
auto_embedded,
embedder_name,
embedder,
quantized,
@ -280,7 +291,7 @@ impl Search<'_> {
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
let (mut merge_results, semantic_hit_count) = ScoreWithRatioResult::merge(
let (merge_results, semantic_hit_count) = ScoreWithRatioResult::merge(
vector_results,
keyword_results,
self.offset,
@ -289,7 +300,6 @@ impl Search<'_> {
search.index,
search.rtxn,
)?;
merge_results.query_vector = Some(vector_query);
assert!(merge_results.documents_ids.len() <= self.limit);
Ok((merge_results, Some(semantic_hit_count)))
}

View File

@ -32,6 +32,7 @@ pub mod similar;
#[derive(Debug, Clone)]
pub struct SemanticSearch {
vector: Option<Vec<f32>>,
auto_embedded: bool,
media: Option<serde_json::Value>,
embedder_name: String,
embedder: Arc<Embedder>,
@ -97,7 +98,33 @@ impl<'a> Search<'a> {
vector: Option<Embedding>,
media: Option<serde_json::Value>,
) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector, media });
self.semantic = Some(SemanticSearch {
embedder_name,
auto_embedded: false,
embedder,
quantized,
vector,
media,
});
self
}
pub fn semantic_auto_embedded(
&mut self,
embedder_name: String,
embedder: Arc<Embedder>,
quantized: bool,
vector: Option<Embedding>,
media: Option<serde_json::Value>,
) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch {
embedder_name,
auto_embedded: true,
embedder,
quantized,
vector,
media,
});
self
}
@ -225,6 +252,7 @@ impl<'a> Search<'a> {
}
let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?;
let mut query_vector = None;
let PartialSearchResult {
located_query_terms,
candidates,
@ -235,26 +263,32 @@ impl<'a> Search<'a> {
} = match self.semantic.as_ref() {
Some(SemanticSearch {
vector: Some(vector),
auto_embedded,
embedder_name,
embedder,
quantized,
media: _,
}) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
&self.distinct,
self.geo_param,
self.offset,
self.limit,
embedder_name,
embedder,
*quantized,
self.time_budget.clone(),
self.ranking_score_threshold,
)?,
}) => {
if *auto_embedded {
query_vector = Some(vector.clone());
}
execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
&self.distinct,
self.geo_param,
self.offset,
self.limit,
embedder_name,
embedder,
*quantized,
self.time_budget.clone(),
self.ranking_score_threshold,
)?
}
_ => execute_search(
&mut ctx,
self.query.as_deref(),
@ -295,7 +329,7 @@ impl<'a> Search<'a> {
documents_ids,
degraded,
used_negative_operator,
query_vector: None,
query_vector,
})
}
}