Factorize the code a bit more and support reporting errors

This commit is contained in:
Clément Renault 2025-05-27 18:07:29 +02:00 committed by Kerollmops
parent fa139ee601
commit 5d8cdb075b
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -2,13 +2,15 @@ use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Write as _;
use std::mem;
use std::ops::ControlFlow;
use std::sync::RwLock;
use std::time::Duration;
use actix_web::web::{self, Data};
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig;
use async_openai::config::{Config, OpenAIConfig};
use async_openai::error::OpenAIError;
use async_openai::types::{
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
@ -40,6 +42,7 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::Sender;
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
use crate::error::MeilisearchHttpError;
@ -83,11 +86,11 @@ async fn chat(
if chat_completion.stream.unwrap_or(false) {
Either::Right(
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
)
} else {
Either::Left(
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
non_streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
)
}
}
@ -327,8 +330,8 @@ async fn process_search_request(
async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
req: HttpRequest,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
let filters = index_scheduler.filters();
@ -420,8 +423,8 @@ async fn non_streamed_chat(
async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
req: HttpRequest,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<impl Responder, ResponseError> {
let filters = index_scheduler.filters();
@ -441,17 +444,60 @@ async fn streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
let function_support =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10);
let tx = SseEventSender(tx);
let _join_handle = Handle::current().spawn(async move {
let client = Client::with_config(config.clone());
let mut global_tool_calls = HashMap::<u32, Call>::new();
let mut finish_reason = None;
// Limit the number of internal calls to satisfy the search requests of the LLM
'main: for _ in 0..20 {
for _ in 0..20 {
let output = run_conversation(
&index_scheduler,
&auth_ctrl,
&search_queue,
&auth_token,
&client,
&chat_settings,
&mut chat_completion,
&tx,
&mut global_tool_calls,
function_support,
);
match output.await {
Ok(ControlFlow::Continue(())) => (),
Ok(ControlFlow::Break(_finish_reason)) => break,
// If the connection is closed we must stop
Err(SendError(_)) => return,
}
}
let _ = tx.stop().await;
});
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
}
/// Updates the chat completion with the new messages, streams the LLM tokens,
/// and report progress and errors.
async fn run_conversation<C: Config>(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: &web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
client: &Client<C>,
chat_settings: &GlobalChatSettings,
chat_completion: &mut CreateChatCompletionRequest,
tx: &SseEventSender,
global_tool_calls: &mut HashMap<u32, Call>,
function_support: FunctionSupport,
) -> Result<ControlFlow<Option<FinishReason>, ()>, SendError<Event>> {
let mut finish_reason = None;
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
while let Some(result) = response.next().await {
match result {
@ -459,8 +505,7 @@ async fn streamed_chat(
let choice = &resp.choices[0];
finish_reason = choice.finish_reason;
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
&choice.delta;
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta;
match tool_calls {
Some(tool_calls) => {
@ -471,8 +516,7 @@ async fn streamed_chat(
r#type: _,
function,
} = chunk;
let FunctionCallStream { name, arguments } =
function.as_ref().unwrap();
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
global_tool_calls
.entry(*index)
@ -482,21 +526,21 @@ async fn streamed_chat(
}
})
.or_insert_with(|| {
if name.as_ref().map_or(false, |n| {
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
}) {
if name
.as_ref()
.map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
{
Call::Internal {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
}
} else {
Call::External { _id: id.as_ref().unwrap().clone() }
Call::External
}
});
if global_tool_calls.get(index).map_or(false, Call::is_external)
{
if global_tool_calls.get(index).map_or(false, Call::is_external) {
todo!("Support forwarding external tool calls");
}
}
@ -504,23 +548,20 @@ async fn streamed_chat(
None => {
if !global_tool_calls.is_empty() {
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls)
mem::take(global_tool_calls)
.into_values()
.flat_map(|call| match call {
Call::Internal {
id,
function_name: name,
arguments,
} => Some(ChatCompletionMessageToolCall {
Call::Internal { id, function_name: name, arguments } => {
Some(ChatCompletionMessageToolCall {
id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { name, arguments },
}),
Call::External { _id: _ } => None,
})
}
Call::External => None,
})
.partition(|call| {
call.function.name
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
});
chat_completion.messages.push(
@ -536,60 +577,76 @@ async fn streamed_chat(
"We do not support external tool forwarding for now"
);
handle_meili_tools(
&index_scheduler,
&auth_ctrl,
&search_queue,
&auth_token,
chat_settings,
tx,
meili_calls,
chat_completion,
&resp,
function_support,
)
.await?;
} else {
tx.forward_response(&resp).await?;
}
}
}
}
Err(err) => {
if function_support.report_errors {
tx.report_error(err).await?;
}
return Ok(ControlFlow::Break(None));
}
}
}
// We must stop if the finish reason is not something we can solve with Meilisearch
match finish_reason {
Some(FinishReason::ToolCalls) => Ok(ControlFlow::Continue(())),
otherwise => Ok(ControlFlow::Break(otherwise)),
}
}
async fn handle_meili_tools(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: &web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
chat_settings: &GlobalChatSettings,
tx: &SseEventSender,
meili_calls: Vec<ChatCompletionMessageToolCall>,
chat_completion: &mut CreateChatCompletionRequest,
resp: &CreateChatCompletionStreamResponse,
FunctionSupport { report_progress, report_sources, append_to_conversation, .. }: FunctionSupport,
) -> Result<(), SendError<Event>> {
for call in meili_calls {
if report_progress {
let call = MeiliSearchProgress {
call_id: call.id.to_string(),
function_name: call.function.name.clone(),
function_arguments: call
.function
.arguments
.clone(),
};
let resp = call.create_response(resp.clone());
// Send the event of "we are doing a search"
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
tx.report_search_progress(
resp.clone(),
&call.id,
&call.function.name,
&call.function.arguments,
)
.await?;
}
if append_to_conversation {
// Ask the front-end user to append this tool *call* to the conversation
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![
ChatCompletionMessageToolCall {
id: call.id.clone(),
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall {
name: call.function.name.clone(),
arguments: call.function.arguments.clone(),
},
},
]),
function_call: None,
}
));
let resp = call.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
tx.append_tool_call_conversation_message(
resp.clone(),
call.id.clone(),
call.function.name.clone(),
call.function.arguments.clone(),
)
.await?;
}
let result =
match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => {
process_search_request(
let result = match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
@ -598,197 +655,74 @@ async fn streamed_chat(
q,
)
.await
.map_err(|e| e.to_string())
}
.map_err(|e| e.to_string()),
Err(err) => Err(err.to_string()),
};
let text = match result {
Ok((_index, documents, text)) => {
if report_sources {
let call = MeiliSearchSources {
call_id: call.id.to_string(),
sources: documents,
};
let resp = call.create_response(resp.clone());
// Send the event of "we are doing a search"
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
tx.report_sources(resp.clone(), &call.id, &documents).await?;
}
text
},
}
Err(err) => err,
};
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(
format!(
content: ChatCompletionRequestToolMessageContent::Text(format!(
"{}\n\n{text}",
chat_settings
.prompts
.as_ref()
.unwrap()
.pre_query
.as_ref()
.unwrap()
),
),
chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()
)),
});
if append_to_conversation {
// Ask the front-end user to append this tool *output* to the conversation
let tool = MeiliAppendConversationMessage(tool.clone());
let resp = tool.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
tx.append_conversation_message(resp.clone(), &tool).await?;
}
chat_completion.messages.push(tool);
}
} else {
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
}
}
}
Err(err) => {
// tracing::error!("{err:?}");
// if let Err(SendError(_)) = tx
// .send(Event::Data(
// sse::Data::new_json(&json!({
// "object": "chat.completion.error",
// "tool": err.to_string(),
// }))
// .unwrap(),
// ))
// .await
// {
// return;
// }
break 'main;
}
}
Ok(())
}
// We must stop if the finish reason is not something we can solve with Meilisearch
if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) {
break;
}
}
pub struct SseEventSender(Sender<Event>);
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
impl SseEventSender {
/// Ask the front-end user to append this tool *call* to the conversation
pub async fn append_tool_call_conversation_message(
&self,
resp: CreateChatCompletionStreamResponse,
call_id: String,
function_name: String,
function_arguments: String,
) -> Result<(), SendError<Event>> {
let message =
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![ChatCompletionMessageToolCall {
id: call_id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { name: function_name, arguments: function_arguments },
}]),
function_call: None,
});
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
self.append_conversation_message(resp, &message).await
}
#[derive(Debug, Clone, Serialize)]
/// Provides information about the current Meilisearch search operation.
struct MeiliSearchProgress {
/// The call ID to track the sources of the search.
pub call_id: String,
/// The name of the function we are executing.
pub function_name: String,
/// The arguments of the function we are executing, encoded in JSON.
pub function_arguments: String,
}
impl MeiliSearchProgress {
fn create_response(
/// Ask the front-end user to append this tool to the conversation
pub async fn append_conversation_message(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(self).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
#[derive(Debug, Clone, Serialize)]
/// Provides sources of the search.
struct MeiliSearchSources {
/// The call ID to track the original search associated to those sources.
pub call_id: String,
/// The documents associated with the search (call_id).
/// Only the displayed attributes of the documents are returned.
pub sources: Vec<Document>,
}
impl MeiliSearchSources {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(self).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
impl MeiliAppendConversationMessage {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(&self.0).unwrap();
message: &ChatCompletionRequestMessage,
) -> Result<(), SendError<Event>> {
let call_text = serde_json::to_string(message).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
@ -798,6 +732,7 @@ impl MeiliAppendConversationMessage {
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
@ -810,7 +745,132 @@ impl MeiliAppendConversationMessage {
finish_reason: None,
logprobs: None,
};
resp
self.send_json(&resp).await
}
pub async fn report_search_progress(
&self,
mut resp: CreateChatCompletionStreamResponse,
call_id: &str,
function_name: &str,
function_arguments: &str,
) -> Result<(), SendError<Event>> {
#[derive(Debug, Clone, Serialize)]
/// Provides information about the current Meilisearch search operation.
struct MeiliSearchProgress<'a> {
/// The call ID to track the sources of the search.
call_id: &'a str,
/// The name of the function we are executing.
function_name: &'a str,
/// The arguments of the function we are executing, encoded in JSON.
function_arguments: &'a str,
}
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
let call_text = serde_json::to_string(&progress).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
self.send_json(&resp).await
}
pub async fn report_sources(
&self,
mut resp: CreateChatCompletionStreamResponse,
call_id: &str,
documents: &[Document],
) -> Result<(), SendError<Event>> {
#[derive(Debug, Clone, Serialize)]
/// Provides sources of the search.
struct MeiliSearchSources<'a> {
/// The call ID to track the original search associated to those sources.
call_id: &'a str,
/// The documents associated with the search (call_id).
/// Only the displayed attributes of the documents are returned.
sources: &'a [Document],
}
let sources = MeiliSearchSources { call_id, sources: documents };
let call_text = serde_json::to_string(&sources).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
self.send_json(&resp).await
}
pub async fn report_error(&self, error: OpenAIError) -> Result<(), SendError<Event>> {
tracing::error!("OpenAI Error: {}", error);
let (error_code, message) = match error {
OpenAIError::Reqwest(e) => ("internal_reqwest_error", e.to_string()),
OpenAIError::ApiError(api_error) => ("llm_api_issue", api_error.to_string()),
OpenAIError::JSONDeserialize(error) => ("internal_json_deserialize", error.to_string()),
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
OpenAIError::StreamError(error) => ("llm_api_stream_error", error.to_string()),
OpenAIError::InvalidArgument(error) => ("internal_invalid_argument", error.to_string()),
};
self.send_json(&json!({
"error_code": error_code,
"message": message,
}))
.await
}
pub async fn forward_response(
&self,
resp: &CreateChatCompletionStreamResponse,
) -> Result<(), SendError<Event>> {
self.send_json(resp).await
}
pub async fn stop(self) -> Result<(), SendError<Event>> {
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
}
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
}
}
@ -822,7 +882,7 @@ enum Call {
Internal { id: String, function_name: String, arguments: String },
/// Tool calls that we track but only to know that its not our functions.
/// We return the function calls as-is to the end-user.
External { _id: String },
External,
}
impl Call {