Compare commits

..

13 Commits

Author SHA1 Message Date
Clément Renault
5601e87ad5 Clean up the code a bit 2025-05-28 14:22:53 +02:00
Clément Renault
ae604f4b87 Factorize the code a bit more and support reporting errors 2025-05-27 18:07:29 +02:00
Clément Renault
da1e285d70 Report the sources 2025-05-27 11:48:12 +02:00
Kerollmops
487e3c2ea3 Fix compilation error in test 2025-05-26 14:11:30 +02:00
Clément Renault
d45647b58d Call specific tools to show progression and results. 2025-05-23 17:25:09 +02:00
Clément Renault
045a1b1e75 Introduce a lot of search parameters and make Deserr happy 2025-05-22 15:34:49 +02:00
Clément Renault
293ac45b7c Expose a well defined set of sources 2025-05-22 10:42:36 +02:00
Clément Renault
8aa7ae912a Add the index descriptions to the function description 2025-05-22 10:40:43 +02:00
Clément Renault
7492b6e669 redact the chat settings API key 2025-05-21 21:18:18 +02:00
Clément Renault
fcc0d43a62 Better chat settings management 2025-05-21 21:06:11 +02:00
Clément Renault
72d4998dce Correctly list the chat settings key actions 2025-05-21 16:24:51 +02:00
Clément Renault
fde11573da Always use the frequency matching strategy 2025-05-21 16:18:37 +02:00
Clément Renault
41220f786b Remove templating validation 2025-05-21 16:10:31 +02:00
16 changed files with 1097 additions and 268 deletions

View File

@@ -305,6 +305,7 @@ pub(crate) mod test {
localized_attributes: Setting::NotSet,
facet_search: Setting::NotSet,
prefix_search: Setting::NotSet,
chat: Setting::NotSet,
_kind: std::marker::PhantomData,
};
settings.check()

View File

@@ -34,6 +34,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String {
planned_failures: _,
run_loop_iteration: _,
embedders: _,
chat_settings: _,
} = scheduler;
let rtxn = env.read_txn().unwrap();

View File

@@ -165,6 +165,7 @@ impl AuthController {
}
}
#[derive(Debug)]
pub struct AuthFilter {
search_rules: Option<SearchRules>,
key_authorized_indexes: SearchRules,

View File

@@ -4,9 +4,12 @@ use std::marker::PhantomData;
use std::ops::ControlFlow;
use deserr::errors::{JsonError, QueryParamError};
use deserr::{take_cf_content, DeserializeError, IntoValue, MergeWithError, ValuePointerRef};
use deserr::{
take_cf_content, DeserializeError, Deserr, IntoValue, MergeWithError, ValuePointerRef,
};
use milli::update::ChatSettings;
use crate::error::deserr_codes::*;
use crate::error::deserr_codes::{self, *};
use crate::error::{
Code, DeserrParseBoolError, DeserrParseIntError, ErrorCode, InvalidTaskDateError,
ParseOffsetDateTimeError,
@@ -33,6 +36,7 @@ pub struct DeserrError<Format, C: Default + ErrorCode> {
pub code: Code,
_phantom: PhantomData<(Format, C)>,
}
impl<Format, C: Default + ErrorCode> DeserrError<Format, C> {
pub fn new(msg: String, code: Code) -> Self {
Self { msg, code, _phantom: PhantomData }
@@ -117,6 +121,16 @@ impl<C: Default + ErrorCode> DeserializeError for DeserrQueryParamError<C> {
}
}
impl Deserr<DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> for ChatSettings {
fn deserialize_from_value<V: IntoValue>(
value: deserr::Value<V>,
location: ValuePointerRef,
) -> Result<Self, DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> {
Deserr::<JsonError>::deserialize_from_value(value, location)
.map_err(|e| DeserrError::new(e.to_string(), InvalidSettingsIndexChat.error_code()))
}
}
pub fn immutable_field_error(field: &str, accepted: &[&str], code: Code) -> DeserrJsonError {
let msg = format!(
"Immutable field `{field}`: expected one of {}",

View File

@@ -326,6 +326,9 @@ pub enum Action {
#[serde(rename = "chat.get")]
#[deserr(rename = "chat.get")]
Chat,
#[serde(rename = "chatSettings.*")]
#[deserr(rename = "chatSettings.*")]
ChatSettingsAll,
#[serde(rename = "chatSettings.get")]
#[deserr(rename = "chatSettings.get")]
ChatSettingsGet,
@@ -357,6 +360,9 @@ impl Action {
SETTINGS_ALL => Some(Self::SettingsAll),
SETTINGS_GET => Some(Self::SettingsGet),
SETTINGS_UPDATE => Some(Self::SettingsUpdate),
CHAT_SETTINGS_ALL => Some(Self::ChatSettingsAll),
CHAT_SETTINGS_GET => Some(Self::ChatSettingsGet),
CHAT_SETTINGS_UPDATE => Some(Self::ChatSettingsUpdate),
STATS_ALL => Some(Self::StatsAll),
STATS_GET => Some(Self::StatsGet),
METRICS_ALL => Some(Self::MetricsAll),
@@ -424,6 +430,7 @@ pub mod actions {
pub const NETWORK_UPDATE: u8 = NetworkUpdate.repr();
pub const CHAT: u8 = Chat.repr();
pub const CHAT_SETTINGS_ALL: u8 = ChatSettingsAll.repr();
pub const CHAT_SETTINGS_GET: u8 = ChatSettingsGet.repr();
pub const CHAT_SETTINGS_UPDATE: u8 = ChatSettingsUpdate.repr();
}

View File

@@ -186,7 +186,7 @@ impl<E: DeserializeError> Deserr<E> for SettingEmbeddingSettings {
/// Holds all the settings for an index. `T` can either be `Checked` if they represents settings
/// whose validity is guaranteed, or `Unchecked` if they need to be validated. In the later case, a
/// call to `check` will return a `Settings<Checked>` from a `Settings<Unchecked>`.
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(
deny_unknown_fields,
rename_all = "camelCase",

View File

@@ -8,7 +8,7 @@ use crate::error::ResponseError;
use crate::settings::{Settings, Unchecked};
use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, ToSchema)]
#[derive(Debug, Clone, PartialEq, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
#[schema(rename_all = "camelCase")]
pub struct TaskView {
@@ -67,7 +67,7 @@ impl TaskView {
}
}
#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize, ToSchema)]
#[derive(Default, Debug, PartialEq, Clone, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
#[schema(rename_all = "camelCase")]
pub struct DetailsView {

View File

@@ -597,7 +597,7 @@ impl fmt::Display for ParseTaskKindError {
}
impl std::error::Error for ParseTaskKindError {}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum Details {
DocumentAdditionOrUpdate {
received_documents: u64,

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ use actix_web::HttpResponse;
use index_scheduler::IndexScheduler;
use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::update::Setting;
use serde::{Deserialize, Serialize};
use crate::extractors::authentication::policies::ActionPolicy;
@@ -23,10 +24,11 @@ async fn get_settings(
Data<IndexScheduler>,
>,
) -> Result<HttpResponse, ResponseError> {
let settings = match index_scheduler.chat_settings()? {
let mut settings = match index_scheduler.chat_settings()? {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
settings.hide_secrets();
Ok(HttpResponse::Ok().json(settings))
}
@@ -35,37 +37,96 @@ async fn patch_settings(
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
Data<IndexScheduler>,
>,
web::Json(chat_settings): web::Json<GlobalChatSettings>,
web::Json(new): web::Json<GlobalChatSettings>,
) -> Result<HttpResponse, ResponseError> {
let chat_settings = serde_json::to_value(chat_settings).unwrap();
index_scheduler.put_chat_settings(&chat_settings)?;
let old = match index_scheduler.chat_settings()? {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
let settings = GlobalChatSettings {
source: new.source.or(old.source),
base_api: new.base_api.clone().or(old.base_api),
api_key: new.api_key.clone().or(old.api_key),
prompts: match (new.prompts, old.prompts) {
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
// If both are set we must merge the prompts settings
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
system: new.system.or(old.system),
search_description: new.search_description.or(old.search_description),
search_q_param: new.search_q_param.or(old.search_q_param),
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
pre_query: new.pre_query.or(old.pre_query),
}),
},
};
let value = serde_json::to_value(settings).unwrap();
index_scheduler.put_chat_settings(&value)?;
Ok(HttpResponse::Ok().finish())
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub enum ChatSource {
OpenAi,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct GlobalChatSettings {
pub source: String,
pub base_api: Option<String>,
pub api_key: Option<String>,
pub prompts: ChatPrompts,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub source: Setting<ChatSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub base_api: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub prompts: Setting<ChatPrompts>,
}
#[derive(Debug, Serialize, Deserialize)]
impl GlobalChatSettings {
pub fn hide_secrets(&mut self) {
match &mut self.api_key {
Setting::Set(key) => Self::hide_secret(key),
Setting::Reset => (),
Setting::NotSet => (),
}
}
fn hide_secret(secret: &mut String) {
match secret.len() {
x if x < 10 => {
secret.replace_range(.., "XXX...");
}
x if x < 20 => {
secret.replace_range(2.., "XXXX...");
}
x if x < 30 => {
secret.replace_range(3.., "XXXXX...");
}
_x => {
secret.replace_range(5.., "XXXXXX...");
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatPrompts {
pub system: String,
pub search_description: String,
pub search_q_param: String,
pub search_index_uid_param: String,
pub pre_query: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatIndexSettings {
pub description: String,
pub document_template: String,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub system: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_description: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_q_param: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_index_uid_param: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub pre_query: Setting<String>,
}
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
@@ -91,17 +152,26 @@ Selecting the right index ensures the most relevant results for the user query";
impl Default for GlobalChatSettings {
fn default() -> Self {
GlobalChatSettings {
source: "openai".to_string(),
base_api: None,
api_key: None,
prompts: ChatPrompts {
system: DEFAULT_SYSTEM_MESSAGE.to_string(),
search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(),
search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION
.to_string(),
pre_query: "".to_string(),
},
source: Setting::NotSet,
base_api: Setting::NotSet,
api_key: Setting::NotSet,
prompts: Setting::Set(ChatPrompts::default()),
}
}
}
impl Default for ChatPrompts {
fn default() -> Self {
ChatPrompts {
system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()),
search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()),
search_q_param: Setting::Set(
DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
),
search_index_uid_param: Setting::Set(
DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(),
),
pre_query: Setting::Set(Default::default()),
}
}
}

View File

@@ -122,6 +122,7 @@ pub struct SearchQuery {
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(f64);
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
@@ -279,8 +280,8 @@ impl fmt::Debug for SearchQuery {
#[deserr(error = DeserrJsonError<InvalidSearchHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
#[serde(rename_all = "camelCase")]
pub struct HybridQuery {
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
#[schema(value_type = f32, default)]
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)]
#[schema(default, value_type = f32)]
#[serde(default)]
pub semantic_ratio: SemanticRatio,
#[deserr(error = DeserrJsonError<InvalidSearchEmbedder>)]

View File

@@ -1695,7 +1695,7 @@ impl Index {
pub fn chat_config(&self, txn: &RoTxn<'_>) -> heed::Result<ChatConfig> {
self.main
.remap_types::<Str, SerdeBincode<_>>()
.remap_types::<Str, SerdeJson<_>>()
.get(txn, main_key::CHAT)
.map(|o| o.unwrap_or_default())
}
@@ -1705,7 +1705,7 @@ impl Index {
txn: &mut RwTxn<'_>,
val: &ChatConfig,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeBincode<_>>().put(txn, main_key::CHAT, &val)
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::CHAT, &val)
}
pub(crate) fn delete_chat_config(&self, txn: &mut RwTxn<'_>) -> heed::Result<bool> {
@@ -1943,15 +1943,54 @@ pub struct ChatConfig {
pub description: String,
/// Contains the document template and max template length.
pub prompt: PromptData,
pub search_parameters: SearchParameters,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub hybrid: Option<HybridQuery>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub distinct: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matching_strategy: Option<MatchingStrategy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub attributes_to_search_on: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ranking_score_threshold: Option<f64>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HybridQuery {
pub semantic_ratio: f32,
pub embedder: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PrefixSettings {
pub prefix_count_threshold: usize,
pub max_prefix_length: usize,
pub compute_prefixes: PrefixSearch,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum MatchingStrategy {
/// Remove query words from last to first
Last,
/// All query words are mandatory
All,
/// Remove query words from the most frequent to the least
Frequency,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "camelCase")]
pub enum PrefixSearch {

View File

@@ -105,10 +105,10 @@ impl Prompt {
max_bytes,
};
// render template with special object that's OK with `doc.*` and `fields.*`
this.template
.render(&template_checker::TemplateChecker)
.map_err(NewPromptError::invalid_fields_in_template)?;
// // render template with special object that's OK with `doc.*` and `fields.*`
// this.template
// .render(&template_checker::TemplateChecker)
// .map_err(NewPromptError::invalid_fields_in_template)?;
Ok(this)
}

View File

@@ -1,14 +1,19 @@
use std::error::Error;
use std::fmt;
use deserr::errors::JsonError;
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::index::ChatConfig;
use crate::index::{self, ChatConfig, SearchParameters};
use crate::prompt::{default_max_bytes, PromptData};
use crate::update::Setting;
use crate::TermsMatchingStrategy;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(deny_unknown_fields, rename_all = camelCase)]
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
pub struct ChatSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
@@ -29,17 +34,226 @@ pub struct ChatSettings {
#[deserr(default)]
#[schema(value_type = Option<usize>)]
pub document_template_max_bytes: Setting<usize>,
/// The search parameters to use for the LLM.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<ChatSearchParams>)]
pub search_parameters: Setting<ChatSearchParams>,
}
impl From<ChatConfig> for ChatSettings {
fn from(config: ChatConfig) -> Self {
let ChatConfig { description, prompt: PromptData { template, max_bytes } } = config;
let ChatConfig {
description,
prompt: PromptData { template, max_bytes },
search_parameters,
} = config;
ChatSettings {
description: Setting::Set(description),
document_template: Setting::Set(template),
document_template_max_bytes: Setting::Set(
max_bytes.unwrap_or(default_max_bytes()).get(),
),
search_parameters: Setting::Set({
let SearchParameters {
hybrid,
limit,
sort,
distinct,
matching_strategy,
attributes_to_search_on,
ranking_score_threshold,
} = search_parameters;
let hybrid = hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| {
HybridQuery { semantic_ratio: SemanticRatio(semantic_ratio), embedder }
});
let matching_strategy = matching_strategy.map(|ms| match ms {
index::MatchingStrategy::Last => MatchingStrategy::Last,
index::MatchingStrategy::All => MatchingStrategy::All,
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
});
let ranking_score_threshold = ranking_score_threshold.map(RankingScoreThreshold);
ChatSearchParams {
hybrid: Setting::some_or_not_set(hybrid),
limit: Setting::some_or_not_set(limit),
sort: Setting::some_or_not_set(sort),
distinct: Setting::some_or_not_set(distinct),
matching_strategy: Setting::some_or_not_set(matching_strategy),
attributes_to_search_on: Setting::some_or_not_set(attributes_to_search_on),
ranking_score_threshold: Setting::some_or_not_set(ranking_score_threshold),
}
}),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
pub struct ChatSearchParams {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<HybridQuery>)]
pub hybrid: Setting<HybridQuery>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default = Setting::Set(20))]
#[schema(value_type = Option<usize>)]
pub limit: Setting<usize>,
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
// #[deserr(default)]
// pub attributes_to_retrieve: Option<BTreeSet<String>>,
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
// #[deserr(default)]
// pub filter: Option<Value>,
//
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<Vec<String>>)]
pub sort: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
pub distinct: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<MatchingStrategy>)]
pub matching_strategy: Setting<MatchingStrategy>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<Vec<String>>)]
pub attributes_to_search_on: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<RankingScoreThreshold>)]
pub ranking_score_threshold: Setting<RankingScoreThreshold>,
}
#[derive(Debug, Clone, Default, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[deserr(error = JsonError, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery {
#[deserr(default)]
#[serde(default)]
#[schema(default, value_type = f32)]
pub semantic_ratio: SemanticRatio,
#[schema(value_type = String)]
pub embedder: String,
}
#[derive(Debug, Clone, Copy, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32);
impl Default for SemanticRatio {
fn default() -> Self {
SemanticRatio(0.5)
}
}
impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio)
} else {
Ok(SemanticRatio(f))
}
}
}
#[derive(Debug)]
pub struct InvalidSearchSemanticRatio;
impl Error for InvalidSearchSemanticRatio {}
impl fmt::Display for InvalidSearchSemanticRatio {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`."
)
}
}
impl std::ops::Deref for SemanticRatio {
type Target = f32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(rename_all = camelCase)]
#[serde(rename_all = "camelCase")]
pub enum MatchingStrategy {
/// Remove query words from last to first
Last,
/// All query words are mandatory
All,
/// Remove query words from the most frequent to the least
Frequency,
}
impl Default for MatchingStrategy {
fn default() -> Self {
Self::Last
}
}
impl From<MatchingStrategy> for TermsMatchingStrategy {
fn from(other: MatchingStrategy) -> Self {
match other {
MatchingStrategy::Last => Self::Last,
MatchingStrategy::All => Self::All,
MatchingStrategy::Frequency => Self::Frequency,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(pub f64);
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
fn try_from(f: f64) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchRankingScoreThreshold)
} else {
Ok(RankingScoreThreshold(f))
}
}
}
#[derive(Debug)]
pub struct InvalidSearchRankingScoreThreshold;
impl Error for InvalidSearchRankingScoreThreshold {}
impl fmt::Display for InvalidSearchRankingScoreThreshold {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
)
}
}

View File

@@ -11,6 +11,7 @@ use roaring::RoaringBitmap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
use super::chat::{ChatSearchParams, RankingScoreThreshold};
use super::del_add::{DelAdd, DelAddOperation};
use super::index_documents::{IndexDocumentsConfig, Transform};
use super::{ChatSettings, IndexerConfig};
@@ -22,8 +23,8 @@ use crate::error::UserError;
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::filterable_attributes_rules::match_faceted_field;
use crate::index::{
ChatConfig, IndexEmbeddingConfig, PrefixSearch, DEFAULT_MIN_WORD_LEN_ONE_TYPO,
DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch,
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
};
use crate::order_by_map::OrderByMap;
use crate::prompt::{default_max_bytes, PromptData};
@@ -123,6 +124,15 @@ impl<T> Setting<T> {
*self = new;
true
}
#[track_caller]
pub fn unwrap(self) -> T {
match self {
Setting::Set(value) => value,
Setting::Reset => panic!("Setting::Reset unwrapped"),
Setting::NotSet => panic!("Setting::NotSet unwrapped"),
}
}
}
impl<T: Serialize> Serialize for Setting<T> {
@@ -1255,11 +1265,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
description: new_description,
document_template: new_document_template,
document_template_max_bytes: new_document_template_max_bytes,
search_parameters: new_search_parameters,
}) => {
let mut old = self.index.chat_config(self.wtxn)?;
let ChatConfig {
ref mut description,
prompt: PromptData { ref mut template, ref mut max_bytes },
ref mut search_parameters,
} = old;
match new_description {
@@ -1280,6 +1292,85 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Setting::NotSet => (),
}
match new_search_parameters {
Setting::Set(sp) => {
let ChatSearchParams {
hybrid,
limit,
sort,
distinct,
matching_strategy,
attributes_to_search_on,
ranking_score_threshold,
} = sp;
match hybrid {
Setting::Set(hybrid) => {
search_parameters.hybrid = Some(crate::index::HybridQuery {
semantic_ratio: *hybrid.semantic_ratio,
embedder: hybrid.embedder.clone(),
})
}
Setting::Reset => search_parameters.hybrid = None,
Setting::NotSet => (),
}
match limit {
Setting::Set(limit) => search_parameters.limit = Some(*limit),
Setting::Reset => search_parameters.limit = None,
Setting::NotSet => (),
}
match sort {
Setting::Set(sort) => search_parameters.sort = Some(sort.clone()),
Setting::Reset => search_parameters.sort = None,
Setting::NotSet => (),
}
match distinct {
Setting::Set(distinct) => {
search_parameters.distinct = Some(distinct.clone())
}
Setting::Reset => search_parameters.distinct = None,
Setting::NotSet => (),
}
match matching_strategy {
Setting::Set(matching_strategy) => {
let strategy = match matching_strategy {
super::chat::MatchingStrategy::Last => MatchingStrategy::Last,
super::chat::MatchingStrategy::All => MatchingStrategy::All,
super::chat::MatchingStrategy::Frequency => {
MatchingStrategy::Frequency
}
};
search_parameters.matching_strategy = Some(strategy)
}
Setting::Reset => search_parameters.matching_strategy = None,
Setting::NotSet => (),
}
match attributes_to_search_on {
Setting::Set(attributes_to_search_on) => {
search_parameters.attributes_to_search_on =
Some(attributes_to_search_on.clone())
}
Setting::Reset => search_parameters.attributes_to_search_on = None,
Setting::NotSet => (),
}
match ranking_score_threshold {
Setting::Set(RankingScoreThreshold(score)) => {
search_parameters.ranking_score_threshold = Some(*score)
}
Setting::Reset => search_parameters.ranking_score_threshold = None,
Setting::NotSet => (),
}
}
Setting::Reset => *search_parameters = Default::default(),
Setting::NotSet => (),
}
self.index.put_chat_config(self.wtxn, &old)?;
Ok(true)
}

View File

@@ -897,6 +897,7 @@ fn test_correct_settings_init() {
prefix_search,
facet_search,
disable_on_numbers,
chat,
} = settings;
assert!(matches!(searchable_fields, Setting::NotSet));
assert!(matches!(displayed_fields, Setting::NotSet));