mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-09-04 03:36:30 +00:00
Implement for multi-search
This commit is contained in:
@ -47,6 +47,7 @@ pub async fn perform_federated_search(
|
||||
let deadline = before_search + std::time::Duration::from_secs(9);
|
||||
|
||||
let required_hit_count = federation.limit + federation.offset;
|
||||
let retrieve_vectors = queries.iter().any(|q| q.retrieve_vectors);
|
||||
|
||||
let network = index_scheduler.network();
|
||||
|
||||
@ -92,6 +93,7 @@ pub async fn perform_federated_search(
|
||||
federation,
|
||||
mut semantic_hit_count,
|
||||
mut results_by_index,
|
||||
mut query_vectors,
|
||||
previous_query_data: _,
|
||||
facet_order,
|
||||
} = search_by_index;
|
||||
@ -123,7 +125,26 @@ pub async fn perform_federated_search(
|
||||
.map(|hit| hit.hit())
|
||||
.collect();
|
||||
|
||||
// 3.3. merge facets
|
||||
// 3.3. merge query vectors
|
||||
let query_vectors = if retrieve_vectors {
|
||||
for remote_results in remote_results.iter_mut() {
|
||||
if let Some(remote_vectors) = remote_results.query_vectors.take() {
|
||||
for (key, value) in remote_vectors.into_iter() {
|
||||
debug_assert!(
|
||||
!query_vectors.contains_key(&key),
|
||||
"Query vector for query {key} already exists"
|
||||
);
|
||||
query_vectors.insert(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(query_vectors)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 3.4. merge facets
|
||||
let (facet_distribution, facet_stats, facets_by_index) =
|
||||
facet_order.merge(federation.merge_facets, remote_results, facets);
|
||||
|
||||
@ -141,6 +162,7 @@ pub async fn perform_federated_search(
|
||||
offset: federation.offset,
|
||||
estimated_total_hits,
|
||||
},
|
||||
query_vectors,
|
||||
semantic_hit_count,
|
||||
degraded,
|
||||
used_negative_operator,
|
||||
@ -409,6 +431,7 @@ fn merge_metadata(
|
||||
hits: _,
|
||||
processing_time_ms,
|
||||
hits_info,
|
||||
query_vectors: _,
|
||||
semantic_hit_count: _,
|
||||
facet_distribution: _,
|
||||
facet_stats: _,
|
||||
@ -658,6 +681,7 @@ struct SearchByIndex {
|
||||
// Then when merging, we'll update its value if there is any semantic hit
|
||||
semantic_hit_count: Option<u32>,
|
||||
results_by_index: Vec<SearchResultByIndex>,
|
||||
query_vectors: BTreeMap<usize, Embedding>,
|
||||
previous_query_data: Option<(RankingRules, usize, String)>,
|
||||
// remember the order and name of first index for each facet when merging with index settings
|
||||
// to detect if the order is inconsistent for a facet.
|
||||
@ -675,6 +699,7 @@ impl SearchByIndex {
|
||||
federation,
|
||||
semantic_hit_count: None,
|
||||
results_by_index: Vec::with_capacity(index_count),
|
||||
query_vectors: BTreeMap::new(),
|
||||
previous_query_data: None,
|
||||
}
|
||||
}
|
||||
@ -842,6 +867,16 @@ impl SearchByIndex {
|
||||
query_vector,
|
||||
} = result;
|
||||
|
||||
if query.retrieve_vectors {
|
||||
if let Some(query_vector) = query_vector {
|
||||
debug_assert!(
|
||||
!self.query_vectors.contains_key(&query_index),
|
||||
"Query vector for query {query_index} already exists"
|
||||
);
|
||||
self.query_vectors.insert(query_index, query_vector);
|
||||
}
|
||||
}
|
||||
|
||||
candidates |= query_candidates;
|
||||
degraded |= query_degraded;
|
||||
used_negative_operator |= query_used_negative_operator;
|
||||
|
@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex};
|
||||
use crate::milli::vector::Embedding;
|
||||
|
||||
pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0;
|
||||
|
||||
@ -117,6 +118,9 @@ pub struct FederatedSearchResult {
|
||||
#[serde(flatten)]
|
||||
pub hits_info: HitsInfo,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub query_vectors: Option<BTreeMap<usize, Embedding>>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub semantic_hit_count: Option<u32>,
|
||||
|
||||
@ -144,6 +148,7 @@ impl fmt::Debug for FederatedSearchResult {
|
||||
hits,
|
||||
processing_time_ms,
|
||||
hits_info,
|
||||
query_vectors,
|
||||
semantic_hit_count,
|
||||
degraded,
|
||||
used_negative_operator,
|
||||
@ -158,6 +163,10 @@ impl fmt::Debug for FederatedSearchResult {
|
||||
debug.field("processing_time_ms", &processing_time_ms);
|
||||
debug.field("hits", &format!("[{} hits returned]", hits.len()));
|
||||
debug.field("hits_info", &hits_info);
|
||||
if let Some(query_vectors) = query_vectors {
|
||||
let known = query_vectors.len();
|
||||
debug.field("query_vectors", &format!("[{known} known vectors]"));
|
||||
}
|
||||
if *used_negative_operator {
|
||||
debug.field("used_negative_operator", used_negative_operator);
|
||||
}
|
||||
|
@ -1020,7 +1020,7 @@ pub fn prepare_search<'t>(
|
||||
.map_err(milli::Error::from)?
|
||||
}
|
||||
};
|
||||
search.semantic(
|
||||
search.semantic_auto_embedded(
|
||||
embedder_name.clone(),
|
||||
embedder.clone(),
|
||||
*quantized,
|
||||
|
@ -2,8 +2,9 @@ use std::sync::Arc;
|
||||
|
||||
use actix_http::StatusCode;
|
||||
use meili_snap::{json_string, snapshot};
|
||||
use wiremock::matchers::AnyMatcher;
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::{path, AnyMatcher};
|
||||
use wiremock::{Mock, MockServer, Request, ResponseTemplate};
|
||||
|
||||
use crate::common::{Server, Value, SCORE_DOCUMENTS};
|
||||
use crate::json;
|
||||
@ -415,6 +416,209 @@ async fn remote_sharding() {
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn remote_sharding_retrieve_vectors() {
|
||||
let ms0 = Server::new().await;
|
||||
let ms1 = Server::new().await;
|
||||
let ms2 = Server::new().await;
|
||||
let index0 = ms0.index("test");
|
||||
let index1 = ms1.index("test");
|
||||
let index2 = ms2.index("test");
|
||||
|
||||
// enable feature
|
||||
|
||||
let (response, code) = ms0.set_features(json!({"network": true})).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response["network"]), @"true");
|
||||
let (response, code) = ms1.set_features(json!({"network": true})).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response["network"]), @"true");
|
||||
let (response, code) = ms2.set_features(json!({"network": true})).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response["network"]), @"true");
|
||||
|
||||
// set self
|
||||
|
||||
let (response, code) = ms0.set_network(json!({"self": "ms0"})).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response), @r###"
|
||||
{
|
||||
"self": "ms0",
|
||||
"remotes": {}
|
||||
}
|
||||
"###);
|
||||
let (response, code) = ms1.set_network(json!({"self": "ms1"})).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response), @r###"
|
||||
{
|
||||
"self": "ms1",
|
||||
"remotes": {}
|
||||
}
|
||||
"###);
|
||||
let (response, code) = ms2.set_network(json!({"self": "ms2"})).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response), @r###"
|
||||
{
|
||||
"self": "ms2",
|
||||
"remotes": {}
|
||||
}
|
||||
"###);
|
||||
|
||||
// setup embedders
|
||||
|
||||
let mock_server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/"))
|
||||
.respond_with(move |req: &Request| {
|
||||
println!("Received request: {:?}", req);
|
||||
let text = req.body_json::<String>().unwrap().to_lowercase();
|
||||
let patterns = [
|
||||
("batman", [1.0, 0.0, 0.0]),
|
||||
("dark", [0.0, 0.1, 0.0]),
|
||||
("knight", [0.1, 0.1, 0.0]),
|
||||
("returns", [0.0, 0.0, 0.2]),
|
||||
("part", [0.05, 0.1, 0.0]),
|
||||
("1", [0.3, 0.05, 0.0]),
|
||||
("2", [0.2, 0.05, 0.0]),
|
||||
];
|
||||
let mut embedding = vec![0.; 3];
|
||||
for (pattern, vector) in patterns {
|
||||
if text.contains(pattern) {
|
||||
for (i, v) in vector.iter().enumerate() {
|
||||
embedding[i] += v;
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseTemplate::new(200).set_body_json(json!({ "data": embedding }))
|
||||
})
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
let url = mock_server.uri();
|
||||
|
||||
for (server, index) in [(&ms0, &index0), (&ms1, &index1), (&ms2, &index2)] {
|
||||
let (response, code) = index
|
||||
.update_settings(json!({
|
||||
"embedders": {
|
||||
"rest": {
|
||||
"source": "rest",
|
||||
"url": url,
|
||||
"dimensions": 3,
|
||||
"request": "{{text}}",
|
||||
"response": { "data": "{{embedding}}" },
|
||||
"documentTemplate": "{{doc.name}}",
|
||||
},
|
||||
},
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
server.wait_task(response.uid()).await.succeeded();
|
||||
}
|
||||
|
||||
// wrap servers
|
||||
let ms0 = Arc::new(ms0);
|
||||
let ms1 = Arc::new(ms1);
|
||||
let ms2 = Arc::new(ms2);
|
||||
|
||||
let rms0 = LocalMeili::new(ms0.clone()).await;
|
||||
let rms1 = LocalMeili::new(ms1.clone()).await;
|
||||
let rms2 = LocalMeili::new(ms2.clone()).await;
|
||||
|
||||
// set network
|
||||
let network = json!({"remotes": {
|
||||
"ms0": {
|
||||
"url": rms0.url()
|
||||
},
|
||||
"ms1": {
|
||||
"url": rms1.url()
|
||||
},
|
||||
"ms2": {
|
||||
"url": rms2.url()
|
||||
}
|
||||
}});
|
||||
|
||||
let (_response, status_code) = ms0.set_network(network.clone()).await;
|
||||
snapshot!(status_code, @"200 OK");
|
||||
let (_response, status_code) = ms1.set_network(network.clone()).await;
|
||||
snapshot!(status_code, @"200 OK");
|
||||
let (_response, status_code) = ms2.set_network(network.clone()).await;
|
||||
snapshot!(status_code, @"200 OK");
|
||||
|
||||
// perform multi-search
|
||||
let query = "badman returns";
|
||||
let request = json!({
|
||||
"federation": {},
|
||||
"queries": [
|
||||
{
|
||||
"q": query,
|
||||
"indexUid": "test",
|
||||
"hybrid": {
|
||||
"semanticRatio": 1.0,
|
||||
"embedder": "rest"
|
||||
},
|
||||
"retrieveVectors": true,
|
||||
"federationOptions": {
|
||||
"remote": "ms0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"q": query,
|
||||
"indexUid": "test",
|
||||
"hybrid": {
|
||||
"semanticRatio": 1.0,
|
||||
"embedder": "rest"
|
||||
},
|
||||
"retrieveVectors": true,
|
||||
"federationOptions": {
|
||||
"remote": "ms1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"q": query,
|
||||
"indexUid": "test",
|
||||
"hybrid": {
|
||||
"semanticRatio": 1.0,
|
||||
"embedder": "rest"
|
||||
},
|
||||
"retrieveVectors": true,
|
||||
"federationOptions": {
|
||||
"remote": "ms2"
|
||||
}
|
||||
},
|
||||
]
|
||||
});
|
||||
|
||||
let (response, _status_code) = ms0.multi_search(request.clone()).await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response, { ".processingTimeMs" => "[time]" }), @r#"
|
||||
{
|
||||
"hits": [],
|
||||
"processingTimeMs": "[time]",
|
||||
"limit": 20,
|
||||
"offset": 0,
|
||||
"estimatedTotalHits": 0,
|
||||
"queryVectors": {
|
||||
"0": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.2
|
||||
],
|
||||
"1": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.2
|
||||
],
|
||||
"2": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.2
|
||||
]
|
||||
},
|
||||
"semanticHitCount": 0,
|
||||
"remoteErrors": {}
|
||||
}
|
||||
"#);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn error_unregistered_remote() {
|
||||
let ms0 = Server::new().await;
|
||||
|
@ -230,7 +230,14 @@ impl Search<'_> {
|
||||
}
|
||||
|
||||
// no embedder, no semantic search
|
||||
let Some(SemanticSearch { vector, embedder_name, embedder, quantized, media }) = semantic
|
||||
let Some(SemanticSearch {
|
||||
vector,
|
||||
mut auto_embedded,
|
||||
embedder_name,
|
||||
embedder,
|
||||
quantized,
|
||||
media,
|
||||
}) = semantic
|
||||
else {
|
||||
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
|
||||
};
|
||||
@ -253,7 +260,10 @@ impl Search<'_> {
|
||||
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
|
||||
|
||||
match embedder.embed_search(query, Some(deadline)) {
|
||||
Ok(embedding) => embedding,
|
||||
Ok(embedding) => {
|
||||
auto_embedded = true;
|
||||
embedding
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!(error=%error, "Embedding failed");
|
||||
return Ok(return_keyword_results(
|
||||
@ -268,6 +278,7 @@ impl Search<'_> {
|
||||
|
||||
search.semantic = Some(SemanticSearch {
|
||||
vector: Some(vector_query.clone()),
|
||||
auto_embedded,
|
||||
embedder_name,
|
||||
embedder,
|
||||
quantized,
|
||||
@ -280,7 +291,7 @@ impl Search<'_> {
|
||||
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
|
||||
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
|
||||
|
||||
let (mut merge_results, semantic_hit_count) = ScoreWithRatioResult::merge(
|
||||
let (merge_results, semantic_hit_count) = ScoreWithRatioResult::merge(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
self.offset,
|
||||
@ -289,7 +300,6 @@ impl Search<'_> {
|
||||
search.index,
|
||||
search.rtxn,
|
||||
)?;
|
||||
merge_results.query_vector = Some(vector_query);
|
||||
assert!(merge_results.documents_ids.len() <= self.limit);
|
||||
Ok((merge_results, Some(semantic_hit_count)))
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ pub mod similar;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SemanticSearch {
|
||||
vector: Option<Vec<f32>>,
|
||||
auto_embedded: bool,
|
||||
media: Option<serde_json::Value>,
|
||||
embedder_name: String,
|
||||
embedder: Arc<Embedder>,
|
||||
@ -97,7 +98,33 @@ impl<'a> Search<'a> {
|
||||
vector: Option<Embedding>,
|
||||
media: Option<serde_json::Value>,
|
||||
) -> &mut Search<'a> {
|
||||
self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector, media });
|
||||
self.semantic = Some(SemanticSearch {
|
||||
embedder_name,
|
||||
auto_embedded: false,
|
||||
embedder,
|
||||
quantized,
|
||||
vector,
|
||||
media,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
pub fn semantic_auto_embedded(
|
||||
&mut self,
|
||||
embedder_name: String,
|
||||
embedder: Arc<Embedder>,
|
||||
quantized: bool,
|
||||
vector: Option<Embedding>,
|
||||
media: Option<serde_json::Value>,
|
||||
) -> &mut Search<'a> {
|
||||
self.semantic = Some(SemanticSearch {
|
||||
embedder_name,
|
||||
auto_embedded: true,
|
||||
embedder,
|
||||
quantized,
|
||||
vector,
|
||||
media,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
@ -225,6 +252,7 @@ impl<'a> Search<'a> {
|
||||
}
|
||||
|
||||
let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?;
|
||||
let mut query_vector = None;
|
||||
let PartialSearchResult {
|
||||
located_query_terms,
|
||||
candidates,
|
||||
@ -235,26 +263,32 @@ impl<'a> Search<'a> {
|
||||
} = match self.semantic.as_ref() {
|
||||
Some(SemanticSearch {
|
||||
vector: Some(vector),
|
||||
auto_embedded,
|
||||
embedder_name,
|
||||
embedder,
|
||||
quantized,
|
||||
media: _,
|
||||
}) => execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
self.scoring_strategy,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
&self.distinct,
|
||||
self.geo_param,
|
||||
self.offset,
|
||||
self.limit,
|
||||
embedder_name,
|
||||
embedder,
|
||||
*quantized,
|
||||
self.time_budget.clone(),
|
||||
self.ranking_score_threshold,
|
||||
)?,
|
||||
}) => {
|
||||
if *auto_embedded {
|
||||
query_vector = Some(vector.clone());
|
||||
}
|
||||
execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
self.scoring_strategy,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
&self.distinct,
|
||||
self.geo_param,
|
||||
self.offset,
|
||||
self.limit,
|
||||
embedder_name,
|
||||
embedder,
|
||||
*quantized,
|
||||
self.time_budget.clone(),
|
||||
self.ranking_score_threshold,
|
||||
)?
|
||||
}
|
||||
_ => execute_search(
|
||||
&mut ctx,
|
||||
self.query.as_deref(),
|
||||
@ -295,7 +329,7 @@ impl<'a> Search<'a> {
|
||||
documents_ids,
|
||||
degraded,
|
||||
used_negative_operator,
|
||||
query_vector: None,
|
||||
query_vector,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user