Call specific tools to show progression and results.

This commit is contained in:
Clément Renault 2025-05-23 17:25:09 +02:00
parent 045a1b1e75
commit d45647b58d
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -10,13 +10,14 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig;
use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream,
FunctionObjectArgs,
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
FunctionObjectArgs, Role,
};
use async_openai::Client;
use bumpalo::Bump;
@ -34,7 +35,7 @@ use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
};
use meilisearch_types::Index;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
@ -52,7 +53,9 @@ use crate::search::{
};
use crate::search_queue::SearchQueue;
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("/completions").route(web::post().to(chat)));
@ -86,18 +89,45 @@ async fn chat(
}
}
#[derive(Default, Debug, Clone, Copy)]
pub struct FunctionSupport {
/// Defines if we can call the _meiliSearchProgress function
/// to inform the front-end about what we are searching for.
progress: bool,
/// Defines if we can call the _meiliAppendConversationMessage
/// function to provide the messages to append into the conversation.
append_to_conversation: bool,
}
/// Setup search tool in chat completion request
fn setup_search_tool(
index_scheduler: &Data<IndexScheduler>,
filters: &meilisearch_auth::AuthFilter,
chat_completion: &mut CreateChatCompletionRequest,
prompts: &ChatPrompts,
) -> Result<(), ResponseError> {
) -> Result<FunctionSupport, ResponseError> {
let tools = chat_completion.tools.get_or_insert_default();
if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
panic!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
}
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
let mut progress = false;
let mut append_to_conversation = false;
tools.retain(|tool| {
match tool.function.name.as_str() {
MEILI_SEARCH_PROGRESS_NAME => {
progress = true;
false
}
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
append_to_conversation = true;
false
}
_ => true, // keep other tools
}
});
let mut index_uids = Vec::new();
let mut function_description = prompts.search_description.clone().unwrap();
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
@ -119,7 +149,7 @@ fn setup_search_tool(
.r#type(ChatCompletionToolType::Function)
.function(
FunctionObjectArgs::default()
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
.name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&function_description)
.parameters(json!({
"type": "object",
@ -145,7 +175,9 @@ fn setup_search_tool(
)
.build()
.unwrap();
tools.push(tool);
chat_completion.messages.insert(
0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
@ -156,7 +188,7 @@ fn setup_search_tool(
}),
);
Ok(())
Ok(FunctionSupport { progress, append_to_conversation })
}
/// Process search request and return formatted results
@ -287,7 +319,8 @@ async fn non_streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let FunctionSupport { progress, append_to_conversation } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let mut response;
loop {
@ -300,7 +333,7 @@ async fn non_streamed_chat(
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
.into_iter()
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
.partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME);
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
@ -378,7 +411,8 @@ async fn streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let FunctionSupport { progress, append_to_conversation } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move {
@ -395,21 +429,8 @@ async fn streamed_chat(
let choice = &resp.choices[0];
finish_reason = choice.finish_reason;
#[allow(deprecated)]
let ChatCompletionStreamResponseDelta {
content,
// Using deprecated field but keeping for compatibility
function_call: _,
ref tool_calls,
role: _,
refusal: _,
} = &choice.delta;
if content.is_some() {
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await {
return;
}
}
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
&choice.delta;
match tool_calls {
Some(tool_calls) => {
@ -422,109 +443,195 @@ async fn streamed_chat(
} = chunk;
let FunctionCallStream { name, arguments } =
function.as_ref().unwrap();
global_tool_calls
.entry(*index)
.and_modify(|call| call.append(arguments.as_ref().unwrap()))
.or_insert_with(|| Call {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
});
}
}
None if !global_tool_calls.is_empty() => {
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls)
.into_values()
.map(|call| ChatCompletionMessageToolCall {
id: call.id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall {
name: call.function_name,
arguments: call.arguments,
},
.and_modify(|call| {
if call.is_internal() {
call.append(arguments.as_ref().unwrap())
}
})
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
for call in meili_calls {
if let Err(SendError(_)) = tx.send(Event::Data(
sse::Data::new_json(json!({
"object": "chat.completion.tool.call",
"tool": call,
}))
.unwrap(),
))
.await {
return;
}
let result = match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
&auth_token,
index_uid,
q,
).await.map_err(|e| e.to_string()),
Err(err) => Err(err.to_string()),
};
let is_error = result.is_err();
let text = match result {
Ok((_, text)) => text,
Err(err) => err,
};
let tool = ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
),
};
if let Err(SendError(_)) = tx.send(Event::Data(
sse::Data::new_json(json!({
"object": if is_error {
"chat.completion.tool.error"
.or_insert_with(|| {
if name.as_ref().map_or(false, |n| {
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
}) {
Call::Internal {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
}
} else {
"chat.completion.tool.output"
},
"tool": ChatCompletionRequestToolMessage {
tool_call_id: call.id,
content: ChatCompletionRequestToolMessageContent::Text(
text,
),
},
}))
.unwrap(),
))
.await {
return;
}
Call::External { _id: id.as_ref().unwrap().clone() }
}
});
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool));
if global_tool_calls.get(index).map_or(false, Call::is_external)
{
todo!("Support forwarding external tool calls");
}
}
}
None => {
if !global_tool_calls.is_empty() {
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls)
.into_values()
.flat_map(|call| match call {
Call::Internal {
id,
function_name: name,
arguments,
} => Some(ChatCompletionMessageToolCall {
id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { name, arguments },
}),
Call::External { _id: _ } => None,
})
.partition(|call| {
call.function.name
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
});
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
assert!(
other_calls.is_empty(),
"We do not support external tool forwarding for now"
);
for call in meili_calls {
if progress {
let call = MeiliSearchProgress {
function_name: call.function.name.clone(),
function_arguments: call
.function
.arguments
.clone(),
};
let resp = call.create_response(resp.clone());
// Send the event of "we are doing a search"
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
if append_to_conversation {
// Ask the front-end user to append this tool *call* to the conversation
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![
ChatCompletionMessageToolCall {
id: call.id.clone(),
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall {
name: call.function.name.clone(),
arguments: call.function.arguments.clone(),
},
},
]),
function_call: None,
}
));
let resp = call.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
let result =
match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => {
process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
&auth_token,
index_uid,
q,
)
.await
.map_err(|e| e.to_string())
}
Err(err) => Err(err.to_string()),
};
let text = match result {
Ok((_, text)) => text,
Err(err) => err,
};
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(
format!(
"{}\n\n{text}",
chat_settings
.prompts
.as_ref()
.unwrap()
.pre_query
.as_ref()
.unwrap()
),
),
});
if append_to_conversation {
// Ask the front-end user to append this tool *output* to the conversation
let tool = MeiliAppendConversationMessage(tool.clone());
let resp = tool.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
chat_completion.messages.push(tool);
}
} else {
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
}
None => (),
}
}
Err(err) => {
tracing::error!("{err:?}");
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&json!({
"object": "chat.completion.error",
"tool": err.to_string(),
})).unwrap())).await {
return;
}
// tracing::error!("{err:?}");
// if let Err(SendError(_)) = tx
// .send(Event::Data(
// sse::Data::new_json(&json!({
// "object": "chat.completion.error",
// "tool": err.to_string(),
// }))
// .unwrap(),
// ))
// .await
// {
// return;
// }
break 'main;
}
@ -543,17 +650,106 @@ async fn streamed_chat(
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
}
#[derive(Debug, Clone, Serialize)]
/// Give context about what Meilisearch is doing.
struct MeiliSearchProgress {
/// The name of the function we are executing.
pub function_name: String,
/// The arguments of the function we are executing, encoded in JSON.
pub function_arguments: String,
}
impl MeiliSearchProgress {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(self).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
impl MeiliAppendConversationMessage {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(&self.0).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
/// The structure used to aggregate the function calls to make.
#[derive(Debug)]
struct Call {
id: String,
function_name: String,
arguments: String,
enum Call {
/// Tool calls to tools that must be managed by Meilisearch internally.
/// Typically the search functions.
Internal { id: String, function_name: String, arguments: String },
/// Tool calls that we track but only to know that its not our functions.
/// We return the function calls as-is to the end-user.
External { _id: String },
}
impl Call {
fn append(&mut self, arguments: &str) {
self.arguments.push_str(arguments);
fn is_internal(&self) -> bool {
matches!(self, Call::Internal { .. })
}
fn is_external(&self) -> bool {
matches!(self, Call::External { .. })
}
fn append(&mut self, more: &str) {
match self {
Call::Internal { arguments, .. } => arguments.push_str(more),
Call::External { .. } => {
panic!("Cannot append argument chunks to an external function")
}
}
}
}