From 260c963540860673b9c6923a9884c46c44c8e25b Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Wed, 23 Jul 2025 11:01:19 +0200 Subject: [PATCH] refactor: split PersonalizationService into enum with CohereService - Refactor PersonalizationService as enum with Cohere and Uninitialized variants - Create dedicated CohereService struct with rerank_search_results method - Split constructor into cohere() and uninitialized() methods - Move all Cohere logic into CohereService for better separation of concerns - Update tests and lib.rs to use new API - Improve code organization and maintainability --- crates/meilisearch/src/lib.rs | 8 ++- crates/meilisearch/src/personalization/mod.rs | 62 ++++++++++++------- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/crates/meilisearch/src/lib.rs b/crates/meilisearch/src/lib.rs index c799ad279..ffedb3668 100644 --- a/crates/meilisearch/src/lib.rs +++ b/crates/meilisearch/src/lib.rs @@ -601,9 +601,11 @@ pub fn configure_data( analytics: Data, ) { // Create personalization service with API key from options - let personalization_service = personalization::PersonalizationService::new( - index_scheduler.experimental_personalization_api_key().cloned(), - ); + let personalization_service = index_scheduler + .experimental_personalization_api_key() + .cloned() + .map(personalization::PersonalizationService::cohere) + .unwrap_or_else(personalization::PersonalizationService::uninitialized); let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize; config .app_data(index_scheduler) diff --git a/crates/meilisearch/src/personalization/mod.rs b/crates/meilisearch/src/personalization/mod.rs index 886533702..06f22bee4 100644 --- a/crates/meilisearch/src/personalization/mod.rs +++ b/crates/meilisearch/src/personalization/mod.rs @@ -6,21 +6,14 @@ use cohere_rust::{ use meilisearch_types::error::ResponseError; use tracing::{debug, error, info}; -pub struct PersonalizationService { - cohere: Option, +pub struct CohereService { + cohere: Cohere, } -impl PersonalizationService { - pub fn new(api_key: Option) -> Self { - let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key)); - - if cohere.is_some() { - info!("Personalization service initialized with Cohere API"); - } else { - debug!("Personalization service initialized without Cohere API key"); - } - - Self { cohere } +impl CohereService { + pub fn new(api_key: String) -> Self { + info!("Personalization service initialized with Cohere API"); + Self { cohere: Cohere::new("https://api.cohere.ai", api_key) } } pub async fn rerank_search_results( @@ -29,9 +22,6 @@ impl PersonalizationService { personalize: Option<&Personalize>, query: Option<&str>, ) -> Result { - // If no API key, return original results - let Some(cohere) = &self.cohere else { return Ok(search_result) }; - // Extract user context from personalization let user_context = personalize.and_then(|p| p.user_context.as_deref()); @@ -67,7 +57,7 @@ impl PersonalizationService { }; // Call Cohere's rerank API - match cohere.rerank(&rerank_request).await { + match self.cohere.rerank(&rerank_request).await { Ok(rerank_response) => { debug!("Cohere rerank successful, reordering {} results", search_result.hits.len()); @@ -94,6 +84,36 @@ impl PersonalizationService { } } +pub enum PersonalizationService { + Cohere(CohereService), + Uninitialized, +} + +impl PersonalizationService { + pub fn cohere(api_key: String) -> Self { + Self::Cohere(CohereService::new(api_key)) + } + + pub fn uninitialized() -> Self { + debug!("Personalization service uninitialized"); + Self::Uninitialized + } + + pub async fn rerank_search_results( + &self, + search_result: SearchResult, + personalize: Option<&Personalize>, + query: Option<&str>, + ) -> Result { + match self { + Self::Cohere(cohere_service) => { + cohere_service.rerank_search_results(search_result, personalize, query).await + } + Self::Uninitialized => Ok(search_result), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -101,7 +121,7 @@ mod tests { #[tokio::test] async fn test_personalization_service_without_api_key() { - let service = PersonalizationService::new(None); + let service = PersonalizationService::uninitialized(); let personalize = Personalize { user_context: Some("test user".to_string()) }; let search_result = SearchResult { @@ -134,7 +154,7 @@ mod tests { #[tokio::test] async fn test_personalization_service_with_user_context_only() { - let service = PersonalizationService::new(Some("fake_key".to_string())); + let service = PersonalizationService::cohere("fake_key".to_string()); let personalize = Personalize { user_context: Some("test user".to_string()) }; let search_result = SearchResult { @@ -166,7 +186,7 @@ mod tests { #[tokio::test] async fn test_personalization_service_with_query_only() { - let service = PersonalizationService::new(Some("fake_key".to_string())); + let service = PersonalizationService::cohere("fake_key".to_string()); let search_result = SearchResult { hits: vec![SearchHit { @@ -196,7 +216,7 @@ mod tests { #[tokio::test] async fn test_personalization_service_both_none() { - let service = PersonalizationService::new(Some("fake_key".to_string())); + let service = PersonalizationService::cohere("fake_key".to_string()); let search_result = SearchResult { hits: vec![SearchHit {