diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 5de5a9367..9de15f751 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -2,13 +2,15 @@ use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Write as _; use std::mem; +use std::ops::ControlFlow; use std::sync::RwLock; use std::time::Duration; use actix_web::web::{self, Data}; use actix_web::{Either, HttpRequest, HttpResponse, Responder}; use actix_web_lab::sse::{self, Event, Sse}; -use async_openai::config::OpenAIConfig; +use async_openai::config::{Config, OpenAIConfig}; +use async_openai::error::OpenAIError; use async_openai::types::{ ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs, @@ -40,6 +42,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::runtime::Handle; use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::Sender; use super::settings::chat::{ChatPrompts, GlobalChatSettings}; use crate::error::MeilisearchHttpError; @@ -83,11 +86,11 @@ async fn chat( if chat_completion.stream.unwrap_or(false) { Either::Right( - streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await, + streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await, ) } else { Either::Left( - non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await, + non_streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await, ) } } @@ -327,8 +330,8 @@ async fn process_search_request( async fn non_streamed_chat( index_scheduler: GuardedData, Data>, auth_ctrl: web::Data, - req: HttpRequest, search_queue: web::Data, + req: HttpRequest, mut chat_completion: CreateChatCompletionRequest, ) -> Result { let filters = index_scheduler.filters(); @@ -420,8 +423,8 @@ async fn non_streamed_chat( async fn streamed_chat( index_scheduler: GuardedData, Data>, auth_ctrl: web::Data, - req: HttpRequest, search_queue: web::Data, + req: HttpRequest, mut chat_completion: CreateChatCompletionRequest, ) -> Result { let filters = index_scheduler.filters(); @@ -441,354 +444,285 @@ async fn streamed_chat( let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); - let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } = + let function_support = setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let (tx, rx) = tokio::sync::mpsc::channel(10); + let tx = SseEventSender(tx); let _join_handle = Handle::current().spawn(async move { let client = Client::with_config(config.clone()); let mut global_tool_calls = HashMap::::new(); - let mut finish_reason = None; // Limit the number of internal calls to satisfy the search requests of the LLM - 'main: for _ in 0..20 { - let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); - while let Some(result) = response.next().await { - match result { - Ok(resp) => { - let choice = &resp.choices[0]; - finish_reason = choice.finish_reason; + for _ in 0..20 { + let output = run_conversation( + &index_scheduler, + &auth_ctrl, + &search_queue, + &auth_token, + &client, + &chat_settings, + &mut chat_completion, + &tx, + &mut global_tool_calls, + function_support, + ); - let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = - &choice.delta; - - match tool_calls { - Some(tool_calls) => { - for chunk in tool_calls { - let ChatCompletionMessageToolCallChunk { - index, - id, - r#type: _, - function, - } = chunk; - let FunctionCallStream { name, arguments } = - function.as_ref().unwrap(); - - global_tool_calls - .entry(*index) - .and_modify(|call| { - if call.is_internal() { - call.append(arguments.as_ref().unwrap()) - } - }) - .or_insert_with(|| { - if name.as_ref().map_or(false, |n| { - n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME - }) { - Call::Internal { - id: id.as_ref().unwrap().clone(), - function_name: name.as_ref().unwrap().clone(), - arguments: arguments.as_ref().unwrap().clone(), - } - } else { - Call::External { _id: id.as_ref().unwrap().clone() } - } - }); - - if global_tool_calls.get(index).map_or(false, Call::is_external) - { - todo!("Support forwarding external tool calls"); - } - } - } - None => { - if !global_tool_calls.is_empty() { - let (meili_calls, other_calls): (Vec<_>, Vec<_>) = - mem::take(&mut global_tool_calls) - .into_values() - .flat_map(|call| match call { - Call::Internal { - id, - function_name: name, - arguments, - } => Some(ChatCompletionMessageToolCall { - id, - r#type: Some(ChatCompletionToolType::Function), - function: FunctionCall { name, arguments }, - }), - Call::External { _id: _ } => None, - }) - .partition(|call| { - call.function.name - == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME - }); - - chat_completion.messages.push( - ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(meili_calls.clone()) - .build() - .unwrap() - .into(), - ); - - assert!( - other_calls.is_empty(), - "We do not support external tool forwarding for now" - ); - - for call in meili_calls { - if report_progress { - let call = MeiliSearchProgress { - call_id: call.id.to_string(), - function_name: call.function.name.clone(), - function_arguments: call - .function - .arguments - .clone(), - }; - let resp = call.create_response(resp.clone()); - // Send the event of "we are doing a search" - if let Err(SendError(_)) = tx - .send(Event::Data(sse::Data::new_json(&resp).unwrap())) - .await - { - return; - } - } - - if append_to_conversation { - // Ask the front-end user to append this tool *call* to the conversation - let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: None, - refusal: None, - name: None, - audio: None, - tool_calls: Some(vec![ - ChatCompletionMessageToolCall { - id: call.id.clone(), - r#type: Some(ChatCompletionToolType::Function), - function: FunctionCall { - name: call.function.name.clone(), - arguments: call.function.arguments.clone(), - }, - }, - ]), - function_call: None, - } - )); - let resp = call.create_response(resp.clone()); - if let Err(SendError(_)) = tx - .send(Event::Data(sse::Data::new_json(&resp).unwrap())) - .await - { - return; - } - } - - let result = - match serde_json::from_str(&call.function.arguments) { - Ok(SearchInIndexParameters { index_uid, q }) => { - process_search_request( - &index_scheduler, - auth_ctrl.clone(), - &search_queue, - &auth_token, - index_uid, - q, - ) - .await - .map_err(|e| e.to_string()) - } - Err(err) => Err(err.to_string()), - }; - - let text = match result { - Ok((_index, documents, text)) => { - if report_sources { - let call = MeiliSearchSources { - call_id: call.id.to_string(), - sources: documents, - }; - let resp = call.create_response(resp.clone()); - // Send the event of "we are doing a search" - if let Err(SendError(_)) = tx - .send(Event::Data(sse::Data::new_json(&resp).unwrap())) - .await - { - return; - } - } - - text - }, - Err(err) => err, - }; - - let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { - tool_call_id: call.id.clone(), - content: ChatCompletionRequestToolMessageContent::Text( - format!( - "{}\n\n{text}", - chat_settings - .prompts - .as_ref() - .unwrap() - .pre_query - .as_ref() - .unwrap() - ), - ), - }); - - if append_to_conversation { - // Ask the front-end user to append this tool *output* to the conversation - let tool = MeiliAppendConversationMessage(tool.clone()); - let resp = tool.create_response(resp.clone()); - if let Err(SendError(_)) = tx - .send(Event::Data(sse::Data::new_json(&resp).unwrap())) - .await - { - return; - } - } - - chat_completion.messages.push(tool); - } - } else { - if let Err(SendError(_)) = tx - .send(Event::Data(sse::Data::new_json(&resp).unwrap())) - .await - { - return; - } - } - } - } - } - Err(err) => { - // tracing::error!("{err:?}"); - // if let Err(SendError(_)) = tx - // .send(Event::Data( - // sse::Data::new_json(&json!({ - // "object": "chat.completion.error", - // "tool": err.to_string(), - // })) - // .unwrap(), - // )) - // .await - // { - // return; - // } - - break 'main; - } - } - } - - // We must stop if the finish reason is not something we can solve with Meilisearch - if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) { - break; + match output.await { + Ok(ControlFlow::Continue(())) => (), + Ok(ControlFlow::Break(_finish_reason)) => break, + // If the connection is closed we must stop + Err(SendError(_)) => return, } } - let _ = tx.send(Event::Data(sse::Data::new("[DONE]"))); + let _ = tx.stop().await; }); Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))) } -#[derive(Debug, Clone, Serialize)] -/// Provides information about the current Meilisearch search operation. -struct MeiliSearchProgress { - /// The call ID to track the sources of the search. - pub call_id: String, - /// The name of the function we are executing. - pub function_name: String, - /// The arguments of the function we are executing, encoded in JSON. - pub function_arguments: String, -} +/// Updates the chat completion with the new messages, streams the LLM tokens, +/// and report progress and errors. +async fn run_conversation( + index_scheduler: &GuardedData, Data>, + auth_ctrl: &web::Data, + search_queue: &web::Data, + auth_token: &str, + client: &Client, + chat_settings: &GlobalChatSettings, + chat_completion: &mut CreateChatCompletionRequest, + tx: &SseEventSender, + global_tool_calls: &mut HashMap, + function_support: FunctionSupport, +) -> Result, ()>, SendError> { + let mut finish_reason = None; -impl MeiliSearchProgress { - fn create_response( - &self, - mut resp: CreateChatCompletionStreamResponse, - ) -> CreateChatCompletionStreamResponse { - let call_text = serde_json::to_string(self).unwrap(); - let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, - id: Some(uuid::Uuid::new_v4().to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()), - arguments: Some(call_text), - }), - }; - resp.choices[0] = ChatChoiceStream { - index: 0, - delta: ChatCompletionStreamResponseDelta { - content: None, - function_call: None, - tool_calls: Some(vec![tool_call]), - role: Some(Role::Assistant), - refusal: None, - }, - finish_reason: None, - logprobs: None, - }; - resp + let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); + while let Some(result) = response.next().await { + match result { + Ok(resp) => { + let choice = &resp.choices[0]; + finish_reason = choice.finish_reason; + + let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta; + + match tool_calls { + Some(tool_calls) => { + for chunk in tool_calls { + let ChatCompletionMessageToolCallChunk { + index, + id, + r#type: _, + function, + } = chunk; + let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); + + global_tool_calls + .entry(*index) + .and_modify(|call| { + if call.is_internal() { + call.append(arguments.as_ref().unwrap()) + } + }) + .or_insert_with(|| { + if name + .as_ref() + .map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) + { + Call::Internal { + id: id.as_ref().unwrap().clone(), + function_name: name.as_ref().unwrap().clone(), + arguments: arguments.as_ref().unwrap().clone(), + } + } else { + Call::External + } + }); + + if global_tool_calls.get(index).map_or(false, Call::is_external) { + todo!("Support forwarding external tool calls"); + } + } + } + None => { + if !global_tool_calls.is_empty() { + let (meili_calls, other_calls): (Vec<_>, Vec<_>) = + mem::take(global_tool_calls) + .into_values() + .flat_map(|call| match call { + Call::Internal { id, function_name: name, arguments } => { + Some(ChatCompletionMessageToolCall { + id, + r#type: Some(ChatCompletionToolType::Function), + function: FunctionCall { name, arguments }, + }) + } + Call::External => None, + }) + .partition(|call| { + call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME + }); + + chat_completion.messages.push( + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(meili_calls.clone()) + .build() + .unwrap() + .into(), + ); + + assert!( + other_calls.is_empty(), + "We do not support external tool forwarding for now" + ); + + handle_meili_tools( + &index_scheduler, + &auth_ctrl, + &search_queue, + &auth_token, + chat_settings, + tx, + meili_calls, + chat_completion, + &resp, + function_support, + ) + .await?; + } else { + tx.forward_response(&resp).await?; + } + } + } + } + Err(err) => { + if function_support.report_errors { + tx.report_error(err).await?; + } + return Ok(ControlFlow::Break(None)); + } + } + } + + // We must stop if the finish reason is not something we can solve with Meilisearch + match finish_reason { + Some(FinishReason::ToolCalls) => Ok(ControlFlow::Continue(())), + otherwise => Ok(ControlFlow::Break(otherwise)), } } -#[derive(Debug, Clone, Serialize)] -/// Provides sources of the search. -struct MeiliSearchSources { - /// The call ID to track the original search associated to those sources. - pub call_id: String, - /// The documents associated with the search (call_id). - /// Only the displayed attributes of the documents are returned. - pub sources: Vec, -} +async fn handle_meili_tools( + index_scheduler: &GuardedData, Data>, + auth_ctrl: &web::Data, + search_queue: &web::Data, + auth_token: &str, + chat_settings: &GlobalChatSettings, + tx: &SseEventSender, + meili_calls: Vec, + chat_completion: &mut CreateChatCompletionRequest, + resp: &CreateChatCompletionStreamResponse, + FunctionSupport { report_progress, report_sources, append_to_conversation, .. }: FunctionSupport, +) -> Result<(), SendError> { + for call in meili_calls { + if report_progress { + tx.report_search_progress( + resp.clone(), + &call.id, + &call.function.name, + &call.function.arguments, + ) + .await?; + } -impl MeiliSearchSources { - fn create_response( - &self, - mut resp: CreateChatCompletionStreamResponse, - ) -> CreateChatCompletionStreamResponse { - let call_text = serde_json::to_string(self).unwrap(); - let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, - id: Some(uuid::Uuid::new_v4().to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()), - arguments: Some(call_text), - }), + if append_to_conversation { + tx.append_tool_call_conversation_message( + resp.clone(), + call.id.clone(), + call.function.name.clone(), + call.function.arguments.clone(), + ) + .await?; + } + + let result = match serde_json::from_str(&call.function.arguments) { + Ok(SearchInIndexParameters { index_uid, q }) => process_search_request( + &index_scheduler, + auth_ctrl.clone(), + &search_queue, + &auth_token, + index_uid, + q, + ) + .await + .map_err(|e| e.to_string()), + Err(err) => Err(err.to_string()), }; - resp.choices[0] = ChatChoiceStream { - index: 0, - delta: ChatCompletionStreamResponseDelta { - content: None, - function_call: None, - tool_calls: Some(vec![tool_call]), - role: Some(Role::Assistant), - refusal: None, - }, - finish_reason: None, - logprobs: None, + + let text = match result { + Ok((_index, documents, text)) => { + if report_sources { + tx.report_sources(resp.clone(), &call.id, &documents).await?; + } + + text + } + Err(err) => err, }; - resp + + let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { + tool_call_id: call.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text(format!( + "{}\n\n{text}", + chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap() + )), + }); + + if append_to_conversation { + tx.append_conversation_message(resp.clone(), &tool).await?; + } + + chat_completion.messages.push(tool); } + + Ok(()) } -struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage); +pub struct SseEventSender(Sender); -impl MeiliAppendConversationMessage { - fn create_response( +impl SseEventSender { + /// Ask the front-end user to append this tool *call* to the conversation + pub async fn append_tool_call_conversation_message( + &self, + resp: CreateChatCompletionStreamResponse, + call_id: String, + function_name: String, + function_arguments: String, + ) -> Result<(), SendError> { + let message = + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { + content: None, + refusal: None, + name: None, + audio: None, + tool_calls: Some(vec![ChatCompletionMessageToolCall { + id: call_id, + r#type: Some(ChatCompletionToolType::Function), + function: FunctionCall { name: function_name, arguments: function_arguments }, + }]), + function_call: None, + }); + + self.append_conversation_message(resp, &message).await + } + + /// Ask the front-end user to append this tool to the conversation + pub async fn append_conversation_message( &self, mut resp: CreateChatCompletionStreamResponse, - ) -> CreateChatCompletionStreamResponse { - let call_text = serde_json::to_string(&self.0).unwrap(); + message: &ChatCompletionRequestMessage, + ) -> Result<(), SendError> { + let call_text = serde_json::to_string(message).unwrap(); let tool_call = ChatCompletionMessageToolCallChunk { index: 0, id: Some(uuid::Uuid::new_v4().to_string()), @@ -798,6 +732,7 @@ impl MeiliAppendConversationMessage { arguments: Some(call_text), }), }; + resp.choices[0] = ChatChoiceStream { index: 0, delta: ChatCompletionStreamResponseDelta { @@ -810,7 +745,132 @@ impl MeiliAppendConversationMessage { finish_reason: None, logprobs: None, }; - resp + + self.send_json(&resp).await + } + + pub async fn report_search_progress( + &self, + mut resp: CreateChatCompletionStreamResponse, + call_id: &str, + function_name: &str, + function_arguments: &str, + ) -> Result<(), SendError> { + #[derive(Debug, Clone, Serialize)] + /// Provides information about the current Meilisearch search operation. + struct MeiliSearchProgress<'a> { + /// The call ID to track the sources of the search. + call_id: &'a str, + /// The name of the function we are executing. + function_name: &'a str, + /// The arguments of the function we are executing, encoded in JSON. + function_arguments: &'a str, + } + + let progress = MeiliSearchProgress { call_id, function_name, function_arguments }; + let call_text = serde_json::to_string(&progress).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()), + arguments: Some(call_text), + }), + }; + + resp.choices[0] = ChatChoiceStream { + index: 0, + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + + self.send_json(&resp).await + } + + pub async fn report_sources( + &self, + mut resp: CreateChatCompletionStreamResponse, + call_id: &str, + documents: &[Document], + ) -> Result<(), SendError> { + #[derive(Debug, Clone, Serialize)] + /// Provides sources of the search. + struct MeiliSearchSources<'a> { + /// The call ID to track the original search associated to those sources. + call_id: &'a str, + /// The documents associated with the search (call_id). + /// Only the displayed attributes of the documents are returned. + sources: &'a [Document], + } + + let sources = MeiliSearchSources { call_id, sources: documents }; + let call_text = serde_json::to_string(&sources).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()), + arguments: Some(call_text), + }), + }; + + resp.choices[0] = ChatChoiceStream { + index: 0, + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + + self.send_json(&resp).await + } + + pub async fn report_error(&self, error: OpenAIError) -> Result<(), SendError> { + tracing::error!("OpenAI Error: {}", error); + + let (error_code, message) = match error { + OpenAIError::Reqwest(e) => ("internal_reqwest_error", e.to_string()), + OpenAIError::ApiError(api_error) => ("llm_api_issue", api_error.to_string()), + OpenAIError::JSONDeserialize(error) => ("internal_json_deserialize", error.to_string()), + OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(), + OpenAIError::StreamError(error) => ("llm_api_stream_error", error.to_string()), + OpenAIError::InvalidArgument(error) => ("internal_invalid_argument", error.to_string()), + }; + + self.send_json(&json!({ + "error_code": error_code, + "message": message, + })) + .await + } + + pub async fn forward_response( + &self, + resp: &CreateChatCompletionStreamResponse, + ) -> Result<(), SendError> { + self.send_json(resp).await + } + + pub async fn stop(self) -> Result<(), SendError> { + self.0.send(Event::Data(sse::Data::new("[DONE]"))).await + } + + async fn send_json(&self, data: &S) -> Result<(), SendError> { + self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await } } @@ -822,7 +882,7 @@ enum Call { Internal { id: String, function_name: String, arguments: String }, /// Tool calls that we track but only to know that its not our functions. /// We return the function calls as-is to the end-user. - External { _id: String }, + External, } impl Call {