Better chat settings management

This commit is contained in:
Clément Renault 2025-05-21 21:06:11 +02:00 committed by Kerollmops
parent f9ecb0ff31
commit 1eb8249a51
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
3 changed files with 95 additions and 45 deletions

View File

@ -28,6 +28,7 @@ use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::ChatConfig; use meilisearch_types::milli::index::ChatConfig;
use meilisearch_types::milli::prompt::{Prompt, PromptData}; use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::update::new::document::DocumentFromDb; use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::{ use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
}; };
@ -107,20 +108,20 @@ fn setup_search_tool(
.function( .function(
FunctionObjectArgs::default() FunctionObjectArgs::default()
.name(SEARCH_IN_INDEX_FUNCTION_NAME) .name(SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&prompts.search_description) .description(&prompts.search_description.clone().unwrap())
.parameters(json!({ .parameters(json!({
"type": "object", "type": "object",
"properties": { "properties": {
"index_uid": { "index_uid": {
"type": "string", "type": "string",
"enum": index_uids, "enum": index_uids,
"description": prompts.search_index_uid_param, "description": prompts.search_index_uid_param.clone().unwrap(),
}, },
"q": { "q": {
// Unfortunately, Mistral does not support an array of types, here. // Unfortunately, Mistral does not support an array of types, here.
// "type": ["string", "null"], // "type": ["string", "null"],
"type": "string", "type": "string",
"description": prompts.search_q_param, "description": prompts.search_q_param.clone().unwrap(),
} }
}, },
"required": ["index_uid", "q"], "required": ["index_uid", "q"],
@ -136,7 +137,9 @@ fn setup_search_tool(
chat_completion.messages.insert( chat_completion.messages.insert(
0, 0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()), content: ChatCompletionRequestSystemMessageContent::Text(
prompts.system.as_ref().unwrap().clone(),
),
name: None, name: None,
}), }),
); );
@ -239,16 +242,17 @@ async fn non_streamed_chat(
}; };
let mut config = OpenAIConfig::default(); let mut config = OpenAIConfig::default();
if let Some(api_key) = chat_settings.api_key.as_ref() { if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key); config = config.with_api_key(api_key);
} }
if let Some(base_api) = chat_settings.base_api.as_ref() { if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api); config = config.with_api_base(base_api);
} }
let client = Client::with_config(config); let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap(); let auth_token = extract_token_from_request(&req)?.unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let mut response; let mut response;
loop { loop {
@ -296,7 +300,7 @@ async fn non_streamed_chat(
tool_call_id: call.id.clone(), tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(format!( content: ChatCompletionRequestToolMessageContent::Text(format!(
"{}\n\n{text}", "{}\n\n{text}",
chat_settings.prompts.pre_query chat_settings.prompts.clone().unwrap().pre_query.unwrap()
)), )),
}, },
)); ));
@ -325,20 +329,21 @@ async fn streamed_chat(
let filters = index_scheduler.filters(); let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() { let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(), Some(value) => serde_json::from_value(value.clone()).unwrap(),
None => GlobalChatSettings::default(), None => GlobalChatSettings::default(),
}; };
let mut config = OpenAIConfig::default(); let mut config = OpenAIConfig::default();
if let Some(api_key) = chat_settings.api_key.as_ref() { if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key); config = config.with_api_key(api_key);
} }
if let Some(base_api) = chat_settings.base_api.as_ref() { if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api); config = config.with_api_base(base_api);
} }
let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10); let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move { let _join_handle = Handle::current().spawn(async move {
@ -447,7 +452,7 @@ async fn streamed_chat(
let tool = ChatCompletionRequestToolMessage { let tool = ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(), tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text( content: ChatCompletionRequestToolMessageContent::Text(
format!("{}\n\n{text}", chat_settings.prompts.pre_query), format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
), ),
}; };

View File

@ -3,6 +3,7 @@ use actix_web::HttpResponse;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
use meilisearch_types::milli::update::Setting;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
@ -35,37 +36,63 @@ async fn patch_settings(
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>, ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
Data<IndexScheduler>, Data<IndexScheduler>,
>, >,
web::Json(chat_settings): web::Json<GlobalChatSettings>, web::Json(new): web::Json<GlobalChatSettings>,
) -> Result<HttpResponse, ResponseError> { ) -> Result<HttpResponse, ResponseError> {
let chat_settings = serde_json::to_value(chat_settings).unwrap(); let old = match index_scheduler.chat_settings()? {
index_scheduler.put_chat_settings(&chat_settings)?; Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
let settings = GlobalChatSettings {
source: new.source.or(old.source),
base_api: new.base_api.clone().or(old.base_api),
api_key: new.api_key.clone().or(old.api_key),
prompts: match (new.prompts, old.prompts) {
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
// If both are set we must merge the prompts settings
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
system: new.system.or(old.system),
search_description: new.search_description.or(old.search_description),
search_q_param: new.search_q_param.or(old.search_q_param),
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
pre_query: new.pre_query.or(old.pre_query),
}),
},
};
let value = serde_json::to_value(settings).unwrap();
index_scheduler.put_chat_settings(&value)?;
Ok(HttpResponse::Ok().finish()) Ok(HttpResponse::Ok().finish())
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct GlobalChatSettings { pub struct GlobalChatSettings {
pub source: String, #[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub base_api: Option<String>, pub source: Setting<String>,
pub api_key: Option<String>, #[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub prompts: ChatPrompts, pub base_api: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub prompts: Setting<ChatPrompts>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatPrompts { pub struct ChatPrompts {
pub system: String, #[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_description: String, pub system: Setting<String>,
pub search_q_param: String, #[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_index_uid_param: String, pub search_description: Setting<String>,
pub pre_query: String, #[serde(default, skip_serializing_if = "Setting::is_not_set")]
} pub search_q_param: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[derive(Debug, Serialize, Deserialize)] pub search_index_uid_param: Setting<String>,
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub struct ChatIndexSettings { pub pre_query: Setting<String>,
pub description: String,
pub document_template: String,
} }
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\ const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
@ -91,17 +118,26 @@ Selecting the right index ensures the most relevant results for the user query";
impl Default for GlobalChatSettings { impl Default for GlobalChatSettings {
fn default() -> Self { fn default() -> Self {
GlobalChatSettings { GlobalChatSettings {
source: "openai".to_string(), source: Setting::Set("openAi".to_string()),
base_api: None, base_api: Setting::NotSet,
api_key: None, api_key: Setting::NotSet,
prompts: ChatPrompts { prompts: Setting::Set(ChatPrompts::default()),
system: DEFAULT_SYSTEM_MESSAGE.to_string(), }
search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(), }
search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(), }
search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION
.to_string(), impl Default for ChatPrompts {
pre_query: "".to_string(), fn default() -> Self {
}, ChatPrompts {
system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()),
search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()),
search_q_param: Setting::Set(
DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
),
search_index_uid_param: Setting::Set(
DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(),
),
pre_query: Setting::Set(Default::default()),
} }
} }
} }

View File

@ -123,6 +123,15 @@ impl<T> Setting<T> {
*self = new; *self = new;
true true
} }
#[track_caller]
pub fn unwrap(self) -> T {
match self {
Setting::Set(value) => value,
Setting::Reset => panic!("Setting::Reset unwrapped"),
Setting::NotSet => panic!("Setting::NotSet unwrapped"),
}
}
} }
impl<T: Serialize> Serialize for Setting<T> { impl<T: Serialize> Serialize for Setting<T> {