feat: refine personalization query by merging with user context

- Merge initial query with user context to create a comprehensive prompt
- Only skip reranking if both query and user_context are None
- Support reranking with query-only, user_context-only, or both
- Use 'let else' pattern for cleaner error handling
- Add comprehensive tests for different parameter combinations
- Improve prompt format for better reranking effectiveness
This commit is contained in:
ManyTheFish
2025-07-23 10:36:11 +02:00
parent 59eaf1875f
commit 0cf7ed9135

View File

@ -29,10 +29,19 @@ impl PersonalizationService {
personalize: Option<&Personalize>, personalize: Option<&Personalize>,
query: Option<&str>, query: Option<&str>,
) -> Result<SearchResult, ResponseError> { ) -> Result<SearchResult, ResponseError> {
// If personalization is not requested, no API key, or no query, return original results // If no API key, return original results
let Some(_personalize) = personalize else { return Ok(search_result) };
let Some(cohere) = &self.cohere else { return Ok(search_result) }; let Some(cohere) = &self.cohere else { return Ok(search_result) };
let Some(query) = query else { return Ok(search_result) };
// Extract user context from personalization
let user_context = personalize.and_then(|p| p.user_context.as_deref());
// Build the prompt by merging query and user context
let prompt = match (query, user_context) {
(Some(q), Some(uc)) => format!("User Context: {}\nQuery: {}", uc, q),
(Some(q), None) => q.to_string(),
(None, Some(uc)) => format!("User Context: {}", uc),
(None, None) => return Ok(search_result),
};
// Extract documents for reranking // Extract documents for reranking
let documents: Vec<String> = search_result let documents: Vec<String> = search_result
@ -50,7 +59,7 @@ impl PersonalizationService {
// Prepare the rerank request // Prepare the rerank request
let rerank_request = ReRankRequest { let rerank_request = ReRankRequest {
query, query: &prompt,
documents: &documents, documents: &documents,
model: ReRankModel::EnglishV3, // Use the default and more recent model model: ReRankModel::EnglishV3, // Use the default and more recent model
top_n: None, top_n: None,
@ -124,7 +133,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_personalization_service_disabled() { async fn test_personalization_service_with_user_context_only() {
let service = PersonalizationService::new(Some("fake_key".to_string())); let service = PersonalizationService::new(Some("fake_key".to_string()));
let personalize = Personalize { user_context: Some("test user".to_string()) }; let personalize = Personalize { user_context: Some("test user".to_string()) };
@ -146,12 +155,71 @@ mod tests {
used_negative_operator: false, used_negative_operator: false,
}; };
let result = service let result =
.rerank_search_results(search_result.clone(), Some(&personalize), Some("test")) service.rerank_search_results(search_result.clone(), Some(&personalize), None).await;
.await;
assert!(result.is_ok()); assert!(result.is_ok());
// Should return original results when personalization is disabled // Should attempt reranking with user context only
let reranked_result = result.unwrap();
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
}
#[tokio::test]
async fn test_personalization_service_with_query_only() {
let service = PersonalizationService::new(Some("fake_key".to_string()));
let search_result = SearchResult {
hits: vec![SearchHit {
document: serde_json::Map::new(),
formatted: serde_json::Map::new(),
matches_position: None,
ranking_score: Some(1.0),
ranking_score_details: None,
}],
query: "test".to_string(),
processing_time_ms: 10,
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
facet_distribution: None,
facet_stats: None,
semantic_hit_count: None,
degraded: false,
used_negative_operator: false,
};
let result = service.rerank_search_results(search_result.clone(), None, Some("test")).await;
assert!(result.is_ok());
// Should attempt reranking with query only
let reranked_result = result.unwrap();
assert_eq!(reranked_result.hits.len(), search_result.hits.len());
}
#[tokio::test]
async fn test_personalization_service_both_none() {
let service = PersonalizationService::new(Some("fake_key".to_string()));
let search_result = SearchResult {
hits: vec![SearchHit {
document: serde_json::Map::new(),
formatted: serde_json::Map::new(),
matches_position: None,
ranking_score: Some(1.0),
ranking_score_details: None,
}],
query: "test".to_string(),
processing_time_ms: 10,
hits_info: HitsInfo::OffsetLimit { limit: 1, offset: 0, estimated_total_hits: 1 },
facet_distribution: None,
facet_stats: None,
semantic_hit_count: None,
degraded: false,
used_negative_operator: false,
};
let result = service.rerank_search_results(search_result.clone(), None, None).await;
assert!(result.is_ok());
// Should return original results when both query and user_context are None
let reranked_result = result.unwrap(); let reranked_result = result.unwrap();
assert_eq!(reranked_result.hits.len(), search_result.hits.len()); assert_eq!(reranked_result.hits.len(), search_result.hits.len());
} }