mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-06 20:25:40 +00:00
Call specific tools to show progression and results.
This commit is contained in:
parent
ca5a87a606
commit
807157b8cd
@ -10,13 +10,14 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
|||||||
use actix_web_lab::sse::{self, Event, Sse};
|
use actix_web_lab::sse::{self, Event, Sse};
|
||||||
use async_openai::config::OpenAIConfig;
|
use async_openai::config::OpenAIConfig;
|
||||||
use async_openai::types::{
|
use async_openai::types::{
|
||||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
||||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
||||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
|
||||||
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
||||||
CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream,
|
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
|
||||||
FunctionObjectArgs,
|
CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
|
||||||
|
FunctionObjectArgs, Role,
|
||||||
};
|
};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use bumpalo::Bump;
|
use bumpalo::Bump;
|
||||||
@ -34,7 +35,7 @@ use meilisearch_types::milli::{
|
|||||||
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
|
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
|
||||||
};
|
};
|
||||||
use meilisearch_types::Index;
|
use meilisearch_types::Index;
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::runtime::Handle;
|
use tokio::runtime::Handle;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
@ -52,7 +53,9 @@ use crate::search::{
|
|||||||
};
|
};
|
||||||
use crate::search_queue::SearchQueue;
|
use crate::search_queue::SearchQueue;
|
||||||
|
|
||||||
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
||||||
|
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
||||||
|
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||||
|
|
||||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
cfg.service(web::resource("/completions").route(web::post().to(chat)));
|
cfg.service(web::resource("/completions").route(web::post().to(chat)));
|
||||||
@ -86,18 +89,45 @@ async fn chat(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug, Clone, Copy)]
|
||||||
|
pub struct FunctionSupport {
|
||||||
|
/// Defines if we can call the _meiliSearchProgress function
|
||||||
|
/// to inform the front-end about what we are searching for.
|
||||||
|
progress: bool,
|
||||||
|
/// Defines if we can call the _meiliAppendConversationMessage
|
||||||
|
/// function to provide the messages to append into the conversation.
|
||||||
|
append_to_conversation: bool,
|
||||||
|
}
|
||||||
|
|
||||||
/// Setup search tool in chat completion request
|
/// Setup search tool in chat completion request
|
||||||
fn setup_search_tool(
|
fn setup_search_tool(
|
||||||
index_scheduler: &Data<IndexScheduler>,
|
index_scheduler: &Data<IndexScheduler>,
|
||||||
filters: &meilisearch_auth::AuthFilter,
|
filters: &meilisearch_auth::AuthFilter,
|
||||||
chat_completion: &mut CreateChatCompletionRequest,
|
chat_completion: &mut CreateChatCompletionRequest,
|
||||||
prompts: &ChatPrompts,
|
prompts: &ChatPrompts,
|
||||||
) -> Result<(), ResponseError> {
|
) -> Result<FunctionSupport, ResponseError> {
|
||||||
let tools = chat_completion.tools.get_or_insert_default();
|
let tools = chat_completion.tools.get_or_insert_default();
|
||||||
if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
|
if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
|
||||||
panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
|
panic!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
|
||||||
|
let mut progress = false;
|
||||||
|
let mut append_to_conversation = false;
|
||||||
|
tools.retain(|tool| {
|
||||||
|
match tool.function.name.as_str() {
|
||||||
|
MEILI_SEARCH_PROGRESS_NAME => {
|
||||||
|
progress = true;
|
||||||
|
false
|
||||||
|
}
|
||||||
|
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
|
||||||
|
append_to_conversation = true;
|
||||||
|
false
|
||||||
|
}
|
||||||
|
_ => true, // keep other tools
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut index_uids = Vec::new();
|
let mut index_uids = Vec::new();
|
||||||
let mut function_description = prompts.search_description.clone().unwrap();
|
let mut function_description = prompts.search_description.clone().unwrap();
|
||||||
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
|
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
|
||||||
@ -119,7 +149,7 @@ fn setup_search_tool(
|
|||||||
.r#type(ChatCompletionToolType::Function)
|
.r#type(ChatCompletionToolType::Function)
|
||||||
.function(
|
.function(
|
||||||
FunctionObjectArgs::default()
|
FunctionObjectArgs::default()
|
||||||
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
|
.name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||||
.description(&function_description)
|
.description(&function_description)
|
||||||
.parameters(json!({
|
.parameters(json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -145,7 +175,9 @@ fn setup_search_tool(
|
|||||||
)
|
)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
tools.push(tool);
|
tools.push(tool);
|
||||||
|
|
||||||
chat_completion.messages.insert(
|
chat_completion.messages.insert(
|
||||||
0,
|
0,
|
||||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||||
@ -156,7 +188,7 @@ fn setup_search_tool(
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(FunctionSupport { progress, append_to_conversation })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process search request and return formatted results
|
/// Process search request and return formatted results
|
||||||
@ -287,7 +319,8 @@ async fn non_streamed_chat(
|
|||||||
|
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
let FunctionSupport { progress, append_to_conversation } =
|
||||||
|
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||||
|
|
||||||
let mut response;
|
let mut response;
|
||||||
loop {
|
loop {
|
||||||
@ -300,7 +333,7 @@ async fn non_streamed_chat(
|
|||||||
|
|
||||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
|
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
|
.partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME);
|
||||||
|
|
||||||
chat_completion.messages.push(
|
chat_completion.messages.push(
|
||||||
ChatCompletionRequestAssistantMessageArgs::default()
|
ChatCompletionRequestAssistantMessageArgs::default()
|
||||||
@ -378,7 +411,8 @@ async fn streamed_chat(
|
|||||||
|
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
let FunctionSupport { progress, append_to_conversation } =
|
||||||
|
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
let _join_handle = Handle::current().spawn(async move {
|
let _join_handle = Handle::current().spawn(async move {
|
||||||
@ -395,21 +429,8 @@ async fn streamed_chat(
|
|||||||
let choice = &resp.choices[0];
|
let choice = &resp.choices[0];
|
||||||
finish_reason = choice.finish_reason;
|
finish_reason = choice.finish_reason;
|
||||||
|
|
||||||
#[allow(deprecated)]
|
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
|
||||||
let ChatCompletionStreamResponseDelta {
|
&choice.delta;
|
||||||
content,
|
|
||||||
// Using deprecated field but keeping for compatibility
|
|
||||||
function_call: _,
|
|
||||||
ref tool_calls,
|
|
||||||
role: _,
|
|
||||||
refusal: _,
|
|
||||||
} = &choice.delta;
|
|
||||||
|
|
||||||
if content.is_some() {
|
|
||||||
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match tool_calls {
|
match tool_calls {
|
||||||
Some(tool_calls) => {
|
Some(tool_calls) => {
|
||||||
@ -422,109 +443,195 @@ async fn streamed_chat(
|
|||||||
} = chunk;
|
} = chunk;
|
||||||
let FunctionCallStream { name, arguments } =
|
let FunctionCallStream { name, arguments } =
|
||||||
function.as_ref().unwrap();
|
function.as_ref().unwrap();
|
||||||
|
|
||||||
global_tool_calls
|
global_tool_calls
|
||||||
.entry(*index)
|
.entry(*index)
|
||||||
.and_modify(|call| call.append(arguments.as_ref().unwrap()))
|
.and_modify(|call| {
|
||||||
.or_insert_with(|| Call {
|
if call.is_internal() {
|
||||||
id: id.as_ref().unwrap().clone(),
|
call.append(arguments.as_ref().unwrap())
|
||||||
function_name: name.as_ref().unwrap().clone(),
|
}
|
||||||
arguments: arguments.as_ref().unwrap().clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None if !global_tool_calls.is_empty() => {
|
|
||||||
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
|
||||||
mem::take(&mut global_tool_calls)
|
|
||||||
.into_values()
|
|
||||||
.map(|call| ChatCompletionMessageToolCall {
|
|
||||||
id: call.id,
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: FunctionCall {
|
|
||||||
name: call.function_name,
|
|
||||||
arguments: call.arguments,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
.partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME);
|
.or_insert_with(|| {
|
||||||
|
if name.as_ref().map_or(false, |n| {
|
||||||
chat_completion.messages.push(
|
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||||
ChatCompletionRequestAssistantMessageArgs::default()
|
}) {
|
||||||
.tool_calls(meili_calls.clone())
|
Call::Internal {
|
||||||
.build()
|
id: id.as_ref().unwrap().clone(),
|
||||||
.unwrap()
|
function_name: name.as_ref().unwrap().clone(),
|
||||||
.into(),
|
arguments: arguments.as_ref().unwrap().clone(),
|
||||||
);
|
}
|
||||||
|
|
||||||
for call in meili_calls {
|
|
||||||
if let Err(SendError(_)) = tx.send(Event::Data(
|
|
||||||
sse::Data::new_json(json!({
|
|
||||||
"object": "chat.completion.tool.call",
|
|
||||||
"tool": call,
|
|
||||||
}))
|
|
||||||
.unwrap(),
|
|
||||||
))
|
|
||||||
.await {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = match serde_json::from_str(&call.function.arguments) {
|
|
||||||
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
|
|
||||||
&index_scheduler,
|
|
||||||
auth_ctrl.clone(),
|
|
||||||
&search_queue,
|
|
||||||
&auth_token,
|
|
||||||
index_uid,
|
|
||||||
q,
|
|
||||||
).await.map_err(|e| e.to_string()),
|
|
||||||
Err(err) => Err(err.to_string()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let is_error = result.is_err();
|
|
||||||
let text = match result {
|
|
||||||
Ok((_, text)) => text,
|
|
||||||
Err(err) => err,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tool = ChatCompletionRequestToolMessage {
|
|
||||||
tool_call_id: call.id.clone(),
|
|
||||||
content: ChatCompletionRequestToolMessageContent::Text(
|
|
||||||
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(SendError(_)) = tx.send(Event::Data(
|
|
||||||
sse::Data::new_json(json!({
|
|
||||||
"object": if is_error {
|
|
||||||
"chat.completion.tool.error"
|
|
||||||
} else {
|
} else {
|
||||||
"chat.completion.tool.output"
|
Call::External { _id: id.as_ref().unwrap().clone() }
|
||||||
},
|
}
|
||||||
"tool": ChatCompletionRequestToolMessage {
|
});
|
||||||
tool_call_id: call.id,
|
|
||||||
content: ChatCompletionRequestToolMessageContent::Text(
|
|
||||||
text,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
.unwrap(),
|
|
||||||
))
|
|
||||||
.await {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool));
|
if global_tool_calls.get(index).map_or(false, Call::is_external)
|
||||||
|
{
|
||||||
|
todo!("Support forwarding external tool calls");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if !global_tool_calls.is_empty() {
|
||||||
|
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
||||||
|
mem::take(&mut global_tool_calls)
|
||||||
|
.into_values()
|
||||||
|
.flat_map(|call| match call {
|
||||||
|
Call::Internal {
|
||||||
|
id,
|
||||||
|
function_name: name,
|
||||||
|
arguments,
|
||||||
|
} => Some(ChatCompletionMessageToolCall {
|
||||||
|
id,
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: FunctionCall { name, arguments },
|
||||||
|
}),
|
||||||
|
Call::External { _id: _ } => None,
|
||||||
|
})
|
||||||
|
.partition(|call| {
|
||||||
|
call.function.name
|
||||||
|
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||||
|
});
|
||||||
|
|
||||||
|
chat_completion.messages.push(
|
||||||
|
ChatCompletionRequestAssistantMessageArgs::default()
|
||||||
|
.tool_calls(meili_calls.clone())
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
other_calls.is_empty(),
|
||||||
|
"We do not support external tool forwarding for now"
|
||||||
|
);
|
||||||
|
|
||||||
|
for call in meili_calls {
|
||||||
|
if progress {
|
||||||
|
let call = MeiliSearchProgress {
|
||||||
|
function_name: call.function.name.clone(),
|
||||||
|
function_arguments: call
|
||||||
|
.function
|
||||||
|
.arguments
|
||||||
|
.clone(),
|
||||||
|
};
|
||||||
|
let resp = call.create_response(resp.clone());
|
||||||
|
// Send the event of "we are doing a search"
|
||||||
|
if let Err(SendError(_)) = tx
|
||||||
|
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if append_to_conversation {
|
||||||
|
// Ask the front-end user to append this tool *call* to the conversation
|
||||||
|
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
|
||||||
|
ChatCompletionRequestAssistantMessage {
|
||||||
|
content: None,
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
audio: None,
|
||||||
|
tool_calls: Some(vec![
|
||||||
|
ChatCompletionMessageToolCall {
|
||||||
|
id: call.id.clone(),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: FunctionCall {
|
||||||
|
name: call.function.name.clone(),
|
||||||
|
arguments: call.function.arguments.clone(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
function_call: None,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
let resp = call.create_response(resp.clone());
|
||||||
|
if let Err(SendError(_)) = tx
|
||||||
|
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let result =
|
||||||
|
match serde_json::from_str(&call.function.arguments) {
|
||||||
|
Ok(SearchInIndexParameters { index_uid, q }) => {
|
||||||
|
process_search_request(
|
||||||
|
&index_scheduler,
|
||||||
|
auth_ctrl.clone(),
|
||||||
|
&search_queue,
|
||||||
|
&auth_token,
|
||||||
|
index_uid,
|
||||||
|
q,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
Err(err) => Err(err.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = match result {
|
||||||
|
Ok((_, text)) => text,
|
||||||
|
Err(err) => err,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||||
|
tool_call_id: call.id.clone(),
|
||||||
|
content: ChatCompletionRequestToolMessageContent::Text(
|
||||||
|
format!(
|
||||||
|
"{}\n\n{text}",
|
||||||
|
chat_settings
|
||||||
|
.prompts
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.pre_query
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
),
|
||||||
|
),
|
||||||
|
});
|
||||||
|
|
||||||
|
if append_to_conversation {
|
||||||
|
// Ask the front-end user to append this tool *output* to the conversation
|
||||||
|
let tool = MeiliAppendConversationMessage(tool.clone());
|
||||||
|
let resp = tool.create_response(resp.clone());
|
||||||
|
if let Err(SendError(_)) = tx
|
||||||
|
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chat_completion.messages.push(tool);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if let Err(SendError(_)) = tx
|
||||||
|
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::error!("{err:?}");
|
// tracing::error!("{err:?}");
|
||||||
if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&json!({
|
// if let Err(SendError(_)) = tx
|
||||||
"object": "chat.completion.error",
|
// .send(Event::Data(
|
||||||
"tool": err.to_string(),
|
// sse::Data::new_json(&json!({
|
||||||
})).unwrap())).await {
|
// "object": "chat.completion.error",
|
||||||
return;
|
// "tool": err.to_string(),
|
||||||
}
|
// }))
|
||||||
|
// .unwrap(),
|
||||||
|
// ))
|
||||||
|
// .await
|
||||||
|
// {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
break 'main;
|
break 'main;
|
||||||
}
|
}
|
||||||
@ -543,17 +650,106 @@ async fn streamed_chat(
|
|||||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// Give context about what Meilisearch is doing.
|
||||||
|
struct MeiliSearchProgress {
|
||||||
|
/// The name of the function we are executing.
|
||||||
|
pub function_name: String,
|
||||||
|
/// The arguments of the function we are executing, encoded in JSON.
|
||||||
|
pub function_arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MeiliSearchProgress {
|
||||||
|
fn create_response(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
) -> CreateChatCompletionStreamResponse {
|
||||||
|
let call_text = serde_json::to_string(self).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
|
||||||
|
|
||||||
|
impl MeiliAppendConversationMessage {
|
||||||
|
fn create_response(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
) -> CreateChatCompletionStreamResponse {
|
||||||
|
let call_text = serde_json::to_string(&self.0).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The structure used to aggregate the function calls to make.
|
/// The structure used to aggregate the function calls to make.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Call {
|
enum Call {
|
||||||
id: String,
|
/// Tool calls to tools that must be managed by Meilisearch internally.
|
||||||
function_name: String,
|
/// Typically the search functions.
|
||||||
arguments: String,
|
Internal { id: String, function_name: String, arguments: String },
|
||||||
|
/// Tool calls that we track but only to know that its not our functions.
|
||||||
|
/// We return the function calls as-is to the end-user.
|
||||||
|
External { _id: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Call {
|
impl Call {
|
||||||
fn append(&mut self, arguments: &str) {
|
fn is_internal(&self) -> bool {
|
||||||
self.arguments.push_str(arguments);
|
matches!(self, Call::Internal { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_external(&self) -> bool {
|
||||||
|
matches!(self, Call::External { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append(&mut self, more: &str) {
|
||||||
|
match self {
|
||||||
|
Call::Internal { arguments, .. } => arguments.push_str(more),
|
||||||
|
Call::External { .. } => {
|
||||||
|
panic!("Cannot append argument chunks to an external function")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user