diff --git a/crates/meilisearch/src/personalization/mod.rs b/crates/meilisearch/src/personalization/mod.rs index 7c8a0caa1..9d65afabf 100644 --- a/crates/meilisearch/src/personalization/mod.rs +++ b/crates/meilisearch/src/personalization/mod.rs @@ -327,7 +327,12 @@ pub enum PersonalizationService { impl PersonalizationService { pub fn cohere(api_key: String) -> Self { - Self::Cohere(CohereService::new(api_key)) + // If the API key is empty, consider the personalization service as disabled + if api_key.trim().is_empty() { + Self::disabled() + } else { + Self::Cohere(CohereService::new(api_key)) + } } pub fn disabled() -> Self { diff --git a/crates/meilisearch/src/routes/indexes/search.rs b/crates/meilisearch/src/routes/indexes/search.rs index b6a03514b..88f8f94e0 100644 --- a/crates/meilisearch/src/routes/indexes/search.rs +++ b/crates/meilisearch/src/routes/indexes/search.rs @@ -366,7 +366,8 @@ pub async fn search_with_url_query( search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors); - let query_str = query.q.clone(); + // Save the query string for personalization if requested + let personalize_query = personalize.is_some().then(|| query.q.clone()).flatten(); let permit = search_queue.try_get_search_permit().await?; let include_metadata = parse_include_metadata_header(&req); @@ -398,7 +399,12 @@ pub async fn search_with_url_query( // Apply personalization if requested if let Some(personalize) = personalize.as_ref() { search_result = personalization_service - .rerank_search_results(search_result, personalize, query_str.as_deref(), deadline) + .rerank_search_results( + search_result, + personalize, + personalize_query.as_deref(), + deadline, + ) .await?; } @@ -505,7 +511,8 @@ pub async fn search_with_post( let include_metadata = parse_include_metadata_header(&req); - let query_str = personalize.is_some().then(|| query.q.clone()).flatten(); + // Save the query string for personalization if requested + let personalize_query = personalize.is_some().then(|| query.q.clone()).flatten(); let permit = search_queue.try_get_search_permit().await?; let search_result = tokio::task::spawn_blocking(move || { @@ -538,7 +545,12 @@ pub async fn search_with_post( // Apply personalization if requested if let Some(personalize) = personalize.as_ref() { search_result = personalization_service - .rerank_search_results(search_result, personalize, query_str.as_deref(), deadline) + .rerank_search_results( + search_result, + personalize, + personalize_query.as_deref(), + deadline, + ) .await?; } diff --git a/crates/meilisearch/tests/common/service.rs b/crates/meilisearch/tests/common/service.rs index c0b07c217..8730b2aa4 100644 --- a/crates/meilisearch/tests/common/service.rs +++ b/crates/meilisearch/tests/common/service.rs @@ -10,6 +10,7 @@ use actix_web::test::TestRequest; use actix_web::web::Data; use index_scheduler::IndexScheduler; use meilisearch::analytics::Analytics; +use meilisearch::personalization::PersonalizationService; use meilisearch::search_queue::SearchQueue; use meilisearch::{create_app, Opt, SubscriberForSecondLayer}; use meilisearch_auth::AuthController; @@ -135,11 +136,18 @@ impl Service { self.options.experimental_search_queue_size, NonZeroUsize::new(1).unwrap(), ); + let personalization_service = self + .options + .experimental_personalization_api_key + .clone() + .map(PersonalizationService::cohere) + .unwrap_or_else(PersonalizationService::disabled); actix_web::test::init_service(create_app( self.index_scheduler.clone().into(), self.auth.clone().into(), Data::new(search_queue), + Data::new(personalization_service), self.options.clone(), (route_layer_handle, stderr_layer_handle), Data::new(Analytics::no_analytics()), diff --git a/crates/meilisearch/tests/logs/mod.rs b/crates/meilisearch/tests/logs/mod.rs index e4dc50a9c..82fa444ab 100644 --- a/crates/meilisearch/tests/logs/mod.rs +++ b/crates/meilisearch/tests/logs/mod.rs @@ -8,6 +8,7 @@ use actix_web::http::header::ContentType; use actix_web::web::Data; use meili_snap::snapshot; use meilisearch::analytics::Analytics; +use meilisearch::personalization::PersonalizationService; use meilisearch::search_queue::SearchQueue; use meilisearch::{create_app, Opt, SubscriberForSecondLayer}; use tracing::level_filters::LevelFilter; @@ -53,6 +54,7 @@ async fn basic_test_log_stream_route() { server.service.index_scheduler.clone().into(), server.service.auth.clone().into(), Data::new(search_queue), + Data::new(PersonalizationService::disabled()), server.service.options.clone(), (route_layer_handle, stderr_layer_handle), Data::new(Analytics::no_analytics()),