From e603e221d57f539e8e7cd18fb08b650963ecfdfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 15 May 2025 15:39:38 +0200 Subject: [PATCH] Factorise a bit the code --- crates/meilisearch/src/routes/chat.rs | 376 ++++++++++++-------------- 1 file changed, 169 insertions(+), 207 deletions(-) diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 8db6a1dde..e4a9b65e2 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -55,6 +55,14 @@ pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service(web::resource("").route(web::post().to(chat))); } +/// Creates OpenAI client with API key +fn create_openai_client() -> Client { + let api_key = std::env::var("MEILI_OPENAI_API_KEY") + .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); + let config = OpenAIConfig::default().with_api_key(&api_key); + Client::with_config(config) +} + /// Get a chat completion async fn chat( index_scheduler: GuardedData, Data>, @@ -77,16 +85,112 @@ async fn chat( } } -async fn non_streamed_chat( - index_scheduler: GuardedData, Data>, - search_queue: web::Data, - mut chat_completion: CreateChatCompletionRequest, -) -> Result { - let api_key = std::env::var("MEILI_OPENAI_API_KEY") - .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); - let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base - let client = Client::with_config(config); +/// Setup search tool in chat completion request +fn setup_search_tool( + chat_completion: &mut CreateChatCompletionRequest, + search_in_index_description: &str, + search_in_index_q_param_description: &str, + search_in_index_index_description: &str, +) { + let tools = chat_completion.tools.get_or_insert_default(); + tools.push( + ChatCompletionToolArgs::default() + .r#type(ChatCompletionToolType::Function) + .function( + FunctionObjectArgs::default() + .name("searchInIndex") + .description(search_in_index_description) + .parameters(json!({ + "type": "object", + "properties": { + "index_uid": { + "type": "string", + "enum": ["main"], + "description": search_in_index_index_description, + }, + "q": { + "type": ["string", "null"], + "description": search_in_index_q_param_description, + } + }, + "required": ["index_uid", "q"], + "additionalProperties": false, + })) + .strict(true) + .build() + .unwrap(), + ) + .build() + .unwrap(), + ); +} +/// Process search request and return formatted results +async fn process_search_request( + index_scheduler: &GuardedData, Data>, + search_queue: &web::Data, + index_uid: String, + q: Option, +) -> Result<(Index, String), ResponseError> { + let mut query = SearchQuery { + q, + hybrid: Some(HybridQuery { + semantic_ratio: SemanticRatio::default(), + embedder: EMBEDDER_NAME.to_string(), + }), + limit: 20, + ..Default::default() + }; + + // Tenant token search_rules. + if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { + add_search_rules(&mut query.filter, search_rules); + } + + // TBD + // let mut aggregate = SearchAggregator::::from_query(&query); + + let index = index_scheduler.index(&index_uid)?; + let search_kind = + search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; + + let permit = search_queue.try_get_search_permit().await?; + let features = index_scheduler.features(); + let index_cloned = index.clone(); + let search_result = tokio::task::spawn_blocking(move || { + perform_search( + index_uid.to_string(), + &index_cloned, + query, + search_kind, + RetrieveVectors::new(false), + features, + ) + }) + .await; + permit.drop().await; + + let search_result = search_result?; + if let Ok(ref search_result) = search_result { + // aggregate.succeed(search_result); + if search_result.degraded { + MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); + } + } + // analytics.publish(aggregate, &req); + + let search_result = search_result?; + let formatted = + format_documents(&index, search_result.hits.into_iter().map(|doc| doc.document)); + let text = formatted.join("\n"); + + Ok((index, text)) +} + +/// Get prompt descriptions from index scheduler +fn get_prompt_descriptions( + index_scheduler: &GuardedData, Data>, +) -> (String, String, String) { let rtxn = index_scheduler.read_txn().unwrap(); let search_in_index_description = index_scheduler .chat_prompts(&rtxn, "searchInIndex-description") @@ -105,39 +209,35 @@ async fn non_streamed_chat( .to_string(); drop(rtxn); + ( + search_in_index_description, + search_in_index_q_param_description, + search_in_index_index_description, + ) +} + +async fn non_streamed_chat( + index_scheduler: GuardedData, Data>, + search_queue: web::Data, + mut chat_completion: CreateChatCompletionRequest, +) -> Result { + let client = create_openai_client(); + + let ( + search_in_index_description, + search_in_index_q_param_description, + search_in_index_index_description, + ) = get_prompt_descriptions(&index_scheduler); + let mut response; loop { - let tools = chat_completion.tools.get_or_insert_default(); - tools.push( - ChatCompletionToolArgs::default() - .r#type(ChatCompletionToolType::Function) - .function( - FunctionObjectArgs::default() - .name("searchInIndex") - .description(&search_in_index_description) - .parameters(json!({ - "type": "object", - "properties": { - "index_uid": { - "type": "string", - "enum": ["main"], - "description": search_in_index_index_description, - }, - "q": { - "type": ["string", "null"], - "description": search_in_index_q_param_description, - } - }, - "required": ["index_uid", "q"], - "additionalProperties": false, - })) - .strict(true) - .build() - .unwrap(), - ) - .build() - .unwrap(), + setup_search_tool( + &mut chat_completion, + &search_in_index_description, + &search_in_index_q_param_description, + &search_in_index_index_description, ); + response = client.chat().create(chat_completion.clone()).await.unwrap(); let choice = &mut response.choices[0]; @@ -160,65 +260,10 @@ async fn non_streamed_chat( let SearchInIndexParameters { index_uid, q } = serde_json::from_str(&call.function.arguments).unwrap(); - let mut query = SearchQuery { - q, - hybrid: Some(HybridQuery { - semantic_ratio: SemanticRatio::default(), - embedder: EMBEDDER_NAME.to_string(), - }), - limit: 20, - ..Default::default() - }; + let (_, text) = + process_search_request(&index_scheduler, &search_queue, index_uid, q) + .await?; - // Tenant token search_rules. - if let Some(search_rules) = - index_scheduler.filters().get_index_search_rules(&index_uid) - { - add_search_rules(&mut query.filter, search_rules); - } - - // TBD - // let mut aggregate = SearchAggregator::::from_query(&query); - - let index = index_scheduler.index(&index_uid)?; - let search_kind = search_kind( - &query, - index_scheduler.get_ref(), - index_uid.to_string(), - &index, - )?; - - let permit = search_queue.try_get_search_permit().await?; - let features = index_scheduler.features(); - let index_cloned = index.clone(); - let search_result = tokio::task::spawn_blocking(move || { - perform_search( - index_uid.to_string(), - &index_cloned, - query, - search_kind, - RetrieveVectors::new(false), - features, - ) - }) - .await; - permit.drop().await; - - let search_result = search_result?; - if let Ok(ref search_result) = search_result { - // aggregate.succeed(search_result); - if search_result.degraded { - MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); - } - } - // analytics.publish(aggregate, &req); - - let search_result = search_result?; - let formatted = format_documents( - &index, - search_result.hits.into_iter().map(|doc| doc.document), - ); - let text = formatted.join("\n"); chat_completion.messages.push(ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { tool_call_id: call.id, @@ -245,63 +290,22 @@ async fn streamed_chat( search_queue: web::Data, mut chat_completion: CreateChatCompletionRequest, ) -> impl Responder { - let api_key = std::env::var("MEILI_OPENAI_API_KEY") - .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); + let ( + search_in_index_description, + search_in_index_q_param_description, + search_in_index_index_description, + ) = get_prompt_descriptions(&index_scheduler); - let rtxn = index_scheduler.read_txn().unwrap(); - let search_in_index_description = index_scheduler - .chat_prompts(&rtxn, "searchInIndex-description") - .unwrap() - .unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION) - .to_string(); - let search_in_index_q_param_description = index_scheduler - .chat_prompts(&rtxn, "searchInIndex-q-param-description") - .unwrap() - .unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION) - .to_string(); - let search_in_index_index_description = index_scheduler - .chat_prompts(&rtxn, "searchInIndex-index-param-description") - .unwrap() - .unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION) - .to_string(); - drop(rtxn); - - let tools = chat_completion.tools.get_or_insert_default(); - tools.push( - ChatCompletionToolArgs::default() - .r#type(ChatCompletionToolType::Function) - .function( - FunctionObjectArgs::default() - .name("searchInIndex") - .description(&search_in_index_description) - .parameters(json!({ - "type": "object", - "properties": { - "index_uid": { - "type": "string", - "enum": ["main"], - "description": search_in_index_index_description, - }, - "q": { - "type": ["string", "null"], - "description": search_in_index_q_param_description, - } - }, - "required": ["index_uid", "q"], - "additionalProperties": false, - })) - .strict(true) - .build() - .unwrap(), - ) - .build() - .unwrap(), + setup_search_tool( + &mut chat_completion, + &search_in_index_description, + &search_in_index_q_param_description, + &search_in_index_index_description, ); let (tx, rx) = tokio::sync::mpsc::channel(10); let _join_handle = Handle::current().spawn(async move { - let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base - let client = Client::with_config(config); + let client = create_openai_client(); let mut global_tool_calls = HashMap::::new(); 'main: loop { @@ -313,7 +317,9 @@ async fn streamed_chat( let delta = &resp.choices[0].delta; let ChatCompletionStreamResponseDelta { content, - function_call: _, + // Using deprecated field but keeping for compatibility + #[allow(deprecated)] + function_call: _, ref tool_calls, role: _, refusal: _, @@ -352,7 +358,7 @@ async fn streamed_chat( None if !global_tool_calls.is_empty() => { // dbg!(&global_tool_calls); - let (meili_calls, other_calls): (Vec<_>, Vec<_>) = + let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = mem::take(&mut global_tool_calls) .into_iter() .map(|(_, call)| ChatCompletionMessageToolCall { @@ -387,67 +393,23 @@ async fn streamed_chat( let SearchInIndexParameters { index_uid, q } = serde_json::from_str(&call.function.arguments).unwrap(); - let mut query = SearchQuery { + let result = process_search_request( + &index_scheduler, + &search_queue, + index_uid, q, - hybrid: Some(HybridQuery { - semantic_ratio: SemanticRatio::default(), - embedder: EMBEDDER_NAME.to_string(), - }), - limit: 20, - ..Default::default() - }; - - // Tenant token search_rules. - if let Some(search_rules) = - index_scheduler.filters().get_index_search_rules(&index_uid) - { - add_search_rules(&mut query.filter, search_rules); - } - - // TBD - // let mut aggregate = SearchAggregator::::from_query(&query); - - let index = index_scheduler.index(&index_uid).unwrap(); - let search_kind = search_kind( - &query, - index_scheduler.get_ref(), - index_uid.to_string(), - &index, ) - .unwrap(); - - let permit = - search_queue.try_get_search_permit().await.unwrap(); - let features = index_scheduler.features(); - let index_cloned = index.clone(); - let search_result = tokio::task::spawn_blocking(move || { - perform_search( - index_uid.to_string(), - &index_cloned, - query, - search_kind, - RetrieveVectors::new(false), - features, - ) - }) .await; - permit.drop().await; - let search_result = search_result.unwrap(); - if let Ok(ref search_result) = search_result { - // aggregate.succeed(search_result); - if search_result.degraded { - MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); - } + // Handle potential errors more explicitly + if let Err(err) = &result { + // Log the error or handle it as needed + eprintln!("Error processing search request: {:?}", err); + continue; } - // analytics.publish(aggregate, &req); - let search_result = search_result.unwrap(); - let formatted = format_documents( - &index, - search_result.hits.into_iter().map(|doc| doc.document), - ); - let text = formatted.join("\n"); + let (_, text) = result.unwrap(); + let tool = ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { tool_call_id: call.id, @@ -515,7 +477,7 @@ fn format_documents(index: &Index, documents: impl Iterator) -> let EmbeddingConfig { embedder_options: _, - prompt: PromptData { template, max_bytes }, + prompt: PromptData { template, max_bytes: _ }, quantized: _, } = config;