Merge pull request #5775 from meilisearch/experimental-search-personnalization

Experimental search personalization
This commit is contained in:
Louis Dureuil
2025-11-06 18:45:18 +00:00
committed by GitHub
19 changed files with 844 additions and 69 deletions

View File

@@ -258,6 +258,7 @@ InvalidIndexUid , InvalidRequest , BAD_REQU
InvalidMultiSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFacets , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchFacetsByIndex , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFacetsByIndex , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchFacetOrder , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFacetOrder , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchQueryPersonalization , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchFederated , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFederated , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchFederationOptions , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchFederationOptions , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchMaxValuesPerFacet , InvalidRequest , BAD_REQUEST ; InvalidMultiSearchMaxValuesPerFacet , InvalidRequest , BAD_REQUEST ;
@@ -315,6 +316,8 @@ InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQU
InvalidSimilarShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ; InvalidSimilarShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ;
InvalidSearchSort , InvalidRequest , BAD_REQUEST ; InvalidSearchSort , InvalidRequest , BAD_REQUEST ;
InvalidSearchDistinct , InvalidRequest , BAD_REQUEST ; InvalidSearchDistinct , InvalidRequest , BAD_REQUEST ;
InvalidSearchPersonalize , InvalidRequest , BAD_REQUEST ;
InvalidSearchPersonalizeUserContext , InvalidRequest , BAD_REQUEST ;
InvalidSearchMediaAndVector , InvalidRequest , BAD_REQUEST ; InvalidSearchMediaAndVector , InvalidRequest , BAD_REQUEST ;
InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ;
InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ; InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ;
@@ -682,6 +685,18 @@ impl fmt::Display for deserr_codes::InvalidNetworkSearchApiKey {
} }
} }
impl fmt::Display for deserr_codes::InvalidSearchPersonalize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "the value of `personalize` is invalid, expected a JSON object with `userContext` string.")
}
}
impl fmt::Display for deserr_codes::InvalidSearchPersonalizeUserContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "the value of `userContext` is invalid, expected a string.")
}
}
#[macro_export] #[macro_export]
macro_rules! internal_error { macro_rules! internal_error {
($target:ty : $($other:path), *) => { ($target:ty : $($other:path), *) => {

View File

@@ -208,6 +208,7 @@ struct Infos {
experimental_no_edition_2024_for_prefix_post_processing: bool, experimental_no_edition_2024_for_prefix_post_processing: bool,
experimental_no_edition_2024_for_facet_post_processing: bool, experimental_no_edition_2024_for_facet_post_processing: bool,
experimental_vector_store_setting: bool, experimental_vector_store_setting: bool,
experimental_personalization: bool,
gpu_enabled: bool, gpu_enabled: bool,
db_path: bool, db_path: bool,
import_dump: bool, import_dump: bool,
@@ -286,6 +287,7 @@ impl Infos {
indexer_options, indexer_options,
config_file_path, config_file_path,
no_analytics: _, no_analytics: _,
experimental_personalization_api_key,
s3_snapshot_options, s3_snapshot_options,
} = options; } = options;
@@ -374,6 +376,7 @@ impl Infos {
experimental_no_edition_2024_for_settings, experimental_no_edition_2024_for_settings,
experimental_no_edition_2024_for_prefix_post_processing, experimental_no_edition_2024_for_prefix_post_processing,
experimental_no_edition_2024_for_facet_post_processing, experimental_no_edition_2024_for_facet_post_processing,
experimental_personalization: experimental_personalization_api_key.is_some(),
} }
} }
} }

View File

@@ -38,6 +38,8 @@ pub enum MeilisearchHttpError {
PaginationInFederatedQuery(usize, &'static str), PaginationInFederatedQuery(usize, &'static str),
#[error("Inside `.queries[{0}]`: Using facet options is not allowed in federated queries.\n - Hint: remove `facets` from query #{0} or remove `federation` from the request\n - Hint: pass `federation.facetsByIndex.{1}: {2:?}` for facets in federated search")] #[error("Inside `.queries[{0}]`: Using facet options is not allowed in federated queries.\n - Hint: remove `facets` from query #{0} or remove `federation` from the request\n - Hint: pass `federation.facetsByIndex.{1}: {2:?}` for facets in federated search")]
FacetsInFederatedQuery(usize, String, Vec<String>), FacetsInFederatedQuery(usize, String, Vec<String>),
#[error("Inside `.queries[{0}]`: Using `.personalize` is not allowed in federated queries.\n - Hint: remove `personalize` from query #{0} or remove `federation` from the request")]
PersonalizationInFederatedQuery(usize),
#[error("Inconsistent order for values in facet `{facet}`: index `{previous_uid}` orders {previous_facet_order}, but index `{current_uid}` orders {index_facet_order}.\n - Hint: Remove `federation.mergeFacets` or change `faceting.sortFacetValuesBy` to be consistent in settings.")] #[error("Inconsistent order for values in facet `{facet}`: index `{previous_uid}` orders {previous_facet_order}, but index `{current_uid}` orders {index_facet_order}.\n - Hint: Remove `federation.mergeFacets` or change `faceting.sortFacetValuesBy` to be consistent in settings.")]
InconsistentFacetOrder { InconsistentFacetOrder {
facet: String, facet: String,
@@ -137,6 +139,9 @@ impl ErrorCode for MeilisearchHttpError {
MeilisearchHttpError::InconsistentFacetOrder { .. } => { MeilisearchHttpError::InconsistentFacetOrder { .. } => {
Code::InvalidMultiSearchFacetOrder Code::InvalidMultiSearchFacetOrder
} }
MeilisearchHttpError::PersonalizationInFederatedQuery(_) => {
Code::InvalidMultiSearchQueryPersonalization
}
MeilisearchHttpError::InconsistentOriginHeaders { .. } => { MeilisearchHttpError::InconsistentOriginHeaders { .. } => {
Code::InconsistentDocumentChangeHeaders Code::InconsistentDocumentChangeHeaders
} }

View File

@@ -11,6 +11,7 @@ pub mod middleware;
pub mod option; pub mod option;
#[cfg(test)] #[cfg(test)]
mod option_test; mod option_test;
pub mod personalization;
pub mod routes; pub mod routes;
pub mod search; pub mod search;
pub mod search_queue; pub mod search_queue;
@@ -58,6 +59,7 @@ use tracing::{error, info_span};
use tracing_subscriber::filter::Targets; use tracing_subscriber::filter::Targets;
use crate::error::MeilisearchHttpError; use crate::error::MeilisearchHttpError;
use crate::personalization::PersonalizationService;
/// Default number of simultaneously opened indexes. /// Default number of simultaneously opened indexes.
/// ///
@@ -128,12 +130,8 @@ pub type LogStderrType = tracing_subscriber::filter::Filtered<
>; >;
pub fn create_app( pub fn create_app(
index_scheduler: Data<IndexScheduler>, services: ServicesData,
auth_controller: Data<AuthController>,
search_queue: Data<SearchQueue>,
opt: Opt, opt: Opt,
logs: (LogRouteHandle, LogStderrHandle),
analytics: Data<Analytics>,
enable_dashboard: bool, enable_dashboard: bool,
) -> actix_web::App< ) -> actix_web::App<
impl ServiceFactory< impl ServiceFactory<
@@ -145,17 +143,7 @@ pub fn create_app(
>, >,
> { > {
let app = actix_web::App::new() let app = actix_web::App::new()
.configure(|s| { .configure(|s| configure_data(s, services, &opt))
configure_data(
s,
index_scheduler.clone(),
auth_controller.clone(),
search_queue.clone(),
&opt,
logs,
analytics.clone(),
)
})
.configure(routes::configure) .configure(routes::configure)
.configure(|s| dashboard(s, enable_dashboard)); .configure(|s| dashboard(s, enable_dashboard));
@@ -690,23 +678,26 @@ fn import_dump(
Ok(index_scheduler_dump.finish()?) Ok(index_scheduler_dump.finish()?)
} }
pub fn configure_data( pub fn configure_data(config: &mut web::ServiceConfig, services: ServicesData, opt: &Opt) {
config: &mut web::ServiceConfig, let ServicesData {
index_scheduler: Data<IndexScheduler>, index_scheduler,
auth: Data<AuthController>, auth,
search_queue: Data<SearchQueue>, search_queue,
opt: &Opt, personalization_service,
(logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle), logs_route_handle,
analytics: Data<Analytics>, logs_stderr_handle,
) { analytics,
} = services;
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize; let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
config config
.app_data(index_scheduler) .app_data(index_scheduler)
.app_data(auth) .app_data(auth)
.app_data(search_queue) .app_data(search_queue)
.app_data(analytics) .app_data(analytics)
.app_data(web::Data::new(logs_route)) .app_data(personalization_service)
.app_data(web::Data::new(logs_stderr)) .app_data(logs_route_handle)
.app_data(logs_stderr_handle)
.app_data(web::Data::new(opt.clone())) .app_data(web::Data::new(opt.clone()))
.app_data( .app_data(
web::JsonConfig::default() web::JsonConfig::default()
@@ -767,3 +758,14 @@ pub fn dashboard(config: &mut web::ServiceConfig, enable_frontend: bool) {
pub fn dashboard(config: &mut web::ServiceConfig, _enable_frontend: bool) { pub fn dashboard(config: &mut web::ServiceConfig, _enable_frontend: bool) {
config.service(web::resource("/").route(web::get().to(routes::running))); config.service(web::resource("/").route(web::get().to(routes::running)));
} }
#[derive(Clone)]
pub struct ServicesData {
pub index_scheduler: Data<IndexScheduler>,
pub auth: Data<AuthController>,
pub search_queue: Data<SearchQueue>,
pub personalization_service: Data<PersonalizationService>,
pub logs_route_handle: Data<LogRouteHandle>,
pub logs_stderr_handle: Data<LogStderrHandle>,
pub analytics: Data<Analytics>,
}

View File

@@ -14,10 +14,11 @@ use index_scheduler::IndexScheduler;
use is_terminal::IsTerminal; use is_terminal::IsTerminal;
use meilisearch::analytics::Analytics; use meilisearch::analytics::Analytics;
use meilisearch::option::LogMode; use meilisearch::option::LogMode;
use meilisearch::personalization::PersonalizationService;
use meilisearch::search_queue::SearchQueue; use meilisearch::search_queue::SearchQueue;
use meilisearch::{ use meilisearch::{
analytics, create_app, setup_meilisearch, LogRouteHandle, LogRouteType, LogStderrHandle, analytics, create_app, setup_meilisearch, LogRouteHandle, LogRouteType, LogStderrHandle,
LogStderrType, Opt, SubscriberForSecondLayer, LogStderrType, Opt, ServicesData, SubscriberForSecondLayer,
}; };
use meilisearch_auth::{generate_master_key, AuthController, MASTER_KEY_MIN_SIZE}; use meilisearch_auth::{generate_master_key, AuthController, MASTER_KEY_MIN_SIZE};
use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor}; use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor};
@@ -152,8 +153,15 @@ async fn run_http(
let enable_dashboard = &opt.env == "development"; let enable_dashboard = &opt.env == "development";
let opt_clone = opt.clone(); let opt_clone = opt.clone();
let index_scheduler = Data::from(index_scheduler); let index_scheduler = Data::from(index_scheduler);
let auth_controller = Data::from(auth_controller); let auth = Data::from(auth_controller);
let analytics = Data::from(analytics); let analytics = Data::from(analytics);
// Create personalization service with API key from options
let personalization_service = Data::new(
opt.experimental_personalization_api_key
.clone()
.map(PersonalizationService::cohere)
.unwrap_or_else(PersonalizationService::disabled),
);
let search_queue = SearchQueue::new( let search_queue = SearchQueue::new(
opt.experimental_search_queue_size, opt.experimental_search_queue_size,
available_parallelism() available_parallelism()
@@ -165,21 +173,25 @@ async fn run_http(
usize::from(opt.experimental_drop_search_after) as u64 usize::from(opt.experimental_drop_search_after) as u64
)); ));
let search_queue = Data::new(search_queue); let search_queue = Data::new(search_queue);
let (logs_route_handle, logs_stderr_handle) = logs;
let logs_route_handle = Data::new(logs_route_handle);
let logs_stderr_handle = Data::new(logs_stderr_handle);
let http_server = HttpServer::new(move || { let services = ServicesData {
create_app( index_scheduler,
index_scheduler.clone(), auth,
auth_controller.clone(), search_queue,
search_queue.clone(), personalization_service,
opt.clone(), logs_route_handle,
logs.clone(), logs_stderr_handle,
analytics.clone(), analytics,
enable_dashboard, };
)
}) let http_server =
// Disable signals allows the server to terminate immediately when a user enter CTRL-C HttpServer::new(move || create_app(services.clone(), opt.clone(), enable_dashboard))
.disable_signals() // Disable signals allows the server to terminate immediately when a user enter CTRL-C
.keep_alive(KeepAlive::Os); .disable_signals()
.keep_alive(KeepAlive::Os);
if let Some(config) = opt_clone.get_ssl_config()? { if let Some(config) = opt_clone.get_ssl_config()? {
http_server.bind_rustls_0_23(opt_clone.http_addr, config)?.run().await?; http_server.bind_rustls_0_23(opt_clone.http_addr, config)?.run().await?;

View File

@@ -114,4 +114,9 @@ lazy_static! {
"Meilisearch Task Queue Size Until Stop Registering", "Meilisearch Task Queue Size Until Stop Registering",
)) ))
.expect("Can't create a metric"); .expect("Can't create a metric");
pub static ref MEILISEARCH_PERSONALIZED_SEARCH_REQUESTS: IntGauge = register_int_gauge!(opts!(
"meilisearch_personalized_search_requests",
"Meilisearch number of search requests with personalization"
))
.expect("Can't create a metric");
} }

View File

@@ -75,6 +75,8 @@ const MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES: &str =
const MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION: &str = "MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION"; const MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION: &str = "MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION";
const MEILI_EXPERIMENTAL_NO_EDITION_2024_FOR_DUMPS: &str = const MEILI_EXPERIMENTAL_NO_EDITION_2024_FOR_DUMPS: &str =
"MEILI_EXPERIMENTAL_NO_EDITION_2024_FOR_DUMPS"; "MEILI_EXPERIMENTAL_NO_EDITION_2024_FOR_DUMPS";
const MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY: &str =
"MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY";
// Related to S3 snapshots // Related to S3 snapshots
const MEILI_S3_BUCKET_URL: &str = "MEILI_S3_BUCKET_URL"; const MEILI_S3_BUCKET_URL: &str = "MEILI_S3_BUCKET_URL";
@@ -494,6 +496,12 @@ pub struct Opt {
#[serde(default)] #[serde(default)]
pub experimental_no_snapshot_compaction: bool, pub experimental_no_snapshot_compaction: bool,
/// Experimental personalization API key feature.
///
/// Sets the API key for personalization features.
#[clap(long, env = MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY)]
pub experimental_personalization_api_key: Option<String>,
#[serde(flatten)] #[serde(flatten)]
#[clap(flatten)] #[clap(flatten)]
pub indexer_options: IndexerOpts, pub indexer_options: IndexerOpts,
@@ -603,6 +611,7 @@ impl Opt {
experimental_limit_batched_tasks_total_size, experimental_limit_batched_tasks_total_size,
experimental_embedding_cache_entries, experimental_embedding_cache_entries,
experimental_no_snapshot_compaction, experimental_no_snapshot_compaction,
experimental_personalization_api_key,
s3_snapshot_options, s3_snapshot_options,
} = self; } = self;
export_to_env_if_not_present(MEILI_DB_PATH, db_path); export_to_env_if_not_present(MEILI_DB_PATH, db_path);
@@ -704,6 +713,12 @@ impl Opt {
MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION, MEILI_EXPERIMENTAL_NO_SNAPSHOT_COMPACTION,
experimental_no_snapshot_compaction.to_string(), experimental_no_snapshot_compaction.to_string(),
); );
if let Some(experimental_personalization_api_key) = experimental_personalization_api_key {
export_to_env_if_not_present(
MEILI_EXPERIMENTAL_PERSONALIZATION_API_KEY,
experimental_personalization_api_key,
);
}
indexer_options.export_to_env(); indexer_options.export_to_env();
if let Some(s3_snapshot_options) = s3_snapshot_options { if let Some(s3_snapshot_options) = s3_snapshot_options {
#[cfg(not(unix))] #[cfg(not(unix))]

View File

@@ -0,0 +1,366 @@
use crate::search::{Personalize, SearchResult};
use meilisearch_types::{
error::{Code, ErrorCode, ResponseError},
milli::TimeBudget,
};
use rand::Rng;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{debug, info, warn};
const COHERE_API_URL: &str = "https://api.cohere.ai/v1/rerank";
const MAX_RETRIES: u32 = 10;
#[derive(Debug, thiserror::Error)]
enum PersonalizationError {
#[error("Personalization service: HTTP request failed: {0}")]
Request(#[from] reqwest::Error),
#[error("Personalization service: Failed to parse response: {0}")]
Parse(String),
#[error("Personalization service: Cohere API error: {0}")]
Api(String),
#[error("Personalization service: Unauthorized: invalid API key")]
Unauthorized,
#[error("Personalization service: Rate limited: too many requests")]
RateLimited,
#[error("Personalization service: Bad request: {0}")]
BadRequest(String),
#[error("Personalization service: Internal server error: {0}")]
InternalServerError(String),
#[error("Personalization service: Network error: {0}")]
Network(String),
#[error("Personalization service: Deadline exceeded")]
DeadlineExceeded,
#[error(transparent)]
FeatureNotEnabled(#[from] index_scheduler::error::FeatureNotEnabledError),
}
impl ErrorCode for PersonalizationError {
fn error_code(&self) -> Code {
match self {
PersonalizationError::FeatureNotEnabled { .. } => Code::FeatureNotEnabled,
PersonalizationError::Unauthorized => Code::RemoteInvalidApiKey,
PersonalizationError::RateLimited => Code::TooManySearchRequests,
PersonalizationError::BadRequest(_) => Code::RemoteBadRequest,
PersonalizationError::InternalServerError(_) => Code::RemoteRemoteError,
PersonalizationError::Network(_) | PersonalizationError::Request(_) => {
Code::RemoteCouldNotSendRequest
}
PersonalizationError::Parse(_) | PersonalizationError::Api(_) => {
Code::RemoteBadResponse
}
PersonalizationError::DeadlineExceeded => Code::Internal, // should not be returned to the client
}
}
}
pub struct CohereService {
client: Client,
api_key: String,
}
impl CohereService {
pub fn new(api_key: String) -> Self {
info!("Personalization service initialized with Cohere API");
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client");
Self { client, api_key }
}
pub async fn rerank_search_results(
&self,
search_result: SearchResult,
personalize: &Personalize,
query: Option<&str>,
time_budget: TimeBudget,
) -> Result<SearchResult, ResponseError> {
if time_budget.exceeded() {
warn!("Could not rerank due to deadline");
// If the deadline is exceeded, return the original search result instead of an error
return Ok(search_result);
}
// Extract user context from personalization
let user_context = personalize.user_context.as_str();
// Build the prompt by merging query and user context
let prompt = match query {
Some(q) => format!("User Context: {user_context}\nQuery: {q}"),
None => format!("User Context: {user_context}"),
};
// Extract documents for reranking
let documents: Vec<String> = search_result
.hits
.iter()
.map(|hit| {
// Convert the document to a string representation for reranking
serde_json::to_string(&hit.document).unwrap_or_else(|_| "{}".to_string())
})
.collect();
if documents.is_empty() {
return Ok(search_result);
}
// Call Cohere's rerank API with retry logic
let reranked_indices =
match self.call_rerank_with_retry(&prompt, &documents, time_budget).await {
Ok(indices) => indices,
Err(PersonalizationError::DeadlineExceeded) => {
// If the deadline is exceeded, return the original search result instead of an error
return Ok(search_result);
}
Err(e) => return Err(e.into()),
};
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
// Reorder the hits based on Cohere's reranking
let mut reranked_hits = Vec::new();
for index in reranked_indices.iter() {
if let Some(hit) = search_result.hits.get(*index) {
reranked_hits.push(hit.clone());
}
}
Ok(SearchResult { hits: reranked_hits, ..search_result })
}
async fn call_rerank_with_retry(
&self,
query: &str,
documents: &[String],
time_budget: TimeBudget,
) -> Result<Vec<usize>, PersonalizationError> {
let request_body = CohereRerankRequest {
query: query.to_string(),
documents: documents.to_vec(),
model: "rerank-english-v3.0".to_string(),
};
// Retry loop similar to vector extraction
for attempt in 0..MAX_RETRIES {
let response_result = self.send_rerank_request(&request_body).await;
let retry_duration = match self.handle_response(response_result).await {
Ok(indices) => return Ok(indices),
Err(retry) => {
warn!("Cohere rerank attempt #{} failed: {}", attempt, retry.error);
if time_budget.exceeded() {
warn!("Could not rerank due to deadline");
return Err(PersonalizationError::DeadlineExceeded);
} else {
match retry.into_duration(attempt) {
Ok(d) => d,
Err(error) => return Err(error),
}
}
}
};
// randomly up to double the retry duration
let retry_duration = retry_duration
+ rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
warn!("Retrying after {}ms", retry_duration.as_millis());
tokio::time::sleep(retry_duration).await;
}
// Final attempt without retry
let response_result = self.send_rerank_request(&request_body).await;
match self.handle_response(response_result).await {
Ok(indices) => Ok(indices),
Err(retry) => Err(retry.into_error()),
}
}
async fn send_rerank_request(
&self,
request_body: &CohereRerankRequest,
) -> Result<reqwest::Response, reqwest::Error> {
self.client
.post(COHERE_API_URL)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(request_body)
.send()
.await
}
async fn handle_response(
&self,
response_result: Result<reqwest::Response, reqwest::Error>,
) -> Result<Vec<usize>, Retry> {
let response = match response_result {
Ok(r) => r,
Err(e) if e.is_timeout() => {
return Err(Retry::retry_later(PersonalizationError::Network(format!(
"Request timeout: {}",
e
))));
}
Err(e) => {
return Err(Retry::retry_later(PersonalizationError::Network(format!(
"Network error: {}",
e
))));
}
};
let status = response.status();
let status_code = status.as_u16();
if status.is_success() {
let rerank_response: CohereRerankResponse = match response.json().await {
Ok(r) => r,
Err(e) => {
return Err(Retry::retry_later(PersonalizationError::Parse(format!(
"Failed to parse response: {}",
e
))));
}
};
// Extract indices from rerank results
let indices: Vec<usize> =
rerank_response.results.iter().map(|result| result.index as usize).collect();
return Ok(indices);
}
// Handle error status codes
let error_body = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
let retry = match status_code {
401 => Retry::give_up(PersonalizationError::Unauthorized),
429 => Retry::rate_limited(PersonalizationError::RateLimited),
400 => Retry::give_up(PersonalizationError::BadRequest(error_body)),
500..=599 => Retry::retry_later(PersonalizationError::InternalServerError(format!(
"Status {}: {}",
status_code, error_body
))),
402..=499 => Retry::give_up(PersonalizationError::Api(format!(
"Status {}: {}",
status_code, error_body
))),
_ => Retry::retry_later(PersonalizationError::Api(format!(
"Unexpected status {}: {}",
status_code, error_body
))),
};
Err(retry)
}
}
#[derive(Serialize)]
struct CohereRerankRequest {
query: String,
documents: Vec<String>,
model: String,
}
#[derive(Deserialize)]
struct CohereRerankResponse {
results: Vec<CohereRerankResult>,
}
#[derive(Deserialize)]
struct CohereRerankResult {
index: u32,
}
// Retry strategy similar to vector extraction
struct Retry {
error: PersonalizationError,
strategy: RetryStrategy,
}
enum RetryStrategy {
GiveUp,
Retry,
RetryAfterRateLimit,
}
impl Retry {
fn give_up(error: PersonalizationError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
fn retry_later(error: PersonalizationError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
fn rate_limited(error: PersonalizationError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
fn into_duration(self, attempt: u32) -> Result<Duration, PersonalizationError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => {
// Exponential backoff: 10^attempt milliseconds
Ok(Duration::from_millis((10u64).pow(attempt)))
}
RetryStrategy::RetryAfterRateLimit => {
// Longer backoff for rate limits: 100ms + exponential
Ok(Duration::from_millis(100 + (10u64).pow(attempt)))
}
}
}
fn into_error(self) -> PersonalizationError {
self.error
}
}
pub enum PersonalizationService {
Cohere(CohereService),
Disabled,
}
impl PersonalizationService {
pub fn cohere(api_key: String) -> Self {
// If the API key is empty, consider the personalization service as disabled
if api_key.trim().is_empty() {
Self::disabled()
} else {
Self::Cohere(CohereService::new(api_key))
}
}
pub fn disabled() -> Self {
debug!("Personalization service disabled");
Self::Disabled
}
pub async fn rerank_search_results(
&self,
search_result: SearchResult,
personalize: &Personalize,
query: Option<&str>,
time_budget: TimeBudget,
) -> Result<SearchResult, ResponseError> {
match self {
Self::Cohere(cohere_service) => {
cohere_service
.rerank_search_results(search_result, personalize, query, time_budget)
.await
}
Self::Disabled => Err(PersonalizationError::FeatureNotEnabled(
index_scheduler::error::FeatureNotEnabledError {
disabled_action: "reranking search results",
feature: "personalization",
issue_link: "https://github.com/orgs/meilisearch/discussions/866",
},
)
.into()),
}
}
}

View File

@@ -343,6 +343,7 @@ impl From<FacetSearchQuery> for SearchQuery {
hybrid, hybrid,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize: None,
} }
} }
} }

View File

@@ -24,9 +24,9 @@ use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search_analytics::{SearchAggregator, SearchGET, SearchPOST}; use crate::routes::indexes::search_analytics::{SearchAggregator, SearchGET, SearchPOST};
use crate::routes::parse_include_metadata_header; use crate::routes::parse_include_metadata_header;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, HybridQuery, MatchingStrategy, RankingScoreThreshold, add_search_rules, perform_search, HybridQuery, MatchingStrategy, Personalize,
RetrieveVectors, SearchKind, SearchParams, SearchQuery, SearchResult, SemanticRatio, RankingScoreThreshold, RetrieveVectors, SearchKind, SearchParams, SearchQuery, SearchResult,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
}; };
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
@@ -134,6 +134,8 @@ pub struct SearchQueryGet {
#[deserr(default, error = DeserrQueryParamError<InvalidSearchLocales>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchLocales>)]
#[param(value_type = Vec<Locale>, explode = false)] #[param(value_type = Vec<Locale>, explode = false)]
pub locales: Option<CS<Locale>>, pub locales: Option<CS<Locale>>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchPersonalizeUserContext>)]
pub personalize_user_context: Option<String>,
} }
#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] #[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)]
@@ -205,6 +207,9 @@ impl TryFrom<SearchQueryGet> for SearchQuery {
)); ));
} }
let personalize =
other.personalize_user_context.map(|user_context| Personalize { user_context });
Ok(Self { Ok(Self {
q: other.q, q: other.q,
// `media` not supported for `GET` // `media` not supported for `GET`
@@ -234,6 +239,7 @@ impl TryFrom<SearchQueryGet> for SearchQuery {
hybrid, hybrid,
ranking_score_threshold: other.ranking_score_threshold.map(|o| o.0), ranking_score_threshold: other.ranking_score_threshold.map(|o| o.0),
locales: other.locales.map(|o| o.into_iter().collect()), locales: other.locales.map(|o| o.into_iter().collect()),
personalize,
}) })
} }
} }
@@ -322,6 +328,7 @@ pub fn fix_sort_query_parameters(sort_query: &str) -> Vec<String> {
pub async fn search_with_url_query( pub async fn search_with_url_query(
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
personalization_service: web::Data<crate::personalization::PersonalizationService>,
index_uid: web::Path<String>, index_uid: web::Path<String>,
params: AwebQueryParameter<SearchQueryGet, DeserrQueryParamError>, params: AwebQueryParameter<SearchQueryGet, DeserrQueryParamError>,
req: HttpRequest, req: HttpRequest,
@@ -342,9 +349,16 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
// Extract personalization and query string before moving query
let personalize = query.personalize.take();
let search_kind = let search_kind =
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors); let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors);
// Save the query string for personalization if requested
let personalize_query = personalize.is_some().then(|| query.q.clone()).flatten();
let permit = search_queue.try_get_search_permit().await?; let permit = search_queue.try_get_search_permit().await?;
let include_metadata = parse_include_metadata_header(&req); let include_metadata = parse_include_metadata_header(&req);
@@ -365,12 +379,24 @@ pub async fn search_with_url_query(
.await; .await;
permit.drop().await; permit.drop().await;
let search_result = search_result?; let search_result = search_result?;
if let Ok(ref search_result) = search_result { if let Ok((search_result, _)) = search_result.as_ref() {
aggregate.succeed(search_result); aggregate.succeed(search_result);
} }
analytics.publish(aggregate, &req); analytics.publish(aggregate, &req);
let search_result = search_result?; let (mut search_result, time_budget) = search_result?;
// Apply personalization if requested
if let Some(personalize) = personalize.as_ref() {
search_result = personalization_service
.rerank_search_results(
search_result,
personalize,
personalize_query.as_deref(),
time_budget,
)
.await?;
}
debug!(request_uid = ?request_uid, returns = ?search_result, "Search get"); debug!(request_uid = ?request_uid, returns = ?search_result, "Search get");
Ok(HttpResponse::Ok().json(search_result)) Ok(HttpResponse::Ok().json(search_result))
@@ -435,6 +461,7 @@ pub async fn search_with_url_query(
pub async fn search_with_post( pub async fn search_with_post(
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
personalization_service: web::Data<crate::personalization::PersonalizationService>,
index_uid: web::Path<String>, index_uid: web::Path<String>,
params: AwebJson<SearchQuery, DeserrJsonError>, params: AwebJson<SearchQuery, DeserrJsonError>,
req: HttpRequest, req: HttpRequest,
@@ -455,12 +482,18 @@ pub async fn search_with_post(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
// Extract personalization and query string before moving query
let personalize = query.personalize.take();
let search_kind = let search_kind =
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors); let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors);
let include_metadata = parse_include_metadata_header(&req); let include_metadata = parse_include_metadata_header(&req);
// Save the query string for personalization if requested
let personalize_query = personalize.is_some().then(|| query.q.clone()).flatten();
let permit = search_queue.try_get_search_permit().await?; let permit = search_queue.try_get_search_permit().await?;
let search_result = tokio::task::spawn_blocking(move || { let search_result = tokio::task::spawn_blocking(move || {
perform_search( perform_search(
@@ -479,7 +512,7 @@ pub async fn search_with_post(
.await; .await;
permit.drop().await; permit.drop().await;
let search_result = search_result?; let search_result = search_result?;
if let Ok(ref search_result) = search_result { if let Ok((ref search_result, _)) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
if search_result.degraded { if search_result.degraded {
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
@@ -487,7 +520,19 @@ pub async fn search_with_post(
} }
analytics.publish(aggregate, &req); analytics.publish(aggregate, &req);
let search_result = search_result?; let (mut search_result, time_budget) = search_result?;
// Apply personalization if requested
if let Some(personalize) = personalize.as_ref() {
search_result = personalization_service
.rerank_search_results(
search_result,
personalize,
personalize_query.as_deref(),
time_budget,
)
.await?;
}
debug!(request_uid = ?request_uid, returns = ?search_result, "Search post"); debug!(request_uid = ?request_uid, returns = ?search_result, "Search post");
Ok(HttpResponse::Ok().json(search_result)) Ok(HttpResponse::Ok().json(search_result))

View File

@@ -7,6 +7,7 @@ use serde_json::{json, Value};
use crate::aggregate_methods; use crate::aggregate_methods;
use crate::analytics::{Aggregate, AggregateMethod}; use crate::analytics::{Aggregate, AggregateMethod};
use crate::metrics::MEILISEARCH_PERSONALIZED_SEARCH_REQUESTS;
use crate::search::{ use crate::search::{
SearchQuery, SearchResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, SearchQuery, SearchResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER,
DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT,
@@ -95,6 +96,9 @@ pub struct SearchAggregator<Method: AggregateMethod> {
show_ranking_score_details: bool, show_ranking_score_details: bool,
ranking_score_threshold: bool, ranking_score_threshold: bool,
// personalization
total_personalized: usize,
marker: std::marker::PhantomData<Method>, marker: std::marker::PhantomData<Method>,
} }
@@ -129,6 +133,7 @@ impl<Method: AggregateMethod> SearchAggregator<Method> {
hybrid, hybrid,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@@ -204,6 +209,12 @@ impl<Method: AggregateMethod> SearchAggregator<Method> {
ret.locales = locales.iter().copied().collect(); ret.locales = locales.iter().copied().collect();
} }
// personalization
if personalize.is_some() {
ret.total_personalized = 1;
MEILISEARCH_PERSONALIZED_SEARCH_REQUESTS.inc();
}
ret.highlight_pre_tag = *highlight_pre_tag != DEFAULT_HIGHLIGHT_PRE_TAG(); ret.highlight_pre_tag = *highlight_pre_tag != DEFAULT_HIGHLIGHT_PRE_TAG();
ret.highlight_post_tag = *highlight_post_tag != DEFAULT_HIGHLIGHT_POST_TAG(); ret.highlight_post_tag = *highlight_post_tag != DEFAULT_HIGHLIGHT_POST_TAG();
ret.crop_marker = *crop_marker != DEFAULT_CROP_MARKER(); ret.crop_marker = *crop_marker != DEFAULT_CROP_MARKER();
@@ -296,6 +307,7 @@ impl<Method: AggregateMethod> Aggregate for SearchAggregator<Method> {
total_used_negative_operator, total_used_negative_operator,
ranking_score_threshold, ranking_score_threshold,
mut locales, mut locales,
total_personalized,
marker: _, marker: _,
} = *new; } = *new;
@@ -381,6 +393,9 @@ impl<Method: AggregateMethod> Aggregate for SearchAggregator<Method> {
// locales // locales
self.locales.append(&mut locales); self.locales.append(&mut locales);
// personalization
self.total_personalized = self.total_personalized.saturating_add(total_personalized);
self self
} }
@@ -426,6 +441,7 @@ impl<Method: AggregateMethod> Aggregate for SearchAggregator<Method> {
total_used_negative_operator, total_used_negative_operator,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
total_personalized,
marker: _, marker: _,
} = *self; } = *self;
@@ -499,6 +515,9 @@ impl<Method: AggregateMethod> Aggregate for SearchAggregator<Method> {
"show_ranking_score_details": show_ranking_score_details, "show_ranking_score_details": show_ranking_score_details,
"ranking_score_threshold": ranking_score_threshold, "ranking_score_threshold": ranking_score_threshold,
}, },
"personalization": {
"total_personalized": total_personalized,
},
}) })
} }
} }

View File

@@ -146,6 +146,7 @@ pub struct SearchResults {
pub async fn multi_search_with_post( pub async fn multi_search_with_post(
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
search_queue: Data<SearchQueue>, search_queue: Data<SearchQueue>,
personalization_service: web::Data<crate::personalization::PersonalizationService>,
params: AwebJson<FederatedSearch, DeserrJsonError>, params: AwebJson<FederatedSearch, DeserrJsonError>,
req: HttpRequest, req: HttpRequest,
analytics: web::Data<Analytics>, analytics: web::Data<Analytics>,
@@ -236,7 +237,7 @@ pub async fn multi_search_with_post(
// changes. // changes.
let search_results: Result<_, (ResponseError, usize)> = async { let search_results: Result<_, (ResponseError, usize)> = async {
let mut search_results = Vec::with_capacity(queries.len()); let mut search_results = Vec::with_capacity(queries.len());
for (query_index, (index_uid, query, federation_options)) in queries for (query_index, (index_uid, mut query, federation_options)) in queries
.into_iter() .into_iter()
.map(SearchQueryWithIndex::into_index_query_federation) .map(SearchQueryWithIndex::into_index_query_federation)
.enumerate() .enumerate()
@@ -269,6 +270,13 @@ pub async fn multi_search_with_post(
}) })
.with_index(query_index)?; .with_index(query_index)?;
// Extract personalization and query string before moving query
let personalize = query.personalize.take();
// Save the query string for personalization if requested
let personalize_query =
personalize.is_some().then(|| query.q.clone()).flatten();
let index_uid_str = index_uid.to_string(); let index_uid_str = index_uid.to_string();
let search_kind = search_kind( let search_kind = search_kind(
@@ -280,7 +288,7 @@ pub async fn multi_search_with_post(
.with_index(query_index)?; .with_index(query_index)?;
let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors); let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors);
let search_result = tokio::task::spawn_blocking(move || { let (mut search_result, time_budget) = tokio::task::spawn_blocking(move || {
perform_search( perform_search(
SearchParams { SearchParams {
index_uid: index_uid_str.clone(), index_uid: index_uid_str.clone(),
@@ -295,11 +303,25 @@ pub async fn multi_search_with_post(
) )
}) })
.await .await
.with_index(query_index)?
.with_index(query_index)?; .with_index(query_index)?;
// Apply personalization if requested
if let Some(personalize) = personalize.as_ref() {
search_result = personalization_service
.rerank_search_results(
search_result,
personalize,
personalize_query.as_deref(),
time_budget,
)
.await
.with_index(query_index)?;
}
search_results.push(SearchResultWithIndex { search_results.push(SearchResultWithIndex {
index_uid: index_uid.into_inner(), index_uid: index_uid.into_inner(),
result: search_result.with_index(query_index)?, result: search_result,
}); });
} }
Ok(search_results) Ok(search_results)

View File

@@ -67,6 +67,7 @@ impl MultiSearchAggregator {
hybrid: _, hybrid: _,
ranking_score_threshold: _, ranking_score_threshold: _,
locales: _, locales: _,
personalize: _,
} in &federated_search.queries } in &federated_search.queries
{ {
if let Some(federation_options) = federation_options { if let Some(federation_options) = federation_options {

View File

@@ -601,6 +601,10 @@ impl PartitionedQueries {
.into()); .into());
} }
if federated_query.has_personalize() {
return Err(MeilisearchHttpError::PersonalizationInFederatedQuery(query_index).into());
}
let (index_uid, query, federation_options) = federated_query.into_index_query_federation(); let (index_uid, query, federation_options) = federated_query.into_index_query_federation();
let federation_options = federation_options.unwrap_or_default(); let federation_options = federation_options.unwrap_or_default();

View File

@@ -59,6 +59,13 @@ pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5); pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5);
pub const INCLUDE_METADATA_HEADER: &str = "Meili-Include-Metadata"; pub const INCLUDE_METADATA_HEADER: &str = "Meili-Include-Metadata";
#[derive(Clone, Default, PartialEq, Deserr, ToSchema, Debug)]
#[deserr(error = DeserrJsonError<InvalidSearchPersonalize>, rename_all = camelCase, deny_unknown_fields)]
pub struct Personalize {
#[deserr(error = DeserrJsonError<InvalidSearchPersonalizeUserContext>)]
pub user_context: String,
}
#[derive(Clone, Default, PartialEq, Deserr, ToSchema)] #[derive(Clone, Default, PartialEq, Deserr, ToSchema)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
pub struct SearchQuery { pub struct SearchQuery {
@@ -122,6 +129,8 @@ pub struct SearchQuery {
pub ranking_score_threshold: Option<RankingScoreThreshold>, pub ranking_score_threshold: Option<RankingScoreThreshold>,
#[deserr(default, error = DeserrJsonError<InvalidSearchLocales>)] #[deserr(default, error = DeserrJsonError<InvalidSearchLocales>)]
pub locales: Option<Vec<Locale>>, pub locales: Option<Vec<Locale>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchPersonalize>, default)]
pub personalize: Option<Personalize>,
} }
impl From<SearchParameters> for SearchQuery { impl From<SearchParameters> for SearchQuery {
@@ -169,6 +178,7 @@ impl From<SearchParameters> for SearchQuery {
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
crop_marker: DEFAULT_CROP_MARKER(), crop_marker: DEFAULT_CROP_MARKER(),
locales: None, locales: None,
personalize: None,
} }
} }
} }
@@ -250,6 +260,7 @@ impl fmt::Debug for SearchQuery {
attributes_to_search_on, attributes_to_search_on,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize,
} = self; } = self;
let mut debug = f.debug_struct("SearchQuery"); let mut debug = f.debug_struct("SearchQuery");
@@ -338,6 +349,10 @@ impl fmt::Debug for SearchQuery {
debug.field("locales", &locales); debug.field("locales", &locales);
} }
if let Some(personalize) = personalize {
debug.field("personalize", &personalize);
}
debug.finish() debug.finish()
} }
} }
@@ -543,6 +558,9 @@ pub struct SearchQueryWithIndex {
pub ranking_score_threshold: Option<RankingScoreThreshold>, pub ranking_score_threshold: Option<RankingScoreThreshold>,
#[deserr(default, error = DeserrJsonError<InvalidSearchLocales>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchLocales>, default)]
pub locales: Option<Vec<Locale>>, pub locales: Option<Vec<Locale>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchPersonalize>, default)]
#[serde(skip)]
pub personalize: Option<Personalize>,
#[deserr(default)] #[deserr(default)]
pub federation_options: Option<FederationOptions>, pub federation_options: Option<FederationOptions>,
@@ -567,6 +585,10 @@ impl SearchQueryWithIndex {
self.facets.as_deref().filter(|v| !v.is_empty()) self.facets.as_deref().filter(|v| !v.is_empty())
} }
pub fn has_personalize(&self) -> bool {
self.personalize.is_some()
}
pub fn from_index_query_federation( pub fn from_index_query_federation(
index_uid: IndexUid, index_uid: IndexUid,
query: SearchQuery, query: SearchQuery,
@@ -600,6 +622,7 @@ impl SearchQueryWithIndex {
attributes_to_search_on, attributes_to_search_on,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize,
} = query; } = query;
SearchQueryWithIndex { SearchQueryWithIndex {
@@ -631,6 +654,7 @@ impl SearchQueryWithIndex {
attributes_to_search_on, attributes_to_search_on,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize,
federation_options, federation_options,
} }
} }
@@ -666,6 +690,7 @@ impl SearchQueryWithIndex {
hybrid, hybrid,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize,
} = self; } = self;
( (
index_uid, index_uid,
@@ -697,6 +722,7 @@ impl SearchQueryWithIndex {
hybrid, hybrid,
ranking_score_threshold, ranking_score_threshold,
locales, locales,
personalize,
// do not use ..Default::default() here, // do not use ..Default::default() here,
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
}, },
@@ -1149,7 +1175,10 @@ pub struct SearchParams {
pub include_metadata: bool, pub include_metadata: bool,
} }
pub fn perform_search(params: SearchParams, index: &Index) -> Result<SearchResult, ResponseError> { pub fn perform_search(
params: SearchParams,
index: &Index,
) -> Result<(SearchResult, TimeBudget), ResponseError> {
let SearchParams { let SearchParams {
index_uid, index_uid,
query, query,
@@ -1168,7 +1197,7 @@ pub fn perform_search(params: SearchParams, index: &Index) -> Result<SearchResul
}; };
let (search, is_finite_pagination, max_total_hits, offset) = let (search, is_finite_pagination, max_total_hits, offset) =
prepare_search(index, &rtxn, &query, &search_kind, time_budget, features)?; prepare_search(index, &rtxn, &query, &search_kind, time_budget.clone(), features)?;
let ( let (
milli::SearchResult { milli::SearchResult {
@@ -1226,6 +1255,7 @@ pub fn perform_search(params: SearchParams, index: &Index) -> Result<SearchResul
attributes_to_search_on: _, attributes_to_search_on: _,
filter: _, filter: _,
distinct: _, distinct: _,
personalize: _,
} = query; } = query;
let format = AttributesFormat { let format = AttributesFormat {
@@ -1291,7 +1321,7 @@ pub fn perform_search(params: SearchParams, index: &Index) -> Result<SearchResul
request_uid: Some(request_uid), request_uid: Some(request_uid),
metadata, metadata,
}; };
Ok(result) Ok((result, time_budget))
} }
#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]

View File

@@ -10,8 +10,9 @@ use actix_web::test::TestRequest;
use actix_web::web::Data; use actix_web::web::Data;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch::analytics::Analytics; use meilisearch::analytics::Analytics;
use meilisearch::personalization::PersonalizationService;
use meilisearch::search_queue::SearchQueue; use meilisearch::search_queue::SearchQueue;
use meilisearch::{create_app, Opt, SubscriberForSecondLayer}; use meilisearch::{create_app, Opt, ServicesData, SubscriberForSecondLayer};
use meilisearch_auth::AuthController; use meilisearch_auth::AuthController;
use tracing::level_filters::LevelFilter; use tracing::level_filters::LevelFilter;
use tracing_subscriber::Layer; use tracing_subscriber::Layer;
@@ -135,14 +136,24 @@ impl Service {
self.options.experimental_search_queue_size, self.options.experimental_search_queue_size,
NonZeroUsize::new(1).unwrap(), NonZeroUsize::new(1).unwrap(),
); );
let personalization_service = self
.options
.experimental_personalization_api_key
.clone()
.map(PersonalizationService::cohere)
.unwrap_or_else(PersonalizationService::disabled);
actix_web::test::init_service(create_app( actix_web::test::init_service(create_app(
self.index_scheduler.clone().into(), ServicesData {
self.auth.clone().into(), index_scheduler: self.index_scheduler.clone().into(),
Data::new(search_queue), auth: self.auth.clone().into(),
search_queue: Data::new(search_queue),
personalization_service: Data::new(personalization_service),
logs_route_handle: Data::new(route_layer_handle),
logs_stderr_handle: Data::new(stderr_layer_handle),
analytics: Data::new(Analytics::no_analytics()),
},
self.options.clone(), self.options.clone(),
(route_layer_handle, stderr_layer_handle),
Data::new(Analytics::no_analytics()),
true, true,
)) ))
.await .await

View File

@@ -207,3 +207,118 @@ async fn errors() {
} }
"###); "###);
} }
#[actix_rt::test]
async fn search_with_personalization_without_enabling_the_feature() {
let server = Server::new().await;
let index = server.unique_index();
// Create the index and add some documents
let (task, _code) = index.create(None).await;
server.wait_task(task.uid()).await.succeeded();
let (task, _code) = index
.add_documents(
json!([
{"id": 1, "title": "The Dark Knight", "genre": "Action"},
{"id": 2, "title": "Inception", "genre": "Sci-Fi"},
{"id": 3, "title": "The Matrix", "genre": "Sci-Fi"}
]),
None,
)
.await;
server.wait_task(task.uid()).await.succeeded();
// Try to search with personalization - should return feature_not_enabled error
let (response, code) = index
.search_post(json!({
"q": "movie",
"personalize": {
"userContext": "I love science fiction movies"
}
}))
.await;
meili_snap::snapshot!(code, @"400 Bad Request");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"message": "reranking search results requires enabling the `personalization` experimental feature. See https://github.com/orgs/meilisearch/discussions/866",
"code": "feature_not_enabled",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#feature_not_enabled"
}
"###);
}
#[actix_rt::test]
async fn multi_search_with_personalization_without_enabling_the_feature() {
let server = Server::new().await;
let index = server.unique_index();
// Create the index and add some documents
let (task, _code) = index.create(None).await;
server.wait_task(task.uid()).await.succeeded();
let (task, _code) = index
.add_documents(
json!([
{"id": 1, "title": "The Dark Knight", "genre": "Action"},
{"id": 2, "title": "Inception", "genre": "Sci-Fi"},
{"id": 3, "title": "The Matrix", "genre": "Sci-Fi"}
]),
None,
)
.await;
server.wait_task(task.uid()).await.succeeded();
// Try to multi-search with personalization - should return feature_not_enabled error
let (response, code) = server
.multi_search(json!({
"queries": [
{
"indexUid": index.uid,
"q": "movie",
"personalize": {
"userContext": "I love science fiction movies"
}
}
]
}))
.await;
meili_snap::snapshot!(code, @"400 Bad Request");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"message": "Inside `.queries[0]`: reranking search results requires enabling the `personalization` experimental feature. See https://github.com/orgs/meilisearch/discussions/866",
"code": "feature_not_enabled",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#feature_not_enabled"
}
"###);
// Try to federated search with personalization - should return feature_not_enabled error
let (response, code) = server
.multi_search(json!({
"federation": {},
"queries": [
{
"indexUid": index.uid,
"q": "movie",
"personalize": {
"userContext": "I love science fiction movies"
}
}
]
}))
.await;
meili_snap::snapshot!(code, @"400 Bad Request");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"message": "Inside `.queries[0]`: Using `.personalize` is not allowed in federated queries.\n - Hint: remove `personalize` from query #0 or remove `federation` from the request",
"code": "invalid_multi_search_query_personalization",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_multi_search_query_personalization"
}
"###);
}

View File

@@ -8,8 +8,9 @@ use actix_web::http::header::ContentType;
use actix_web::web::Data; use actix_web::web::Data;
use meili_snap::snapshot; use meili_snap::snapshot;
use meilisearch::analytics::Analytics; use meilisearch::analytics::Analytics;
use meilisearch::personalization::PersonalizationService;
use meilisearch::search_queue::SearchQueue; use meilisearch::search_queue::SearchQueue;
use meilisearch::{create_app, Opt, SubscriberForSecondLayer}; use meilisearch::{create_app, Opt, ServicesData, SubscriberForSecondLayer};
use tracing::level_filters::LevelFilter; use tracing::level_filters::LevelFilter;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::Layer; use tracing_subscriber::Layer;
@@ -50,12 +51,16 @@ async fn basic_test_log_stream_route() {
); );
let app = actix_web::test::init_service(create_app( let app = actix_web::test::init_service(create_app(
server.service.index_scheduler.clone().into(), ServicesData {
server.service.auth.clone().into(), index_scheduler: server.service.index_scheduler.clone().into(),
Data::new(search_queue), auth: server.service.auth.clone().into(),
search_queue: Data::new(search_queue),
personalization_service: Data::new(PersonalizationService::disabled()),
logs_route_handle: Data::new(route_layer_handle),
logs_stderr_handle: Data::new(stderr_layer_handle),
analytics: Data::new(Analytics::no_analytics()),
},
server.service.options.clone(), server.service.options.clone(),
(route_layer_handle, stderr_layer_handle),
Data::new(Analytics::no_analytics()),
true, true,
)) ))
.await; .await;

View File

@@ -1,7 +1,11 @@
use meili_snap::*; use meili_snap::*;
use meilisearch::Opt;
use tempfile::TempDir;
use super::test_settings_documents_indexing_swapping_and_search; use super::test_settings_documents_indexing_swapping_and_search;
use crate::common::{shared_does_not_exists_index, Server, DOCUMENTS, NESTED_DOCUMENTS}; use crate::common::{
default_settings, shared_does_not_exists_index, Server, DOCUMENTS, NESTED_DOCUMENTS,
};
use crate::json; use crate::json;
#[actix_rt::test] #[actix_rt::test]
@@ -1320,3 +1324,98 @@ async fn search_with_contains_without_enabling_the_feature() {
} }
"#); "#);
} }
#[actix_rt::test]
#[ignore]
async fn search_with_personalization_invalid_api_key() {
// Create a server with a fake personalization API key
let dir = TempDir::new().unwrap();
let options = Opt {
experimental_personalization_api_key: Some("fake-api-key-12345".to_string()),
..default_settings(dir.path())
};
let server = Server::new_with_options(options).await.unwrap();
let index = server.unique_index();
// Create the index and add some documents
let (task, _code) = index.create(None).await;
server.wait_task(task.uid()).await.succeeded();
let (task, _code) = index
.add_documents(
json!([
{"id": 1, "title": "The Dark Knight", "genre": "Action"},
{"id": 2, "title": "Inception", "genre": "Sci-Fi"},
{"id": 3, "title": "The Matrix", "genre": "Sci-Fi"}
]),
None,
)
.await;
server.wait_task(task.uid()).await.succeeded();
// Try to search with personalization - should return remote_invalid_api_key error
let (response, code) = index
.search_post(json!({
"q": "the",
"personalize": {
"userContext": "I love science fiction movies"
}
}))
.await;
snapshot!(code, @"403 Forbidden");
snapshot!(json_string!(response), @r#"
{
"message": "Personalization service: Unauthorized: invalid API key",
"code": "remote_invalid_api_key",
"type": "auth",
"link": "https://docs.meilisearch.com/errors#remote_invalid_api_key"
}
"#);
}
#[actix_rt::test]
async fn search_with_personalization_no_user_context() {
// Create a server with a fake personalization API key
let dir = TempDir::new().unwrap();
let options = Opt {
experimental_personalization_api_key: Some("fake-api-key-12345".to_string()),
..default_settings(dir.path())
};
let server = Server::new_with_options(options).await.unwrap();
let index = server.unique_index();
// Create the index and add some documents
let (task, _code) = index.create(None).await;
server.wait_task(task.uid()).await.succeeded();
let (task, _code) = index
.add_documents(
json!([
{"id": 1, "title": "The Dark Knight", "genre": "Action"},
{"id": 2, "title": "Inception", "genre": "Sci-Fi"},
{"id": 3, "title": "The Matrix", "genre": "Sci-Fi"}
]),
None,
)
.await;
server.wait_task(task.uid()).await.succeeded();
// Try to search with personalization - should return remote_invalid_api_key error
let (response, code) = index
.search_post(json!({
"q": "the",
"personalize": {}
}))
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(json_string!(response), @r###"
{
"message": "Missing field `userContext` inside `.personalize`",
"code": "invalid_search_personalize",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_personalize"
}
"###);
}