mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-08 05:05:42 +00:00
Introduce a lot of search parameters and make Deserr happy
This commit is contained in:
parent
bf3286ba41
commit
ca5a87a606
@ -165,6 +165,7 @@ impl AuthController {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct AuthFilter {
|
pub struct AuthFilter {
|
||||||
search_rules: Option<SearchRules>,
|
search_rules: Option<SearchRules>,
|
||||||
key_authorized_indexes: SearchRules,
|
key_authorized_indexes: SearchRules,
|
||||||
|
@ -4,9 +4,12 @@ use std::marker::PhantomData;
|
|||||||
use std::ops::ControlFlow;
|
use std::ops::ControlFlow;
|
||||||
|
|
||||||
use deserr::errors::{JsonError, QueryParamError};
|
use deserr::errors::{JsonError, QueryParamError};
|
||||||
use deserr::{take_cf_content, DeserializeError, IntoValue, MergeWithError, ValuePointerRef};
|
use deserr::{
|
||||||
|
take_cf_content, DeserializeError, Deserr, IntoValue, MergeWithError, ValuePointerRef,
|
||||||
|
};
|
||||||
|
use milli::update::ChatSettings;
|
||||||
|
|
||||||
use crate::error::deserr_codes::*;
|
use crate::error::deserr_codes::{self, *};
|
||||||
use crate::error::{
|
use crate::error::{
|
||||||
Code, DeserrParseBoolError, DeserrParseIntError, ErrorCode, InvalidTaskDateError,
|
Code, DeserrParseBoolError, DeserrParseIntError, ErrorCode, InvalidTaskDateError,
|
||||||
ParseOffsetDateTimeError,
|
ParseOffsetDateTimeError,
|
||||||
@ -33,6 +36,7 @@ pub struct DeserrError<Format, C: Default + ErrorCode> {
|
|||||||
pub code: Code,
|
pub code: Code,
|
||||||
_phantom: PhantomData<(Format, C)>,
|
_phantom: PhantomData<(Format, C)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Format, C: Default + ErrorCode> DeserrError<Format, C> {
|
impl<Format, C: Default + ErrorCode> DeserrError<Format, C> {
|
||||||
pub fn new(msg: String, code: Code) -> Self {
|
pub fn new(msg: String, code: Code) -> Self {
|
||||||
Self { msg, code, _phantom: PhantomData }
|
Self { msg, code, _phantom: PhantomData }
|
||||||
@ -117,6 +121,16 @@ impl<C: Default + ErrorCode> DeserializeError for DeserrQueryParamError<C> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Deserr<DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> for ChatSettings {
|
||||||
|
fn deserialize_from_value<V: IntoValue>(
|
||||||
|
value: deserr::Value<V>,
|
||||||
|
location: ValuePointerRef,
|
||||||
|
) -> Result<Self, DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> {
|
||||||
|
Deserr::<JsonError>::deserialize_from_value(value, location)
|
||||||
|
.map_err(|e| DeserrError::new(e.to_string(), InvalidSettingsIndexChat.error_code()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn immutable_field_error(field: &str, accepted: &[&str], code: Code) -> DeserrJsonError {
|
pub fn immutable_field_error(field: &str, accepted: &[&str], code: Code) -> DeserrJsonError {
|
||||||
let msg = format!(
|
let msg = format!(
|
||||||
"Immutable field `{field}`: expected one of {}",
|
"Immutable field `{field}`: expected one of {}",
|
||||||
|
@ -186,7 +186,7 @@ impl<E: DeserializeError> Deserr<E> for SettingEmbeddingSettings {
|
|||||||
/// Holds all the settings for an index. `T` can either be `Checked` if they represents settings
|
/// Holds all the settings for an index. `T` can either be `Checked` if they represents settings
|
||||||
/// whose validity is guaranteed, or `Unchecked` if they need to be validated. In the later case, a
|
/// whose validity is guaranteed, or `Unchecked` if they need to be validated. In the later case, a
|
||||||
/// call to `check` will return a `Settings<Checked>` from a `Settings<Unchecked>`.
|
/// call to `check` will return a `Settings<Checked>` from a `Settings<Unchecked>`.
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
|
||||||
#[serde(
|
#[serde(
|
||||||
deny_unknown_fields,
|
deny_unknown_fields,
|
||||||
rename_all = "camelCase",
|
rename_all = "camelCase",
|
||||||
|
@ -8,7 +8,7 @@ use crate::error::ResponseError;
|
|||||||
use crate::settings::{Settings, Unchecked};
|
use crate::settings::{Settings, Unchecked};
|
||||||
use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId};
|
use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, ToSchema)]
|
#[derive(Debug, Clone, PartialEq, Serialize, ToSchema)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
#[schema(rename_all = "camelCase")]
|
#[schema(rename_all = "camelCase")]
|
||||||
pub struct TaskView {
|
pub struct TaskView {
|
||||||
@ -67,7 +67,7 @@ impl TaskView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize, ToSchema)]
|
#[derive(Default, Debug, PartialEq, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
#[schema(rename_all = "camelCase")]
|
#[schema(rename_all = "camelCase")]
|
||||||
pub struct DetailsView {
|
pub struct DetailsView {
|
||||||
|
@ -597,7 +597,7 @@ impl fmt::Display for ParseTaskKindError {
|
|||||||
}
|
}
|
||||||
impl std::error::Error for ParseTaskKindError {}
|
impl std::error::Error for ParseTaskKindError {}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||||
pub enum Details {
|
pub enum Details {
|
||||||
DocumentAdditionOrUpdate {
|
DocumentAdditionOrUpdate {
|
||||||
received_documents: u64,
|
received_documents: u64,
|
||||||
|
@ -26,7 +26,7 @@ use meilisearch_auth::AuthController;
|
|||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use meilisearch_types::heed::RoTxn;
|
use meilisearch_types::heed::RoTxn;
|
||||||
use meilisearch_types::keys::actions;
|
use meilisearch_types::keys::actions;
|
||||||
use meilisearch_types::milli::index::ChatConfig;
|
use meilisearch_types::milli::index::{self, ChatConfig, SearchParameters};
|
||||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||||
use meilisearch_types::milli::update::Setting;
|
use meilisearch_types::milli::update::Setting;
|
||||||
@ -46,12 +46,12 @@ use crate::extractors::authentication::{extract_token_from_request, GuardedData,
|
|||||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||||
use crate::routes::indexes::search::search_kind;
|
use crate::routes::indexes::search::search_kind;
|
||||||
use crate::search::{
|
use crate::search::{
|
||||||
add_search_rules, prepare_search, search_from_kind, HybridQuery, MatchingStrategy, SearchQuery,
|
add_search_rules, prepare_search, search_from_kind, HybridQuery, MatchingStrategy,
|
||||||
SemanticRatio,
|
RankingScoreThreshold, SearchQuery, SemanticRatio, DEFAULT_SEARCH_LIMIT,
|
||||||
|
DEFAULT_SEMANTIC_RATIO,
|
||||||
};
|
};
|
||||||
use crate::search_queue::SearchQueue;
|
use crate::search_queue::SearchQueue;
|
||||||
|
|
||||||
const EMBEDDER_NAME: &str = "openai";
|
|
||||||
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||||
|
|
||||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
@ -168,14 +168,43 @@ async fn process_search_request(
|
|||||||
index_uid: String,
|
index_uid: String,
|
||||||
q: Option<String>,
|
q: Option<String>,
|
||||||
) -> Result<(Index, String), ResponseError> {
|
) -> Result<(Index, String), ResponseError> {
|
||||||
|
// TBD
|
||||||
|
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
||||||
|
|
||||||
|
let index = index_scheduler.index(&index_uid)?;
|
||||||
|
let rtxn = index.static_read_txn()?;
|
||||||
|
let ChatConfig { description: _, prompt: _, search_parameters } = index.chat_config(&rtxn)?;
|
||||||
|
let SearchParameters {
|
||||||
|
hybrid,
|
||||||
|
limit,
|
||||||
|
sort,
|
||||||
|
distinct,
|
||||||
|
matching_strategy,
|
||||||
|
attributes_to_search_on,
|
||||||
|
ranking_score_threshold,
|
||||||
|
} = search_parameters;
|
||||||
|
|
||||||
let mut query = SearchQuery {
|
let mut query = SearchQuery {
|
||||||
q,
|
q,
|
||||||
hybrid: Some(HybridQuery {
|
hybrid: hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| HybridQuery {
|
||||||
semantic_ratio: SemanticRatio::default(),
|
semantic_ratio: SemanticRatio::try_from(semantic_ratio)
|
||||||
embedder: EMBEDDER_NAME.to_string(),
|
.ok()
|
||||||
|
.unwrap_or_else(DEFAULT_SEMANTIC_RATIO),
|
||||||
|
embedder,
|
||||||
}),
|
}),
|
||||||
limit: 20,
|
limit: limit.unwrap_or_else(DEFAULT_SEARCH_LIMIT),
|
||||||
matching_strategy: MatchingStrategy::Frequency,
|
sort: sort,
|
||||||
|
distinct: distinct,
|
||||||
|
matching_strategy: matching_strategy
|
||||||
|
.map(|ms| match ms {
|
||||||
|
index::MatchingStrategy::Last => MatchingStrategy::Last,
|
||||||
|
index::MatchingStrategy::All => MatchingStrategy::All,
|
||||||
|
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
|
||||||
|
})
|
||||||
|
.unwrap_or(MatchingStrategy::Frequency),
|
||||||
|
attributes_to_search_on: attributes_to_search_on,
|
||||||
|
ranking_score_threshold: ranking_score_threshold
|
||||||
|
.and_then(|rst| RankingScoreThreshold::try_from(rst).ok()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -189,19 +218,13 @@ async fn process_search_request(
|
|||||||
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
|
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
|
||||||
add_search_rules(&mut query.filter, search_rules);
|
add_search_rules(&mut query.filter, search_rules);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TBD
|
|
||||||
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
|
||||||
|
|
||||||
let index = index_scheduler.index(&index_uid)?;
|
|
||||||
let search_kind =
|
let search_kind =
|
||||||
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||||
|
|
||||||
let permit = search_queue.try_get_search_permit().await?;
|
let permit = search_queue.try_get_search_permit().await?;
|
||||||
let features = index_scheduler.features();
|
let features = index_scheduler.features();
|
||||||
let index_cloned = index.clone();
|
let index_cloned = index.clone();
|
||||||
let search_result = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> {
|
let output = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> {
|
||||||
let rtxn = index_cloned.read_txn()?;
|
|
||||||
let time_budget = match index_cloned
|
let time_budget = match index_cloned
|
||||||
.search_cutoff(&rtxn)
|
.search_cutoff(&rtxn)
|
||||||
.map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))?
|
.map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))?
|
||||||
@ -214,14 +237,14 @@ async fn process_search_request(
|
|||||||
prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?;
|
prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?;
|
||||||
|
|
||||||
search_from_kind(index_uid, search_kind, search)
|
search_from_kind(index_uid, search_kind, search)
|
||||||
.map(|(search_results, _)| search_results)
|
.map(|(search_results, _)| (rtxn, search_results))
|
||||||
.map_err(ResponseError::from)
|
.map_err(ResponseError::from)
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
permit.drop().await;
|
permit.drop().await;
|
||||||
|
|
||||||
let search_result = search_result?;
|
let output = output?;
|
||||||
if let Ok(ref search_result) = search_result {
|
if let Ok((_, ref search_result)) = output {
|
||||||
// aggregate.succeed(search_result);
|
// aggregate.succeed(search_result);
|
||||||
if search_result.degraded {
|
if search_result.degraded {
|
||||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
||||||
@ -229,8 +252,8 @@ async fn process_search_request(
|
|||||||
}
|
}
|
||||||
// analytics.publish(aggregate, &req);
|
// analytics.publish(aggregate, &req);
|
||||||
|
|
||||||
let search_result = search_result?;
|
let (rtxn, search_result) = output?;
|
||||||
let rtxn = index.read_txn()?;
|
// let rtxn = index.read_txn()?;
|
||||||
let render_alloc = Bump::new();
|
let render_alloc = Bump::new();
|
||||||
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
|
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
|
||||||
let text = formatted.join("\n");
|
let text = formatted.join("\n");
|
||||||
|
@ -122,6 +122,7 @@ pub struct SearchQuery {
|
|||||||
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
|
||||||
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
|
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
|
||||||
pub struct RankingScoreThreshold(f64);
|
pub struct RankingScoreThreshold(f64);
|
||||||
|
|
||||||
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
||||||
type Error = InvalidSearchRankingScoreThreshold;
|
type Error = InvalidSearchRankingScoreThreshold;
|
||||||
|
|
||||||
@ -279,8 +280,8 @@ impl fmt::Debug for SearchQuery {
|
|||||||
#[deserr(error = DeserrJsonError<InvalidSearchHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
|
#[deserr(error = DeserrJsonError<InvalidSearchHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct HybridQuery {
|
pub struct HybridQuery {
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)]
|
||||||
#[schema(value_type = f32, default)]
|
#[schema(default, value_type = f32)]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub semantic_ratio: SemanticRatio,
|
pub semantic_ratio: SemanticRatio,
|
||||||
#[deserr(error = DeserrJsonError<InvalidSearchEmbedder>)]
|
#[deserr(error = DeserrJsonError<InvalidSearchEmbedder>)]
|
||||||
|
@ -1695,7 +1695,7 @@ impl Index {
|
|||||||
|
|
||||||
pub fn chat_config(&self, txn: &RoTxn<'_>) -> heed::Result<ChatConfig> {
|
pub fn chat_config(&self, txn: &RoTxn<'_>) -> heed::Result<ChatConfig> {
|
||||||
self.main
|
self.main
|
||||||
.remap_types::<Str, SerdeBincode<_>>()
|
.remap_types::<Str, SerdeJson<_>>()
|
||||||
.get(txn, main_key::CHAT)
|
.get(txn, main_key::CHAT)
|
||||||
.map(|o| o.unwrap_or_default())
|
.map(|o| o.unwrap_or_default())
|
||||||
}
|
}
|
||||||
@ -1705,7 +1705,7 @@ impl Index {
|
|||||||
txn: &mut RwTxn<'_>,
|
txn: &mut RwTxn<'_>,
|
||||||
val: &ChatConfig,
|
val: &ChatConfig,
|
||||||
) -> heed::Result<()> {
|
) -> heed::Result<()> {
|
||||||
self.main.remap_types::<Str, SerdeBincode<_>>().put(txn, main_key::CHAT, &val)
|
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::CHAT, &val)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn delete_chat_config(&self, txn: &mut RwTxn<'_>) -> heed::Result<bool> {
|
pub(crate) fn delete_chat_config(&self, txn: &mut RwTxn<'_>) -> heed::Result<bool> {
|
||||||
@ -1943,15 +1943,54 @@ pub struct ChatConfig {
|
|||||||
pub description: String,
|
pub description: String,
|
||||||
/// Contains the document template and max template length.
|
/// Contains the document template and max template length.
|
||||||
pub prompt: PromptData,
|
pub prompt: PromptData,
|
||||||
|
pub search_parameters: SearchParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SearchParameters {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hybrid: Option<HybridQuery>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub limit: Option<usize>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sort: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub distinct: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub matching_strategy: Option<MatchingStrategy>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub attributes_to_search_on: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ranking_score_threshold: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct HybridQuery {
|
||||||
|
pub semantic_ratio: f32,
|
||||||
|
pub embedder: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PrefixSettings {
|
pub struct PrefixSettings {
|
||||||
pub prefix_count_threshold: usize,
|
pub prefix_count_threshold: usize,
|
||||||
pub max_prefix_length: usize,
|
pub max_prefix_length: usize,
|
||||||
pub compute_prefixes: PrefixSearch,
|
pub compute_prefixes: PrefixSearch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub enum MatchingStrategy {
|
||||||
|
/// Remove query words from last to first
|
||||||
|
Last,
|
||||||
|
/// All query words are mandatory
|
||||||
|
All,
|
||||||
|
/// Remove query words from the most frequent to the least
|
||||||
|
Frequency,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub enum PrefixSearch {
|
pub enum PrefixSearch {
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use deserr::errors::JsonError;
|
||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use crate::index::ChatConfig;
|
use crate::index::{self, ChatConfig, SearchParameters};
|
||||||
use crate::prompt::{default_max_bytes, PromptData};
|
use crate::prompt::{default_max_bytes, PromptData};
|
||||||
use crate::update::Setting;
|
use crate::update::Setting;
|
||||||
|
use crate::TermsMatchingStrategy;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
|
||||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||||
#[deserr(deny_unknown_fields, rename_all = camelCase)]
|
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
|
||||||
pub struct ChatSettings {
|
pub struct ChatSettings {
|
||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
#[deserr(default)]
|
#[deserr(default)]
|
||||||
@ -29,17 +34,226 @@ pub struct ChatSettings {
|
|||||||
#[deserr(default)]
|
#[deserr(default)]
|
||||||
#[schema(value_type = Option<usize>)]
|
#[schema(value_type = Option<usize>)]
|
||||||
pub document_template_max_bytes: Setting<usize>,
|
pub document_template_max_bytes: Setting<usize>,
|
||||||
|
|
||||||
|
/// The search parameters to use for the LLM.
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<ChatSearchParams>)]
|
||||||
|
pub search_parameters: Setting<ChatSearchParams>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ChatConfig> for ChatSettings {
|
impl From<ChatConfig> for ChatSettings {
|
||||||
fn from(config: ChatConfig) -> Self {
|
fn from(config: ChatConfig) -> Self {
|
||||||
let ChatConfig { description, prompt: PromptData { template, max_bytes } } = config;
|
let ChatConfig {
|
||||||
|
description,
|
||||||
|
prompt: PromptData { template, max_bytes },
|
||||||
|
search_parameters,
|
||||||
|
} = config;
|
||||||
ChatSettings {
|
ChatSettings {
|
||||||
description: Setting::Set(description),
|
description: Setting::Set(description),
|
||||||
document_template: Setting::Set(template),
|
document_template: Setting::Set(template),
|
||||||
document_template_max_bytes: Setting::Set(
|
document_template_max_bytes: Setting::Set(
|
||||||
max_bytes.unwrap_or(default_max_bytes()).get(),
|
max_bytes.unwrap_or(default_max_bytes()).get(),
|
||||||
),
|
),
|
||||||
|
search_parameters: Setting::Set({
|
||||||
|
let SearchParameters {
|
||||||
|
hybrid,
|
||||||
|
limit,
|
||||||
|
sort,
|
||||||
|
distinct,
|
||||||
|
matching_strategy,
|
||||||
|
attributes_to_search_on,
|
||||||
|
ranking_score_threshold,
|
||||||
|
} = search_parameters;
|
||||||
|
|
||||||
|
let hybrid = hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| {
|
||||||
|
HybridQuery { semantic_ratio: SemanticRatio(semantic_ratio), embedder }
|
||||||
|
});
|
||||||
|
|
||||||
|
let matching_strategy = matching_strategy.map(|ms| match ms {
|
||||||
|
index::MatchingStrategy::Last => MatchingStrategy::Last,
|
||||||
|
index::MatchingStrategy::All => MatchingStrategy::All,
|
||||||
|
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
|
||||||
|
});
|
||||||
|
|
||||||
|
let ranking_score_threshold = ranking_score_threshold.map(RankingScoreThreshold);
|
||||||
|
|
||||||
|
ChatSearchParams {
|
||||||
|
hybrid: Setting::some_or_not_set(hybrid),
|
||||||
|
limit: Setting::some_or_not_set(limit),
|
||||||
|
sort: Setting::some_or_not_set(sort),
|
||||||
|
distinct: Setting::some_or_not_set(distinct),
|
||||||
|
matching_strategy: Setting::some_or_not_set(matching_strategy),
|
||||||
|
attributes_to_search_on: Setting::some_or_not_set(attributes_to_search_on),
|
||||||
|
ranking_score_threshold: Setting::some_or_not_set(ranking_score_threshold),
|
||||||
|
}
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
|
||||||
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||||
|
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
|
||||||
|
pub struct ChatSearchParams {
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<HybridQuery>)]
|
||||||
|
pub hybrid: Setting<HybridQuery>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default = Setting::Set(20))]
|
||||||
|
#[schema(value_type = Option<usize>)]
|
||||||
|
pub limit: Setting<usize>,
|
||||||
|
|
||||||
|
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
// #[deserr(default)]
|
||||||
|
// pub attributes_to_retrieve: Option<BTreeSet<String>>,
|
||||||
|
|
||||||
|
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
// #[deserr(default)]
|
||||||
|
// pub filter: Option<Value>,
|
||||||
|
//
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<Vec<String>>)]
|
||||||
|
pub sort: Setting<Vec<String>>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<String>)]
|
||||||
|
pub distinct: Setting<String>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<MatchingStrategy>)]
|
||||||
|
pub matching_strategy: Setting<MatchingStrategy>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<Vec<String>>)]
|
||||||
|
pub attributes_to_search_on: Setting<Vec<String>>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
#[schema(value_type = Option<RankingScoreThreshold>)]
|
||||||
|
pub ranking_score_threshold: Setting<RankingScoreThreshold>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
#[deserr(error = JsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||||
|
pub struct HybridQuery {
|
||||||
|
#[deserr(default)]
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default, value_type = f32)]
|
||||||
|
pub semantic_ratio: SemanticRatio,
|
||||||
|
#[schema(value_type = String)]
|
||||||
|
pub embedder: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
|
||||||
|
pub struct SemanticRatio(f32);
|
||||||
|
|
||||||
|
impl Default for SemanticRatio {
|
||||||
|
fn default() -> Self {
|
||||||
|
SemanticRatio(0.5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::TryFrom<f32> for SemanticRatio {
|
||||||
|
type Error = InvalidSearchSemanticRatio;
|
||||||
|
|
||||||
|
fn try_from(f: f32) -> Result<Self, Self::Error> {
|
||||||
|
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
|
||||||
|
#[allow(clippy::manual_range_contains)]
|
||||||
|
if f > 1.0 || f < 0.0 {
|
||||||
|
Err(InvalidSearchSemanticRatio)
|
||||||
|
} else {
|
||||||
|
Ok(SemanticRatio(f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct InvalidSearchSemanticRatio;
|
||||||
|
|
||||||
|
impl Error for InvalidSearchSemanticRatio {}
|
||||||
|
|
||||||
|
impl fmt::Display for InvalidSearchSemanticRatio {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for SemanticRatio {
|
||||||
|
type Target = f32;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)]
|
||||||
|
#[deserr(rename_all = camelCase)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub enum MatchingStrategy {
|
||||||
|
/// Remove query words from last to first
|
||||||
|
Last,
|
||||||
|
/// All query words are mandatory
|
||||||
|
All,
|
||||||
|
/// Remove query words from the most frequent to the least
|
||||||
|
Frequency,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MatchingStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Last
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<MatchingStrategy> for TermsMatchingStrategy {
|
||||||
|
fn from(other: MatchingStrategy) -> Self {
|
||||||
|
match other {
|
||||||
|
MatchingStrategy::Last => Self::Last,
|
||||||
|
MatchingStrategy::All => Self::All,
|
||||||
|
MatchingStrategy::Frequency => Self::Frequency,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize, Deserialize)]
|
||||||
|
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
|
||||||
|
pub struct RankingScoreThreshold(pub f64);
|
||||||
|
|
||||||
|
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
|
||||||
|
type Error = InvalidSearchRankingScoreThreshold;
|
||||||
|
|
||||||
|
fn try_from(f: f64) -> Result<Self, Self::Error> {
|
||||||
|
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
|
||||||
|
#[allow(clippy::manual_range_contains)]
|
||||||
|
if f > 1.0 || f < 0.0 {
|
||||||
|
Err(InvalidSearchRankingScoreThreshold)
|
||||||
|
} else {
|
||||||
|
Ok(RankingScoreThreshold(f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct InvalidSearchRankingScoreThreshold;
|
||||||
|
|
||||||
|
impl Error for InvalidSearchRankingScoreThreshold {}
|
||||||
|
|
||||||
|
impl fmt::Display for InvalidSearchRankingScoreThreshold {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -11,6 +11,7 @@ use roaring::RoaringBitmap;
|
|||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use time::OffsetDateTime;
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use super::chat::{ChatSearchParams, RankingScoreThreshold};
|
||||||
use super::del_add::{DelAdd, DelAddOperation};
|
use super::del_add::{DelAdd, DelAddOperation};
|
||||||
use super::index_documents::{IndexDocumentsConfig, Transform};
|
use super::index_documents::{IndexDocumentsConfig, Transform};
|
||||||
use super::{ChatSettings, IndexerConfig};
|
use super::{ChatSettings, IndexerConfig};
|
||||||
@ -22,8 +23,8 @@ use crate::error::UserError;
|
|||||||
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
|
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
|
||||||
use crate::filterable_attributes_rules::match_faceted_field;
|
use crate::filterable_attributes_rules::match_faceted_field;
|
||||||
use crate::index::{
|
use crate::index::{
|
||||||
ChatConfig, IndexEmbeddingConfig, PrefixSearch, DEFAULT_MIN_WORD_LEN_ONE_TYPO,
|
ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch,
|
||||||
DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
|
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
|
||||||
};
|
};
|
||||||
use crate::order_by_map::OrderByMap;
|
use crate::order_by_map::OrderByMap;
|
||||||
use crate::prompt::{default_max_bytes, PromptData};
|
use crate::prompt::{default_max_bytes, PromptData};
|
||||||
@ -1263,11 +1264,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||||||
description: new_description,
|
description: new_description,
|
||||||
document_template: new_document_template,
|
document_template: new_document_template,
|
||||||
document_template_max_bytes: new_document_template_max_bytes,
|
document_template_max_bytes: new_document_template_max_bytes,
|
||||||
|
search_parameters: new_search_parameters,
|
||||||
}) => {
|
}) => {
|
||||||
let mut old = self.index.chat_config(self.wtxn)?;
|
let mut old = self.index.chat_config(self.wtxn)?;
|
||||||
let ChatConfig {
|
let ChatConfig {
|
||||||
ref mut description,
|
ref mut description,
|
||||||
prompt: PromptData { ref mut template, ref mut max_bytes },
|
prompt: PromptData { ref mut template, ref mut max_bytes },
|
||||||
|
ref mut search_parameters,
|
||||||
} = old;
|
} = old;
|
||||||
|
|
||||||
match new_description {
|
match new_description {
|
||||||
@ -1288,6 +1291,85 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
|||||||
Setting::NotSet => (),
|
Setting::NotSet => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match new_search_parameters {
|
||||||
|
Setting::Set(sp) => {
|
||||||
|
let ChatSearchParams {
|
||||||
|
hybrid,
|
||||||
|
limit,
|
||||||
|
sort,
|
||||||
|
distinct,
|
||||||
|
matching_strategy,
|
||||||
|
attributes_to_search_on,
|
||||||
|
ranking_score_threshold,
|
||||||
|
} = sp;
|
||||||
|
|
||||||
|
match hybrid {
|
||||||
|
Setting::Set(hybrid) => {
|
||||||
|
search_parameters.hybrid = Some(crate::index::HybridQuery {
|
||||||
|
semantic_ratio: *hybrid.semantic_ratio,
|
||||||
|
embedder: hybrid.embedder.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Setting::Reset => search_parameters.hybrid = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
match limit {
|
||||||
|
Setting::Set(limit) => search_parameters.limit = Some(*limit),
|
||||||
|
Setting::Reset => search_parameters.limit = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
match sort {
|
||||||
|
Setting::Set(sort) => search_parameters.sort = Some(sort.clone()),
|
||||||
|
Setting::Reset => search_parameters.sort = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
match distinct {
|
||||||
|
Setting::Set(distinct) => {
|
||||||
|
search_parameters.distinct = Some(distinct.clone())
|
||||||
|
}
|
||||||
|
Setting::Reset => search_parameters.distinct = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
match matching_strategy {
|
||||||
|
Setting::Set(matching_strategy) => {
|
||||||
|
let strategy = match matching_strategy {
|
||||||
|
super::chat::MatchingStrategy::Last => MatchingStrategy::Last,
|
||||||
|
super::chat::MatchingStrategy::All => MatchingStrategy::All,
|
||||||
|
super::chat::MatchingStrategy::Frequency => {
|
||||||
|
MatchingStrategy::Frequency
|
||||||
|
}
|
||||||
|
};
|
||||||
|
search_parameters.matching_strategy = Some(strategy)
|
||||||
|
}
|
||||||
|
Setting::Reset => search_parameters.matching_strategy = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
match attributes_to_search_on {
|
||||||
|
Setting::Set(attributes_to_search_on) => {
|
||||||
|
search_parameters.attributes_to_search_on =
|
||||||
|
Some(attributes_to_search_on.clone())
|
||||||
|
}
|
||||||
|
Setting::Reset => search_parameters.attributes_to_search_on = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
match ranking_score_threshold {
|
||||||
|
Setting::Set(RankingScoreThreshold(score)) => {
|
||||||
|
search_parameters.ranking_score_threshold = Some(*score)
|
||||||
|
}
|
||||||
|
Setting::Reset => search_parameters.ranking_score_threshold = None,
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Setting::Reset => *search_parameters = Default::default(),
|
||||||
|
Setting::NotSet => (),
|
||||||
|
}
|
||||||
|
|
||||||
self.index.put_chat_config(self.wtxn, &old)?;
|
self.index.put_chat_config(self.wtxn, &old)?;
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user