Call specific tools to show progression and results.

This commit is contained in:
Clément Renault 2025-05-23 17:25:09 +02:00
parent 045a1b1e75
commit d45647b58d
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -10,13 +10,14 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse}; use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig; use async_openai::config::OpenAIConfig;
use async_openai::types::{ use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream, ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
FunctionObjectArgs, CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
FunctionObjectArgs, Role,
}; };
use async_openai::Client; use async_openai::Client;
use bumpalo::Bump; use bumpalo::Bump;
@ -34,7 +35,7 @@ use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
}; };
use meilisearch_types::Index; use meilisearch_types::Index;
use serde::Deserialize; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
@ -52,7 +53,9 @@ use crate::search::{
}; };
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("/completions").route(web::post().to(chat))); cfg.service(web::resource("/completions").route(web::post().to(chat)));
@ -86,18 +89,45 @@ async fn chat(
} }
} }
#[derive(Default, Debug, Clone, Copy)]
pub struct FunctionSupport {
/// Defines if we can call the _meiliSearchProgress function
/// to inform the front-end about what we are searching for.
progress: bool,
/// Defines if we can call the _meiliAppendConversationMessage
/// function to provide the messages to append into the conversation.
append_to_conversation: bool,
}
/// Setup search tool in chat completion request /// Setup search tool in chat completion request
fn setup_search_tool( fn setup_search_tool(
index_scheduler: &Data<IndexScheduler>, index_scheduler: &Data<IndexScheduler>,
filters: &meilisearch_auth::AuthFilter, filters: &meilisearch_auth::AuthFilter,
chat_completion: &mut CreateChatCompletionRequest, chat_completion: &mut CreateChatCompletionRequest,
prompts: &ChatPrompts, prompts: &ChatPrompts,
) -> Result<(), ResponseError> { ) -> Result<FunctionSupport, ResponseError> {
let tools = chat_completion.tools.get_or_insert_default(); let tools = chat_completion.tools.get_or_insert_default();
if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() { if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set"); panic!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
} }
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
let mut progress = false;
let mut append_to_conversation = false;
tools.retain(|tool| {
match tool.function.name.as_str() {
MEILI_SEARCH_PROGRESS_NAME => {
progress = true;
false
}
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
append_to_conversation = true;
false
}
_ => true, // keep other tools
}
});
let mut index_uids = Vec::new(); let mut index_uids = Vec::new();
let mut function_description = prompts.search_description.clone().unwrap(); let mut function_description = prompts.search_description.clone().unwrap();
index_scheduler.try_for_each_index::<_, ()>(|name, index| { index_scheduler.try_for_each_index::<_, ()>(|name, index| {
@ -119,7 +149,7 @@ fn setup_search_tool(
.r#type(ChatCompletionToolType::Function) .r#type(ChatCompletionToolType::Function)
.function( .function(
FunctionObjectArgs::default() FunctionObjectArgs::default()
.name(SEARCH_IN_INDEX_FUNCTION_NAME) .name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&function_description) .description(&function_description)
.parameters(json!({ .parameters(json!({
"type": "object", "type": "object",
@ -145,7 +175,9 @@ fn setup_search_tool(
) )
.build() .build()
.unwrap(); .unwrap();
tools.push(tool); tools.push(tool);
chat_completion.messages.insert( chat_completion.messages.insert(
0, 0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
@ -156,7 +188,7 @@ fn setup_search_tool(
}), }),
); );
Ok(()) Ok(FunctionSupport { progress, append_to_conversation })
} }
/// Process search request and return formatted results /// Process search request and return formatted results
@ -287,6 +319,7 @@ async fn non_streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap(); let auth_token = extract_token_from_request(&req)?.unwrap();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { progress, append_to_conversation } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let mut response; let mut response;
@ -300,7 +333,7 @@ async fn non_streamed_chat(
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
.into_iter() .into_iter()
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME); .partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME);
chat_completion.messages.push( chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default() ChatCompletionRequestAssistantMessageArgs::default()
@ -378,6 +411,7 @@ async fn streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { progress, append_to_conversation } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10); let (tx, rx) = tokio::sync::mpsc::channel(10);
@ -395,21 +429,8 @@ async fn streamed_chat(
let choice = &resp.choices[0]; let choice = &resp.choices[0];
finish_reason = choice.finish_reason; finish_reason = choice.finish_reason;
#[allow(deprecated)] let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
let ChatCompletionStreamResponseDelta { &choice.delta;
content,
// Using deprecated field but keeping for compatibility
function_call: _,
ref tool_calls,
role: _,
refusal: _,
} = &choice.delta;
if content.is_some() {
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await {
return;
}
}
match tool_calls { match tool_calls {
Some(tool_calls) => { Some(tool_calls) => {
@ -422,29 +443,55 @@ async fn streamed_chat(
} = chunk; } = chunk;
let FunctionCallStream { name, arguments } = let FunctionCallStream { name, arguments } =
function.as_ref().unwrap(); function.as_ref().unwrap();
global_tool_calls global_tool_calls
.entry(*index) .entry(*index)
.and_modify(|call| call.append(arguments.as_ref().unwrap())) .and_modify(|call| {
.or_insert_with(|| Call { if call.is_internal() {
call.append(arguments.as_ref().unwrap())
}
})
.or_insert_with(|| {
if name.as_ref().map_or(false, |n| {
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
}) {
Call::Internal {
id: id.as_ref().unwrap().clone(), id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(), function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(), arguments: arguments.as_ref().unwrap().clone(),
}
} else {
Call::External { _id: id.as_ref().unwrap().clone() }
}
}); });
if global_tool_calls.get(index).map_or(false, Call::is_external)
{
todo!("Support forwarding external tool calls");
} }
} }
None if !global_tool_calls.is_empty() => { }
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = None => {
if !global_tool_calls.is_empty() {
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls) mem::take(&mut global_tool_calls)
.into_values() .into_values()
.map(|call| ChatCompletionMessageToolCall { .flat_map(|call| match call {
id: call.id, Call::Internal {
id,
function_name: name,
arguments,
} => Some(ChatCompletionMessageToolCall {
id,
r#type: Some(ChatCompletionToolType::Function), r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { function: FunctionCall { name, arguments },
name: call.function_name, }),
arguments: call.arguments, Call::External { _id: _ } => None,
},
}) })
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME); .partition(|call| {
call.function.name
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
});
chat_completion.messages.push( chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default() ChatCompletionRequestAssistantMessageArgs::default()
@ -454,77 +501,137 @@ async fn streamed_chat(
.into(), .into(),
); );
assert!(
other_calls.is_empty(),
"We do not support external tool forwarding for now"
);
for call in meili_calls { for call in meili_calls {
if let Err(SendError(_)) = tx.send(Event::Data( if progress {
sse::Data::new_json(json!({ let call = MeiliSearchProgress {
"object": "chat.completion.tool.call", function_name: call.function.name.clone(),
"tool": call, function_arguments: call
})) .function
.unwrap(), .arguments
)) .clone(),
.await { };
let resp = call.create_response(resp.clone());
// Send the event of "we are doing a search"
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return; return;
} }
}
let result = match serde_json::from_str(&call.function.arguments) { if append_to_conversation {
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request( // Ask the front-end user to append this tool *call* to the conversation
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![
ChatCompletionMessageToolCall {
id: call.id.clone(),
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall {
name: call.function.name.clone(),
arguments: call.function.arguments.clone(),
},
},
]),
function_call: None,
}
));
let resp = call.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
let result =
match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => {
process_search_request(
&index_scheduler, &index_scheduler,
auth_ctrl.clone(), auth_ctrl.clone(),
&search_queue, &search_queue,
&auth_token, &auth_token,
index_uid, index_uid,
q, q,
).await.map_err(|e| e.to_string()), )
.await
.map_err(|e| e.to_string())
}
Err(err) => Err(err.to_string()), Err(err) => Err(err.to_string()),
}; };
let is_error = result.is_err();
let text = match result { let text = match result {
Ok((_, text)) => text, Ok((_, text)) => text,
Err(err) => err, Err(err) => err,
}; };
let tool = ChatCompletionRequestToolMessage { let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(), tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text( content: ChatCompletionRequestToolMessageContent::Text(
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()), format!(
"{}\n\n{text}",
chat_settings
.prompts
.as_ref()
.unwrap()
.pre_query
.as_ref()
.unwrap()
), ),
}; ),
});
if let Err(SendError(_)) = tx.send(Event::Data( if append_to_conversation {
sse::Data::new_json(json!({ // Ask the front-end user to append this tool *output* to the conversation
"object": if is_error { let tool = MeiliAppendConversationMessage(tool.clone());
"chat.completion.tool.error" let resp = tool.create_response(resp.clone());
} else { if let Err(SendError(_)) = tx
"chat.completion.tool.output" .send(Event::Data(sse::Data::new_json(&resp).unwrap()))
}, .await
"tool": ChatCompletionRequestToolMessage { {
tool_call_id: call.id,
content: ChatCompletionRequestToolMessageContent::Text(
text,
),
},
}))
.unwrap(),
))
.await {
return; return;
} }
}
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool)); chat_completion.messages.push(tool);
}
} else {
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
} }
} }
None => (),
} }
} }
Err(err) => { Err(err) => {
tracing::error!("{err:?}"); // tracing::error!("{err:?}");
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&json!({ // if let Err(SendError(_)) = tx
"object": "chat.completion.error", // .send(Event::Data(
"tool": err.to_string(), // sse::Data::new_json(&json!({
})).unwrap())).await { // "object": "chat.completion.error",
return; // "tool": err.to_string(),
} // }))
// .unwrap(),
// ))
// .await
// {
// return;
// }
break 'main; break 'main;
} }
@ -543,17 +650,106 @@ async fn streamed_chat(
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))) Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
} }
#[derive(Debug, Clone, Serialize)]
/// Give context about what Meilisearch is doing.
struct MeiliSearchProgress {
/// The name of the function we are executing.
pub function_name: String,
/// The arguments of the function we are executing, encoded in JSON.
pub function_arguments: String,
}
impl MeiliSearchProgress {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(self).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
impl MeiliAppendConversationMessage {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(&self.0).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
/// The structure used to aggregate the function calls to make. /// The structure used to aggregate the function calls to make.
#[derive(Debug)] #[derive(Debug)]
struct Call { enum Call {
id: String, /// Tool calls to tools that must be managed by Meilisearch internally.
function_name: String, /// Typically the search functions.
arguments: String, Internal { id: String, function_name: String, arguments: String },
/// Tool calls that we track but only to know that its not our functions.
/// We return the function calls as-is to the end-user.
External { _id: String },
} }
impl Call { impl Call {
fn append(&mut self, arguments: &str) { fn is_internal(&self) -> bool {
self.arguments.push_str(arguments); matches!(self, Call::Internal { .. })
}
fn is_external(&self) -> bool {
matches!(self, Call::External { .. })
}
fn append(&mut self, more: &str) {
match self {
Call::Internal { arguments, .. } => arguments.push_str(more),
Call::External { .. } => {
panic!("Cannot append argument chunks to an external function")
}
}
} }
} }