diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 31e089231..2715cff72 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -41,6 +41,7 @@ use crate::search::{ use crate::search_queue::SearchQueue; const EMBEDDER_NAME: &str = "openai"; +const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service(web::resource("/completions").route(web::post().to(chat))); @@ -77,39 +78,41 @@ async fn chat( /// Setup search tool in chat completion request fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts) { let tools = chat_completion.tools.get_or_insert_default(); - tools.push( - ChatCompletionToolArgs::default() - .r#type(ChatCompletionToolType::Function) - .function( - FunctionObjectArgs::default() - .name("searchInIndex") - .description(&prompts.search_description) - .parameters(json!({ - "type": "object", - "properties": { - "index_uid": { - "type": "string", - "enum": ["main"], - "description": prompts.search_index_uid_param, - }, - "q": { - // Unfortunately, Mistral does not support an array of types, here. - // "type": ["string", "null"], - "type": "string", - "description": prompts.search_q_param, - } - }, - "required": ["index_uid", "q"], - "additionalProperties": false, - })) - .strict(true) - .build() - .unwrap(), - ) - .build() - .unwrap(), - ); + if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() { + panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set"); + } + let tool = ChatCompletionToolArgs::default() + .r#type(ChatCompletionToolType::Function) + .function( + FunctionObjectArgs::default() + .name(SEARCH_IN_INDEX_FUNCTION_NAME) + .description(&prompts.search_description) + .parameters(json!({ + "type": "object", + "properties": { + "index_uid": { + "type": "string", + "enum": ["main"], + "description": prompts.search_index_uid_param, + }, + "q": { + // Unfortunately, Mistral does not support an array of types, here. + // "type": ["string", "null"], + "type": "string", + "description": prompts.search_q_param, + } + }, + "required": ["index_uid", "q"], + "additionalProperties": false, + })) + .strict(true) + .build() + .unwrap(), + ) + .build() + .unwrap(); + tools.push(tool); chat_completion.messages.insert( 0, ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { @@ -222,8 +225,9 @@ async fn non_streamed_chat( Some(FinishReason::ToolCalls) => { let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default(); - let (meili_calls, other_calls): (Vec<_>, Vec<_>) = - tool_calls.into_iter().partition(|call| call.function.name == "searchInIndex"); + let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls + .into_iter() + .partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME); chat_completion.messages.push( ChatCompletionRequestAssistantMessageArgs::default() @@ -297,7 +301,8 @@ async fn streamed_chat( let mut global_tool_calls = HashMap::::new(); let mut finish_reason = None; - 'main: while finish_reason.map_or(true, |fr| fr == FinishReason::ToolCalls) { + // Limit the number of internal calls to satisfy the search requests of the LLM + 'main: for _ in 0..20 { let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); while let Some(result) = response.next().await { match result { @@ -354,7 +359,7 @@ async fn streamed_chat( arguments: call.arguments, }, }) - .partition(|call| call.function.name == "searchInIndex"); + .partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME); chat_completion.messages.push( ChatCompletionRequestAssistantMessageArgs::default() @@ -441,6 +446,11 @@ async fn streamed_chat( } } } + + // We must stop if the finish reason is not something we can solve with Meilisearch + if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) { + break; + } } let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));