Merge branch 'main' into change-proximity-precision-settings

This commit is contained in:
Many the fish
2023-12-18 09:08:47 +01:00
committed by GitHub
55 changed files with 5801 additions and 723 deletions

1118
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -276,6 +276,7 @@ pub(crate) mod test {
), ),
}), }),
pagination: Setting::NotSet, pagination: Setting::NotSet,
embedders: Setting::NotSet,
_kind: std::marker::PhantomData, _kind: std::marker::PhantomData,
}; };
settings.check() settings.check()

View File

@@ -378,6 +378,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> {
v5::Setting::Reset => v6::Setting::Reset, v5::Setting::Reset => v6::Setting::Reset,
v5::Setting::NotSet => v6::Setting::NotSet, v5::Setting::NotSet => v6::Setting::NotSet,
}, },
embedders: v6::Setting::NotSet,
_kind: std::marker::PhantomData, _kind: std::marker::PhantomData,
} }
} }

View File

@@ -1202,6 +1202,10 @@ impl IndexScheduler {
let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; let config = IndexDocumentsConfig { update_method: method, ..Default::default() };
let embedder_configs = index.embedding_configs(index_wtxn)?;
// TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense)
let embedders = self.embedders(embedder_configs)?;
let mut builder = milli::update::IndexDocuments::new( let mut builder = milli::update::IndexDocuments::new(
index_wtxn, index_wtxn,
index, index,
@@ -1220,6 +1224,8 @@ impl IndexScheduler {
let (new_builder, user_result) = builder.add_documents(reader)?; let (new_builder, user_result) = builder.add_documents(reader)?;
builder = new_builder; builder = new_builder;
builder = builder.with_embedders(embedders.clone());
let received_documents = let received_documents =
if let Some(Details::DocumentAdditionOrUpdate { if let Some(Details::DocumentAdditionOrUpdate {
received_documents, received_documents,
@@ -1345,6 +1351,9 @@ impl IndexScheduler {
for (task, (_, settings)) in tasks.iter_mut().zip(settings) { for (task, (_, settings)) in tasks.iter_mut().zip(settings) {
let checked_settings = settings.clone().check(); let checked_settings = settings.clone().check();
if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) {
self.features().check_vector("Passing `embedders` in settings")?
}
task.details = Some(Details::SettingsUpdate { settings: Box::new(settings) }); task.details = Some(Details::SettingsUpdate { settings: Box::new(settings) });
apply_settings_to_builder(&checked_settings, &mut builder); apply_settings_to_builder(&checked_settings, &mut builder);

View File

@@ -56,12 +56,12 @@ impl RoFeatures {
} }
} }
pub fn check_vector(&self) -> Result<()> { pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> {
if self.runtime.vector_store { if self.runtime.vector_store {
Ok(()) Ok(())
} else { } else {
Err(FeatureNotEnabledError { Err(FeatureNotEnabledError {
disabled_action: "Passing `vector` as a query parameter", disabled_action,
feature: "vector store", feature: "vector store",
issue_link: "https://github.com/meilisearch/product/discussions/677", issue_link: "https://github.com/meilisearch/product/discussions/677",
} }

View File

@@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String {
planned_failures: _, planned_failures: _,
run_loop_iteration: _, run_loop_iteration: _,
currently_updating_index: _, currently_updating_index: _,
embedders: _,
} = scheduler; } = scheduler;
let rtxn = env.read_txn().unwrap(); let rtxn = env.read_txn().unwrap();

View File

@@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::documents::DocumentsBatchBuilder;
use meilisearch_types::milli::update::IndexerConfig; use meilisearch_types::milli::update::IndexerConfig;
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs};
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task};
use puffin::FrameView; use puffin::FrameView;
@@ -341,6 +342,8 @@ pub struct IndexScheduler {
/// so that a handle to the index is available from other threads (search) in an optimized manner. /// so that a handle to the index is available from other threads (search) in an optimized manner.
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>, currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
embedders: Arc<RwLock<HashMap<EmbedderOptions, Arc<Embedder>>>>,
// ================= test // ================= test
// The next entry is dedicated to the tests. // The next entry is dedicated to the tests.
/// Provide a way to set a breakpoint in multiple part of the scheduler. /// Provide a way to set a breakpoint in multiple part of the scheduler.
@@ -386,6 +389,7 @@ impl IndexScheduler {
auth_path: self.auth_path.clone(), auth_path: self.auth_path.clone(),
version_file_path: self.version_file_path.clone(), version_file_path: self.version_file_path.clone(),
currently_updating_index: self.currently_updating_index.clone(), currently_updating_index: self.currently_updating_index.clone(),
embedders: self.embedders.clone(),
#[cfg(test)] #[cfg(test)]
test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), test_breakpoint_sdr: self.test_breakpoint_sdr.clone(),
#[cfg(test)] #[cfg(test)]
@@ -484,6 +488,7 @@ impl IndexScheduler {
auth_path: options.auth_path, auth_path: options.auth_path,
version_file_path: options.version_file_path, version_file_path: options.version_file_path,
currently_updating_index: Arc::new(RwLock::new(None)), currently_updating_index: Arc::new(RwLock::new(None)),
embedders: Default::default(),
#[cfg(test)] #[cfg(test)]
test_breakpoint_sdr, test_breakpoint_sdr,
@@ -1333,6 +1338,40 @@ impl IndexScheduler {
} }
} }
// TODO: consider using a type alias or a struct embedder/template
pub fn embedders(
&self,
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| {
let prompt =
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
// optimistically return existing embedder
{
let embedders = self.embedders.read().unwrap();
if let Some(embedder) = embedders.get(&embedder_options) {
return Ok((name, (embedder.clone(), prompt)));
}
}
// add missing embedder
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(meilisearch_types::milli::vector::Error::from)
.map_err(meilisearch_types::milli::Error::from)?,
);
{
let mut embedders = self.embedders.write().unwrap();
embedders.insert(embedder_options, embedder.clone());
}
Ok((name, (embedder, prompt)))
})
.collect();
res.map(EmbeddingConfigs::new)
}
/// Blocks the thread until the test handle asks to progress to/through this breakpoint. /// Blocks the thread until the test handle asks to progress to/through this breakpoint.
/// ///
/// Two messages are sent through the channel for each breakpoint. /// Two messages are sent through the channel for each breakpoint.

View File

@@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError);
merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskKindError);
merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(ParseTaskStatusError);
merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(IndexUidFormatError);
merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio);

View File

@@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ;
InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
InvalidEmbedder , InvalidRequest , BAD_REQUEST ;
InvalidHybridQuery , InvalidRequest , BAD_REQUEST ;
InvalidIndexLimit , InvalidRequest , BAD_REQUEST ; InvalidIndexLimit , InvalidRequest , BAD_REQUEST ;
InvalidIndexOffset , InvalidRequest , BAD_REQUEST ; InvalidIndexOffset , InvalidRequest , BAD_REQUEST ;
InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ; InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ;
@@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
@@ -256,6 +259,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ;
InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ; InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ;
InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ;
InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ; InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ;
InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ;
InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ; InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ;
InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ;
InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ;
@@ -295,15 +299,18 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
MissingIndexUid , InvalidRequest , BAD_REQUEST ; MissingIndexUid , InvalidRequest , BAD_REQUEST ;
MissingMasterKey , Auth , UNAUTHORIZED ; MissingMasterKey , Auth , UNAUTHORIZED ;
MissingPayload , InvalidRequest , BAD_REQUEST ; MissingPayload , InvalidRequest , BAD_REQUEST ;
MissingSearchHybrid , InvalidRequest , BAD_REQUEST ;
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ; MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
MissingTaskFilters , InvalidRequest , BAD_REQUEST ; MissingTaskFilters , InvalidRequest , BAD_REQUEST ;
NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY; NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY;
PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ; PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ;
TaskNotFound , InvalidRequest , NOT_FOUND ; TaskNotFound , InvalidRequest , NOT_FOUND ;
TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ;
TooManyVectors , InvalidRequest , BAD_REQUEST ;
UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableDocument , Internal , BAD_REQUEST ;
UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ;
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ;
VectorEmbeddingError , InvalidRequest , BAD_REQUEST
} }
impl ErrorCode for JoinError { impl ErrorCode for JoinError {
@@ -336,6 +343,10 @@ impl ErrorCode for milli::Error {
UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => {
Code::InvalidDocumentId Code::InvalidDocumentId
} }
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
UserError::MultiplePrimaryKeyCandidatesFound { .. } => { UserError::MultiplePrimaryKeyCandidatesFound { .. } => {
Code::IndexPrimaryKeyMultipleCandidatesFound Code::IndexPrimaryKeyMultipleCandidatesFound
@@ -353,11 +364,15 @@ impl ErrorCode for milli::Error {
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType,
UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType,
UserError::TooManyVectors(_, _) => Code::TooManyVectors,
UserError::SortError(_) => Code::InvalidSearchSort, UserError::SortError(_) => Code::InvalidSearchSort,
UserError::InvalidMinTypoWordLenSetting(_, _) => { UserError::InvalidMinTypoWordLenSetting(_, _) => {
Code::InvalidSettingsTypoTolerance Code::InvalidSettingsTypoTolerance
} }
UserError::InvalidEmbedder(_) => Code::InvalidEmbedder,
UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError,
} }
} }
} }
@@ -445,6 +460,15 @@ impl fmt::Display for DeserrParseIntError {
} }
} }
impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`."
)
}
}
#[macro_export] #[macro_export]
macro_rules! internal_error { macro_rules! internal_error {
($target:ty : $($other:path), *) => { ($target:ty : $($other:path), *) => {

View File

@@ -199,6 +199,10 @@ pub struct Settings<T> {
#[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)] #[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)]
pub pagination: Setting<PaginationSettings>, pub pagination: Setting<PaginationSettings>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)]
pub embedders: Setting<BTreeMap<String, Setting<milli::vector::settings::EmbeddingSettings>>>,
#[serde(skip)] #[serde(skip)]
#[deserr(skip)] #[deserr(skip)]
pub _kind: PhantomData<T>, pub _kind: PhantomData<T>,
@@ -222,6 +226,7 @@ impl Settings<Checked> {
typo_tolerance: Setting::Reset, typo_tolerance: Setting::Reset,
faceting: Setting::Reset, faceting: Setting::Reset,
pagination: Setting::Reset, pagination: Setting::Reset,
embedders: Setting::Reset,
_kind: PhantomData, _kind: PhantomData,
} }
} }
@@ -243,6 +248,7 @@ impl Settings<Checked> {
typo_tolerance, typo_tolerance,
faceting, faceting,
pagination, pagination,
embedders,
.. ..
} = self; } = self;
@@ -262,6 +268,7 @@ impl Settings<Checked> {
typo_tolerance, typo_tolerance,
faceting, faceting,
pagination, pagination,
embedders,
_kind: PhantomData, _kind: PhantomData,
} }
} }
@@ -307,6 +314,7 @@ impl Settings<Unchecked> {
typo_tolerance: self.typo_tolerance, typo_tolerance: self.typo_tolerance,
faceting: self.faceting, faceting: self.faceting,
pagination: self.pagination, pagination: self.pagination,
embedders: self.embedders,
_kind: PhantomData, _kind: PhantomData,
} }
} }
@@ -490,6 +498,12 @@ pub fn apply_settings_to_builder(
Setting::Reset => builder.reset_pagination_max_total_hits(), Setting::Reset => builder.reset_pagination_max_total_hits(),
Setting::NotSet => (), Setting::NotSet => (),
} }
match settings.embedders.clone() {
Setting::Set(value) => builder.set_embedder_settings(value),
Setting::Reset => builder.reset_embedder_settings(),
Setting::NotSet => (),
}
} }
pub fn settings( pub fn settings(
@@ -571,6 +585,12 @@ pub fn settings(
), ),
}; };
let embedders = index
.embedding_configs(rtxn)?
.into_iter()
.map(|(name, config)| (name, Setting::Set(config.into())))
.collect();
Ok(Settings { Ok(Settings {
displayed_attributes: match displayed_attributes { displayed_attributes: match displayed_attributes {
Some(attrs) => Setting::Set(attrs), Some(attrs) => Setting::Set(attrs),
@@ -599,6 +619,7 @@ pub fn settings(
typo_tolerance: Setting::Set(typo_tolerance), typo_tolerance: Setting::Set(typo_tolerance),
faceting: Setting::Set(faceting), faceting: Setting::Set(faceting),
pagination: Setting::Set(pagination), pagination: Setting::Set(pagination),
embedders: Setting::Set(embedders),
_kind: PhantomData, _kind: PhantomData,
}) })
} }
@@ -747,6 +768,7 @@ pub(crate) mod test {
typo_tolerance: Setting::NotSet, typo_tolerance: Setting::NotSet,
faceting: Setting::NotSet, faceting: Setting::NotSet,
pagination: Setting::NotSet, pagination: Setting::NotSet,
embedders: Setting::NotSet,
_kind: PhantomData::<Unchecked>, _kind: PhantomData::<Unchecked>,
}; };
@@ -772,6 +794,7 @@ pub(crate) mod test {
typo_tolerance: Setting::NotSet, typo_tolerance: Setting::NotSet,
faceting: Setting::NotSet, faceting: Setting::NotSet,
pagination: Setting::NotSet, pagination: Setting::NotSet,
embedders: Setting::NotSet,
_kind: PhantomData::<Unchecked>, _kind: PhantomData::<Unchecked>,
}; };

View File

@@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats};
use crate::search::{ use crate::search::{
FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO,
}; };
use crate::Opt; use crate::Opt;
@@ -586,6 +586,11 @@ pub struct SearchAggregator {
// vector // vector
// The maximum number of floats in a vector request // The maximum number of floats in a vector request
max_vector_size: usize, max_vector_size: usize,
// Whether the semantic ratio passed to a hybrid search equals the default ratio.
semantic_ratio: bool,
// Whether a non-default embedder was specified
embedder: bool,
hybrid: bool,
// every time a search is done, we increment the counter linked to the used settings // every time a search is done, we increment the counter linked to the used settings
matching_strategy: HashMap<String, usize>, matching_strategy: HashMap<String, usize>,
@@ -639,6 +644,7 @@ impl SearchAggregator {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@@ -712,6 +718,12 @@ impl SearchAggregator {
ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score = *show_ranking_score;
ret.show_ranking_score_details = *show_ranking_score_details; ret.show_ranking_score_details = *show_ranking_score_details;
if let Some(hybrid) = hybrid {
ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO();
ret.embedder = hybrid.embedder.is_some();
ret.hybrid = true;
}
ret ret
} }
@@ -765,6 +777,9 @@ impl SearchAggregator {
facets_total_number_of_facets, facets_total_number_of_facets,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
semantic_ratio,
embedder,
hybrid,
} = other; } = other;
if self.timestamp.is_none() { if self.timestamp.is_none() {
@@ -810,6 +825,9 @@ impl SearchAggregator {
// vector // vector
self.max_vector_size = self.max_vector_size.max(max_vector_size); self.max_vector_size = self.max_vector_size.max(max_vector_size);
self.semantic_ratio |= semantic_ratio;
self.hybrid |= hybrid;
self.embedder |= embedder;
// pagination // pagination
self.max_limit = self.max_limit.max(max_limit); self.max_limit = self.max_limit.max(max_limit);
@@ -878,6 +896,9 @@ impl SearchAggregator {
facets_total_number_of_facets, facets_total_number_of_facets,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
semantic_ratio,
embedder,
hybrid,
} = self; } = self;
if total_received == 0 { if total_received == 0 {
@@ -917,6 +938,11 @@ impl SearchAggregator {
"vector": { "vector": {
"max_vector_size": max_vector_size, "max_vector_size": max_vector_size,
}, },
"hybrid": {
"enabled": hybrid,
"semantic_ratio": semantic_ratio,
"embedder": embedder,
},
"pagination": { "pagination": {
"max_limit": max_limit, "max_limit": max_limit,
"max_offset": max_offset, "max_offset": max_offset,
@@ -1012,6 +1038,7 @@ impl MultiSearchAggregator {
crop_marker: _, crop_marker: _,
matching_strategy: _, matching_strategy: _,
attributes_to_search_on: _, attributes_to_search_on: _,
hybrid: _,
} = query; } = query;
index_uid.as_str() index_uid.as_str()
@@ -1158,6 +1185,7 @@ impl FacetSearchAggregator {
filter, filter,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@@ -1171,7 +1199,8 @@ impl FacetSearchAggregator {
|| vector.is_some() || vector.is_some()
|| filter.is_some() || filter.is_some()
|| *matching_strategy != MatchingStrategy::default() || *matching_strategy != MatchingStrategy::default()
|| attributes_to_search_on.is_some(); || attributes_to_search_on.is_some()
|| hybrid.is_some();
ret ret
} }

View File

@@ -51,6 +51,8 @@ pub enum MeilisearchHttpError {
DocumentFormat(#[from] DocumentFormatError), DocumentFormat(#[from] DocumentFormatError),
#[error(transparent)] #[error(transparent)]
Join(#[from] JoinError), Join(#[from] JoinError),
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
MissingSearchHybrid,
} }
impl ErrorCode for MeilisearchHttpError { impl ErrorCode for MeilisearchHttpError {
@@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError {
MeilisearchHttpError::FileStore(_) => Code::Internal, MeilisearchHttpError::FileStore(_) => Code::Internal,
MeilisearchHttpError::DocumentFormat(e) => e.error_code(), MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
MeilisearchHttpError::Join(_) => Code::Internal, MeilisearchHttpError::Join(_) => Code::Internal,
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
} }
} }
} }

View File

@@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// does all the setup before meilisearch is launched /// does all the setup before meilisearch is launched
fn setup(opt: &Opt) -> anyhow::Result<()> { fn setup(opt: &Opt) -> anyhow::Result<()> {
let mut log_builder = env_logger::Builder::new(); let mut log_builder = env_logger::Builder::new();
log_builder.parse_filters(&opt.log_level.to_string()); let log_filters = format!(
"{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn",
opt.log_level
);
log_builder.parse_filters(&log_filters);
log_builder.init(); log_builder.init();

View File

@@ -13,9 +13,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator};
use crate::extractors::authentication::policies::*; use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::search::{ use crate::search::{
add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery,
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
}; };
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
@@ -36,6 +36,8 @@ pub struct FacetSearchQuery {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
pub filter: Option<Value>, pub filter: Option<Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)]
@@ -95,6 +97,7 @@ impl From<FacetSearchQuery> for SearchQuery {
filter, filter,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = value; } = value;
SearchQuery { SearchQuery {
@@ -119,6 +122,7 @@ impl From<FacetSearchQuery> for SearchQuery {
matching_strategy, matching_strategy,
vector, vector,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} }
} }
} }

View File

@@ -2,12 +2,14 @@ use actix_web::web::Data;
use actix_web::{web, HttpRequest, HttpResponse}; use actix_web::{web, HttpRequest, HttpResponse};
use deserr::actix_web::{AwebJson, AwebQueryParameter}; use deserr::actix_web::{AwebJson, AwebQueryParameter};
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use log::debug; use log::{debug, warn};
use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::query_params::Param;
use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli;
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value; use serde_json::Value;
@@ -16,9 +18,9 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio,
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
}; };
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
@@ -74,6 +76,31 @@ pub struct SearchQueryGet {
matching_strategy: MatchingStrategy, matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)]
pub attributes_to_search_on: Option<CS<String>>, pub attributes_to_search_on: Option<CS<String>>,
#[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)]
pub hybrid_embedder: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchSemanticRatio>)]
pub hybrid_semantic_ratio: Option<SemanticRatioGet>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)]
#[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatioGet(SemanticRatio);
impl std::convert::TryFrom<String> for SemanticRatioGet {
type Error = InvalidSearchSemanticRatio;
fn try_from(s: String) -> Result<Self, Self::Error> {
let f: f32 = s.parse().map_err(|_| InvalidSearchSemanticRatio)?;
Ok(SemanticRatioGet(SemanticRatio::try_from(f)?))
}
}
impl std::ops::Deref for SemanticRatioGet {
type Target = SemanticRatio;
fn deref(&self) -> &Self::Target {
&self.0
}
} }
impl From<SearchQueryGet> for SearchQuery { impl From<SearchQueryGet> for SearchQuery {
@@ -86,6 +113,20 @@ impl From<SearchQueryGet> for SearchQuery {
None => None, None => None,
}; };
let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) {
(None, None) => None,
(None, Some(semantic_ratio)) => {
Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: None })
}
(Some(embedder), None) => Some(HybridQuery {
semantic_ratio: DEFAULT_SEMANTIC_RATIO(),
embedder: Some(embedder),
}),
(Some(embedder), Some(semantic_ratio)) => {
Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: Some(embedder) })
}
};
Self { Self {
q: other.q, q: other.q,
vector: other.vector.map(CS::into_inner), vector: other.vector.map(CS::into_inner),
@@ -108,6 +149,7 @@ impl From<SearchQueryGet> for SearchQuery {
crop_marker: other.crop_marker, crop_marker: other.crop_marker,
matching_strategy: other.matching_strategy, matching_strategy: other.matching_strategy,
attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()),
hybrid,
} }
} }
} }
@@ -158,8 +200,12 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
.await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
} }
@@ -193,8 +239,12 @@ pub async fn search_with_post(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
.await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
} }
@@ -206,6 +256,80 @@ pub async fn search_with_post(
Ok(HttpResponse::Ok().json(search_result)) Ok(HttpResponse::Ok().json(search_result))
} }
pub async fn embed(
query: &mut SearchQuery,
index_scheduler: &IndexScheduler,
index: &milli::Index,
) -> Result<Option<DistributionShift>, ResponseError> {
match (&query.hybrid, &query.vector, &query.q) {
(Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q))
if !q.trim().is_empty() =>
{
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
let distribution = embedder.distribution();
let embeddings = embedder
.embed(vec![q.to_owned()])
.await
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
.expect("No vector returned from embedding");
if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query");
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
} else {
query.vector = Some(embeddings.into_inner());
}
Ok(distribution)
}
(Some(hybrid), vector, _) => {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = &hybrid.embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
if let Some(vector) = vector {
if vector.len() != embedder.dimensions() {
return Err(meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorDimensions {
expected: embedder.dimensions(),
found: vector.len(),
},
)
.into());
}
}
Ok(embedder.distribution())
}
_ => Ok(None),
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;

View File

@@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::facet_values_sort::FacetValuesSort; use meilisearch_types::facet_values_sort::FacetValuesSort;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked}; use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked};
use meilisearch_types::tasks::KindWithContent; use meilisearch_types::tasks::KindWithContent;
use serde_json::json; use serde_json::json;
@@ -546,6 +547,67 @@ make_setting_route!(
} }
); );
make_setting_route!(
"/embedders",
patch,
std::collections::BTreeMap<String, Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>>,
meilisearch_types::deserr::DeserrJsonError<
meilisearch_types::error::deserr_codes::InvalidSettingsEmbedders,
>,
embedders,
"embedders",
analytics,
|setting: &Option<std::collections::BTreeMap<String, Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>>>, req: &HttpRequest| {
analytics.publish(
"Embedders Updated".to_string(),
serde_json::json!({"embedders": crate::routes::indexes::settings::embedder_analytics(setting.as_ref())}),
Some(req),
);
}
);
fn embedder_analytics(
setting: Option<
&std::collections::BTreeMap<
String,
Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>,
>,
>,
) -> serde_json::Value {
let mut sources = std::collections::HashSet::new();
if let Some(s) = &setting {
for source in s
.values()
.filter_map(|config| config.clone().set())
.filter_map(|config| config.embedder_options.set())
{
use meilisearch_types::milli::vector::settings::EmbedderSettings;
match source {
EmbedderSettings::OpenAi(_) => sources.insert("openAi"),
EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"),
EmbedderSettings::UserProvided(_) => sources.insert("userProvided"),
};
}
};
let document_template_used = setting.as_ref().map(|map| {
map.values()
.filter_map(|config| config.clone().set())
.any(|config| config.document_template.set().is_some())
});
json!(
{
"total": setting.as_ref().map(|s| s.len()),
"sources": sources,
"document_template_used": document_template_used,
}
)
}
macro_rules! generate_configure { macro_rules! generate_configure {
($($mod:ident),*) => { ($($mod:ident),*) => {
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
@@ -575,7 +637,8 @@ generate_configure!(
ranking_rules, ranking_rules,
typo_tolerance, typo_tolerance,
pagination, pagination,
faceting faceting,
embedders
); );
pub async fn update_all( pub async fn update_all(
@@ -682,6 +745,7 @@ pub async fn update_all(
"synonyms": { "synonyms": {
"total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()), "total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()),
}, },
"embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set())
}), }),
Some(&req), Some(&req),
); );

View File

@@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator};
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::authentication::{AuthenticationError, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::routes::indexes::search::embed;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
}; };
@@ -74,8 +75,13 @@ pub async fn multi_search_with_post(
}) })
.with_index(query_index)?; .with_index(query_index)?;
let search_result = let distribution = embed(&mut query, index_scheduler.get_ref(), &index)
tokio::task::spawn_blocking(move || perform_search(&index, query, features)) .await
.with_index(query_index)?;
let search_result = tokio::task::spawn_blocking(move || {
perform_search(&index, query, features, distribution)
})
.await .await
.with_index(query_index)?; .with_index(query_index)?;

View File

@@ -7,24 +7,21 @@ use deserr::Deserr;
use either::Either; use either::Either;
use index_scheduler::RoFeatures; use index_scheduler::RoFeatures;
use indexmap::IndexMap; use indexmap::IndexMap;
use log::warn;
use meilisearch_auth::IndexSearchRules; use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn; use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::{ use meilisearch_types::milli::vector::DistributionShift;
dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document}; use meilisearch_types::{milli, Document};
use milli::tokenizer::TokenizerBuilder; use milli::tokenizer::TokenizerBuilder;
use milli::{ use milli::{
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
}; };
use ordered_float::OrderedFloat;
use regex::Regex; use regex::Regex;
use serde::Serialize; use serde::Serialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -39,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10;
pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string(); pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string();
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5);
#[derive(Debug, Clone, Default, PartialEq, Deserr)] #[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
@@ -47,6 +45,8 @@ pub struct SearchQuery {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize, pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@@ -87,6 +87,48 @@ pub struct SearchQuery {
pub attributes_to_search_on: Option<Vec<String>>, pub attributes_to_search_on: Option<Vec<String>>,
} }
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery {
/// TODO validate that sementic ratio is between 0.0 and 1,0
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
pub semantic_ratio: SemanticRatio,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32);
impl Default for SemanticRatio {
fn default() -> Self {
DEFAULT_SEMANTIC_RATIO()
}
}
impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio)
} else {
Ok(SemanticRatio(f))
}
}
}
impl std::ops::Deref for SemanticRatio {
type Target = f32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl SearchQuery { impl SearchQuery {
pub fn is_finite_pagination(&self) -> bool { pub fn is_finite_pagination(&self) -> bool {
self.page.or(self.hits_per_page).is_some() self.page.or(self.hits_per_page).is_some()
@@ -106,6 +148,8 @@ pub struct SearchQueryWithIndex {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize, pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@@ -171,6 +215,7 @@ impl SearchQueryWithIndex {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = self; } = self;
( (
index_uid, index_uid,
@@ -196,6 +241,7 @@ impl SearchQueryWithIndex {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
// do not use ..Default::default() here, // do not use ..Default::default() here,
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
}, },
@@ -335,19 +381,44 @@ fn prepare_search<'t>(
rtxn: &'t RoTxn, rtxn: &'t RoTxn,
query: &'t SearchQuery, query: &'t SearchQuery,
features: RoFeatures, features: RoFeatures,
distribution: Option<DistributionShift>,
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn); let mut search = index.search(rtxn);
if query.vector.is_some() && query.q.is_some() { if query.vector.is_some() {
warn!("Ignoring the query string `q` when used with the `vector` parameter."); features.check_vector("Passing `vector` as a query parameter")?;
} }
if query.hybrid.is_some() {
features.check_vector("Passing `hybrid` as a query parameter")?;
}
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
return Err(MeilisearchHttpError::MissingSearchHybrid);
}
search.distribution_shift(distribution);
if let Some(ref vector) = query.vector { if let Some(ref vector) = query.vector {
match &query.hybrid {
// If semantic ratio is 0.0, only the query search will impact the search results,
// skip the vector
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
_otherwise => {
search.vector(vector.clone()); search.vector(vector.clone());
} }
}
}
if let Some(ref query) = query.q { if let Some(ref q) = query.q {
search.query(query); match &query.hybrid {
// If semantic ratio is 1.0, only the vector search will impact the search results,
// skip the query
Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (),
_otherwise => {
search.query(q);
}
}
} }
if let Some(ref searchable) = query.attributes_to_search_on { if let Some(ref searchable) = query.attributes_to_search_on {
@@ -374,8 +445,8 @@ fn prepare_search<'t>(
features.check_score_details()?; features.check_score_details()?;
} }
if query.vector.is_some() { if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
features.check_vector()?; search.embedder_name(embedder);
} }
// compute the offset on the limit depending on the pagination mode. // compute the offset on the limit depending on the pagination mode.
@@ -421,15 +492,22 @@ pub fn perform_search(
index: &Index, index: &Index,
query: SearchQuery, query: SearchQuery,
features: RoFeatures, features: RoFeatures,
distribution: Option<DistributionShift>,
) -> Result<SearchResult, MeilisearchHttpError> { ) -> Result<SearchResult, MeilisearchHttpError> {
let before_search = Instant::now(); let before_search = Instant::now();
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
let (search, is_finite_pagination, max_total_hits, offset) = let (search, is_finite_pagination, max_total_hits, offset) =
prepare_search(index, &rtxn, &query, features)?; prepare_search(index, &rtxn, &query, features, distribution)?;
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
search.execute()?; match &query.hybrid {
Some(hybrid) => match *hybrid.semantic_ratio {
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
ratio => search.execute_hybrid(ratio)?,
},
None => search.execute()?,
};
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
@@ -538,13 +616,17 @@ pub fn perform_search(
insert_geo_distance(sort, &mut document); insert_geo_distance(sort, &mut document);
} }
let semantic_score = match query.vector.as_ref() { let mut semantic_score = None;
Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { for details in &score {
Some(vectors) => compute_semantic_score(vector, vectors)?, if let ScoreDetails::Vector(score_details::Vector {
None => None, target_vector: _,
}, value_similarity: Some((_matching_vector, similarity)),
None => None, }) = details
}; {
semantic_score = Some(*similarity);
break;
}
}
let ranking_score = let ranking_score =
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
@@ -647,8 +729,9 @@ pub fn perform_facet_search(
let before_search = Instant::now(); let before_search = Instant::now();
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?; let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?;
let mut facet_search = SearchForFacetValues::new(facet_name, search); let mut facet_search =
SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some());
if let Some(facet_query) = &facet_query { if let Some(facet_query) = &facet_query {
facet_search.query(facet_query); facet_search.query(facet_query);
} }
@@ -676,18 +759,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
} }
} }
fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result<Option<f32>> {
let vectors = serde_json::from_value(vectors)
.map(VectorOrArrayOfVectors::into_array_of_vectors)
.map_err(InternalError::SerdeJson)?;
Ok(vectors
.into_iter()
.flatten()
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
.max()
.map(OrderedFloat::into_inner))
}
fn compute_formatted_options( fn compute_formatted_options(
attr_to_highlight: &HashSet<String>, attr_to_highlight: &HashSet<String>,
attr_to_crop: &[String], attr_to_crop: &[String],
@@ -815,22 +886,6 @@ fn make_document(
Ok(document) Ok(document)
} }
/// Extract the JSON value under the field name specified
/// but doesn't support nested objects.
fn extract_field(
field_name: &str,
field_ids_map: &FieldsIdsMap,
obkv: obkv::KvReaderU16,
) -> Result<Option<serde_json::Value>, MeilisearchHttpError> {
match field_ids_map.id(field_name) {
Some(fid) => match obkv.get(fid) {
Some(value) => Ok(serde_json::from_slice(value).map(Some)?),
None => Ok(None),
},
None => Ok(None),
}
}
fn format_fields<'a>( fn format_fields<'a>(
document: &Document, document: &Document,
field_ids_map: &FieldsIdsMap, field_ids_map: &FieldsIdsMap,

View File

@@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@@ -1895,7 +1907,8 @@ async fn import_dump_v6_containing_experimental_features() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"###); "###);

View File

@@ -0,0 +1,152 @@
use meili_snap::snapshot;
use once_cell::sync::Lazy;
use crate::common::index::Index;
use crate::common::{Server, Value};
use crate::json;
async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Index<'a> {
let index = server.index("test");
let (response, code) = server.set_features(json!({"vectorStore": true})).await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"scoreDetails": false,
"vectorStore": true,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
}
"###);
let (response, code) = index
.update_settings(
json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }),
)
.await;
assert_eq!(202, code, "{:?}", response);
index.wait_task(response.uid()).await;
let (response, code) = index.add_documents(documents.clone(), None).await;
assert_eq!(202, code, "{:?}", response);
index.wait_task(response.uid()).await;
index
}
static SIMPLE_SEARCH_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
json!([
{
"title": "Shazam!",
"desc": "a Captain Marvel ersatz",
"id": "1",
"_vectors": {"default": [1.0, 3.0]},
},
{
"title": "Captain Planet",
"desc": "He's not part of the Marvel Cinematic Universe",
"id": "2",
"_vectors": {"default": [1.0, 2.0]},
},
{
"title": "Captain Marvel",
"desc": "a Shazam ersatz",
"id": "3",
"_vectors": {"default": [2.0, 3.0]},
}])
});
#[actix_rt::test]
async fn simple_search() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}),
)
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###);
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}}),
)
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
}
#[actix_rt::test]
async fn invalid_semantic_ratio() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 1.2}}),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": -0.8}}),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
let (response, code) = index
.search_get(
&yaup::to_string(
&json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": 1.2}),
)
.unwrap(),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
let (response, code) = index
.search_get(
&yaup::to_string(
&json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": -0.2}),
)
.unwrap(),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
}

View File

@@ -6,6 +6,7 @@ mod errors;
mod facet_search; mod facet_search;
mod formatted; mod formatted;
mod geo; mod geo;
mod hybrid;
mod multi; mod multi;
mod pagination; mod pagination;
mod restrict_searchable; mod restrict_searchable;
@@ -20,22 +21,27 @@ static DOCUMENTS: Lazy<Value> = Lazy::new(|| {
{ {
"title": "Shazam!", "title": "Shazam!",
"id": "287947", "id": "287947",
"_vectors": { "manual": [1, 2, 3]},
}, },
{ {
"title": "Captain Marvel", "title": "Captain Marvel",
"id": "299537", "id": "299537",
"_vectors": { "manual": [1, 2, 54] },
}, },
{ {
"title": "Escape Room", "title": "Escape Room",
"id": "522681", "id": "522681",
"_vectors": { "manual": [10, -23, 32] },
}, },
{ {
"title": "How to Train Your Dragon: The Hidden World", "title": "How to Train Your Dragon: The Hidden World",
"id": "166428", "id": "166428",
"_vectors": { "manual": [-100, 231, 32] },
}, },
{ {
"title": "Gläss", "title": "Gläss",
"id": "450465", "id": "450465",
"_vectors": { "manual": [-100, 340, 90] },
} }
]) ])
}); });
@@ -57,6 +63,7 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
}, },
], ],
"cattos": "pésti", "cattos": "pésti",
"_vectors": { "manual": [1, 2, 3]},
}, },
{ {
"id": 654, "id": 654,
@@ -69,12 +76,14 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
}, },
], ],
"cattos": ["simba", "pestiféré"], "cattos": ["simba", "pestiféré"],
"_vectors": { "manual": [1, 2, 54] },
}, },
{ {
"id": 750, "id": 750,
"father": "romain", "father": "romain",
"mother": "michelle", "mother": "michelle",
"cattos": ["enigma"], "cattos": ["enigma"],
"_vectors": { "manual": [10, 23, 32] },
}, },
{ {
"id": 951, "id": 951,
@@ -91,6 +100,7 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
}, },
], ],
"cattos": ["moumoute", "gomez"], "cattos": ["moumoute", "gomez"],
"_vectors": { "manual": [10, 23, 32] },
}, },
]) ])
}); });
@@ -802,6 +812,13 @@ async fn experimental_feature_score_details() {
{ {
"title": "How to Train Your Dragon: The Hidden World", "title": "How to Train Your Dragon: The Hidden World",
"id": "166428", "id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
},
"_rankingScoreDetails": { "_rankingScoreDetails": {
"words": { "words": {
"order": 0, "order": 0,
@@ -823,7 +840,7 @@ async fn experimental_feature_score_details() {
"order": 3, "order": 3,
"attributeRankingOrderScore": 1.0, "attributeRankingOrderScore": 1.0,
"queryWordDistanceScore": 0.8095238095238095, "queryWordDistanceScore": 0.8095238095238095,
"score": 0.9365079365079364 "score": 0.9727891156462584
}, },
"exactness": { "exactness": {
"order": 4, "order": 4,
@@ -870,13 +887,92 @@ async fn experimental_feature_vector_store() {
meili_snap::snapshot!(code, @"200 OK"); meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(response["vectorStore"], @"true"); meili_snap::snapshot!(response["vectorStore"], @"true");
let (response, code) = index
.update_settings(json!({"embedders": {
"manual": {
"source": {
"userProvided": {"dimensions": 3}
}
}
}}))
.await;
meili_snap::snapshot!(code, @"202 Accepted");
let response = index.wait_task(response.uid()).await;
meili_snap::snapshot!(meili_snap::json_string!(response["status"]), @"\"succeeded\"");
let (response, code) = index let (response, code) = index
.search_post(json!({ .search_post(json!({
"vector": [1.0, 2.0, 3.0], "vector": [1.0, 2.0, 3.0],
})) }))
.await; .await;
meili_snap::snapshot!(code, @"200 OK"); meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); // vector search returns all documents that don't have vectors in the last bucket, like all sorts
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Shazam!",
"id": "287947",
"_vectors": {
"manual": [
1,
2,
3
]
},
"_semanticScore": 1.0
},
{
"title": "Captain Marvel",
"id": "299537",
"_vectors": {
"manual": [
1,
2,
54
]
},
"_semanticScore": 0.9129112
},
{
"title": "Gläss",
"id": "450465",
"_vectors": {
"manual": [
-100,
340,
90
]
},
"_semanticScore": 0.8106413
},
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
},
"_semanticScore": 0.74120104
},
{
"title": "Escape Room",
"id": "522681",
"_vectors": {
"manual": [
10,
-23,
32
]
}
}
]
"###);
} }
#[cfg(feature = "default")] #[cfg(feature = "default")]
@@ -1126,7 +1222,14 @@ async fn simple_search_with_strange_synonyms() {
[ [
{ {
"title": "How to Train Your Dragon: The Hidden World", "title": "How to Train Your Dragon: The Hidden World",
"id": "166428" "id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
}
} }
] ]
"###); "###);
@@ -1140,7 +1243,14 @@ async fn simple_search_with_strange_synonyms() {
[ [
{ {
"title": "How to Train Your Dragon: The Hidden World", "title": "How to Train Your Dragon: The Hidden World",
"id": "166428" "id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
}
} }
] ]
"###); "###);
@@ -1154,7 +1264,14 @@ async fn simple_search_with_strange_synonyms() {
[ [
{ {
"title": "How to Train Your Dragon: The Hidden World", "title": "How to Train Your Dragon: The Hidden World",
"id": "166428" "id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
}
} }
] ]
"###); "###);

View File

@@ -72,7 +72,14 @@ async fn simple_search_single_index() {
"hits": [ "hits": [
{ {
"title": "Gläss", "title": "Gläss",
"id": "450465" "id": "450465",
"_vectors": {
"manual": [
-100,
340,
90
]
}
} }
], ],
"query": "glass", "query": "glass",
@@ -86,7 +93,14 @@ async fn simple_search_single_index() {
"hits": [ "hits": [
{ {
"title": "Captain Marvel", "title": "Captain Marvel",
"id": "299537" "id": "299537",
"_vectors": {
"manual": [
1,
2,
54
]
}
} }
], ],
"query": "captain", "query": "captain",
@@ -177,7 +191,14 @@ async fn simple_search_two_indexes() {
"hits": [ "hits": [
{ {
"title": "Gläss", "title": "Gläss",
"id": "450465" "id": "450465",
"_vectors": {
"manual": [
-100,
340,
90
]
}
} }
], ],
"query": "glass", "query": "glass",
@@ -203,7 +224,14 @@ async fn simple_search_two_indexes() {
"age": 4 "age": 4
} }
], ],
"cattos": "pésti" "cattos": "pésti",
"_vectors": {
"manual": [
1,
2,
3
]
}
}, },
{ {
"id": 654, "id": 654,
@@ -218,8 +246,15 @@ async fn simple_search_two_indexes() {
"cattos": [ "cattos": [
"simba", "simba",
"pestiféré" "pestiféré"
],
"_vectors": {
"manual": [
1,
2,
54
] ]
} }
}
], ],
"query": "pésti", "query": "pésti",
"processingTimeMs": "[time]", "processingTimeMs": "[time]",

View File

@@ -54,7 +54,7 @@ async fn get_settings() {
let (response, code) = index.settings().await; let (response, code) = index.settings().await;
assert_eq!(code, 200); assert_eq!(code, 200);
let settings = response.as_object().unwrap(); let settings = response.as_object().unwrap();
assert_eq!(settings.keys().len(), 15); assert_eq!(settings.keys().len(), 16);
assert_eq!(settings["displayedAttributes"], json!(["*"])); assert_eq!(settings["displayedAttributes"], json!(["*"]));
assert_eq!(settings["searchableAttributes"], json!(["*"])); assert_eq!(settings["searchableAttributes"], json!(["*"]));
assert_eq!(settings["filterableAttributes"], json!([])); assert_eq!(settings["filterableAttributes"], json!([]));
@@ -83,6 +83,7 @@ async fn get_settings() {
"maxTotalHits": 1000, "maxTotalHits": 1000,
}) })
); );
assert_eq!(settings["embedders"], json!({}));
} }
#[actix_rt::test] #[actix_rt::test]

View File

@@ -27,13 +27,15 @@ fst = "0.4.7"
fxhash = "0.2.1" fxhash = "0.2.1"
geoutils = "0.5.1" geoutils = "0.5.1"
grenad = { version = "0.4.5", default-features = false, features = [ grenad = { version = "0.4.5", default-features = false, features = [
"rayon", "tempfile" "rayon",
"tempfile",
] } ] }
heed = { version = "0.20.0-alpha.9", default-features = false, features = [ heed = { version = "0.20.0-alpha.9", default-features = false, features = [
"serde-json", "serde-bincode", "read-txn-no-tls" "serde-json",
"serde-bincode",
"read-txn-no-tls",
] } ] }
indexmap = { version = "2.0.0", features = ["serde"] } indexmap = { version = "2.0.0", features = ["serde"] }
instant-distance = { version = "0.6.1", features = ["with-serde"] }
json-depth-checker = { path = "../json-depth-checker" } json-depth-checker = { path = "../json-depth-checker" }
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
memmap2 = "0.7.1" memmap2 = "0.7.1"
@@ -72,6 +74,23 @@ puffin = "0.16.0"
log = "0.4.17" log = "0.4.17"
logging_timer = "1.1.0" logging_timer = "1.1.0"
csv = "1.2.1" csv = "1.2.1"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" }
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
"online",
] }
tokio = { version = "1.34.0", features = ["rt"] }
futures = "0.3.29"
reqwest = { version = "0.11.16", features = [
"rustls-tls",
"json",
], default-features = false }
tiktoken-rs = "0.5.7"
liquid = "0.26.4"
arroy = { git = "https://github.com/meilisearch/arroy.git", version = "0.1.0" }
rand = "0.8.5"
[dev-dependencies] [dev-dependencies]
mimalloc = { version = "0.1.37", default-features = false } mimalloc = { version = "0.1.37", default-features = false }
@@ -83,7 +102,15 @@ meili-snap = { path = "../meili-snap" }
rand = { version = "0.8.5", features = ["small_rng"] } rand = { version = "0.8.5", features = ["small_rng"] }
[features] [features]
all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] all-tokenizations = [
"charabia/chinese",
"charabia/hebrew",
"charabia/japanese",
"charabia/thai",
"charabia/korean",
"charabia/greek",
"charabia/khmer",
]
# Use POSIX semaphores instead of SysV semaphores in LMDB # Use POSIX semaphores instead of SysV semaphores in LMDB
# For more information on this feature, see heed's Cargo.toml # For more information on this feature, see heed's Cargo.toml

View File

@@ -5,8 +5,8 @@ use std::time::Instant;
use heed::EnvOpenOptions; use heed::EnvOpenOptions;
use milli::{ use milli::{
execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger, execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext,
TermsMatchingStrategy, SearchLogger, TermsMatchingStrategy,
}; };
#[global_allocator] #[global_allocator]
@@ -49,14 +49,15 @@ fn main() -> Result<(), Box<dyn Error>> {
let start = Instant::now(); let start = Instant::now();
let mut ctx = SearchContext::new(&index, &txn); let mut ctx = SearchContext::new(&index, &txn);
let universe = filtered_universe(&ctx, &None)?;
let docs = execute_search( let docs = execute_search(
&mut ctx, &mut ctx,
&(!query.trim().is_empty()).then(|| query.trim().to_owned()), (!query.trim().is_empty()).then(|| query.trim()),
&None,
TermsMatchingStrategy::Last, TermsMatchingStrategy::Last,
milli::score_details::ScoringStrategy::Skip, milli::score_details::ScoringStrategy::Skip,
false, false,
&None, universe,
&None, &None,
GeoSortStrategy::default(), GeoSortStrategy::default(),
0, 0,

View File

@@ -1,41 +0,0 @@
use std::ops;
use instant_distance::Point;
use serde::{Deserialize, Serialize};
use crate::normalize_vector;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct NDotProductPoint(Vec<f32>);
impl NDotProductPoint {
pub fn new(point: Vec<f32>) -> Self {
NDotProductPoint(normalize_vector(point))
}
pub fn into_inner(self) -> Vec<f32> {
self.0
}
}
impl ops::Deref for NDotProductPoint {
type Target = [f32];
fn deref(&self) -> &Self::Target {
self.0.as_slice()
}
}
impl Point for NDotProductPoint {
fn distance(&self, other: &Self) -> f32 {
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
debug_assert!(!dist.is_nan());
dist
}
}
/// Returns the dot product similarity score that will between 0.0 and 1.0
/// if both vectors are normalized. The higher the more similar the vectors are.
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}

View File

@@ -61,6 +61,10 @@ pub enum InternalError {
AbortedIndexation, AbortedIndexation,
#[error("The matching words list contains at least one invalid member.")] #[error("The matching words list contains at least one invalid member.")]
InvalidMatchingWords, InvalidMatchingWords,
#[error(transparent)]
ArroyError(#[from] arroy::Error),
#[error(transparent)]
VectorEmbeddingError(#[from] crate::vector::Error),
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@@ -110,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
InvalidGeoField(#[from] GeoError), InvalidGeoField(#[from] GeoError),
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
InvalidVectorDimensions { expected: usize, found: usize }, InvalidVectorDimensions { expected: usize, found: usize },
#[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
InvalidVectorsType { document_id: Value, value: Value }, InvalidVectorsType { document_id: Value, value: Value, subfield: String },
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
InvalidVectorsMapType { document_id: Value, value: Value },
#[error("{0}")] #[error("{0}")]
InvalidFilter(String), InvalidFilter(String),
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
@@ -180,6 +186,49 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
UnknownInternalDocumentId { document_id: DocumentId }, UnknownInternalDocumentId { document_id: DocumentId },
#[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")]
InvalidMinTypoWordLenSetting(u8, u8), InvalidMinTypoWordLenSetting(u8, u8),
#[error(transparent)]
VectorEmbeddingError(#[from] crate::vector::Error),
#[error(transparent)]
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
#[error(transparent)]
InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
#[error("Invalid prompt in for embeddings with name '{0}': {1}.")]
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
TooManyEmbedders(usize),
#[error("Cannot find embedder with name {0}.")]
InvalidEmbedder(String),
#[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")]
TooManyVectors(String, usize),
}
impl From<crate::vector::Error> for Error {
fn from(value: crate::vector::Error) -> Self {
match value.fault() {
FaultSource::User => Error::UserError(value.into()),
FaultSource::Runtime => Error::InternalError(value.into()),
FaultSource::Bug => Error::InternalError(value.into()),
FaultSource::Undecided => Error::InternalError(value.into()),
}
}
}
impl From<arroy::Error> for Error {
fn from(value: arroy::Error) -> Self {
match value {
arroy::Error::Heed(heed) => heed.into(),
arroy::Error::Io(io) => io.into(),
arroy::Error::InvalidVecDimension { expected, received } => {
Error::UserError(UserError::InvalidVectorDimensions { expected, found: received })
}
arroy::Error::DatabaseFull
| arroy::Error::InvalidItemAppend
| arroy::Error::UnmatchingDistance { .. }
| arroy::Error::MissingMetadata => {
Error::InternalError(InternalError::ArroyError(value))
}
}
}
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@@ -336,6 +385,26 @@ impl From<HeedError> for Error {
} }
} }
#[derive(Debug, Clone, Copy)]
pub enum FaultSource {
User,
Runtime,
Bug,
Undecided,
}
impl std::fmt::Display for FaultSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
FaultSource::User => "user error",
FaultSource::Runtime => "runtime error",
FaultSource::Bug => "coding error",
FaultSource::Undecided => "error",
};
f.write_str(s)
}
}
#[test] #[test]
fn conditionally_lookup_for_error_message() { fn conditionally_lookup_for_error_message() {
let prefix = "Attribute `name` is not sortable."; let prefix = "Attribute `name` is not sortable.";

View File

@@ -10,7 +10,6 @@ use roaring::RoaringBitmap;
use rstar::RTree; use rstar::RTree;
use time::OffsetDateTime; use time::OffsetDateTime;
use crate::distance::NDotProductPoint;
use crate::documents::PrimaryKey; use crate::documents::PrimaryKey;
use crate::error::{InternalError, UserError}; use crate::error::{InternalError, UserError};
use crate::fields_ids_map::FieldsIdsMap; use crate::fields_ids_map::FieldsIdsMap;
@@ -22,7 +21,7 @@ use crate::heed_codec::{
BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec, BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec,
}; };
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::readable_slices::ReadableSlices; use crate::vector::EmbeddingConfig;
use crate::{ use crate::{
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
@@ -30,9 +29,6 @@ use crate::{
BEU32, BEU64, BEU32, BEU64,
}; };
/// The HNSW data-structure that we serialize, fill and search in.
pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
@@ -48,10 +44,6 @@ pub mod main_key {
pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map";
pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids";
pub const GEO_RTREE_KEY: &str = "geo-rtree"; pub const GEO_RTREE_KEY: &str = "geo-rtree";
/// The prefix of the key that is used to store the, potential big, HNSW structure.
/// It is concatenated with a big-endian encoded number (non-human readable).
/// e.g. vector-hnsw0x0032.
pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw";
pub const PRIMARY_KEY_KEY: &str = "primary-key"; pub const PRIMARY_KEY_KEY: &str = "primary-key";
pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields"; pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields";
pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields"; pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields";
@@ -74,6 +66,7 @@ pub mod main_key {
pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by";
pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits";
pub const PROXIMITY_PRECISION: &str = "proximity-precision"; pub const PROXIMITY_PRECISION: &str = "proximity-precision";
pub const EMBEDDING_CONFIGS: &str = "embedding_configs";
} }
pub mod db_name { pub mod db_name {
@@ -99,7 +92,8 @@ pub mod db_name {
pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst"; pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst";
pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s";
pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings";
pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id";
pub const VECTOR_ARROY: &str = "vector-arroy";
pub const DOCUMENTS: &str = "documents"; pub const DOCUMENTS: &str = "documents";
pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids";
} }
@@ -166,8 +160,10 @@ pub struct Index {
/// Maps the document id, the facet field id and the strings. /// Maps the document id, the facet field id and the strings.
pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>, pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>,
/// Maps a vector id to the document id that have it. /// Maps an embedder name to its id in the arroy store.
pub vector_id_docid: Database<BEU32, BEU32>, pub embedder_category_id: Database<Str, U8>,
/// Vector store based on arroy™.
pub vector_arroy: arroy::Database<arroy::distances::Angular>,
/// Maps the document id to the document as an obkv store. /// Maps the document id to the document as an obkv store.
pub(crate) documents: Database<BEU32, ObkvCodec>, pub(crate) documents: Database<BEU32, ObkvCodec>,
@@ -182,7 +178,7 @@ impl Index {
) -> Result<Index> { ) -> Result<Index> {
use db_name::*; use db_name::*;
options.max_dbs(24); options.max_dbs(25);
let env = options.open(path)?; let env = options.open(path)?;
let mut wtxn = env.write_txn()?; let mut wtxn = env.write_txn()?;
@@ -222,7 +218,11 @@ impl Index {
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?;
let field_id_docid_facet_strings = let field_id_docid_facet_strings =
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?;
let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; // vector stuff
let embedder_category_id =
env.create_database(&mut wtxn, Some(VECTOR_EMBEDDER_CATEGORY_ID))?;
let vector_arroy = env.create_database(&mut wtxn, Some(VECTOR_ARROY))?;
let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?;
wtxn.commit()?; wtxn.commit()?;
@@ -252,7 +252,8 @@ impl Index {
facet_id_is_empty_docids, facet_id_is_empty_docids,
field_id_docid_facet_f64s, field_id_docid_facet_f64s,
field_id_docid_facet_strings, field_id_docid_facet_strings,
vector_id_docid, vector_arroy,
embedder_category_id,
documents, documents,
}) })
} }
@@ -475,63 +476,6 @@ impl Index {
None => Ok(RoaringBitmap::new()), None => Ok(RoaringBitmap::new()),
} }
} }
/* vector HNSW */
/// Writes the provided `hnsw`.
pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> {
// We must delete all the chunks before we write the new HNSW chunks.
self.delete_vector_hnsw(wtxn)?;
let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB
let bytes = bincode::serialize(hnsw).map_err(Into::into).map_err(heed::Error::Encoding)?;
for (i, chunk) in bytes.chunks(chunk_size).enumerate() {
let i = i as u32;
let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec();
key.extend_from_slice(&i.to_be_bytes());
self.main.remap_types::<Bytes, Bytes>().put(wtxn, &key, chunk)?;
}
Ok(())
}
/// Delete the `hnsw`.
pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result<bool> {
let mut iter = self
.main
.remap_types::<Bytes, DecodeIgnore>()
.prefix_iter_mut(wtxn, main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes())?;
let mut deleted = false;
while iter.next().transpose()?.is_some() {
// We do not keep a reference to the key or the value.
unsafe { deleted |= iter.del_current()? };
}
Ok(deleted)
}
/// Returns the `hnsw`.
pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result<Option<Hnsw>> {
let mut slices = Vec::new();
for result in self
.main
.remap_types::<Str, Bytes>()
.prefix_iter(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)?
{
let (_, slice) = result?;
slices.push(slice);
}
if slices.is_empty() {
Ok(None)
} else {
let readable_slices: ReadableSlices<_> = slices.into_iter().collect();
Ok(Some(
bincode::deserialize_from(readable_slices)
.map_err(Into::into)
.map_err(heed::Error::Decoding)?,
))
}
}
/* field distribution */ /* field distribution */
/// Writes the field distribution which associates every field name with /// Writes the field distribution which associates every field name with
@@ -1528,6 +1472,41 @@ impl Index {
Ok(script_language) Ok(script_language)
} }
pub(crate) fn put_embedding_configs(
&self,
wtxn: &mut RwTxn<'_>,
configs: Vec<(String, EmbeddingConfig)>,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>().put(
wtxn,
main_key::EMBEDDING_CONFIGS,
&configs,
)
}
pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
self.main.remap_key_type::<Str>().delete(wtxn, main_key::EMBEDDING_CONFIGS)
}
pub fn embedding_configs(
&self,
rtxn: &RoTxn<'_>,
) -> Result<Vec<(String, crate::vector::EmbeddingConfig)>> {
Ok(self
.main
.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>()
.get(rtxn, main_key::EMBEDDING_CONFIGS)?
.unwrap_or_default())
}
pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result<String> {
let configs = self.embedding_configs(rtxn)?;
Ok(match configs.as_slice() {
[(ref first_name, _)] => first_name.clone(),
_ => "default".to_owned(),
})
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -10,18 +10,18 @@ pub mod documents;
mod asc_desc; mod asc_desc;
mod criterion; mod criterion;
pub mod distance;
mod error; mod error;
mod external_documents_ids; mod external_documents_ids;
pub mod facet; pub mod facet;
mod fields_ids_map; mod fields_ids_map;
pub mod heed_codec; pub mod heed_codec;
pub mod index; pub mod index;
pub mod prompt;
pub mod proximity; pub mod proximity;
mod readable_slices;
pub mod score_details; pub mod score_details;
mod search; mod search;
pub mod update; pub mod update;
pub mod vector;
#[cfg(test)] #[cfg(test)]
#[macro_use] #[macro_use]
@@ -32,13 +32,12 @@ use std::convert::{TryFrom, TryInto};
use std::hash::BuildHasherDefault; use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
pub use distance::dot_product_similarity;
pub use filter_parser::{Condition, FilterCondition, Span, Token}; pub use filter_parser::{Condition, FilterCondition, Span, Token};
use fxhash::{FxHasher32, FxHasher64}; use fxhash::{FxHasher32, FxHasher64};
pub use grenad::CompressionType; pub use grenad::CompressionType;
pub use search::new::{ pub use search::new::{
execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger, execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext,
VisualSearchLogger, SearchLogger, VisualSearchLogger,
}; };
use serde_json::Value; use serde_json::Value;
pub use {charabia as tokenizer, heed}; pub use {charabia as tokenizer, heed};

View File

@@ -0,0 +1,97 @@
use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use super::document::Document;
use super::fields::Fields;
use crate::FieldsIdsMap;
#[derive(Debug, Clone)]
pub struct Context<'a> {
document: &'a Document<'a>,
fields: Fields<'a>,
}
impl<'a> Context<'a> {
pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self {
Self { document, fields: Fields::new(document, field_id_map) }
}
}
impl<'a> ObjectView for Context<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(self.document.as_value())
.chain(std::iter::once(self.fields.as_value())),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "doc" || index == "fields"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"doc" => Some(self.document.as_value()),
"fields" => Some(self.fields.as_value()),
_ => None,
}
}
}
impl<'a> ValueView for Context<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => false,
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

View File

@@ -0,0 +1,131 @@
use std::cell::OnceCell;
use std::collections::BTreeMap;
use liquid::model::{
DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
use crate::FieldsIdsMap;
#[derive(Debug, Clone)]
pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>);
#[derive(Debug, Clone)]
struct ParsedValue(std::cell::OnceCell<LiquidValue>);
impl ParsedValue {
fn empty() -> ParsedValue {
ParsedValue(OnceCell::new())
}
fn get(&self, raw: &[u8]) -> &LiquidValue {
self.0.get_or_init(|| {
let value: serde_json::Value = serde_json::from_slice(raw).unwrap();
liquid::model::to_value(&value).unwrap()
})
}
}
impl<'a> Document<'a> {
pub fn new(
data: obkv::KvReaderU16<'a>,
side: DelAdd,
inverted_field_map: &'a FieldsIdsMap,
) -> Self {
let mut out_data = BTreeMap::new();
for (fid, raw) in data {
let obkv = KvReaderDelAdd::new(raw);
let Some(raw) = obkv.get(side) else {
continue;
};
let Some(name) = inverted_field_map.name(fid) else {
continue;
};
out_data.insert(name, (raw, ParsedValue::empty()));
}
Self(out_data)
}
fn is_empty(&self) -> bool {
self.0.is_empty()
}
fn len(&self) -> usize {
self.0.len()
}
fn iter(&self) -> impl Iterator<Item = (KString, LiquidValue)> + '_ {
self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned()))
}
}
impl<'a> ObjectView for Document<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
self.len() as i64
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
let keys = BTreeMap::keys(&self.0).map(|&s| s.into());
Box::new(keys)
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView))
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView)))
}
fn contains_key(&self, index: &str) -> bool {
self.0.contains_key(index)
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView)
}
}
impl<'a> ValueView for Document<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => self.is_empty(),
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(self.iter().collect())
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

56
milli/src/prompt/error.rs Normal file
View File

@@ -0,0 +1,56 @@
use crate::error::FaultSource;
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct NewPromptError {
pub kind: NewPromptErrorKind,
pub fault: FaultSource,
}
impl From<NewPromptError> for crate::Error {
fn from(value: NewPromptError) -> Self {
crate::Error::UserError(crate::UserError::InvalidPrompt(value))
}
}
impl NewPromptError {
pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError {
Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User }
}
pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError {
Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
pub enum NewPromptErrorKind {
#[error("cannot parse template: {0}")]
CannotParseTemplate(liquid::Error),
#[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")]
InvalidFieldsInTemplate(liquid::Error),
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct RenderPromptError {
pub kind: RenderPromptErrorKind,
pub fault: FaultSource,
}
impl RenderPromptError {
pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError {
Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
pub enum RenderPromptErrorKind {
#[error("missing field in document: {0}")]
MissingContext(liquid::Error),
}
impl From<RenderPromptError> for crate::Error {
fn from(value: RenderPromptError) -> Self {
crate::Error::UserError(crate::UserError::MissingDocumentField(value))
}
}

172
milli/src/prompt/fields.rs Normal file
View File

@@ -0,0 +1,172 @@
use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use super::document::Document;
use crate::FieldsIdsMap;
#[derive(Debug, Clone)]
pub struct Fields<'a>(Vec<FieldValue<'a>>);
impl<'a> Fields<'a> {
pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self {
Self(
std::iter::repeat(document)
.zip(field_id_map.iter())
.map(|(document, (_fid, name))| FieldValue { document, name })
.collect(),
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct FieldValue<'a> {
name: &'a str,
document: &'a Document<'a>,
}
impl<'a> ValueView for FieldValue<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => self.is_empty(),
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}
impl<'a> FieldValue<'a> {
pub fn name(&self) -> &&'a str {
&self.name
}
pub fn value(&self) -> &dyn ValueView {
self.document.get(self.name).unwrap_or(&LiquidValue::Nil)
}
pub fn is_empty(&self) -> bool {
self.size() == 0
}
}
impl<'a> ObjectView for FieldValue<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "name" || index == "value"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"name" => Some(self.name()),
"value" => Some(self.value()),
_ => None,
}
}
}
impl<'a> ArrayView for Fields<'a> {
fn as_value(&self) -> &dyn ValueView {
self.0.as_value()
}
fn size(&self) -> i64 {
self.0.len() as i64
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
self.0.values()
}
fn contains_key(&self, index: i64) -> bool {
self.0.contains_key(index)
}
fn get(&self, index: i64) -> Option<&dyn ValueView> {
ArrayView::get(&self.0, index)
}
}
impl<'a> ValueView for Fields<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
self.0.render()
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
self.0.source()
}
fn type_name(&self) -> &'static str {
self.0.type_name()
}
fn query_state(&self, state: liquid::model::State) -> bool {
self.0.query_state(state)
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
self.0.to_kstr()
}
fn to_value(&self) -> LiquidValue {
self.0.to_value()
}
fn as_array(&self) -> Option<&dyn ArrayView> {
Some(self)
}
}

176
milli/src/prompt/mod.rs Normal file
View File

@@ -0,0 +1,176 @@
mod context;
mod document;
pub(crate) mod error;
mod fields;
mod template_checker;
use std::convert::TryFrom;
use error::{NewPromptError, RenderPromptError};
use self::context::Context;
use self::document::Document;
use crate::update::del_add::DelAdd;
use crate::FieldsIdsMap;
pub struct Prompt {
template: liquid::Template,
template_text: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PromptData {
pub template: String,
}
impl From<Prompt> for PromptData {
fn from(value: Prompt) -> Self {
Self { template: value.template_text }
}
}
impl TryFrom<PromptData> for Prompt {
type Error = NewPromptError;
fn try_from(value: PromptData) -> Result<Self, Self::Error> {
Prompt::new(value.template)
}
}
impl Clone for Prompt {
fn clone(&self) -> Self {
let template_text = self.template_text.clone();
Self { template: new_template(&template_text).unwrap(), template_text }
}
}
fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> {
liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text)
}
fn default_template() -> liquid::Template {
new_template(default_template_text()).unwrap()
}
fn default_template_text() -> &'static str {
"{% for field in fields %} \
{{ field.name }}: {{ field.value }}\n\
{% endfor %}"
}
impl Default for Prompt {
fn default() -> Self {
Self { template: default_template(), template_text: default_template_text().into() }
}
}
impl Default for PromptData {
fn default() -> Self {
Self { template: default_template_text().into() }
}
}
impl Prompt {
pub fn new(template: String) -> Result<Self, NewPromptError> {
let this = Self {
template: liquid::ParserBuilder::with_stdlib()
.build()
.unwrap()
.parse(&template)
.map_err(NewPromptError::cannot_parse_template)?,
template_text: template,
};
// render template with special object that's OK with `doc.*` and `fields.*`
this.template
.render(&template_checker::TemplateChecker)
.map_err(NewPromptError::invalid_fields_in_template)?;
Ok(this)
}
pub fn render(
&self,
document: obkv::KvReaderU16<'_>,
side: DelAdd,
field_id_map: &FieldsIdsMap,
) -> Result<String, RenderPromptError> {
let document = Document::new(document, side, field_id_map);
let context = Context::new(&document, field_id_map);
self.template.render(&context).map_err(RenderPromptError::missing_context)
}
}
#[cfg(test)]
mod test {
use super::Prompt;
use crate::error::FaultSource;
use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
#[test]
fn default_template() {
// does not panic
Prompt::default();
}
#[test]
fn empty_template() {
Prompt::new("".into()).unwrap();
}
#[test]
fn template_ok() {
Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap();
}
#[test]
fn template_syntax() {
assert!(matches!(
Prompt::new("{{doc.title: {{doc.overview}}".into()),
Err(NewPromptError {
kind: NewPromptErrorKind::CannotParseTemplate(_),
fault: FaultSource::User
})
));
}
#[test]
fn template_missing_doc() {
assert!(matches!(
Prompt::new("{{title}}: {{overview}}".into()),
Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User
})
));
}
#[test]
fn template_nested_doc() {
Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap();
}
#[test]
fn template_fields() {
Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap();
}
#[test]
fn template_fields_ok() {
Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into())
.unwrap();
}
#[test]
fn template_fields_invalid() {
assert!(matches!(
// intentionally garbled field
Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()),
Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User
})
));
}
}

View File

@@ -0,0 +1,301 @@
use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{Object, ObjectView, ValueView};
#[derive(Debug)]
pub struct TemplateChecker;
#[derive(Debug)]
pub struct DummyDoc;
#[derive(Debug)]
pub struct DummyFields;
#[derive(Debug)]
pub struct DummyField;
const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil;
impl ObjectView for DummyField {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter())
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "name" || index == "value"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
if self.contains_key(index) {
Some(DUMMY_VALUE.as_view())
} else {
None
}
}
}
impl ValueView for DummyField {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> DisplayCow<'_> {
DUMMY_VALUE.render()
}
fn source(&self) -> DisplayCow<'_> {
DUMMY_VALUE.source()
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
}
fn to_kstr(&self) -> KStringCow<'_> {
DUMMY_VALUE.to_kstr()
}
fn to_value(&self) -> LiquidValue {
let mut this = Object::new();
this.insert("name".into(), LiquidValue::Nil);
this.insert("value".into(), LiquidValue::Nil);
LiquidValue::Object(this)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}
impl ValueView for DummyFields {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> DisplayCow<'_> {
DUMMY_VALUE.render()
}
fn source(&self) -> DisplayCow<'_> {
DUMMY_VALUE.source()
}
fn type_name(&self) -> &'static str {
"array"
}
fn query_state(&self, state: State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
}
fn to_kstr(&self) -> KStringCow<'_> {
DUMMY_VALUE.to_kstr()
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Array(vec![DummyField.to_value()])
}
fn as_array(&self) -> Option<&dyn ArrayView> {
Some(self)
}
}
impl ArrayView for DummyFields {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
u16::MAX as i64
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(std::iter::once(DummyField.as_value()))
}
fn contains_key(&self, index: i64) -> bool {
index < self.size()
}
fn get(&self, _index: i64) -> Option<&dyn ValueView> {
Some(DummyField.as_value())
}
}
impl ObjectView for DummyDoc {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
1000
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(std::iter::empty())
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(std::iter::empty())
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(std::iter::empty())
}
fn contains_key(&self, _index: &str) -> bool {
true
}
fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> {
// Recursively sends itself
Some(self)
}
}
impl ValueView for DummyDoc {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> DisplayCow<'_> {
DUMMY_VALUE.render()
}
fn source(&self) -> DisplayCow<'_> {
DUMMY_VALUE.source()
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
}
fn to_kstr(&self) -> KStringCow<'_> {
DUMMY_VALUE.to_kstr()
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Nil
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}
impl ObjectView for TemplateChecker {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "doc" || index == "fields"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"doc" => Some(DummyDoc.as_value()),
"fields" => Some(DummyFields.as_value()),
_ => None,
}
}
}
impl ValueView for TemplateChecker {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => false,
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

View File

@@ -1,85 +0,0 @@
use std::io::{self, Read};
use std::iter::FromIterator;
pub struct ReadableSlices<A> {
inner: Vec<A>,
pos: u64,
}
impl<A> FromIterator<A> for ReadableSlices<A> {
fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
ReadableSlices { inner: iter.into_iter().collect(), pos: 0 }
}
}
impl<A: AsRef<[u8]>> Read for ReadableSlices<A> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let original_buf_len = buf.len();
// We explore the list of slices to find the one where we must start reading.
let mut pos = self.pos;
let index = match self
.inner
.iter()
.map(|s| s.as_ref().len() as u64)
.position(|size| pos.checked_sub(size).map(|p| pos = p).is_none())
{
Some(index) => index,
None => return Ok(0),
};
let mut inner_pos = pos as usize;
for slice in &self.inner[index..] {
let slice = &slice.as_ref()[inner_pos..];
if buf.len() > slice.len() {
// We must exhaust the current slice and go to the next one there is not enough here.
buf[..slice.len()].copy_from_slice(slice);
buf = &mut buf[slice.len()..];
inner_pos = 0;
} else {
// There is enough in this slice to fill the remaining bytes of the buffer.
// Let's break just after filling it.
buf.copy_from_slice(&slice[..buf.len()]);
buf = &mut [];
break;
}
}
let written = original_buf_len - buf.len();
self.pos += written as u64;
Ok(written)
}
}
#[cfg(test)]
mod test {
use std::io::Read;
use super::ReadableSlices;
#[test]
fn basic() {
let data: Vec<_> = (0..100).collect();
let splits: Vec<_> = data.chunks(3).collect();
let mut rdslices: ReadableSlices<_> = splits.into_iter().collect();
let mut output = Vec::new();
let length = rdslices.read_to_end(&mut output).unwrap();
assert_eq!(length, data.len());
assert_eq!(output, data);
}
#[test]
fn small_reads() {
let data: Vec<_> = (0..u8::MAX).collect();
let splits: Vec<_> = data.chunks(27).collect();
let mut rdslices: ReadableSlices<_> = splits.into_iter().collect();
let buffer = &mut [0; 45];
let length = rdslices.read(buffer).unwrap();
let expected: Vec<_> = (0..buffer.len() as u8).collect();
assert_eq!(length, buffer.len());
assert_eq!(buffer, &expected[..]);
}
}

View File

@@ -1,3 +1,6 @@
use std::cmp::Ordering;
use itertools::Itertools;
use serde::Serialize; use serde::Serialize;
use crate::distance_between_two_points; use crate::distance_between_two_points;
@@ -12,9 +15,24 @@ pub enum ScoreDetails {
ExactAttribute(ExactAttribute), ExactAttribute(ExactAttribute),
ExactWords(ExactWords), ExactWords(ExactWords),
Sort(Sort), Sort(Sort),
Vector(Vector),
GeoSort(GeoSort), GeoSort(GeoSort),
} }
#[derive(Clone, Copy)]
pub enum ScoreValue<'a> {
Score(f64),
Sort(&'a Sort),
GeoSort(&'a GeoSort),
}
enum RankOrValue<'a> {
Rank(Rank),
Sort(&'a Sort),
GeoSort(&'a GeoSort),
Score(f64),
}
impl ScoreDetails { impl ScoreDetails {
pub fn local_score(&self) -> Option<f64> { pub fn local_score(&self) -> Option<f64> {
self.rank().map(Rank::local_score) self.rank().map(Rank::local_score)
@@ -31,11 +49,55 @@ impl ScoreDetails {
ScoreDetails::ExactWords(details) => Some(details.rank()), ScoreDetails::ExactWords(details) => Some(details.rank()),
ScoreDetails::Sort(_) => None, ScoreDetails::Sort(_) => None,
ScoreDetails::GeoSort(_) => None, ScoreDetails::GeoSort(_) => None,
ScoreDetails::Vector(_) => None,
} }
} }
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 { pub fn global_score<'a>(details: impl Iterator<Item = &'a Self> + 'a) -> f64 {
Rank::global_score(details.filter_map(Self::rank)) Self::score_values(details)
.find_map(|x| {
let ScoreValue::Score(score) = x else {
return None;
};
Some(score)
})
.unwrap_or(1.0f64)
}
pub fn score_values<'a>(
details: impl Iterator<Item = &'a Self> + 'a,
) -> impl Iterator<Item = ScoreValue<'a>> + 'a {
details
.map(ScoreDetails::rank_or_value)
.coalesce(|left, right| match (left, right) {
(RankOrValue::Rank(left), RankOrValue::Rank(right)) => {
Ok(RankOrValue::Rank(Rank::merge(left, right)))
}
(left, right) => Err((left, right)),
})
.map(|rank_or_value| match rank_or_value {
RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()),
RankOrValue::Sort(s) => ScoreValue::Sort(s),
RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g),
RankOrValue::Score(s) => ScoreValue::Score(s),
})
}
fn rank_or_value(&self) -> RankOrValue<'_> {
match self {
ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()),
ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()),
ScoreDetails::Proximity(p) => RankOrValue::Rank(*p),
ScoreDetails::Fid(f) => RankOrValue::Rank(*f),
ScoreDetails::Position(p) => RankOrValue::Rank(*p),
ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()),
ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()),
ScoreDetails::Sort(sort) => RankOrValue::Sort(sort),
ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort),
ScoreDetails::Vector(vector) => RankOrValue::Score(
vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64),
),
}
} }
/// Panics /// Panics
@@ -181,6 +243,19 @@ impl ScoreDetails {
details_map.insert(sort, sort_details); details_map.insert(sort, sort_details);
order += 1; order += 1;
} }
ScoreDetails::Vector(s) => {
let vector = format!("vectorSort({:?})", s.target_vector);
let value = s.value_similarity.as_ref().map(|(v, _)| v);
let similarity = s.value_similarity.as_ref().map(|(_, s)| s);
let details = serde_json::json!({
"order": order,
"value": value,
"similarity": similarity,
});
details_map.insert(vector, details);
order += 1;
}
} }
} }
details_map details_map
@@ -297,15 +372,21 @@ impl Rank {
pub fn global_score(details: impl Iterator<Item = Self>) -> f64 { pub fn global_score(details: impl Iterator<Item = Self>) -> f64 {
let mut rank = Rank { rank: 1, max_rank: 1 }; let mut rank = Rank { rank: 1, max_rank: 1 };
for inner_rank in details { for inner_rank in details {
rank.rank -= 1; rank = Rank::merge(rank, inner_rank);
rank.rank *= inner_rank.max_rank;
rank.max_rank *= inner_rank.max_rank;
rank.rank += inner_rank.rank;
} }
rank.local_score() rank.local_score()
} }
pub fn merge(mut outer: Rank, inner: Rank) -> Rank {
outer.rank = outer.rank.saturating_sub(1);
outer.rank *= inner.max_rank;
outer.max_rank *= inner.max_rank;
outer.rank += inner.rank;
outer
}
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
@@ -335,13 +416,78 @@ pub struct Sort {
pub value: serde_json::Value, pub value: serde_json::Value,
} }
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] impl PartialOrd for Sort {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.field_name != other.field_name {
return None;
}
if self.ascending != other.ascending {
return None;
}
match (&self.value, &other.value) {
(serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal),
(serde_json::Value::Null, _) => Some(Ordering::Less),
(_, serde_json::Value::Null) => Some(Ordering::Greater),
// numbers are always before strings
(serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater),
(serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less),
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
// FIXME: unwrap permitted here?
let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?;
// 12 < 42, and when ascending, we want to see 12 first, so the smallest.
// Hence, when ascending, smaller is better
Some(if self.ascending { order.reverse() } else { order })
}
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
let order = left.cmp(right);
// Taking e.g. "a" and "z"
// "a" < "z", and when ascending, we want to see "a" first, so the smallest.
// Hence, when ascending, smaller is better
Some(if self.ascending { order.reverse() } else { order })
}
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct GeoSort { pub struct GeoSort {
pub target_point: [f64; 2], pub target_point: [f64; 2],
pub ascending: bool, pub ascending: bool,
pub value: Option<[f64; 2]>, pub value: Option<[f64; 2]>,
} }
impl PartialOrd for GeoSort {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.target_point != other.target_point {
return None;
}
if self.ascending != other.ascending {
return None;
}
Some(match (self.distance(), other.distance()) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(left), Some(right)) => {
let order = left.partial_cmp(&right)?;
if self.ascending {
// when ascending, the one with the smallest distance has the best score
order.reverse()
} else {
order
}
}
})
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct Vector {
pub target_vector: Vec<f32>,
pub value_similarity: Option<(Vec<f32>, f32)>,
}
impl GeoSort { impl GeoSort {
pub fn distance(&self) -> Option<f64> { pub fn distance(&self) -> Option<f64> {
self.value.map(|value| distance_between_two_points(&self.target_point, &value)) self.value.map(|value| distance_between_two_points(&self.target_point, &value))

183
milli/src/search/hybrid.rs Normal file
View File

@@ -0,0 +1,183 @@
use std::cmp::Ordering;
use itertools::Itertools;
use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::{MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult {
matching_words: MatchingWords,
candidates: RoaringBitmap,
document_scores: Vec<(u32, ScoreWithRatio)>,
}
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
fn compare_scores(
&(ref left_scores, left_ratio): &ScoreWithRatio,
&(ref right_scores, right_ratio): &ScoreWithRatio,
) -> Ordering {
let mut left_it = ScoreDetails::score_values(left_scores.iter());
let mut right_it = ScoreDetails::score_values(right_scores.iter());
loop {
let left = left_it.next();
let right = right_it.next();
match (left, right) {
(None, None) => return Ordering::Equal,
(None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater,
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
let left = left * left_ratio as f64;
let right = right * right_ratio as f64;
if (left - right).abs() <= f64::EPSILON {
continue;
}
return left.partial_cmp(&right).unwrap();
}
(Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
// if we have this, we're bad
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
unreachable!("Unexpected geo and sort comparison")
}
}
}
}
impl ScoreWithRatioResult {
fn new(results: SearchResult, ratio: f32) -> Self {
let document_scores = results
.documents_ids
.into_iter()
.zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
.collect();
Self {
matching_words: results.matching_words,
candidates: results.candidates,
document_scores,
}
}
fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult {
let mut documents_ids =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut document_scores =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut documents_seen = RoaringBitmap::new();
for (docid, (main_score, _sub_score)) in left
.document_scores
.into_iter()
.merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| {
// the first value is the one with the greatest score
compare_scores(left, right).is_ge()
})
// remove documents we already saw
.filter(|(docid, _)| documents_seen.insert(*docid))
// start skipping **after** the filter
.skip(from)
// take **after** skipping
.take(length)
{
documents_ids.push(docid);
// TODO: pass both scores to documents_score in some way?
document_scores.push(main_score);
}
SearchResult {
matching_words: left.matching_words,
candidates: left.candidates | right.candidates,
documents_ids,
document_scores,
}
}
}
impl<'a> Search<'a> {
pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> {
// TODO: find classier way to achieve that than to reset vector and query params
// create separate keyword and semantic searches
let mut search = Search {
query: self.query.clone(),
vector: self.vector.clone(),
filter: self.filter.clone(),
offset: 0,
limit: self.limit + self.offset,
sort_criteria: self.sort_criteria.clone(),
searchable_attributes: self.searchable_attributes,
geo_strategy: self.geo_strategy,
terms_matching_strategy: self.terms_matching_strategy,
scoring_strategy: ScoringStrategy::Detailed,
words_limit: self.words_limit,
exhaustive_number_hits: self.exhaustive_number_hits,
rtxn: self.rtxn,
index: self.index,
distribution_shift: self.distribution_shift,
embedder_name: self.embedder_name.clone(),
};
let vector_query = search.vector.take();
let keyword_results = search.execute()?;
// skip semantic search if we don't have a vector query (placeholder search)
let Some(vector_query) = vector_query else {
return Ok(keyword_results);
};
// completely skip semantic search if the results of the keyword search are good enough
if self.results_good_enough(&keyword_results, semantic_ratio) {
return Ok(keyword_results);
}
search.vector = Some(vector_query);
search.query = None;
// TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?;
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
let merge_results =
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
assert!(merge_results.documents_ids.len() <= self.limit);
Ok(merge_results)
}
fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
// A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
const GOOD_ENOUGH_SCORE: f64 = 0.45;
// 1. we check that we got a sufficient number of results
if keyword_results.document_scores.len() < self.limit + self.offset {
return false;
}
// 2. and that all results have a good enough score.
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
for score in &keyword_results.document_scores {
let score = ScoreDetails::global_score(score.iter());
if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
return false;
}
}
true
}
}

View File

@@ -12,12 +12,14 @@ use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::PartialSearchResult; use self::new::{execute_vector_search, PartialSearchResult};
use crate::error::UserError; use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift;
use crate::{ use crate::{
execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
Result, SearchContext,
}; };
// Building these factories is not free. // Building these factories is not free.
@@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100;
pub mod facet; pub mod facet;
mod fst_utils; mod fst_utils;
pub mod hybrid;
pub mod new; pub mod new;
pub struct Search<'a> { pub struct Search<'a> {
@@ -46,8 +49,11 @@ pub struct Search<'a> {
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
words_limit: usize, words_limit: usize,
exhaustive_number_hits: bool, exhaustive_number_hits: bool,
/// TODO: Add semantic ratio or pass it directly to execute_hybrid()
rtxn: &'a heed::RoTxn<'a>, rtxn: &'a heed::RoTxn<'a>,
index: &'a Index, index: &'a Index,
distribution_shift: Option<DistributionShift>,
embedder_name: Option<String>,
} }
impl<'a> Search<'a> { impl<'a> Search<'a> {
@@ -67,6 +73,8 @@ impl<'a> Search<'a> {
words_limit: 10, words_limit: 10,
rtxn, rtxn,
index, index,
distribution_shift: None,
embedder_name: None,
} }
} }
@@ -75,8 +83,8 @@ impl<'a> Search<'a> {
self self
} }
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> { pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> {
self.vector = Some(vector.into()); self.vector = Some(vector);
self self
} }
@@ -133,22 +141,66 @@ impl<'a> Search<'a> {
self self
} }
pub fn distribution_shift(
&mut self,
distribution_shift: Option<DistributionShift>,
) -> &mut Search<'a> {
self.distribution_shift = distribution_shift;
self
}
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
self.embedder_name = Some(embedder_name.into());
self
}
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn);
filtered_universe(&ctx, &self.filter)
} else {
Ok(self.execute()?.candidates)
}
}
pub fn execute(&self) -> Result<SearchResult> { pub fn execute(&self) -> Result<SearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn); let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes { if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?; ctx.searchable_attributes(searchable_attributes)?;
} }
let universe = filtered_universe(&ctx, &self.filter)?;
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
execute_search( match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx, &mut ctx,
&self.query, vector,
&self.vector, self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
)?,
None => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy, self.terms_matching_strategy,
self.scoring_strategy, self.scoring_strategy,
self.exhaustive_number_hits, self.exhaustive_number_hits,
&self.filter, universe,
&self.sort_criteria, &self.sort_criteria,
self.geo_strategy, self.geo_strategy,
self.offset, self.offset,
@@ -156,7 +208,8 @@ impl<'a> Search<'a> {
Some(self.words_limit), Some(self.words_limit),
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
)?; )?,
};
// consume context and located_query_terms to build MatchingWords. // consume context and located_query_terms to build MatchingWords.
let matching_words = match located_query_terms { let matching_words = match located_query_terms {
@@ -185,6 +238,8 @@ impl fmt::Debug for Search<'_> {
exhaustive_number_hits, exhaustive_number_hits,
rtxn: _, rtxn: _,
index: _, index: _,
distribution_shift,
embedder_name,
} = self; } = self;
f.debug_struct("Search") f.debug_struct("Search")
.field("query", query) .field("query", query)
@@ -198,6 +253,8 @@ impl fmt::Debug for Search<'_> {
.field("scoring_strategy", scoring_strategy) .field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits) .field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit) .field("words_limit", words_limit)
.field("distribution_shift", distribution_shift)
.field("embedder_name", embedder_name)
.finish() .finish()
} }
} }
@@ -249,11 +306,16 @@ pub struct SearchForFacetValues<'a> {
query: Option<String>, query: Option<String>,
facet: String, facet: String,
search_query: Search<'a>, search_query: Search<'a>,
is_hybrid: bool,
} }
impl<'a> SearchForFacetValues<'a> { impl<'a> SearchForFacetValues<'a> {
pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> { pub fn new(
SearchForFacetValues { query: None, facet, search_query } facet: String,
search_query: Search<'a>,
is_hybrid: bool,
) -> SearchForFacetValues<'a> {
SearchForFacetValues { query: None, facet, search_query, is_hybrid }
} }
pub fn query(&mut self, query: impl Into<String>) -> &mut Self { pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
@@ -303,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> {
None => return Ok(vec![]), None => return Ok(vec![]),
}; };
let search_candidates = self.search_query.execute()?.candidates; let search_candidates = self
.search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
match self.query.as_ref() { match self.query.as_ref() {
Some(query) => { Some(query) => {

View File

@@ -107,12 +107,16 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
/// Refill the internal buffer of cached docids based on the strategy. /// Refill the internal buffer of cached docids based on the strategy.
/// Drop the rtree if we don't need it anymore. /// Drop the rtree if we don't need it anymore.
fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> { fn fill_buffer(
&mut self,
ctx: &mut SearchContext,
geo_candidates: &RoaringBitmap,
) -> Result<()> {
debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng");
debug_assert!(self.cached_sorted_docids.is_empty()); debug_assert!(self.cached_sorted_docids.is_empty());
// lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree`
let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) { let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) {
if let Some(rtree) = self.rtree.as_ref() { if let Some(rtree) = self.rtree.as_ref() {
// get rtree from cache // get rtree from cache
Some(rtree) Some(rtree)
@@ -131,7 +135,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
if self.ascending { if self.ascending {
let point = lat_lng_to_xyz(&self.point); let point = lat_lng_to_xyz(&self.point);
for point in rtree.nearest_neighbor_iter(&point) { for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) { if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_back(point.data); self.cached_sorted_docids.push_back(point.data);
if self.cached_sorted_docids.len() >= cache_size { if self.cached_sorted_docids.len() >= cache_size {
break; break;
@@ -143,7 +147,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// and we insert the points in reverse order they get reversed when emptying the cache later on // and we insert the points in reverse order they get reversed when emptying the cache later on
let point = lat_lng_to_xyz(&opposite_of(self.point)); let point = lat_lng_to_xyz(&opposite_of(self.point));
for point in rtree.nearest_neighbor_iter(&point) { for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) { if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_front(point.data); self.cached_sorted_docids.push_front(point.data);
if self.cached_sorted_docids.len() >= cache_size { if self.cached_sorted_docids.len() >= cache_size {
break; break;
@@ -155,8 +159,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// the iterative version // the iterative version
let [lat, lng] = self.field_ids.unwrap(); let [lat, lng] = self.field_ids.unwrap();
let mut documents = self let mut documents = geo_candidates
.geo_candidates
.iter() .iter()
.map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) })
.collect::<Result<Vec<(u32, [f64; 2])>>>()?; .collect::<Result<Vec<(u32, [f64; 2])>>>()?;
@@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
assert!(self.query.is_none()); assert!(self.query.is_none());
self.query = Some(query.clone()); self.query = Some(query.clone());
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() { let geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(()); return Ok(());
} }
@@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat"); let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat");
let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng"); let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng");
self.field_ids = Some([lat, lng]); self.field_ids = Some([lat, lng]);
self.fill_buffer(ctx)?; self.fill_buffer(ctx, &geo_candidates)?;
Ok(()) Ok(())
} }
@@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
universe: &RoaringBitmap, universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> { ) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone(); let query = self.query.as_ref().unwrap().clone();
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() { let geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput { return Ok(Some(RankingRuleOutput {
query, query,
candidates: universe.clone(), candidates: universe.clone(),
@@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
} }
}; };
while let Some((id, point)) = next(&mut self.cached_sorted_docids) { while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if self.geo_candidates.contains(id) { if geo_candidates.contains(id) {
return Ok(Some(RankingRuleOutput { return Ok(Some(RankingRuleOutput {
query, query,
candidates: RoaringBitmap::from_iter([id]), candidates: RoaringBitmap::from_iter([id]),
@@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
// if we got out of this loop it means we've exhausted our cache. // if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again. // we need to refill it and run the function again.
self.fill_buffer(ctx)?; self.fill_buffer(ctx, &geo_candidates)?;
self.next_bucket(ctx, logger, universe) self.next_bucket(ctx, logger, universe)
} }

View File

@@ -498,19 +498,19 @@ mod tests {
use super::*; use super::*;
use crate::index::tests::TempIndex; use crate::index::tests::TempIndex;
use crate::{execute_search, SearchContext}; use crate::{execute_search, filtered_universe, SearchContext};
impl<'a> MatcherBuilder<'a> { impl<'a> MatcherBuilder<'a> {
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
let mut ctx = SearchContext::new(index, rtxn); let mut ctx = SearchContext::new(index, rtxn);
let universe = filtered_universe(&ctx, &None).unwrap();
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
&mut ctx, &mut ctx,
&Some(query.to_string()), Some(query),
&None,
crate::TermsMatchingStrategy::default(), crate::TermsMatchingStrategy::default(),
crate::score_details::ScoringStrategy::Skip, crate::score_details::ScoringStrategy::Skip,
false, false,
&None, universe,
&None, &None,
crate::search::new::GeoSortStrategy::default(), crate::search::new::GeoSortStrategy::default(),
0, 0,

View File

@@ -16,6 +16,7 @@ mod small_bitmap;
mod exact_attribute; mod exact_attribute;
mod sort; mod sort;
mod vector_sort;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@@ -28,7 +29,6 @@ use db_cache::DatabaseCache;
use exact_attribute::ExactAttribute; use exact_attribute::ExactAttribute;
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
use heed::RoTxn; use heed::RoTxn;
use instant_distance::Search;
use interner::{DedupInterner, Interner}; use interner::{DedupInterner, Interner};
pub use logger::visual::VisualSearchLogger; pub use logger::visual::VisualSearchLogger;
pub use logger::{DefaultSearchLogger, SearchLogger}; pub use logger::{DefaultSearchLogger, SearchLogger};
@@ -46,10 +46,11 @@ use self::geo_sort::GeoSort;
pub use self::geo_sort::Strategy as GeoSortStrategy; pub use self::geo_sort::Strategy as GeoSortStrategy;
use self::graph_based_ranking_rule::Words; use self::graph_based_ranking_rule::Words;
use self::interner::Interned; use self::interner::Interned;
use crate::distance::NDotProductPoint; use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry; use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule; use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift;
use crate::{ use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
}; };
@@ -258,6 +259,80 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
Ok(ranking_rules) Ok(ranking_rules)
} }
fn get_ranking_rules_for_vector<'ctx>(
ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
let mut sort = false;
let mut sorted_fields = HashSet::new();
let mut geo_sorted = false;
let mut vector = false;
let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![];
let settings_ranking_rules = ctx.index.criteria(ctx.txn)?;
for rr in settings_ranking_rules {
match rr {
crate::Criterion::Words
| crate::Criterion::Typo
| crate::Criterion::Proximity
| crate::Criterion::Attribute
| crate::Criterion::Exactness => {
if !vector {
let vector_candidates = ctx.index.documents_ids(ctx.txn)?;
let vector_sort = VectorSort::new(
ctx,
target.to_vec(),
vector_candidates,
limit_plus_offset,
distribution_shift,
embedder_name,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
}
}
crate::Criterion::Sort => {
if sort {
continue;
}
resolve_sort_criteria(
sort_criteria,
ctx,
&mut ranking_rules,
&mut sorted_fields,
&mut geo_sorted,
geo_strategy,
)?;
sort = true;
}
crate::Criterion::Asc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?));
}
crate::Criterion::Desc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?));
}
}
}
Ok(ranking_rules)
}
/// Return the list of initialised ranking rules to be used for a query graph search. /// Return the list of initialised ranking rules to be used for a query graph search.
fn get_ranking_rules_for_query_graph_search<'ctx>( fn get_ranking_rules_for_query_graph_search<'ctx>(
ctx: &SearchContext<'ctx>, ctx: &SearchContext<'ctx>,
@@ -422,15 +497,72 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
Ok(()) Ok(())
} }
pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> {
Ok(if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
})
}
#[allow(clippy::too_many_arguments)]
pub fn execute_vector_search(
ctx: &mut SearchContext,
vector: &[f32],
scoring_strategy: ScoringStrategy,
universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
length: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules = get_ranking_rules_for_vector(
ctx,
sort_criteria,
geo_strategy,
from + length,
vector,
distribution_shift,
embedder_name,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
&mut placeholder_search_logger;
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
ctx,
ranking_rules,
&PlaceholderQuery,
&universe,
from,
length,
scoring_strategy,
placeholder_search_logger,
)?;
Ok(PartialSearchResult {
candidates: all_candidates,
document_scores: scores,
documents_ids: docids,
located_query_terms: None,
})
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn execute_search( pub fn execute_search(
ctx: &mut SearchContext, ctx: &mut SearchContext,
query: &Option<String>, query: Option<&str>,
vector: &Option<Vec<f32>>,
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
exhaustive_number_hits: bool, exhaustive_number_hits: bool,
filters: &Option<Filter>, mut universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy, geo_strategy: geo_sort::Strategy,
from: usize, from: usize,
@@ -439,60 +571,8 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>, placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>, query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
let mut universe = if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
};
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
if let Some(vector) = vector {
let mut search = Search::default();
let docids = match ctx.index.vector_hnsw(ctx.txn)? {
Some(hnsw) => {
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
if vector.len() != expected_size {
return Err(UserError::InvalidVectorDimensions {
expected: expected_size,
found: vector.len(),
}
.into());
}
}
let vector = NDotProductPoint::new(vector.clone());
let neighbors = hnsw.search(&vector, &mut search);
let mut docids = Vec::new();
let mut uniq_docids = RoaringBitmap::new();
for instant_distance::Item { distance: _, pid, point: _ } in neighbors {
let index = pid.into_inner();
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
if universe.contains(docid) && uniq_docids.insert(docid) {
docids.push(docid);
if docids.len() == (from + length) {
break;
}
}
}
// return the nearest documents that are also part of the candidates
// along with a dummy list of scores that are useless in this context.
docids.into_iter().skip(from).take(length).collect()
}
None => Vec::new(),
};
return Ok(PartialSearchResult {
candidates: universe,
document_scores: vec![Vec::new(); docids.len()],
documents_ids: docids,
located_query_terms: None,
});
}
let mut located_query_terms = None; let mut located_query_terms = None;
let query_terms = if let Some(query) = query { let query_terms = if let Some(query) = query {
// We make sure that the analyzer is aware of the stop words // We make sure that the analyzer is aware of the stop words
@@ -546,7 +626,7 @@ pub fn execute_search(
terms_matching_strategy, terms_matching_strategy,
)?; )?;
universe = universe &=
resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?;
bucket_sort( bucket_sort(

View File

@@ -0,0 +1,170 @@
use std::iter::FromIterator;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift;
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
query: Option<Q>,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_index: u8,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
pub fn new(
ctx: &SearchContext,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Self> {
let embedder_index = ctx
.index
.embedder_category_id
.get(ctx.txn, embedder_name)?
.ok_or_else(|| crate::UserError::InvalidEmbedder(embedder_name.to_owned()))?;
Ok(Self {
query: None,
target,
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
embedder_index,
})
}
fn fill_buffer(
&mut self,
ctx: &mut SearchContext<'_>,
vector_candidates: &RoaringBitmap,
) -> Result<()> {
let writer_index = (self.embedder_index as u16) << 8;
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map_while(|k| {
arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy)
.map(Some)
.or_else(|e| match e {
arroy::Error::MissingMetadata => Ok(None),
e => Err(e),
})
.transpose()
})
.collect();
let readers = readers?;
let target = &self.target;
let mut results = Vec::new();
for reader in readers.iter() {
let nns_by_vector =
reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?;
let vectors: std::result::Result<Vec<_>, _> = nns_by_vector
.iter()
.map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap())
.collect();
let vectors = vectors?;
results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z)));
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
fn id(&self) -> String {
"vector_sort".to_owned()
}
fn start_iteration(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
query: &Q,
) -> Result<()> {
assert!(self.query.is_none());
self.query = Some(query.clone());
let vector_candidates = &self.vector_candidates & universe;
self.fill_buffer(ctx, &vector_candidates)?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn next_bucket(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
let vector_candidates = &self.vector_candidates & universe;
if vector_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
if vector_candidates.contains(docid) {
let score = 1.0 - distance;
let score = self
.distribution_shift
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, score)),
}),
}));
}
}
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx, &vector_candidates)?;
// we tried filling the buffer, but it remained empty 😢
// it means we don't actually have any document remaining in the universe with a vector.
// => exit
if self.cached_sorted_docids.len() == 0 {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
self.next_bucket(ctx, _logger, universe)
}
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {
self.query = None;
}
}

View File

@@ -42,7 +42,8 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
facet_id_is_empty_docids, facet_id_is_empty_docids,
field_id_docid_facet_f64s, field_id_docid_facet_f64s,
field_id_docid_facet_strings, field_id_docid_facet_strings,
vector_id_docid, vector_arroy,
embedder_category_id: _,
documents, documents,
} = self.index; } = self.index;
@@ -58,7 +59,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?;
self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_rtree(self.wtxn)?;
self.index.delete_geo_faceted_documents_ids(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
self.index.delete_vector_hnsw(self.wtxn)?;
// Clear the other databases. // Clear the other databases.
external_documents_ids.clear(self.wtxn)?; external_documents_ids.clear(self.wtxn)?;
@@ -82,7 +82,9 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
facet_id_string_docids.clear(self.wtxn)?; facet_id_string_docids.clear(self.wtxn)?;
field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?;
field_id_docid_facet_strings.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?;
vector_id_docid.clear(self.wtxn)?; // vector
vector_arroy.clear(self.wtxn)?;
documents.clear(self.wtxn)?; documents.clear(self.wtxn)?;
Ok(number_of_documents) Ok(number_of_documents)

View File

@@ -1,9 +1,10 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use std::convert::TryFrom; use std::convert::{TryFrom, TryInto};
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader, BufWriter}; use std::io::{self, BufReader, BufWriter};
use std::mem::size_of; use std::mem::size_of;
use std::str::from_utf8; use std::str::from_utf8;
use std::sync::Arc;
use bytemuck::cast_slice; use bytemuck::cast_slice;
use grenad::Writer; use grenad::Writer;
@@ -13,13 +14,56 @@ use serde_json::{from_slice, Value};
use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::error::UserError; use crate::error::UserError;
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at; use crate::update::index_documents::helpers::try_split_at;
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; use crate::vector::Embedder;
use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors};
/// The length of the elements that are always in the buffer when inserting new values. /// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub struct ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
pub manual_vectors: grenad::Reader<BufReader<File>>,
// docid -> ()
pub remove_vectors: grenad::Reader<BufReader<File>>,
// docid -> prompt
pub prompts: grenad::Reader<BufReader<File>>,
}
enum VectorStateDelta {
NoChange,
// Remove all vectors, generated or manual, from this document
NowRemoved,
// Add the manually specified vectors, passed in the other grenad
// Remove any previously generated vectors
// Note: changing the value of the manually specified vector **should not record** this delta
WasGeneratedNowManual(Vec<Vec<f32>>),
ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>),
// Add the vector computed from the specified prompt
// Remove any previous vector
// Note: changing the value of the prompt **does require** recording this delta
NowGenerated(String),
}
impl VectorStateDelta {
fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) {
match self {
VectorStateDelta::NoChange => Default::default(),
VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()),
VectorStateDelta::WasGeneratedNowManual(add) => {
(true, Default::default(), (Default::default(), add))
}
VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)),
VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()),
}
}
}
/// Extracts the embedding vector contained in each document under the `_vectors` field. /// Extracts the embedding vector contained in each document under the `_vectors` field.
/// ///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32> /// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
@@ -27,16 +71,35 @@ const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub fn extract_vector_points<R: io::Read + io::Seek>( pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>, obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters, indexer: GrenadParameters,
vectors_fid: FieldId, field_id_map: &FieldsIdsMap,
) -> Result<grenad::Reader<BufReader<File>>> { prompt: &Prompt,
embedder_name: &str,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!(); puffin::profile_function!();
let mut writer = create_writer( // (docid, _index) -> KvWriterDelAdd -> Vector
let mut manual_vectors_writer = create_writer(
indexer.chunk_compression_type, indexer.chunk_compression_type,
indexer.chunk_compression_level, indexer.chunk_compression_level,
tempfile::tempfile()?, tempfile::tempfile()?,
); );
// (docid) -> (prompt)
let mut prompts_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> ()
let mut remove_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let vectors_fid = field_id_map.id("_vectors");
let mut key_buffer = Vec::new(); let mut key_buffer = Vec::new();
let mut cursor = obkv_documents.into_cursor()?; let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? { while let Some((key, value)) = cursor.move_on_next()? {
@@ -53,43 +116,157 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
// lazily get it when needed // lazily get it when needed
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
// first we retrieve the _vectors field let vectors_field = vectors_fid
if let Some(value) = obkv.get(vectors_fid) { .and_then(|vectors_fid| obkv.get(vectors_fid))
let vectors_obkv = KvReaderDelAdd::new(value); .map(KvReaderDelAdd::new)
.map(|obkv| to_vector_maps(obkv, document_id))
.transpose()?;
// then we extract the values let (del_map, add_map) = vectors_field.unzip();
let del_vectors = vectors_obkv let del_map = del_map.flatten();
.get(DelAdd::Deletion) let add_map = add_map.flatten();
.map(|vectors| extract_vectors(vectors, document_id))
.transpose()? let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
.flatten(); let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
let add_vectors = vectors_obkv
.get(DelAdd::Addition) let delta = match (del_value, add_value) {
.map(|vectors| extract_vectors(vectors, document_id)) (Some(old), Some(new)) => {
.transpose()? // no autogeneration
.flatten(); let del_vectors = extract_vectors(old, document_id, embedder_name)?;
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > u8::MAX.into() {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::ManualDelta(del_vectors, add_vectors)
}
(Some(_old), None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
}
(None, Some(new)) => {
// was possibly autogenerated, remove all vectors for that document
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > u8::MAX.into() {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::WasGeneratedNowManual(add_vectors)
}
(None, None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
}
}
};
// and we finally push the unique vectors into the writer // and we finally push the unique vectors into the writer
push_vectors_diff( push_vectors_diff(
&mut writer, &mut remove_vectors_writer,
&mut prompts_writer,
&mut manual_vectors_writer,
&mut key_buffer, &mut key_buffer,
del_vectors.unwrap_or_default(), delta,
add_vectors.unwrap_or_default(),
)?; )?;
} }
Ok(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
})
} }
writer_into_reader(writer) fn to_vector_maps(
obkv: KvReaderDelAdd,
document_id: impl Fn() -> Value,
) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> {
let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?;
let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?;
Ok((del, add))
}
fn to_vector_map(
obkv: KvReaderDelAdd,
side: DelAdd,
document_id: &impl Fn() -> Value,
) -> Result<Option<serde_json::Map<String, Value>>> {
Ok(if let Some(value) = obkv.get(side) {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
document_id: document_id(),
value,
}));
};
Some(value)
} else {
None
})
} }
/// Computes the diff between both Del and Add numbers and /// Computes the diff between both Del and Add numbers and
/// only inserts the parts that differ in the sorter. /// only inserts the parts that differ in the sorter.
fn push_vectors_diff( fn push_vectors_diff(
writer: &mut Writer<BufWriter<File>>, remove_vectors_writer: &mut Writer<BufWriter<File>>,
prompts_writer: &mut Writer<BufWriter<File>>,
manual_vectors_writer: &mut Writer<BufWriter<File>>,
key_buffer: &mut Vec<u8>, key_buffer: &mut Vec<u8>,
mut del_vectors: Vec<Vec<f32>>, delta: VectorStateDelta,
mut add_vectors: Vec<Vec<f32>>,
) -> Result<()> { ) -> Result<()> {
let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values();
if must_remove {
key_buffer.truncate(TRUNCATE_SIZE);
remove_vectors_writer.insert(&key_buffer, [])?;
}
if !prompt.is_empty() {
key_buffer.truncate(TRUNCATE_SIZE);
prompts_writer.insert(&key_buffer, prompt.as_bytes())?;
}
// We sort and dedup the vectors // We sort and dedup the vectors
del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
@@ -114,7 +291,7 @@ fn push_vectors_diff(
let mut obkv = KvWriterDelAdd::memory(); let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; obkv.insert(DelAdd::Deletion, cast_slice(&vector))?;
let bytes = obkv.into_inner()?; let bytes = obkv.into_inner()?;
writer.insert(&key_buffer, bytes)?; manual_vectors_writer.insert(&key_buffer, bytes)?;
} }
EitherOrBoth::Right(vector) => { EitherOrBoth::Right(vector) => {
// We insert only the Add part of the Obkv to inform // We insert only the Add part of the Obkv to inform
@@ -122,7 +299,7 @@ fn push_vectors_diff(
let mut obkv = KvWriterDelAdd::memory(); let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Addition, cast_slice(&vector))?; obkv.insert(DelAdd::Addition, cast_slice(&vector))?;
let bytes = obkv.into_inner()?; let bytes = obkv.into_inner()?;
writer.insert(&key_buffer, bytes)?; manual_vectors_writer.insert(&key_buffer, bytes)?;
} }
} }
} }
@@ -136,13 +313,112 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
} }
/// Extracts the vectors from a JSON value. /// Extracts the vectors from a JSON value.
fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> { fn extract_vectors(
match from_slice(value) { value: Value,
Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), document_id: impl Fn() -> Value,
name: &str,
) -> Result<Vec<Vec<f32>>> {
// FIXME: ugly clone of the vectors here
match serde_json::from_value(value.clone()) {
Ok(vectors) => {
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
}
Err(_) => Err(UserError::InvalidVectorsType { Err(_) => Err(UserError::InvalidVectorsType {
document_id: document_id(), document_id: document_id(),
value: from_slice(value).map_err(InternalError::SerdeJson)?, value,
subfield: name.to_owned(),
} }
.into()), .into()),
} }
} }
#[logging_timer::time]
pub fn extract_embeddings<R: io::Read + io::Seek>(
// docid, prompt
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
// docid, state with embedding
let mut state_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let mut chunks = Vec::with_capacity(n_chunks);
let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk);
let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk);
let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string
let prompt = unsafe { std::str::from_utf8_unchecked(value) };
if current_chunk.len() == current_chunk.capacity() {
chunks.push(std::mem::replace(
&mut current_chunk,
Vec::with_capacity(n_vectors_per_chunk),
));
chunks_ids.push(std::mem::replace(
&mut current_chunk_ids,
Vec::with_capacity(n_vectors_per_chunk),
));
};
current_chunk.push(prompt.to_owned());
current_chunk_ids.push(docid);
if chunks.len() == chunks.capacity() {
let chunked_embeds = rt
.block_on(
embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
chunks_ids.clear();
}
}
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = rt
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
if !current_chunk.is_empty() {
let embeds = rt
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
writer_into_reader(state_writer)
}

View File

@@ -23,7 +23,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids;
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
use self::extract_geo_points::extract_geo_points; use self::extract_geo_points::extract_geo_points;
use self::extract_vector_points::extract_vector_points; use self::extract_vector_points::{
extract_embeddings, extract_vector_points, ExtractedVectorPoints,
};
use self::extract_word_docids::extract_word_docids; use self::extract_word_docids::extract_word_docids;
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
use self::extract_word_position_docids::extract_word_position_docids; use self::extract_word_position_docids::extract_word_position_docids;
@@ -33,7 +35,8 @@ use super::helpers::{
}; };
use super::{helpers, TypedChunk}; use super::{helpers, TypedChunk};
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::{FieldId, Result}; use crate::vector::EmbeddingConfigs;
use crate::{FieldId, FieldsIdsMap, Result};
/// Extract data for each databases from obkv documents in parallel. /// Extract data for each databases from obkv documents in parallel.
/// Send data in grenad file over provided Sender. /// Send data in grenad file over provided Sender.
@@ -47,13 +50,14 @@ pub(crate) fn data_from_obkv_documents(
faceted_fields: HashSet<FieldId>, faceted_fields: HashSet<FieldId>,
primary_key_id: FieldId, primary_key_id: FieldId,
geo_fields_ids: Option<(FieldId, FieldId)>, geo_fields_ids: Option<(FieldId, FieldId)>,
vectors_field_id: Option<FieldId>, field_id_map: FieldsIdsMap,
stop_words: Option<fst::Set<&[u8]>>, stop_words: Option<fst::Set<&[u8]>>,
allowed_separators: Option<&[&str]>, allowed_separators: Option<&[&str]>,
dictionary: Option<&[&str]>, dictionary: Option<&[&str]>,
max_positions_per_attributes: Option<u32>, max_positions_per_attributes: Option<u32>,
exact_attributes: HashSet<FieldId>, exact_attributes: HashSet<FieldId>,
proximity_precision: ProximityPrecision, proximity_precision: ProximityPrecision,
embedders: EmbeddingConfigs,
) -> Result<()> { ) -> Result<()> {
puffin::profile_function!(); puffin::profile_function!();
@@ -64,7 +68,8 @@ pub(crate) fn data_from_obkv_documents(
original_documents_chunk, original_documents_chunk,
indexer, indexer,
lmdb_writer_sx.clone(), lmdb_writer_sx.clone(),
vectors_field_id, field_id_map.clone(),
embedders.clone(),
) )
}) })
.collect::<Result<()>>()?; .collect::<Result<()>>()?;
@@ -276,24 +281,53 @@ fn send_original_documents_data(
original_documents_chunk: Result<grenad::Reader<BufReader<File>>>, original_documents_chunk: Result<grenad::Reader<BufReader<File>>>,
indexer: GrenadParameters, indexer: GrenadParameters,
lmdb_writer_sx: Sender<Result<TypedChunk>>, lmdb_writer_sx: Sender<Result<TypedChunk>>,
vectors_field_id: Option<FieldId>, field_id_map: FieldsIdsMap,
embedders: EmbeddingConfigs,
) -> Result<()> { ) -> Result<()> {
let original_documents_chunk = let original_documents_chunk =
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
if let Some(vectors_field_id) = vectors_field_id {
let documents_chunk_cloned = original_documents_chunk.clone(); let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || { rayon::spawn(move || {
let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); for (name, (embedder, prompt)) in embedders {
let _ = match result { let result = extract_vector_points(
Ok(vector_points) => { documents_chunk_cloned.clone(),
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) indexer,
&field_id_map,
&prompt,
&name,
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
None
} }
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
}; };
});
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name: name,
}));
} }
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
}
}
});
// TODO: create a custom internal error // TODO: create a custom internal error
lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap();

View File

@@ -4,7 +4,7 @@ mod helpers;
mod transform; mod transform;
mod typed_chunk; mod typed_chunk;
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::io::{Cursor, Read, Seek}; use std::io::{Cursor, Read, Seek};
use std::iter::FromIterator; use std::iter::FromIterator;
use std::num::NonZeroU32; use std::num::NonZeroU32;
@@ -14,6 +14,7 @@ use crossbeam_channel::{Receiver, Sender};
use heed::types::Str; use heed::types::Str;
use heed::Database; use heed::Database;
use log::debug; use log::debug;
use rand::SeedableRng;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use slice_group_by::GroupBy; use slice_group_by::GroupBy;
@@ -36,6 +37,7 @@ pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{ use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
}; };
use crate::vector::EmbeddingConfigs;
use crate::{CboRoaringBitmapCodec, Index, Result}; use crate::{CboRoaringBitmapCodec, Index, Result};
static MERGED_DATABASE_COUNT: usize = 7; static MERGED_DATABASE_COUNT: usize = 7;
@@ -78,6 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
should_abort: FA, should_abort: FA,
added_documents: u64, added_documents: u64,
deleted_documents: u64, deleted_documents: u64,
embedders: EmbeddingConfigs,
} }
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
@@ -121,6 +124,7 @@ where
index, index,
added_documents: 0, added_documents: 0,
deleted_documents: 0, deleted_documents: 0,
embedders: Default::default(),
}) })
} }
@@ -167,6 +171,11 @@ where
Ok((self, Ok(indexed_documents))) Ok((self, Ok(indexed_documents)))
} }
pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self {
self.embedders = embedders;
self
}
/// Remove a batch of documents from the current builder. /// Remove a batch of documents from the current builder.
/// ///
/// Returns the number of documents deleted from the builder. /// Returns the number of documents deleted from the builder.
@@ -322,17 +331,18 @@ where
// get filterable fields for facet databases // get filterable fields for facet databases
let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?;
// get the fid of the `_geo.lat` and `_geo.lng` fields. // get the fid of the `_geo.lat` and `_geo.lng` fields.
let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") { let mut field_id_map = self.index.fields_ids_map(self.wtxn)?;
// self.index.fields_ids_map($a)? ==>> field_id_map
let geo_fields_ids = match field_id_map.id("_geo") {
Some(gfid) => { Some(gfid) => {
let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid);
let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid);
// if `_geo` is faceted then we get the `lat` and `lng` // if `_geo` is faceted then we get the `lat` and `lng`
if is_sortable || is_filterable { if is_sortable || is_filterable {
let field_ids = self let field_ids = field_id_map
.index
.fields_ids_map(self.wtxn)?
.insert("_geo.lat") .insert("_geo.lat")
.zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng")) .zip(field_id_map.insert("_geo.lng"))
.ok_or(UserError::AttributeLimitReached)?; .ok_or(UserError::AttributeLimitReached)?;
Some(field_ids) Some(field_ids)
} else { } else {
@@ -341,8 +351,6 @@ where
} }
None => None, None => None,
}; };
// get the fid of the `_vectors` field.
let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
let stop_words = self.index.stop_words(self.wtxn)?; let stop_words = self.index.stop_words(self.wtxn)?;
let separators = self.index.allowed_separators(self.wtxn)?; let separators = self.index.allowed_separators(self.wtxn)?;
@@ -364,6 +372,8 @@ where
self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB
let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes;
let cloned_embedder = self.embedders.clone();
// Run extraction pipeline in parallel. // Run extraction pipeline in parallel.
pool.install(|| { pool.install(|| {
puffin::profile_scope!("extract_and_send_grenad_chunks"); puffin::profile_scope!("extract_and_send_grenad_chunks");
@@ -387,13 +397,14 @@ where
faceted_fields, faceted_fields,
primary_key_id, primary_key_id,
geo_fields_ids, geo_fields_ids,
vectors_field_id, field_id_map,
stop_words, stop_words,
separators.as_deref(), separators.as_deref(),
dictionary.as_deref(), dictionary.as_deref(),
max_positions_per_attributes, max_positions_per_attributes,
exact_attributes, exact_attributes,
proximity_precision, proximity_precision,
cloned_embedder,
) )
}); });
@@ -402,7 +413,7 @@ where
} }
// needs to be dropped to avoid channel waiting lock. // needs to be dropped to avoid channel waiting lock.
drop(lmdb_writer_sx) drop(lmdb_writer_sx);
}); });
let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0; let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0;
@@ -419,6 +430,8 @@ where
let mut word_docids = None; let mut word_docids = None;
let mut exact_word_docids = None; let mut exact_word_docids = None;
let mut dimension = HashMap::new();
for result in lmdb_writer_rx { for result in lmdb_writer_rx {
if (self.should_abort)() { if (self.should_abort)() {
return Err(Error::InternalError(InternalError::AbortedIndexation)); return Err(Error::InternalError(InternalError::AbortedIndexation));
@@ -448,6 +461,22 @@ where
word_position_docids = Some(cloneable_chunk); word_position_docids = Some(cloneable_chunk);
TypedChunk::WordPositionDocids(chunk) TypedChunk::WordPositionDocids(chunk)
} }
TypedChunk::VectorPoints {
expected_dimension,
remove_vectors,
embeddings,
manual_vectors,
embedder_name,
} => {
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension,
manual_vectors,
embedder_name,
}
}
otherwise => otherwise, otherwise => otherwise,
}; };
@@ -480,6 +509,33 @@ where
// We write the primary key field id into the main database // We write the primary key field id into the main database
self.index.put_primary_key(self.wtxn, &primary_key)?; self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn;
let vector_arroy = self.index.vector_arroy;
let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
pool.install(|| {
let writer_index = (embedder_index as u16) << 8;
for k in 0..=u8::MAX {
let writer = arroy::Writer::prepare(
wtxn,
vector_arroy,
writer_index | (k as u16),
dimension,
)?;
if writer.is_empty(wtxn)? {
break;
}
writer.build(wtxn, &mut rng, None)?;
}
Result::Ok(())
})?;
}
self.execute_prefix_databases( self.execute_prefix_databases(
word_docids, word_docids,
@@ -694,6 +750,8 @@ fn execute_word_prefix_docids(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::BTreeMap;
use big_s::S; use big_s::S;
use fst::IntoStreamer; use fst::IntoStreamer;
use heed::RwTxn; use heed::RwTxn;
@@ -703,6 +761,7 @@ mod tests {
use crate::documents::documents_batch_reader_from_objects; use crate::documents::documents_batch_reader_from_objects;
use crate::index::tests::TempIndex; use crate::index::tests::TempIndex;
use crate::search::TermsMatchingStrategy; use crate::search::TermsMatchingStrategy;
use crate::update::Setting;
use crate::{db_snap, Filter, Search}; use crate::{db_snap, Filter, Search};
#[test] #[test]
@@ -2494,18 +2553,39 @@ mod tests {
/// Vectors must be of the same length. /// Vectors must be of the same length.
#[test] #[test]
fn test_multiple_vectors() { fn test_multiple_vectors() {
use crate::vector::settings::{EmbedderSettings, EmbeddingSettings};
let index = TempIndex::new(); let index = TempIndex::new();
index.add_documents(documents!([{"id": 0, "_vectors": [[0, 1, 2], [3, 4, 5]] }])).unwrap(); index
index.add_documents(documents!([{"id": 1, "_vectors": [6, 7, 8] }])).unwrap(); .update_settings(|settings| {
let mut embedders = BTreeMap::default();
embedders.insert(
"manual".to_string(),
Setting::Set(EmbeddingSettings {
embedder_options: Setting::Set(EmbedderSettings::UserProvided(
crate::vector::settings::UserProvidedSettings { dimensions: 3 },
)),
document_template: Setting::NotSet,
}),
);
settings.set_embedder_settings(embedders);
})
.unwrap();
index index
.add_documents( .add_documents(
documents!([{"id": 2, "_vectors": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }]), documents!([{"id": 0, "_vectors": { "manual": [[0, 1, 2], [3, 4, 5]] } }]),
)
.unwrap();
index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap();
index
.add_documents(
documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]),
) )
.unwrap(); .unwrap();
let rtxn = index.read_txn().unwrap(); let rtxn = index.read_txn().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap(); let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap();
assert_eq!(res.documents_ids.len(), 3); assert_eq!(res.documents_ids.len(), 3);
} }

View File

@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::convert::TryInto; use std::convert::TryInto;
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader}; use std::io::{self, BufReader};
@@ -8,9 +8,7 @@ use charabia::{Language, Script};
use grenad::MergerBuilder; use grenad::MergerBuilder;
use heed::types::Bytes; use heed::types::Bytes;
use heed::{PutFlags, RwTxn}; use heed::{PutFlags, RwTxn};
use log::error;
use obkv::{KvReader, KvWriter}; use obkv::{KvReader, KvWriter};
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use super::helpers::{ use super::helpers::{
@@ -18,16 +16,15 @@ use super::helpers::{
valid_lmdb_key, CursorClonableMmap, valid_lmdb_key, CursorClonableMmap,
}; };
use super::{ClonableMmap, MergeFn}; use super::{ClonableMmap, MergeFn};
use crate::distance::NDotProductPoint;
use crate::error::UserError;
use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind}; use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind};
use crate::facet::FacetType; use crate::facet::FacetType;
use crate::index::db_name::DOCUMENTS; use crate::index::db_name::DOCUMENTS;
use crate::index::Hnsw;
use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd};
use crate::update::facet::FacetsUpdate; use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; use crate::{
lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError,
};
pub(crate) enum TypedChunk { pub(crate) enum TypedChunk {
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>), FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
@@ -47,7 +44,13 @@ pub(crate) enum TypedChunk {
FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>), FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>),
FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>), FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>),
GeoPoints(grenad::Reader<BufReader<File>>), GeoPoints(grenad::Reader<BufReader<File>>),
VectorPoints(grenad::Reader<BufReader<File>>), VectorPoints {
remove_vectors: grenad::Reader<BufReader<File>>,
embeddings: Option<grenad::Reader<BufReader<File>>>,
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
},
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
} }
@@ -100,8 +103,8 @@ impl TypedChunk {
TypedChunk::GeoPoints(grenad) => { TypedChunk::GeoPoints(grenad) => {
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
} }
TypedChunk::VectorPoints(grenad) => { TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
} }
TypedChunk::ScriptLanguageDocids(sl_map) => { TypedChunk::ScriptLanguageDocids(sl_map) => {
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
@@ -355,19 +358,77 @@ pub(crate) fn write_typed_chunk_into_index(
index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_rtree(wtxn, &rtree)?;
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
} }
TypedChunk::VectorPoints(vector_points) => { TypedChunk::VectorPoints {
let mut vectors_set = HashSet::new(); remove_vectors,
// We extract and store the previous vectors manual_vectors,
if let Some(hnsw) = index.vector_hnsw(wtxn)? { embeddings,
for (pid, point) in hnsw.iter() { expected_dimension,
let pid_key = pid.into_inner(); embedder_name,
let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); } => {
let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
vectors_set.insert((docid, vector)); InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
let writer_index = (embedder_index as u16) << 8;
// FIXME: allow customizing distance
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map(|k| {
arroy::Writer::prepare(
wtxn,
index.vector_arroy,
writer_index | (k as u16),
expected_dimension,
)
})
.collect();
let writers = writers?;
// remove vectors for docids we want them removed
let mut cursor = remove_vectors.into_cursor()?;
while let Some((key, _)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
for writer in &writers {
// Uses invariant: vectors are packed in the first writers.
if !writer.del_item(wtxn, docid)? {
break;
}
} }
} }
let mut cursor = vector_points.into_cursor()?; // add generated embeddings
if let Some(embeddings) = embeddings {
let mut cursor = embeddings.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
let data = pod_collect_to_vec(value);
// it is a code error to have embeddings and not expected_dimension
let embeddings =
crate::vector::Embeddings::from_inner(data, expected_dimension)
// code error if we somehow got the wrong dimension
.unwrap();
if embeddings.embedding_count() > u8::MAX.into() {
let external_docid = if let Ok(Some(Ok(index))) = index
.external_id_of(wtxn, std::iter::once(docid))
.map(|it| it.into_iter().next())
{
index
} else {
format!("internal docid={docid}")
};
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
external_docid,
embeddings.embedding_count(),
)));
}
for (embedding, writer) in embeddings.iter().zip(&writers) {
writer.add_item(wtxn, docid, embedding)?;
}
}
}
// perform the manual diff
let mut cursor = manual_vectors.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? { while let Some((key, value)) = cursor.move_on_next()? {
// convert the key back to a u32 (4 bytes) // convert the key back to a u32 (4 bytes)
let (left, _index) = try_split_array_at(key).unwrap(); let (left, _index) = try_split_array_at(key).unwrap();
@@ -375,58 +436,52 @@ pub(crate) fn write_typed_chunk_into_index(
let vector_deladd_obkv = KvReaderDelAdd::new(value); let vector_deladd_obkv = KvReaderDelAdd::new(value);
if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) {
// convert the vector back to a Vec<f32> let vector: Vec<f32> = pod_collect_to_vec(value);
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
let key = (docid, vector);
if !vectors_set.remove(&key) {
error!("Unable to delete the vector: {:?}", key.1);
}
}
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
// convert the vector back to a Vec<f32>
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
vectors_set.insert((docid, vector));
}
}
// Extract the most common vector dimension let mut deleted_index = None;
let expected_dimension_size = { for (index, writer) in writers.iter().enumerate() {
let mut dims = HashMap::new(); let Some(candidate) = writer.item_vector(wtxn, docid)? else {
vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); // uses invariant: vectors are packed in the first writers.
dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) break;
}; };
if candidate == vector {
// Ensure that the vector lengths are correct and writer.del_item(wtxn, docid)?;
// prepare the vectors before inserting them in the HNSW. deleted_index = Some(index);
let mut points = Vec::new();
let mut docids = Vec::new();
for (docid, vector) in vectors_set {
if expected_dimension_size.map_or(false, |expected| expected != vector.len()) {
return Err(UserError::InvalidVectorDimensions {
expected: expected_dimension_size.unwrap_or(vector.len()),
found: vector.len(),
}
.into());
} else {
let vector = vector.into_iter().map(OrderedFloat::into_inner).collect();
points.push(NDotProductPoint::new(vector));
docids.push(docid);
} }
} }
let hnsw_length = points.len(); // 🥲 enforce invariant: vectors are packed in the first writers.
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
assert_eq!(docids.len(), pids.len()); for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
// Store the vectors in the point-docid relation database break;
index.vector_id_docid.clear(wtxn)?; };
for (docid, pid) in docids.into_iter().zip(pids) { last_index_with_a_vector = Some((index, candidate));
index.vector_id_docid.put(wtxn, &pid.into_inner(), &docid)?; }
if let Some((last_index, vector)) = last_index_with_a_vector {
// unwrap: computed the index from the list of writers
let writer = writers.get(last_index).unwrap();
writer.del_item(wtxn, docid)?;
writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?;
}
}
} }
log::debug!("There are {} entries in the HNSW so far", hnsw_length); if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
index.put_vector_hnsw(wtxn, &new_hnsw)?; let vector = pod_collect_to_vec(value);
// overflow was detected during vector extraction.
for writer in &writers {
if !writer.contains_item(wtxn, docid)? {
writer.add_item(wtxn, docid, &vector)?;
break;
}
}
}
}
log::debug!("Finished vector chunk for {}", embedder_name);
} }
TypedChunk::ScriptLanguageDocids(sl_map) => { TypedChunk::ScriptLanguageDocids(sl_map) => {
for (key, (deletion, addition)) in sl_map { for (key, (deletion, addition)) in sl_map {

View File

@@ -1,9 +1,11 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::convert::TryInto;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::sync::Arc;
use charabia::{Normalize, Tokenizer, TokenizerBuilder}; use charabia::{Normalize, Tokenizer, TokenizerBuilder};
use deserr::{DeserializeError, Deserr}; use deserr::{DeserializeError, Deserr};
use itertools::Itertools; use itertools::{EitherOrBoth, Itertools};
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime; use time::OffsetDateTime;
@@ -15,6 +17,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod; use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep}; use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{FieldsIdsMap, Index, OrderBy, Result}; use crate::{FieldsIdsMap, Index, OrderBy, Result};
#[derive(Debug, Clone, PartialEq, Eq, Copy)] #[derive(Debug, Clone, PartialEq, Eq, Copy)]
@@ -73,6 +77,13 @@ impl<T> Setting<T> {
otherwise => otherwise, otherwise => otherwise,
} }
} }
pub fn apply(&mut self, new: Self) {
if let Setting::NotSet = new {
return;
}
*self = new;
}
} }
impl<T: Serialize> Serialize for Setting<T> { impl<T: Serialize> Serialize for Setting<T> {
@@ -129,6 +140,7 @@ pub struct Settings<'a, 't, 'i> {
sort_facet_values_by: Setting<HashMap<String, OrderBy>>, sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
pagination_max_total_hits: Setting<usize>, pagination_max_total_hits: Setting<usize>,
proximity_precision: Setting<ProximityPrecision>, proximity_precision: Setting<ProximityPrecision>,
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
} }
impl<'a, 't, 'i> Settings<'a, 't, 'i> { impl<'a, 't, 'i> Settings<'a, 't, 'i> {
@@ -161,6 +173,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
sort_facet_values_by: Setting::NotSet, sort_facet_values_by: Setting::NotSet,
pagination_max_total_hits: Setting::NotSet, pagination_max_total_hits: Setting::NotSet,
proximity_precision: Setting::NotSet, proximity_precision: Setting::NotSet,
embedder_settings: Setting::NotSet,
indexer_config, indexer_config,
} }
} }
@@ -343,6 +356,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.proximity_precision = Setting::Reset; self.proximity_precision = Setting::Reset;
} }
pub fn set_embedder_settings(&mut self, value: BTreeMap<String, Setting<EmbeddingSettings>>) {
self.embedder_settings = Setting::Set(value);
}
pub fn reset_embedder_settings(&mut self) {
self.embedder_settings = Setting::Reset;
}
fn reindex<FP, FA>( fn reindex<FP, FA>(
&mut self, &mut self,
progress_callback: &FP, progress_callback: &FP,
@@ -377,6 +398,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
fields_ids_map, fields_ids_map,
)?; )?;
let embedder_configs = self.index.embedding_configs(self.wtxn)?;
let embedders = self.embedders(embedder_configs)?;
// We index the generated `TransformOutput` which must contain // We index the generated `TransformOutput` which must contain
// all the documents with fields in the newly defined searchable order. // all the documents with fields in the newly defined searchable order.
let indexing_builder = IndexDocuments::new( let indexing_builder = IndexDocuments::new(
@@ -387,11 +411,33 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
&progress_callback, &progress_callback,
&should_abort, &should_abort,
)?; )?;
let indexing_builder = indexing_builder.with_embedders(embedders);
indexing_builder.execute_raw(output)?; indexing_builder.execute_raw(output)?;
Ok(()) Ok(())
} }
fn embedders(
&self,
embedding_configs: Vec<(String, EmbeddingConfig)>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?,
);
Ok((name, (embedder, prompt)))
})
.collect();
res.map(EmbeddingConfigs::new)
}
fn update_displayed(&mut self) -> Result<bool> { fn update_displayed(&mut self) -> Result<bool> {
match self.displayed_fields { match self.displayed_fields {
Setting::Set(ref fields) => { Setting::Set(ref fields) => {
@@ -890,6 +936,73 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Ok(changed) Ok(changed)
} }
fn update_embedding_configs(&mut self) -> Result<bool> {
let update = match std::mem::take(&mut self.embedder_settings) {
Setting::Set(configs) => {
let mut changed = false;
let old_configs = self.index.embedding_configs(self.wtxn)?;
let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> =
old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect();
let mut new_configs = BTreeMap::new();
for joined in old_configs
.into_iter()
.merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right))
{
match joined {
EitherOrBoth::Both((name, mut old), (_, new)) => {
old.apply(new);
let new = validate_prompt(&name, old)?;
changed = true;
new_configs.insert(name, new);
}
EitherOrBoth::Left((name, setting)) => {
new_configs.insert(name, setting);
}
EitherOrBoth::Right((name, setting)) => {
let setting = validate_prompt(&name, setting)?;
changed = true;
new_configs.insert(name, setting);
}
}
}
let new_configs: Vec<(String, EmbeddingConfig)> = new_configs
.into_iter()
.filter_map(|(name, setting)| match setting {
Setting::Set(value) => Some((name, value.into())),
Setting::Reset => None,
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
})
.collect();
self.index.embedder_category_id.clear(self.wtxn)?;
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
self.index.embedder_category_id.put_with_flags(
self.wtxn,
heed::PutFlags::APPEND,
embedder_name,
&index
.try_into()
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
)?;
}
if new_configs.is_empty() {
self.index.delete_embedding_configs(self.wtxn)?;
} else {
self.index.put_embedding_configs(self.wtxn, new_configs)?;
}
changed
}
Setting::Reset => {
self.index.delete_embedding_configs(self.wtxn)?;
true
}
Setting::NotSet => false,
};
Ok(update)
}
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()> pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
where where
FP: Fn(UpdateIndexingStep) + Sync, FP: Fn(UpdateIndexingStep) + Sync,
@@ -927,6 +1040,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
let searchable_updated = self.update_searchable()?; let searchable_updated = self.update_searchable()?;
let exact_attributes_updated = self.update_exact_attributes()?; let exact_attributes_updated = self.update_exact_attributes()?;
let proximity_precision = self.update_proximity_precision()?; let proximity_precision = self.update_proximity_precision()?;
// TODO: very rough approximation of the needs for reindexing where any change will result in
// a full reindexing.
// What can be done instead:
// 1. Only change the distance on a distance change
// 2. Only change the name -> embedder mapping on a name change
// 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage
let embedding_configs_updated = self.update_embedding_configs()?;
if stop_words_updated if stop_words_updated
|| non_separator_tokens_updated || non_separator_tokens_updated
@@ -937,6 +1057,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|| searchable_updated || searchable_updated
|| exact_attributes_updated || exact_attributes_updated
|| proximity_precision || proximity_precision
|| embedding_configs_updated
{ {
self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?;
} }
@@ -945,6 +1066,31 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
} }
} }
fn validate_prompt(
name: &str,
new: Setting<EmbeddingSettings>,
) -> Result<Setting<EmbeddingSettings>> {
match new {
Setting::Set(EmbeddingSettings {
embedder_options,
document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }),
}) => {
// validate
let template = crate::prompt::Prompt::new(template)
.map(|prompt| crate::prompt::PromptData::from(prompt).template)
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
Ok(Setting::Set(EmbeddingSettings {
embedder_options,
document_template: Setting::Set(PromptSettings {
template: Setting::Set(template),
}),
}))
}
new => Ok(new),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use big_s::S; use big_s::S;
@@ -1763,6 +1909,7 @@ mod tests {
sort_facet_values_by, sort_facet_values_by,
pagination_max_total_hits, pagination_max_total_hits,
proximity_precision, proximity_precision,
embedder_settings,
} = settings; } = settings;
assert!(matches!(searchable_fields, Setting::NotSet)); assert!(matches!(searchable_fields, Setting::NotSet));
assert!(matches!(displayed_fields, Setting::NotSet)); assert!(matches!(displayed_fields, Setting::NotSet));
@@ -1785,6 +1932,7 @@ mod tests {
assert!(matches!(sort_facet_values_by, Setting::NotSet)); assert!(matches!(sort_facet_values_by, Setting::NotSet));
assert!(matches!(pagination_max_total_hits, Setting::NotSet)); assert!(matches!(pagination_max_total_hits, Setting::NotSet));
assert!(matches!(proximity_precision, Setting::NotSet)); assert!(matches!(proximity_precision, Setting::NotSet));
assert!(matches!(embedder_settings, Setting::NotSet));
}) })
.unwrap(); .unwrap();
} }

244
milli/src/vector/error.rs Normal file
View File

@@ -0,0 +1,244 @@
use std::path::PathBuf;
use hf_hub::api::sync::ApiError;
use crate::error::FaultSource;
use crate::vector::openai::OpenAiError;
#[derive(Debug, thiserror::Error)]
#[error("Error while generating embeddings: {inner}")]
pub struct Error {
pub inner: Box<ErrorKind>,
}
impl<I: Into<ErrorKind>> From<I> for Error {
fn from(value: I) -> Self {
Self { inner: Box::new(value.into()) }
}
}
impl Error {
pub fn fault(&self) -> FaultSource {
match &*self.inner {
ErrorKind::NewEmbedderError(inner) => inner.fault,
ErrorKind::EmbedError(inner) => inner.fault,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ErrorKind {
#[error(transparent)]
NewEmbedderError(#[from] NewEmbedderError),
#[error(transparent)]
EmbedError(#[from] EmbedError),
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct EmbedError {
pub kind: EmbedErrorKind,
pub fault: FaultSource,
}
#[derive(Debug, thiserror::Error)]
pub enum EmbedErrorKind {
#[error("could not tokenize: {0}")]
Tokenize(Box<dyn std::error::Error + Send + Sync>),
#[error("unexpected tensor shape: {0}")]
TensorShape(candle_core::Error),
#[error("unexpected tensor value: {0}")]
TensorValue(candle_core::Error),
#[error("could not run model: {0}")]
ModelForward(candle_core::Error),
#[error("could not reach OpenAI: {0}")]
OpenAiNetwork(reqwest::Error),
#[error("unexpected response from OpenAI: {0}")]
OpenAiUnexpected(reqwest::Error),
#[error("could not authenticate against OpenAI: {0}")]
OpenAiAuth(OpenAiError),
#[error("sent too many requests to OpenAI: {0}")]
OpenAiTooManyRequests(OpenAiError),
#[error("received internal error from OpenAI: {0}")]
OpenAiInternalServerError(OpenAiError),
#[error("sent too many tokens in a request to OpenAI: {0}")]
OpenAiTooManyTokens(OpenAiError),
#[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String),
}
impl EmbedError {
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
}
pub fn tensor_shape(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
}
pub fn tensor_value(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
}
pub fn model_forward(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
}
pub fn openai_network(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
}
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
}
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
}
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct NewEmbedderError {
pub kind: NewEmbedderErrorKind,
pub fault: FaultSource,
}
impl NewEmbedderError {
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
let open_config = OpenConfig { filename: config_filename, inner };
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
}
pub fn deserialize_config(
config: String,
config_filename: PathBuf,
inner: serde_json::Error,
) -> NewEmbedderError {
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner };
Self {
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
fault: FaultSource::Runtime,
}
}
pub fn open_tokenizer(
tokenizer_filename: PathBuf,
inner: Box<dyn std::error::Error + Send + Sync>,
) -> NewEmbedderError {
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
Self {
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
fault: FaultSource::Runtime,
}
}
pub fn new_api_fail(inner: ApiError) -> Self {
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
}
pub fn api_get(inner: ApiError) -> Self {
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
}
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
}
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
}
pub fn load_model(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
}
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime,
}
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
#[error("could not open config at {filename:?}: {inner}")]
pub struct OpenConfig {
pub filename: PathBuf,
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")]
pub struct DeserializeConfig {
pub config: String,
pub filename: PathBuf,
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("could not open tokenizer at {filename}: {inner}")]
pub struct OpenTokenizer {
pub filename: PathBuf,
#[source]
pub inner: Box<dyn std::error::Error + Send + Sync>,
}
#[derive(Debug, thiserror::Error)]
pub enum NewEmbedderErrorKind {
// hf
#[error(transparent)]
OpenConfig(OpenConfig),
#[error(transparent)]
DeserializeConfig(DeserializeConfig),
#[error(transparent)]
OpenTokenizer(OpenTokenizer),
#[error("could not build weights from Pytorch weights: {0}")]
PytorchWeight(candle_core::Error),
#[error("could not build weights from Safetensor weights: {0}")]
SafetensorWeight(candle_core::Error),
#[error("could not spawn HG_HUB API client: {0}")]
NewApiFail(ApiError),
#[error("fetching file from HG_HUB failed: {0}")]
ApiGet(ApiError),
#[error("could not determine model dimensions: test embedding failed with {0}")]
CouldNotDetermineDimension(EmbedError),
#[error("loading model failed: {0}")]
LoadModel(candle_core::Error),
// openai
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
}

195
milli/src/vector/hf.rs Normal file
View File

@@ -0,0 +1,195 @@
use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
pub use super::error::{EmbedError, Error, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings};
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Deserialize,
serde::Serialize,
deserr::Deserr,
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
enum WeightSource {
#[default]
Safetensors,
Pytorch,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
}
impl EmbedderOptions {
pub fn new() -> Self {
Self {
model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
}
}
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::new()
}
}
/// Perform embedding of documents and queries
pub struct Embedder {
model: BertModel,
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
}
impl std::fmt::Debug for Embedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embedder")
.field("model", &self.options.model)
.field("tokenizer", &self.tokenizer)
.field("options", &self.options)
.finish()
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
let device = candle_core::Device::Cpu;
let repo = match options.revision.clone() {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),
};
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let (weights, source) = {
api.get("pytorch_model.bin")
.map(|filename| (filename, WeightSource::Pytorch))
.or_else(|_| {
api.get("model.safetensors")
.map(|filename| (filename, WeightSource::Safetensors))
})
.map_err(NewEmbedderError::api_get)?
};
(config, tokenizer, weights, source)
};
let config = std::fs::read_to_string(&config_filename)
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
let config: Config = serde_json::from_str(&config).map_err(|inner| {
NewEmbedderError::deserialize_config(config, config_filename, inner)
})?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?,
WeightSource::Safetensors => unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
.map_err(NewEmbedderError::safetensor_weight)?
},
};
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let mut this = Self { model, tokenizer, options, dimensions: 0 };
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
}
pub fn embed(
&self,
mut texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
let tokens = match texts.len() {
1 => vec![self
.tokenizer
.encode(texts.pop().unwrap(), true)
.map_err(EmbedError::tokenize)?],
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
})
.collect::<Result<Vec<_>, EmbedError>>()?;
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings =
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn chunk_count_hint(&self) -> usize {
1
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 })
} else {
None
}
}
}

View File

@@ -0,0 +1,34 @@
use super::error::EmbedError;
use super::Embeddings;
#[derive(Debug, Clone, Copy)]
pub struct Embedder {
dimensions: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub dimensions: usize,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions }
}
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let Some(text) = texts.pop() else { return Ok(Default::default()) };
Err(EmbedError::embed_on_manual_embedder(text))
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
}

257
milli/src/vector/mod.rs Normal file
View File

@@ -0,0 +1,257 @@
use std::collections::HashMap;
use std::sync::Arc;
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
pub mod error;
pub mod hf;
pub mod manual;
pub mod openai;
pub mod settings;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
pub struct Embeddings<F> {
data: Vec<F>,
dimension: usize,
}
impl<F> Embeddings<F> {
pub fn new(dimension: usize) -> Self {
Self { data: Default::default(), dimension }
}
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
Self { dimension: embedding.len(), data: embedding }
}
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
let mut this = Self::new(dimension);
this.append(data)?;
Ok(this)
}
pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn into_inner(self) -> Vec<F> {
self.data
}
pub fn as_inner(&self) -> &[F] {
&self.data
}
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
self.data.as_slice().chunks_exact(self.dimension)
}
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
if embedding.len() != self.dimension {
return Err(embedding);
}
self.data.append(&mut embedding);
Ok(())
}
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
if embeddings.len() % self.dimension != 0 {
return Err(embeddings);
}
self.data.append(&mut embeddings);
Ok(())
}
}
#[derive(Debug)]
pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig {
pub embedder_options: EmbedderOptions,
pub prompt: PromptData,
// TODO: add metrics and anything needed
}
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
impl EmbeddingConfigs {
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
Self(data)
}
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.0.get(name).cloned()
}
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.get_default_embedder_name().and_then(|default| self.get(&default))
}
pub fn get_default_embedder_name(&self) -> Option<String> {
let mut it = self.0.keys();
let first_name = it.next();
let second_name = it.next();
match (first_name, second_name) {
(None, _) => None,
(Some(first), None) => Some(first.to_owned()),
(Some(_), Some(_)) => Some("default".to_owned()),
}
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>));
type IntoIter = std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>)>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::HuggingFace(Default::default())
}
}
impl EmbedderOptions {
pub fn huggingface() -> Self {
Self::HuggingFace(hf::EmbedderOptions::new())
}
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
})
}
pub async fn embed(
&self,
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
pub async fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
pub fn chunk_count_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
pub fn distribution(&self) -> Option<DistributionShift> {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DistributionShift {
pub current_mean: f32,
pub current_sigma: f32,
}
impl DistributionShift {
/// `None` if sigma <= 0.
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: mean, current_sigma: sigma })
}
}
pub fn shift(&self, score: f32) -> f32 {
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
let target_mean = 0.5;
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean);
let mut score = factor * score + offset;
// clamp the final score in the ]0, 1] interval.
if score <= 0.0 {
score = f32::EPSILON;
}
if score > 1.0 {
score = 1.0;
}
score
}
}

445
milli/src/vector/openai.rs Normal file
View File

@@ -0,0 +1,445 @@
use std::fmt::Display;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use super::error::{EmbedError, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)]
pub struct Embedder {
client: reqwest::Client,
tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
}
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
deserr::Deserr,
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbeddingModel {
#[default]
#[serde(rename = "text-embedding-ada-002")]
#[deserr(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
}
impl EmbeddingModel {
pub fn max_token(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 8191,
}
}
pub fn dimensions(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 1536,
}
}
pub fn name(&self) -> &'static str {
match self {
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
}
}
pub fn from_name(name: &'static str) -> Option<Self> {
match name {
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
_ => None,
}
}
fn distribution(&self) -> Option<DistributionShift> {
match self {
EmbeddingModel::TextEmbeddingAda002 => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
}
}
}
}
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() }
}
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model }
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
);
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
let client = reqwest::ClientBuilder::new()
.default_headers(headers)
.build()
.map_err(NewEmbedderError::openai_initialize_web_client)?;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, client, tokenizer })
}
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false;
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts).await
} else {
self.try_embed(&texts).await
};
let retry_duration = match result {
Ok(embeddings) => return Ok(embeddings),
Err(retry) => {
log::warn!("Failed: {}", retry.error);
tokenized |= retry.must_tokenize();
retry.into_duration(attempt)
}
}?;
log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
tokio::time::sleep(retry_duration).await;
}
let result = if tokenized {
self.try_embed_tokenized(&texts).await
} else {
self.try_embed(&texts).await
};
result.map_err(Retry::into_error)
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
match response.status() {
StatusCode::UNAUTHORIZED => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::give_up(EmbedError::openai_auth_error(
error_response.error,
)));
}
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
error_response.error,
)));
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.error,
)));
}
StatusCode::SERVICE_UNAVAILABLE => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.error,
)));
}
StatusCode::BAD_REQUEST => {
// Most probably, one text contained too many tokens
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
error_response.error,
)));
}
code => {
return Err(Retry::give_up(EmbedError::openai_unhandled_status_code(
code.as_u16(),
)));
}
}
}
Ok(response)
}
async fn try_embed<S: AsRef<str> + serde::Serialize>(
&self,
texts: &[S],
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
log::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
let response = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
log::trace!("response: {:?}", response.data);
Ok(response
.data
.into_iter()
.map(|data| Embeddings::from_single_embedding(data.embedding))
.collect())
}
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text]).await?);
continue;
}
let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt =
Embeddings::new(self.options.embedding_model.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
}
// end of text
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens).await {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
.map_err(Retry::retry_later)?;
tokio::time::sleep(duration).await;
}
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
}
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
let request =
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
let response = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let mut response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub async fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
.await
}
pub fn chunk_count_hint(&self) -> usize {
10
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
}
// retrying in case of failure
struct Retry {
error: EmbedError,
strategy: RetryStrategy,
}
enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
fn into_error(self) -> EmbedError {
self.error
}
}
// openai api structs
#[derive(Debug, Serialize)]
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str,
input: &'a [S],
}
#[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> {
model: &'a str,
input: &'a [usize],
}
#[derive(Debug, Deserialize)]
struct OpenAiResponse {
data: Vec<OpenAiEmbedding>,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorResponse {
error: OpenAiError,
}
#[derive(Debug, Deserialize)]
pub struct OpenAiError {
message: String,
// type: String,
code: Option<String>,
}
impl Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.code {
Some(code) => write!(f, "{} ({})", self.message, code),
None => write!(f, "{}", self.message),
}
}
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedding {
embedding: Embedding,
// object: String,
// index: usize,
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}

View File

@@ -0,0 +1,292 @@
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")]
#[deserr(default, rename = "source")]
pub embedder_options: Setting<EmbedderSettings>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub document_template: Setting<PromptSettings>,
}
impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { embedder_options, document_template: prompt } = new;
self.embedder_options.apply(embedder_options);
self.document_template.apply(prompt);
}
}
impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self {
Self {
embedder_options: Setting::Set(value.embedder_options.into()),
document_template: Setting::Set(value.prompt.into()),
}
}
}
impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default();
let EmbeddingSettings { embedder_options, document_template: prompt } = value;
if let Some(embedder_options) = embedder_options.set() {
this.embedder_options = embedder_options.into();
}
if let Some(prompt) = prompt.set() {
this.prompt = prompt.into();
}
this
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct PromptSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub template: Setting<String>,
}
impl PromptSettings {
pub fn apply(&mut self, new: Self) {
let PromptSettings { template } = new;
self.template.apply(template);
}
}
impl From<PromptData> for PromptSettings {
fn from(value: PromptData) -> Self {
Self { template: Setting::Set(value.template) }
}
}
impl From<PromptSettings> for PromptData {
fn from(value: PromptSettings) -> Self {
let mut this = PromptData::default();
let PromptSettings { template } = value;
if let Some(template) = template.set() {
this.template = template;
}
this
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub enum EmbedderSettings {
HuggingFace(Setting<HfEmbedderSettings>),
OpenAi(Setting<OpenAiEmbedderSettings>),
UserProvided(UserProvidedSettings),
}
impl<E> Deserr<E> for EmbedderSettings
where
E: deserr::DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
match value {
deserr::Value::Map(map) => {
if deserr::Map::len(&map) != 1 {
return Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"Expected a single field, got {} fields",
deserr::Map::len(&map)
),
},
location,
)));
}
let mut it = deserr::Map::into_iter(map);
let (k, v) = it.next().unwrap();
match k.as_str() {
"huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set(
HfEmbedderSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
))),
"openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set(
OpenAiEmbedderSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
))),
"userProvided" => Ok(EmbedderSettings::UserProvided(
UserProvidedSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
)),
other => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::UnknownKey {
key: other,
accepted: &["huggingFace", "openAi", "userProvided"],
},
location,
))),
}
}
_ => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::Map],
},
location,
))),
}
}
}
impl Default for EmbedderSettings {
fn default() -> Self {
Self::OpenAi(Default::default())
}
}
impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
fn from(value: crate::vector::EmbedderOptions) -> Self {
match value {
crate::vector::EmbedderOptions::HuggingFace(hf) => {
Self::HuggingFace(Setting::Set(hf.into()))
}
crate::vector::EmbedderOptions::OpenAi(openai) => {
Self::OpenAi(Setting::Set(openai.into()))
}
crate::vector::EmbedderOptions::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
}
}
}
impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
fn from(value: EmbedderSettings) -> Self {
match value {
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
EmbedderSettings::OpenAi(_setting) => {
Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None))
}
EmbedderSettings::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct HfEmbedderSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub model: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub revision: Setting<String>,
}
impl HfEmbedderSettings {
pub fn apply(&mut self, new: Self) {
let HfEmbedderSettings { model, revision } = new;
self.model.apply(model);
self.revision.apply(revision);
}
}
impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
fn from(value: crate::vector::hf::EmbedderOptions) -> Self {
Self {
model: Setting::Set(value.model),
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
}
}
}
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
fn from(value: HfEmbedderSettings) -> Self {
let HfEmbedderSettings { model, revision } = value;
let mut this = Self::default();
if let Some(model) = model.set() {
this.model = model;
}
if let Some(revision) = revision.set() {
this.revision = Some(revision);
}
this
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct OpenAiEmbedderSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")]
#[deserr(default, rename = "model")]
pub embedding_model: Setting<crate::vector::openai::EmbeddingModel>,
}
impl OpenAiEmbedderSettings {
pub fn apply(&mut self, new: Self) {
let Self { api_key, embedding_model: embedding_mode } = new;
self.api_key.apply(api_key);
self.embedding_model.apply(embedding_mode);
}
}
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
Self {
api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset),
embedding_model: Setting::Set(value.embedding_model),
}
}
}
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
fn from(value: OpenAiEmbedderSettings) -> Self {
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() }
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct UserProvidedSettings {
pub dimensions: usize,
}
impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions {
fn from(value: UserProvidedSettings) -> Self {
Self { dimensions: value.dimensions }
}
}
impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings {
fn from(value: crate::vector::manual::EmbedderOptions) -> Self {
Self { dimensions: value.dimensions }
}
}