diff --git a/crates/meilisearch-types/src/features.rs b/crates/meilisearch-types/src/features.rs index 9bcd58347..210a0f0f9 100644 --- a/crates/meilisearch-types/src/features.rs +++ b/crates/meilisearch-types/src/features.rs @@ -102,16 +102,24 @@ pub enum ChatCompletionSource { VLlm, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SystemRole { + System, + Developer, +} + impl ChatCompletionSource { - pub fn system_role(&self, model: &str) -> &'static str { + pub fn system_role(&self, model: &str) -> SystemRole { + use ChatCompletionSource::*; + use SystemRole::*; match self { - ChatCompletionSource::OpenAi if Self::old_openai_model(model) => "system", - ChatCompletionSource::OpenAi => "developer", - ChatCompletionSource::AzureOpenAi if Self::old_openai_model(model) => "system", - ChatCompletionSource::AzureOpenAi => "developer", - ChatCompletionSource::Mistral => "system", - ChatCompletionSource::Gemini => "system", - ChatCompletionSource::VLlm => "system", + OpenAi if Self::old_openai_model(model) => System, + OpenAi => Developer, + AzureOpenAi if Self::old_openai_model(model) => System, + AzureOpenAi => Developer, + Mistral => System, + Gemini => System, + VLlm => System, } } diff --git a/crates/meilisearch/src/routes/chats/chat_completions.rs b/crates/meilisearch/src/routes/chats/chat_completions.rs index f588541fa..f4e42cae3 100644 --- a/crates/meilisearch/src/routes/chats/chat_completions.rs +++ b/crates/meilisearch/src/routes/chats/chat_completions.rs @@ -9,7 +9,8 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder}; use actix_web_lab::sse::{Event, Sse}; use async_openai::types::{ ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, - ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestDeveloperMessage, + ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, @@ -24,6 +25,7 @@ use meilisearch_auth::AuthController; use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::features::{ ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings, + SystemRole, }; use meilisearch_types::keys::actions; use meilisearch_types::milli::index::ChatConfig; @@ -117,6 +119,7 @@ fn setup_search_tool( filters: &meilisearch_auth::AuthFilter, chat_completion: &mut CreateChatCompletionRequest, prompts: &DbChatCompletionPrompts, + system_role: SystemRole, ) -> Result { let tools = chat_completion.tools.get_or_insert_default(); if tools.iter().any(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) { @@ -195,13 +198,21 @@ fn setup_search_tool( tools.push(tool); - chat_completion.messages.insert( - 0, - ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()), - name: None, - }), - ); + let system_message = match system_role { + SystemRole::System => { + ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()), + name: None, + }) + } + SystemRole::Developer => { + ChatCompletionRequestMessage::Developer(ChatCompletionRequestDeveloperMessage { + content: ChatCompletionRequestDeveloperMessageContent::Text(prompts.system.clone()), + name: None, + }) + } + }; + chat_completion.messages.insert(0, system_message); Ok(FunctionSupport { report_progress, report_sources, append_to_conversation }) } @@ -315,9 +326,15 @@ async fn non_streamed_chat( let config = Config::new(&chat_settings); let client = Client::with_config(config); let auth_token = extract_token_from_request(&req)?.unwrap(); + let system_role = chat_settings.source.system_role(&chat_completion.model); // TODO do function support later - let _function_support = - setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; + let _function_support = setup_search_tool( + &index_scheduler, + filters, + &mut chat_completion, + &chat_settings.prompts, + system_role, + )?; let mut response; loop { @@ -408,8 +425,14 @@ async fn streamed_chat( let config = Config::new(&chat_settings); let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); - let function_support = - setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; + let system_role = chat_settings.source.system_role(&chat_completion.model); + let function_support = setup_search_tool( + &index_scheduler, + filters, + &mut chat_completion, + &chat_settings.prompts, + system_role, + )?; tracing::debug!("Conversation function support: {function_support:?}");