mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-09-06 12:46:31 +00:00
refactor: split PersonalizationService into enum with CohereService
- Refactor PersonalizationService as enum with Cohere and Uninitialized variants - Create dedicated CohereService struct with rerank_search_results method - Split constructor into cohere() and uninitialized() methods - Move all Cohere logic into CohereService for better separation of concerns - Update tests and lib.rs to use new API - Improve code organization and maintainability
This commit is contained in:
@ -678,9 +678,11 @@ pub fn configure_data(
|
||||
analytics: Data<Analytics>,
|
||||
) {
|
||||
// Create personalization service with API key from options
|
||||
let personalization_service = personalization::PersonalizationService::new(
|
||||
index_scheduler.experimental_personalization_api_key().cloned(),
|
||||
);
|
||||
let personalization_service = index_scheduler
|
||||
.experimental_personalization_api_key()
|
||||
.cloned()
|
||||
.map(personalization::PersonalizationService::cohere)
|
||||
.unwrap_or_else(personalization::PersonalizationService::uninitialized);
|
||||
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
|
||||
config
|
||||
.app_data(index_scheduler)
|
||||
|
@ -6,21 +6,14 @@ use cohere_rust::{
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
pub struct PersonalizationService {
|
||||
cohere: Option<Cohere>,
|
||||
pub struct CohereService {
|
||||
cohere: Cohere,
|
||||
}
|
||||
|
||||
impl PersonalizationService {
|
||||
pub fn new(api_key: Option<String>) -> Self {
|
||||
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
|
||||
|
||||
if cohere.is_some() {
|
||||
info!("Personalization service initialized with Cohere API");
|
||||
} else {
|
||||
debug!("Personalization service initialized without Cohere API key");
|
||||
}
|
||||
|
||||
Self { cohere }
|
||||
impl CohereService {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
info!("Personalization service initialized with Cohere API");
|
||||
Self { cohere: Cohere::new("https://api.cohere.ai", api_key) }
|
||||
}
|
||||
|
||||
pub async fn rerank_search_results(
|
||||
@ -29,9 +22,6 @@ impl PersonalizationService {
|
||||
personalize: Option<&Personalize>,
|
||||
query: Option<&str>,
|
||||
) -> Result<SearchResult, ResponseError> {
|
||||
// If no API key, return original results
|
||||
let Some(cohere) = &self.cohere else { return Ok(search_result) };
|
||||
|
||||
// Extract user context from personalization
|
||||
let user_context = personalize.and_then(|p| p.user_context.as_deref());
|
||||
|
||||
@ -67,7 +57,7 @@ impl PersonalizationService {
|
||||
};
|
||||
|
||||
// Call Cohere's rerank API
|
||||
match cohere.rerank(&rerank_request).await {
|
||||
match self.cohere.rerank(&rerank_request).await {
|
||||
Ok(rerank_response) => {
|
||||
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
|
||||
|
||||
@ -94,6 +84,36 @@ impl PersonalizationService {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum PersonalizationService {
|
||||
Cohere(CohereService),
|
||||
Uninitialized,
|
||||
}
|
||||
|
||||
impl PersonalizationService {
|
||||
pub fn cohere(api_key: String) -> Self {
|
||||
Self::Cohere(CohereService::new(api_key))
|
||||
}
|
||||
|
||||
pub fn uninitialized() -> Self {
|
||||
debug!("Personalization service uninitialized");
|
||||
Self::Uninitialized
|
||||
}
|
||||
|
||||
pub async fn rerank_search_results(
|
||||
&self,
|
||||
search_result: SearchResult,
|
||||
personalize: Option<&Personalize>,
|
||||
query: Option<&str>,
|
||||
) -> Result<SearchResult, ResponseError> {
|
||||
match self {
|
||||
Self::Cohere(cohere_service) => {
|
||||
cohere_service.rerank_search_results(search_result, personalize, query).await
|
||||
}
|
||||
Self::Uninitialized => Ok(search_result),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -101,7 +121,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_personalization_service_without_api_key() {
|
||||
let service = PersonalizationService::new(None);
|
||||
let service = PersonalizationService::uninitialized();
|
||||
let personalize = Personalize { user_context: Some("test user".to_string()) };
|
||||
|
||||
let search_result = SearchResult {
|
||||
@ -134,7 +154,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_personalization_service_with_user_context_only() {
|
||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
||||
let service = PersonalizationService::cohere("fake_key".to_string());
|
||||
let personalize = Personalize { user_context: Some("test user".to_string()) };
|
||||
|
||||
let search_result = SearchResult {
|
||||
@ -166,7 +186,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_personalization_service_with_query_only() {
|
||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
||||
let service = PersonalizationService::cohere("fake_key".to_string());
|
||||
|
||||
let search_result = SearchResult {
|
||||
hits: vec![SearchHit {
|
||||
@ -196,7 +216,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_personalization_service_both_none() {
|
||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
||||
let service = PersonalizationService::cohere("fake_key".to_string());
|
||||
|
||||
let search_result = SearchResult {
|
||||
hits: vec![SearchHit {
|
||||
|
Reference in New Issue
Block a user