diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index b8c3eee29..3db948eb8 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -10,13 +10,14 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder}; use actix_web_lab::sse::{self, Event, Sse}; use async_openai::config::OpenAIConfig; use async_openai::types::{ - ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, - ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, - ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, - ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, - ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, - CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream, - FunctionObjectArgs, + ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, + ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs, + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage, + ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, + ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, + CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream, + FunctionObjectArgs, Role, }; use async_openai::Client; use bumpalo::Bump; @@ -34,7 +35,7 @@ use meilisearch_types::milli::{ DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget, }; use meilisearch_types::Index; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::runtime::Handle; use tokio::sync::mpsc::error::SendError; @@ -52,7 +53,9 @@ use crate::search::{ }; use crate::search_queue::SearchQueue; -const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; +const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress"; +const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage"; +const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service(web::resource("/completions").route(web::post().to(chat))); @@ -86,18 +89,45 @@ async fn chat( } } +#[derive(Default, Debug, Clone, Copy)] +pub struct FunctionSupport { + /// Defines if we can call the _meiliSearchProgress function + /// to inform the front-end about what we are searching for. + progress: bool, + /// Defines if we can call the _meiliAppendConversationMessage + /// function to provide the messages to append into the conversation. + append_to_conversation: bool, +} + /// Setup search tool in chat completion request fn setup_search_tool( index_scheduler: &Data, filters: &meilisearch_auth::AuthFilter, chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts, -) -> Result<(), ResponseError> { +) -> Result { let tools = chat_completion.tools.get_or_insert_default(); - if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() { - panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set"); + if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() { + panic!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function already set"); } + // Remove internal tools used for front-end notifications as they should be hidden from the LLM. + let mut progress = false; + let mut append_to_conversation = false; + tools.retain(|tool| { + match tool.function.name.as_str() { + MEILI_SEARCH_PROGRESS_NAME => { + progress = true; + false + } + MEILI_APPEND_CONVERSATION_MESSAGE_NAME => { + append_to_conversation = true; + false + } + _ => true, // keep other tools + } + }); + let mut index_uids = Vec::new(); let mut function_description = prompts.search_description.clone().unwrap(); index_scheduler.try_for_each_index::<_, ()>(|name, index| { @@ -119,7 +149,7 @@ fn setup_search_tool( .r#type(ChatCompletionToolType::Function) .function( FunctionObjectArgs::default() - .name(SEARCH_IN_INDEX_FUNCTION_NAME) + .name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) .description(&function_description) .parameters(json!({ "type": "object", @@ -145,7 +175,9 @@ fn setup_search_tool( ) .build() .unwrap(); + tools.push(tool); + chat_completion.messages.insert( 0, ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { @@ -156,7 +188,7 @@ fn setup_search_tool( }), ); - Ok(()) + Ok(FunctionSupport { progress, append_to_conversation }) } /// Process search request and return formatted results @@ -287,7 +319,8 @@ async fn non_streamed_chat( let auth_token = extract_token_from_request(&req)?.unwrap(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); - setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; + let FunctionSupport { progress, append_to_conversation } = + setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let mut response; loop { @@ -300,7 +333,7 @@ async fn non_streamed_chat( let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls .into_iter() - .partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME); + .partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME); chat_completion.messages.push( ChatCompletionRequestAssistantMessageArgs::default() @@ -378,7 +411,8 @@ async fn streamed_chat( let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); - setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; + let FunctionSupport { progress, append_to_conversation } = + setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let (tx, rx) = tokio::sync::mpsc::channel(10); let _join_handle = Handle::current().spawn(async move { @@ -395,21 +429,8 @@ async fn streamed_chat( let choice = &resp.choices[0]; finish_reason = choice.finish_reason; - #[allow(deprecated)] - let ChatCompletionStreamResponseDelta { - content, - // Using deprecated field but keeping for compatibility - function_call: _, - ref tool_calls, - role: _, - refusal: _, - } = &choice.delta; - - if content.is_some() { - if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await { - return; - } - } + let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = + &choice.delta; match tool_calls { Some(tool_calls) => { @@ -422,109 +443,195 @@ async fn streamed_chat( } = chunk; let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); + global_tool_calls .entry(*index) - .and_modify(|call| call.append(arguments.as_ref().unwrap())) - .or_insert_with(|| Call { - id: id.as_ref().unwrap().clone(), - function_name: name.as_ref().unwrap().clone(), - arguments: arguments.as_ref().unwrap().clone(), - }); - } - } - None if !global_tool_calls.is_empty() => { - let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = - mem::take(&mut global_tool_calls) - .into_values() - .map(|call| ChatCompletionMessageToolCall { - id: call.id, - r#type: Some(ChatCompletionToolType::Function), - function: FunctionCall { - name: call.function_name, - arguments: call.arguments, - }, + .and_modify(|call| { + if call.is_internal() { + call.append(arguments.as_ref().unwrap()) + } }) - .partition(|call| call.function.name == SEARCH_IN_INDEX_FUNCTION_NAME); - - chat_completion.messages.push( - ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(meili_calls.clone()) - .build() - .unwrap() - .into(), - ); - - for call in meili_calls { - if let Err(SendError(_)) = tx.send(Event::Data( - sse::Data::new_json(json!({ - "object": "chat.completion.tool.call", - "tool": call, - })) - .unwrap(), - )) - .await { - return; - } - - let result = match serde_json::from_str(&call.function.arguments) { - Ok(SearchInIndexParameters { index_uid, q }) => process_search_request( - &index_scheduler, - auth_ctrl.clone(), - &search_queue, - &auth_token, - index_uid, - q, - ).await.map_err(|e| e.to_string()), - Err(err) => Err(err.to_string()), - }; - - let is_error = result.is_err(); - let text = match result { - Ok((_, text)) => text, - Err(err) => err, - }; - - let tool = ChatCompletionRequestToolMessage { - tool_call_id: call.id.clone(), - content: ChatCompletionRequestToolMessageContent::Text( - format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()), - ), - }; - - if let Err(SendError(_)) = tx.send(Event::Data( - sse::Data::new_json(json!({ - "object": if is_error { - "chat.completion.tool.error" + .or_insert_with(|| { + if name.as_ref().map_or(false, |n| { + n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME + }) { + Call::Internal { + id: id.as_ref().unwrap().clone(), + function_name: name.as_ref().unwrap().clone(), + arguments: arguments.as_ref().unwrap().clone(), + } } else { - "chat.completion.tool.output" - }, - "tool": ChatCompletionRequestToolMessage { - tool_call_id: call.id, - content: ChatCompletionRequestToolMessageContent::Text( - text, - ), - }, - })) - .unwrap(), - )) - .await { - return; - } + Call::External { _id: id.as_ref().unwrap().clone() } + } + }); - chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool)); + if global_tool_calls.get(index).map_or(false, Call::is_external) + { + todo!("Support forwarding external tool calls"); + } + } + } + None => { + if !global_tool_calls.is_empty() { + let (meili_calls, other_calls): (Vec<_>, Vec<_>) = + mem::take(&mut global_tool_calls) + .into_values() + .flat_map(|call| match call { + Call::Internal { + id, + function_name: name, + arguments, + } => Some(ChatCompletionMessageToolCall { + id, + r#type: Some(ChatCompletionToolType::Function), + function: FunctionCall { name, arguments }, + }), + Call::External { _id: _ } => None, + }) + .partition(|call| { + call.function.name + == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME + }); + + chat_completion.messages.push( + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(meili_calls.clone()) + .build() + .unwrap() + .into(), + ); + + assert!( + other_calls.is_empty(), + "We do not support external tool forwarding for now" + ); + + for call in meili_calls { + if progress { + let call = MeiliSearchProgress { + function_name: call.function.name.clone(), + function_arguments: call + .function + .arguments + .clone(), + }; + let resp = call.create_response(resp.clone()); + // Send the event of "we are doing a search" + if let Err(SendError(_)) = tx + .send(Event::Data(sse::Data::new_json(&resp).unwrap())) + .await + { + return; + } + } + + if append_to_conversation { + // Ask the front-end user to append this tool *call* to the conversation + let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant( + ChatCompletionRequestAssistantMessage { + content: None, + refusal: None, + name: None, + audio: None, + tool_calls: Some(vec![ + ChatCompletionMessageToolCall { + id: call.id.clone(), + r#type: Some(ChatCompletionToolType::Function), + function: FunctionCall { + name: call.function.name.clone(), + arguments: call.function.arguments.clone(), + }, + }, + ]), + function_call: None, + } + )); + let resp = call.create_response(resp.clone()); + if let Err(SendError(_)) = tx + .send(Event::Data(sse::Data::new_json(&resp).unwrap())) + .await + { + return; + } + } + + let result = + match serde_json::from_str(&call.function.arguments) { + Ok(SearchInIndexParameters { index_uid, q }) => { + process_search_request( + &index_scheduler, + auth_ctrl.clone(), + &search_queue, + &auth_token, + index_uid, + q, + ) + .await + .map_err(|e| e.to_string()) + } + Err(err) => Err(err.to_string()), + }; + + let text = match result { + Ok((_, text)) => text, + Err(err) => err, + }; + + let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { + tool_call_id: call.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text( + format!( + "{}\n\n{text}", + chat_settings + .prompts + .as_ref() + .unwrap() + .pre_query + .as_ref() + .unwrap() + ), + ), + }); + + if append_to_conversation { + // Ask the front-end user to append this tool *output* to the conversation + let tool = MeiliAppendConversationMessage(tool.clone()); + let resp = tool.create_response(resp.clone()); + if let Err(SendError(_)) = tx + .send(Event::Data(sse::Data::new_json(&resp).unwrap())) + .await + { + return; + } + } + + chat_completion.messages.push(tool); + } + } else { + if let Err(SendError(_)) = tx + .send(Event::Data(sse::Data::new_json(&resp).unwrap())) + .await + { + return; + } } } - None => (), } } Err(err) => { - tracing::error!("{err:?}"); - if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&json!({ - "object": "chat.completion.error", - "tool": err.to_string(), - })).unwrap())).await { - return; - } + // tracing::error!("{err:?}"); + // if let Err(SendError(_)) = tx + // .send(Event::Data( + // sse::Data::new_json(&json!({ + // "object": "chat.completion.error", + // "tool": err.to_string(), + // })) + // .unwrap(), + // )) + // .await + // { + // return; + // } break 'main; } @@ -543,17 +650,106 @@ async fn streamed_chat( Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))) } +#[derive(Debug, Clone, Serialize)] +/// Give context about what Meilisearch is doing. +struct MeiliSearchProgress { + /// The name of the function we are executing. + pub function_name: String, + /// The arguments of the function we are executing, encoded in JSON. + pub function_arguments: String, +} + +impl MeiliSearchProgress { + fn create_response( + &self, + mut resp: CreateChatCompletionStreamResponse, + ) -> CreateChatCompletionStreamResponse { + let call_text = serde_json::to_string(self).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()), + arguments: Some(call_text), + }), + }; + resp.choices[0] = ChatChoiceStream { + index: 0, + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + resp + } +} + +struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage); + +impl MeiliAppendConversationMessage { + fn create_response( + &self, + mut resp: CreateChatCompletionStreamResponse, + ) -> CreateChatCompletionStreamResponse { + let call_text = serde_json::to_string(&self.0).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()), + arguments: Some(call_text), + }), + }; + resp.choices[0] = ChatChoiceStream { + index: 0, + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + resp + } +} + /// The structure used to aggregate the function calls to make. #[derive(Debug)] -struct Call { - id: String, - function_name: String, - arguments: String, +enum Call { + /// Tool calls to tools that must be managed by Meilisearch internally. + /// Typically the search functions. + Internal { id: String, function_name: String, arguments: String }, + /// Tool calls that we track but only to know that its not our functions. + /// We return the function calls as-is to the end-user. + External { _id: String }, } impl Call { - fn append(&mut self, arguments: &str) { - self.arguments.push_str(arguments); + fn is_internal(&self) -> bool { + matches!(self, Call::Internal { .. }) + } + + fn is_external(&self) -> bool { + matches!(self, Call::External { .. }) + } + + fn append(&mut self, more: &str) { + match self { + Call::Internal { arguments, .. } => arguments.push_str(more), + Call::External { .. } => { + panic!("Cannot append argument chunks to an external function") + } + } } }