diff --git a/Cargo.lock b/Cargo.lock index 8413b3d14..787ff0550 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1194,6 +1194,21 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +[[package]] +name = "cohere-rust" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b553b385b0f2562138baea705b5707335314f8e91a58e7d1a03c3a6c332423" +dependencies = [ + "bytes", + "reqwest", + "serde", + "serde_json", + "strum_macros 0.26.4", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "color-spantrace" version = "0.3.0" @@ -3445,7 +3460,7 @@ dependencies = [ "serde_json", "serde_yaml", "strum", - "strum_macros", + "strum_macros 0.27.1", "unicode-blocks", "unicode-normalization", "unicode-segmentation", @@ -3753,6 +3768,7 @@ dependencies = [ "bytes", "cargo_toml", "clap", + "cohere-rust", "crossbeam-channel", "deserr", "dump", @@ -5822,7 +5838,20 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" dependencies = [ - "strum_macros", + "strum_macros 0.27.1", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.101", ] [[package]] diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index 21f6b58e5..852c7e887 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -95,6 +95,7 @@ uuid = { version = "1.17.0", features = ["serde", "v4"] } serde_urlencoded = "0.7.1" termcolor = "1.4.1" url = { version = "2.5.4", features = ["serde"] } +cohere-rust = "0.6.0" tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["json"] } tracing-trace = { version = "0.1.0", path = "../tracing-trace" } diff --git a/crates/meilisearch/src/lib.rs b/crates/meilisearch/src/lib.rs index 0fb93b65a..e468d878d 100644 --- a/crates/meilisearch/src/lib.rs +++ b/crates/meilisearch/src/lib.rs @@ -9,6 +9,7 @@ pub mod middleware; pub mod option; #[cfg(test)] mod option_test; +pub mod personalization; pub mod routes; pub mod search; pub mod search_queue; @@ -676,12 +677,17 @@ pub fn configure_data( (logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle), analytics: Data, ) { + // Create personalization service with API key from options + let personalization_service = personalization::PersonalizationService::new( + index_scheduler.experimental_personalization_api_key().cloned(), + ); let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize; config .app_data(index_scheduler) .app_data(auth) .app_data(search_queue) .app_data(analytics) + .app_data(web::Data::new(personalization_service)) .app_data(web::Data::new(logs_route)) .app_data(web::Data::new(logs_stderr)) .app_data(web::Data::new(opt.clone())) diff --git a/crates/meilisearch/src/personalization/mod.rs b/crates/meilisearch/src/personalization/mod.rs new file mode 100644 index 000000000..893aa8be4 --- /dev/null +++ b/crates/meilisearch/src/personalization/mod.rs @@ -0,0 +1,162 @@ +use crate::search::{Personalization, SearchResult}; +use cohere_rust::{ + api::rerank::{ReRankModel, ReRankRequest}, + Cohere, +}; +use meilisearch_types::error::ResponseError; +use tracing::{debug, error, info}; + +pub struct PersonalizationService { + cohere: Option, +} + +impl PersonalizationService { + pub fn new(api_key: Option) -> Self { + let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key)); + + if cohere.is_some() { + info!("Personalization service initialized with Cohere API"); + } else { + debug!("Personalization service initialized without Cohere API key"); + } + + Self { cohere } + } + + pub async fn rerank_search_results( + &self, + search_result: SearchResult, + personalization: &Personalization, + query: &str, + ) -> Result { + // If personalization is not enabled or no API key, return original results + if !personalization.personalized || self.cohere.is_none() { + return Ok(search_result); + } + + let cohere = self.cohere.as_ref().unwrap(); + + // Extract documents for reranking + let documents: Vec = search_result + .hits + .iter() + .map(|hit| { + // Convert the document to a string representation for reranking + serde_json::to_string(&hit.document).unwrap_or_else(|_| "{}".to_string()) + }) + .collect(); + + if documents.is_empty() { + return Ok(search_result); + } + + // Prepare the rerank request + let rerank_request = ReRankRequest { + query, + documents: &documents, + model: ReRankModel::EnglishV3, // Use the default and more recent model + top_n: None, + max_chunks_per_doc: None, + }; + + // Call Cohere's rerank API + match cohere.rerank(&rerank_request).await { + Ok(rerank_response) => { + debug!("Cohere rerank successful, reordering {} results", search_result.hits.len()); + + // Create a mapping from original index to new rank + let reranked_indices: Vec = + rerank_response.iter().map(|result| result.index as usize).collect(); + + // Reorder the hits based on Cohere's reranking + let mut reranked_hits = search_result.hits.clone(); + for (new_index, original_index) in reranked_indices.iter().enumerate() { + if *original_index < reranked_hits.len() { + reranked_hits.swap(new_index, *original_index); + } + } + + Ok(SearchResult { hits: reranked_hits, ..search_result }) + } + Err(e) => { + error!("Cohere rerank failed with model EnglishV3: {}", e); + // Return original results on error + Ok(search_result) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::search::{HitsInfo, SearchHit}; + + #[tokio::test] + async fn test_personalization_service_without_api_key() { + let service = PersonalizationService::new(None); + let personalization = + Personalization { personalized: true, user_profile: Some("test user".to_string()) }; + + let search_result = SearchResult { + hits: vec![SearchHit { + document: serde_json::Map::new(), + formatted: serde_json::Map::new(), + matches_position: None, + ranking_score: Some(1.0), + ranking_score_details: None, + }], + query: "test".to_string(), + processing_time_ms: 10, + hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 }, + facet_distribution: None, + facet_stats: None, + semantic_hit_count: None, + degraded: false, + used_negative_operator: false, + }; + + let result = + service.rerank_search_results(search_result.clone(), &personalization, "test").await; + assert!(result.is_ok()); + + // Should return original results when no API key is provided + let reranked_result = result.unwrap(); + assert_eq!(reranked_result.hits.len(), search_result.hits.len()); + } + + #[tokio::test] + async fn test_personalization_service_disabled() { + let service = PersonalizationService::new(Some("fake_key".to_string())); + let personalization = Personalization { + personalized: false, // Personalization disabled + user_profile: Some("test user".to_string()), + }; + + let search_result = SearchResult { + hits: vec![SearchHit { + document: serde_json::Map::new(), + formatted: serde_json::Map::new(), + matches_position: None, + ranking_score: Some(1.0), + ranking_score_details: None, + }], + query: "test".to_string(), + processing_time_ms: 10, + hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 }, + facet_distribution: None, + facet_stats: None, + semantic_hit_count: None, + degraded: false, + used_negative_operator: false, + }; + + let result = + service.rerank_search_results(search_result.clone(), &personalization, "test").await; + assert!(result.is_ok()); + + // Should return original results when personalization is disabled + let reranked_result = result.unwrap(); + assert_eq!(reranked_result.hits.len(), search_result.hits.len()); + } +} diff --git a/crates/meilisearch/src/routes/indexes/search.rs b/crates/meilisearch/src/routes/indexes/search.rs index 3e1b59f4b..84bc9810a 100644 --- a/crates/meilisearch/src/routes/indexes/search.rs +++ b/crates/meilisearch/src/routes/indexes/search.rs @@ -342,6 +342,7 @@ pub fn fix_sort_query_parameters(sort_query: &str) -> Vec { pub async fn search_with_url_query( index_scheduler: GuardedData, Data>, search_queue: web::Data, + personalization_service: web::Data, index_uid: web::Path, params: AwebQueryParameter, req: HttpRequest, @@ -364,6 +365,11 @@ pub async fn search_with_url_query( let search_kind = search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors); + + // Extract personalization and query string before moving query + let personalization = query.personalization.clone(); + let query_str = query.q.clone(); + let permit = search_queue.try_get_search_permit().await?; let search_result = tokio::task::spawn_blocking(move || { perform_search( @@ -383,7 +389,16 @@ pub async fn search_with_url_query( } analytics.publish(aggregate, &req); - let search_result = search_result?; + let mut search_result = search_result?; + + // Apply personalization if requested + if let Some(personalization) = &personalization { + if let Some(query_str) = &query_str { + search_result = personalization_service + .rerank_search_results(search_result, personalization, query_str) + .await?; + } + } debug!(returns = ?search_result, "Search get"); Ok(HttpResponse::Ok().json(search_result)) @@ -448,6 +463,7 @@ pub async fn search_with_url_query( pub async fn search_with_post( index_scheduler: GuardedData, Data>, search_queue: web::Data, + personalization_service: web::Data, index_uid: web::Path, params: AwebJson, req: HttpRequest, @@ -471,6 +487,10 @@ pub async fn search_with_post( search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors); + // Extract personalization and query string before moving query + let personalization = query.personalization.clone(); + let query_str = query.q.clone(); + let permit = search_queue.try_get_search_permit().await?; let search_result = tokio::task::spawn_blocking(move || { perform_search( @@ -493,7 +513,16 @@ pub async fn search_with_post( } analytics.publish(aggregate, &req); - let search_result = search_result?; + let mut search_result = search_result?; + + // Apply personalization if requested + if let Some(personalization) = &personalization { + if let Some(query_str) = &query_str { + search_result = personalization_service + .rerank_search_results(search_result, personalization, query_str) + .await?; + } + } debug!(returns = ?search_result, "Search post"); Ok(HttpResponse::Ok().json(search_result))