refactor: split PersonalizationService into enum with CohereService

- Refactor PersonalizationService as enum with Cohere and Uninitialized variants
- Create dedicated CohereService struct with rerank_search_results method
- Split constructor into cohere() and uninitialized() methods
- Move all Cohere logic into CohereService for better separation of concerns
- Update tests and lib.rs to use new API
- Improve code organization and maintainability
This commit is contained in:
ManyTheFish
2025-07-23 11:01:19 +02:00
parent 0cf7ed9135
commit f3a0969c61
2 changed files with 46 additions and 24 deletions

View File

@ -624,9 +624,11 @@ pub fn configure_data(
analytics: Data<Analytics>,
) {
// Create personalization service with API key from options
let personalization_service = personalization::PersonalizationService::new(
index_scheduler.experimental_personalization_api_key().cloned(),
);
let personalization_service = index_scheduler
.experimental_personalization_api_key()
.cloned()
.map(personalization::PersonalizationService::cohere)
.unwrap_or_else(personalization::PersonalizationService::uninitialized);
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
config
.app_data(index_scheduler)

View File

@ -6,21 +6,14 @@ use cohere_rust::{
use meilisearch_types::error::ResponseError;
use tracing::{debug, error, info};
pub struct PersonalizationService {
cohere: Option<Cohere>,
pub struct CohereService {
cohere: Cohere,
}
impl PersonalizationService {
pub fn new(api_key: Option<String>) -> Self {
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
if cohere.is_some() {
info!("Personalization service initialized with Cohere API");
} else {
debug!("Personalization service initialized without Cohere API key");
}
Self { cohere }
impl CohereService {
pub fn new(api_key: String) -> Self {
info!("Personalization service initialized with Cohere API");
Self { cohere: Cohere::new("https://api.cohere.ai", api_key) }
}
pub async fn rerank_search_results(
@ -29,9 +22,6 @@ impl PersonalizationService {
personalize: Option<&Personalize>,
query: Option<&str>,
) -> Result<SearchResult, ResponseError> {
// If no API key, return original results
let Some(cohere) = &self.cohere else { return Ok(search_result) };
// Extract user context from personalization
let user_context = personalize.and_then(|p| p.user_context.as_deref());
@ -67,7 +57,7 @@ impl PersonalizationService {
};
// Call Cohere's rerank API
match cohere.rerank(&rerank_request).await {
match self.cohere.rerank(&rerank_request).await {
Ok(rerank_response) => {
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
@ -94,6 +84,36 @@ impl PersonalizationService {
}
}
pub enum PersonalizationService {
Cohere(CohereService),
Uninitialized,
}
impl PersonalizationService {
pub fn cohere(api_key: String) -> Self {
Self::Cohere(CohereService::new(api_key))
}
pub fn uninitialized() -> Self {
debug!("Personalization service uninitialized");
Self::Uninitialized
}
pub async fn rerank_search_results(
&self,
search_result: SearchResult,
personalize: Option<&Personalize>,
query: Option<&str>,
) -> Result<SearchResult, ResponseError> {
match self {
Self::Cohere(cohere_service) => {
cohere_service.rerank_search_results(search_result, personalize, query).await
}
Self::Uninitialized => Ok(search_result),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -101,7 +121,7 @@ mod tests {
#[tokio::test]
async fn test_personalization_service_without_api_key() {
let service = PersonalizationService::new(None);
let service = PersonalizationService::uninitialized();
let personalize = Personalize { user_context: Some("test user".to_string()) };
let search_result = SearchResult {
@ -134,7 +154,7 @@ mod tests {
#[tokio::test]
async fn test_personalization_service_with_user_context_only() {
let service = PersonalizationService::new(Some("fake_key".to_string()));
let service = PersonalizationService::cohere("fake_key".to_string());
let personalize = Personalize { user_context: Some("test user".to_string()) };
let search_result = SearchResult {
@ -166,7 +186,7 @@ mod tests {
#[tokio::test]
async fn test_personalization_service_with_query_only() {
let service = PersonalizationService::new(Some("fake_key".to_string()));
let service = PersonalizationService::cohere("fake_key".to_string());
let search_result = SearchResult {
hits: vec![SearchHit {
@ -196,7 +216,7 @@ mod tests {
#[tokio::test]
async fn test_personalization_service_both_none() {
let service = PersonalizationService::new(Some("fake_key".to_string()));
let service = PersonalizationService::cohere("fake_key".to_string());
let search_result = SearchResult {
hits: vec![SearchHit {