mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-09-06 04:36:32 +00:00
refactor: split PersonalizationService into enum with CohereService
- Refactor PersonalizationService as enum with Cohere and Uninitialized variants - Create dedicated CohereService struct with rerank_search_results method - Split constructor into cohere() and uninitialized() methods - Move all Cohere logic into CohereService for better separation of concerns - Update tests and lib.rs to use new API - Improve code organization and maintainability
This commit is contained in:
@ -678,9 +678,11 @@ pub fn configure_data(
|
|||||||
analytics: Data<Analytics>,
|
analytics: Data<Analytics>,
|
||||||
) {
|
) {
|
||||||
// Create personalization service with API key from options
|
// Create personalization service with API key from options
|
||||||
let personalization_service = personalization::PersonalizationService::new(
|
let personalization_service = index_scheduler
|
||||||
index_scheduler.experimental_personalization_api_key().cloned(),
|
.experimental_personalization_api_key()
|
||||||
);
|
.cloned()
|
||||||
|
.map(personalization::PersonalizationService::cohere)
|
||||||
|
.unwrap_or_else(personalization::PersonalizationService::uninitialized);
|
||||||
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
|
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
|
||||||
config
|
config
|
||||||
.app_data(index_scheduler)
|
.app_data(index_scheduler)
|
||||||
|
@ -6,21 +6,14 @@ use cohere_rust::{
|
|||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
pub struct PersonalizationService {
|
pub struct CohereService {
|
||||||
cohere: Option<Cohere>,
|
cohere: Cohere,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PersonalizationService {
|
impl CohereService {
|
||||||
pub fn new(api_key: Option<String>) -> Self {
|
pub fn new(api_key: String) -> Self {
|
||||||
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
|
|
||||||
|
|
||||||
if cohere.is_some() {
|
|
||||||
info!("Personalization service initialized with Cohere API");
|
info!("Personalization service initialized with Cohere API");
|
||||||
} else {
|
Self { cohere: Cohere::new("https://api.cohere.ai", api_key) }
|
||||||
debug!("Personalization service initialized without Cohere API key");
|
|
||||||
}
|
|
||||||
|
|
||||||
Self { cohere }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn rerank_search_results(
|
pub async fn rerank_search_results(
|
||||||
@ -29,9 +22,6 @@ impl PersonalizationService {
|
|||||||
personalize: Option<&Personalize>,
|
personalize: Option<&Personalize>,
|
||||||
query: Option<&str>,
|
query: Option<&str>,
|
||||||
) -> Result<SearchResult, ResponseError> {
|
) -> Result<SearchResult, ResponseError> {
|
||||||
// If no API key, return original results
|
|
||||||
let Some(cohere) = &self.cohere else { return Ok(search_result) };
|
|
||||||
|
|
||||||
// Extract user context from personalization
|
// Extract user context from personalization
|
||||||
let user_context = personalize.and_then(|p| p.user_context.as_deref());
|
let user_context = personalize.and_then(|p| p.user_context.as_deref());
|
||||||
|
|
||||||
@ -67,7 +57,7 @@ impl PersonalizationService {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Call Cohere's rerank API
|
// Call Cohere's rerank API
|
||||||
match cohere.rerank(&rerank_request).await {
|
match self.cohere.rerank(&rerank_request).await {
|
||||||
Ok(rerank_response) => {
|
Ok(rerank_response) => {
|
||||||
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
|
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
|
||||||
|
|
||||||
@ -94,6 +84,36 @@ impl PersonalizationService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum PersonalizationService {
|
||||||
|
Cohere(CohereService),
|
||||||
|
Uninitialized,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PersonalizationService {
|
||||||
|
pub fn cohere(api_key: String) -> Self {
|
||||||
|
Self::Cohere(CohereService::new(api_key))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn uninitialized() -> Self {
|
||||||
|
debug!("Personalization service uninitialized");
|
||||||
|
Self::Uninitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn rerank_search_results(
|
||||||
|
&self,
|
||||||
|
search_result: SearchResult,
|
||||||
|
personalize: Option<&Personalize>,
|
||||||
|
query: Option<&str>,
|
||||||
|
) -> Result<SearchResult, ResponseError> {
|
||||||
|
match self {
|
||||||
|
Self::Cohere(cohere_service) => {
|
||||||
|
cohere_service.rerank_search_results(search_result, personalize, query).await
|
||||||
|
}
|
||||||
|
Self::Uninitialized => Ok(search_result),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -101,7 +121,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_personalization_service_without_api_key() {
|
async fn test_personalization_service_without_api_key() {
|
||||||
let service = PersonalizationService::new(None);
|
let service = PersonalizationService::uninitialized();
|
||||||
let personalize = Personalize { user_context: Some("test user".to_string()) };
|
let personalize = Personalize { user_context: Some("test user".to_string()) };
|
||||||
|
|
||||||
let search_result = SearchResult {
|
let search_result = SearchResult {
|
||||||
@ -134,7 +154,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_personalization_service_with_user_context_only() {
|
async fn test_personalization_service_with_user_context_only() {
|
||||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
let service = PersonalizationService::cohere("fake_key".to_string());
|
||||||
let personalize = Personalize { user_context: Some("test user".to_string()) };
|
let personalize = Personalize { user_context: Some("test user".to_string()) };
|
||||||
|
|
||||||
let search_result = SearchResult {
|
let search_result = SearchResult {
|
||||||
@ -166,7 +186,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_personalization_service_with_query_only() {
|
async fn test_personalization_service_with_query_only() {
|
||||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
let service = PersonalizationService::cohere("fake_key".to_string());
|
||||||
|
|
||||||
let search_result = SearchResult {
|
let search_result = SearchResult {
|
||||||
hits: vec![SearchHit {
|
hits: vec![SearchHit {
|
||||||
@ -196,7 +216,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_personalization_service_both_none() {
|
async fn test_personalization_service_both_none() {
|
||||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
let service = PersonalizationService::cohere("fake_key".to_string());
|
||||||
|
|
||||||
let search_result = SearchResult {
|
let search_result = SearchResult {
|
||||||
hits: vec![SearchHit {
|
hits: vec![SearchHit {
|
||||||
|
Reference in New Issue
Block a user