refactor: split PersonalizationService into enum with CohereService

- Refactor PersonalizationService as enum with Cohere and Uninitialized variants
- Create dedicated CohereService struct with rerank_search_results method
- Split constructor into cohere() and uninitialized() methods
- Move all Cohere logic into CohereService for better separation of concerns
- Update tests and lib.rs to use new API
- Improve code organization and maintainability
This commit is contained in:
ManyTheFish
2025-07-23 11:01:19 +02:00
parent be68dd6785
commit c56be3d820
2 changed files with 46 additions and 24 deletions

View File

@ -678,9 +678,11 @@ pub fn configure_data(
analytics: Data<Analytics>, analytics: Data<Analytics>,
) { ) {
// Create personalization service with API key from options // Create personalization service with API key from options
let personalization_service = personalization::PersonalizationService::new( let personalization_service = index_scheduler
index_scheduler.experimental_personalization_api_key().cloned(), .experimental_personalization_api_key()
); .cloned()
.map(personalization::PersonalizationService::cohere)
.unwrap_or_else(personalization::PersonalizationService::uninitialized);
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize; let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
config config
.app_data(index_scheduler) .app_data(index_scheduler)

View File

@ -6,21 +6,14 @@ use cohere_rust::{
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
pub struct PersonalizationService { pub struct CohereService {
cohere: Option<Cohere>, cohere: Cohere,
} }
impl PersonalizationService { impl CohereService {
pub fn new(api_key: Option<String>) -> Self { pub fn new(api_key: String) -> Self {
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
if cohere.is_some() {
info!("Personalization service initialized with Cohere API"); info!("Personalization service initialized with Cohere API");
} else { Self { cohere: Cohere::new("https://api.cohere.ai", api_key) }
debug!("Personalization service initialized without Cohere API key");
}
Self { cohere }
} }
pub async fn rerank_search_results( pub async fn rerank_search_results(
@ -29,9 +22,6 @@ impl PersonalizationService {
personalize: Option<&Personalize>, personalize: Option<&Personalize>,
query: Option<&str>, query: Option<&str>,
) -> Result<SearchResult, ResponseError> { ) -> Result<SearchResult, ResponseError> {
// If no API key, return original results
let Some(cohere) = &self.cohere else { return Ok(search_result) };
// Extract user context from personalization // Extract user context from personalization
let user_context = personalize.and_then(|p| p.user_context.as_deref()); let user_context = personalize.and_then(|p| p.user_context.as_deref());
@ -67,7 +57,7 @@ impl PersonalizationService {
}; };
// Call Cohere's rerank API // Call Cohere's rerank API
match cohere.rerank(&rerank_request).await { match self.cohere.rerank(&rerank_request).await {
Ok(rerank_response) => { Ok(rerank_response) => {
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len()); debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
@ -94,6 +84,36 @@ impl PersonalizationService {
} }
} }
pub enum PersonalizationService {
Cohere(CohereService),
Uninitialized,
}
impl PersonalizationService {
pub fn cohere(api_key: String) -> Self {
Self::Cohere(CohereService::new(api_key))
}
pub fn uninitialized() -> Self {
debug!("Personalization service uninitialized");
Self::Uninitialized
}
pub async fn rerank_search_results(
&self,
search_result: SearchResult,
personalize: Option<&Personalize>,
query: Option<&str>,
) -> Result<SearchResult, ResponseError> {
match self {
Self::Cohere(cohere_service) => {
cohere_service.rerank_search_results(search_result, personalize, query).await
}
Self::Uninitialized => Ok(search_result),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -101,7 +121,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_personalization_service_without_api_key() { async fn test_personalization_service_without_api_key() {
let service = PersonalizationService::new(None); let service = PersonalizationService::uninitialized();
let personalize = Personalize { user_context: Some("test user".to_string()) }; let personalize = Personalize { user_context: Some("test user".to_string()) };
let search_result = SearchResult { let search_result = SearchResult {
@ -134,7 +154,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_personalization_service_with_user_context_only() { async fn test_personalization_service_with_user_context_only() {
let service = PersonalizationService::new(Some("fake_key".to_string())); let service = PersonalizationService::cohere("fake_key".to_string());
let personalize = Personalize { user_context: Some("test user".to_string()) }; let personalize = Personalize { user_context: Some("test user".to_string()) };
let search_result = SearchResult { let search_result = SearchResult {
@ -166,7 +186,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_personalization_service_with_query_only() { async fn test_personalization_service_with_query_only() {
let service = PersonalizationService::new(Some("fake_key".to_string())); let service = PersonalizationService::cohere("fake_key".to_string());
let search_result = SearchResult { let search_result = SearchResult {
hits: vec![SearchHit { hits: vec![SearchHit {
@ -196,7 +216,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_personalization_service_both_none() { async fn test_personalization_service_both_none() {
let service = PersonalizationService::new(Some("fake_key".to_string())); let service = PersonalizationService::cohere("fake_key".to_string());
let search_result = SearchResult { let search_result = SearchResult {
hits: vec![SearchHit { hits: vec![SearchHit {