Compare commits

...

19 Commits

Author SHA1 Message Date
Clément Renault
d45647b58d Call specific tools to show progression and results. 2025-05-23 17:25:09 +02:00
Clément Renault
045a1b1e75 Introduce a lot of search parameters and make Deserr happy 2025-05-22 15:34:49 +02:00
Clément Renault
293ac45b7c Expose a well defined set of sources 2025-05-22 10:42:36 +02:00
Clément Renault
8aa7ae912a Add the index descriptions to the function description 2025-05-22 10:40:43 +02:00
Clément Renault
7492b6e669 redact the chat settings API key 2025-05-21 21:18:18 +02:00
Clément Renault
fcc0d43a62 Better chat settings management 2025-05-21 21:06:11 +02:00
Clément Renault
72d4998dce Correctly list the chat settings key actions 2025-05-21 16:24:51 +02:00
Clément Renault
fde11573da Always use the frequency matching strategy 2025-05-21 16:18:37 +02:00
Clément Renault
41220f786b Remove templating validation 2025-05-21 16:10:31 +02:00
Clément Renault
4d59fdb65d Correctly support document templates on the chat API 2025-05-21 15:32:34 +02:00
Clément Renault
3e51c0a4c1 Introduce the new index chat settings 2025-05-21 11:07:06 +02:00
Clément Renault
91c6ab8392 Make sure errorneous calls are handled and forwarded to the LLM 2025-05-20 18:01:08 +02:00
Clément Renault
beff6adeb1 Catch invalid argument calls to search function 2025-05-20 17:55:21 +02:00
Clément Renault
18eab165a7 Support multiple indexes and not only main 2025-05-20 17:43:24 +02:00
Clément Renault
5c6b63df65 Limit the number of internal loop calls and change the function name 2025-05-20 16:44:28 +02:00
Clément Renault
7266aed770 Correctly support tenant tokens and filters 2025-05-20 16:15:49 +02:00
Clément Renault
bae6c98aa3 Stream errors 2025-05-20 12:23:22 +02:00
Clément Renault
42c95cf3c4 Stop the stream when the connexion stops and chnage the events 2025-05-20 12:05:51 +02:00
Clément Renault
4f919db344 Generate a new default chat API key 2025-05-20 11:00:19 +02:00
26 changed files with 1349 additions and 337 deletions

1
Cargo.lock generated
View File

@@ -3782,6 +3782,7 @@ dependencies = [
"brotli 6.0.0",
"bstr",
"build-info",
"bumpalo",
"byte-unit",
"bytes",
"cargo_toml",

View File

@@ -398,6 +398,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> {
search_cutoff_ms: v6::Setting::NotSet,
facet_search: v6::Setting::NotSet,
prefix_search: v6::Setting::NotSet,
chat: v6::Setting::NotSet,
_kind: std::marker::PhantomData,
}
}

View File

@@ -165,6 +165,7 @@ impl AuthController {
}
}
#[derive(Debug)]
pub struct AuthFilter {
search_rules: Option<SearchRules>,
key_authorized_indexes: SearchRules,
@@ -351,6 +352,7 @@ pub struct IndexSearchRules {
fn generate_default_keys(store: &HeedAuthStore) -> Result<()> {
store.put_api_key(Key::default_admin())?;
store.put_api_key(Key::default_search())?;
store.put_api_key(Key::default_chat())?;
Ok(())
}

View File

@@ -4,9 +4,12 @@ use std::marker::PhantomData;
use std::ops::ControlFlow;
use deserr::errors::{JsonError, QueryParamError};
use deserr::{take_cf_content, DeserializeError, IntoValue, MergeWithError, ValuePointerRef};
use deserr::{
take_cf_content, DeserializeError, Deserr, IntoValue, MergeWithError, ValuePointerRef,
};
use milli::update::ChatSettings;
use crate::error::deserr_codes::*;
use crate::error::deserr_codes::{self, *};
use crate::error::{
Code, DeserrParseBoolError, DeserrParseIntError, ErrorCode, InvalidTaskDateError,
ParseOffsetDateTimeError,
@@ -33,6 +36,7 @@ pub struct DeserrError<Format, C: Default + ErrorCode> {
pub code: Code,
_phantom: PhantomData<(Format, C)>,
}
impl<Format, C: Default + ErrorCode> DeserrError<Format, C> {
pub fn new(msg: String, code: Code) -> Self {
Self { msg, code, _phantom: PhantomData }
@@ -117,6 +121,16 @@ impl<C: Default + ErrorCode> DeserializeError for DeserrQueryParamError<C> {
}
}
impl Deserr<DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> for ChatSettings {
fn deserialize_from_value<V: IntoValue>(
value: deserr::Value<V>,
location: ValuePointerRef,
) -> Result<Self, DeserrError<DeserrJson, deserr_codes::InvalidSettingsIndexChat>> {
Deserr::<JsonError>::deserialize_from_value(value, location)
.map_err(|e| DeserrError::new(e.to_string(), InvalidSettingsIndexChat.error_code()))
}
}
pub fn immutable_field_error(field: &str, accepted: &[&str], code: Code) -> DeserrJsonError {
let msg = format!(
"Immutable field `{field}`: expected one of {}",

View File

@@ -387,7 +387,8 @@ VectorEmbeddingError , InvalidRequest , BAD_REQUEST ;
NotFoundSimilarId , InvalidRequest , BAD_REQUEST ;
InvalidDocumentEditionContext , InvalidRequest , BAD_REQUEST ;
InvalidDocumentEditionFunctionFilter , InvalidRequest , BAD_REQUEST ;
EditDocumentsByFunctionError , InvalidRequest , BAD_REQUEST
EditDocumentsByFunctionError , InvalidRequest , BAD_REQUEST ;
InvalidSettingsIndexChat , InvalidRequest , BAD_REQUEST
}
impl ErrorCode for JoinError {

View File

@@ -158,6 +158,21 @@ impl Key {
updated_at: now,
}
}
pub fn default_chat() -> Self {
let now = OffsetDateTime::now_utc();
let uid = Uuid::new_v4();
Self {
name: Some("Default Chat API Key".to_string()),
description: Some("Use it to chat and search from the frontend".to_string()),
uid,
actions: vec![Action::Chat, Action::Search],
indexes: vec![IndexUidPattern::all()],
expires_at: None,
created_at: now,
updated_at: now,
}
}
}
fn parse_expiration_date(
@@ -310,7 +325,10 @@ pub enum Action {
NetworkUpdate,
#[serde(rename = "chat.get")]
#[deserr(rename = "chat.get")]
ChatGet,
Chat,
#[serde(rename = "chatSettings.*")]
#[deserr(rename = "chatSettings.*")]
ChatSettingsAll,
#[serde(rename = "chatSettings.get")]
#[deserr(rename = "chatSettings.get")]
ChatSettingsGet,
@@ -342,6 +360,9 @@ impl Action {
SETTINGS_ALL => Some(Self::SettingsAll),
SETTINGS_GET => Some(Self::SettingsGet),
SETTINGS_UPDATE => Some(Self::SettingsUpdate),
CHAT_SETTINGS_ALL => Some(Self::ChatSettingsAll),
CHAT_SETTINGS_GET => Some(Self::ChatSettingsGet),
CHAT_SETTINGS_UPDATE => Some(Self::ChatSettingsUpdate),
STATS_ALL => Some(Self::StatsAll),
STATS_GET => Some(Self::StatsGet),
METRICS_ALL => Some(Self::MetricsAll),
@@ -358,7 +379,7 @@ impl Action {
EXPERIMENTAL_FEATURES_UPDATE => Some(Self::ExperimentalFeaturesUpdate),
NETWORK_GET => Some(Self::NetworkGet),
NETWORK_UPDATE => Some(Self::NetworkUpdate),
CHAT_GET => Some(Self::ChatGet),
CHAT => Some(Self::Chat),
_otherwise => None,
}
}
@@ -408,7 +429,8 @@ pub mod actions {
pub const NETWORK_GET: u8 = NetworkGet.repr();
pub const NETWORK_UPDATE: u8 = NetworkUpdate.repr();
pub const CHAT_GET: u8 = ChatGet.repr();
pub const CHAT: u8 = Chat.repr();
pub const CHAT_SETTINGS_ALL: u8 = ChatSettingsAll.repr();
pub const CHAT_SETTINGS_GET: u8 = ChatSettingsGet.repr();
pub const CHAT_SETTINGS_UPDATE: u8 = ChatSettingsUpdate.repr();
}

View File

@@ -11,6 +11,7 @@ use fst::IntoStreamer;
use milli::disabled_typos_terms::DisabledTyposTerms;
use milli::index::{IndexEmbeddingConfig, PrefixSearch};
use milli::proximity::ProximityPrecision;
pub use milli::update::ChatSettings;
use milli::update::Setting;
use milli::{Criterion, CriterionError, FilterableAttributesRule, Index, DEFAULT_VALUES_PER_FACET};
use serde::{Deserialize, Serialize, Serializer};
@@ -185,7 +186,7 @@ impl<E: DeserializeError> Deserr<E> for SettingEmbeddingSettings {
/// Holds all the settings for an index. `T` can either be `Checked` if they represents settings
/// whose validity is guaranteed, or `Unchecked` if they need to be validated. In the later case, a
/// call to `check` will return a `Settings<Checked>` from a `Settings<Unchecked>`.
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(
deny_unknown_fields,
rename_all = "camelCase",
@@ -199,72 +200,86 @@ pub struct Settings<T> {
#[deserr(default, error = DeserrJsonError<InvalidSettingsDisplayedAttributes>)]
#[schema(value_type = Option<Vec<String>>, example = json!(["id", "title", "description", "url"]))]
pub displayed_attributes: WildcardSetting,
/// Fields in which to search for matching query words sorted by order of importance.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsSearchableAttributes>)]
#[schema(value_type = Option<Vec<String>>, example = json!(["title", "description"]))]
pub searchable_attributes: WildcardSetting,
/// Attributes to use for faceting and filtering. See [Filtering and Faceted Search](https://www.meilisearch.com/docs/learn/filtering_and_sorting/search_with_facet_filters).
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsFilterableAttributes>)]
#[schema(value_type = Option<Vec<FilterableAttributesRule>>, example = json!(["release_date", "genre"]))]
pub filterable_attributes: Setting<Vec<FilterableAttributesRule>>,
/// Attributes to use when sorting search results.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsSortableAttributes>)]
#[schema(value_type = Option<Vec<String>>, example = json!(["release_date"]))]
pub sortable_attributes: Setting<BTreeSet<String>>,
/// List of ranking rules sorted by order of importance. The order is customizable.
/// [A list of ordered built-in ranking rules](https://www.meilisearch.com/docs/learn/relevancy/relevancy).
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsRankingRules>)]
#[schema(value_type = Option<Vec<String>>, example = json!([RankingRuleView::Words, RankingRuleView::Typo, RankingRuleView::Proximity, RankingRuleView::Attribute, RankingRuleView::Exactness]))]
pub ranking_rules: Setting<Vec<RankingRuleView>>,
/// List of words ignored when present in search queries.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsStopWords>)]
#[schema(value_type = Option<Vec<String>>, example = json!(["the", "a", "them", "their"]))]
pub stop_words: Setting<BTreeSet<String>>,
/// List of characters not delimiting where one term begins and ends.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsNonSeparatorTokens>)]
#[schema(value_type = Option<Vec<String>>, example = json!([" ", "\n"]))]
pub non_separator_tokens: Setting<BTreeSet<String>>,
/// List of characters delimiting where one term begins and ends.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsSeparatorTokens>)]
#[schema(value_type = Option<Vec<String>>, example = json!(["S"]))]
pub separator_tokens: Setting<BTreeSet<String>>,
/// List of strings Meilisearch should parse as a single term.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsDictionary>)]
#[schema(value_type = Option<Vec<String>>, example = json!(["iPhone pro"]))]
pub dictionary: Setting<BTreeSet<String>>,
/// List of associated words treated similarly. A word associated to an array of word as synonyms.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsSynonyms>)]
#[schema(value_type = Option<BTreeMap<String, Vec<String>>>, example = json!({ "he": ["she", "they", "them"], "phone": ["iPhone", "android"]}))]
pub synonyms: Setting<BTreeMap<String, Vec<String>>>,
/// Search returns documents with distinct (different) values of the given field.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsDistinctAttribute>)]
#[schema(value_type = Option<String>, example = json!("sku"))]
pub distinct_attribute: Setting<String>,
/// Precision level when calculating the proximity ranking rule.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsProximityPrecision>)]
#[schema(value_type = Option<String>, example = json!(ProximityPrecisionView::ByAttribute))]
pub proximity_precision: Setting<ProximityPrecisionView>,
/// Customize typo tolerance feature.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsTypoTolerance>)]
#[schema(value_type = Option<TypoSettings>, example = json!({ "enabled": true, "disableOnAttributes": ["title"]}))]
pub typo_tolerance: Setting<TypoSettings>,
/// Faceting settings.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsFaceting>)]
#[schema(value_type = Option<FacetingSettings>, example = json!({ "maxValuesPerFacet": 10, "sortFacetValuesBy": { "genre": FacetValuesSort::Count }}))]
pub faceting: Setting<FacetingSettings>,
/// Pagination settings.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)]
@@ -276,24 +291,34 @@ pub struct Settings<T> {
#[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)]
#[schema(value_type = Option<BTreeMap<String, SettingEmbeddingSettings>>)]
pub embedders: Setting<BTreeMap<String, SettingEmbeddingSettings>>,
/// Maximum duration of a search query.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsSearchCutoffMs>)]
#[schema(value_type = Option<u64>, example = json!(50))]
pub search_cutoff_ms: Setting<u64>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsLocalizedAttributes>)]
#[schema(value_type = Option<Vec<LocalizedAttributesRuleView>>, example = json!(50))]
pub localized_attributes: Setting<Vec<LocalizedAttributesRuleView>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsFacetSearch>)]
#[schema(value_type = Option<bool>, example = json!(true))]
pub facet_search: Setting<bool>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsPrefixSearch>)]
#[schema(value_type = Option<PrefixSearchSettings>, example = json!("Hemlo"))]
pub prefix_search: Setting<PrefixSearchSettings>,
/// Customize the chat prompting.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsIndexChat>)]
#[schema(value_type = Option<ChatSettings>)]
pub chat: Setting<ChatSettings>,
#[serde(skip)]
#[deserr(skip)]
pub _kind: PhantomData<T>,
@@ -359,6 +384,7 @@ impl Settings<Checked> {
localized_attributes: Setting::Reset,
facet_search: Setting::Reset,
prefix_search: Setting::Reset,
chat: Setting::Reset,
_kind: PhantomData,
}
}
@@ -385,6 +411,7 @@ impl Settings<Checked> {
localized_attributes: localized_attributes_rules,
facet_search,
prefix_search,
chat,
_kind,
} = self;
@@ -409,6 +436,7 @@ impl Settings<Checked> {
localized_attributes: localized_attributes_rules,
facet_search,
prefix_search,
chat,
_kind: PhantomData,
}
}
@@ -459,6 +487,7 @@ impl Settings<Unchecked> {
localized_attributes: self.localized_attributes,
facet_search: self.facet_search,
prefix_search: self.prefix_search,
chat: self.chat,
_kind: PhantomData,
}
}
@@ -533,8 +562,9 @@ impl Settings<Unchecked> {
Setting::Set(this)
}
},
prefix_search: other.prefix_search.or(self.prefix_search),
facet_search: other.facet_search.or(self.facet_search),
prefix_search: other.prefix_search.or(self.prefix_search),
chat: other.chat.clone().or(self.chat.clone()),
_kind: PhantomData,
}
}
@@ -573,6 +603,7 @@ pub fn apply_settings_to_builder(
localized_attributes: localized_attributes_rules,
facet_search,
prefix_search,
chat,
_kind,
} = settings;
@@ -783,6 +814,12 @@ pub fn apply_settings_to_builder(
Setting::Reset => builder.reset_facet_search(),
Setting::NotSet => (),
}
match chat {
Setting::Set(chat) => builder.set_chat(chat.clone()),
Setting::Reset => builder.reset_chat(),
Setting::NotSet => (),
}
}
pub enum SecretPolicy {
@@ -880,14 +917,11 @@ pub fn settings(
})
.collect();
let embedders = Setting::Set(embedders);
let search_cutoff_ms = index.search_cutoff(rtxn)?;
let localized_attributes_rules = index.localized_attributes_rules(rtxn)?;
let prefix_search = index.prefix_search(rtxn)?.map(PrefixSearchSettings::from);
let facet_search = index.facet_search(rtxn)?;
let chat = index.chat_config(rtxn).map(ChatSettings::from)?;
let mut settings = Settings {
displayed_attributes: match displayed_attributes {
@@ -925,8 +959,9 @@ pub fn settings(
Some(rules) => Setting::Set(rules.into_iter().map(|r| r.into()).collect()),
None => Setting::Reset,
},
prefix_search: Setting::Set(prefix_search.unwrap_or_default()),
facet_search: Setting::Set(facet_search),
prefix_search: Setting::Set(prefix_search.unwrap_or_default()),
chat: Setting::Set(chat),
_kind: PhantomData,
};
@@ -1154,6 +1189,7 @@ pub(crate) mod test {
search_cutoff_ms: Setting::NotSet,
facet_search: Setting::NotSet,
prefix_search: Setting::NotSet,
chat: Setting::NotSet,
_kind: PhantomData::<Unchecked>,
};
@@ -1185,6 +1221,8 @@ pub(crate) mod test {
search_cutoff_ms: Setting::NotSet,
facet_search: Setting::NotSet,
prefix_search: Setting::NotSet,
chat: Setting::NotSet,
_kind: PhantomData::<Unchecked>,
};

View File

@@ -8,7 +8,7 @@ use crate::error::ResponseError;
use crate::settings::{Settings, Unchecked};
use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, ToSchema)]
#[derive(Debug, Clone, PartialEq, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
#[schema(rename_all = "camelCase")]
pub struct TaskView {
@@ -67,7 +67,7 @@ impl TaskView {
}
}
#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize, ToSchema)]
#[derive(Default, Debug, PartialEq, Clone, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
#[schema(rename_all = "camelCase")]
pub struct DetailsView {

View File

@@ -597,7 +597,7 @@ impl fmt::Display for ParseTaskKindError {
}
impl std::error::Error for ParseTaskKindError {}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum Details {
DocumentAdditionOrUpdate {
received_documents: u64,

View File

@@ -32,6 +32,7 @@ async-trait = "0.1.85"
bstr = "1.11.3"
byte-unit = { version = "5.1.6", features = ["serde"] }
bytes = "1.9.0"
bumpalo = "3.16.0"
clap = { version = "4.5.24", features = ["derive", "env"] }
crossbeam-channel = "0.5.15"
deserr = { version = "0.6.3", features = ["actix-web"] }

View File

@@ -4,6 +4,7 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use actix_web::http::header::AUTHORIZATION;
use actix_web::web::Data;
use actix_web::FromRequest;
pub use error::AuthenticationError;
@@ -94,36 +95,44 @@ impl<P: Policy + 'static, D: 'static + Clone> FromRequest for GuardedData<P, D>
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
match req.app_data::<Data<AuthController>>().cloned() {
Some(auth) => match req
.headers()
.get("Authorization")
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
{
Some(mut type_token) => match type_token.next() {
Some("Bearer") => {
// TODO: find a less hardcoded way?
let index = req.match_info().get("index_uid");
match type_token.next() {
Some(token) => Box::pin(Self::auth_bearer(
auth,
token.to_string(),
index.map(String::from),
req.app_data::<D>().cloned(),
)),
None => Box::pin(err(AuthenticationError::InvalidToken.into())),
}
}
_otherwise => {
Box::pin(err(AuthenticationError::MissingAuthorizationHeader.into()))
}
},
None => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
Some(auth) => match extract_token_from_request(req) {
Ok(Some(token)) => {
// TODO: find a less hardcoded way?
let index = req.match_info().get("index_uid");
Box::pin(Self::auth_bearer(
auth,
token.to_string(),
index.map(String::from),
req.app_data::<D>().cloned(),
))
}
Ok(None) => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
Err(e) => Box::pin(err(e.into())),
},
None => Box::pin(err(AuthenticationError::IrretrievableState.into())),
}
}
}
pub fn extract_token_from_request(
req: &actix_web::HttpRequest,
) -> Result<Option<&str>, AuthenticationError> {
match req
.headers()
.get(AUTHORIZATION)
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
{
Some(mut type_token) => match type_token.next() {
Some("Bearer") => match type_token.next() {
Some(token) => Ok(Some(token)),
None => Err(AuthenticationError::InvalidToken),
},
_otherwise => Err(AuthenticationError::MissingAuthorizationHeader),
},
None => Ok(None),
}
}
pub trait Policy {
fn authenticate(
auth: Data<AuthController>,
@@ -299,8 +308,8 @@ pub mod policies {
auth: &AuthController,
token: &str,
) -> Result<TenantTokenOutcome, AuthError> {
// Only search action can be accessed by a tenant token.
if A != actions::SEARCH {
// Only search and chat actions can be accessed by a tenant token.
if A != actions::SEARCH && A != actions::CHAT {
return Ok(TenantTokenOutcome::NotATenantToken);
}

View File

@@ -1,45 +1,61 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Write as _;
use std::mem;
use std::sync::RwLock;
use std::time::Duration;
use actix_web::web::{self, Data};
use actix_web::{Either, HttpResponse, Responder};
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig;
use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream,
FunctionObjectArgs,
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
FunctionObjectArgs, Role,
};
use async_openai::Client;
use bumpalo::Bump;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::IndexEmbeddingConfig;
use meilisearch_types::milli::prompt::PromptData;
use meilisearch_types::milli::vector::EmbeddingConfig;
use meilisearch_types::{Document, Index};
use meilisearch_types::milli::index::{self, ChatConfig, SearchParameters};
use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
};
use meilisearch_types::Index;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::runtime::Handle;
use tracing::error;
use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, ChatSettings};
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::GuardedData;
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search::search_kind;
use crate::search::{
add_search_rules, perform_search, HybridQuery, RetrieveVectors, SearchQuery, SemanticRatio,
add_search_rules, prepare_search, search_from_kind, HybridQuery, MatchingStrategy,
RankingScoreThreshold, SearchQuery, SemanticRatio, DEFAULT_SEARCH_LIMIT,
DEFAULT_SEMANTIC_RATIO,
};
use crate::search_queue::SearchQueue;
const EMBEDDER_NAME: &str = "openai";
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("/completions").route(web::post().to(chat)));
@@ -47,7 +63,9 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
/// Get a chat completion
async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
) -> impl Responder {
@@ -61,104 +79,204 @@ async fn chat(
);
if chat_completion.stream.unwrap_or(false) {
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
Either::Right(
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
} else {
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await)
Either::Left(
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
}
}
#[derive(Default, Debug, Clone, Copy)]
pub struct FunctionSupport {
/// Defines if we can call the _meiliSearchProgress function
/// to inform the front-end about what we are searching for.
progress: bool,
/// Defines if we can call the _meiliAppendConversationMessage
/// function to provide the messages to append into the conversation.
append_to_conversation: bool,
}
/// Setup search tool in chat completion request
fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts) {
fn setup_search_tool(
index_scheduler: &Data<IndexScheduler>,
filters: &meilisearch_auth::AuthFilter,
chat_completion: &mut CreateChatCompletionRequest,
prompts: &ChatPrompts,
) -> Result<FunctionSupport, ResponseError> {
let tools = chat_completion.tools.get_or_insert_default();
tools.push(
ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(
FunctionObjectArgs::default()
.name("searchInIndex")
.description(&prompts.search_description)
.parameters(json!({
"type": "object",
"properties": {
"index_uid": {
"type": "string",
"enum": ["main"],
"description": prompts.search_index_uid_param,
},
"q": {
// Unfortunately, Mistral does not support an array of types, here.
// "type": ["string", "null"],
"type": "string",
"description": prompts.search_q_param,
}
if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
panic!("{MEILI_SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
}
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
let mut progress = false;
let mut append_to_conversation = false;
tools.retain(|tool| {
match tool.function.name.as_str() {
MEILI_SEARCH_PROGRESS_NAME => {
progress = true;
false
}
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
append_to_conversation = true;
false
}
_ => true, // keep other tools
}
});
let mut index_uids = Vec::new();
let mut function_description = prompts.search_description.clone().unwrap();
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
// Make sure to skip unauthorized indexes
if !filters.is_index_authorized(&name) {
return Ok(());
}
let rtxn = index.read_txn()?;
let chat_config = index.chat_config(&rtxn)?;
let index_description = chat_config.description;
let _ = writeln!(&mut function_description, "\n\n - {name}: {index_description}\n");
index_uids.push(name.to_string());
Ok(())
})?;
let tool = ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(
FunctionObjectArgs::default()
.name(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&function_description)
.parameters(json!({
"type": "object",
"properties": {
"index_uid": {
"type": "string",
"enum": index_uids,
"description": prompts.search_index_uid_param.clone().unwrap(),
},
"required": ["index_uid", "q"],
"additionalProperties": false,
}))
.strict(true)
.build()
.unwrap(),
)
.build()
.unwrap(),
);
"q": {
// Unfortunately, Mistral does not support an array of types, here.
// "type": ["string", "null"],
"type": "string",
"description": prompts.search_q_param.clone().unwrap(),
}
},
"required": ["index_uid", "q"],
"additionalProperties": false,
}))
.strict(true)
.build()
.unwrap(),
)
.build()
.unwrap();
tools.push(tool);
chat_completion.messages.insert(
0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
content: ChatCompletionRequestSystemMessageContent::Text(
prompts.system.as_ref().unwrap().clone(),
),
name: None,
}),
);
Ok(FunctionSupport { progress, append_to_conversation })
}
/// Process search request and return formatted results
async fn process_search_request(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
index_uid: String,
q: Option<String>,
) -> Result<(Index, String), ResponseError> {
let mut query = SearchQuery {
q,
hybrid: Some(HybridQuery {
semantic_ratio: SemanticRatio::default(),
embedder: EMBEDDER_NAME.to_string(),
}),
limit: 20,
..Default::default()
};
// Tenant token search_rules.
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
add_search_rules(&mut query.filter, search_rules);
}
// TBD
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
let index = index_scheduler.index(&index_uid)?;
let rtxn = index.static_read_txn()?;
let ChatConfig { description: _, prompt: _, search_parameters } = index.chat_config(&rtxn)?;
let SearchParameters {
hybrid,
limit,
sort,
distinct,
matching_strategy,
attributes_to_search_on,
ranking_score_threshold,
} = search_parameters;
let mut query = SearchQuery {
q,
hybrid: hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| HybridQuery {
semantic_ratio: SemanticRatio::try_from(semantic_ratio)
.ok()
.unwrap_or_else(DEFAULT_SEMANTIC_RATIO),
embedder,
}),
limit: limit.unwrap_or_else(DEFAULT_SEARCH_LIMIT),
sort: sort,
distinct: distinct,
matching_strategy: matching_strategy
.map(|ms| match ms {
index::MatchingStrategy::Last => MatchingStrategy::Last,
index::MatchingStrategy::All => MatchingStrategy::All,
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
})
.unwrap_or(MatchingStrategy::Frequency),
attributes_to_search_on: attributes_to_search_on,
ranking_score_threshold: ranking_score_threshold
.and_then(|rst| RankingScoreThreshold::try_from(rst).ok()),
..Default::default()
};
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
auth_ctrl,
auth_token,
Some(index_uid.as_str()),
)?;
// Tenant token search_rules.
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
add_search_rules(&mut query.filter, search_rules);
}
let search_kind =
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
let permit = search_queue.try_get_search_permit().await?;
let features = index_scheduler.features();
let index_cloned = index.clone();
let search_result = tokio::task::spawn_blocking(move || {
perform_search(
index_uid.to_string(),
&index_cloned,
query,
search_kind,
RetrieveVectors::new(false),
features,
)
let output = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> {
let time_budget = match index_cloned
.search_cutoff(&rtxn)
.map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))?
{
Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)),
None => TimeBudget::default(),
};
let (search, _is_finite_pagination, _max_total_hits, _offset) =
prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?;
search_from_kind(index_uid, search_kind, search)
.map(|(search_results, _)| (rtxn, search_results))
.map_err(ResponseError::from)
})
.await;
permit.drop().await;
let search_result = search_result?;
if let Ok(ref search_result) = search_result {
let output = output?;
if let Ok((_, ref search_result)) = output {
// aggregate.succeed(search_result);
if search_result.degraded {
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
@@ -166,34 +284,43 @@ async fn process_search_request(
}
// analytics.publish(aggregate, &req);
let search_result = search_result?;
let formatted =
format_documents(&index, search_result.hits.into_iter().map(|doc| doc.document));
let (rtxn, search_result) = output?;
// let rtxn = index.read_txn()?;
let render_alloc = Bump::new();
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
let text = formatted.join("\n");
drop(rtxn);
Ok((index, text))
}
async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(),
None => GlobalChatSettings::default(),
};
let mut config = OpenAIConfig::default();
if let Some(api_key) = chat_settings.api_key.as_ref() {
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let client = Client::with_config(config);
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let auth_token = extract_token_from_request(&req)?.unwrap();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { progress, append_to_conversation } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let mut response;
loop {
@@ -204,8 +331,9 @@ async fn non_streamed_chat(
Some(FinishReason::ToolCalls) => {
let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default();
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
tool_calls.into_iter().partition(|call| call.function.name == "searchInIndex");
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = tool_calls
.into_iter()
.partition(|call| call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME);
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
@@ -216,17 +344,32 @@ async fn non_streamed_chat(
);
for call in meili_calls {
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
let result = match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
&auth_token,
index_uid,
q,
)
.await
.map_err(|e| e.to_string()),
Err(err) => Err(err.to_string()),
};
let (_, text) =
process_search_request(&index_scheduler, &search_queue, index_uid, q)
.await?;
let text = match result {
Ok((_, text)) => text,
Err(err) => err,
};
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: call.id,
content: ChatCompletionRequestToolMessageContent::Text(text),
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(format!(
"{}\n\n{text}",
chat_settings.prompts.clone().unwrap().pre_query.unwrap()
)),
},
));
}
@@ -245,24 +388,31 @@ async fn non_streamed_chat(
}
async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder {
) -> Result<impl Responder, ResponseError> {
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(),
Some(value) => serde_json::from_value(value.clone()).unwrap(),
None => GlobalChatSettings::default(),
};
let mut config = OpenAIConfig::default();
if let Some(api_key) = chat_settings.api_key.as_ref() {
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { progress, append_to_conversation } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move {
@@ -270,7 +420,8 @@ async fn streamed_chat(
let mut global_tool_calls = HashMap::<u32, Call>::new();
let mut finish_reason = None;
'main: while finish_reason.map_or(true, |fr| fr == FinishReason::ToolCalls) {
// Limit the number of internal calls to satisfy the search requests of the LLM
'main: for _ in 0..20 {
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
while let Some(result) = response.next().await {
match result {
@@ -278,19 +429,8 @@ async fn streamed_chat(
let choice = &resp.choices[0];
finish_reason = choice.finish_reason;
#[allow(deprecated)]
let ChatCompletionStreamResponseDelta {
content,
// Using deprecated field but keeping for compatibility
function_call: _,
ref tool_calls,
role: _,
refusal: _,
} = &choice.delta;
if content.is_some() {
tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap()
}
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
&choice.delta;
match tool_calls {
Some(tool_calls) => {
@@ -303,128 +443,313 @@ async fn streamed_chat(
} = chunk;
let FunctionCallStream { name, arguments } =
function.as_ref().unwrap();
global_tool_calls
.entry(*index)
.and_modify(|call| {
call.append(arguments.as_ref().unwrap());
if call.is_internal() {
call.append(arguments.as_ref().unwrap())
}
})
.or_insert_with(|| Call {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
.or_insert_with(|| {
if name.as_ref().map_or(false, |n| {
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
}) {
Call::Internal {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
}
} else {
Call::External { _id: id.as_ref().unwrap().clone() }
}
});
if global_tool_calls.get(index).map_or(false, Call::is_external)
{
todo!("Support forwarding external tool calls");
}
}
}
None if !global_tool_calls.is_empty() => {
// dbg!(&global_tool_calls);
None => {
if !global_tool_calls.is_empty() {
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls)
.into_values()
.flat_map(|call| match call {
Call::Internal {
id,
function_name: name,
arguments,
} => Some(ChatCompletionMessageToolCall {
id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { name, arguments },
}),
Call::External { _id: _ } => None,
})
.partition(|call| {
call.function.name
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
});
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls)
.into_values()
.map(|call| ChatCompletionMessageToolCall {
id: call.id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall {
name: call.function_name,
arguments: call.arguments,
},
})
.partition(|call| call.function.name == "searchInIndex");
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
for call in meili_calls {
tx.send(Event::Data(
sse::Data::new_json(json!({
"object": "chat.completion.tool.call",
"tool": call,
}))
.unwrap(),
))
.await
.unwrap();
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
let result = process_search_request(
&index_scheduler,
&search_queue,
index_uid,
q,
)
.await;
let text = match result {
Ok((_, text)) => text,
Err(err) => {
error!("Error processing search request: {err:?}");
continue;
}
};
let tool = ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(
format!("{}\n\n{text}", chat_settings.prompts.pre_query),
),
},
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
tx.send(Event::Data(
sse::Data::new_json(json!({
"object": "chat.completion.tool.output",
"tool": ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: call.id,
content: ChatCompletionRequestToolMessageContent::Text(
text,
),
},
),
}))
.unwrap(),
))
.await
.unwrap();
assert!(
other_calls.is_empty(),
"We do not support external tool forwarding for now"
);
chat_completion.messages.push(tool);
for call in meili_calls {
if progress {
let call = MeiliSearchProgress {
function_name: call.function.name.clone(),
function_arguments: call
.function
.arguments
.clone(),
};
let resp = call.create_response(resp.clone());
// Send the event of "we are doing a search"
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
if append_to_conversation {
// Ask the front-end user to append this tool *call* to the conversation
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![
ChatCompletionMessageToolCall {
id: call.id.clone(),
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall {
name: call.function.name.clone(),
arguments: call.function.arguments.clone(),
},
},
]),
function_call: None,
}
));
let resp = call.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
let result =
match serde_json::from_str(&call.function.arguments) {
Ok(SearchInIndexParameters { index_uid, q }) => {
process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
&auth_token,
index_uid,
q,
)
.await
.map_err(|e| e.to_string())
}
Err(err) => Err(err.to_string()),
};
let text = match result {
Ok((_, text)) => text,
Err(err) => err,
};
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(
format!(
"{}\n\n{text}",
chat_settings
.prompts
.as_ref()
.unwrap()
.pre_query
.as_ref()
.unwrap()
),
),
});
if append_to_conversation {
// Ask the front-end user to append this tool *output* to the conversation
let tool = MeiliAppendConversationMessage(tool.clone());
let resp = tool.create_response(resp.clone());
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
chat_completion.messages.push(tool);
}
} else {
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
}
None => (),
}
}
Err(err) => {
// writeln!(lock, "error: {err}").unwrap();
tracing::error!("{err:?}");
// tracing::error!("{err:?}");
// if let Err(SendError(_)) = tx
// .send(Event::Data(
// sse::Data::new_json(&json!({
// "object": "chat.completion.error",
// "tool": err.to_string(),
// }))
// .unwrap(),
// ))
// .await
// {
// return;
// }
break 'main;
}
}
}
// We must stop if the finish reason is not something we can solve with Meilisearch
if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) {
break;
}
}
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
});
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
}
#[derive(Debug, Clone, Serialize)]
/// Give context about what Meilisearch is doing.
struct MeiliSearchProgress {
/// The name of the function we are executing.
pub function_name: String,
/// The arguments of the function we are executing, encoded in JSON.
pub function_arguments: String,
}
impl MeiliSearchProgress {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(self).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
impl MeiliAppendConversationMessage {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(&self.0).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
/// The structure used to aggregate the function calls to make.
#[derive(Debug)]
struct Call {
id: String,
function_name: String,
arguments: String,
enum Call {
/// Tool calls to tools that must be managed by Meilisearch internally.
/// Typically the search functions.
Internal { id: String, function_name: String, arguments: String },
/// Tool calls that we track but only to know that its not our functions.
/// We return the function calls as-is to the end-user.
External { _id: String },
}
impl Call {
fn append(&mut self, arguments: &str) {
self.arguments.push_str(arguments);
fn is_internal(&self) -> bool {
matches!(self, Call::Internal { .. })
}
fn is_external(&self) -> bool {
matches!(self, Call::External { .. })
}
fn append(&mut self, more: &str) {
match self {
Call::Internal { arguments, .. } => arguments.push_str(more),
Call::External { .. } => {
panic!("Cannot append argument chunks to an external function")
}
}
}
}
@@ -436,31 +761,36 @@ struct SearchInIndexParameters {
q: Option<String>,
}
fn format_documents(index: &Index, documents: impl Iterator<Item = Document>) -> Vec<String> {
let rtxn = index.read_txn().unwrap();
let IndexEmbeddingConfig { name: _, config, user_provided: _ } = index
.embedding_configs(&rtxn)
.unwrap()
fn format_documents<'t, 'doc>(
rtxn: &RoTxn<'t>,
index: &Index,
doc_alloc: &'doc Bump,
internal_docids: Vec<DocumentId>,
) -> Result<Vec<&'doc str>, ResponseError> {
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
let prompt = Prompt::new(template, max_bytes).unwrap();
let fid_map = index.fields_ids_map(rtxn)?;
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
let global = RwLock::new(fid_map_with_meta);
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
let external_ids: Vec<String> = index
.external_id_of(rtxn, internal_docids.iter().copied())?
.into_iter()
.find(|conf| conf.name == EMBEDDER_NAME)
.unwrap();
.collect::<Result<_, _>>()?;
let EmbeddingConfig {
embedder_options: _,
prompt: PromptData { template, max_bytes: _ },
quantized: _,
} = config;
let mut renders = Vec::new();
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
Some(doc) => doc,
None => continue,
};
#[derive(Serialize)]
struct Doc<T: Serialize> {
doc: T,
let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
renders.push(text);
}
let template = liquid::ParserBuilder::with_stdlib().build().unwrap().parse(&template).unwrap();
documents
.map(|doc| {
let object = liquid::to_object(&Doc { doc }).unwrap();
template.render(&object).unwrap()
})
.collect()
Ok(renders)
}

View File

@@ -6,7 +6,7 @@ use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::settings::{
settings, SecretPolicy, SettingEmbeddingSettings, Settings, Unchecked,
settings, ChatSettings, SecretPolicy, SettingEmbeddingSettings, Settings, Unchecked,
};
use meilisearch_types::tasks::KindWithContent;
use tracing::debug;
@@ -508,6 +508,17 @@ make_setting_routes!(
camelcase_attr: "prefixSearch",
analytics: PrefixSearchAnalytics
},
{
route: "/chat",
update_verb: put,
value_type: ChatSettings,
err_type: meilisearch_types::deserr::DeserrJsonError<
meilisearch_types::error::deserr_codes::InvalidSettingsIndexChat,
>,
attr: chat,
camelcase_attr: "chat",
analytics: ChatAnalytics
},
);
#[utoipa::path(
@@ -597,6 +608,7 @@ pub async fn update_all(
),
facet_search: FacetSearchAnalytics::new(new_settings.facet_search.as_ref().set()),
prefix_search: PrefixSearchAnalytics::new(new_settings.prefix_search.as_ref().set()),
chat: ChatAnalytics::new(new_settings.chat.as_ref().set()),
},
&req,
);

View File

@@ -10,8 +10,8 @@ use meilisearch_types::locales::{Locale, LocalizedAttributesRuleView};
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::FilterableAttributesRule;
use meilisearch_types::settings::{
FacetingSettings, PaginationSettings, PrefixSearchSettings, ProximityPrecisionView,
RankingRuleView, SettingEmbeddingSettings, TypoSettings,
ChatSettings, FacetingSettings, PaginationSettings, PrefixSearchSettings,
ProximityPrecisionView, RankingRuleView, SettingEmbeddingSettings, TypoSettings,
};
use serde::Serialize;
@@ -39,6 +39,7 @@ pub struct SettingsAnalytics {
pub non_separator_tokens: NonSeparatorTokensAnalytics,
pub facet_search: FacetSearchAnalytics,
pub prefix_search: PrefixSearchAnalytics,
pub chat: ChatAnalytics,
}
impl Aggregate for SettingsAnalytics {
@@ -198,6 +199,7 @@ impl Aggregate for SettingsAnalytics {
set: new.prefix_search.set | self.prefix_search.set,
value: new.prefix_search.value.or(self.prefix_search.value),
},
chat: ChatAnalytics { set: new.chat.set | self.chat.set },
})
}
@@ -674,3 +676,18 @@ impl PrefixSearchAnalytics {
SettingsAnalytics { prefix_search: self, ..Default::default() }
}
}
#[derive(Serialize, Default)]
pub struct ChatAnalytics {
pub set: bool,
}
impl ChatAnalytics {
pub fn new(settings: Option<&ChatSettings>) -> Self {
Self { set: settings.is_some() }
}
pub fn into_settings(self) -> SettingsAnalytics {
SettingsAnalytics { chat: self, ..Default::default() }
}
}

View File

@@ -1,10 +1,9 @@
use std::collections::BTreeMap;
use actix_web::web::{self, Data};
use actix_web::HttpResponse;
use index_scheduler::IndexScheduler;
use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::update::Setting;
use serde::{Deserialize, Serialize};
use crate::extractors::authentication::policies::ActionPolicy;
@@ -25,10 +24,11 @@ async fn get_settings(
Data<IndexScheduler>,
>,
) -> Result<HttpResponse, ResponseError> {
let settings = match index_scheduler.chat_settings()? {
let mut settings = match index_scheduler.chat_settings()? {
Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(),
None => GlobalChatSettings::default(),
};
settings.hide_secrets();
Ok(HttpResponse::Ok().json(settings))
}
@@ -37,38 +37,96 @@ async fn patch_settings(
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
Data<IndexScheduler>,
>,
web::Json(chat_settings): web::Json<ChatSettings>,
web::Json(new): web::Json<GlobalChatSettings>,
) -> Result<HttpResponse, ResponseError> {
let chat_settings = serde_json::to_value(chat_settings).unwrap();
index_scheduler.put_chat_settings(&chat_settings)?;
let old = match index_scheduler.chat_settings()? {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
let settings = GlobalChatSettings {
source: new.source.or(old.source),
base_api: new.base_api.clone().or(old.base_api),
api_key: new.api_key.clone().or(old.api_key),
prompts: match (new.prompts, old.prompts) {
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
// If both are set we must merge the prompts settings
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
system: new.system.or(old.system),
search_description: new.search_description.or(old.search_description),
search_q_param: new.search_q_param.or(old.search_q_param),
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
pre_query: new.pre_query.or(old.pre_query),
}),
},
};
let value = serde_json::to_value(settings).unwrap();
index_scheduler.put_chat_settings(&value)?;
Ok(HttpResponse::Ok().finish())
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatSettings {
pub source: String,
pub base_api: Option<String>,
pub api_key: Option<String>,
pub prompts: ChatPrompts,
pub indexes: BTreeMap<String, ChatIndexSettings>,
pub enum ChatSource {
OpenAi,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct GlobalChatSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub source: Setting<ChatSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub base_api: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub prompts: Setting<ChatPrompts>,
}
impl GlobalChatSettings {
pub fn hide_secrets(&mut self) {
match &mut self.api_key {
Setting::Set(key) => Self::hide_secret(key),
Setting::Reset => (),
Setting::NotSet => (),
}
}
fn hide_secret(secret: &mut String) {
match secret.len() {
x if x < 10 => {
secret.replace_range(.., "XXX...");
}
x if x < 20 => {
secret.replace_range(2.., "XXXX...");
}
x if x < 30 => {
secret.replace_range(3.., "XXXXX...");
}
_x => {
secret.replace_range(5.., "XXXXXX...");
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatPrompts {
pub system: String,
pub search_description: String,
pub search_q_param: String,
pub search_index_uid_param: String,
pub pre_query: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatIndexSettings {
pub description: String,
pub document_template: String,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub system: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_description: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_q_param: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub search_index_uid_param: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
pub pre_query: Setting<String>,
}
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a highly capable research assistant with access to powerful search tools. IMPORTANT INSTRUCTIONS:\
@@ -91,21 +149,29 @@ const DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION: &str =
"The name of the index to search within. An index is a collection of documents organized for search. \
Selecting the right index ensures the most relevant results for the user query";
impl Default for ChatSettings {
impl Default for GlobalChatSettings {
fn default() -> Self {
ChatSettings {
source: "openai".to_string(),
base_api: None,
api_key: None,
prompts: ChatPrompts {
system: DEFAULT_SYSTEM_MESSAGE.to_string(),
search_description: DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string(),
search_q_param: DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
search_index_uid_param: DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION
.to_string(),
pre_query: "".to_string(),
},
indexes: BTreeMap::new(),
GlobalChatSettings {
source: Setting::NotSet,
base_api: Setting::NotSet,
api_key: Setting::NotSet,
prompts: Setting::Set(ChatPrompts::default()),
}
}
}
impl Default for ChatPrompts {
fn default() -> Self {
ChatPrompts {
system: Setting::Set(DEFAULT_SYSTEM_MESSAGE.to_string()),
search_description: Setting::Set(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION.to_string()),
search_q_param: Setting::Set(
DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION.to_string(),
),
search_index_uid_param: Setting::Set(
DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION.to_string(),
),
pre_query: Setting::Set(Default::default()),
}
}
}

View File

@@ -122,6 +122,7 @@ pub struct SearchQuery {
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(f64);
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
@@ -279,8 +280,8 @@ impl fmt::Debug for SearchQuery {
#[deserr(error = DeserrJsonError<InvalidSearchHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
#[serde(rename_all = "camelCase")]
pub struct HybridQuery {
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
#[schema(value_type = f32, default)]
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)]
#[schema(default, value_type = f32)]
#[serde(default)]
pub semantic_ratio: SemanticRatio,
#[deserr(error = DeserrJsonError<InvalidSearchEmbedder>)]
@@ -882,7 +883,7 @@ pub fn add_search_rules(filter: &mut Option<Value>, rules: IndexSearchRules) {
}
}
fn prepare_search<'t>(
pub fn prepare_search<'t>(
index: &'t Index,
rtxn: &'t RoTxn,
query: &'t SearchQuery,

View File

@@ -820,6 +820,22 @@ async fn list_api_keys() {
"createdAt": "[ignored]",
"updatedAt": "[ignored]"
},
{
"name": "Default Chat API Key",
"description": "Use it to chat and search from the frontend",
"key": "[ignored]",
"uid": "[ignored]",
"actions": [
"search",
"chat.get"
],
"indexes": [
"*"
],
"expiresAt": null,
"createdAt": "[ignored]",
"updatedAt": "[ignored]"
},
{
"name": "Default Search API Key",
"description": "Use it to search from the frontend",

View File

@@ -32,13 +32,13 @@ impl ExternalDocumentsIds {
&self,
rtxn: &RoTxn<'_>,
external_id: A,
) -> heed::Result<Option<u32>> {
) -> heed::Result<Option<DocumentId>> {
self.0.get(rtxn, external_id.as_ref())
}
/// An helper function to debug this type, returns an `HashMap` of both,
/// soft and hard fst maps, combined.
pub fn to_hash_map(&self, rtxn: &RoTxn<'_>) -> heed::Result<HashMap<String, u32>> {
pub fn to_hash_map(&self, rtxn: &RoTxn<'_>) -> heed::Result<HashMap<String, DocumentId>> {
let mut map = HashMap::default();
for result in self.0.iter(rtxn)? {
let (external, internal) = result?;

View File

@@ -7,6 +7,7 @@ use crate::FieldId;
mod global;
pub mod metadata;
pub use global::GlobalFieldsIdsMap;
pub use metadata::{FieldIdMapWithMetadata, MetadataBuilder};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldsIdsMap {

View File

@@ -23,6 +23,7 @@ use crate::heed_codec::facet::{
use crate::heed_codec::version::VersionCodec;
use crate::heed_codec::{BEU16StrCodec, FstSetCodec, StrBEU16Codec, StrRefCodec};
use crate::order_by_map::OrderByMap;
use crate::prompt::PromptData;
use crate::proximity::ProximityPrecision;
use crate::vector::{ArroyStats, ArroyWrapper, Embedding, EmbeddingConfig};
use crate::{
@@ -79,6 +80,7 @@ pub mod main_key {
pub const PREFIX_SEARCH: &str = "prefix_search";
pub const DOCUMENTS_STATS: &str = "documents_stats";
pub const DISABLED_TYPOS_TERMS: &str = "disabled_typos_terms";
pub const CHAT: &str = "chat";
}
pub mod db_name {
@@ -1691,6 +1693,25 @@ impl Index {
self.main.remap_key_type::<Str>().delete(txn, main_key::FACET_SEARCH)
}
pub fn chat_config(&self, txn: &RoTxn<'_>) -> heed::Result<ChatConfig> {
self.main
.remap_types::<Str, SerdeJson<_>>()
.get(txn, main_key::CHAT)
.map(|o| o.unwrap_or_default())
}
pub(crate) fn put_chat_config(
&self,
txn: &mut RwTxn<'_>,
val: &ChatConfig,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::CHAT, &val)
}
pub(crate) fn delete_chat_config(&self, txn: &mut RwTxn<'_>) -> heed::Result<bool> {
self.main.remap_key_type::<Str>().delete(txn, main_key::CHAT)
}
pub fn localized_attributes_rules(
&self,
rtxn: &RoTxn<'_>,
@@ -1917,13 +1938,59 @@ pub struct IndexEmbeddingConfig {
pub user_provided: RoaringBitmap,
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct ChatConfig {
pub description: String,
/// Contains the document template and max template length.
pub prompt: PromptData,
pub search_parameters: SearchParameters,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub hybrid: Option<HybridQuery>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub distinct: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matching_strategy: Option<MatchingStrategy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub attributes_to_search_on: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ranking_score_threshold: Option<f64>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HybridQuery {
pub semantic_ratio: f32,
pub embedder: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PrefixSettings {
pub prefix_count_threshold: usize,
pub max_prefix_length: usize,
pub compute_prefixes: PrefixSearch,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum MatchingStrategy {
/// Remove query words from last to first
Last,
/// All query words are mandatory
All,
/// Remove query words from the most frequent to the least
Frequency,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "camelCase")]
pub enum PrefixSearch {

View File

@@ -52,18 +52,19 @@ pub use search::new::{
};
use serde_json::Value;
pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
pub use {charabia as tokenizer, heed, rhai};
pub use {arroy, charabia as tokenizer, heed, rhai};
pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError};
pub use self::attribute_patterns::AttributePatterns;
pub use self::attribute_patterns::PatternMatch;
pub use self::attribute_patterns::{AttributePatterns, PatternMatch};
pub use self::criterion::{default_criteria, Criterion, CriterionError};
pub use self::error::{
Error, FieldIdMapMissingEntry, InternalError, SerializationError, UserError,
};
pub use self::external_documents_ids::ExternalDocumentsIds;
pub use self::fieldids_weights_map::FieldidsWeightsMap;
pub use self::fields_ids_map::{FieldsIdsMap, GlobalFieldsIdsMap};
pub use self::fields_ids_map::{
FieldIdMapWithMetadata, FieldsIdsMap, GlobalFieldsIdsMap, MetadataBuilder,
};
pub use self::filterable_attributes_rules::{
FilterFeatures, FilterableAttributesFeatures, FilterableAttributesPatterns,
FilterableAttributesRule,
@@ -84,8 +85,6 @@ pub use self::search::{
};
pub use self::update::ChannelCongestion;
pub use arroy;
pub type Result<T> = std::result::Result<T, error::Error>;
pub type Attribute = u32;

View File

@@ -105,10 +105,10 @@ impl Prompt {
max_bytes,
};
// render template with special object that's OK with `doc.*` and `fields.*`
this.template
.render(&template_checker::TemplateChecker)
.map_err(NewPromptError::invalid_fields_in_template)?;
// // render template with special object that's OK with `doc.*` and `fields.*`
// this.template
// .render(&template_checker::TemplateChecker)
// .map_err(NewPromptError::invalid_fields_in_template)?;
Ok(this)
}

View File

@@ -0,0 +1,259 @@
use std::error::Error;
use std::fmt;
use deserr::errors::JsonError;
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::index::{self, ChatConfig, SearchParameters};
use crate::prompt::{default_max_bytes, PromptData};
use crate::update::Setting;
use crate::TermsMatchingStrategy;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
pub struct ChatSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
pub description: Setting<String>,
/// A liquid template used to render documents to a text that can be embedded.
///
/// Meillisearch interpolates the template for each document and sends the resulting text to the embedder.
/// The embedder then generates document vectors based on this text.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
pub document_template: Setting<String>,
/// Rendered texts are truncated to this size. Defaults to 400.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<usize>)]
pub document_template_max_bytes: Setting<usize>,
/// The search parameters to use for the LLM.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<ChatSearchParams>)]
pub search_parameters: Setting<ChatSearchParams>,
}
impl From<ChatConfig> for ChatSettings {
fn from(config: ChatConfig) -> Self {
let ChatConfig {
description,
prompt: PromptData { template, max_bytes },
search_parameters,
} = config;
ChatSettings {
description: Setting::Set(description),
document_template: Setting::Set(template),
document_template_max_bytes: Setting::Set(
max_bytes.unwrap_or(default_max_bytes()).get(),
),
search_parameters: Setting::Set({
let SearchParameters {
hybrid,
limit,
sort,
distinct,
matching_strategy,
attributes_to_search_on,
ranking_score_threshold,
} = search_parameters;
let hybrid = hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| {
HybridQuery { semantic_ratio: SemanticRatio(semantic_ratio), embedder }
});
let matching_strategy = matching_strategy.map(|ms| match ms {
index::MatchingStrategy::Last => MatchingStrategy::Last,
index::MatchingStrategy::All => MatchingStrategy::All,
index::MatchingStrategy::Frequency => MatchingStrategy::Frequency,
});
let ranking_score_threshold = ranking_score_threshold.map(RankingScoreThreshold);
ChatSearchParams {
hybrid: Setting::some_or_not_set(hybrid),
limit: Setting::some_or_not_set(limit),
sort: Setting::some_or_not_set(sort),
distinct: Setting::some_or_not_set(distinct),
matching_strategy: Setting::some_or_not_set(matching_strategy),
attributes_to_search_on: Setting::some_or_not_set(attributes_to_search_on),
ranking_score_threshold: Setting::some_or_not_set(ranking_score_threshold),
}
}),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)]
pub struct ChatSearchParams {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<HybridQuery>)]
pub hybrid: Setting<HybridQuery>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default = Setting::Set(20))]
#[schema(value_type = Option<usize>)]
pub limit: Setting<usize>,
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
// #[deserr(default)]
// pub attributes_to_retrieve: Option<BTreeSet<String>>,
// #[serde(default, skip_serializing_if = "Setting::is_not_set")]
// #[deserr(default)]
// pub filter: Option<Value>,
//
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<Vec<String>>)]
pub sort: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
pub distinct: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<MatchingStrategy>)]
pub matching_strategy: Setting<MatchingStrategy>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<Vec<String>>)]
pub attributes_to_search_on: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<RankingScoreThreshold>)]
pub ranking_score_threshold: Setting<RankingScoreThreshold>,
}
#[derive(Debug, Clone, Default, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[deserr(error = JsonError, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery {
#[deserr(default)]
#[serde(default)]
#[schema(default, value_type = f32)]
pub semantic_ratio: SemanticRatio,
#[schema(value_type = String)]
pub embedder: String,
}
#[derive(Debug, Clone, Copy, Deserr, ToSchema, PartialEq, Serialize, Deserialize)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32);
impl Default for SemanticRatio {
fn default() -> Self {
SemanticRatio(0.5)
}
}
impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio)
} else {
Ok(SemanticRatio(f))
}
}
}
#[derive(Debug)]
pub struct InvalidSearchSemanticRatio;
impl Error for InvalidSearchSemanticRatio {}
impl fmt::Display for InvalidSearchSemanticRatio {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`."
)
}
}
impl std::ops::Deref for SemanticRatio {
type Target = f32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(rename_all = camelCase)]
#[serde(rename_all = "camelCase")]
pub enum MatchingStrategy {
/// Remove query words from last to first
Last,
/// All query words are mandatory
All,
/// Remove query words from the most frequent to the least
Frequency,
}
impl Default for MatchingStrategy {
fn default() -> Self {
Self::Last
}
}
impl From<MatchingStrategy> for TermsMatchingStrategy {
fn from(other: MatchingStrategy) -> Self {
match other {
MatchingStrategy::Last => Self::Last,
MatchingStrategy::All => Self::All,
MatchingStrategy::Frequency => Self::Frequency,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize, Deserialize)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(pub f64);
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
fn try_from(f: f64) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchRankingScoreThreshold)
} else {
Ok(RankingScoreThreshold(f))
}
}
}
#[derive(Debug)]
pub struct InvalidSearchRankingScoreThreshold;
impl Error for InvalidSearchRankingScoreThreshold {}
impl fmt::Display for InvalidSearchRankingScoreThreshold {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
)
}
}

View File

@@ -1,4 +1,5 @@
pub use self::available_ids::AvailableIds;
pub use self::chat::ChatSettings;
pub use self::clear_documents::ClearDocuments;
pub use self::concurrent_available_ids::ConcurrentAvailableIds;
pub use self::facet::bulk::FacetsUpdateBulk;
@@ -13,6 +14,7 @@ pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids;
pub use self::words_prefixes_fst::WordsPrefixesFst;
mod available_ids;
mod chat;
mod clear_documents;
mod concurrent_available_ids;
pub(crate) mod del_add;

View File

@@ -11,9 +11,10 @@ use roaring::RoaringBitmap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
use super::chat::{ChatSearchParams, RankingScoreThreshold};
use super::del_add::{DelAdd, DelAddOperation};
use super::index_documents::{IndexDocumentsConfig, Transform};
use super::IndexerConfig;
use super::{ChatSettings, IndexerConfig};
use crate::attribute_patterns::PatternMatch;
use crate::constants::RESERVED_GEO_FIELD_NAME;
use crate::criterion::Criterion;
@@ -22,11 +23,11 @@ use crate::error::UserError;
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::filterable_attributes_rules::match_faceted_field;
use crate::index::{
IndexEmbeddingConfig, PrefixSearch, DEFAULT_MIN_WORD_LEN_ONE_TYPO,
DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch,
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
};
use crate::order_by_map::OrderByMap;
use crate::prompt::default_max_bytes;
use crate::prompt::{default_max_bytes, PromptData};
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
@@ -123,6 +124,15 @@ impl<T> Setting<T> {
*self = new;
true
}
#[track_caller]
pub fn unwrap(self) -> T {
match self {
Setting::Set(value) => value,
Setting::Reset => panic!("Setting::Reset unwrapped"),
Setting::NotSet => panic!("Setting::NotSet unwrapped"),
}
}
}
impl<T: Serialize> Serialize for Setting<T> {
@@ -185,6 +195,7 @@ pub struct Settings<'a, 't, 'i> {
localized_attributes_rules: Setting<Vec<LocalizedAttributesRule>>,
prefix_search: Setting<PrefixSearch>,
facet_search: Setting<bool>,
chat: Setting<ChatSettings>,
}
impl<'a, 't, 'i> Settings<'a, 't, 'i> {
@@ -223,6 +234,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
localized_attributes_rules: Setting::NotSet,
prefix_search: Setting::NotSet,
facet_search: Setting::NotSet,
chat: Setting::NotSet,
indexer_config,
}
}
@@ -453,6 +465,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.facet_search = Setting::Reset;
}
pub fn set_chat(&mut self, value: ChatSettings) {
self.chat = Setting::Set(value);
}
pub fn reset_chat(&mut self) {
self.chat = Setting::Reset;
}
#[tracing::instrument(
level = "trace"
skip(self, progress_callback, should_abort, settings_diff),
@@ -1239,6 +1259,126 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Ok(())
}
fn update_chat_config(&mut self) -> heed::Result<bool> {
match &mut self.chat {
Setting::Set(ChatSettings {
description: new_description,
document_template: new_document_template,
document_template_max_bytes: new_document_template_max_bytes,
search_parameters: new_search_parameters,
}) => {
let mut old = self.index.chat_config(self.wtxn)?;
let ChatConfig {
ref mut description,
prompt: PromptData { ref mut template, ref mut max_bytes },
ref mut search_parameters,
} = old;
match new_description {
Setting::Set(d) => *description = d.clone(),
Setting::Reset => *description = Default::default(),
Setting::NotSet => (),
}
match new_document_template {
Setting::Set(dt) => *template = dt.clone(),
Setting::Reset => *template = Default::default(),
Setting::NotSet => (),
}
match new_document_template_max_bytes {
Setting::Set(m) => *max_bytes = NonZeroUsize::new(*m),
Setting::Reset => *max_bytes = Some(default_max_bytes()),
Setting::NotSet => (),
}
match new_search_parameters {
Setting::Set(sp) => {
let ChatSearchParams {
hybrid,
limit,
sort,
distinct,
matching_strategy,
attributes_to_search_on,
ranking_score_threshold,
} = sp;
match hybrid {
Setting::Set(hybrid) => {
search_parameters.hybrid = Some(crate::index::HybridQuery {
semantic_ratio: *hybrid.semantic_ratio,
embedder: hybrid.embedder.clone(),
})
}
Setting::Reset => search_parameters.hybrid = None,
Setting::NotSet => (),
}
match limit {
Setting::Set(limit) => search_parameters.limit = Some(*limit),
Setting::Reset => search_parameters.limit = None,
Setting::NotSet => (),
}
match sort {
Setting::Set(sort) => search_parameters.sort = Some(sort.clone()),
Setting::Reset => search_parameters.sort = None,
Setting::NotSet => (),
}
match distinct {
Setting::Set(distinct) => {
search_parameters.distinct = Some(distinct.clone())
}
Setting::Reset => search_parameters.distinct = None,
Setting::NotSet => (),
}
match matching_strategy {
Setting::Set(matching_strategy) => {
let strategy = match matching_strategy {
super::chat::MatchingStrategy::Last => MatchingStrategy::Last,
super::chat::MatchingStrategy::All => MatchingStrategy::All,
super::chat::MatchingStrategy::Frequency => {
MatchingStrategy::Frequency
}
};
search_parameters.matching_strategy = Some(strategy)
}
Setting::Reset => search_parameters.matching_strategy = None,
Setting::NotSet => (),
}
match attributes_to_search_on {
Setting::Set(attributes_to_search_on) => {
search_parameters.attributes_to_search_on =
Some(attributes_to_search_on.clone())
}
Setting::Reset => search_parameters.attributes_to_search_on = None,
Setting::NotSet => (),
}
match ranking_score_threshold {
Setting::Set(RankingScoreThreshold(score)) => {
search_parameters.ranking_score_threshold = Some(*score)
}
Setting::Reset => search_parameters.ranking_score_threshold = None,
Setting::NotSet => (),
}
}
Setting::Reset => *search_parameters = Default::default(),
Setting::NotSet => (),
}
self.index.put_chat_config(self.wtxn, &old)?;
Ok(true)
}
Setting::Reset => self.index.delete_chat_config(self.wtxn),
Setting::NotSet => Ok(false),
}
}
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
where
FP: Fn(UpdateIndexingStep) + Sync,
@@ -1276,6 +1416,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.update_facet_search()?;
self.update_localized_attributes_rules()?;
self.update_disabled_typos_terms()?;
self.update_chat_config()?;
let embedding_config_updates = self.update_embedding_configs()?;

View File

@@ -33,6 +33,7 @@ pub struct EmbeddingSettings {
///
/// - Defaults to `openAi`
pub source: Setting<EmbedderSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
@@ -55,6 +56,7 @@ pub struct EmbeddingSettings {
/// - For source `openAi`, defaults to `text-embedding-3-small`
/// - For source `huggingFace`, defaults to `BAAI/bge-base-en-v1.5`
pub model: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
@@ -75,6 +77,7 @@ pub struct EmbeddingSettings {
/// - When `model` is set to default, defaults to `617ca489d9e86b49b8167676d8220688b99db36e`
/// - Otherwise, defaults to `null`
pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<OverridePooling>)]
@@ -96,6 +99,7 @@ pub struct EmbeddingSettings {
///
/// - Embedders created before this parameter was available default to `forceMean` to preserve the existing behavior.
pub pooling: Setting<OverridePooling>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
@@ -118,6 +122,7 @@ pub struct EmbeddingSettings {
///
/// - This setting is partially hidden when returned by the settings
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
@@ -141,6 +146,7 @@ pub struct EmbeddingSettings {
/// - For source `openAi`, the dimensions is the maximum allowed by the model.
/// - For sources `ollama` and `rest`, the dimensions are inferred by embedding a sample text.
pub dimensions: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<bool>)]
@@ -167,6 +173,7 @@ pub struct EmbeddingSettings {
/// first enabling it. If you are unsure of whether the performance-relevancy tradeoff is right for you,
/// we recommend to use this parameter on a test index first.
pub binary_quantized: Setting<bool>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<bool>)]
@@ -183,6 +190,7 @@ pub struct EmbeddingSettings {
///
/// - 🏗️ When modified, embeddings are regenerated for documents whose rendering through the template produces a different text.
pub document_template: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<usize>)]
@@ -201,6 +209,7 @@ pub struct EmbeddingSettings {
///
/// - Defaults to 400
pub document_template_max_bytes: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)]
@@ -219,6 +228,7 @@ pub struct EmbeddingSettings {
/// - 🌱 When modified for source `openAi`, embeddings are never regenerated
/// - 🏗️ When modified for sources `ollama` and `rest`, embeddings are always regenerated
pub url: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<serde_json::Value>)]
@@ -236,6 +246,7 @@ pub struct EmbeddingSettings {
///
/// - 🏗️ Changing the value of this parameter always regenerates embeddings
pub request: Setting<serde_json::Value>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<serde_json::Value>)]
@@ -253,6 +264,7 @@ pub struct EmbeddingSettings {
///
/// - 🏗️ Changing the value of this parameter always regenerates embeddings
pub response: Setting<serde_json::Value>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<BTreeMap<String, String>>)]