mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-06 04:05:37 +00:00
Factorize the code a bit more and support reporting errors
This commit is contained in:
parent
fa139ee601
commit
5d8cdb075b
@ -2,13 +2,15 @@ use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
use std::mem;
|
||||
use std::ops::ControlFlow;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event, Sse};
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::config::{Config, OpenAIConfig};
|
||||
use async_openai::error::OpenAIError;
|
||||
use async_openai::types::{
|
||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
||||
@ -40,6 +42,7 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
|
||||
use crate::error::MeilisearchHttpError;
|
||||
@ -83,11 +86,11 @@ async fn chat(
|
||||
|
||||
if chat_completion.stream.unwrap_or(false) {
|
||||
Either::Right(
|
||||
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
||||
streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
|
||||
)
|
||||
} else {
|
||||
Either::Left(
|
||||
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
||||
non_streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -327,8 +330,8 @@ async fn process_search_request(
|
||||
async fn non_streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
req: HttpRequest,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
let filters = index_scheduler.filters();
|
||||
@ -420,8 +423,8 @@ async fn non_streamed_chat(
|
||||
async fn streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
req: HttpRequest,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> Result<impl Responder, ResponseError> {
|
||||
let filters = index_scheduler.filters();
|
||||
@ -441,17 +444,60 @@ async fn streamed_chat(
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
|
||||
let function_support =
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let tx = SseEventSender(tx);
|
||||
let _join_handle = Handle::current().spawn(async move {
|
||||
let client = Client::with_config(config.clone());
|
||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||
let mut finish_reason = None;
|
||||
|
||||
// Limit the number of internal calls to satisfy the search requests of the LLM
|
||||
'main: for _ in 0..20 {
|
||||
for _ in 0..20 {
|
||||
let output = run_conversation(
|
||||
&index_scheduler,
|
||||
&auth_ctrl,
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
&client,
|
||||
&chat_settings,
|
||||
&mut chat_completion,
|
||||
&tx,
|
||||
&mut global_tool_calls,
|
||||
function_support,
|
||||
);
|
||||
|
||||
match output.await {
|
||||
Ok(ControlFlow::Continue(())) => (),
|
||||
Ok(ControlFlow::Break(_finish_reason)) => break,
|
||||
// If the connection is closed we must stop
|
||||
Err(SendError(_)) => return,
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx.stop().await;
|
||||
});
|
||||
|
||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||
}
|
||||
|
||||
/// Updates the chat completion with the new messages, streams the LLM tokens,
|
||||
/// and report progress and errors.
|
||||
async fn run_conversation<C: Config>(
|
||||
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: &web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
client: &Client<C>,
|
||||
chat_settings: &GlobalChatSettings,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
tx: &SseEventSender,
|
||||
global_tool_calls: &mut HashMap<u32, Call>,
|
||||
function_support: FunctionSupport,
|
||||
) -> Result<ControlFlow<Option<FinishReason>, ()>, SendError<Event>> {
|
||||
let mut finish_reason = None;
|
||||
|
||||
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||
while let Some(result) = response.next().await {
|
||||
match result {
|
||||
@ -459,8 +505,7 @@ async fn streamed_chat(
|
||||
let choice = &resp.choices[0];
|
||||
finish_reason = choice.finish_reason;
|
||||
|
||||
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
|
||||
&choice.delta;
|
||||
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta;
|
||||
|
||||
match tool_calls {
|
||||
Some(tool_calls) => {
|
||||
@ -471,8 +516,7 @@ async fn streamed_chat(
|
||||
r#type: _,
|
||||
function,
|
||||
} = chunk;
|
||||
let FunctionCallStream { name, arguments } =
|
||||
function.as_ref().unwrap();
|
||||
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
||||
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
@ -482,21 +526,21 @@ async fn streamed_chat(
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
if name.as_ref().map_or(false, |n| {
|
||||
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
}) {
|
||||
if name
|
||||
.as_ref()
|
||||
.map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
{
|
||||
Call::Internal {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
}
|
||||
} else {
|
||||
Call::External { _id: id.as_ref().unwrap().clone() }
|
||||
Call::External
|
||||
}
|
||||
});
|
||||
|
||||
if global_tool_calls.get(index).map_or(false, Call::is_external)
|
||||
{
|
||||
if global_tool_calls.get(index).map_or(false, Call::is_external) {
|
||||
todo!("Support forwarding external tool calls");
|
||||
}
|
||||
}
|
||||
@ -504,23 +548,20 @@ async fn streamed_chat(
|
||||
None => {
|
||||
if !global_tool_calls.is_empty() {
|
||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(&mut global_tool_calls)
|
||||
mem::take(global_tool_calls)
|
||||
.into_values()
|
||||
.flat_map(|call| match call {
|
||||
Call::Internal {
|
||||
id,
|
||||
function_name: name,
|
||||
arguments,
|
||||
} => Some(ChatCompletionMessageToolCall {
|
||||
Call::Internal { id, function_name: name, arguments } => {
|
||||
Some(ChatCompletionMessageToolCall {
|
||||
id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name, arguments },
|
||||
}),
|
||||
Call::External { _id: _ } => None,
|
||||
})
|
||||
}
|
||||
Call::External => None,
|
||||
})
|
||||
.partition(|call| {
|
||||
call.function.name
|
||||
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
});
|
||||
|
||||
chat_completion.messages.push(
|
||||
@ -536,60 +577,76 @@ async fn streamed_chat(
|
||||
"We do not support external tool forwarding for now"
|
||||
);
|
||||
|
||||
handle_meili_tools(
|
||||
&index_scheduler,
|
||||
&auth_ctrl,
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
chat_settings,
|
||||
tx,
|
||||
meili_calls,
|
||||
chat_completion,
|
||||
&resp,
|
||||
function_support,
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
tx.forward_response(&resp).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
if function_support.report_errors {
|
||||
tx.report_error(err).await?;
|
||||
}
|
||||
return Ok(ControlFlow::Break(None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We must stop if the finish reason is not something we can solve with Meilisearch
|
||||
match finish_reason {
|
||||
Some(FinishReason::ToolCalls) => Ok(ControlFlow::Continue(())),
|
||||
otherwise => Ok(ControlFlow::Break(otherwise)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_meili_tools(
|
||||
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: &web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
chat_settings: &GlobalChatSettings,
|
||||
tx: &SseEventSender,
|
||||
meili_calls: Vec<ChatCompletionMessageToolCall>,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
resp: &CreateChatCompletionStreamResponse,
|
||||
FunctionSupport { report_progress, report_sources, append_to_conversation, .. }: FunctionSupport,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
for call in meili_calls {
|
||||
if report_progress {
|
||||
let call = MeiliSearchProgress {
|
||||
call_id: call.id.to_string(),
|
||||
function_name: call.function.name.clone(),
|
||||
function_arguments: call
|
||||
.function
|
||||
.arguments
|
||||
.clone(),
|
||||
};
|
||||
let resp = call.create_response(resp.clone());
|
||||
// Send the event of "we are doing a search"
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
tx.report_search_progress(
|
||||
resp.clone(),
|
||||
&call.id,
|
||||
&call.function.name,
|
||||
&call.function.arguments,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if append_to_conversation {
|
||||
// Ask the front-end user to append this tool *call* to the conversation
|
||||
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
|
||||
ChatCompletionRequestAssistantMessage {
|
||||
content: None,
|
||||
refusal: None,
|
||||
name: None,
|
||||
audio: None,
|
||||
tool_calls: Some(vec![
|
||||
ChatCompletionMessageToolCall {
|
||||
id: call.id.clone(),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments: call.function.arguments.clone(),
|
||||
},
|
||||
},
|
||||
]),
|
||||
function_call: None,
|
||||
}
|
||||
));
|
||||
let resp = call.create_response(resp.clone());
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
tx.append_tool_call_conversation_message(
|
||||
resp.clone(),
|
||||
call.id.clone(),
|
||||
call.function.name.clone(),
|
||||
call.function.arguments.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let result =
|
||||
match serde_json::from_str(&call.function.arguments) {
|
||||
Ok(SearchInIndexParameters { index_uid, q }) => {
|
||||
process_search_request(
|
||||
let result = match serde_json::from_str(&call.function.arguments) {
|
||||
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
@ -598,197 +655,74 @@ async fn streamed_chat(
|
||||
q,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
.map_err(|e| e.to_string()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
let text = match result {
|
||||
Ok((_index, documents, text)) => {
|
||||
if report_sources {
|
||||
let call = MeiliSearchSources {
|
||||
call_id: call.id.to_string(),
|
||||
sources: documents,
|
||||
};
|
||||
let resp = call.create_response(resp.clone());
|
||||
// Send the event of "we are doing a search"
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
tx.report_sources(resp.clone(), &call.id, &documents).await?;
|
||||
}
|
||||
|
||||
text
|
||||
},
|
||||
}
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(
|
||||
format!(
|
||||
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
||||
"{}\n\n{text}",
|
||||
chat_settings
|
||||
.prompts
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.pre_query
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
),
|
||||
),
|
||||
chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()
|
||||
)),
|
||||
});
|
||||
|
||||
if append_to_conversation {
|
||||
// Ask the front-end user to append this tool *output* to the conversation
|
||||
let tool = MeiliAppendConversationMessage(tool.clone());
|
||||
let resp = tool.create_response(resp.clone());
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
tx.append_conversation_message(resp.clone(), &tool).await?;
|
||||
}
|
||||
|
||||
chat_completion.messages.push(tool);
|
||||
}
|
||||
} else {
|
||||
if let Err(SendError(_)) = tx
|
||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
// tracing::error!("{err:?}");
|
||||
// if let Err(SendError(_)) = tx
|
||||
// .send(Event::Data(
|
||||
// sse::Data::new_json(&json!({
|
||||
// "object": "chat.completion.error",
|
||||
// "tool": err.to_string(),
|
||||
// }))
|
||||
// .unwrap(),
|
||||
// ))
|
||||
// .await
|
||||
// {
|
||||
// return;
|
||||
// }
|
||||
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// We must stop if the finish reason is not something we can solve with Meilisearch
|
||||
if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
pub struct SseEventSender(Sender<Event>);
|
||||
|
||||
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
|
||||
impl SseEventSender {
|
||||
/// Ask the front-end user to append this tool *call* to the conversation
|
||||
pub async fn append_tool_call_conversation_message(
|
||||
&self,
|
||||
resp: CreateChatCompletionStreamResponse,
|
||||
call_id: String,
|
||||
function_name: String,
|
||||
function_arguments: String,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
let message =
|
||||
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||
content: None,
|
||||
refusal: None,
|
||||
name: None,
|
||||
audio: None,
|
||||
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
||||
id: call_id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name: function_name, arguments: function_arguments },
|
||||
}]),
|
||||
function_call: None,
|
||||
});
|
||||
|
||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||
self.append_conversation_message(resp, &message).await
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides information about the current Meilisearch search operation.
|
||||
struct MeiliSearchProgress {
|
||||
/// The call ID to track the sources of the search.
|
||||
pub call_id: String,
|
||||
/// The name of the function we are executing.
|
||||
pub function_name: String,
|
||||
/// The arguments of the function we are executing, encoded in JSON.
|
||||
pub function_arguments: String,
|
||||
}
|
||||
|
||||
impl MeiliSearchProgress {
|
||||
fn create_response(
|
||||
/// Ask the front-end user to append this tool to the conversation
|
||||
pub async fn append_conversation_message(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
) -> CreateChatCompletionStreamResponse {
|
||||
let call_text = serde_json::to_string(self).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
resp
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides sources of the search.
|
||||
struct MeiliSearchSources {
|
||||
/// The call ID to track the original search associated to those sources.
|
||||
pub call_id: String,
|
||||
/// The documents associated with the search (call_id).
|
||||
/// Only the displayed attributes of the documents are returned.
|
||||
pub sources: Vec<Document>,
|
||||
}
|
||||
|
||||
impl MeiliSearchSources {
|
||||
fn create_response(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
) -> CreateChatCompletionStreamResponse {
|
||||
let call_text = serde_json::to_string(self).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
resp
|
||||
}
|
||||
}
|
||||
|
||||
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
|
||||
|
||||
impl MeiliAppendConversationMessage {
|
||||
fn create_response(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
) -> CreateChatCompletionStreamResponse {
|
||||
let call_text = serde_json::to_string(&self.0).unwrap();
|
||||
message: &ChatCompletionRequestMessage,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
let call_text = serde_json::to_string(message).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
@ -798,6 +732,7 @@ impl MeiliAppendConversationMessage {
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
@ -810,7 +745,132 @@ impl MeiliAppendConversationMessage {
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
resp
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_search_progress(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
function_name: &str,
|
||||
function_arguments: &str,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides information about the current Meilisearch search operation.
|
||||
struct MeiliSearchProgress<'a> {
|
||||
/// The call ID to track the sources of the search.
|
||||
call_id: &'a str,
|
||||
/// The name of the function we are executing.
|
||||
function_name: &'a str,
|
||||
/// The arguments of the function we are executing, encoded in JSON.
|
||||
function_arguments: &'a str,
|
||||
}
|
||||
|
||||
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||
let call_text = serde_json::to_string(&progress).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_sources(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
documents: &[Document],
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides sources of the search.
|
||||
struct MeiliSearchSources<'a> {
|
||||
/// The call ID to track the original search associated to those sources.
|
||||
call_id: &'a str,
|
||||
/// The documents associated with the search (call_id).
|
||||
/// Only the displayed attributes of the documents are returned.
|
||||
sources: &'a [Document],
|
||||
}
|
||||
|
||||
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||
let call_text = serde_json::to_string(&sources).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_error(&self, error: OpenAIError) -> Result<(), SendError<Event>> {
|
||||
tracing::error!("OpenAI Error: {}", error);
|
||||
|
||||
let (error_code, message) = match error {
|
||||
OpenAIError::Reqwest(e) => ("internal_reqwest_error", e.to_string()),
|
||||
OpenAIError::ApiError(api_error) => ("llm_api_issue", api_error.to_string()),
|
||||
OpenAIError::JSONDeserialize(error) => ("internal_json_deserialize", error.to_string()),
|
||||
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
||||
OpenAIError::StreamError(error) => ("llm_api_stream_error", error.to_string()),
|
||||
OpenAIError::InvalidArgument(error) => ("internal_invalid_argument", error.to_string()),
|
||||
};
|
||||
|
||||
self.send_json(&json!({
|
||||
"error_code": error_code,
|
||||
"message": message,
|
||||
}))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn forward_response(
|
||||
&self,
|
||||
resp: &CreateChatCompletionStreamResponse,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
self.send_json(resp).await
|
||||
}
|
||||
|
||||
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
|
||||
}
|
||||
|
||||
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
||||
}
|
||||
}
|
||||
|
||||
@ -822,7 +882,7 @@ enum Call {
|
||||
Internal { id: String, function_name: String, arguments: String },
|
||||
/// Tool calls that we track but only to know that its not our functions.
|
||||
/// We return the function calls as-is to the end-user.
|
||||
External { _id: String },
|
||||
External,
|
||||
}
|
||||
|
||||
impl Call {
|
||||
|
Loading…
x
Reference in New Issue
Block a user