mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-06 12:15:45 +00:00
Call specific tools to show progression and results.
This commit is contained in:
parent
045a1b1e75
commit
d45647b58d
@ -10,13 +10,14 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event, Sse};
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::types::{
|
||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
||||
CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream,
|
||||
FunctionObjectArgs,
|
||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
||||
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
|
||||
CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
|
||||
FunctionObjectArgs, Role,
|
||||
};
|
||||
use async_openai::Client;
|
||||
use bumpalo::Bump;
|
||||
@ -34,7 +35,7 @@ use meilisearch_types::milli::{
|
||||
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
|
||||
};
|
||||
use meilisearch_types::Index;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
@ -52,7 +53,9 @@ use crate::search::{
|
||||
};
|
||||
use crate::search_queue::SearchQueue;
|
||||
|
||||
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
||||
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
||||
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(web::resource("/completions").route(web::post().to(chat)));
|
||||
@ -86,18 +89,45 @@ async fn chat(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy)]
|
||||
pub struct FunctionSupport {
|
||||
/// Defines if we can call the _meiliSearchProgress function
|
||||
/// to inform the front-end about what we are searching for.
|
||||
progress: bool,
|
||||
/// Defines if we can call the _meiliAppendConversationMessage
|
||||
/// function to provide the messages to append into the conversation.
|
||||
append_to_conversation: bool,
|
||||
}
|
||||
|
||||
/// Setup search tool in chat completion request
|
||||
fn setup_search_tool(
|
||||
index_scheduler: &Data<IndexScheduler>,
|
||||
filters: &meilisearch_auth::AuthFilter,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
prompts: &ChatPrompts,
|
||||
) -> Result<(), ResponseError> {
|
||||
) -> Result<FunctionSupport, ResponseError> {
|
||||
let tools = chat_completion.tools.get_or_insert_default();
|
||||
if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
|
||||
panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
|
||||
if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
|
||||
panic!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
|
||||
}
|
||||
|
||||
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
|
||||
let mut progress = false;
|
||||
let mut append_to_conversation = false;
|
||||
tools.retain(|tool| {
|
||||
match tool.function.name.as_str() {
|
||||
MEILI_SEARCH_PROGRESS_NAME => {
|
||||
progress = true;
|
||||
false
|
||||
}
|
||||
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
|
||||
append_to_conversation = true;
|
||||
false
|
||||
}
|
||||
_ => true, // keep other tools
|
||||
}
|
||||
});
|
||||
|
||||
let mut index_uids = Vec::new();
|
||||
let mut function_description = prompts.search_description.clone().unwrap();
|
||||
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
|
||||
@ -119,7 +149,7 @@ fn setup_search_tool(
|
||||
.r#type(ChatCompletionToolType::Function)
|
||||
.function(
|
||||
FunctionObjectArgs::default()
|
||||
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
.name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
.description(&function_description)
|
||||
.parameters(json!({
|
||||
"type": "object",
|
||||
@ -145,7 +175,9 @@ fn setup_search_tool(
|
||||
)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
tools.push(tool);
|
||||
|
||||
chat_completion.messages.insert(
|
||||
0,
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||
@ -156,7 +188,7 @@ fn setup_search_tool(
|
||||
}),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
Ok(FunctionSupport { progress, append_to_conversation })
|
||||
}
|
||||
|
||||
/// Process search request and return formatted results
|
||||
@ -287,7 +319,8 @@ async fn non_streamed_chat(
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
let FunctionSupport { progress, append_to_conversation } =
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
|
||||
let mut response;
|
||||
loop {
|
||||
@ -300,7 +333,7 @@ async fn non_streamed_chat(
|
||||
|
||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
|
||||
.into_iter()
|
||||
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
|
||||
.partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME);
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
@ -378,7 +411,8 @@ async fn streamed_chat(
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
let FunctionSupport { progress, append_to_conversation } =
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let _join_handle = Handle::current().spawn(async move {
|
||||
@ -395,21 +429,8 @@ async fn streamed_chat(
|
||||
let choice = &resp.choices[0];
|
||||
finish_reason = choice.finish_reason;
|
||||
|
||||
#[allow(deprecated)]
|
||||
let ChatCompletionStreamResponseDelta {
|
||||
content,
|
||||
// Using deprecated field but keeping for compatibility
|
||||
function_call: _,
|
||||
ref tool_calls,
|
||||
role: _,
|
||||
refusal: _,
|
||||
} = &choice.delta;
|
||||
|
||||
if content.is_some() {
|
||||
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await {
|
||||
return;
|
||||
}
|
||||
}
|
||||
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
|
||||
&choice.delta;
|
||||
|
||||
match tool_calls {
|
||||
Some(tool_calls) => {
|
||||
@ -422,109 +443,195 @@ async fn streamed_chat(
|
||||
} = chunk;
|
||||
let FunctionCallStream { name, arguments } =
|
||||
function.as_ref().unwrap();
|
||||
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
.and_modify(|call| call.append(arguments.as_ref().unwrap()))
|
||||
.or_insert_with(|| Call {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
None if !global_tool_calls.is_empty() => {
|
||||
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(&mut global_tool_calls)
|
||||
.into_values()
|
||||
.map(|call| ChatCompletionMessageToolCall {
|
||||
id: call.id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall {
|
||||
name: call.function_name,
|
||||
arguments: call.arguments,
|
||||
},
|
||||
.and_modify(|call| {
|
||||
if call.is_internal() {
|
||||
call.append(arguments.as_ref().unwrap())
|
||||
}
|
||||
})
|
||||
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
|
||||
for call in meili_calls {
|
||||
if let Err(SendError(_)) = tx.send(Event::Data(
|
||||
sse::Data::new_json(json!({
|
||||
"object": "chat.completion.tool.call",
|
||||
"tool": call,
|
||||
}))
|
||||
.unwrap(),
|
||||
))
|
||||
.await {
|
||||
return;
|
||||
}
|
||||
|
||||
let result = match serde_json::from_str(&call.function.arguments) {
|
||||
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
).await.map_err(|e| e.to_string()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
let is_error = result.is_err();
|
||||
let text = match result {
|
||||
Ok((_, text)) => text,
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
let tool = ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(
|
||||
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
|
||||
),
|
||||
};
|
||||
|
||||
if let Err(SendError(_)) = tx.send(Event::Data(
|
||||
sse::Data::new_json(json!({
|
||||
"object": if is_error {
|
||||
"chat.completion.tool.error"
|
||||
.or_insert_with(|| {
|
||||
if name.as_ref().map_or(false, |n| {
|
||||
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
}) {
|
||||
Call::Internal {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
}
|
||||
} else {
|
||||
"chat.completion.tool.output"
|
||||
},
|
||||
"tool": ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id,
|
||||
content: ChatCompletionRequestToolMessageContent::Text(
|
||||
text,
|
||||
),
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
))
|
||||
.await {
|
||||
return;
|
||||
}
|
||||
Call::External { _id: id.as_ref().unwrap().clone() }
|
||||
}
|
||||
});
|
||||
|
||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool));
|
||||
if global_tool_calls.get(index).map_or(false, Call::is_external)
|
||||
{
|
||||
todo!("Support forwarding external tool calls");
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if !global_tool_calls.is_empty() {
|
||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(&mut global_tool_calls)
|
||||
.into_values()
|
||||
.flat_map(|call| match call {
|
||||
Call::Internal {
|
||||
id,
|
||||
function_name: name,
|
||||
arguments,
|
||||
} => Some(ChatCompletionMessageToolCall {
|
||||
id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name, arguments },
|
||||
}),
|
||||
Call::External { _id: _ } => None,
|
||||
})
|
||||
.partition(|call| {
|
||||
call.function.name
|
||||
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
});
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
|
||||
assert!(
|
||||
other_calls.is_empty(),
|
||||
"We do not support external tool forwarding for now"
|
||||
);
|
||||
|
||||
for call in meili_calls {
|
||||
if progress {
|
||||
let call = MeiliSearchProgress {
|
||||
function_name: call.function.name.clone(),
|
||||
function_arguments: call
|
||||
.function
|
||||
.arguments
|
||||
.clone(),
|
||||
};
|
||||
let resp = call.create_response(resp.clone());
|
||||
// Send the event of "we are doing a search"
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if append_to_conversation {
|
||||
// Ask the front-end user to append this tool *call* to the conversation
|
||||
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
|
||||
ChatCompletionRequestAssistantMessage {
|
||||
content: None,
|
||||
refusal: None,
|
||||
name: None,
|
||||
audio: None,
|
||||
tool_calls: Some(vec![
|
||||
ChatCompletionMessageToolCall {
|
||||
id: call.id.clone(),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments: call.function.arguments.clone(),
|
||||
},
|
||||
},
|
||||
]),
|
||||
function_call: None,
|
||||
}
|
||||
));
|
||||
let resp = call.create_response(resp.clone());
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let result =
|
||||
match serde_json::from_str(&call.function.arguments) {
|
||||
Ok(SearchInIndexParameters { index_uid, q }) => {
|
||||
process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
Err(err) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
let text = match result {
|
||||
Ok((_, text)) => text,
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(
|
||||
format!(
|
||||
"{}\n\n{text}",
|
||||
chat_settings
|
||||
.prompts
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.pre_query
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
),
|
||||
),
|
||||
});
|
||||
|
||||
if append_to_conversation {
|
||||
// Ask the front-end user to append this tool *output* to the conversation
|
||||
let tool = MeiliAppendConversationMessage(tool.clone());
|
||||
let resp = tool.create_response(resp.clone());
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
chat_completion.messages.push(tool);
|
||||
}
|
||||
} else {
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("{err:?}");
|
||||
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&json!({
|
||||
"object": "chat.completion.error",
|
||||
"tool": err.to_string(),
|
||||
})).unwrap())).await {
|
||||
return;
|
||||
}
|
||||
// tracing::error!("{err:?}");
|
||||
// if let Err(SendError(_)) = tx
|
||||
// .send(Event::Data(
|
||||
// sse::Data::new_json(&json!({
|
||||
// "object": "chat.completion.error",
|
||||
// "tool": err.to_string(),
|
||||
// }))
|
||||
// .unwrap(),
|
||||
// ))
|
||||
// .await
|
||||
// {
|
||||
// return;
|
||||
// }
|
||||
|
||||
break 'main;
|
||||
}
|
||||
@ -543,17 +650,106 @@ async fn streamed_chat(
|
||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Give context about what Meilisearch is doing.
|
||||
struct MeiliSearchProgress {
|
||||
/// The name of the function we are executing.
|
||||
pub function_name: String,
|
||||
/// The arguments of the function we are executing, encoded in JSON.
|
||||
pub function_arguments: String,
|
||||
}
|
||||
|
||||
impl MeiliSearchProgress {
|
||||
fn create_response(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
) -> CreateChatCompletionStreamResponse {
|
||||
let call_text = serde_json::to_string(self).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
resp
|
||||
}
|
||||
}
|
||||
|
||||
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
|
||||
|
||||
impl MeiliAppendConversationMessage {
|
||||
fn create_response(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
) -> CreateChatCompletionStreamResponse {
|
||||
let call_text = serde_json::to_string(&self.0).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
resp
|
||||
}
|
||||
}
|
||||
|
||||
/// The structure used to aggregate the function calls to make.
|
||||
#[derive(Debug)]
|
||||
struct Call {
|
||||
id: String,
|
||||
function_name: String,
|
||||
arguments: String,
|
||||
enum Call {
|
||||
/// Tool calls to tools that must be managed by Meilisearch internally.
|
||||
/// Typically the search functions.
|
||||
Internal { id: String, function_name: String, arguments: String },
|
||||
/// Tool calls that we track but only to know that its not our functions.
|
||||
/// We return the function calls as-is to the end-user.
|
||||
External { _id: String },
|
||||
}
|
||||
|
||||
impl Call {
|
||||
fn append(&mut self, arguments: &str) {
|
||||
self.arguments.push_str(arguments);
|
||||
fn is_internal(&self) -> bool {
|
||||
matches!(self, Call::Internal { .. })
|
||||
}
|
||||
|
||||
fn is_external(&self) -> bool {
|
||||
matches!(self, Call::External { .. })
|
||||
}
|
||||
|
||||
fn append(&mut self, more: &str) {
|
||||
match self {
|
||||
Call::Internal { arguments, .. } => arguments.push_str(more),
|
||||
Call::External { .. } => {
|
||||
panic!("Cannot append argument chunks to an external function")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user