Make sure to use the system prompt

This commit is contained in:
Clément Renault 2025-06-06 12:32:40 +02:00
parent 70670c3be4
commit 717a026fdd
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
2 changed files with 51 additions and 20 deletions

View File

@ -102,16 +102,24 @@ pub enum ChatCompletionSource {
VLlm, VLlm,
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SystemRole {
System,
Developer,
}
impl ChatCompletionSource { impl ChatCompletionSource {
pub fn system_role(&self, model: &str) -> &'static str { pub fn system_role(&self, model: &str) -> SystemRole {
use ChatCompletionSource::*;
use SystemRole::*;
match self { match self {
ChatCompletionSource::OpenAi if Self::old_openai_model(model) => "system", OpenAi if Self::old_openai_model(model) => System,
ChatCompletionSource::OpenAi => "developer", OpenAi => Developer,
ChatCompletionSource::AzureOpenAi if Self::old_openai_model(model) => "system", AzureOpenAi if Self::old_openai_model(model) => System,
ChatCompletionSource::AzureOpenAi => "developer", AzureOpenAi => Developer,
ChatCompletionSource::Mistral => "system", Mistral => System,
ChatCompletionSource::Gemini => "system", Gemini => System,
ChatCompletionSource::VLlm => "system", VLlm => System,
} }
} }

View File

@ -9,7 +9,8 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{Event, Sse}; use actix_web_lab::sse::{Event, Sse};
use async_openai::types::{ use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestDeveloperMessage,
ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
@ -24,6 +25,7 @@ use meilisearch_auth::AuthController;
use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::error::{Code, ResponseError};
use meilisearch_types::features::{ use meilisearch_types::features::{
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings, ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
SystemRole,
}; };
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::ChatConfig; use meilisearch_types::milli::index::ChatConfig;
@ -117,6 +119,7 @@ fn setup_search_tool(
filters: &meilisearch_auth::AuthFilter, filters: &meilisearch_auth::AuthFilter,
chat_completion: &mut CreateChatCompletionRequest, chat_completion: &mut CreateChatCompletionRequest,
prompts: &DbChatCompletionPrompts, prompts: &DbChatCompletionPrompts,
system_role: SystemRole,
) -> Result<FunctionSupport, ResponseError> { ) -> Result<FunctionSupport, ResponseError> {
let tools = chat_completion.tools.get_or_insert_default(); let tools = chat_completion.tools.get_or_insert_default();
if tools.iter().any(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) { if tools.iter().any(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) {
@ -195,13 +198,21 @@ fn setup_search_tool(
tools.push(tool); tools.push(tool);
chat_completion.messages.insert( let system_message = match system_role {
0, SystemRole::System => {
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()), content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
name: None, name: None,
}), })
); }
SystemRole::Developer => {
ChatCompletionRequestMessage::Developer(ChatCompletionRequestDeveloperMessage {
content: ChatCompletionRequestDeveloperMessageContent::Text(prompts.system.clone()),
name: None,
})
}
};
chat_completion.messages.insert(0, system_message);
Ok(FunctionSupport { report_progress, report_sources, append_to_conversation }) Ok(FunctionSupport { report_progress, report_sources, append_to_conversation })
} }
@ -315,9 +326,15 @@ async fn non_streamed_chat(
let config = Config::new(&chat_settings); let config = Config::new(&chat_settings);
let client = Client::with_config(config); let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap(); let auth_token = extract_token_from_request(&req)?.unwrap();
let system_role = chat_settings.source.system_role(&chat_completion.model);
// TODO do function support later // TODO do function support later
let _function_support = let _function_support = setup_search_tool(
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; &index_scheduler,
filters,
&mut chat_completion,
&chat_settings.prompts,
system_role,
)?;
let mut response; let mut response;
loop { loop {
@ -408,8 +425,14 @@ async fn streamed_chat(
let config = Config::new(&chat_settings); let config = Config::new(&chat_settings);
let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let function_support = let system_role = chat_settings.source.system_role(&chat_completion.model);
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; let function_support = setup_search_tool(
&index_scheduler,
filters,
&mut chat_completion,
&chat_settings.prompts,
system_role,
)?;
tracing::debug!("Conversation function support: {function_support:?}"); tracing::debug!("Conversation function support: {function_support:?}");