mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-09-06 04:36:32 +00:00
feat: add personalization service with EnglishV3-only reranking
- Add new personalization module with Cohere integration - Implement rerank_search_results method using EnglishV3 model - Remove fallback logic to EnglishV2 for simplified behavior - Add comprehensive error handling and logging - Include unit tests for service behavior - Update search route to support personalization feature
This commit is contained in:
33
Cargo.lock
generated
33
Cargo.lock
generated
@ -1194,6 +1194,21 @@ version = "0.7.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
|
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cohere-rust"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d8b553b385b0f2562138baea705b5707335314f8e91a58e7d1a03c3a6c332423"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"reqwest",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"strum_macros 0.26.4",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "color-spantrace"
|
name = "color-spantrace"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
@ -3445,7 +3460,7 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
"strum",
|
"strum",
|
||||||
"strum_macros",
|
"strum_macros 0.27.1",
|
||||||
"unicode-blocks",
|
"unicode-blocks",
|
||||||
"unicode-normalization",
|
"unicode-normalization",
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
@ -3753,6 +3768,7 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"cargo_toml",
|
"cargo_toml",
|
||||||
"clap",
|
"clap",
|
||||||
|
"cohere-rust",
|
||||||
"crossbeam-channel",
|
"crossbeam-channel",
|
||||||
"deserr",
|
"deserr",
|
||||||
"dump",
|
"dump",
|
||||||
@ -5822,7 +5838,20 @@ version = "0.27.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
|
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"strum_macros",
|
"strum_macros 0.27.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strum_macros"
|
||||||
|
version = "0.26.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
|
||||||
|
dependencies = [
|
||||||
|
"heck",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"rustversion",
|
||||||
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -95,6 +95,7 @@ uuid = { version = "1.17.0", features = ["serde", "v4"] }
|
|||||||
serde_urlencoded = "0.7.1"
|
serde_urlencoded = "0.7.1"
|
||||||
termcolor = "1.4.1"
|
termcolor = "1.4.1"
|
||||||
url = { version = "2.5.4", features = ["serde"] }
|
url = { version = "2.5.4", features = ["serde"] }
|
||||||
|
cohere-rust = "0.6.0"
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
tracing-subscriber = { version = "0.3.19", features = ["json"] }
|
tracing-subscriber = { version = "0.3.19", features = ["json"] }
|
||||||
tracing-trace = { version = "0.1.0", path = "../tracing-trace" }
|
tracing-trace = { version = "0.1.0", path = "../tracing-trace" }
|
||||||
|
@ -9,6 +9,7 @@ pub mod middleware;
|
|||||||
pub mod option;
|
pub mod option;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod option_test;
|
mod option_test;
|
||||||
|
pub mod personalization;
|
||||||
pub mod routes;
|
pub mod routes;
|
||||||
pub mod search;
|
pub mod search;
|
||||||
pub mod search_queue;
|
pub mod search_queue;
|
||||||
@ -676,12 +677,17 @@ pub fn configure_data(
|
|||||||
(logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle),
|
(logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle),
|
||||||
analytics: Data<Analytics>,
|
analytics: Data<Analytics>,
|
||||||
) {
|
) {
|
||||||
|
// Create personalization service with API key from options
|
||||||
|
let personalization_service = personalization::PersonalizationService::new(
|
||||||
|
index_scheduler.experimental_personalization_api_key().cloned(),
|
||||||
|
);
|
||||||
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
|
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
|
||||||
config
|
config
|
||||||
.app_data(index_scheduler)
|
.app_data(index_scheduler)
|
||||||
.app_data(auth)
|
.app_data(auth)
|
||||||
.app_data(search_queue)
|
.app_data(search_queue)
|
||||||
.app_data(analytics)
|
.app_data(analytics)
|
||||||
|
.app_data(web::Data::new(personalization_service))
|
||||||
.app_data(web::Data::new(logs_route))
|
.app_data(web::Data::new(logs_route))
|
||||||
.app_data(web::Data::new(logs_stderr))
|
.app_data(web::Data::new(logs_stderr))
|
||||||
.app_data(web::Data::new(opt.clone()))
|
.app_data(web::Data::new(opt.clone()))
|
||||||
|
162
crates/meilisearch/src/personalization/mod.rs
Normal file
162
crates/meilisearch/src/personalization/mod.rs
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
use crate::search::{Personalization, SearchResult};
|
||||||
|
use cohere_rust::{
|
||||||
|
api::rerank::{ReRankModel, ReRankRequest},
|
||||||
|
Cohere,
|
||||||
|
};
|
||||||
|
use meilisearch_types::error::ResponseError;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
pub struct PersonalizationService {
|
||||||
|
cohere: Option<Cohere>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PersonalizationService {
|
||||||
|
pub fn new(api_key: Option<String>) -> Self {
|
||||||
|
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
|
||||||
|
|
||||||
|
if cohere.is_some() {
|
||||||
|
info!("Personalization service initialized with Cohere API");
|
||||||
|
} else {
|
||||||
|
debug!("Personalization service initialized without Cohere API key");
|
||||||
|
}
|
||||||
|
|
||||||
|
Self { cohere }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn rerank_search_results(
|
||||||
|
&self,
|
||||||
|
search_result: SearchResult,
|
||||||
|
personalization: &Personalization,
|
||||||
|
query: &str,
|
||||||
|
) -> Result<SearchResult, ResponseError> {
|
||||||
|
// If personalization is not enabled or no API key, return original results
|
||||||
|
if !personalization.personalized || self.cohere.is_none() {
|
||||||
|
return Ok(search_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cohere = self.cohere.as_ref().unwrap();
|
||||||
|
|
||||||
|
// Extract documents for reranking
|
||||||
|
let documents: Vec<String> = search_result
|
||||||
|
.hits
|
||||||
|
.iter()
|
||||||
|
.map(|hit| {
|
||||||
|
// Convert the document to a string representation for reranking
|
||||||
|
serde_json::to_string(&hit.document).unwrap_or_else(|_| "{}".to_string())
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if documents.is_empty() {
|
||||||
|
return Ok(search_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the rerank request
|
||||||
|
let rerank_request = ReRankRequest {
|
||||||
|
query,
|
||||||
|
documents: &documents,
|
||||||
|
model: ReRankModel::EnglishV3, // Use the default and more recent model
|
||||||
|
top_n: None,
|
||||||
|
max_chunks_per_doc: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Call Cohere's rerank API
|
||||||
|
match cohere.rerank(&rerank_request).await {
|
||||||
|
Ok(rerank_response) => {
|
||||||
|
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
|
||||||
|
|
||||||
|
// Create a mapping from original index to new rank
|
||||||
|
let reranked_indices: Vec<usize> =
|
||||||
|
rerank_response.iter().map(|result| result.index as usize).collect();
|
||||||
|
|
||||||
|
// Reorder the hits based on Cohere's reranking
|
||||||
|
let mut reranked_hits = search_result.hits.clone();
|
||||||
|
for (new_index, original_index) in reranked_indices.iter().enumerate() {
|
||||||
|
if *original_index < reranked_hits.len() {
|
||||||
|
reranked_hits.swap(new_index, *original_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(SearchResult { hits: reranked_hits, ..search_result })
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Cohere rerank failed with model EnglishV3: {}", e);
|
||||||
|
// Return original results on error
|
||||||
|
Ok(search_result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::search::{HitsInfo, SearchHit};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_personalization_service_without_api_key() {
|
||||||
|
let service = PersonalizationService::new(None);
|
||||||
|
let personalization =
|
||||||
|
Personalization { personalized: true, user_profile: Some("test user".to_string()) };
|
||||||
|
|
||||||
|
let search_result = SearchResult {
|
||||||
|
hits: vec![SearchHit {
|
||||||
|
document: serde_json::Map::new(),
|
||||||
|
formatted: serde_json::Map::new(),
|
||||||
|
matches_position: None,
|
||||||
|
ranking_score: Some(1.0),
|
||||||
|
ranking_score_details: None,
|
||||||
|
}],
|
||||||
|
query: "test".to_string(),
|
||||||
|
processing_time_ms: 10,
|
||||||
|
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
|
||||||
|
facet_distribution: None,
|
||||||
|
facet_stats: None,
|
||||||
|
semantic_hit_count: None,
|
||||||
|
degraded: false,
|
||||||
|
used_negative_operator: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result =
|
||||||
|
service.rerank_search_results(search_result.clone(), &personalization, "test").await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Should return original results when no API key is provided
|
||||||
|
let reranked_result = result.unwrap();
|
||||||
|
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_personalization_service_disabled() {
|
||||||
|
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
||||||
|
let personalization = Personalization {
|
||||||
|
personalized: false, // Personalization disabled
|
||||||
|
user_profile: Some("test user".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let search_result = SearchResult {
|
||||||
|
hits: vec![SearchHit {
|
||||||
|
document: serde_json::Map::new(),
|
||||||
|
formatted: serde_json::Map::new(),
|
||||||
|
matches_position: None,
|
||||||
|
ranking_score: Some(1.0),
|
||||||
|
ranking_score_details: None,
|
||||||
|
}],
|
||||||
|
query: "test".to_string(),
|
||||||
|
processing_time_ms: 10,
|
||||||
|
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
|
||||||
|
facet_distribution: None,
|
||||||
|
facet_stats: None,
|
||||||
|
semantic_hit_count: None,
|
||||||
|
degraded: false,
|
||||||
|
used_negative_operator: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result =
|
||||||
|
service.rerank_search_results(search_result.clone(), &personalization, "test").await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Should return original results when personalization is disabled
|
||||||
|
let reranked_result = result.unwrap();
|
||||||
|
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
|
||||||
|
}
|
||||||
|
}
|
@ -342,6 +342,7 @@ pub fn fix_sort_query_parameters(sort_query: &str) -> Vec<String> {
|
|||||||
pub async fn search_with_url_query(
|
pub async fn search_with_url_query(
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
|
personalization_service: web::Data<crate::personalization::PersonalizationService>,
|
||||||
index_uid: web::Path<String>,
|
index_uid: web::Path<String>,
|
||||||
params: AwebQueryParameter<SearchQueryGet, DeserrQueryParamError>,
|
params: AwebQueryParameter<SearchQueryGet, DeserrQueryParamError>,
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
@ -364,6 +365,11 @@ pub async fn search_with_url_query(
|
|||||||
let search_kind =
|
let search_kind =
|
||||||
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||||
let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors);
|
let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors);
|
||||||
|
|
||||||
|
// Extract personalization and query string before moving query
|
||||||
|
let personalization = query.personalization.clone();
|
||||||
|
let query_str = query.q.clone();
|
||||||
|
|
||||||
let permit = search_queue.try_get_search_permit().await?;
|
let permit = search_queue.try_get_search_permit().await?;
|
||||||
let search_result = tokio::task::spawn_blocking(move || {
|
let search_result = tokio::task::spawn_blocking(move || {
|
||||||
perform_search(
|
perform_search(
|
||||||
@ -383,7 +389,16 @@ pub async fn search_with_url_query(
|
|||||||
}
|
}
|
||||||
analytics.publish(aggregate, &req);
|
analytics.publish(aggregate, &req);
|
||||||
|
|
||||||
let search_result = search_result?;
|
let mut search_result = search_result?;
|
||||||
|
|
||||||
|
// Apply personalization if requested
|
||||||
|
if let Some(personalization) = &personalization {
|
||||||
|
if let Some(query_str) = &query_str {
|
||||||
|
search_result = personalization_service
|
||||||
|
.rerank_search_results(search_result, personalization, query_str)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
debug!(returns = ?search_result, "Search get");
|
debug!(returns = ?search_result, "Search get");
|
||||||
Ok(HttpResponse::Ok().json(search_result))
|
Ok(HttpResponse::Ok().json(search_result))
|
||||||
@ -448,6 +463,7 @@ pub async fn search_with_url_query(
|
|||||||
pub async fn search_with_post(
|
pub async fn search_with_post(
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
|
personalization_service: web::Data<crate::personalization::PersonalizationService>,
|
||||||
index_uid: web::Path<String>,
|
index_uid: web::Path<String>,
|
||||||
params: AwebJson<SearchQuery, DeserrJsonError>,
|
params: AwebJson<SearchQuery, DeserrJsonError>,
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
@ -471,6 +487,10 @@ pub async fn search_with_post(
|
|||||||
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||||
let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors);
|
let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors);
|
||||||
|
|
||||||
|
// Extract personalization and query string before moving query
|
||||||
|
let personalization = query.personalization.clone();
|
||||||
|
let query_str = query.q.clone();
|
||||||
|
|
||||||
let permit = search_queue.try_get_search_permit().await?;
|
let permit = search_queue.try_get_search_permit().await?;
|
||||||
let search_result = tokio::task::spawn_blocking(move || {
|
let search_result = tokio::task::spawn_blocking(move || {
|
||||||
perform_search(
|
perform_search(
|
||||||
@ -493,7 +513,16 @@ pub async fn search_with_post(
|
|||||||
}
|
}
|
||||||
analytics.publish(aggregate, &req);
|
analytics.publish(aggregate, &req);
|
||||||
|
|
||||||
let search_result = search_result?;
|
let mut search_result = search_result?;
|
||||||
|
|
||||||
|
// Apply personalization if requested
|
||||||
|
if let Some(personalization) = &personalization {
|
||||||
|
if let Some(query_str) = &query_str {
|
||||||
|
search_result = personalization_service
|
||||||
|
.rerank_search_results(search_result, personalization, query_str)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
debug!(returns = ?search_result, "Search post");
|
debug!(returns = ?search_result, "Search post");
|
||||||
Ok(HttpResponse::Ok().json(search_result))
|
Ok(HttpResponse::Ok().json(search_result))
|
||||||
|
Reference in New Issue
Block a user