mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-04 19:25:32 +00:00
Better chat settings management
This commit is contained in:
parent
f9ecb0ff31
commit
1eb8249a51
@ -28,6 +28,7 @@ use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::index::ChatConfig;
|
||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use meilisearch_types::milli::{
|
||||
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
|
||||
};
|
||||
@ -107,20 +108,20 @@ fn setup_search_tool(
|
||||
.function(
|
||||
FunctionObjectArgs::default()
|
||||
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
.description(&prompts.search_description)
|
||||
.description(&prompts.search_description.clone().unwrap())
|
||||
.parameters(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_uid": {
|
||||
"type": "string",
|
||||
"enum": index_uids,
|
||||
"description": prompts.search_index_uid_param,
|
||||
"description": prompts.search_index_uid_param.clone().unwrap(),
|
||||
},
|
||||
"q": {
|
||||
// Unfortunately, Mistral does not support an array of types, here.
|
||||
// "type": ["string", "null"],
|
||||
"type": "string",
|
||||
"description": prompts.search_q_param,
|
||||
"description": prompts.search_q_param.clone().unwrap(),
|
||||
}
|
||||
},
|
||||
"required": ["index_uid", "q"],
|
||||
@ -136,7 +137,9 @@ fn setup_search_tool(
|
||||
chat_completion.messages.insert(
|
||||
0,
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||
prompts.system.as_ref().unwrap().clone(),
|
||||
),
|
||||
name: None,
|
||||
}),
|
||||
);
|
||||
@ -239,16 +242,17 @@ async fn non_streamed_chat(
|
||||
};
|
||||
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
let client = Client::with_config(config);
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
|
||||
let mut response;
|
||||
loop {
|
||||
@ -296,7 +300,7 @@ async fn non_streamed_chat(
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
||||
"{}\n\n{text}",
|
||||
chat_settings.prompts.pre_query
|
||||
chat_settings.prompts.clone().unwrap().pre_query.unwrap()
|
||||
)),
|
||||
},
|
||||
));
|
||||
@ -325,20 +329,21 @@ async fn streamed_chat(
|
||||
let filters = index_scheduler.filters();
|
||||
|
||||
let chat_settings = match index_scheduler.chat_settings().unwrap() {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
Some(value) => serde_json::from_value(value.clone()).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
};
|
||||
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let _join_handle = Handle::current().spawn(async move {
|
||||
@ -447,7 +452,7 @@ async fn streamed_chat(
|
||||
let tool = ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(
|
||||
format!("{}\n\n{text}", chat_settings.prompts.pre_query),
|
||||
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
|
||||
),
|
||||
};
|
||||
|
||||
|
@ -3,6 +3,7 @@ use actix_web::HttpResponse;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
@ -35,37 +36,63 @@ async fn patch_settings(
|
||||
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
web::Json(chat_settings): web::Json<GlobalChatSettings>,
|
||||
web::Json(new): web::Json<GlobalChatSettings>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
let chat_settings = serde_json::to_value(chat_settings).unwrap();
|
||||
index_scheduler.put_chat_settings(&chat_settings)?;
|
||||
let old = match index_scheduler.chat_settings()? {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
};
|
||||
|
||||
let settings = GlobalChatSettings {
|
||||
source: new.source.or(old.source),
|
||||
base_api: new.base_api.clone().or(old.base_api),
|
||||
api_key: new.api_key.clone().or(old.api_key),
|
||||
prompts: match (new.prompts, old.prompts) {
|
||||
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
|
||||
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
|
||||
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
|
||||
// If both are set we must merge the prompts settings
|
||||
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
|
||||
system: new.system.or(old.system),
|
||||
search_description: new.search_description.or(old.search_description),
|
||||
search_q_param: new.search_q_param.or(old.search_q_param),
|
||||
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
|
||||
pre_query: new.pre_query.or(old.pre_query),
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(settings).unwrap();
|
||||
index_scheduler.put_chat_settings(&value)?;
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct GlobalChatSettings {
|
||||
pub source: String,
|
||||
pub base_api: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub prompts: ChatPrompts,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub source: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub base_api: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub prompts: Setting<ChatPrompts>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct ChatPrompts {
|
||||
pub system: String,
|
||||
pub search_description: String,
|
||||
pub search_q_param: String,
|
||||
pub search_index_uid_param: String,
|
||||
pub pre_query: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct ChatIndexSettings {
|
||||
pub description: String,
|
||||
pub document_template: String,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub system: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub search_description: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub search_q_param: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub search_index_uid_param: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub pre_query: Setting<String>,
|
||||
}
|
||||
|
||||
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
|
||||
@ -91,17 +118,26 @@ Selecting the right index ensures the most relevant results for the user query";
|
||||
impl Default for GlobalChatSettings {
|
||||
fn default() -> Self {
|
||||
GlobalChatSettings {
|
||||
source: "openai".to_string(),
|
||||
base_api: None,
|
||||
api_key: None,
|
||||
prompts: ChatPrompts {
|
||||
system: DEFAULT_SYSTEM_MESSAGE.to_string(),
|
||||
search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(),
|
||||
search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||
search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION
|
||||
.to_string(),
|
||||
pre_query: "".to_string(),
|
||||
},
|
||||
source: Setting::Set("openAi".to_string()),
|
||||
base_api: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
prompts: Setting::Set(ChatPrompts::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChatPrompts {
|
||||
fn default() -> Self {
|
||||
ChatPrompts {
|
||||
system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()),
|
||||
search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()),
|
||||
search_q_param: Setting::Set(
|
||||
DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||
),
|
||||
search_index_uid_param: Setting::Set(
|
||||
DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||
),
|
||||
pre_query: Setting::Set(Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -123,6 +123,15 @@ impl<T> Setting<T> {
|
||||
*self = new;
|
||||
true
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn unwrap(self) -> T {
|
||||
match self {
|
||||
Setting::Set(value) => value,
|
||||
Setting::Reset => panic!("Setting::Reset unwrapped"),
|
||||
Setting::NotSet => panic!("Setting::NotSet unwrapped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Serialize> Serialize for Setting<T> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user