mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-07-18 04:11:07 +00:00
Compare commits
61 Commits
prototype-
...
v1.6.2
Author | SHA1 | Date | |
---|---|---|---|
1a083d54fc | |||
3a97d30cd9 | |||
3bdb2b06be | |||
64afed4dbe | |||
b26ddfcc3d | |||
c59f3f2f95 | |||
049bd45849 | |||
2491db8746 | |||
425bc92ce6 | |||
cbd065ed46 | |||
b9f365a965 | |||
3f21daf2e7 | |||
d77df4ecdb | |||
fdac97e3c8 | |||
bbdfbd8ea1 | |||
da7c796be1 | |||
014eaea428 | |||
a6fa0b97ec | |||
38abfec611 | |||
84a5c304fc | |||
e93d36d5b9 | |||
95f8e21533 | |||
68f197624e | |||
b79b03d4e2 | |||
86270e6878 | |||
81b6128b29 | |||
5f5a486895 | |||
5f4fc6c955 | |||
1f5e8fc072 | |||
3f3462ab62 | |||
93363b0201 | |||
97bb1ff9e2 | |||
5ee1378856 | |||
e27b850b09 | |||
f75f22e026 | |||
6203f4acef | |||
12edc2c20a | |||
94b9f3b310 | |||
da99a04eb3 | |||
54ae6951eb | |||
658ec6e0a4 | |||
43e822e802 | |||
ee54d3171e | |||
a0e713c4e7 | |||
d4cb0a885b | |||
f52dee2b3b | |||
0bf879fb88 | |||
6ff81de401 | |||
2e4c9651df | |||
ec9649c922 | |||
9123370e90 | |||
14b396d302 | |||
393216bf30 | |||
e249e4db7b | |||
de2ca7006e | |||
333ce12eb2 | |||
fb9db1eba6 | |||
b2193e612f | |||
942d49314c | |||
9a846e82bc | |||
9df8cfc013 |
37
Cargo.lock
generated
37
Cargo.lock
generated
@ -383,7 +383,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "arroy"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/meilisearch/arroy.git#4f193fd534acd357b65bfe9eec4b3fed8ece2007"
|
||||
source = "git+https://github.com/meilisearch/arroy.git#d372648212e561a4845077cdb9239423d78655a2"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
@ -491,7 +491,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "benchmarks"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
@ -1402,7 +1402,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "dump"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"big_s",
|
||||
@ -1592,9 +1592,6 @@ name = "esaxx-rs"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
@ -1637,7 +1634,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "file-store"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"faux",
|
||||
"tempfile",
|
||||
@ -1659,7 +1656,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "filter-parser"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"insta",
|
||||
"nom",
|
||||
@ -1690,7 +1687,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "flatten-serde-json"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"serde_json",
|
||||
@ -1808,7 +1805,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "fuzzers"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"clap",
|
||||
@ -2766,7 +2763,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "index-scheduler"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"big_s",
|
||||
@ -2963,7 +2960,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "json-depth-checker"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"serde_json",
|
||||
@ -3475,7 +3472,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
|
||||
|
||||
[[package]]
|
||||
name = "meili-snap"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"insta",
|
||||
"md5",
|
||||
@ -3484,7 +3481,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"actix-cors",
|
||||
"actix-http",
|
||||
@ -3575,7 +3572,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch-auth"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"base64 0.21.5",
|
||||
"enum-iterator",
|
||||
@ -3594,7 +3591,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch-types"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"actix-web",
|
||||
"anyhow",
|
||||
@ -3624,7 +3621,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilitool"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
@ -3672,7 +3669,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "milli"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"arroy",
|
||||
"big_s",
|
||||
@ -4079,7 +4076,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "permissive-json-pointer"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"big_s",
|
||||
"serde_json",
|
||||
@ -5313,11 +5310,9 @@ version = "0.14.1"
|
||||
source = "git+https://github.com/huggingface/tokenizers.git?tag=v0.14.1#6357206cdcce4d78ffb1e0372feb456caea09375"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"clap",
|
||||
"derive_builder",
|
||||
"esaxx-rs",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
"itertools 0.11.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
|
@ -19,7 +19,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
authors = ["Quentin de Quelen <quentin@dequelen.me>", "Clément Renault <clement@meilisearch.com>"]
|
||||
description = "Meilisearch HTTP server"
|
||||
homepage = "https://meilisearch.com"
|
||||
|
@ -936,8 +936,8 @@ impl IndexScheduler {
|
||||
};
|
||||
|
||||
// the index operation can take a long time, so save this handle to make it available to the search for the duration of the tick
|
||||
*self.currently_updating_index.write().unwrap() =
|
||||
Some((index_uid.clone(), index.clone()));
|
||||
self.index_mapper
|
||||
.set_currently_updating_index(Some((index_uid.clone(), index.clone())));
|
||||
|
||||
let mut index_wtxn = index.write_txn()?;
|
||||
let tasks = self.apply_index_operation(&mut index_wtxn, &index, op)?;
|
||||
@ -1351,9 +1351,6 @@ impl IndexScheduler {
|
||||
|
||||
for (task, (_, settings)) in tasks.iter_mut().zip(settings) {
|
||||
let checked_settings = settings.clone().check();
|
||||
if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) {
|
||||
self.features().check_vector("Passing `embedders` in settings")?
|
||||
}
|
||||
task.details = Some(Details::SettingsUpdate { settings: Box::new(settings) });
|
||||
apply_settings_to_builder(&checked_settings, &mut builder);
|
||||
|
||||
|
@ -69,6 +69,10 @@ pub struct IndexMapper {
|
||||
/// Whether we open a meilisearch index with the MDB_WRITEMAP option or not.
|
||||
enable_mdb_writemap: bool,
|
||||
pub indexer_config: Arc<IndexerConfig>,
|
||||
|
||||
/// A few types of long running batches of tasks that act on a single index set this field
|
||||
/// so that a handle to the index is available from other threads (search) in an optimized manner.
|
||||
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
|
||||
}
|
||||
|
||||
/// Whether the index is available for use or is forbidden to be inserted back in the index map
|
||||
@ -151,6 +155,7 @@ impl IndexMapper {
|
||||
index_growth_amount,
|
||||
enable_mdb_writemap,
|
||||
indexer_config: Arc::new(indexer_config),
|
||||
currently_updating_index: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -303,6 +308,14 @@ impl IndexMapper {
|
||||
|
||||
/// Return an index, may open it if it wasn't already opened.
|
||||
pub fn index(&self, rtxn: &RoTxn, name: &str) -> Result<Index> {
|
||||
if let Some((current_name, current_index)) =
|
||||
self.currently_updating_index.read().unwrap().as_ref()
|
||||
{
|
||||
if current_name == name {
|
||||
return Ok(current_index.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let uuid = self
|
||||
.index_mapping
|
||||
.get(rtxn, name)?
|
||||
@ -474,4 +487,8 @@ impl IndexMapper {
|
||||
pub fn indexer_config(&self) -> &IndexerConfig {
|
||||
&self.indexer_config
|
||||
}
|
||||
|
||||
pub fn set_currently_updating_index(&self, index: Option<(String, Index)>) {
|
||||
*self.currently_updating_index.write().unwrap() = index;
|
||||
}
|
||||
}
|
||||
|
@ -42,7 +42,6 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String {
|
||||
test_breakpoint_sdr: _,
|
||||
planned_failures: _,
|
||||
run_loop_iteration: _,
|
||||
currently_updating_index: _,
|
||||
embedders: _,
|
||||
} = scheduler;
|
||||
|
||||
|
@ -351,10 +351,6 @@ pub struct IndexScheduler {
|
||||
/// The path to the version file of Meilisearch.
|
||||
pub(crate) version_file_path: PathBuf,
|
||||
|
||||
/// A few types of long running batches of tasks that act on a single index set this field
|
||||
/// so that a handle to the index is available from other threads (search) in an optimized manner.
|
||||
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
|
||||
|
||||
embedders: Arc<RwLock<HashMap<EmbedderOptions, Arc<Embedder>>>>,
|
||||
|
||||
// ================= test
|
||||
@ -403,7 +399,6 @@ impl IndexScheduler {
|
||||
version_file_path: self.version_file_path.clone(),
|
||||
webhook_url: self.webhook_url.clone(),
|
||||
webhook_authorization_header: self.webhook_authorization_header.clone(),
|
||||
currently_updating_index: self.currently_updating_index.clone(),
|
||||
embedders: self.embedders.clone(),
|
||||
#[cfg(test)]
|
||||
test_breakpoint_sdr: self.test_breakpoint_sdr.clone(),
|
||||
@ -504,7 +499,6 @@ impl IndexScheduler {
|
||||
version_file_path: options.version_file_path,
|
||||
webhook_url: options.webhook_url,
|
||||
webhook_authorization_header: options.webhook_authorization_header,
|
||||
currently_updating_index: Arc::new(RwLock::new(None)),
|
||||
embedders: Default::default(),
|
||||
|
||||
#[cfg(test)]
|
||||
@ -688,13 +682,6 @@ impl IndexScheduler {
|
||||
/// If you need to fetch information from or perform an action on all indexes,
|
||||
/// see the `try_for_each_index` function.
|
||||
pub fn index(&self, name: &str) -> Result<Index> {
|
||||
if let Some((current_name, current_index)) =
|
||||
self.currently_updating_index.read().unwrap().as_ref()
|
||||
{
|
||||
if current_name == name {
|
||||
return Ok(current_index.clone());
|
||||
}
|
||||
}
|
||||
let rtxn = self.env.read_txn()?;
|
||||
self.index_mapper.index(&rtxn, name)
|
||||
}
|
||||
@ -1175,7 +1162,7 @@ impl IndexScheduler {
|
||||
};
|
||||
|
||||
// Reset the currently updating index to relinquish the index handle
|
||||
*self.currently_updating_index.write().unwrap() = None;
|
||||
self.index_mapper.set_currently_updating_index(None);
|
||||
|
||||
#[cfg(test)]
|
||||
self.maybe_fail(tests::FailureLocation::AcquiringWtxn)?;
|
||||
|
@ -344,7 +344,10 @@ impl ErrorCode for milli::Error {
|
||||
Code::InvalidDocumentId
|
||||
}
|
||||
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
|
||||
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
|
||||
UserError::InvalidFieldForSource { .. }
|
||||
| UserError::MissingFieldForSource { .. }
|
||||
| UserError::InvalidOpenAiModel { .. }
|
||||
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
|
||||
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
|
||||
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
|
||||
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
|
||||
|
@ -318,6 +318,21 @@ impl Settings<Unchecked> {
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate(self) -> Result<Self, milli::Error> {
|
||||
self.validate_embedding_settings()
|
||||
}
|
||||
|
||||
fn validate_embedding_settings(mut self) -> Result<Self, milli::Error> {
|
||||
let Setting::Set(mut configs) = self.embedders else { return Ok(self) };
|
||||
for (name, config) in configs.iter_mut() {
|
||||
let config_to_check = std::mem::take(config);
|
||||
let checked_config = milli::update::validate_embedding_settings(config_to_check, name)?;
|
||||
*config = checked_config
|
||||
}
|
||||
self.embedders = Setting::Set(configs);
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -585,11 +600,12 @@ pub fn settings(
|
||||
),
|
||||
};
|
||||
|
||||
let embedders = index
|
||||
let embedders: BTreeMap<_, _> = index
|
||||
.embedding_configs(rtxn)?
|
||||
.into_iter()
|
||||
.map(|(name, config)| (name, Setting::Set(config.into())))
|
||||
.collect();
|
||||
let embedders = if embedders.is_empty() { Setting::NotSet } else { Setting::Set(embedders) };
|
||||
|
||||
Ok(Settings {
|
||||
displayed_attributes: match displayed_attributes {
|
||||
@ -611,15 +627,12 @@ pub fn settings(
|
||||
Some(field) => Setting::Set(field),
|
||||
None => Setting::Reset,
|
||||
},
|
||||
proximity_precision: match proximity_precision {
|
||||
Some(precision) => Setting::Set(precision),
|
||||
None => Setting::Reset,
|
||||
},
|
||||
proximity_precision: Setting::Set(proximity_precision.unwrap_or_default()),
|
||||
synonyms: Setting::Set(synonyms),
|
||||
typo_tolerance: Setting::Set(typo_tolerance),
|
||||
faceting: Setting::Set(faceting),
|
||||
pagination: Setting::Set(pagination),
|
||||
embedders: Setting::Set(embedders),
|
||||
embedders,
|
||||
_kind: PhantomData,
|
||||
})
|
||||
}
|
||||
@ -720,10 +733,11 @@ impl From<RankingRuleView> for Criterion {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserr, Serialize, Deserialize)]
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Deserr, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(error = DeserrJsonError<InvalidSettingsProximityPrecision>, rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum ProximityPrecisionView {
|
||||
#[default]
|
||||
ByWord,
|
||||
ByAttribute,
|
||||
}
|
||||
|
@ -154,5 +154,5 @@ greek = ["meilisearch-types/greek"]
|
||||
khmer = ["meilisearch-types/khmer"]
|
||||
|
||||
[package.metadata.mini-dashboard]
|
||||
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.11/build.zip"
|
||||
sha1 = "83cd44ed1e5f97ecb581dc9f958a63f4ccc982d9"
|
||||
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.13/build.zip"
|
||||
sha1 = "e20cc9b390003c6c844f4b8bcc5c5013191a77ff"
|
||||
|
@ -90,6 +90,11 @@ macro_rules! make_setting_route {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let new_settings = $crate::routes::indexes::settings::validate_settings(
|
||||
new_settings,
|
||||
&index_scheduler,
|
||||
)?;
|
||||
|
||||
let allow_index_creation =
|
||||
index_scheduler.filters().allow_index_creation(&index_uid);
|
||||
|
||||
@ -453,7 +458,7 @@ make_setting_route!(
|
||||
json!({
|
||||
"proximity_precision": {
|
||||
"set": precision.is_some(),
|
||||
"value": precision,
|
||||
"value": precision.unwrap_or_default(),
|
||||
}
|
||||
}),
|
||||
Some(req),
|
||||
@ -582,13 +587,13 @@ fn embedder_analytics(
|
||||
for source in s
|
||||
.values()
|
||||
.filter_map(|config| config.clone().set())
|
||||
.filter_map(|config| config.embedder_options.set())
|
||||
.filter_map(|config| config.source.set())
|
||||
{
|
||||
use meilisearch_types::milli::vector::settings::EmbedderSettings;
|
||||
use meilisearch_types::milli::vector::settings::EmbedderSource;
|
||||
match source {
|
||||
EmbedderSettings::OpenAi(_) => sources.insert("openAi"),
|
||||
EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"),
|
||||
EmbedderSettings::UserProvided(_) => sources.insert("userProvided"),
|
||||
EmbedderSource::OpenAi => sources.insert("openAi"),
|
||||
EmbedderSource::HuggingFace => sources.insert("huggingFace"),
|
||||
EmbedderSource::UserProvided => sources.insert("userProvided"),
|
||||
};
|
||||
}
|
||||
};
|
||||
@ -651,6 +656,7 @@ pub async fn update_all(
|
||||
let index_uid = IndexUid::try_from(index_uid.into_inner())?;
|
||||
|
||||
let new_settings = body.into_inner();
|
||||
let new_settings = validate_settings(new_settings, &index_scheduler)?;
|
||||
|
||||
analytics.publish(
|
||||
"Settings Updated".to_string(),
|
||||
@ -684,7 +690,8 @@ pub async fn update_all(
|
||||
"set": new_settings.distinct_attribute.as_ref().set().is_some()
|
||||
},
|
||||
"proximity_precision": {
|
||||
"set": new_settings.proximity_precision.as_ref().set().is_some()
|
||||
"set": new_settings.proximity_precision.as_ref().set().is_some(),
|
||||
"value": new_settings.proximity_precision.as_ref().set().copied().unwrap_or_default()
|
||||
},
|
||||
"typo_tolerance": {
|
||||
"enabled": new_settings.typo_tolerance
|
||||
@ -800,3 +807,13 @@ pub async fn delete_all(
|
||||
debug!("returns: {:?}", task);
|
||||
Ok(HttpResponse::Accepted().json(task))
|
||||
}
|
||||
|
||||
fn validate_settings(
|
||||
settings: Settings<Unchecked>,
|
||||
index_scheduler: &IndexScheduler,
|
||||
) -> Result<Settings<Unchecked>, ResponseError> {
|
||||
if matches!(settings.embedders, Setting::Set(_)) {
|
||||
index_scheduler.features().check_vector("Passing `embedders` in settings")?
|
||||
}
|
||||
Ok(settings.validate()?)
|
||||
}
|
||||
|
@ -735,6 +735,9 @@ pub fn perform_facet_search(
|
||||
if let Some(facet_query) = &facet_query {
|
||||
facet_search.query(facet_query);
|
||||
}
|
||||
if let Some(max_facets) = index.max_values_per_facet(&rtxn)? {
|
||||
facet_search.max_values(max_facets as usize);
|
||||
}
|
||||
|
||||
Ok(FacetSearchResult {
|
||||
facet_hits: facet_search.execute()?,
|
||||
@ -897,6 +900,14 @@ fn format_fields<'a>(
|
||||
let mut matches_position = compute_matches.then(BTreeMap::new);
|
||||
let mut document = document.clone();
|
||||
|
||||
// reduce the formatted option list to the attributes that should be formatted,
|
||||
// instead of all the attributes to display.
|
||||
let formatting_fields_options: Vec<_> = formatted_options
|
||||
.iter()
|
||||
.filter(|(_, option)| option.should_format())
|
||||
.map(|(fid, option)| (field_ids_map.name(*fid).unwrap(), option))
|
||||
.collect();
|
||||
|
||||
// select the attributes to retrieve
|
||||
let displayable_names =
|
||||
displayable_ids.iter().map(|&fid| field_ids_map.name(fid).expect("Missing field name"));
|
||||
@ -905,13 +916,15 @@ fn format_fields<'a>(
|
||||
// to the value and merge them together. eg. If a user said he wanted to highlight `doggo`
|
||||
// and crop `doggo.name`. `doggo.name` needs to be highlighted + cropped while `doggo.age` is only
|
||||
// highlighted.
|
||||
let format = formatted_options
|
||||
// Warn: The time to compute the format list scales with the number of fields to format;
|
||||
// cumulated with map_leaf_values that iterates over all the nested fields, it gives a quadratic complexity:
|
||||
// d*f where d is the total number of fields to display and f is the total number of fields to format.
|
||||
let format = formatting_fields_options
|
||||
.iter()
|
||||
.filter(|(field, _option)| {
|
||||
let name = field_ids_map.name(**field).unwrap();
|
||||
.filter(|(name, _option)| {
|
||||
milli::is_faceted_by(name, key) || milli::is_faceted_by(key, name)
|
||||
})
|
||||
.map(|(_, option)| *option)
|
||||
.map(|(_, option)| **option)
|
||||
.reduce(|acc, option| acc.merge(option));
|
||||
let mut infos = Vec::new();
|
||||
|
||||
@ -1008,7 +1021,7 @@ fn format_value<'a>(
|
||||
let value = matcher.format(format_options);
|
||||
Value::String(value.into_owned())
|
||||
}
|
||||
None => Value::Number(number),
|
||||
None => Value::String(s),
|
||||
}
|
||||
}
|
||||
value => value,
|
||||
|
@ -64,7 +64,7 @@ impl Display for Value {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
json_string!(self, { ".enqueuedAt" => "[date]", ".processedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" })
|
||||
json_string!(self, { ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" })
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1760,6 +1760,181 @@ async fn add_documents_invalid_geo_field() {
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
// The three next tests are related to #4333
|
||||
|
||||
// _geo has a lat and lng but set to `null`
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "12",
|
||||
"_geo": { "lng": null, "lat": 67}
|
||||
}
|
||||
]);
|
||||
|
||||
let (response, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@r###"
|
||||
{
|
||||
"uid": 14,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse longitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
// _geo has a lat and lng but set to `null`
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "12",
|
||||
"_geo": { "lng": 35, "lat": null }
|
||||
}
|
||||
]);
|
||||
|
||||
let (response, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@r###"
|
||||
{
|
||||
"uid": 15,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse latitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
// _geo has a lat and lng but set to `null`
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "13",
|
||||
"_geo": { "lng": null, "lat": null }
|
||||
}
|
||||
]);
|
||||
|
||||
let (response, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@r###"
|
||||
{
|
||||
"uid": 16,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse latitude nor longitude in the document with the id: `13`. Was expecting finite numbers but instead got `null` and `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
// Related to #4333
|
||||
#[actix_rt::test]
|
||||
async fn add_invalid_geo_and_then_settings() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
index.create(Some("id")).await;
|
||||
|
||||
// _geo is not an object
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "11",
|
||||
"_geo": { "lat": null, "lng": null },
|
||||
}
|
||||
]);
|
||||
let (ret, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let ret = index.wait_task(ret.uid()).await;
|
||||
snapshot!(ret, @r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
let (ret, code) = index.update_settings(json!({"sortableAttributes": ["_geo"]})).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let ret = index.wait_task(ret.uid()).await;
|
||||
snapshot!(ret, @r###"
|
||||
{
|
||||
"uid": 2,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "settingsUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"sortableAttributes": [
|
||||
"_geo"
|
||||
]
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse latitude in the document with the id: `\"11\"`. Was expecting a finite number but instead got `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
|
@ -59,7 +59,7 @@ async fn import_dump_v1_movie_raw() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -77,8 +77,7 @@ async fn import_dump_v1_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -221,7 +220,7 @@ async fn import_dump_v1_movie_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -239,8 +238,7 @@ async fn import_dump_v1_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -369,7 +367,7 @@ async fn import_dump_v1_rubygems_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -387,8 +385,7 @@ async fn import_dump_v1_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -503,7 +500,7 @@ async fn import_dump_v2_movie_raw() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -521,8 +518,7 @@ async fn import_dump_v2_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -649,7 +645,7 @@ async fn import_dump_v2_movie_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -667,8 +663,7 @@ async fn import_dump_v2_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -794,7 +789,7 @@ async fn import_dump_v2_rubygems_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -812,8 +807,7 @@ async fn import_dump_v2_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -928,7 +922,7 @@ async fn import_dump_v3_movie_raw() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -946,8 +940,7 @@ async fn import_dump_v3_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1074,7 +1067,7 @@ async fn import_dump_v3_movie_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -1092,8 +1085,7 @@ async fn import_dump_v3_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1219,7 +1211,7 @@ async fn import_dump_v3_rubygems_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -1237,8 +1229,7 @@ async fn import_dump_v3_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1353,7 +1344,7 @@ async fn import_dump_v4_movie_raw() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -1371,8 +1362,7 @@ async fn import_dump_v4_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1499,7 +1489,7 @@ async fn import_dump_v4_movie_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -1517,8 +1507,7 @@ async fn import_dump_v4_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1644,7 +1633,7 @@ async fn import_dump_v4_rubygems_with_settings() {
|
||||
"dictionary": [],
|
||||
"synonyms": {},
|
||||
"distinctAttribute": null,
|
||||
"proximityPrecision": null,
|
||||
"proximityPrecision": "byWord",
|
||||
"typoTolerance": {
|
||||
"enabled": true,
|
||||
"minWordSizeForTypos": {
|
||||
@ -1662,8 +1651,7 @@ async fn import_dump_v4_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1907,8 +1895,7 @@ async fn import_dump_v6_containing_experimental_features() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
},
|
||||
"embedders": {}
|
||||
}
|
||||
}
|
||||
"###);
|
||||
|
||||
|
@ -105,6 +105,24 @@ async fn more_advanced_facet_search() {
|
||||
snapshot!(response["facetHits"].as_array().unwrap().len(), @"1");
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn simple_facet_search_with_max_values() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
let documents = DOCUMENTS.clone();
|
||||
index.update_settings_faceting(json!({ "maxValuesPerFacet": 1 })).await;
|
||||
index.update_settings_filterable_attributes(json!(["genres"])).await;
|
||||
index.add_documents(documents, None).await;
|
||||
index.wait_task(2).await;
|
||||
|
||||
let (response, code) =
|
||||
index.facet_search(json!({"facetName": "genres", "facetQuery": "a"})).await;
|
||||
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
assert_eq!(dbg!(response)["facetHits"].as_array().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn non_filterable_facet_search_error() {
|
||||
let server = Server::new().await;
|
||||
|
@ -21,9 +21,9 @@ async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Inde
|
||||
"###);
|
||||
|
||||
let (response, code) = index
|
||||
.update_settings(
|
||||
json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }),
|
||||
)
|
||||
.update_settings(json!({ "embedders": {"default": {
|
||||
"source": "userProvided",
|
||||
"dimensions": 2}}} ))
|
||||
.await;
|
||||
assert_eq!(202, code, "{:?}", response);
|
||||
index.wait_task(response.uid()).await;
|
||||
@ -56,6 +56,15 @@ static SIMPLE_SEARCH_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
|
||||
}])
|
||||
});
|
||||
|
||||
static SINGLE_DOCUMENT: Lazy<Value> = Lazy::new(|| {
|
||||
json!([{
|
||||
"title": "Shazam!",
|
||||
"desc": "a Captain Marvel ersatz",
|
||||
"id": "1",
|
||||
"_vectors": {"default": [1.0, 3.0]},
|
||||
}])
|
||||
});
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn simple_search() {
|
||||
let server = Server::new().await;
|
||||
@ -78,6 +87,52 @@ async fn simple_search() {
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn highlighter() {
|
||||
let server = Server::new().await;
|
||||
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
|
||||
|
||||
let (response, code) = index
|
||||
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
|
||||
"hybrid": {"semanticRatio": 0.2},
|
||||
"attributesToHighlight": [
|
||||
"desc"
|
||||
],
|
||||
"highlightPreTag": "**BEGIN**",
|
||||
"highlightPostTag": "**END**"
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}}}]"###);
|
||||
|
||||
let (response, code) = index
|
||||
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
|
||||
"hybrid": {"semanticRatio": 0.8},
|
||||
"attributesToHighlight": [
|
||||
"desc"
|
||||
],
|
||||
"highlightPreTag": "**BEGIN**",
|
||||
"highlightPostTag": "**END**"
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_semanticScore":0.9472136}]"###);
|
||||
|
||||
// no highlighting on full semantic
|
||||
let (response, code) = index
|
||||
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
|
||||
"hybrid": {"semanticRatio": 1.0},
|
||||
"attributesToHighlight": [
|
||||
"desc"
|
||||
],
|
||||
"highlightPreTag": "**BEGIN**",
|
||||
"highlightPostTag": "**END**"
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}}]"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn invalid_semantic_ratio() {
|
||||
let server = Server::new().await;
|
||||
@ -149,3 +204,18 @@ async fn invalid_semantic_ratio() {
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn single_document() {
|
||||
let server = Server::new().await;
|
||||
let index = index_with_documents(&server, &SINGLE_DOCUMENT).await;
|
||||
|
||||
let (response, code) = index
|
||||
.search_post(
|
||||
json!({"vector": [1.0, 3.0], "hybrid": {"semanticRatio": 1.0}, "showRankingScore": true}),
|
||||
)
|
||||
.await;
|
||||
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"][0], @r###"{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0,"_semanticScore":1.0}"###);
|
||||
}
|
||||
|
@ -890,13 +890,21 @@ async fn experimental_feature_vector_store() {
|
||||
let (response, code) = index
|
||||
.update_settings(json!({"embedders": {
|
||||
"manual": {
|
||||
"source": {
|
||||
"userProvided": {"dimensions": 3}
|
||||
}
|
||||
"source": "userProvided",
|
||||
"dimensions": 3,
|
||||
}
|
||||
}}))
|
||||
.await;
|
||||
|
||||
meili_snap::snapshot!(response, @r###"
|
||||
{
|
||||
"taskUid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "enqueued",
|
||||
"type": "settingsUpdate",
|
||||
"enqueuedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
meili_snap::snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
|
||||
|
@ -54,7 +54,7 @@ async fn get_settings() {
|
||||
let (response, code) = index.settings().await;
|
||||
assert_eq!(code, 200);
|
||||
let settings = response.as_object().unwrap();
|
||||
assert_eq!(settings.keys().len(), 16);
|
||||
assert_eq!(settings.keys().len(), 15);
|
||||
assert_eq!(settings["displayedAttributes"], json!(["*"]));
|
||||
assert_eq!(settings["searchableAttributes"], json!(["*"]));
|
||||
assert_eq!(settings["filterableAttributes"], json!([]));
|
||||
@ -83,7 +83,7 @@ async fn get_settings() {
|
||||
"maxTotalHits": 1000,
|
||||
})
|
||||
);
|
||||
assert_eq!(settings["embedders"], json!({}));
|
||||
assert_eq!(settings["proximityPrecision"], json!("byWord"));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
|
@ -77,7 +77,7 @@ csv = "1.2.1"
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
|
||||
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" }
|
||||
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1", default_features = false, features = ["onig"] }
|
||||
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
|
||||
"online",
|
||||
] }
|
||||
|
@ -192,7 +192,7 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
|
||||
#[error(transparent)]
|
||||
InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
|
||||
#[error("Invalid prompt in for embeddings with name '{0}': {1}.")]
|
||||
#[error("`.embedders.{0}.documentTemplate`: Invalid template: {1}.")]
|
||||
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
|
||||
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
|
||||
TooManyEmbedders(usize),
|
||||
@ -200,6 +200,33 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
InvalidEmbedder(String),
|
||||
#[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")]
|
||||
TooManyVectors(String, usize),
|
||||
#[error("`.embedders.{embedder_name}`: Field `{field}` unavailable for source `{source_}` (only available for sources: {}). Available fields: {}",
|
||||
allowed_sources_for_field
|
||||
.iter()
|
||||
.map(|accepted| format!("`{}`", accepted))
|
||||
.collect::<Vec<String>>()
|
||||
.join(", "),
|
||||
allowed_fields_for_source
|
||||
.iter()
|
||||
.map(|accepted| format!("`{}`", accepted))
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
)]
|
||||
InvalidFieldForSource {
|
||||
embedder_name: String,
|
||||
source_: crate::vector::settings::EmbedderSource,
|
||||
field: &'static str,
|
||||
allowed_fields_for_source: &'static [&'static str],
|
||||
allowed_sources_for_field: &'static [crate::vector::settings::EmbedderSource],
|
||||
},
|
||||
#[error("`.embedders.{embedder_name}.model`: Invalid model `{model}` for OpenAI. Supported models: {:?}", crate::vector::openai::EmbeddingModel::supported_models())]
|
||||
InvalidOpenAiModel { embedder_name: String, model: String },
|
||||
#[error("`.embedders.{embedder_name}`: Missing field `{field}` (note: this field is mandatory for source {source_})")]
|
||||
MissingFieldForSource {
|
||||
field: &'static str,
|
||||
source_: crate::vector::settings::EmbedderSource,
|
||||
embedder_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<crate::vector::Error> for Error {
|
||||
|
@ -222,72 +222,3 @@ where
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use heed::BytesDecode;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::lexicographically_iterate_over_facet_distribution;
|
||||
use crate::heed_codec::facet::OrderedF64Codec;
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::tests::{get_random_looking_index, get_simple_index};
|
||||
|
||||
#[test]
|
||||
fn filter_distribution_all() {
|
||||
let indexes = [get_simple_index(), get_random_looking_index()];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (0..=255).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
lexicographically_iterate_over_facet_distribution(
|
||||
&txn,
|
||||
index.content,
|
||||
0,
|
||||
&candidates,
|
||||
|facet, count, _| {
|
||||
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
|
||||
results.push_str(&format!("{facet}: {count}\n"));
|
||||
Ok(ControlFlow::Continue(()))
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_distribution_all_stop_early() {
|
||||
let indexes = [get_simple_index(), get_random_looking_index()];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (0..=255).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let mut nbr_facets = 0;
|
||||
lexicographically_iterate_over_facet_distribution(
|
||||
&txn,
|
||||
index.content,
|
||||
0,
|
||||
&candidates,
|
||||
|facet, count, _| {
|
||||
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
|
||||
if nbr_facets == 100 {
|
||||
Ok(ControlFlow::Break(()))
|
||||
} else {
|
||||
nbr_facets += 1;
|
||||
results.push_str(&format!("{facet}: {count}\n"));
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -303,347 +303,3 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Bound;
|
||||
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::find_docids_of_facet_within_bounds;
|
||||
use crate::heed_codec::facet::{FacetGroupKeyCodec, OrderedF64Codec};
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::tests::{
|
||||
get_random_looking_index, get_random_looking_index_with_multiple_field_ids,
|
||||
get_simple_index, get_simple_index_with_multiple_field_ids,
|
||||
};
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
|
||||
#[test]
|
||||
fn random_looking_index_snap() {
|
||||
let index = get_random_looking_index();
|
||||
milli_snap!(format!("{index}"), @"3256c76a7c1b768a013e78d5fa6e9ff9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_looking_index_with_multiple_field_ids_snap() {
|
||||
let index = get_random_looking_index_with_multiple_field_ids();
|
||||
milli_snap!(format!("{index}"), @"c3e5fe06a8f1c404ed4935b32c90a89b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_index_snap() {
|
||||
let index = get_simple_index();
|
||||
milli_snap!(format!("{index}"), @"5dbfa134cc44abeb3ab6242fc182e48e");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_index_with_multiple_field_ids_snap() {
|
||||
let index = get_simple_index_with_multiple_field_ids();
|
||||
milli_snap!(format!("{index}"), @"a4893298218f682bc76357f46777448c");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_range_increasing() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(0.);
|
||||
let end = Bound::Included(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!("0 <= . <= {i} : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("included_{i}"));
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Excluded(0.);
|
||||
let end = Bound::Excluded(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!("0 < . < {i} : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("excluded_{i}"));
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn filter_range_decreasing() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=255).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Included(255.);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!("{i} <= . <= 255 : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("included_{i}"));
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=255).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Excluded(i);
|
||||
let end = Bound::Excluded(255.);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!("{i} < . < 255 : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("excluded_{i}"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn filter_range_pinch() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=128).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Included(255. - i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!(
|
||||
"{i} <= . <= {r} : {docids}\n",
|
||||
r = 255. - i,
|
||||
docids = display_bitmap(&docids)
|
||||
));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("included_{i}"));
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=128).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Excluded(i);
|
||||
let end = Bound::Excluded(255. - i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!(
|
||||
"{i} < . < {r} {docids}\n",
|
||||
r = 255. - i,
|
||||
docids = display_bitmap(&docids)
|
||||
));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("excluded_{i}"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_range_unbounded() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Unbounded;
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!(">= {i}: {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("start_from_included_{i}"));
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Unbounded;
|
||||
let end = Bound::Included(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!("<= {i}: {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("end_at_included_{i}"));
|
||||
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&Bound::Unbounded,
|
||||
&Bound::Unbounded,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(
|
||||
&format!("all field_id 0: {}\n", display_bitmap(&docids)),
|
||||
format!("unbounded_field_id_0_{i}")
|
||||
);
|
||||
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
1,
|
||||
&Bound::Unbounded,
|
||||
&Bound::Unbounded,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(
|
||||
&format!("all field_id 1: {}\n", display_bitmap(&docids)),
|
||||
format!("unbounded_field_id_1_{i}")
|
||||
);
|
||||
|
||||
drop(txn);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_range_exact() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let mut results_0 = String::new();
|
||||
let mut results_1 = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Included(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results_0.push_str(&format!("{i}: {}\n", display_bitmap(&docids)));
|
||||
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
1,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results_1.push_str(&format!("{i}: {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results_0, format!("field_id_0_exact_{i}"));
|
||||
milli_snap!(results_1, format!("field_id_1_exact_{i}"));
|
||||
|
||||
drop(txn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -112,119 +112,3 @@ impl<'t, 'e> Iterator for AscendingFacetSort<'t, 'e> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::facet_sort_ascending::ascending_facet_sort;
|
||||
use crate::search::facet::tests::{
|
||||
get_random_looking_index, get_random_looking_string_index_with_multiple_field_ids,
|
||||
get_simple_index, get_simple_string_index_with_multiple_field_ids,
|
||||
};
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending() {
|
||||
let indexes = [get_simple_index(), get_random_looking_index()];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 0, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_multiple_field_ids() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-0"));
|
||||
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-1"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_no_candidates() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_inexisting_field_id() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 3, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -117,128 +117,3 @@ impl<'t> Iterator for DescendingFacetSort<'t> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::heed_codec::facet::FacetGroupKeyCodec;
|
||||
use crate::heed_codec::BytesRefCodec;
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::facet_sort_descending::descending_facet_sort;
|
||||
use crate::search::facet::tests::{
|
||||
get_random_looking_index, get_random_looking_string_index_with_multiple_field_ids,
|
||||
get_simple_index, get_simple_index_with_multiple_field_ids,
|
||||
get_simple_string_index_with_multiple_field_ids,
|
||||
};
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
|
||||
#[test]
|
||||
fn filter_sort_descending() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let db = index.content.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
|
||||
let iter = descending_facet_sort(&txn, db, 0, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_descending_multiple_field_ids() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let db = index.content.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
|
||||
let iter = descending_facet_sort(&txn, db, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-0"));
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
let iter = descending_facet_sort(&txn, db, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-1"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_no_candidates() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = descending_facet_sort(&txn, index.content, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
let mut results = String::new();
|
||||
let iter = descending_facet_sort(&txn, index.content, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_inexisting_field_id() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = descending_facet_sort(&txn, index.content, 3, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -116,109 +116,3 @@ pub(crate) fn get_highest_level<'t>(
|
||||
})
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use rand::{Rng, SeedableRng};
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::heed_codec::facet::OrderedF64Codec;
|
||||
use crate::heed_codec::StrRefCodec;
|
||||
use crate::update::facet::test_helpers::FacetIndex;
|
||||
|
||||
pub fn get_simple_index() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
for i in 0..256u16 {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(i as u32);
|
||||
index.insert(&mut txn, 0, &(i as f64), &bitmap);
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_random_looking_index() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
|
||||
|
||||
for (_i, key) in std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).enumerate() {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(key);
|
||||
bitmap.insert(key + 100);
|
||||
index.insert(&mut txn, 0, &(key as f64), &bitmap);
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_simple_index_with_multiple_field_ids() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
for fid in 0..2 {
|
||||
for i in 0..256u16 {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(i as u32);
|
||||
index.insert(&mut txn, fid, &(i as f64), &bitmap);
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_random_looking_index_with_multiple_field_ids() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
|
||||
let keys =
|
||||
std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).collect::<Vec<u32>>();
|
||||
for fid in 0..2 {
|
||||
for (_i, &key) in keys.iter().enumerate() {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(key);
|
||||
bitmap.insert(key + 100);
|
||||
index.insert(&mut txn, fid, &(key as f64), &bitmap);
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_simple_string_index_with_multiple_field_ids() -> FacetIndex<StrRefCodec> {
|
||||
let index = FacetIndex::<StrRefCodec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
for fid in 0..2 {
|
||||
for i in 0..256u16 {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(i as u32);
|
||||
if i % 2 == 0 {
|
||||
index.insert(&mut txn, fid, &format!("{i}").as_str(), &bitmap);
|
||||
} else {
|
||||
index.insert(&mut txn, fid, &"", &bitmap);
|
||||
}
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_random_looking_string_index_with_multiple_field_ids() -> FacetIndex<StrRefCodec> {
|
||||
let index = FacetIndex::<StrRefCodec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
|
||||
let keys =
|
||||
std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).collect::<Vec<u32>>();
|
||||
for fid in 0..2 {
|
||||
for (_i, &key) in keys.iter().enumerate() {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(key);
|
||||
bitmap.insert(key + 100);
|
||||
if key % 2 == 0 {
|
||||
index.insert(&mut txn, fid, &format!("{key}").as_str(), &bitmap);
|
||||
} else {
|
||||
index.insert(&mut txn, fid, &"", &bitmap);
|
||||
}
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ impl ScoreWithRatioResult {
|
||||
}
|
||||
|
||||
SearchResult {
|
||||
matching_words: left.matching_words,
|
||||
matching_words: right.matching_words,
|
||||
candidates: left.candidates | right.candidates,
|
||||
documents_ids,
|
||||
document_scores,
|
||||
|
@ -27,8 +27,8 @@ static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
|
||||
static LEVDIST1: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(1, true));
|
||||
static LEVDIST2: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(2, true));
|
||||
|
||||
/// The maximum number of facets returned by the facet search route.
|
||||
const MAX_NUMBER_OF_FACETS: usize = 100;
|
||||
/// The maximum number of values per facet returned by the facet search route.
|
||||
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
|
||||
|
||||
pub mod facet;
|
||||
mod fst_utils;
|
||||
@ -306,6 +306,7 @@ pub struct SearchForFacetValues<'a> {
|
||||
query: Option<String>,
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
max_values: usize,
|
||||
is_hybrid: bool,
|
||||
}
|
||||
|
||||
@ -315,7 +316,13 @@ impl<'a> SearchForFacetValues<'a> {
|
||||
search_query: Search<'a>,
|
||||
is_hybrid: bool,
|
||||
) -> SearchForFacetValues<'a> {
|
||||
SearchForFacetValues { query: None, facet, search_query, is_hybrid }
|
||||
SearchForFacetValues {
|
||||
query: None,
|
||||
facet,
|
||||
search_query,
|
||||
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
|
||||
is_hybrid,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
|
||||
@ -323,6 +330,11 @@ impl<'a> SearchForFacetValues<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_values(&mut self, max: usize) -> &mut Self {
|
||||
self.max_values = max;
|
||||
self
|
||||
}
|
||||
|
||||
fn one_original_value_of(
|
||||
&self,
|
||||
field_id: FieldId,
|
||||
@ -462,7 +474,7 @@ impl<'a> SearchForFacetValues<'a> {
|
||||
.unwrap_or_else(|| left_bound.to_string());
|
||||
results.push(FacetValueHit { value, count });
|
||||
}
|
||||
if results.len() >= MAX_NUMBER_OF_FACETS {
|
||||
if results.len() >= self.max_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -507,7 +519,7 @@ impl<'a> SearchForFacetValues<'a> {
|
||||
.unwrap_or_else(|| query.to_string());
|
||||
results.push(FacetValueHit { value, count });
|
||||
}
|
||||
if results.len() >= MAX_NUMBER_OF_FACETS {
|
||||
if results.len() >= self.max_values {
|
||||
return Ok(ControlFlow::Break(()));
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ pub struct BucketSortOutput {
|
||||
|
||||
// TODO: would probably be good to regroup some of these inside of a struct?
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[logging_timer::time]
|
||||
pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
ctx: &mut SearchContext<'ctx>,
|
||||
mut ranking_rules: Vec<BoxRankingRule<'ctx, Q>>,
|
||||
|
@ -72,7 +72,7 @@ impl<'m> MatcherBuilder<'m> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default)]
|
||||
#[derive(Copy, Clone, Default, Debug)]
|
||||
pub struct FormatOptions {
|
||||
pub highlight: bool,
|
||||
pub crop: Option<usize>,
|
||||
@ -82,6 +82,10 @@ impl FormatOptions {
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
Self { highlight: self.highlight || other.highlight, crop: self.crop.or(other.crop) }
|
||||
}
|
||||
|
||||
pub fn should_format(&self) -> bool {
|
||||
self.highlight || self.crop.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -191,6 +191,7 @@ fn resolve_maximally_reduced_query_graph(
|
||||
Ok(docids)
|
||||
}
|
||||
|
||||
#[logging_timer::time]
|
||||
fn resolve_universe(
|
||||
ctx: &mut SearchContext,
|
||||
initial_universe: &RoaringBitmap,
|
||||
@ -556,6 +557,7 @@ pub fn execute_vector_search(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[logging_timer::time]
|
||||
pub fn execute_search(
|
||||
ctx: &mut SearchContext,
|
||||
query: Option<&str>,
|
||||
|
@ -5,6 +5,7 @@ use super::*;
|
||||
use crate::{Result, SearchContext, MAX_WORD_LENGTH};
|
||||
|
||||
/// Convert the tokenised search query into a list of located query terms.
|
||||
#[logging_timer::time]
|
||||
pub fn located_query_terms_from_tokens(
|
||||
ctx: &mut SearchContext,
|
||||
query: NormalizedTokenIter,
|
||||
|
@ -407,54 +407,6 @@ mod tests {
|
||||
test("large_group_small_min_level", 16, 2);
|
||||
test("odd_group_odd_min_level", 7, 3);
|
||||
}
|
||||
#[test]
|
||||
fn insert_delete_field_insert() {
|
||||
let test = |name: &str, group_size: u8, min_level_size: u8| {
|
||||
let index =
|
||||
FacetIndex::<OrderedF64Codec>::new(group_size, 0 /*NA*/, min_level_size);
|
||||
let mut wtxn = index.env.write_txn().unwrap();
|
||||
|
||||
let mut elements = Vec::<((u16, f64), RoaringBitmap)>::new();
|
||||
for i in 0..100u32 {
|
||||
// field id = 0, left_bound = i, docids = [i]
|
||||
elements.push(((0, i as f64), once(i).collect()));
|
||||
}
|
||||
for i in 0..100u32 {
|
||||
// field id = 1, left_bound = i, docids = [i]
|
||||
elements.push(((1, i as f64), once(i).collect()));
|
||||
}
|
||||
index.bulk_insert(&mut wtxn, &[0, 1], elements.iter());
|
||||
|
||||
index.verify_structure_validity(&wtxn, 0);
|
||||
index.verify_structure_validity(&wtxn, 1);
|
||||
// delete all the elements for the facet id 0
|
||||
for i in 0..100u32 {
|
||||
index.delete_single_docid(&mut wtxn, 0, &(i as f64), i);
|
||||
}
|
||||
index.verify_structure_validity(&wtxn, 0);
|
||||
index.verify_structure_validity(&wtxn, 1);
|
||||
|
||||
let mut elements = Vec::<((u16, f64), RoaringBitmap)>::new();
|
||||
// then add some elements again for the facet id 1
|
||||
for i in 0..110u32 {
|
||||
// field id = 1, left_bound = i, docids = [i]
|
||||
elements.push(((1, i as f64), once(i).collect()));
|
||||
}
|
||||
index.verify_structure_validity(&wtxn, 0);
|
||||
index.verify_structure_validity(&wtxn, 1);
|
||||
index.bulk_insert(&mut wtxn, &[0, 1], elements.iter());
|
||||
|
||||
wtxn.commit().unwrap();
|
||||
|
||||
milli_snap!(format!("{index}"), name);
|
||||
};
|
||||
|
||||
test("default", 4, 5);
|
||||
test("small_group_small_min_level", 2, 2);
|
||||
test("small_group_large_min_level", 2, 128);
|
||||
test("large_group_small_min_level", 16, 2);
|
||||
test("odd_group_odd_min_level", 7, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bug_3165() {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -72,7 +72,6 @@ two methods.
|
||||
Related PR: https://github.com/meilisearch/milli/pull/619
|
||||
*/
|
||||
|
||||
pub const FACET_MAX_GROUP_SIZE: u8 = 8;
|
||||
pub const FACET_GROUP_SIZE: u8 = 4;
|
||||
pub const FACET_MIN_LEVEL_SIZE: u8 = 5;
|
||||
|
||||
@ -88,17 +87,14 @@ use heed::BytesEncode;
|
||||
use log::debug;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use self::incremental::FacetsUpdateIncremental;
|
||||
use super::FacetsUpdateBulk;
|
||||
use crate::facet::FacetType;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec};
|
||||
use crate::heed_codec::BytesRefCodec;
|
||||
use crate::heed_codec::facet::FacetGroupKey;
|
||||
use crate::update::index_documents::create_sorter;
|
||||
use crate::update::merge_btreeset_string;
|
||||
use crate::{BEU16StrCodec, Index, Result, MAX_FACET_VALUE_LENGTH};
|
||||
|
||||
pub mod bulk;
|
||||
pub mod incremental;
|
||||
|
||||
/// A builder used to add new elements to the `facet_id_string_docids` or `facet_id_f64_docids` databases.
|
||||
///
|
||||
@ -106,11 +102,9 @@ pub mod incremental;
|
||||
/// a bulk update method or an incremental update method.
|
||||
pub struct FacetsUpdate<'i> {
|
||||
index: &'i Index,
|
||||
database: heed::Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
|
||||
facet_type: FacetType,
|
||||
delta_data: grenad::Reader<BufReader<File>>,
|
||||
group_size: u8,
|
||||
max_group_size: u8,
|
||||
min_level_size: u8,
|
||||
}
|
||||
impl<'i> FacetsUpdate<'i> {
|
||||
@ -119,19 +113,9 @@ impl<'i> FacetsUpdate<'i> {
|
||||
facet_type: FacetType,
|
||||
delta_data: grenad::Reader<BufReader<File>>,
|
||||
) -> Self {
|
||||
let database = match facet_type {
|
||||
FacetType::String => {
|
||||
index.facet_id_string_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>()
|
||||
}
|
||||
FacetType::Number => {
|
||||
index.facet_id_f64_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>()
|
||||
}
|
||||
};
|
||||
Self {
|
||||
index,
|
||||
database,
|
||||
group_size: FACET_GROUP_SIZE,
|
||||
max_group_size: FACET_MAX_GROUP_SIZE,
|
||||
min_level_size: FACET_MIN_LEVEL_SIZE,
|
||||
facet_type,
|
||||
delta_data,
|
||||
@ -145,30 +129,16 @@ impl<'i> FacetsUpdate<'i> {
|
||||
debug!("Computing and writing the facet values levels docids into LMDB on disk...");
|
||||
self.index.set_updated_at(wtxn, &OffsetDateTime::now_utc())?;
|
||||
|
||||
// See self::comparison_bench::benchmark_facet_indexing
|
||||
if self.delta_data.len() >= (self.database.len(wtxn)? / 50) {
|
||||
let field_ids =
|
||||
self.index.faceted_fields_ids(wtxn)?.iter().copied().collect::<Vec<_>>();
|
||||
let bulk_update = FacetsUpdateBulk::new(
|
||||
self.index,
|
||||
field_ids,
|
||||
self.facet_type,
|
||||
self.delta_data,
|
||||
self.group_size,
|
||||
self.min_level_size,
|
||||
);
|
||||
bulk_update.execute(wtxn)?;
|
||||
} else {
|
||||
let incremental_update = FacetsUpdateIncremental::new(
|
||||
self.index,
|
||||
self.facet_type,
|
||||
self.delta_data,
|
||||
self.group_size,
|
||||
self.min_level_size,
|
||||
self.max_group_size,
|
||||
);
|
||||
incremental_update.execute(wtxn)?;
|
||||
}
|
||||
let field_ids = self.index.faceted_fields_ids(wtxn)?.iter().copied().collect::<Vec<_>>();
|
||||
let bulk_update = FacetsUpdateBulk::new(
|
||||
self.index,
|
||||
field_ids,
|
||||
self.facet_type,
|
||||
self.delta_data,
|
||||
self.group_size,
|
||||
self.min_level_size,
|
||||
);
|
||||
bulk_update.execute(wtxn)?;
|
||||
|
||||
// We clear the list of normalized-for-search facets
|
||||
// and the previous FSTs to compute everything from scratch
|
||||
@ -264,7 +234,6 @@ impl<'i> FacetsUpdate<'i> {
|
||||
pub(crate) mod test_helpers {
|
||||
use std::cell::Cell;
|
||||
use std::fmt::Display;
|
||||
use std::iter::FromIterator;
|
||||
use std::marker::PhantomData;
|
||||
use std::rc::Rc;
|
||||
|
||||
@ -280,7 +249,6 @@ pub(crate) mod test_helpers {
|
||||
use crate::search::facet::get_highest_level;
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
use crate::update::del_add::{DelAdd, KvWriterDelAdd};
|
||||
use crate::update::FacetsUpdateIncrementalInner;
|
||||
use crate::CboRoaringBitmapCodec;
|
||||
|
||||
/// Utility function to generate a string whose position in a lexicographically
|
||||
@ -396,49 +364,6 @@ pub(crate) mod test_helpers {
|
||||
self.min_level_size.set(std::cmp::max(1, min_level_size));
|
||||
}
|
||||
|
||||
pub fn insert<'a>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
field_id: u16,
|
||||
key: &'a <BoundCodec as BytesEncode<'a>>::EItem,
|
||||
docids: &RoaringBitmap,
|
||||
) {
|
||||
let update = FacetsUpdateIncrementalInner {
|
||||
db: self.content,
|
||||
group_size: self.group_size.get(),
|
||||
min_level_size: self.min_level_size.get(),
|
||||
max_group_size: self.max_group_size.get(),
|
||||
};
|
||||
let key_bytes = BoundCodec::bytes_encode(key).unwrap();
|
||||
update.insert(wtxn, field_id, &key_bytes, docids).unwrap();
|
||||
}
|
||||
pub fn delete_single_docid<'a>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
field_id: u16,
|
||||
key: &'a <BoundCodec as BytesEncode<'a>>::EItem,
|
||||
docid: u32,
|
||||
) {
|
||||
self.delete(wtxn, field_id, key, &RoaringBitmap::from_iter(std::iter::once(docid)))
|
||||
}
|
||||
|
||||
pub fn delete<'a>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
field_id: u16,
|
||||
key: &'a <BoundCodec as BytesEncode<'a>>::EItem,
|
||||
docids: &RoaringBitmap,
|
||||
) {
|
||||
let update = FacetsUpdateIncrementalInner {
|
||||
db: self.content,
|
||||
group_size: self.group_size.get(),
|
||||
min_level_size: self.min_level_size.get(),
|
||||
max_group_size: self.max_group_size.get(),
|
||||
};
|
||||
let key_bytes = BoundCodec::bytes_encode(key).unwrap();
|
||||
update.delete(wtxn, field_id, &key_bytes, docids).unwrap();
|
||||
}
|
||||
|
||||
pub fn bulk_insert<'a, 'b>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
@ -555,63 +480,3 @@ pub(crate) mod test_helpers {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[cfg(test)]
|
||||
mod comparison_bench {
|
||||
use std::iter::once;
|
||||
|
||||
use rand::Rng;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::test_helpers::FacetIndex;
|
||||
use crate::heed_codec::facet::OrderedF64Codec;
|
||||
|
||||
// This is a simple test to get an intuition on the relative speed
|
||||
// of the incremental vs. bulk indexer.
|
||||
//
|
||||
// The benchmark shows the worst-case scenario for the incremental indexer, since
|
||||
// each facet value contains only one document ID.
|
||||
//
|
||||
// In that scenario, it appears that the incremental indexer is about 50 times slower than the
|
||||
// bulk indexer.
|
||||
// #[test]
|
||||
fn benchmark_facet_indexing() {
|
||||
let mut facet_value = 0;
|
||||
|
||||
let mut r = rand::thread_rng();
|
||||
|
||||
for i in 1..=20 {
|
||||
let size = 50_000 * i;
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
let mut elements = Vec::<((u16, f64), RoaringBitmap)>::new();
|
||||
for i in 0..size {
|
||||
// field id = 0, left_bound = i, docids = [i]
|
||||
elements.push(((0, facet_value as f64), once(i).collect()));
|
||||
facet_value += 1;
|
||||
}
|
||||
let timer = std::time::Instant::now();
|
||||
index.bulk_insert(&mut txn, &[0], elements.iter());
|
||||
let time_spent = timer.elapsed().as_millis();
|
||||
println!("bulk {size} : {time_spent}ms");
|
||||
|
||||
txn.commit().unwrap();
|
||||
|
||||
for nbr_doc in [1, 100, 1000, 10_000] {
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
let timer = std::time::Instant::now();
|
||||
//
|
||||
// insert one document
|
||||
//
|
||||
for _ in 0..nbr_doc {
|
||||
index.insert(&mut txn, 0, &r.gen(), &once(1).collect());
|
||||
}
|
||||
let time_spent = timer.elapsed().as_millis();
|
||||
println!(" add {nbr_doc} : {time_spent}ms");
|
||||
txn.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,9 @@ pub fn extract_geo_points<R: io::Read + io::Seek>(
|
||||
// since we only need the primary key when we throw an error
|
||||
// we create this getter to lazily get it when needed
|
||||
let document_id = || -> Value {
|
||||
let document_id = obkv.get(primary_key_id).unwrap();
|
||||
let reader = KvReaderDelAdd::new(obkv.get(primary_key_id).unwrap());
|
||||
let document_id =
|
||||
reader.get(DelAdd::Deletion).or(reader.get(DelAdd::Addition)).unwrap();
|
||||
serde_json::from_slice(document_id).unwrap()
|
||||
};
|
||||
|
||||
|
@ -339,9 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
indexer: GrenadParameters,
|
||||
embedder: Arc<Embedder>,
|
||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
||||
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
|
||||
|
||||
// docid, state with embedding
|
||||
@ -375,11 +373,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
current_chunk_ids.push(docid);
|
||||
|
||||
if chunks.len() == chunks.capacity() {
|
||||
let chunked_embeds = rt
|
||||
.block_on(
|
||||
embedder
|
||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
||||
)
|
||||
let chunked_embeds = embedder
|
||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
@ -396,8 +391,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
|
||||
// send last chunk
|
||||
if !chunks.is_empty() {
|
||||
let chunked_embeds = rt
|
||||
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
||||
let chunked_embeds = embedder
|
||||
.embed_chunks(std::mem::take(&mut chunks))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
for (docid, embeddings) in chunks_ids
|
||||
@ -410,13 +405,15 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
}
|
||||
|
||||
if !current_chunk.is_empty() {
|
||||
let embeds = rt
|
||||
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
||||
let embeds = embedder
|
||||
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
if let Some(embeds) = embeds.first() {
|
||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ pub use merge_functions::{
|
||||
keep_first, keep_latest_obkv, merge_btreeset_string, merge_cbo_roaring_bitmaps,
|
||||
merge_deladd_cbo_roaring_bitmaps, merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap,
|
||||
merge_roaring_bitmaps, obkvs_keep_last_addition_merge_deletions,
|
||||
obkvs_merge_additions_and_deletions, serialize_roaring_bitmap, MergeFn,
|
||||
obkvs_merge_additions_and_deletions, MergeFn,
|
||||
};
|
||||
|
||||
use crate::MAX_WORD_LENGTH;
|
||||
|
@ -21,7 +21,7 @@ use slice_group_by::GroupBy;
|
||||
use typed_chunk::{write_typed_chunk_into_index, TypedChunk};
|
||||
|
||||
use self::enrich::enrich_documents_batch;
|
||||
pub use self::enrich::{extract_finite_float_from_value, validate_geo_from_json, DocumentId};
|
||||
pub use self::enrich::{extract_finite_float_from_value, DocumentId};
|
||||
pub use self::helpers::{
|
||||
as_cloneable_grenad, create_sorter, create_writer, fst_stream_into_hashset,
|
||||
fst_stream_into_vec, merge_btreeset_string, merge_cbo_roaring_bitmaps,
|
||||
@ -2553,7 +2553,7 @@ mod tests {
|
||||
/// Vectors must be of the same length.
|
||||
#[test]
|
||||
fn test_multiple_vectors() {
|
||||
use crate::vector::settings::{EmbedderSettings, EmbeddingSettings};
|
||||
use crate::vector::settings::EmbeddingSettings;
|
||||
let index = TempIndex::new();
|
||||
|
||||
index
|
||||
@ -2562,9 +2562,11 @@ mod tests {
|
||||
embedders.insert(
|
||||
"manual".to_string(),
|
||||
Setting::Set(EmbeddingSettings {
|
||||
embedder_options: Setting::Set(EmbedderSettings::UserProvided(
|
||||
crate::vector::settings::UserProvidedSettings { dimensions: 3 },
|
||||
)),
|
||||
source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(3),
|
||||
document_template: Setting::NotSet,
|
||||
}),
|
||||
);
|
||||
@ -2579,10 +2581,10 @@ mod tests {
|
||||
.unwrap();
|
||||
index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap();
|
||||
index
|
||||
.add_documents(
|
||||
documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]),
|
||||
)
|
||||
.unwrap();
|
||||
.add_documents(
|
||||
documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap();
|
||||
|
@ -1,14 +1,13 @@
|
||||
pub use self::available_documents_ids::AvailableDocumentsIds;
|
||||
pub use self::clear_documents::ClearDocuments;
|
||||
pub use self::facet::bulk::FacetsUpdateBulk;
|
||||
pub use self::facet::incremental::FacetsUpdateIncrementalInner;
|
||||
pub use self::index_documents::{
|
||||
merge_btreeset_string, merge_cbo_roaring_bitmaps, merge_roaring_bitmaps,
|
||||
DocumentAdditionResult, DocumentId, IndexDocuments, IndexDocumentsConfig, IndexDocumentsMethod,
|
||||
MergeFn,
|
||||
};
|
||||
pub use self::indexer_config::IndexerConfig;
|
||||
pub use self::settings::{Setting, Settings};
|
||||
pub use self::settings::{validate_embedding_settings, Setting, Settings};
|
||||
pub use self::update_step::UpdateIndexingStep;
|
||||
pub use self::word_prefix_docids::WordPrefixDocids;
|
||||
pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids;
|
||||
|
@ -17,7 +17,7 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::update::index_documents::IndexDocumentsMethod;
|
||||
use crate::update::{IndexDocuments, UpdateIndexingStep};
|
||||
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
|
||||
use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings};
|
||||
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
|
||||
use crate::{FieldsIdsMap, Index, OrderBy, Result};
|
||||
|
||||
@ -78,11 +78,19 @@ impl<T> Setting<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
/// Returns `true` if applying the new setting changed this setting
|
||||
pub fn apply(&mut self, new: Self) -> bool
|
||||
where
|
||||
T: PartialEq + Eq,
|
||||
{
|
||||
if let Setting::NotSet = new {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
if self == &new {
|
||||
return false;
|
||||
}
|
||||
*self = new;
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
@ -950,17 +958,23 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
.merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right))
|
||||
{
|
||||
match joined {
|
||||
// updated config
|
||||
EitherOrBoth::Both((name, mut old), (_, new)) => {
|
||||
old.apply(new);
|
||||
let new = validate_prompt(&name, old)?;
|
||||
changed = true;
|
||||
changed |= old.apply(new);
|
||||
let new = validate_embedding_settings(old, &name)?;
|
||||
new_configs.insert(name, new);
|
||||
}
|
||||
// unchanged config
|
||||
EitherOrBoth::Left((name, setting)) => {
|
||||
new_configs.insert(name, setting);
|
||||
}
|
||||
EitherOrBoth::Right((name, setting)) => {
|
||||
let setting = validate_prompt(&name, setting)?;
|
||||
// new config
|
||||
EitherOrBoth::Right((name, mut setting)) => {
|
||||
// apply the default source in case the source was not set so that it gets validated
|
||||
crate::vector::settings::EmbeddingSettings::apply_default_source(
|
||||
&mut setting,
|
||||
);
|
||||
let setting = validate_embedding_settings(setting, &name)?;
|
||||
changed = true;
|
||||
new_configs.insert(name, setting);
|
||||
}
|
||||
@ -1072,8 +1086,12 @@ fn validate_prompt(
|
||||
) -> Result<Setting<EmbeddingSettings>> {
|
||||
match new {
|
||||
Setting::Set(EmbeddingSettings {
|
||||
embedder_options,
|
||||
document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }),
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template: Setting::Set(template),
|
||||
}) => {
|
||||
// validate
|
||||
let template = crate::prompt::Prompt::new(template)
|
||||
@ -1081,16 +1099,71 @@ fn validate_prompt(
|
||||
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
|
||||
|
||||
Ok(Setting::Set(EmbeddingSettings {
|
||||
embedder_options,
|
||||
document_template: Setting::Set(PromptSettings {
|
||||
template: Setting::Set(template),
|
||||
}),
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template: Setting::Set(template),
|
||||
}))
|
||||
}
|
||||
new => Ok(new),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_embedding_settings(
|
||||
settings: Setting<EmbeddingSettings>,
|
||||
name: &str,
|
||||
) -> Result<Setting<EmbeddingSettings>> {
|
||||
let settings = validate_prompt(name, settings)?;
|
||||
let Setting::Set(settings) = settings else { return Ok(settings) };
|
||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
||||
settings;
|
||||
let Some(inferred_source) = source.set() else {
|
||||
return Ok(Setting::Set(EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
}));
|
||||
};
|
||||
match inferred_source {
|
||||
EmbedderSource::OpenAi => {
|
||||
check_unset(&revision, "revision", inferred_source, name)?;
|
||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||
if let Setting::Set(model) = &model {
|
||||
crate::vector::openai::EmbeddingModel::from_name(model.as_str()).ok_or(
|
||||
crate::error::UserError::InvalidOpenAiModel {
|
||||
embedder_name: name.to_owned(),
|
||||
model: model.clone(),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||
}
|
||||
EmbedderSource::UserProvided => {
|
||||
check_unset(&model, "model", inferred_source, name)?;
|
||||
check_unset(&revision, "revision", inferred_source, name)?;
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
|
||||
check_set(&dimensions, "dimensions", inferred_source, name)?;
|
||||
}
|
||||
}
|
||||
Ok(Setting::Set(EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use big_s::S;
|
||||
|
@ -67,6 +67,10 @@ pub enum EmbedErrorKind {
|
||||
OpenAiUnhandledStatusCode(u16),
|
||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||
ManualEmbed(String),
|
||||
#[error("could not initialize asynchronous runtime: {0}")]
|
||||
OpenAiRuntimeInit(std::io::Error),
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
@ -117,6 +121,14 @@ impl EmbedError {
|
||||
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@ -183,10 +195,6 @@ impl NewEmbedderError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
||||
}
|
||||
@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind {
|
||||
#[error("loading model failed: {0}")]
|
||||
LoadModel(candle_core::Error),
|
||||
// openai
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
||||
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
||||
}
|
||||
|
@ -145,7 +145,8 @@ impl Embedder {
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
tokens.truncate(512);
|
||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||
|
@ -163,18 +163,24 @@ impl Embedder {
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
||||
Embedder::OpenAi(embedder) => {
|
||||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
}
|
||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
/// # Panics
|
||||
///
|
||||
/// - if called from an asynchronous context
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
client: reqwest::Client,
|
||||
headers: reqwest::header::HeaderMap,
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
@ -34,6 +34,9 @@ pub struct EmbedderOptions {
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum EmbeddingModel {
|
||||
// # WARNING
|
||||
//
|
||||
// If ever adding a model, make sure to add it to the list of supported models below.
|
||||
#[default]
|
||||
#[serde(rename = "text-embedding-ada-002")]
|
||||
#[deserr(rename = "text-embedding-ada-002")]
|
||||
@ -41,6 +44,10 @@ pub enum EmbeddingModel {
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn supported_models() -> &'static [&'static str] {
|
||||
&["text-embedding-ada-002"]
|
||||
}
|
||||
|
||||
pub fn max_token(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => 8191,
|
||||
@ -59,7 +66,7 @@ impl EmbeddingModel {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_name(name: &'static str) -> Option<Self> {
|
||||
pub fn from_name(name: &str) -> Option<Self> {
|
||||
match name {
|
||||
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
|
||||
_ => None,
|
||||
@ -88,6 +95,13 @@ impl EmbedderOptions {
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||||
reqwest::ClientBuilder::new()
|
||||
.default_headers(self.headers.clone())
|
||||
.build()
|
||||
.map_err(EmbedError::openai_initialize_web_client)
|
||||
}
|
||||
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let mut inferred_api_key = Default::default();
|
||||
@ -104,25 +118,25 @@ impl Embedder {
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(NewEmbedderError::openai_initialize_web_client)?;
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
Ok(Self { options, client, tokenizer })
|
||||
Ok(Self { options, headers, tokenizer })
|
||||
}
|
||||
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
pub async fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut tokenized = false;
|
||||
|
||||
for attempt in 0..7 {
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts).await
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts).await
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
let retry_duration = match result {
|
||||
@ -138,9 +152,9 @@ impl Embedder {
|
||||
}
|
||||
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts).await
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts).await
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
result.map_err(Retry::into_error)
|
||||
@ -218,13 +232,13 @@ impl Embedder {
|
||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
for text in texts {
|
||||
log::trace!("Received prompt: {}", text.as_ref())
|
||||
}
|
||||
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
|
||||
let response = self
|
||||
.client
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
@ -249,7 +263,11 @@ impl Embedder {
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
async fn try_embed_tokenized(
|
||||
&self,
|
||||
text: &[String],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
pub const OVERLAP_SIZE: usize = 200;
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
@ -257,7 +275,7 @@ impl Embedder {
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.try_embed(&[text]).await?);
|
||||
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -266,22 +284,26 @@ impl Embedder {
|
||||
Embeddings::new(self.options.embedding_model.dimensions());
|
||||
while tokens.len() > max_token_count {
|
||||
let window = &tokens[..max_token_count];
|
||||
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
|
||||
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
|
||||
|
||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||
}
|
||||
|
||||
// end of text
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
||||
async fn embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
for attempt in 0..9 {
|
||||
let duration = match self.try_embed_tokens(tokens).await {
|
||||
let duration = match self.try_embed_tokens(tokens, client).await {
|
||||
Ok(embedding) => return Ok(embedding),
|
||||
Err(retry) => retry.into_duration(attempt),
|
||||
}
|
||||
@ -290,14 +312,19 @@ impl Embedder {
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
|
||||
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
self.try_embed_tokens(tokens, client)
|
||||
.await
|
||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
}
|
||||
|
||||
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
||||
async fn try_embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
let request =
|
||||
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
|
||||
let response = self
|
||||
.client
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
@ -315,12 +342,19 @@ impl Embedder {
|
||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
|
||||
.await
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.map_err(EmbedError::openai_runtime_init)?;
|
||||
let client = self.new_client()?;
|
||||
rt.block_on(futures::future::try_join_all(
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
|
@ -4,32 +4,189 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::prompt::PromptData;
|
||||
use crate::update::Setting;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
use crate::UserError;
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct EmbeddingSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")]
|
||||
#[deserr(default, rename = "source")]
|
||||
pub embedder_options: Setting<EmbedderSettings>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub document_template: Setting<PromptSettings>,
|
||||
pub source: Setting<EmbedderSource>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub model: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub revision: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub dimensions: Setting<usize>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub document_template: Setting<String>,
|
||||
}
|
||||
|
||||
pub fn check_unset<T>(
|
||||
key: &Setting<T>,
|
||||
field: &'static str,
|
||||
source: EmbedderSource,
|
||||
embedder_name: &str,
|
||||
) -> Result<(), UserError> {
|
||||
if matches!(key, Setting::NotSet) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(UserError::InvalidFieldForSource {
|
||||
embedder_name: embedder_name.to_owned(),
|
||||
source_: source,
|
||||
field,
|
||||
allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source),
|
||||
allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_set<T>(
|
||||
key: &Setting<T>,
|
||||
field: &'static str,
|
||||
source: EmbedderSource,
|
||||
embedder_name: &str,
|
||||
) -> Result<(), UserError> {
|
||||
if matches!(key, Setting::Set(_)) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(UserError::MissingFieldForSource {
|
||||
field,
|
||||
source_: source,
|
||||
embedder_name: embedder_name.to_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingSettings {
|
||||
pub const SOURCE: &'static str = "source";
|
||||
pub const MODEL: &'static str = "model";
|
||||
pub const REVISION: &'static str = "revision";
|
||||
pub const API_KEY: &'static str = "apiKey";
|
||||
pub const DIMENSIONS: &'static str = "dimensions";
|
||||
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
||||
|
||||
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
|
||||
match field {
|
||||
Self::SOURCE => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
|
||||
}
|
||||
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
||||
Self::REVISION => &[EmbedderSource::HuggingFace],
|
||||
Self::API_KEY => &[EmbedderSource::OpenAi],
|
||||
Self::DIMENSIONS => &[EmbedderSource::UserProvided],
|
||||
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
||||
_other => unreachable!("unknown field"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
|
||||
match source {
|
||||
EmbedderSource::OpenAi => {
|
||||
&[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE]
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
|
||||
}
|
||||
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_source(setting: &mut Setting<EmbeddingSettings>) {
|
||||
if let Setting::Set(EmbeddingSettings {
|
||||
source: source @ (Setting::NotSet | Setting::Reset),
|
||||
..
|
||||
}) = setting
|
||||
{
|
||||
*source = Setting::Set(EmbedderSource::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum EmbedderSource {
|
||||
#[default]
|
||||
OpenAi,
|
||||
HuggingFace,
|
||||
UserProvided,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EmbedderSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
EmbedderSource::OpenAi => "openAi",
|
||||
EmbedderSource::HuggingFace => "huggingFace",
|
||||
EmbedderSource::UserProvided => "userProvided",
|
||||
};
|
||||
f.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let EmbeddingSettings { embedder_options, document_template: prompt } = new;
|
||||
self.embedder_options.apply(embedder_options);
|
||||
self.document_template.apply(prompt);
|
||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
||||
new;
|
||||
let old_source = self.source;
|
||||
self.source.apply(source);
|
||||
// Reinitialize the whole setting object on a source change
|
||||
if old_source != self.source {
|
||||
*self = EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
self.model.apply(model);
|
||||
self.revision.apply(revision);
|
||||
self.api_key.apply(api_key);
|
||||
self.dimensions.apply(dimensions);
|
||||
self.document_template.apply(document_template);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
fn from(value: EmbeddingConfig) -> Self {
|
||||
Self {
|
||||
embedder_options: Setting::Set(value.embedder_options.into()),
|
||||
document_template: Setting::Set(value.prompt.into()),
|
||||
let EmbeddingConfig { embedder_options, prompt } = value;
|
||||
match embedder_options {
|
||||
super::EmbedderOptions::HuggingFace(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::HuggingFace),
|
||||
model: Setting::Set(options.model),
|
||||
revision: options.revision.map(Setting::Set).unwrap_or_default(),
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
},
|
||||
super::EmbedderOptions::OpenAi(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::OpenAi),
|
||||
model: Setting::Set(options.embedding_model.name().to_owned()),
|
||||
revision: Setting::NotSet,
|
||||
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
},
|
||||
super::EmbedderOptions::UserProvided(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::UserProvided),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(options.dimensions),
|
||||
document_template: Setting::NotSet,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -37,256 +194,51 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
fn from(value: EmbeddingSettings) -> Self {
|
||||
let mut this = Self::default();
|
||||
let EmbeddingSettings { embedder_options, document_template: prompt } = value;
|
||||
if let Some(embedder_options) = embedder_options.set() {
|
||||
this.embedder_options = embedder_options.into();
|
||||
}
|
||||
if let Some(prompt) = prompt.set() {
|
||||
this.prompt = prompt.into();
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct PromptSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub template: Setting<String>,
|
||||
}
|
||||
|
||||
impl PromptSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let PromptSettings { template } = new;
|
||||
self.template.apply(template);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PromptData> for PromptSettings {
|
||||
fn from(value: PromptData) -> Self {
|
||||
Self { template: Setting::Set(value.template) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PromptSettings> for PromptData {
|
||||
fn from(value: PromptSettings) -> Self {
|
||||
let mut this = PromptData::default();
|
||||
let PromptSettings { template } = value;
|
||||
if let Some(template) = template.set() {
|
||||
this.template = template;
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub enum EmbedderSettings {
|
||||
HuggingFace(Setting<HfEmbedderSettings>),
|
||||
OpenAi(Setting<OpenAiEmbedderSettings>),
|
||||
UserProvided(UserProvidedSettings),
|
||||
}
|
||||
|
||||
impl<E> Deserr<E> for EmbedderSettings
|
||||
where
|
||||
E: deserr::DeserializeError,
|
||||
{
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::Map(map) => {
|
||||
if deserr::Map::len(&map) != 1 {
|
||||
return Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::Unexpected {
|
||||
msg: format!(
|
||||
"Expected a single field, got {} fields",
|
||||
deserr::Map::len(&map)
|
||||
),
|
||||
},
|
||||
location,
|
||||
)));
|
||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
||||
value;
|
||||
if let Some(source) = source.set() {
|
||||
match source {
|
||||
EmbedderSource::OpenAi => {
|
||||
let mut options = super::openai::EmbedderOptions::with_default_model(None);
|
||||
if let Some(model) = model.set() {
|
||||
if let Some(model) = super::openai::EmbeddingModel::from_name(&model) {
|
||||
options.embedding_model = model;
|
||||
}
|
||||
}
|
||||
if let Some(api_key) = api_key.set() {
|
||||
options.api_key = Some(api_key);
|
||||
}
|
||||
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
||||
}
|
||||
let mut it = deserr::Map::into_iter(map);
|
||||
let (k, v) = it.next().unwrap();
|
||||
|
||||
match k.as_str() {
|
||||
"huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set(
|
||||
HfEmbedderSettings::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_key(&k),
|
||||
)?,
|
||||
))),
|
||||
"openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set(
|
||||
OpenAiEmbedderSettings::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_key(&k),
|
||||
)?,
|
||||
))),
|
||||
"userProvided" => Ok(EmbedderSettings::UserProvided(
|
||||
UserProvidedSettings::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_key(&k),
|
||||
)?,
|
||||
)),
|
||||
other => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::UnknownKey {
|
||||
key: other,
|
||||
accepted: &["huggingFace", "openAi", "userProvided"],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
EmbedderSource::HuggingFace => {
|
||||
let mut options = super::hf::EmbedderOptions::default();
|
||||
if let Some(model) = model.set() {
|
||||
options.model = model;
|
||||
// Reset the revision if we are setting the model.
|
||||
// This allows the following:
|
||||
// "huggingFace": {} -> default model with default revision
|
||||
// "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision
|
||||
// "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision
|
||||
options.revision = None;
|
||||
}
|
||||
if let Some(revision) = revision.set() {
|
||||
options.revision = Some(revision);
|
||||
}
|
||||
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
|
||||
}
|
||||
EmbedderSource::UserProvided => {
|
||||
this.embedder_options =
|
||||
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
|
||||
dimensions: dimensions.set().unwrap(),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::Map],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderSettings {
|
||||
fn default() -> Self {
|
||||
Self::OpenAi(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
|
||||
fn from(value: crate::vector::EmbedderOptions) -> Self {
|
||||
match value {
|
||||
crate::vector::EmbedderOptions::HuggingFace(hf) => {
|
||||
Self::HuggingFace(Setting::Set(hf.into()))
|
||||
}
|
||||
crate::vector::EmbedderOptions::OpenAi(openai) => {
|
||||
Self::OpenAi(Setting::Set(openai.into()))
|
||||
}
|
||||
crate::vector::EmbedderOptions::UserProvided(user_provided) => {
|
||||
Self::UserProvided(user_provided.into())
|
||||
}
|
||||
if let Setting::Set(template) = document_template {
|
||||
this.prompt = PromptData { template }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
|
||||
fn from(value: EmbedderSettings) -> Self {
|
||||
match value {
|
||||
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
|
||||
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
|
||||
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
|
||||
EmbedderSettings::OpenAi(_setting) => {
|
||||
Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None))
|
||||
}
|
||||
EmbedderSettings::UserProvided(user_provided) => {
|
||||
Self::UserProvided(user_provided.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct HfEmbedderSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub model: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub revision: Setting<String>,
|
||||
}
|
||||
|
||||
impl HfEmbedderSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let HfEmbedderSettings { model, revision } = new;
|
||||
self.model.apply(model);
|
||||
self.revision.apply(revision);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
|
||||
fn from(value: crate::vector::hf::EmbedderOptions) -> Self {
|
||||
Self {
|
||||
model: Setting::Set(value.model),
|
||||
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
||||
fn from(value: HfEmbedderSettings) -> Self {
|
||||
let HfEmbedderSettings { model, revision } = value;
|
||||
let mut this = Self::default();
|
||||
if let Some(model) = model.set() {
|
||||
this.model = model;
|
||||
}
|
||||
if let Some(revision) = revision.set() {
|
||||
this.revision = Some(revision);
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct OpenAiEmbedderSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")]
|
||||
#[deserr(default, rename = "model")]
|
||||
pub embedding_model: Setting<crate::vector::openai::EmbeddingModel>,
|
||||
}
|
||||
|
||||
impl OpenAiEmbedderSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let Self { api_key, embedding_model: embedding_mode } = new;
|
||||
self.api_key.apply(api_key);
|
||||
self.embedding_model.apply(embedding_mode);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
|
||||
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
|
||||
Self {
|
||||
api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset),
|
||||
embedding_model: Setting::Set(value.embedding_model),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
|
||||
fn from(value: OpenAiEmbedderSettings) -> Self {
|
||||
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
|
||||
Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct UserProvidedSettings {
|
||||
pub dimensions: usize,
|
||||
}
|
||||
|
||||
impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions {
|
||||
fn from(value: UserProvidedSettings) -> Self {
|
||||
Self { dimensions: value.dimensions }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings {
|
||||
fn from(value: crate::vector::manual::EmbedderOptions) -> Self {
|
||||
Self { dimensions: value.dimensions }
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user