From 1eb8249a51fe089c90547fe005e5f15c8dec7178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 21 May 2025 21:06:11 +0200 Subject: [PATCH] Better chat settings management --- crates/meilisearch/src/routes/chat.rs | 31 +++--- .../meilisearch/src/routes/settings/chat.rs | 100 ++++++++++++------ crates/milli/src/update/settings.rs | 9 ++ 3 files changed, 95 insertions(+), 45 deletions(-) diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 2f434d706..05512bff3 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -28,6 +28,7 @@ use meilisearch_types::keys::actions; use meilisearch_types::milli::index::ChatConfig; use meilisearch_types::milli::prompt::{Prompt, PromptData}; use meilisearch_types::milli::update::new::document::DocumentFromDb; +use meilisearch_types::milli::update::Setting; use meilisearch_types::milli::{ DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget, }; @@ -107,20 +108,20 @@ fn setup_search_tool( .function( FunctionObjectArgs::default() .name(SEARCH_IN_INDEX_FUNCTION_NAME) - .description(&prompts.search_description) + .description(&prompts.search_description.clone().unwrap()) .parameters(json!({ "type": "object", "properties": { "index_uid": { "type": "string", "enum": index_uids, - "description": prompts.search_index_uid_param, + "description": prompts.search_index_uid_param.clone().unwrap(), }, "q": { // Unfortunately, Mistral does not support an array of types, here. // "type": ["string", "null"], "type": "string", - "description": prompts.search_q_param, + "description": prompts.search_q_param.clone().unwrap(), } }, "required": ["index_uid", "q"], @@ -136,7 +137,9 @@ fn setup_search_tool( chat_completion.messages.insert( 0, ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()), + content: ChatCompletionRequestSystemMessageContent::Text( + prompts.system.as_ref().unwrap().clone(), + ), name: None, }), ); @@ -239,16 +242,17 @@ async fn non_streamed_chat( }; let mut config = OpenAIConfig::default(); - if let Some(api_key) = chat_settings.api_key.as_ref() { + if let Setting::Set(api_key) = chat_settings.api_key.as_ref() { config = config.with_api_key(api_key); } - if let Some(base_api) = chat_settings.base_api.as_ref() { + if let Setting::Set(base_api) = chat_settings.base_api.as_ref() { config = config.with_api_base(base_api); } let client = Client::with_config(config); let auth_token = extract_token_from_request(&req)?.unwrap(); - setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; + let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); + setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let mut response; loop { @@ -296,7 +300,7 @@ async fn non_streamed_chat( tool_call_id: call.id.clone(), content: ChatCompletionRequestToolMessageContent::Text(format!( "{}\n\n{text}", - chat_settings.prompts.pre_query + chat_settings.prompts.clone().unwrap().pre_query.unwrap() )), }, )); @@ -325,20 +329,21 @@ async fn streamed_chat( let filters = index_scheduler.filters(); let chat_settings = match index_scheduler.chat_settings().unwrap() { - Some(value) => serde_json::from_value(value).unwrap(), + Some(value) => serde_json::from_value(value.clone()).unwrap(), None => GlobalChatSettings::default(), }; let mut config = OpenAIConfig::default(); - if let Some(api_key) = chat_settings.api_key.as_ref() { + if let Setting::Set(api_key) = chat_settings.api_key.as_ref() { config = config.with_api_key(api_key); } - if let Some(base_api) = chat_settings.base_api.as_ref() { + if let Setting::Set(base_api) = chat_settings.base_api.as_ref() { config = config.with_api_base(base_api); } let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); - setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; + let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap(); + setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?; let (tx, rx) = tokio::sync::mpsc::channel(10); let _join_handle = Handle::current().spawn(async move { @@ -447,7 +452,7 @@ async fn streamed_chat( let tool = ChatCompletionRequestToolMessage { tool_call_id: call.id.clone(), content: ChatCompletionRequestToolMessageContent::Text( - format!("{}\n\n{text}", chat_settings.prompts.pre_query), + format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()), ), }; diff --git a/crates/meilisearch/src/routes/settings/chat.rs b/crates/meilisearch/src/routes/settings/chat.rs index 586fa041e..42fb456b8 100644 --- a/crates/meilisearch/src/routes/settings/chat.rs +++ b/crates/meilisearch/src/routes/settings/chat.rs @@ -3,6 +3,7 @@ use actix_web::HttpResponse; use index_scheduler::IndexScheduler; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; +use meilisearch_types::milli::update::Setting; use serde::{Deserialize, Serialize}; use crate::extractors::authentication::policies::ActionPolicy; @@ -35,37 +36,63 @@ async fn patch_settings( ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>, Data, >, - web::Json(chat_settings): web::Json, + web::Json(new): web::Json, ) -> Result { - let chat_settings = serde_json::to_value(chat_settings).unwrap(); - index_scheduler.put_chat_settings(&chat_settings)?; + let old = match index_scheduler.chat_settings()? { + Some(value) => serde_json::from_value(value).unwrap(), + None => GlobalChatSettings::default(), + }; + + let settings = GlobalChatSettings { + source: new.source.or(old.source), + base_api: new.base_api.clone().or(old.base_api), + api_key: new.api_key.clone().or(old.api_key), + prompts: match (new.prompts, old.prompts) { + (Setting::NotSet, set) | (set, Setting::NotSet) => set, + (Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset, + (Setting::Reset, Setting::Set(set)) => Setting::Set(set), + // If both are set we must merge the prompts settings + (Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts { + system: new.system.or(old.system), + search_description: new.search_description.or(old.search_description), + search_q_param: new.search_q_param.or(old.search_q_param), + search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param), + pre_query: new.pre_query.or(old.pre_query), + }), + }, + }; + + let value = serde_json::to_value(settings).unwrap(); + index_scheduler.put_chat_settings(&value)?; Ok(HttpResponse::Ok().finish()) } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields, rename_all = "camelCase")] pub struct GlobalChatSettings { - pub source: String, - pub base_api: Option, - pub api_key: Option, - pub prompts: ChatPrompts, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub source: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub base_api: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub api_key: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub prompts: Setting, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields, rename_all = "camelCase")] pub struct ChatPrompts { - pub system: String, - pub search_description: String, - pub search_q_param: String, - pub search_index_uid_param: String, - pub pre_query: String, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -pub struct ChatIndexSettings { - pub description: String, - pub document_template: String, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub system: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub search_description: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub search_q_param: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub search_index_uid_param: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + pub pre_query: Setting, } const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\ @@ -91,17 +118,26 @@ Selecting the right index ensures the most relevant results for the user query"; impl Default for GlobalChatSettings { fn default() -> Self { GlobalChatSettings { - source: "openai".to_string(), - base_api: None, - api_key: None, - prompts: ChatPrompts { - system: DEFAULT_SYSTEM_MESSAGE.to_string(), - search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(), - search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(), - search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION - .to_string(), - pre_query: "".to_string(), - }, + source: Setting::Set("openAi".to_string()), + base_api: Setting::NotSet, + api_key: Setting::NotSet, + prompts: Setting::Set(ChatPrompts::default()), + } + } +} + +impl Default for ChatPrompts { + fn default() -> Self { + ChatPrompts { + system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()), + search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()), + search_q_param: Setting::Set( + DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(), + ), + search_index_uid_param: Setting::Set( + DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(), + ), + pre_query: Setting::Set(Default::default()), } } } diff --git a/crates/milli/src/update/settings.rs b/crates/milli/src/update/settings.rs index 697bf8168..bba8bc758 100644 --- a/crates/milli/src/update/settings.rs +++ b/crates/milli/src/update/settings.rs @@ -123,6 +123,15 @@ impl Setting { *self = new; true } + + #[track_caller] + pub fn unwrap(self) -> T { + match self { + Setting::Set(value) => value, + Setting::Reset => panic!("Setting::Reset unwrapped"), + Setting::NotSet => panic!("Setting::NotSet unwrapped"), + } + } } impl Serialize for Setting {