mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-09 13:45:43 +00:00
Make sure to use the system prompt
This commit is contained in:
parent
70670c3be4
commit
717a026fdd
@ -102,16 +102,24 @@ pub enum ChatCompletionSource {
|
|||||||
VLlm,
|
VLlm,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum SystemRole {
|
||||||
|
System,
|
||||||
|
Developer,
|
||||||
|
}
|
||||||
|
|
||||||
impl ChatCompletionSource {
|
impl ChatCompletionSource {
|
||||||
pub fn system_role(&self, model: &str) -> &'static str {
|
pub fn system_role(&self, model: &str) -> SystemRole {
|
||||||
|
use ChatCompletionSource::*;
|
||||||
|
use SystemRole::*;
|
||||||
match self {
|
match self {
|
||||||
ChatCompletionSource::OpenAi if Self::old_openai_model(model) => "system",
|
OpenAi if Self::old_openai_model(model) => System,
|
||||||
ChatCompletionSource::OpenAi => "developer",
|
OpenAi => Developer,
|
||||||
ChatCompletionSource::AzureOpenAi if Self::old_openai_model(model) => "system",
|
AzureOpenAi if Self::old_openai_model(model) => System,
|
||||||
ChatCompletionSource::AzureOpenAi => "developer",
|
AzureOpenAi => Developer,
|
||||||
ChatCompletionSource::Mistral => "system",
|
Mistral => System,
|
||||||
ChatCompletionSource::Gemini => "system",
|
Gemini => System,
|
||||||
ChatCompletionSource::VLlm => "system",
|
VLlm => System,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,7 +9,8 @@ use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
|||||||
use actix_web_lab::sse::{Event, Sse};
|
use actix_web_lab::sse::{Event, Sse};
|
||||||
use async_openai::types::{
|
use async_openai::types::{
|
||||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestDeveloperMessage,
|
||||||
|
ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
|
||||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
||||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||||
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
||||||
@ -24,6 +25,7 @@ use meilisearch_auth::AuthController;
|
|||||||
use meilisearch_types::error::{Code, ResponseError};
|
use meilisearch_types::error::{Code, ResponseError};
|
||||||
use meilisearch_types::features::{
|
use meilisearch_types::features::{
|
||||||
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
|
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
|
||||||
|
SystemRole,
|
||||||
};
|
};
|
||||||
use meilisearch_types::keys::actions;
|
use meilisearch_types::keys::actions;
|
||||||
use meilisearch_types::milli::index::ChatConfig;
|
use meilisearch_types::milli::index::ChatConfig;
|
||||||
@ -117,6 +119,7 @@ fn setup_search_tool(
|
|||||||
filters: &meilisearch_auth::AuthFilter,
|
filters: &meilisearch_auth::AuthFilter,
|
||||||
chat_completion: &mut CreateChatCompletionRequest,
|
chat_completion: &mut CreateChatCompletionRequest,
|
||||||
prompts: &DbChatCompletionPrompts,
|
prompts: &DbChatCompletionPrompts,
|
||||||
|
system_role: SystemRole,
|
||||||
) -> Result<FunctionSupport, ResponseError> {
|
) -> Result<FunctionSupport, ResponseError> {
|
||||||
let tools = chat_completion.tools.get_or_insert_default();
|
let tools = chat_completion.tools.get_or_insert_default();
|
||||||
if tools.iter().any(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) {
|
if tools.iter().any(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) {
|
||||||
@ -195,13 +198,21 @@ fn setup_search_tool(
|
|||||||
|
|
||||||
tools.push(tool);
|
tools.push(tool);
|
||||||
|
|
||||||
chat_completion.messages.insert(
|
let system_message = match system_role {
|
||||||
0,
|
SystemRole::System => {
|
||||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
|
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
|
||||||
name: None,
|
name: None,
|
||||||
}),
|
})
|
||||||
);
|
}
|
||||||
|
SystemRole::Developer => {
|
||||||
|
ChatCompletionRequestMessage::Developer(ChatCompletionRequestDeveloperMessage {
|
||||||
|
content: ChatCompletionRequestDeveloperMessageContent::Text(prompts.system.clone()),
|
||||||
|
name: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
chat_completion.messages.insert(0, system_message);
|
||||||
|
|
||||||
Ok(FunctionSupport { report_progress, report_sources, append_to_conversation })
|
Ok(FunctionSupport { report_progress, report_sources, append_to_conversation })
|
||||||
}
|
}
|
||||||
@ -315,9 +326,15 @@ async fn non_streamed_chat(
|
|||||||
let config = Config::new(&chat_settings);
|
let config = Config::new(&chat_settings);
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||||
|
let system_role = chat_settings.source.system_role(&chat_completion.model);
|
||||||
// TODO do function support later
|
// TODO do function support later
|
||||||
let _function_support =
|
let _function_support = setup_search_tool(
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
&index_scheduler,
|
||||||
|
filters,
|
||||||
|
&mut chat_completion,
|
||||||
|
&chat_settings.prompts,
|
||||||
|
system_role,
|
||||||
|
)?;
|
||||||
|
|
||||||
let mut response;
|
let mut response;
|
||||||
loop {
|
loop {
|
||||||
@ -408,8 +425,14 @@ async fn streamed_chat(
|
|||||||
|
|
||||||
let config = Config::new(&chat_settings);
|
let config = Config::new(&chat_settings);
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||||
let function_support =
|
let system_role = chat_settings.source.system_role(&chat_completion.model);
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
let function_support = setup_search_tool(
|
||||||
|
&index_scheduler,
|
||||||
|
filters,
|
||||||
|
&mut chat_completion,
|
||||||
|
&chat_settings.prompts,
|
||||||
|
system_role,
|
||||||
|
)?;
|
||||||
|
|
||||||
tracing::debug!("Conversation function support: {function_support:?}");
|
tracing::debug!("Conversation function support: {function_support:?}");
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user