Limit the number of internal loop calls and change the function name

This commit is contained in:
Clément Renault 2025-05-20 16:44:28 +02:00 committed by Kerollmops
parent 1159af1219
commit 14c8e5cb56
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -41,6 +41,7 @@ use crate::search::{
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
const EMBEDDER_NAME: &str = "openai"; const EMBEDDER_NAME: &str = "openai";
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("/completions").route(web::post().to(chat))); cfg.service(web::resource("/completions").route(web::post().to(chat)));
@ -77,12 +78,15 @@ async fn chat(
/// Setup search tool in chat completion request /// Setup search tool in chat completion request
fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts) { fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts) {
let tools = chat_completion.tools.get_or_insert_default(); let tools = chat_completion.tools.get_or_insert_default();
tools.push( if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
ChatCompletionToolArgs::default() panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
}
let tool = ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function) .r#type(ChatCompletionToolType::Function)
.function( .function(
FunctionObjectArgs::default() FunctionObjectArgs::default()
.name("searchInIndex") .name(SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&prompts.search_description) .description(&prompts.search_description)
.parameters(json!({ .parameters(json!({
"type": "object", "type": "object",
@ -107,9 +111,8 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
.unwrap(), .unwrap(),
) )
.build() .build()
.unwrap(), .unwrap();
); tools.push(tool);
chat_completion.messages.insert( chat_completion.messages.insert(
0, 0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
@ -222,8 +225,9 @@ async fn non_streamed_chat(
Some(FinishReason::ToolCalls) => { Some(FinishReason::ToolCalls) => {
let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default(); let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default();
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
tool_calls.into_iter().partition(|call| call.function.name == "searchInIndex"); .into_iter()
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
chat_completion.messages.push( chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default() ChatCompletionRequestAssistantMessageArgs::default()
@ -297,7 +301,8 @@ async fn streamed_chat(
let mut global_tool_calls = HashMap::<u32, Call>::new(); let mut global_tool_calls = HashMap::<u32, Call>::new();
let mut finish_reason = None; let mut finish_reason = None;
'main: while finish_reason.map_or(true, |fr| fr == FinishReason::ToolCalls) { // Limit the number of internal calls to satisfy the search requests of the LLM
'main: for _ in 0..20 {
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
while let Some(result) = response.next().await { while let Some(result) = response.next().await {
match result { match result {
@ -354,7 +359,7 @@ async fn streamed_chat(
arguments: call.arguments, arguments: call.arguments,
}, },
}) })
.partition(|call| call.function.name == "searchInIndex"); .partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
chat_completion.messages.push( chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default() ChatCompletionRequestAssistantMessageArgs::default()
@ -441,6 +446,11 @@ async fn streamed_chat(
} }
} }
} }
// We must stop if the finish reason is not something we can solve with Meilisearch
if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) {
break;
}
} }
let _ = tx.send(Event::Data(sse::Data::new("[DONE]"))); let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));