feat: add personalization service with EnglishV3-only reranking

- Add new personalization module with Cohere integration
- Implement rerank_search_results method using EnglishV3 model
- Remove fallback logic to EnglishV2 for simplified behavior
- Add comprehensive error handling and logging
- Include unit tests for service behavior
- Update search route to support personalization feature
This commit is contained in:
ManyTheFish
2025-07-23 09:17:22 +02:00
parent 2b2bbad3f8
commit decd4df5a8
5 changed files with 231 additions and 4 deletions

33
Cargo.lock generated
View File

@ -1194,6 +1194,21 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cohere-rust"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8b553b385b0f2562138baea705b5707335314f8e91a58e7d1a03c3a6c332423"
dependencies = [
"bytes",
"reqwest",
"serde",
"serde_json",
"strum_macros 0.26.4",
"thiserror 1.0.69",
"tokio",
]
[[package]]
name = "color-spantrace"
version = "0.3.0"
@ -3445,7 +3460,7 @@ dependencies = [
"serde_json",
"serde_yaml",
"strum",
"strum_macros",
"strum_macros 0.27.1",
"unicode-blocks",
"unicode-normalization",
"unicode-segmentation",
@ -3753,6 +3768,7 @@ dependencies = [
"bytes",
"cargo_toml",
"clap",
"cohere-rust",
"crossbeam-channel",
"deserr",
"dump",
@ -5821,7 +5837,20 @@ version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
"strum_macros 0.27.1",
]
[[package]]
name = "strum_macros"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.101",
]
[[package]]

View File

@ -94,6 +94,7 @@ uuid = { version = "1.17.0", features = ["serde", "v4"] }
serde_urlencoded = "0.7.1"
termcolor = "1.4.1"
url = { version = "2.5.4", features = ["serde"] }
cohere-rust = "0.6.0"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["json"] }
tracing-trace = { version = "0.1.0", path = "../tracing-trace" }

View File

@ -9,6 +9,7 @@ pub mod middleware;
pub mod option;
#[cfg(test)]
mod option_test;
pub mod personalization;
pub mod routes;
pub mod search;
pub mod search_queue;
@ -622,12 +623,17 @@ pub fn configure_data(
(logs_route, logs_stderr): (LogRouteHandle, LogStderrHandle),
analytics: Data<Analytics>,
) {
// Create personalization service with API key from options
let personalization_service = personalization::PersonalizationService::new(
index_scheduler.experimental_personalization_api_key().cloned(),
);
let http_payload_size_limit = opt.http_payload_size_limit.as_u64() as usize;
config
.app_data(index_scheduler)
.app_data(auth)
.app_data(search_queue)
.app_data(analytics)
.app_data(web::Data::new(personalization_service))
.app_data(web::Data::new(logs_route))
.app_data(web::Data::new(logs_stderr))
.app_data(web::Data::new(opt.clone()))

View File

@ -0,0 +1,162 @@
use crate::search::{Personalization, SearchResult};
use cohere_rust::{
api::rerank::{ReRankModel, ReRankRequest},
Cohere,
};
use meilisearch_types::error::ResponseError;
use tracing::{debug, error, info};
pub struct PersonalizationService {
cohere: Option<Cohere>,
}
impl PersonalizationService {
pub fn new(api_key: Option<String>) -> Self {
let cohere = api_key.map(|key| Cohere::new("https://api.cohere.ai", key));
if cohere.is_some() {
info!("Personalization service initialized with Cohere API");
} else {
debug!("Personalization service initialized without Cohere API key");
}
Self { cohere }
}
pub async fn rerank_search_results(
&self,
search_result: SearchResult,
personalization: &Personalization,
query: &str,
) -> Result<SearchResult, ResponseError> {
// If personalization is not enabled or no API key, return original results
if !personalization.personalized || self.cohere.is_none() {
return Ok(search_result);
}
let cohere = self.cohere.as_ref().unwrap();
// Extract documents for reranking
let documents: Vec<String> = search_result
.hits
.iter()
.map(|hit| {
// Convert the document to a string representation for reranking
serde_json::to_string(&hit.document).unwrap_or_else(|_| "{}".to_string())
})
.collect();
if documents.is_empty() {
return Ok(search_result);
}
// Prepare the rerank request
let rerank_request = ReRankRequest {
query,
documents: &documents,
model: ReRankModel::EnglishV3, // Use the default and more recent model
top_n: None,
max_chunks_per_doc: None,
};
// Call Cohere's rerank API
match cohere.rerank(&rerank_request).await {
Ok(rerank_response) => {
debug!("Cohere rerank successful, reordering {} results", search_result.hits.len());
// Create a mapping from original index to new rank
let reranked_indices: Vec<usize> =
rerank_response.iter().map(|result| result.index as usize).collect();
// Reorder the hits based on Cohere's reranking
let mut reranked_hits = search_result.hits.clone();
for (new_index, original_index) in reranked_indices.iter().enumerate() {
if *original_index < reranked_hits.len() {
reranked_hits.swap(new_index, *original_index);
}
}
Ok(SearchResult { hits: reranked_hits, ..search_result })
}
Err(e) => {
error!("Cohere rerank failed with model EnglishV3: {}", e);
// Return original results on error
Ok(search_result)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::{HitsInfo, SearchHit};
#[tokio::test]
async fn test_personalization_service_without_api_key() {
let service = PersonalizationService::new(None);
let personalization =
Personalization { personalized: true, user_profile: Some("test user".to_string()) };
let search_result = SearchResult {
hits: vec![SearchHit {
document: serde_json::Map::new(),
formatted: serde_json::Map::new(),
matches_position: None,
ranking_score: Some(1.0),
ranking_score_details: None,
}],
query: "test".to_string(),
processing_time_ms: 10,
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
facet_distribution: None,
facet_stats: None,
semantic_hit_count: None,
degraded: false,
used_negative_operator: false,
};
let result =
service.rerank_search_results(search_result.clone(), &personalization, "test").await;
assert!(result.is_ok());
// Should return original results when no API key is provided
let reranked_result = result.unwrap();
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
}
#[tokio::test]
async fn test_personalization_service_disabled() {
let service = PersonalizationService::new(Some("fake_key".to_string()));
let personalization = Personalization {
personalized: false, // Personalization disabled
user_profile: Some("test user".to_string()),
};
let search_result = SearchResult {
hits: vec![SearchHit {
document: serde_json::Map::new(),
formatted: serde_json::Map::new(),
matches_position: None,
ranking_score: Some(1.0),
ranking_score_details: None,
}],
query: "test".to_string(),
processing_time_ms: 10,
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
facet_distribution: None,
facet_stats: None,
semantic_hit_count: None,
degraded: false,
used_negative_operator: false,
};
let result =
service.rerank_search_results(search_result.clone(), &personalization, "test").await;
assert!(result.is_ok());
// Should return original results when personalization is disabled
let reranked_result = result.unwrap();
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
}
}

View File

@ -342,6 +342,7 @@ pub fn fix_sort_query_parameters(sort_query: &str) -> Vec<String> {
pub async fn search_with_url_query(
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>,
personalization_service: web::Data<crate::personalization::PersonalizationService>,
index_uid: web::Path<String>,
params: AwebQueryParameter<SearchQueryGet, DeserrQueryParamError>,
req: HttpRequest,
@ -364,6 +365,11 @@ pub async fn search_with_url_query(
let search_kind =
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors);
// Extract personalization and query string before moving query
let personalization = query.personalization.clone();
let query_str = query.q.clone();
let permit = search_queue.try_get_search_permit().await?;
let search_result = tokio::task::spawn_blocking(move || {
perform_search(
@ -383,7 +389,16 @@ pub async fn search_with_url_query(
}
analytics.publish(aggregate, &req);
let search_result = search_result?;
let mut search_result = search_result?;
// Apply personalization if requested
if let Some(personalization) = &personalization {
if let Some(query_str) = &query_str {
search_result = personalization_service
.rerank_search_results(search_result, personalization, query_str)
.await?;
}
}
debug!(returns = ?search_result, "Search get");
Ok(HttpResponse::Ok().json(search_result))
@ -448,6 +463,7 @@ pub async fn search_with_url_query(
pub async fn search_with_post(
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>,
personalization_service: web::Data<crate::personalization::PersonalizationService>,
index_uid: web::Path<String>,
params: AwebJson<SearchQuery, DeserrJsonError>,
req: HttpRequest,
@ -471,6 +487,10 @@ pub async fn search_with_post(
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors);
// Extract personalization and query string before moving query
let personalization = query.personalization.clone();
let query_str = query.q.clone();
let permit = search_queue.try_get_search_permit().await?;
let search_result = tokio::task::spawn_blocking(move || {
perform_search(
@ -493,7 +513,16 @@ pub async fn search_with_post(
}
analytics.publish(aggregate, &req);
let search_result = search_result?;
let mut search_result = search_result?;
// Apply personalization if requested
if let Some(personalization) = &personalization {
if let Some(query_str) = &query_str {
search_result = personalization_service
.rerank_search_results(search_result, personalization, query_str)
.await?;
}
}
debug!(returns = ?search_result, "Search post");
Ok(HttpResponse::Ok().json(search_result))