mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-06 04:05:37 +00:00
Better chat settings management
This commit is contained in:
parent
f9ecb0ff31
commit
1eb8249a51
@ -28,6 +28,7 @@ use meilisearch_types::keys::actions;
|
|||||||
use meilisearch_types::milli::index::ChatConfig;
|
use meilisearch_types::milli::index::ChatConfig;
|
||||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||||
|
use meilisearch_types::milli::update::Setting;
|
||||||
use meilisearch_types::milli::{
|
use meilisearch_types::milli::{
|
||||||
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
|
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
|
||||||
};
|
};
|
||||||
@ -107,20 +108,20 @@ fn setup_search_tool(
|
|||||||
.function(
|
.function(
|
||||||
FunctionObjectArgs::default()
|
FunctionObjectArgs::default()
|
||||||
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
|
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||||
.description(&prompts.search_description)
|
.description(&prompts.search_description.clone().unwrap())
|
||||||
.parameters(json!({
|
.parameters(json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"index_uid": {
|
"index_uid": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": index_uids,
|
"enum": index_uids,
|
||||||
"description": prompts.search_index_uid_param,
|
"description": prompts.search_index_uid_param.clone().unwrap(),
|
||||||
},
|
},
|
||||||
"q": {
|
"q": {
|
||||||
// Unfortunately, Mistral does not support an array of types, here.
|
// Unfortunately, Mistral does not support an array of types, here.
|
||||||
// "type": ["string", "null"],
|
// "type": ["string", "null"],
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": prompts.search_q_param,
|
"description": prompts.search_q_param.clone().unwrap(),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["index_uid", "q"],
|
"required": ["index_uid", "q"],
|
||||||
@ -136,7 +137,9 @@ fn setup_search_tool(
|
|||||||
chat_completion.messages.insert(
|
chat_completion.messages.insert(
|
||||||
0,
|
0,
|
||||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
|
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||||
|
prompts.system.as_ref().unwrap().clone(),
|
||||||
|
),
|
||||||
name: None,
|
name: None,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
@ -239,16 +242,17 @@ async fn non_streamed_chat(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut config = OpenAIConfig::default();
|
let mut config = OpenAIConfig::default();
|
||||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
|
||||||
config = config.with_api_key(api_key);
|
config = config.with_api_key(api_key);
|
||||||
}
|
}
|
||||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
|
||||||
config = config.with_api_base(base_api);
|
config = config.with_api_base(base_api);
|
||||||
}
|
}
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
|
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||||
|
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||||
|
|
||||||
let mut response;
|
let mut response;
|
||||||
loop {
|
loop {
|
||||||
@ -296,7 +300,7 @@ async fn non_streamed_chat(
|
|||||||
tool_call_id: call.id.clone(),
|
tool_call_id: call.id.clone(),
|
||||||
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
||||||
"{}\n\n{text}",
|
"{}\n\n{text}",
|
||||||
chat_settings.prompts.pre_query
|
chat_settings.prompts.clone().unwrap().pre_query.unwrap()
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
));
|
));
|
||||||
@ -325,20 +329,21 @@ async fn streamed_chat(
|
|||||||
let filters = index_scheduler.filters();
|
let filters = index_scheduler.filters();
|
||||||
|
|
||||||
let chat_settings = match index_scheduler.chat_settings().unwrap() {
|
let chat_settings = match index_scheduler.chat_settings().unwrap() {
|
||||||
Some(value) => serde_json::from_value(value).unwrap(),
|
Some(value) => serde_json::from_value(value.clone()).unwrap(),
|
||||||
None => GlobalChatSettings::default(),
|
None => GlobalChatSettings::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut config = OpenAIConfig::default();
|
let mut config = OpenAIConfig::default();
|
||||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
|
||||||
config = config.with_api_key(api_key);
|
config = config.with_api_key(api_key);
|
||||||
}
|
}
|
||||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
|
||||||
config = config.with_api_base(base_api);
|
config = config.with_api_base(base_api);
|
||||||
}
|
}
|
||||||
|
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||||
|
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
let _join_handle = Handle::current().spawn(async move {
|
let _join_handle = Handle::current().spawn(async move {
|
||||||
@ -447,7 +452,7 @@ async fn streamed_chat(
|
|||||||
let tool = ChatCompletionRequestToolMessage {
|
let tool = ChatCompletionRequestToolMessage {
|
||||||
tool_call_id: call.id.clone(),
|
tool_call_id: call.id.clone(),
|
||||||
content: ChatCompletionRequestToolMessageContent::Text(
|
content: ChatCompletionRequestToolMessageContent::Text(
|
||||||
format!("{}\n\n{text}", chat_settings.prompts.pre_query),
|
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ use actix_web::HttpResponse;
|
|||||||
use index_scheduler::IndexScheduler;
|
use index_scheduler::IndexScheduler;
|
||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use meilisearch_types::keys::actions;
|
use meilisearch_types::keys::actions;
|
||||||
|
use meilisearch_types::milli::update::Setting;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::extractors::authentication::policies::ActionPolicy;
|
use crate::extractors::authentication::policies::ActionPolicy;
|
||||||
@ -35,37 +36,63 @@ async fn patch_settings(
|
|||||||
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
|
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
|
||||||
Data<IndexScheduler>,
|
Data<IndexScheduler>,
|
||||||
>,
|
>,
|
||||||
web::Json(chat_settings): web::Json<GlobalChatSettings>,
|
web::Json(new): web::Json<GlobalChatSettings>,
|
||||||
) -> Result<HttpResponse, ResponseError> {
|
) -> Result<HttpResponse, ResponseError> {
|
||||||
let chat_settings = serde_json::to_value(chat_settings).unwrap();
|
let old = match index_scheduler.chat_settings()? {
|
||||||
index_scheduler.put_chat_settings(&chat_settings)?;
|
Some(value) => serde_json::from_value(value).unwrap(),
|
||||||
|
None => GlobalChatSettings::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let settings = GlobalChatSettings {
|
||||||
|
source: new.source.or(old.source),
|
||||||
|
base_api: new.base_api.clone().or(old.base_api),
|
||||||
|
api_key: new.api_key.clone().or(old.api_key),
|
||||||
|
prompts: match (new.prompts, old.prompts) {
|
||||||
|
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
|
||||||
|
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
|
||||||
|
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
|
||||||
|
// If both are set we must merge the prompts settings
|
||||||
|
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
|
||||||
|
system: new.system.or(old.system),
|
||||||
|
search_description: new.search_description.or(old.search_description),
|
||||||
|
search_q_param: new.search_q_param.or(old.search_q_param),
|
||||||
|
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
|
||||||
|
pre_query: new.pre_query.or(old.pre_query),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = serde_json::to_value(settings).unwrap();
|
||||||
|
index_scheduler.put_chat_settings(&value)?;
|
||||||
Ok(HttpResponse::Ok().finish())
|
Ok(HttpResponse::Ok().finish())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||||
pub struct GlobalChatSettings {
|
pub struct GlobalChatSettings {
|
||||||
pub source: String,
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
pub base_api: Option<String>,
|
pub source: Setting<String>,
|
||||||
pub api_key: Option<String>,
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
pub prompts: ChatPrompts,
|
pub base_api: Setting<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
pub api_key: Setting<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
pub prompts: Setting<ChatPrompts>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||||
pub struct ChatPrompts {
|
pub struct ChatPrompts {
|
||||||
pub system: String,
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
pub search_description: String,
|
pub system: Setting<String>,
|
||||||
pub search_q_param: String,
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
pub search_index_uid_param: String,
|
pub search_description: Setting<String>,
|
||||||
pub pre_query: String,
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
}
|
pub search_q_param: Setting<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
pub search_index_uid_param: Setting<String>,
|
||||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
pub struct ChatIndexSettings {
|
pub pre_query: Setting<String>,
|
||||||
pub description: String,
|
|
||||||
pub document_template: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
|
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
|
||||||
@ -91,17 +118,26 @@ Selecting the right index ensures the most relevant results for the user query";
|
|||||||
impl Default for GlobalChatSettings {
|
impl Default for GlobalChatSettings {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
GlobalChatSettings {
|
GlobalChatSettings {
|
||||||
source: "openai".to_string(),
|
source: Setting::Set("openAi".to_string()),
|
||||||
base_api: None,
|
base_api: Setting::NotSet,
|
||||||
api_key: None,
|
api_key: Setting::NotSet,
|
||||||
prompts: ChatPrompts {
|
prompts: Setting::Set(ChatPrompts::default()),
|
||||||
system: DEFAULT_SYSTEM_MESSAGE.to_string(),
|
}
|
||||||
search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(),
|
}
|
||||||
search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
}
|
||||||
search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION
|
|
||||||
.to_string(),
|
impl Default for ChatPrompts {
|
||||||
pre_query: "".to_string(),
|
fn default() -> Self {
|
||||||
},
|
ChatPrompts {
|
||||||
|
system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()),
|
||||||
|
search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()),
|
||||||
|
search_q_param: Setting::Set(
|
||||||
|
DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||||
|
),
|
||||||
|
search_index_uid_param: Setting::Set(
|
||||||
|
DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||||
|
),
|
||||||
|
pre_query: Setting::Set(Default::default()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,6 +123,15 @@ impl<T> Setting<T> {
|
|||||||
*self = new;
|
*self = new;
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[track_caller]
|
||||||
|
pub fn unwrap(self) -> T {
|
||||||
|
match self {
|
||||||
|
Setting::Set(value) => value,
|
||||||
|
Setting::Reset => panic!("Setting::Reset unwrapped"),
|
||||||
|
Setting::NotSet => panic!("Setting::NotSet unwrapped"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Serialize> Serialize for Setting<T> {
|
impl<T: Serialize> Serialize for Setting<T> {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user