mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-12-26 06:17:00 +00:00
Compare commits
13 Commits
prototype-
...
prototype-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5601e87ad5 | ||
|
|
ae604f4b87 | ||
|
|
da1e285d70 | ||
|
|
487e3c2ea3 | ||
|
|
d45647b58d | ||
|
|
045a1b1e75 | ||
|
|
293ac45b7c | ||
|
|
8aa7ae912a | ||
|
|
7492b6e669 | ||
|
|
fcc0d43a62 | ||
|
|
72d4998dce | ||
|
|
fde11573da | ||
|
|
41220f786b |
@@ -305,6 +305,7 @@ pub(crate) mod test {
|
||||
localized_attributes: Setting::NotSet,
|
||||
facet_search: Setting::NotSet,
|
||||
prefix_search: Setting::NotSet,
|
||||
chat: Setting::NotSet,
|
||||
_kind: std::marker::PhantomData,
|
||||
};
|
||||
settings.check()
|
||||
|
||||
@@ -34,6 +34,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String {
|
||||
planned_failures: _,
|
||||
run_loop_iteration: _,
|
||||
embedders: _,
|
||||
chat_settings: _,
|
||||
} = scheduler;
|
||||
|
||||
let rtxn = env.read_txn().unwrap();
|
||||
|
||||
@@ -165,6 +165,7 @@ impl AuthController {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthFilter {
|
||||
search_rules: Option<SearchRules>,
|
||||
key_authorized_indexes: SearchRules,
|
||||
|
||||
@@ -4,9 +4,12 @@ use std::marker::PhantomData;
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use deserr::errors::{JsonError, QueryParamError};
|
||||
use deserr::{take_cf_content, DeserializeError, IntoValue, MergeWithError, ValuePointerRef};
|
||||
use deserr::{
|
||||
take_cf_content, DeserializeError, Deserr, IntoValue, MergeWithError, ValuePointerRef,
|
||||
};
|
||||
use milli::update::ChatSettings;
|
||||
|
||||
use crate::error::deserr_codes::*;
|
||||
use crate::error::deserr_codes::{self, *};
|
||||
use crate::error::{
|
||||
Code, DeserrParseBoolError, DeserrParseIntError, ErrorCode, InvalidTaskDateError,
|
||||
ParseOffsetDateTimeError,
|
||||
@@ -33,6 +36,7 @@ pub struct DeserrError<Format, C: Default + ErrorCode> {
|
||||
pub code: Code,
|
||||
_phantom: PhantomData<(Format, C)>,
|
||||
}
|
||||
|
||||
impl<Format, C: Default + ErrorCode> DeserrError<Format, C> {
|
||||
pub fn new(msg: String, code: Code) -> Self {
|
||||
Self { msg, code, _phantom: PhantomData }
|
||||
@@ -117,6 +121,16 @@ impl<C: Default + ErrorCode> DeserializeError for DeserrQueryParamError<C> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserr<DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> for ChatSettings {
|
||||
fn deserialize_from_value<V: IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: ValuePointerRef,
|
||||
) -> Result<Self, DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> {
|
||||
Deserr::<JsonError>::deserialize_from_value(value, location)
|
||||
.map_err(|e| DeserrError::new(e.to_string(), InvalidSettingsIndexChat.error_code()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn immutable_field_error(field: &str, accepted: &[&str], code: Code) -> DeserrJsonError {
|
||||
let msg = format!(
|
||||
"Immutable field `{field}`: expected one of {}",
|
||||
|
||||
@@ -326,6 +326,9 @@ pub enum Action {
|
||||
#[serde(rename = "chat.get")]
|
||||
#[deserr(rename = "chat.get")]
|
||||
Chat,
|
||||
#[serde(rename = "chatSettings.*")]
|
||||
#[deserr(rename = "chatSettings.*")]
|
||||
ChatSettingsAll,
|
||||
#[serde(rename = "chatSettings.get")]
|
||||
#[deserr(rename = "chatSettings.get")]
|
||||
ChatSettingsGet,
|
||||
@@ -357,6 +360,9 @@ impl Action {
|
||||
SETTINGS_ALL => Some(Self::SettingsAll),
|
||||
SETTINGS_GET => Some(Self::SettingsGet),
|
||||
SETTINGS_UPDATE => Some(Self::SettingsUpdate),
|
||||
CHAT_SETTINGS_ALL => Some(Self::ChatSettingsAll),
|
||||
CHAT_SETTINGS_GET => Some(Self::ChatSettingsGet),
|
||||
CHAT_SETTINGS_UPDATE => Some(Self::ChatSettingsUpdate),
|
||||
STATS_ALL => Some(Self::StatsAll),
|
||||
STATS_GET => Some(Self::StatsGet),
|
||||
METRICS_ALL => Some(Self::MetricsAll),
|
||||
@@ -424,6 +430,7 @@ pub mod actions {
|
||||
pub const NETWORK_UPDATE: u8 = NetworkUpdate.repr();
|
||||
|
||||
pub const CHAT: u8 = Chat.repr();
|
||||
pub const CHAT_SETTINGS_ALL: u8 = ChatSettingsAll.repr();
|
||||
pub const CHAT_SETTINGS_GET: u8 = ChatSettingsGet.repr();
|
||||
pub const CHAT_SETTINGS_UPDATE: u8 = ChatSettingsUpdate.repr();
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ impl<E: DeserializeError> Deserr<E> for SettingEmbeddingSettings {
|
||||
/// Holds all the settings for an index. `T` can either be `Checked` if they represents settings
|
||||
/// whose validity is guaranteed, or `Unchecked` if they need to be validated. In the later case, a
|
||||
/// call to `check` will return a `Settings<Checked>` from a `Settings<Unchecked>`.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
|
||||
#[serde(
|
||||
deny_unknown_fields,
|
||||
rename_all = "camelCase",
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::error::ResponseError;
|
||||
use crate::settings::{Settings, Unchecked};
|
||||
use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, ToSchema)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
pub struct TaskView {
|
||||
@@ -67,7 +67,7 @@ impl TaskView {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize, ToSchema)]
|
||||
#[derive(Default, Debug, PartialEq, Clone, Serialize, Deserialize, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[schema(rename_all = "camelCase")]
|
||||
pub struct DetailsView {
|
||||
|
||||
@@ -597,7 +597,7 @@ impl fmt::Display for ParseTaskKindError {
|
||||
}
|
||||
impl std::error::Error for ParseTaskKindError {}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum Details {
|
||||
DocumentAdditionOrUpdate {
|
||||
received_documents: u64,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ use actix_web::HttpResponse;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
@@ -23,10 +24,11 @@ async fn get_settings(
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
let settings = match index_scheduler.chat_settings()? {
|
||||
let mut settings = match index_scheduler.chat_settings()? {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
};
|
||||
settings.hide_secrets();
|
||||
Ok(HttpResponse::Ok().json(settings))
|
||||
}
|
||||
|
||||
@@ -35,37 +37,96 @@ async fn patch_settings(
|
||||
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
|
||||
Data<IndexScheduler>,
|
||||
>,
|
||||
web::Json(chat_settings): web::Json<GlobalChatSettings>,
|
||||
web::Json(new): web::Json<GlobalChatSettings>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
let chat_settings = serde_json::to_value(chat_settings).unwrap();
|
||||
index_scheduler.put_chat_settings(&chat_settings)?;
|
||||
let old = match index_scheduler.chat_settings()? {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
};
|
||||
|
||||
let settings = GlobalChatSettings {
|
||||
source: new.source.or(old.source),
|
||||
base_api: new.base_api.clone().or(old.base_api),
|
||||
api_key: new.api_key.clone().or(old.api_key),
|
||||
prompts: match (new.prompts, old.prompts) {
|
||||
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
|
||||
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
|
||||
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
|
||||
// If both are set we must merge the prompts settings
|
||||
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
|
||||
system: new.system.or(old.system),
|
||||
search_description: new.search_description.or(old.search_description),
|
||||
search_q_param: new.search_q_param.or(old.search_q_param),
|
||||
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
|
||||
pre_query: new.pre_query.or(old.pre_query),
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(settings).unwrap();
|
||||
index_scheduler.put_chat_settings(&value)?;
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub enum ChatSource {
|
||||
OpenAi,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct GlobalChatSettings {
|
||||
pub source: String,
|
||||
pub base_api: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub prompts: ChatPrompts,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub source: Setting<ChatSource>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub base_api: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub prompts: Setting<ChatPrompts>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
impl GlobalChatSettings {
|
||||
pub fn hide_secrets(&mut self) {
|
||||
match &mut self.api_key {
|
||||
Setting::Set(key) => Self::hide_secret(key),
|
||||
Setting::Reset => (),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn hide_secret(secret: &mut String) {
|
||||
match secret.len() {
|
||||
x if x < 10 => {
|
||||
secret.replace_range(.., "XXX...");
|
||||
}
|
||||
x if x < 20 => {
|
||||
secret.replace_range(2.., "XXXX...");
|
||||
}
|
||||
x if x < 30 => {
|
||||
secret.replace_range(3.., "XXXXX...");
|
||||
}
|
||||
_x => {
|
||||
secret.replace_range(5.., "XXXXXX...");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct ChatPrompts {
|
||||
pub system: String,
|
||||
pub search_description: String,
|
||||
pub search_q_param: String,
|
||||
pub search_index_uid_param: String,
|
||||
pub pre_query: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct ChatIndexSettings {
|
||||
pub description: String,
|
||||
pub document_template: String,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub system: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub search_description: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub search_q_param: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub search_index_uid_param: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub pre_query: Setting<String>,
|
||||
}
|
||||
|
||||
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
|
||||
@@ -91,17 +152,26 @@ Selecting the right index ensures the most relevant results for the user query";
|
||||
impl Default for GlobalChatSettings {
|
||||
fn default() -> Self {
|
||||
GlobalChatSettings {
|
||||
source: "openai".to_string(),
|
||||
base_api: None,
|
||||
api_key: None,
|
||||
prompts: ChatPrompts {
|
||||
system: DEFAULT_SYSTEM_MESSAGE.to_string(),
|
||||
search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(),
|
||||
search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||
search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION
|
||||
.to_string(),
|
||||
pre_query: "".to_string(),
|
||||
},
|
||||
source: Setting::NotSet,
|
||||
base_api: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
prompts: Setting::Set(ChatPrompts::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChatPrompts {
|
||||
fn default() -> Self {
|
||||
ChatPrompts {
|
||||
system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()),
|
||||
search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()),
|
||||
search_q_param: Setting::Set(
|
||||
DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||
),
|
||||
search_index_uid_param: Setting::Set(
|
||||
DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(),
|
||||
),
|
||||
pre_query: Setting::Set(Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +122,7 @@ pub struct SearchQuery {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
|
||||
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
|
||||
pub struct RankingScoreThreshold(f64);
|
||||
|
||||
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
||||
type Error = InvalidSearchRankingScoreThreshold;
|
||||
|
||||
@@ -279,8 +280,8 @@ impl fmt::Debug for SearchQuery {
|
||||
#[deserr(error = DeserrJsonError<InvalidSearchHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HybridQuery {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
|
||||
#[schema(value_type = f32, default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)]
|
||||
#[schema(default, value_type = f32)]
|
||||
#[serde(default)]
|
||||
pub semantic_ratio: SemanticRatio,
|
||||
#[deserr(error = DeserrJsonError<InvalidSearchEmbedder>)]
|
||||
|
||||
@@ -1695,7 +1695,7 @@ impl Index {
|
||||
|
||||
pub fn chat_config(&self, txn: &RoTxn<'_>) -> heed::Result<ChatConfig> {
|
||||
self.main
|
||||
.remap_types::<Str, SerdeBincode<_>>()
|
||||
.remap_types::<Str, SerdeJson<_>>()
|
||||
.get(txn, main_key::CHAT)
|
||||
.map(|o| o.unwrap_or_default())
|
||||
}
|
||||
@@ -1705,7 +1705,7 @@ impl Index {
|
||||
txn: &mut RwTxn<'_>,
|
||||
val: &ChatConfig,
|
||||
) -> heed::Result<()> {
|
||||
self.main.remap_types::<Str, SerdeBincode<_>>().put(txn, main_key::CHAT, &val)
|
||||
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::CHAT, &val)
|
||||
}
|
||||
|
||||
pub(crate) fn delete_chat_config(&self, txn: &mut RwTxn<'_>) -> heed::Result<bool> {
|
||||
@@ -1943,15 +1943,54 @@ pub struct ChatConfig {
|
||||
pub description: String,
|
||||
/// Contains the document template and max template length.
|
||||
pub prompt: PromptData,
|
||||
pub search_parameters: SearchParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SearchParameters {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub hybrid: Option<HybridQuery>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub limit: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sort: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub distinct: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub matching_strategy: Option<MatchingStrategy>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub attributes_to_search_on: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ranking_score_threshold: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HybridQuery {
|
||||
pub semantic_ratio: f32,
|
||||
pub embedder: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PrefixSettings {
|
||||
pub prefix_count_threshold: usize,
|
||||
pub max_prefix_length: usize,
|
||||
pub compute_prefixes: PrefixSearch,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum MatchingStrategy {
|
||||
/// Remove query words from last to first
|
||||
Last,
|
||||
/// All query words are mandatory
|
||||
All,
|
||||
/// Remove query words from the most frequent to the least
|
||||
Frequency,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum PrefixSearch {
|
||||
|
||||
@@ -105,10 +105,10 @@ impl Prompt {
|
||||
max_bytes,
|
||||
};
|
||||
|
||||
// render template with special object that's OK with `doc.*` and `fields.*`
|
||||
this.template
|
||||
.render(&template_checker::TemplateChecker)
|
||||
.map_err(NewPromptError::invalid_fields_in_template)?;
|
||||
// // render template with special object that's OK with `doc.*` and `fields.*`
|
||||
// this.template
|
||||
// .render(&template_checker::TemplateChecker)
|
||||
// .map_err(NewPromptError::invalid_fields_in_template)?;
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use deserr::errors::JsonError;
|
||||
use deserr::Deserr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::index::ChatConfig;
|
||||
use crate::index::{self, ChatConfig, SearchParameters};
|
||||
use crate::prompt::{default_max_bytes, PromptData};
|
||||
use crate::update::Setting;
|
||||
use crate::TermsMatchingStrategy;
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(deny_unknown_fields, rename_all = camelCase)]
|
||||
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
|
||||
pub struct ChatSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
@@ -29,17 +34,226 @@ pub struct ChatSettings {
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<usize>)]
|
||||
pub document_template_max_bytes: Setting<usize>,
|
||||
|
||||
/// The search parameters to use for the LLM.
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<ChatSearchParams>)]
|
||||
pub search_parameters: Setting<ChatSearchParams>,
|
||||
}
|
||||
|
||||
impl From<ChatConfig> for ChatSettings {
|
||||
fn from(config: ChatConfig) -> Self {
|
||||
let ChatConfig { description, prompt: PromptData { template, max_bytes } } = config;
|
||||
let ChatConfig {
|
||||
description,
|
||||
prompt: PromptData { template, max_bytes },
|
||||
search_parameters,
|
||||
} = config;
|
||||
ChatSettings {
|
||||
description: Setting::Set(description),
|
||||
document_template: Setting::Set(template),
|
||||
document_template_max_bytes: Setting::Set(
|
||||
max_bytes.unwrap_or(default_max_bytes()).get(),
|
||||
),
|
||||
search_parameters: Setting::Set({
|
||||
let SearchParameters {
|
||||
hybrid,
|
||||
limit,
|
||||
sort,
|
||||
distinct,
|
||||
matching_strategy,
|
||||
attributes_to_search_on,
|
||||
ranking_score_threshold,
|
||||
} = search_parameters;
|
||||
|
||||
let hybrid = hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| {
|
||||
HybridQuery { semantic_ratio: SemanticRatio(semantic_ratio), embedder }
|
||||
});
|
||||
|
||||
let matching_strategy = matching_strategy.map(|ms| match ms {
|
||||
index::MatchingStrategy::Last => MatchingStrategy::Last,
|
||||
index::MatchingStrategy::All => MatchingStrategy::All,
|
||||
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
|
||||
});
|
||||
|
||||
let ranking_score_threshold = ranking_score_threshold.map(RankingScoreThreshold);
|
||||
|
||||
ChatSearchParams {
|
||||
hybrid: Setting::some_or_not_set(hybrid),
|
||||
limit: Setting::some_or_not_set(limit),
|
||||
sort: Setting::some_or_not_set(sort),
|
||||
distinct: Setting::some_or_not_set(distinct),
|
||||
matching_strategy: Setting::some_or_not_set(matching_strategy),
|
||||
attributes_to_search_on: Setting::some_or_not_set(attributes_to_search_on),
|
||||
ranking_score_threshold: Setting::some_or_not_set(ranking_score_threshold),
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
|
||||
pub struct ChatSearchParams {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<HybridQuery>)]
|
||||
pub hybrid: Setting<HybridQuery>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default = Setting::Set(20))]
|
||||
#[schema(value_type = Option<usize>)]
|
||||
pub limit: Setting<usize>,
|
||||
|
||||
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
// #[deserr(default)]
|
||||
// pub attributes_to_retrieve: Option<BTreeSet<String>>,
|
||||
|
||||
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
// #[deserr(default)]
|
||||
// pub filter: Option<Value>,
|
||||
//
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<Vec<String>>)]
|
||||
pub sort: Setting<Vec<String>>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<String>)]
|
||||
pub distinct: Setting<String>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<MatchingStrategy>)]
|
||||
pub matching_strategy: Setting<MatchingStrategy>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<Vec<String>>)]
|
||||
pub attributes_to_search_on: Setting<Vec<String>>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<RankingScoreThreshold>)]
|
||||
pub ranking_score_threshold: Setting<RankingScoreThreshold>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[deserr(error = JsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct HybridQuery {
|
||||
#[deserr(default)]
|
||||
#[serde(default)]
|
||||
#[schema(default, value_type = f32)]
|
||||
pub semantic_ratio: SemanticRatio,
|
||||
#[schema(value_type = String)]
|
||||
pub embedder: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
|
||||
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
|
||||
pub struct SemanticRatio(f32);
|
||||
|
||||
impl Default for SemanticRatio {
|
||||
fn default() -> Self {
|
||||
SemanticRatio(0.5)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<f32> for SemanticRatio {
|
||||
type Error = InvalidSearchSemanticRatio;
|
||||
|
||||
fn try_from(f: f32) -> Result<Self, Self::Error> {
|
||||
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if f > 1.0 || f < 0.0 {
|
||||
Err(InvalidSearchSemanticRatio)
|
||||
} else {
|
||||
Ok(SemanticRatio(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidSearchSemanticRatio;
|
||||
|
||||
impl Error for InvalidSearchSemanticRatio {}
|
||||
|
||||
impl fmt::Display for InvalidSearchSemanticRatio {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for SemanticRatio {
|
||||
type Target = f32;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)]
|
||||
#[deserr(rename_all = camelCase)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum MatchingStrategy {
|
||||
/// Remove query words from last to first
|
||||
Last,
|
||||
/// All query words are mandatory
|
||||
All,
|
||||
/// Remove query words from the most frequent to the least
|
||||
Frequency,
|
||||
}
|
||||
|
||||
impl Default for MatchingStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Last
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MatchingStrategy> for TermsMatchingStrategy {
|
||||
fn from(other: MatchingStrategy) -> Self {
|
||||
match other {
|
||||
MatchingStrategy::Last => Self::Last,
|
||||
MatchingStrategy::All => Self::All,
|
||||
MatchingStrategy::Frequency => Self::Frequency,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize, Deserialize)]
|
||||
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
|
||||
pub struct RankingScoreThreshold(pub f64);
|
||||
|
||||
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
||||
type Error = InvalidSearchRankingScoreThreshold;
|
||||
|
||||
fn try_from(f: f64) -> Result<Self, Self::Error> {
|
||||
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if f > 1.0 || f < 0.0 {
|
||||
Err(InvalidSearchRankingScoreThreshold)
|
||||
} else {
|
||||
Ok(RankingScoreThreshold(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidSearchRankingScoreThreshold;
|
||||
|
||||
impl Error for InvalidSearchRankingScoreThreshold {}
|
||||
|
||||
impl fmt::Display for InvalidSearchRankingScoreThreshold {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ use roaring::RoaringBitmap;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use super::chat::{ChatSearchParams, RankingScoreThreshold};
|
||||
use super::del_add::{DelAdd, DelAddOperation};
|
||||
use super::index_documents::{IndexDocumentsConfig, Transform};
|
||||
use super::{ChatSettings, IndexerConfig};
|
||||
@@ -22,8 +23,8 @@ use crate::error::UserError;
|
||||
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
|
||||
use crate::filterable_attributes_rules::match_faceted_field;
|
||||
use crate::index::{
|
||||
ChatConfig, IndexEmbeddingConfig, PrefixSearch, DEFAULT_MIN_WORD_LEN_ONE_TYPO,
|
||||
DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
|
||||
ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch,
|
||||
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
|
||||
};
|
||||
use crate::order_by_map::OrderByMap;
|
||||
use crate::prompt::{default_max_bytes, PromptData};
|
||||
@@ -123,6 +124,15 @@ impl<T> Setting<T> {
|
||||
*self = new;
|
||||
true
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn unwrap(self) -> T {
|
||||
match self {
|
||||
Setting::Set(value) => value,
|
||||
Setting::Reset => panic!("Setting::Reset unwrapped"),
|
||||
Setting::NotSet => panic!("Setting::NotSet unwrapped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Serialize> Serialize for Setting<T> {
|
||||
@@ -1255,11 +1265,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
description: new_description,
|
||||
document_template: new_document_template,
|
||||
document_template_max_bytes: new_document_template_max_bytes,
|
||||
search_parameters: new_search_parameters,
|
||||
}) => {
|
||||
let mut old = self.index.chat_config(self.wtxn)?;
|
||||
let ChatConfig {
|
||||
ref mut description,
|
||||
prompt: PromptData { ref mut template, ref mut max_bytes },
|
||||
ref mut search_parameters,
|
||||
} = old;
|
||||
|
||||
match new_description {
|
||||
@@ -1280,6 +1292,85 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match new_search_parameters {
|
||||
Setting::Set(sp) => {
|
||||
let ChatSearchParams {
|
||||
hybrid,
|
||||
limit,
|
||||
sort,
|
||||
distinct,
|
||||
matching_strategy,
|
||||
attributes_to_search_on,
|
||||
ranking_score_threshold,
|
||||
} = sp;
|
||||
|
||||
match hybrid {
|
||||
Setting::Set(hybrid) => {
|
||||
search_parameters.hybrid = Some(crate::index::HybridQuery {
|
||||
semantic_ratio: *hybrid.semantic_ratio,
|
||||
embedder: hybrid.embedder.clone(),
|
||||
})
|
||||
}
|
||||
Setting::Reset => search_parameters.hybrid = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match limit {
|
||||
Setting::Set(limit) => search_parameters.limit = Some(*limit),
|
||||
Setting::Reset => search_parameters.limit = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match sort {
|
||||
Setting::Set(sort) => search_parameters.sort = Some(sort.clone()),
|
||||
Setting::Reset => search_parameters.sort = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match distinct {
|
||||
Setting::Set(distinct) => {
|
||||
search_parameters.distinct = Some(distinct.clone())
|
||||
}
|
||||
Setting::Reset => search_parameters.distinct = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match matching_strategy {
|
||||
Setting::Set(matching_strategy) => {
|
||||
let strategy = match matching_strategy {
|
||||
super::chat::MatchingStrategy::Last => MatchingStrategy::Last,
|
||||
super::chat::MatchingStrategy::All => MatchingStrategy::All,
|
||||
super::chat::MatchingStrategy::Frequency => {
|
||||
MatchingStrategy::Frequency
|
||||
}
|
||||
};
|
||||
search_parameters.matching_strategy = Some(strategy)
|
||||
}
|
||||
Setting::Reset => search_parameters.matching_strategy = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match attributes_to_search_on {
|
||||
Setting::Set(attributes_to_search_on) => {
|
||||
search_parameters.attributes_to_search_on =
|
||||
Some(attributes_to_search_on.clone())
|
||||
}
|
||||
Setting::Reset => search_parameters.attributes_to_search_on = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match ranking_score_threshold {
|
||||
Setting::Set(RankingScoreThreshold(score)) => {
|
||||
search_parameters.ranking_score_threshold = Some(*score)
|
||||
}
|
||||
Setting::Reset => search_parameters.ranking_score_threshold = None,
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
}
|
||||
Setting::Reset => *search_parameters = Default::default(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
self.index.put_chat_config(self.wtxn, &old)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
@@ -897,6 +897,7 @@ fn test_correct_settings_init() {
|
||||
prefix_search,
|
||||
facet_search,
|
||||
disable_on_numbers,
|
||||
chat,
|
||||
} = settings;
|
||||
assert!(matches!(searchable_fields, Setting::NotSet));
|
||||
assert!(matches!(displayed_fields, Setting::NotSet));
|
||||
|
||||
Reference in New Issue
Block a user