diff --git a/crates/meilisearch-types/src/error.rs b/crates/meilisearch-types/src/error.rs index 2d7185b5b..e7edf7d62 100644 --- a/crates/meilisearch-types/src/error.rs +++ b/crates/meilisearch-types/src/error.rs @@ -258,6 +258,7 @@ InvalidIndexUid , InvalidRequest , BAD_REQU InvalidMultiSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFacetsByIndex , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFacetOrder , InvalidRequest , BAD_REQUEST ; +InvalidMultiSearchQueryPersonalization , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFederated , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFederationOptions , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchMaxValuesPerFacet , InvalidRequest , BAD_REQUEST ; @@ -315,6 +316,8 @@ InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQU InvalidSimilarShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ; InvalidSearchSort , InvalidRequest , BAD_REQUEST ; InvalidSearchDistinct , InvalidRequest , BAD_REQUEST ; +InvalidSearchPersonalize , InvalidRequest , BAD_REQUEST ; +InvalidSearchPersonalizeUserContext , InvalidRequest , BAD_REQUEST ; InvalidSearchMediaAndVector , InvalidRequest , BAD_REQUEST ; InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ; @@ -682,6 +685,18 @@ impl fmt::Display for deserr_codes::InvalidNetworkSearchApiKey { } } +impl fmt::Display for deserr_codes::InvalidSearchPersonalize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "the value of `personalize` is invalid, expected a JSON object with `userContext` string.") + } +} + +impl fmt::Display for deserr_codes::InvalidSearchPersonalizeUserContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "the value of `userContext` is invalid, expected a string.") + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/crates/meilisearch/src/analytics/segment_analytics.rs b/crates/meilisearch/src/analytics/segment_analytics.rs index aec71d431..527b5684b 100644 --- a/crates/meilisearch/src/analytics/segment_analytics.rs +++ b/crates/meilisearch/src/analytics/segment_analytics.rs @@ -208,6 +208,7 @@ struct Infos { experimental_no_edition_2024_for_prefix_post_processing: bool, experimental_no_edition_2024_for_facet_post_processing: bool, experimental_vector_store_setting: bool, + experimental_personalization: bool, gpu_enabled: bool, db_path: bool, import_dump: bool, @@ -286,6 +287,7 @@ impl Infos { indexer_options, config_file_path, no_analytics: _, + experimental_personalization_api_key, s3_snapshot_options, } = options; @@ -374,6 +376,7 @@ impl Infos { experimental_no_edition_2024_for_settings, experimental_no_edition_2024_for_prefix_post_processing, experimental_no_edition_2024_for_facet_post_processing, + experimental_personalization: experimental_personalization_api_key.is_some(), } } } diff --git a/crates/meilisearch/src/error.rs b/crates/meilisearch/src/error.rs index a9cc8cc7b..371e5c67d 100644 --- a/crates/meilisearch/src/error.rs +++ b/crates/meilisearch/src/error.rs @@ -38,6 +38,8 @@ pub enum MeilisearchHttpError { PaginationInFederatedQuery(usize, &'static str), #[error("Inside `.queries[{0}]`: Using facet options is not allowed in federated queries.\n - Hint: remove `facets` from query #{0} or remove `federation` from the request\n - Hint: pass `federation.facetsByIndex.{1}: {2:?}` for facets in federated search")] FacetsInFederatedQuery(usize, String, Vec), + #[error("Inside `.queries[{0}]`: Using `.personalize` is not allowed in federated queries.\n - Hint: remove `personalize` from query #{0} or remove `federation` from the request")] + PersonalizationInFederatedQuery(usize), #[error("Inconsistent order for values in facet `{facet}`: index `{previous_uid}` orders {previous_facet_order}, but index `{current_uid}` orders {index_facet_order}.\n - Hint: Remove `federation.mergeFacets` or change `faceting.sortFacetValuesBy` to be consistent in settings.")] InconsistentFacetOrder { facet: String, @@ -137,6 +139,9 @@ impl ErrorCode for MeilisearchHttpError { MeilisearchHttpError::InconsistentFacetOrder { .. } => { Code::InvalidMultiSearchFacetOrder } + MeilisearchHttpError::PersonalizationInFederatedQuery(_) => { + Code::InvalidMultiSearchQueryPersonalization + } MeilisearchHttpError::InconsistentOriginHeaders { .. } => { Code::InconsistentDocumentChangeHeaders } diff --git a/crates/meilisearch/src/lib.rs b/crates/meilisearch/src/lib.rs index 8a2d2c0a1..9295054e2 100644 --- a/crates/meilisearch/src/lib.rs +++ b/crates/meilisearch/src/lib.rs @@ -11,6 +11,7 @@ pub mod middleware; pub mod option; #[cfg(test)] mod option_test; +pub mod personalization; pub mod routes; pub mod search; pub mod search_queue; @@ -58,6 +59,7 @@ use tracing::{error, info_span}; use tracing_subscriber::filter::Targets; use crate::error::MeilisearchHttpError; +use crate::personalization::PersonalizationService; /// Default number of simultaneously opened indexes. /// @@ -128,12 +130,8 @@ pub type LogStderrType = tracing_subscriber::filter::Filtered< >; pub fn create_app( - index_scheduler: Data, - auth_controller: Data, - search_queue: Data, + services: ServicesData, opt: Opt, - logs: (LogRouteHandle, LogStderrHandle), - analytics: Data, enable_dashboard: bool, ) -> actix_web::App< impl ServiceFactory< @@ -145,17 +143,7 @@ pub fn create_app( >, > { let app = actix_web::App::new() - .configure(|s| { - configure_data( - s, - index_scheduler.clone(), - auth_controller.clone(), - search_queue.clone(), - &opt, - logs, - analytics.clone(), - ) - }) + .configure(|s| configure_data(s, services, &opt)) .configure(routes::configure) .configure(|s| dashboard(s, enable_dashboard)); @@ -690,23 +678,26 @@ fn import_dump( Ok(index_scheduler_dump.finish()?) } -pub fn configure_data( - config: &mut web::ServiceConfig, - index_scheduler: Data, - auth: Data, - search_queue: Data, - opt: &Opt, - (logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle), - analytics: Data, -) { +pub fn configure_data(config: &mut web::ServiceConfig, services: ServicesData, opt: &Opt) { + let ServicesData { + index_scheduler, + auth, + search_queue, + personalization_service, + logs_route_handle, + logs_stderr_handle, + analytics, + } = services; + let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize; config .app_data(index_scheduler) .app_data(auth) .app_data(search_queue) .app_data(analytics) - .app_data(web::Data::new(logs_route)) - .app_data(web::Data::new(logs_stderr)) + .app_data(personalization_service) + .app_data(logs_route_handle) + .app_data(logs_stderr_handle) .app_data(web::Data::new(opt.clone())) .app_data( web::JsonConfig::default() @@ -767,3 +758,14 @@ pub fn dashboard(config: &mut web::ServiceConfig, enable_frontend: bool) { pub fn dashboard(config: &mut web::ServiceConfig, _enable_frontend: bool) { config.service(web::resource("/").route(web::get().to(routes::running))); } + +#[derive(Clone)] +pub struct ServicesData { + pub index_scheduler: Data, + pub auth: Data, + pub search_queue: Data, + pub personalization_service: Data, + pub logs_route_handle: Data, + pub logs_stderr_handle: Data, + pub analytics: Data, +} diff --git a/crates/meilisearch/src/main.rs b/crates/meilisearch/src/main.rs index be0beb97f..539680cd2 100644 --- a/crates/meilisearch/src/main.rs +++ b/crates/meilisearch/src/main.rs @@ -14,10 +14,11 @@ use index_scheduler::IndexScheduler; use is_terminal::IsTerminal; use meilisearch::analytics::Analytics; use meilisearch::option::LogMode; +use meilisearch::personalization::PersonalizationService; use meilisearch::search_queue::SearchQueue; use meilisearch::{ analytics, create_app, setup_meilisearch, LogRouteHandle, LogRouteType, LogStderrHandle, - LogStderrType, Opt, SubscriberForSecondLayer, + LogStderrType, Opt, ServicesData, SubscriberForSecondLayer, }; use meilisearch_auth::{generate_master_key, AuthController, MASTER_KEY_MIN_SIZE}; use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor}; @@ -152,8 +153,15 @@ async fn run_http( let enable_dashboard = &opt.env == "development"; let opt_clone = opt.clone(); let index_scheduler = Data::from(index_scheduler); - let auth_controller = Data::from(auth_controller); + let auth = Data::from(auth_controller); let analytics = Data::from(analytics); + // Create personalization service with API key from options + let personalization_service = Data::new( + opt.experimental_personalization_api_key + .clone() + .map(PersonalizationService::cohere) + .unwrap_or_else(PersonalizationService::disabled), + ); let search_queue = SearchQueue::new( opt.experimental_search_queue_size, available_parallelism() @@ -165,21 +173,25 @@ async fn run_http( usize::from(opt.experimental_drop_search_after) as u64 )); let search_queue = Data::new(search_queue); + let (logs_route_handle, logs_stderr_handle) = logs; + let logs_route_handle = Data::new(logs_route_handle); + let logs_stderr_handle = Data::new(logs_stderr_handle); - let http_server = HttpServer::new(move || { - create_app( - index_scheduler.clone(), - auth_controller.clone(), - search_queue.clone(), - opt.clone(), - logs.clone(), - analytics.clone(), - enable_dashboard, - ) - }) - // Disable signals allows the server to terminate immediately when a user enter CTRL-C - .disable_signals() - .keep_alive(KeepAlive::Os); + let services = ServicesData { + index_scheduler, + auth, + search_queue, + personalization_service, + logs_route_handle, + logs_stderr_handle, + analytics, + }; + + let http_server = + HttpServer::new(move || create_app(services.clone(), opt.clone(), enable_dashboard)) + // Disable signals allows the server to terminate immediately when a user enter CTRL-C + .disable_signals() + .keep_alive(KeepAlive::Os); if let Some(config) = opt_clone.get_ssl_config()? { http_server.bind_rustls_0_23(opt_clone.http_addr, config)?.run().await?; diff --git a/crates/meilisearch/src/metrics.rs b/crates/meilisearch/src/metrics.rs index 607bc91eb..7ae12e355 100644 --- a/crates/meilisearch/src/metrics.rs +++ b/crates/meilisearch/src/metrics.rs @@ -114,4 +114,9 @@ lazy_static! { "Meilisearch Task Queue Size Until Stop Registering", )) .expect("Can't create a metric"); + pub static ref MEILISEARCH_PERSONALIZED_SEARCH_REQUESTS: IntGauge = register_int_gauge!(opts!( + "meilisearch_personalized_search_requests", + "Meilisearch number of search requests with personalization" + )) + .expect("Can't create a metric"); } diff --git a/crates/meilisearch/src/option.rs b/crates/meilisearch/src/option.rs index b054cd4cf..6de8082ce 100644 --- a/crates/meilisearch/src/option.rs +++ b/crates/meilisearch/src/option.rs @@ -75,6 +75,8 @@ const MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES: &str = const MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION: &str = "MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION"; const MEILI_EXPERIMENTAL_NO_EDITION_2024_FOR_DUMPS: &str = "MEILI_EXPERIMENTAL_NO_EDITION_2024_FOR_DUMPS"; +const MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY: &str = + "MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY"; // Related to S3 snapshots const MEILI_S3_BUCKET_URL: &str = "MEILI_S3_BUCKET_URL"; @@ -494,6 +496,12 @@ pub struct Opt { #[serde(default)] pub experimental_no_snapshot_compaction: bool, + /// Experimental personalization API key feature. + /// + /// Sets the API key for personalization features. + #[clap(long, env = MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY)] + pub experimental_personalization_api_key: Option, + #[serde(flatten)] #[clap(flatten)] pub indexer_options: IndexerOpts, @@ -603,6 +611,7 @@ impl Opt { experimental_limit_batched_tasks_total_size, experimental_embedding_cache_entries, experimental_no_snapshot_compaction, + experimental_personalization_api_key, s3_snapshot_options, } = self; export_to_env_if_not_present(MEILI_DB_PATH, db_path); @@ -704,6 +713,12 @@ impl Opt { MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION, experimental_no_snapshot_compaction.to_string(), ); + if let Some(experimental_personalization_api_key) = experimental_personalization_api_key { + export_to_env_if_not_present( + MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY, + experimental_personalization_api_key, + ); + } indexer_options.export_to_env(); if let Some(s3_snapshot_options) = s3_snapshot_options { #[cfg(not(unix))] diff --git a/crates/meilisearch/src/personalization/mod.rs b/crates/meilisearch/src/personalization/mod.rs new file mode 100644 index 000000000..289b6a660 --- /dev/null +++ b/crates/meilisearch/src/personalization/mod.rs @@ -0,0 +1,366 @@ +use crate::search::{Personalize, SearchResult}; +use meilisearch_types::{ + error::{Code, ErrorCode, ResponseError}, + milli::TimeBudget, +}; +use rand::Rng; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tracing::{debug, info, warn}; + +const COHERE_API_URL: &str = "https://api.cohere.ai/v1/rerank"; +const MAX_RETRIES: u32 = 10; + +#[derive(Debug, thiserror::Error)] +enum PersonalizationError { + #[error("Personalization service: HTTP request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("Personalization service: Failed to parse response: {0}")] + Parse(String), + #[error("Personalization service: Cohere API error: {0}")] + Api(String), + #[error("Personalization service: Unauthorized: invalid API key")] + Unauthorized, + #[error("Personalization service: Rate limited: too many requests")] + RateLimited, + #[error("Personalization service: Bad request: {0}")] + BadRequest(String), + #[error("Personalization service: Internal server error: {0}")] + InternalServerError(String), + #[error("Personalization service: Network error: {0}")] + Network(String), + #[error("Personalization service: Deadline exceeded")] + DeadlineExceeded, + #[error(transparent)] + FeatureNotEnabled(#[from] index_scheduler::error::FeatureNotEnabledError), +} + +impl ErrorCode for PersonalizationError { + fn error_code(&self) -> Code { + match self { + PersonalizationError::FeatureNotEnabled { .. } => Code::FeatureNotEnabled, + PersonalizationError::Unauthorized => Code::RemoteInvalidApiKey, + PersonalizationError::RateLimited => Code::TooManySearchRequests, + PersonalizationError::BadRequest(_) => Code::RemoteBadRequest, + PersonalizationError::InternalServerError(_) => Code::RemoteRemoteError, + PersonalizationError::Network(_) | PersonalizationError::Request(_) => { + Code::RemoteCouldNotSendRequest + } + PersonalizationError::Parse(_) | PersonalizationError::Api(_) => { + Code::RemoteBadResponse + } + PersonalizationError::DeadlineExceeded => Code::Internal, // should not be returned to the client + } + } +} + +pub struct CohereService { + client: Client, + api_key: String, +} + +impl CohereService { + pub fn new(api_key: String) -> Self { + info!("Personalization service initialized with Cohere API"); + let client = Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .expect("Failed to create HTTP client"); + Self { client, api_key } + } + + pub async fn rerank_search_results( + &self, + search_result: SearchResult, + personalize: &Personalize, + query: Option<&str>, + time_budget: TimeBudget, + ) -> Result { + if time_budget.exceeded() { + warn!("Could not rerank due to deadline"); + // If the deadline is exceeded, return the original search result instead of an error + return Ok(search_result); + } + + // Extract user context from personalization + let user_context = personalize.user_context.as_str(); + + // Build the prompt by merging query and user context + let prompt = match query { + Some(q) => format!("User Context: {user_context}\nQuery: {q}"), + None => format!("User Context: {user_context}"), + }; + + // Extract documents for reranking + let documents: Vec = search_result + .hits + .iter() + .map(|hit| { + // Convert the document to a string representation for reranking + serde_json::to_string(&hit.document).unwrap_or_else(|_| "{}".to_string()) + }) + .collect(); + + if documents.is_empty() { + return Ok(search_result); + } + + // Call Cohere's rerank API with retry logic + let reranked_indices = + match self.call_rerank_with_retry(&prompt, &documents, time_budget).await { + Ok(indices) => indices, + Err(PersonalizationError::DeadlineExceeded) => { + // If the deadline is exceeded, return the original search result instead of an error + return Ok(search_result); + } + Err(e) => return Err(e.into()), + }; + + debug!("Cohere rerank successful, reordering {} results", search_result.hits.len()); + + // Reorder the hits based on Cohere's reranking + let mut reranked_hits = Vec::new(); + for index in reranked_indices.iter() { + if let Some(hit) = search_result.hits.get(*index) { + reranked_hits.push(hit.clone()); + } + } + + Ok(SearchResult { hits: reranked_hits, ..search_result }) + } + + async fn call_rerank_with_retry( + &self, + query: &str, + documents: &[String], + time_budget: TimeBudget, + ) -> Result, PersonalizationError> { + let request_body = CohereRerankRequest { + query: query.to_string(), + documents: documents.to_vec(), + model: "rerank-english-v3.0".to_string(), + }; + + // Retry loop similar to vector extraction + for attempt in 0..MAX_RETRIES { + let response_result = self.send_rerank_request(&request_body).await; + + let retry_duration = match self.handle_response(response_result).await { + Ok(indices) => return Ok(indices), + Err(retry) => { + warn!("Cohere rerank attempt #{} failed: {}", attempt, retry.error); + + if time_budget.exceeded() { + warn!("Could not rerank due to deadline"); + return Err(PersonalizationError::DeadlineExceeded); + } else { + match retry.into_duration(attempt) { + Ok(d) => d, + Err(error) => return Err(error), + } + } + } + }; + + // randomly up to double the retry duration + let retry_duration = retry_duration + + rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration); + + warn!("Retrying after {}ms", retry_duration.as_millis()); + tokio::time::sleep(retry_duration).await; + } + + // Final attempt without retry + let response_result = self.send_rerank_request(&request_body).await; + + match self.handle_response(response_result).await { + Ok(indices) => Ok(indices), + Err(retry) => Err(retry.into_error()), + } + } + + async fn send_rerank_request( + &self, + request_body: &CohereRerankRequest, + ) -> Result { + self.client + .post(COHERE_API_URL) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(request_body) + .send() + .await + } + + async fn handle_response( + &self, + response_result: Result, + ) -> Result, Retry> { + let response = match response_result { + Ok(r) => r, + Err(e) if e.is_timeout() => { + return Err(Retry::retry_later(PersonalizationError::Network(format!( + "Request timeout: {}", + e + )))); + } + Err(e) => { + return Err(Retry::retry_later(PersonalizationError::Network(format!( + "Network error: {}", + e + )))); + } + }; + + let status = response.status(); + let status_code = status.as_u16(); + + if status.is_success() { + let rerank_response: CohereRerankResponse = match response.json().await { + Ok(r) => r, + Err(e) => { + return Err(Retry::retry_later(PersonalizationError::Parse(format!( + "Failed to parse response: {}", + e + )))); + } + }; + + // Extract indices from rerank results + let indices: Vec = + rerank_response.results.iter().map(|result| result.index as usize).collect(); + + return Ok(indices); + } + + // Handle error status codes + let error_body = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + + let retry = match status_code { + 401 => Retry::give_up(PersonalizationError::Unauthorized), + 429 => Retry::rate_limited(PersonalizationError::RateLimited), + 400 => Retry::give_up(PersonalizationError::BadRequest(error_body)), + 500..=599 => Retry::retry_later(PersonalizationError::InternalServerError(format!( + "Status {}: {}", + status_code, error_body + ))), + 402..=499 => Retry::give_up(PersonalizationError::Api(format!( + "Status {}: {}", + status_code, error_body + ))), + _ => Retry::retry_later(PersonalizationError::Api(format!( + "Unexpected status {}: {}", + status_code, error_body + ))), + }; + + Err(retry) + } +} + +#[derive(Serialize)] +struct CohereRerankRequest { + query: String, + documents: Vec, + model: String, +} + +#[derive(Deserialize)] +struct CohereRerankResponse { + results: Vec, +} + +#[derive(Deserialize)] +struct CohereRerankResult { + index: u32, +} + +// Retry strategy similar to vector extraction +struct Retry { + error: PersonalizationError, + strategy: RetryStrategy, +} + +enum RetryStrategy { + GiveUp, + Retry, + RetryAfterRateLimit, +} + +impl Retry { + fn give_up(error: PersonalizationError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + fn retry_later(error: PersonalizationError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + fn rate_limited(error: PersonalizationError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => { + // Exponential backoff: 10^attempt milliseconds + Ok(Duration::from_millis((10u64).pow(attempt))) + } + RetryStrategy::RetryAfterRateLimit => { + // Longer backoff for rate limits: 100ms + exponential + Ok(Duration::from_millis(100 + (10u64).pow(attempt))) + } + } + } + + fn into_error(self) -> PersonalizationError { + self.error + } +} + +pub enum PersonalizationService { + Cohere(CohereService), + Disabled, +} + +impl PersonalizationService { + pub fn cohere(api_key: String) -> Self { + // If the API key is empty, consider the personalization service as disabled + if api_key.trim().is_empty() { + Self::disabled() + } else { + Self::Cohere(CohereService::new(api_key)) + } + } + + pub fn disabled() -> Self { + debug!("Personalization service disabled"); + Self::Disabled + } + + pub async fn rerank_search_results( + &self, + search_result: SearchResult, + personalize: &Personalize, + query: Option<&str>, + time_budget: TimeBudget, + ) -> Result { + match self { + Self::Cohere(cohere_service) => { + cohere_service + .rerank_search_results(search_result, personalize, query, time_budget) + .await + } + Self::Disabled => Err(PersonalizationError::FeatureNotEnabled( + index_scheduler::error::FeatureNotEnabledError { + disabled_action: "reranking search results", + feature: "personalization", + issue_link: "https://github.com/orgs/meilisearch/discussions/866", + }, + ) + .into()), + } + } +} diff --git a/crates/meilisearch/src/routes/indexes/facet_search.rs b/crates/meilisearch/src/routes/indexes/facet_search.rs index 18ad54ccf..ae904bcb2 100644 --- a/crates/meilisearch/src/routes/indexes/facet_search.rs +++ b/crates/meilisearch/src/routes/indexes/facet_search.rs @@ -343,6 +343,7 @@ impl From for SearchQuery { hybrid, ranking_score_threshold, locales, + personalize: None, } } } diff --git a/crates/meilisearch/src/routes/indexes/search.rs b/crates/meilisearch/src/routes/indexes/search.rs index 8012f2302..ebc9f529b 100644 --- a/crates/meilisearch/src/routes/indexes/search.rs +++ b/crates/meilisearch/src/routes/indexes/search.rs @@ -24,9 +24,9 @@ use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::routes::indexes::search_analytics::{SearchAggregator, SearchGET, SearchPOST}; use crate::routes::parse_include_metadata_header; use crate::search::{ - add_search_rules, perform_search, HybridQuery, MatchingStrategy, RankingScoreThreshold, - RetrieveVectors, SearchKind, SearchParams, SearchQuery, SearchResult, SemanticRatio, - DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, Personalize, + RankingScoreThreshold, RetrieveVectors, SearchKind, SearchParams, SearchQuery, SearchResult, + SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; use crate::search_queue::SearchQueue; @@ -134,6 +134,8 @@ pub struct SearchQueryGet { #[deserr(default, error = DeserrQueryParamError)] #[param(value_type = Vec, explode = false)] pub locales: Option>, + #[deserr(default, error = DeserrQueryParamError)] + pub personalize_user_context: Option, } #[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] @@ -205,6 +207,9 @@ impl TryFrom for SearchQuery { )); } + let personalize = + other.personalize_user_context.map(|user_context| Personalize { user_context }); + Ok(Self { q: other.q, // `media` not supported for `GET` @@ -234,6 +239,7 @@ impl TryFrom for SearchQuery { hybrid, ranking_score_threshold: other.ranking_score_threshold.map(|o| o.0), locales: other.locales.map(|o| o.into_iter().collect()), + personalize, }) } } @@ -322,6 +328,7 @@ pub fn fix_sort_query_parameters(sort_query: &str) -> Vec { pub async fn search_with_url_query( index_scheduler: GuardedData, Data>, search_queue: web::Data, + personalization_service: web::Data, index_uid: web::Path, params: AwebQueryParameter, req: HttpRequest, @@ -342,9 +349,16 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; + // Extract personalization and query string before moving query + let personalize = query.personalize.take(); + let search_kind = search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors); + + // Save the query string for personalization if requested + let personalize_query = personalize.is_some().then(|| query.q.clone()).flatten(); + let permit = search_queue.try_get_search_permit().await?; let include_metadata = parse_include_metadata_header(&req); @@ -365,12 +379,24 @@ pub async fn search_with_url_query( .await; permit.drop().await; let search_result = search_result?; - if let Ok(ref search_result) = search_result { + if let Ok((search_result, _)) = search_result.as_ref() { aggregate.succeed(search_result); } analytics.publish(aggregate, &req); - let search_result = search_result?; + let (mut search_result, time_budget) = search_result?; + + // Apply personalization if requested + if let Some(personalize) = personalize.as_ref() { + search_result = personalization_service + .rerank_search_results( + search_result, + personalize, + personalize_query.as_deref(), + time_budget, + ) + .await?; + } debug!(request_uid = ?request_uid, returns = ?search_result, "Search get"); Ok(HttpResponse::Ok().json(search_result)) @@ -435,6 +461,7 @@ pub async fn search_with_url_query( pub async fn search_with_post( index_scheduler: GuardedData, Data>, search_queue: web::Data, + personalization_service: web::Data, index_uid: web::Path, params: AwebJson, req: HttpRequest, @@ -455,12 +482,18 @@ pub async fn search_with_post( let index = index_scheduler.index(&index_uid)?; + // Extract personalization and query string before moving query + let personalize = query.personalize.take(); + let search_kind = search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors); let include_metadata = parse_include_metadata_header(&req); + // Save the query string for personalization if requested + let personalize_query = personalize.is_some().then(|| query.q.clone()).flatten(); + let permit = search_queue.try_get_search_permit().await?; let search_result = tokio::task::spawn_blocking(move || { perform_search( @@ -479,7 +512,7 @@ pub async fn search_with_post( .await; permit.drop().await; let search_result = search_result?; - if let Ok(ref search_result) = search_result { + if let Ok((ref search_result, _)) = search_result { aggregate.succeed(search_result); if search_result.degraded { MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); @@ -487,7 +520,19 @@ pub async fn search_with_post( } analytics.publish(aggregate, &req); - let search_result = search_result?; + let (mut search_result, time_budget) = search_result?; + + // Apply personalization if requested + if let Some(personalize) = personalize.as_ref() { + search_result = personalization_service + .rerank_search_results( + search_result, + personalize, + personalize_query.as_deref(), + time_budget, + ) + .await?; + } debug!(request_uid = ?request_uid, returns = ?search_result, "Search post"); Ok(HttpResponse::Ok().json(search_result)) diff --git a/crates/meilisearch/src/routes/indexes/search_analytics.rs b/crates/meilisearch/src/routes/indexes/search_analytics.rs index 09045fc4a..549e8af6a 100644 --- a/crates/meilisearch/src/routes/indexes/search_analytics.rs +++ b/crates/meilisearch/src/routes/indexes/search_analytics.rs @@ -7,6 +7,7 @@ use serde_json::{json, Value}; use crate::aggregate_methods; use crate::analytics::{Aggregate, AggregateMethod}; +use crate::metrics::MEILISEARCH_PERSONALIZED_SEARCH_REQUESTS; use crate::search::{ SearchQuery, SearchResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, @@ -95,6 +96,9 @@ pub struct SearchAggregator { show_ranking_score_details: bool, ranking_score_threshold: bool, + // personalization + total_personalized: usize, + marker: std::marker::PhantomData, } @@ -129,6 +133,7 @@ impl SearchAggregator { hybrid, ranking_score_threshold, locales, + personalize, } = query; let mut ret = Self::default(); @@ -204,6 +209,12 @@ impl SearchAggregator { ret.locales = locales.iter().copied().collect(); } + // personalization + if personalize.is_some() { + ret.total_personalized = 1; + MEILISEARCH_PERSONALIZED_SEARCH_REQUESTS.inc(); + } + ret.highlight_pre_tag = *highlight_pre_tag != DEFAULT_HIGHLIGHT_PRE_TAG(); ret.highlight_post_tag = *highlight_post_tag != DEFAULT_HIGHLIGHT_POST_TAG(); ret.crop_marker = *crop_marker != DEFAULT_CROP_MARKER(); @@ -296,6 +307,7 @@ impl Aggregate for SearchAggregator { total_used_negative_operator, ranking_score_threshold, mut locales, + total_personalized, marker: _, } = *new; @@ -381,6 +393,9 @@ impl Aggregate for SearchAggregator { // locales self.locales.append(&mut locales); + // personalization + self.total_personalized = self.total_personalized.saturating_add(total_personalized); + self } @@ -426,6 +441,7 @@ impl Aggregate for SearchAggregator { total_used_negative_operator, ranking_score_threshold, locales, + total_personalized, marker: _, } = *self; @@ -499,6 +515,9 @@ impl Aggregate for SearchAggregator { "show_ranking_score_details": show_ranking_score_details, "ranking_score_threshold": ranking_score_threshold, }, + "personalization": { + "total_personalized": total_personalized, + }, }) } } diff --git a/crates/meilisearch/src/routes/multi_search.rs b/crates/meilisearch/src/routes/multi_search.rs index 4e833072a..e9e4140b2 100644 --- a/crates/meilisearch/src/routes/multi_search.rs +++ b/crates/meilisearch/src/routes/multi_search.rs @@ -146,6 +146,7 @@ pub struct SearchResults { pub async fn multi_search_with_post( index_scheduler: GuardedData, Data>, search_queue: Data, + personalization_service: web::Data, params: AwebJson, req: HttpRequest, analytics: web::Data, @@ -236,7 +237,7 @@ pub async fn multi_search_with_post( // changes. let search_results: Result<_, (ResponseError, usize)> = async { let mut search_results = Vec::with_capacity(queries.len()); - for (query_index, (index_uid, query, federation_options)) in queries + for (query_index, (index_uid, mut query, federation_options)) in queries .into_iter() .map(SearchQueryWithIndex::into_index_query_federation) .enumerate() @@ -269,6 +270,13 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; + // Extract personalization and query string before moving query + let personalize = query.personalize.take(); + + // Save the query string for personalization if requested + let personalize_query = + personalize.is_some().then(|| query.q.clone()).flatten(); + let index_uid_str = index_uid.to_string(); let search_kind = search_kind( @@ -280,7 +288,7 @@ pub async fn multi_search_with_post( .with_index(query_index)?; let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors); - let search_result = tokio::task::spawn_blocking(move || { + let (mut search_result, time_budget) = tokio::task::spawn_blocking(move || { perform_search( SearchParams { index_uid: index_uid_str.clone(), @@ -295,11 +303,25 @@ pub async fn multi_search_with_post( ) }) .await + .with_index(query_index)? .with_index(query_index)?; + // Apply personalization if requested + if let Some(personalize) = personalize.as_ref() { + search_result = personalization_service + .rerank_search_results( + search_result, + personalize, + personalize_query.as_deref(), + time_budget, + ) + .await + .with_index(query_index)?; + } + search_results.push(SearchResultWithIndex { index_uid: index_uid.into_inner(), - result: search_result.with_index(query_index)?, + result: search_result, }); } Ok(search_results) diff --git a/crates/meilisearch/src/routes/multi_search_analytics.rs b/crates/meilisearch/src/routes/multi_search_analytics.rs index c24875797..830f4a0e5 100644 --- a/crates/meilisearch/src/routes/multi_search_analytics.rs +++ b/crates/meilisearch/src/routes/multi_search_analytics.rs @@ -67,6 +67,7 @@ impl MultiSearchAggregator { hybrid: _, ranking_score_threshold: _, locales: _, + personalize: _, } in &federated_search.queries { if let Some(federation_options) = federation_options { diff --git a/crates/meilisearch/src/search/federated/perform.rs b/crates/meilisearch/src/search/federated/perform.rs index 8cd03cc75..7bc32d463 100644 --- a/crates/meilisearch/src/search/federated/perform.rs +++ b/crates/meilisearch/src/search/federated/perform.rs @@ -601,6 +601,10 @@ impl PartitionedQueries { .into()); } + if federated_query.has_personalize() { + return Err(MeilisearchHttpError::PersonalizationInFederatedQuery(query_index).into()); + } + let (index_uid, query, federation_options) = federated_query.into_index_query_federation(); let federation_options = federation_options.unwrap_or_default(); diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 4a09df8fa..1c4f84c42 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -59,6 +59,13 @@ pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5); pub const INCLUDE_METADATA_HEADER: &str = "Meili-Include-Metadata"; +#[derive(Clone, Default, PartialEq, Deserr, ToSchema, Debug)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct Personalize { + #[deserr(error = DeserrJsonError)] + pub user_context: String, +} + #[derive(Clone, Default, PartialEq, Deserr, ToSchema)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct SearchQuery { @@ -122,6 +129,8 @@ pub struct SearchQuery { pub ranking_score_threshold: Option, #[deserr(default, error = DeserrJsonError)] pub locales: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub personalize: Option, } impl From for SearchQuery { @@ -169,6 +178,7 @@ impl From for SearchQuery { highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), crop_marker: DEFAULT_CROP_MARKER(), locales: None, + personalize: None, } } } @@ -250,6 +260,7 @@ impl fmt::Debug for SearchQuery { attributes_to_search_on, ranking_score_threshold, locales, + personalize, } = self; let mut debug = f.debug_struct("SearchQuery"); @@ -338,6 +349,10 @@ impl fmt::Debug for SearchQuery { debug.field("locales", &locales); } + if let Some(personalize) = personalize { + debug.field("personalize", &personalize); + } + debug.finish() } } @@ -543,6 +558,9 @@ pub struct SearchQueryWithIndex { pub ranking_score_threshold: Option, #[deserr(default, error = DeserrJsonError, default)] pub locales: Option>, + #[deserr(default, error = DeserrJsonError, default)] + #[serde(skip)] + pub personalize: Option, #[deserr(default)] pub federation_options: Option, @@ -567,6 +585,10 @@ impl SearchQueryWithIndex { self.facets.as_deref().filter(|v| !v.is_empty()) } + pub fn has_personalize(&self) -> bool { + self.personalize.is_some() + } + pub fn from_index_query_federation( index_uid: IndexUid, query: SearchQuery, @@ -600,6 +622,7 @@ impl SearchQueryWithIndex { attributes_to_search_on, ranking_score_threshold, locales, + personalize, } = query; SearchQueryWithIndex { @@ -631,6 +654,7 @@ impl SearchQueryWithIndex { attributes_to_search_on, ranking_score_threshold, locales, + personalize, federation_options, } } @@ -666,6 +690,7 @@ impl SearchQueryWithIndex { hybrid, ranking_score_threshold, locales, + personalize, } = self; ( index_uid, @@ -697,6 +722,7 @@ impl SearchQueryWithIndex { hybrid, ranking_score_threshold, locales, + personalize, // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, @@ -1149,7 +1175,10 @@ pub struct SearchParams { pub include_metadata: bool, } -pub fn perform_search(params: SearchParams, index: &Index) -> Result { +pub fn perform_search( + params: SearchParams, + index: &Index, +) -> Result<(SearchResult, TimeBudget), ResponseError> { let SearchParams { index_uid, query, @@ -1168,7 +1197,7 @@ pub fn perform_search(params: SearchParams, index: &Index) -> Result Result Result