mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-09-05 20:26:31 +00:00
feat: add personalization service with EnglishV3-only reranking
- Add new personalization module with Cohere integration - Implement rerank_search_results method using EnglishV3 model - Remove fallback logic to EnglishV2 for simplified behavior - Add comprehensive error handling and logging - Include unit tests for service behavior - Update search route to support personalization feature
This commit is contained in:
33
Cargo.lock
generated
33
Cargo.lock
generated
@ -1194,6 +1194,21 @@ version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
|
||||
|
||||
[[package]]
|
||||
name = "cohere-rust"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8b553b385b0f2562138baea705b5707335314f8e91a58e7d1a03c3a6c332423"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum_macros 0.26.4",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "color-spantrace"
|
||||
version = "0.3.0"
|
||||
@ -3445,7 +3460,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"strum_macros 0.27.1",
|
||||
"unicode-blocks",
|
||||
"unicode-normalization",
|
||||
"unicode-segmentation",
|
||||
@ -3753,6 +3768,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"cargo_toml",
|
||||
"clap",
|
||||
"cohere-rust",
|
||||
"crossbeam-channel",
|
||||
"deserr",
|
||||
"dump",
|
||||
@ -5821,7 +5837,20 @@ version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
"strum_macros 0.27.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -94,6 +94,7 @@ uuid = { version = "1.17.0", features = ["serde", "v4"] }
|
||||
serde_urlencoded = "0.7.1"
|
||||
termcolor = "1.4.1"
|
||||
url = { version = "2.5.4", features = ["serde"] }
|
||||
cohere-rust = "0.6.0"
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["json"] }
|
||||
tracing-trace = { version = "0.1.0", path = "../tracing-trace" }
|
||||
|
@ -9,6 +9,7 @@ pub mod middleware;
|
||||
pub mod option;
|
||||
#[cfg(test)]
|
||||
mod option_test;
|
||||
pub mod personalization;
|
||||
pub mod routes;
|
||||
pub mod search;
|
||||
pub mod search_queue;
|
||||
@ -622,12 +623,17 @@ pub fn configure_data(
|
||||
(logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle),
|
||||
analytics: Data<Analytics>,
|
||||
) {
|
||||
// Create personalization service with API key from options
|
||||
let personalization_service = personalization::PersonalizationService::new(
|
||||
index_scheduler.experimental_personalization_api_key().cloned(),
|
||||
);
|
||||
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
|
||||
config
|
||||
.app_data(index_scheduler)
|
||||
.app_data(auth)
|
||||
.app_data(search_queue)
|
||||
.app_data(analytics)
|
||||
.app_data(web::Data::new(personalization_service))
|
||||
.app_data(web::Data::new(logs_route))
|
||||
.app_data(web::Data::new(logs_stderr))
|
||||
.app_data(web::Data::new(opt.clone()))
|
||||
|
162
crates/meilisearch/src/personalization/mod.rs
Normal file
162
crates/meilisearch/src/personalization/mod.rs
Normal file
@ -0,0 +1,162 @@
|
||||
use crate::search::{Personalization, SearchResult};
|
||||
use cohere_rust::{
|
||||
api::rerank::{ReRankModel, ReRankRequest},
|
||||
Cohere,
|
||||
};
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
pub struct PersonalizationService {
|
||||
cohere: Option<Cohere>,
|
||||
}
|
||||
|
||||
impl PersonalizationService {
|
||||
pub fn new(api_key: Option<String>) -> Self {
|
||||
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
|
||||
|
||||
if cohere.is_some() {
|
||||
info!("Personalization service initialized with Cohere API");
|
||||
} else {
|
||||
debug!("Personalization service initialized without Cohere API key");
|
||||
}
|
||||
|
||||
Self { cohere }
|
||||
}
|
||||
|
||||
pub async fn rerank_search_results(
|
||||
&self,
|
||||
search_result: SearchResult,
|
||||
personalization: &Personalization,
|
||||
query: &str,
|
||||
) -> Result<SearchResult, ResponseError> {
|
||||
// If personalization is not enabled or no API key, return original results
|
||||
if !personalization.personalized || self.cohere.is_none() {
|
||||
return Ok(search_result);
|
||||
}
|
||||
|
||||
let cohere = self.cohere.as_ref().unwrap();
|
||||
|
||||
// Extract documents for reranking
|
||||
let documents: Vec<String> = search_result
|
||||
.hits
|
||||
.iter()
|
||||
.map(|hit| {
|
||||
// Convert the document to a string representation for reranking
|
||||
serde_json::to_string(&hit.document).unwrap_or_else(|_| "{}".to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
if documents.is_empty() {
|
||||
return Ok(search_result);
|
||||
}
|
||||
|
||||
// Prepare the rerank request
|
||||
let rerank_request = ReRankRequest {
|
||||
query,
|
||||
documents: &documents,
|
||||
model: ReRankModel::EnglishV3, // Use the default and more recent model
|
||||
top_n: None,
|
||||
max_chunks_per_doc: None,
|
||||
};
|
||||
|
||||
// Call Cohere's rerank API
|
||||
match cohere.rerank(&rerank_request).await {
|
||||
Ok(rerank_response) => {
|
||||
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
|
||||
|
||||
// Create a mapping from original index to new rank
|
||||
let reranked_indices: Vec<usize> =
|
||||
rerank_response.iter().map(|result| result.index as usize).collect();
|
||||
|
||||
// Reorder the hits based on Cohere's reranking
|
||||
let mut reranked_hits = search_result.hits.clone();
|
||||
for (new_index, original_index) in reranked_indices.iter().enumerate() {
|
||||
if *original_index < reranked_hits.len() {
|
||||
reranked_hits.swap(new_index, *original_index);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(SearchResult { hits: reranked_hits, ..search_result })
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Cohere rerank failed with model EnglishV3: {}", e);
|
||||
// Return original results on error
|
||||
Ok(search_result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::search::{HitsInfo, SearchHit};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_personalization_service_without_api_key() {
|
||||
let service = PersonalizationService::new(None);
|
||||
let personalization =
|
||||
Personalization { personalized: true, user_profile: Some("test user".to_string()) };
|
||||
|
||||
let search_result = SearchResult {
|
||||
hits: vec![SearchHit {
|
||||
document: serde_json::Map::new(),
|
||||
formatted: serde_json::Map::new(),
|
||||
matches_position: None,
|
||||
ranking_score: Some(1.0),
|
||||
ranking_score_details: None,
|
||||
}],
|
||||
query: "test".to_string(),
|
||||
processing_time_ms: 10,
|
||||
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
|
||||
facet_distribution: None,
|
||||
facet_stats: None,
|
||||
semantic_hit_count: None,
|
||||
degraded: false,
|
||||
used_negative_operator: false,
|
||||
};
|
||||
|
||||
let result =
|
||||
service.rerank_search_results(search_result.clone(), &personalization, "test").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Should return original results when no API key is provided
|
||||
let reranked_result = result.unwrap();
|
||||
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_personalization_service_disabled() {
|
||||
let service = PersonalizationService::new(Some("fake_key".to_string()));
|
||||
let personalization = Personalization {
|
||||
personalized: false, // Personalization disabled
|
||||
user_profile: Some("test user".to_string()),
|
||||
};
|
||||
|
||||
let search_result = SearchResult {
|
||||
hits: vec![SearchHit {
|
||||
document: serde_json::Map::new(),
|
||||
formatted: serde_json::Map::new(),
|
||||
matches_position: None,
|
||||
ranking_score: Some(1.0),
|
||||
ranking_score_details: None,
|
||||
}],
|
||||
query: "test".to_string(),
|
||||
processing_time_ms: 10,
|
||||
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
|
||||
facet_distribution: None,
|
||||
facet_stats: None,
|
||||
semantic_hit_count: None,
|
||||
degraded: false,
|
||||
used_negative_operator: false,
|
||||
};
|
||||
|
||||
let result =
|
||||
service.rerank_search_results(search_result.clone(), &personalization, "test").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Should return original results when personalization is disabled
|
||||
let reranked_result = result.unwrap();
|
||||
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
|
||||
}
|
||||
}
|
@ -342,6 +342,7 @@ pub fn fix_sort_query_parameters(sort_query: &str) -> Vec<String> {
|
||||
pub async fn search_with_url_query(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
personalization_service: web::Data<crate::personalization::PersonalizationService>,
|
||||
index_uid: web::Path<String>,
|
||||
params: AwebQueryParameter<SearchQueryGet, DeserrQueryParamError>,
|
||||
req: HttpRequest,
|
||||
@ -364,6 +365,11 @@ pub async fn search_with_url_query(
|
||||
let search_kind =
|
||||
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||
let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors);
|
||||
|
||||
// Extract personalization and query string before moving query
|
||||
let personalization = query.personalization.clone();
|
||||
let query_str = query.q.clone();
|
||||
|
||||
let permit = search_queue.try_get_search_permit().await?;
|
||||
let search_result = tokio::task::spawn_blocking(move || {
|
||||
perform_search(
|
||||
@ -383,7 +389,16 @@ pub async fn search_with_url_query(
|
||||
}
|
||||
analytics.publish(aggregate, &req);
|
||||
|
||||
let search_result = search_result?;
|
||||
let mut search_result = search_result?;
|
||||
|
||||
// Apply personalization if requested
|
||||
if let Some(personalization) = &personalization {
|
||||
if let Some(query_str) = &query_str {
|
||||
search_result = personalization_service
|
||||
.rerank_search_results(search_result, personalization, query_str)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
debug!(returns = ?search_result, "Search get");
|
||||
Ok(HttpResponse::Ok().json(search_result))
|
||||
@ -448,6 +463,7 @@ pub async fn search_with_url_query(
|
||||
pub async fn search_with_post(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
personalization_service: web::Data<crate::personalization::PersonalizationService>,
|
||||
index_uid: web::Path<String>,
|
||||
params: AwebJson<SearchQuery, DeserrJsonError>,
|
||||
req: HttpRequest,
|
||||
@ -471,6 +487,10 @@ pub async fn search_with_post(
|
||||
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||
let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors);
|
||||
|
||||
// Extract personalization and query string before moving query
|
||||
let personalization = query.personalization.clone();
|
||||
let query_str = query.q.clone();
|
||||
|
||||
let permit = search_queue.try_get_search_permit().await?;
|
||||
let search_result = tokio::task::spawn_blocking(move || {
|
||||
perform_search(
|
||||
@ -493,7 +513,16 @@ pub async fn search_with_post(
|
||||
}
|
||||
analytics.publish(aggregate, &req);
|
||||
|
||||
let search_result = search_result?;
|
||||
let mut search_result = search_result?;
|
||||
|
||||
// Apply personalization if requested
|
||||
if let Some(personalization) = &personalization {
|
||||
if let Some(query_str) = &query_str {
|
||||
search_result = personalization_service
|
||||
.rerank_search_results(search_result, personalization, query_str)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
debug!(returns = ?search_result, "Search post");
|
||||
Ok(HttpResponse::Ok().json(search_result))
|
||||
|
Reference in New Issue
Block a user