mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 13:06:27 +00:00 
			
		
		
		
	Merge branch 'main' into change-proximity-precision-settings
This commit is contained in:
		
							
								
								
									
										1118
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1118
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -276,6 +276,7 @@ pub(crate) mod test { | |||||||
|                 ), |                 ), | ||||||
|             }), |             }), | ||||||
|             pagination: Setting::NotSet, |             pagination: Setting::NotSet, | ||||||
|  |             embedders: Setting::NotSet, | ||||||
|             _kind: std::marker::PhantomData, |             _kind: std::marker::PhantomData, | ||||||
|         }; |         }; | ||||||
|         settings.check() |         settings.check() | ||||||
|   | |||||||
| @@ -378,6 +378,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> { | |||||||
|                 v5::Setting::Reset => v6::Setting::Reset, |                 v5::Setting::Reset => v6::Setting::Reset, | ||||||
|                 v5::Setting::NotSet => v6::Setting::NotSet, |                 v5::Setting::NotSet => v6::Setting::NotSet, | ||||||
|             }, |             }, | ||||||
|  |             embedders: v6::Setting::NotSet, | ||||||
|             _kind: std::marker::PhantomData, |             _kind: std::marker::PhantomData, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -1202,6 +1202,10 @@ impl IndexScheduler { | |||||||
|  |  | ||||||
|                 let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; |                 let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; | ||||||
|  |  | ||||||
|  |                 let embedder_configs = index.embedding_configs(index_wtxn)?; | ||||||
|  |                 // TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense) | ||||||
|  |                 let embedders = self.embedders(embedder_configs)?; | ||||||
|  |  | ||||||
|                 let mut builder = milli::update::IndexDocuments::new( |                 let mut builder = milli::update::IndexDocuments::new( | ||||||
|                     index_wtxn, |                     index_wtxn, | ||||||
|                     index, |                     index, | ||||||
| @@ -1220,6 +1224,8 @@ impl IndexScheduler { | |||||||
|                             let (new_builder, user_result) = builder.add_documents(reader)?; |                             let (new_builder, user_result) = builder.add_documents(reader)?; | ||||||
|                             builder = new_builder; |                             builder = new_builder; | ||||||
|  |  | ||||||
|  |                             builder = builder.with_embedders(embedders.clone()); | ||||||
|  |  | ||||||
|                             let received_documents = |                             let received_documents = | ||||||
|                                 if let Some(Details::DocumentAdditionOrUpdate { |                                 if let Some(Details::DocumentAdditionOrUpdate { | ||||||
|                                     received_documents, |                                     received_documents, | ||||||
| @@ -1345,6 +1351,9 @@ impl IndexScheduler { | |||||||
|  |  | ||||||
|                 for (task, (_, settings)) in tasks.iter_mut().zip(settings) { |                 for (task, (_, settings)) in tasks.iter_mut().zip(settings) { | ||||||
|                     let checked_settings = settings.clone().check(); |                     let checked_settings = settings.clone().check(); | ||||||
|  |                     if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) { | ||||||
|  |                         self.features().check_vector("Passing `embedders` in settings")? | ||||||
|  |                     } | ||||||
|                     task.details = Some(Details::SettingsUpdate { settings: Box::new(settings) }); |                     task.details = Some(Details::SettingsUpdate { settings: Box::new(settings) }); | ||||||
|                     apply_settings_to_builder(&checked_settings, &mut builder); |                     apply_settings_to_builder(&checked_settings, &mut builder); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -56,12 +56,12 @@ impl RoFeatures { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn check_vector(&self) -> Result<()> { |     pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> { | ||||||
|         if self.runtime.vector_store { |         if self.runtime.vector_store { | ||||||
|             Ok(()) |             Ok(()) | ||||||
|         } else { |         } else { | ||||||
|             Err(FeatureNotEnabledError { |             Err(FeatureNotEnabledError { | ||||||
|                 disabled_action: "Passing `vector` as a query parameter", |                 disabled_action, | ||||||
|                 feature: "vector store", |                 feature: "vector store", | ||||||
|                 issue_link: "https://github.com/meilisearch/product/discussions/677", |                 issue_link: "https://github.com/meilisearch/product/discussions/677", | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String { | |||||||
|         planned_failures: _, |         planned_failures: _, | ||||||
|         run_loop_iteration: _, |         run_loop_iteration: _, | ||||||
|         currently_updating_index: _, |         currently_updating_index: _, | ||||||
|  |         embedders: _, | ||||||
|     } = scheduler; |     } = scheduler; | ||||||
|  |  | ||||||
|     let rtxn = env.read_txn().unwrap(); |     let rtxn = env.read_txn().unwrap(); | ||||||
|   | |||||||
| @@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; | |||||||
| use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; | use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; | ||||||
| use meilisearch_types::milli::documents::DocumentsBatchBuilder; | use meilisearch_types::milli::documents::DocumentsBatchBuilder; | ||||||
| use meilisearch_types::milli::update::IndexerConfig; | use meilisearch_types::milli::update::IndexerConfig; | ||||||
|  | use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs}; | ||||||
| use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; | use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; | ||||||
| use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; | use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; | ||||||
| use puffin::FrameView; | use puffin::FrameView; | ||||||
| @@ -341,6 +342,8 @@ pub struct IndexScheduler { | |||||||
|     /// so that a handle to the index is available from other threads (search) in an optimized manner. |     /// so that a handle to the index is available from other threads (search) in an optimized manner. | ||||||
|     currently_updating_index: Arc<RwLock<Option<(String, Index)>>>, |     currently_updating_index: Arc<RwLock<Option<(String, Index)>>>, | ||||||
|  |  | ||||||
|  |     embedders: Arc<RwLock<HashMap<EmbedderOptions, Arc<Embedder>>>>, | ||||||
|  |  | ||||||
|     // ================= test |     // ================= test | ||||||
|     // The next entry is dedicated to the tests. |     // The next entry is dedicated to the tests. | ||||||
|     /// Provide a way to set a breakpoint in multiple part of the scheduler. |     /// Provide a way to set a breakpoint in multiple part of the scheduler. | ||||||
| @@ -386,6 +389,7 @@ impl IndexScheduler { | |||||||
|             auth_path: self.auth_path.clone(), |             auth_path: self.auth_path.clone(), | ||||||
|             version_file_path: self.version_file_path.clone(), |             version_file_path: self.version_file_path.clone(), | ||||||
|             currently_updating_index: self.currently_updating_index.clone(), |             currently_updating_index: self.currently_updating_index.clone(), | ||||||
|  |             embedders: self.embedders.clone(), | ||||||
|             #[cfg(test)] |             #[cfg(test)] | ||||||
|             test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), |             test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), | ||||||
|             #[cfg(test)] |             #[cfg(test)] | ||||||
| @@ -484,6 +488,7 @@ impl IndexScheduler { | |||||||
|             auth_path: options.auth_path, |             auth_path: options.auth_path, | ||||||
|             version_file_path: options.version_file_path, |             version_file_path: options.version_file_path, | ||||||
|             currently_updating_index: Arc::new(RwLock::new(None)), |             currently_updating_index: Arc::new(RwLock::new(None)), | ||||||
|  |             embedders: Default::default(), | ||||||
|  |  | ||||||
|             #[cfg(test)] |             #[cfg(test)] | ||||||
|             test_breakpoint_sdr, |             test_breakpoint_sdr, | ||||||
| @@ -1333,6 +1338,40 @@ impl IndexScheduler { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // TODO: consider using a type alias or a struct embedder/template | ||||||
|  |     pub fn embedders( | ||||||
|  |         &self, | ||||||
|  |         embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, | ||||||
|  |     ) -> Result<EmbeddingConfigs> { | ||||||
|  |         let res: Result<_> = embedding_configs | ||||||
|  |             .into_iter() | ||||||
|  |             .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { | ||||||
|  |                 let prompt = | ||||||
|  |                     Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?); | ||||||
|  |                 // optimistically return existing embedder | ||||||
|  |                 { | ||||||
|  |                     let embedders = self.embedders.read().unwrap(); | ||||||
|  |                     if let Some(embedder) = embedders.get(&embedder_options) { | ||||||
|  |                         return Ok((name, (embedder.clone(), prompt))); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 // add missing embedder | ||||||
|  |                 let embedder = Arc::new( | ||||||
|  |                     Embedder::new(embedder_options.clone()) | ||||||
|  |                         .map_err(meilisearch_types::milli::vector::Error::from) | ||||||
|  |                         .map_err(meilisearch_types::milli::Error::from)?, | ||||||
|  |                 ); | ||||||
|  |                 { | ||||||
|  |                     let mut embedders = self.embedders.write().unwrap(); | ||||||
|  |                     embedders.insert(embedder_options, embedder.clone()); | ||||||
|  |                 } | ||||||
|  |                 Ok((name, (embedder, prompt))) | ||||||
|  |             }) | ||||||
|  |             .collect(); | ||||||
|  |         res.map(EmbeddingConfigs::new) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Blocks the thread until the test handle asks to progress to/through this breakpoint. |     /// Blocks the thread until the test handle asks to progress to/through this breakpoint. | ||||||
|     /// |     /// | ||||||
|     /// Two messages are sent through the channel for each breakpoint. |     /// Two messages are sent through the channel for each breakpoint. | ||||||
|   | |||||||
| @@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError); | |||||||
| merge_with_error_impl_take_error_message!(ParseTaskKindError); | merge_with_error_impl_take_error_message!(ParseTaskKindError); | ||||||
| merge_with_error_impl_take_error_message!(ParseTaskStatusError); | merge_with_error_impl_take_error_message!(ParseTaskStatusError); | ||||||
| merge_with_error_impl_take_error_message!(IndexUidFormatError); | merge_with_error_impl_take_error_message!(IndexUidFormatError); | ||||||
|  | merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); | ||||||
|   | |||||||
| @@ -222,6 +222,8 @@ InvalidVectorsType                    , InvalidRequest       , BAD_REQUEST ; | |||||||
| InvalidDocumentId                     , InvalidRequest       , BAD_REQUEST ; | InvalidDocumentId                     , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidDocumentLimit                  , InvalidRequest       , BAD_REQUEST ; | InvalidDocumentLimit                  , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidDocumentOffset                 , InvalidRequest       , BAD_REQUEST ; | InvalidDocumentOffset                 , InvalidRequest       , BAD_REQUEST ; | ||||||
|  | InvalidEmbedder                       , InvalidRequest       , BAD_REQUEST ; | ||||||
|  | InvalidHybridQuery                    , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidIndexLimit                     , InvalidRequest       , BAD_REQUEST ; | InvalidIndexLimit                     , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidIndexOffset                    , InvalidRequest       , BAD_REQUEST ; | InvalidIndexOffset                    , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidIndexPrimaryKey                , InvalidRequest       , BAD_REQUEST ; | InvalidIndexPrimaryKey                , InvalidRequest       , BAD_REQUEST ; | ||||||
| @@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve     , InvalidRequest       , BAD_REQUEST ; | |||||||
| InvalidSearchCropLength               , InvalidRequest       , BAD_REQUEST ; | InvalidSearchCropLength               , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSearchCropMarker               , InvalidRequest       , BAD_REQUEST ; | InvalidSearchCropMarker               , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSearchFacets                   , InvalidRequest       , BAD_REQUEST ; | InvalidSearchFacets                   , InvalidRequest       , BAD_REQUEST ; | ||||||
|  | InvalidSearchSemanticRatio            , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidFacetSearchFacetName           , InvalidRequest       , BAD_REQUEST ; | InvalidFacetSearchFacetName           , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSearchFilter                   , InvalidRequest       , BAD_REQUEST ; | InvalidSearchFilter                   , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSearchHighlightPostTag         , InvalidRequest       , BAD_REQUEST ; | InvalidSearchHighlightPostTag         , InvalidRequest       , BAD_REQUEST ; | ||||||
| @@ -256,6 +259,7 @@ InvalidSettingsProximityPrecision     , InvalidRequest       , BAD_REQUEST ; | |||||||
| InvalidSettingsFaceting               , InvalidRequest       , BAD_REQUEST ; | InvalidSettingsFaceting               , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSettingsFilterableAttributes   , InvalidRequest       , BAD_REQUEST ; | InvalidSettingsFilterableAttributes   , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSettingsPagination             , InvalidRequest       , BAD_REQUEST ; | InvalidSettingsPagination             , InvalidRequest       , BAD_REQUEST ; | ||||||
|  | InvalidSettingsEmbedders              , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSettingsRankingRules           , InvalidRequest       , BAD_REQUEST ; | InvalidSettingsRankingRules           , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSettingsSearchableAttributes   , InvalidRequest       , BAD_REQUEST ; | InvalidSettingsSearchableAttributes   , InvalidRequest       , BAD_REQUEST ; | ||||||
| InvalidSettingsSortableAttributes     , InvalidRequest       , BAD_REQUEST ; | InvalidSettingsSortableAttributes     , InvalidRequest       , BAD_REQUEST ; | ||||||
| @@ -295,15 +299,18 @@ MissingFacetSearchFacetName           , InvalidRequest       , BAD_REQUEST ; | |||||||
| MissingIndexUid                       , InvalidRequest       , BAD_REQUEST ; | MissingIndexUid                       , InvalidRequest       , BAD_REQUEST ; | ||||||
| MissingMasterKey                      , Auth                 , UNAUTHORIZED ; | MissingMasterKey                      , Auth                 , UNAUTHORIZED ; | ||||||
| MissingPayload                        , InvalidRequest       , BAD_REQUEST ; | MissingPayload                        , InvalidRequest       , BAD_REQUEST ; | ||||||
|  | MissingSearchHybrid                   , InvalidRequest       , BAD_REQUEST ; | ||||||
| MissingSwapIndexes                    , InvalidRequest       , BAD_REQUEST ; | MissingSwapIndexes                    , InvalidRequest       , BAD_REQUEST ; | ||||||
| MissingTaskFilters                    , InvalidRequest       , BAD_REQUEST ; | MissingTaskFilters                    , InvalidRequest       , BAD_REQUEST ; | ||||||
| NoSpaceLeftOnDevice                   , System               , UNPROCESSABLE_ENTITY; | NoSpaceLeftOnDevice                   , System               , UNPROCESSABLE_ENTITY; | ||||||
| PayloadTooLarge                       , InvalidRequest       , PAYLOAD_TOO_LARGE ; | PayloadTooLarge                       , InvalidRequest       , PAYLOAD_TOO_LARGE ; | ||||||
| TaskNotFound                          , InvalidRequest       , NOT_FOUND ; | TaskNotFound                          , InvalidRequest       , NOT_FOUND ; | ||||||
| TooManyOpenFiles                      , System               , UNPROCESSABLE_ENTITY ; | TooManyOpenFiles                      , System               , UNPROCESSABLE_ENTITY ; | ||||||
|  | TooManyVectors                        , InvalidRequest       , BAD_REQUEST ; | ||||||
| UnretrievableDocument                 , Internal             , BAD_REQUEST ; | UnretrievableDocument                 , Internal             , BAD_REQUEST ; | ||||||
| UnretrievableErrorCode                , InvalidRequest       , BAD_REQUEST ; | UnretrievableErrorCode                , InvalidRequest       , BAD_REQUEST ; | ||||||
| UnsupportedMediaType                  , InvalidRequest       , UNSUPPORTED_MEDIA_TYPE | UnsupportedMediaType                  , InvalidRequest       , UNSUPPORTED_MEDIA_TYPE ; | ||||||
|  | VectorEmbeddingError                  , InvalidRequest       , BAD_REQUEST | ||||||
| } | } | ||||||
|  |  | ||||||
| impl ErrorCode for JoinError { | impl ErrorCode for JoinError { | ||||||
| @@ -336,6 +343,10 @@ impl ErrorCode for milli::Error { | |||||||
|                     UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { |                     UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { | ||||||
|                         Code::InvalidDocumentId |                         Code::InvalidDocumentId | ||||||
|                     } |                     } | ||||||
|  |                     UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, | ||||||
|  |                     UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, | ||||||
|  |                     UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, | ||||||
|  |                     UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, | ||||||
|                     UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, |                     UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, | ||||||
|                     UserError::MultiplePrimaryKeyCandidatesFound { .. } => { |                     UserError::MultiplePrimaryKeyCandidatesFound { .. } => { | ||||||
|                         Code::IndexPrimaryKeyMultipleCandidatesFound |                         Code::IndexPrimaryKeyMultipleCandidatesFound | ||||||
| @@ -353,11 +364,15 @@ impl ErrorCode for milli::Error { | |||||||
|                     UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, |                     UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, | ||||||
|                     UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, |                     UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, | ||||||
|                     UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, |                     UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, | ||||||
|  |                     UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType, | ||||||
|                     UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, |                     UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, | ||||||
|  |                     UserError::TooManyVectors(_, _) => Code::TooManyVectors, | ||||||
|                     UserError::SortError(_) => Code::InvalidSearchSort, |                     UserError::SortError(_) => Code::InvalidSearchSort, | ||||||
|                     UserError::InvalidMinTypoWordLenSetting(_, _) => { |                     UserError::InvalidMinTypoWordLenSetting(_, _) => { | ||||||
|                         Code::InvalidSettingsTypoTolerance |                         Code::InvalidSettingsTypoTolerance | ||||||
|                     } |                     } | ||||||
|  |                     UserError::InvalidEmbedder(_) => Code::InvalidEmbedder, | ||||||
|  |                     UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -445,6 +460,15 @@ impl fmt::Display for DeserrParseIntError { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { | ||||||
|  |     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||||
|  |         write!( | ||||||
|  |             f, | ||||||
|  |             "the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`." | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| #[macro_export] | #[macro_export] | ||||||
| macro_rules! internal_error { | macro_rules! internal_error { | ||||||
|     ($target:ty : $($other:path), *) => { |     ($target:ty : $($other:path), *) => { | ||||||
|   | |||||||
| @@ -199,6 +199,10 @@ pub struct Settings<T> { | |||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)] |     #[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)] | ||||||
|     pub pagination: Setting<PaginationSettings>, |     pub pagination: Setting<PaginationSettings>, | ||||||
|  |  | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||||
|  |     #[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)] | ||||||
|  |     pub embedders: Setting<BTreeMap<String, Setting<milli::vector::settings::EmbeddingSettings>>>, | ||||||
|  |  | ||||||
|     #[serde(skip)] |     #[serde(skip)] | ||||||
|     #[deserr(skip)] |     #[deserr(skip)] | ||||||
|     pub _kind: PhantomData<T>, |     pub _kind: PhantomData<T>, | ||||||
| @@ -222,6 +226,7 @@ impl Settings<Checked> { | |||||||
|             typo_tolerance: Setting::Reset, |             typo_tolerance: Setting::Reset, | ||||||
|             faceting: Setting::Reset, |             faceting: Setting::Reset, | ||||||
|             pagination: Setting::Reset, |             pagination: Setting::Reset, | ||||||
|  |             embedders: Setting::Reset, | ||||||
|             _kind: PhantomData, |             _kind: PhantomData, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -243,6 +248,7 @@ impl Settings<Checked> { | |||||||
|             typo_tolerance, |             typo_tolerance, | ||||||
|             faceting, |             faceting, | ||||||
|             pagination, |             pagination, | ||||||
|  |             embedders, | ||||||
|             .. |             .. | ||||||
|         } = self; |         } = self; | ||||||
|  |  | ||||||
| @@ -262,6 +268,7 @@ impl Settings<Checked> { | |||||||
|             typo_tolerance, |             typo_tolerance, | ||||||
|             faceting, |             faceting, | ||||||
|             pagination, |             pagination, | ||||||
|  |             embedders, | ||||||
|             _kind: PhantomData, |             _kind: PhantomData, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -307,6 +314,7 @@ impl Settings<Unchecked> { | |||||||
|             typo_tolerance: self.typo_tolerance, |             typo_tolerance: self.typo_tolerance, | ||||||
|             faceting: self.faceting, |             faceting: self.faceting, | ||||||
|             pagination: self.pagination, |             pagination: self.pagination, | ||||||
|  |             embedders: self.embedders, | ||||||
|             _kind: PhantomData, |             _kind: PhantomData, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -490,6 +498,12 @@ pub fn apply_settings_to_builder( | |||||||
|         Setting::Reset => builder.reset_pagination_max_total_hits(), |         Setting::Reset => builder.reset_pagination_max_total_hits(), | ||||||
|         Setting::NotSet => (), |         Setting::NotSet => (), | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     match settings.embedders.clone() { | ||||||
|  |         Setting::Set(value) => builder.set_embedder_settings(value), | ||||||
|  |         Setting::Reset => builder.reset_embedder_settings(), | ||||||
|  |         Setting::NotSet => (), | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| pub fn settings( | pub fn settings( | ||||||
| @@ -571,6 +585,12 @@ pub fn settings( | |||||||
|         ), |         ), | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|  |     let embedders = index | ||||||
|  |         .embedding_configs(rtxn)? | ||||||
|  |         .into_iter() | ||||||
|  |         .map(|(name, config)| (name, Setting::Set(config.into()))) | ||||||
|  |         .collect(); | ||||||
|  |  | ||||||
|     Ok(Settings { |     Ok(Settings { | ||||||
|         displayed_attributes: match displayed_attributes { |         displayed_attributes: match displayed_attributes { | ||||||
|             Some(attrs) => Setting::Set(attrs), |             Some(attrs) => Setting::Set(attrs), | ||||||
| @@ -599,6 +619,7 @@ pub fn settings( | |||||||
|         typo_tolerance: Setting::Set(typo_tolerance), |         typo_tolerance: Setting::Set(typo_tolerance), | ||||||
|         faceting: Setting::Set(faceting), |         faceting: Setting::Set(faceting), | ||||||
|         pagination: Setting::Set(pagination), |         pagination: Setting::Set(pagination), | ||||||
|  |         embedders: Setting::Set(embedders), | ||||||
|         _kind: PhantomData, |         _kind: PhantomData, | ||||||
|     }) |     }) | ||||||
| } | } | ||||||
| @@ -747,6 +768,7 @@ pub(crate) mod test { | |||||||
|             typo_tolerance: Setting::NotSet, |             typo_tolerance: Setting::NotSet, | ||||||
|             faceting: Setting::NotSet, |             faceting: Setting::NotSet, | ||||||
|             pagination: Setting::NotSet, |             pagination: Setting::NotSet, | ||||||
|  |             embedders: Setting::NotSet, | ||||||
|             _kind: PhantomData::<Unchecked>, |             _kind: PhantomData::<Unchecked>, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
| @@ -772,6 +794,7 @@ pub(crate) mod test { | |||||||
|             typo_tolerance: Setting::NotSet, |             typo_tolerance: Setting::NotSet, | ||||||
|             faceting: Setting::NotSet, |             faceting: Setting::NotSet, | ||||||
|             pagination: Setting::NotSet, |             pagination: Setting::NotSet, | ||||||
|  |             embedders: Setting::NotSet, | ||||||
|             _kind: PhantomData::<Unchecked>, |             _kind: PhantomData::<Unchecked>, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats}; | |||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, |     FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, | ||||||
|     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, |     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, | ||||||
|     DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, |     DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO, | ||||||
| }; | }; | ||||||
| use crate::Opt; | use crate::Opt; | ||||||
|  |  | ||||||
| @@ -586,6 +586,11 @@ pub struct SearchAggregator { | |||||||
|     // vector |     // vector | ||||||
|     // The maximum number of floats in a vector request |     // The maximum number of floats in a vector request | ||||||
|     max_vector_size: usize, |     max_vector_size: usize, | ||||||
|  |     // Whether the semantic ratio passed to a hybrid search equals the default ratio. | ||||||
|  |     semantic_ratio: bool, | ||||||
|  |     // Whether a non-default embedder was specified | ||||||
|  |     embedder: bool, | ||||||
|  |     hybrid: bool, | ||||||
|  |  | ||||||
|     // every time a search is done, we increment the counter linked to the used settings |     // every time a search is done, we increment the counter linked to the used settings | ||||||
|     matching_strategy: HashMap<String, usize>, |     matching_strategy: HashMap<String, usize>, | ||||||
| @@ -639,6 +644,7 @@ impl SearchAggregator { | |||||||
|             crop_marker, |             crop_marker, | ||||||
|             matching_strategy, |             matching_strategy, | ||||||
|             attributes_to_search_on, |             attributes_to_search_on, | ||||||
|  |             hybrid, | ||||||
|         } = query; |         } = query; | ||||||
|  |  | ||||||
|         let mut ret = Self::default(); |         let mut ret = Self::default(); | ||||||
| @@ -712,6 +718,12 @@ impl SearchAggregator { | |||||||
|         ret.show_ranking_score = *show_ranking_score; |         ret.show_ranking_score = *show_ranking_score; | ||||||
|         ret.show_ranking_score_details = *show_ranking_score_details; |         ret.show_ranking_score_details = *show_ranking_score_details; | ||||||
|  |  | ||||||
|  |         if let Some(hybrid) = hybrid { | ||||||
|  |             ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO(); | ||||||
|  |             ret.embedder = hybrid.embedder.is_some(); | ||||||
|  |             ret.hybrid = true; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         ret |         ret | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -765,6 +777,9 @@ impl SearchAggregator { | |||||||
|             facets_total_number_of_facets, |             facets_total_number_of_facets, | ||||||
|             show_ranking_score, |             show_ranking_score, | ||||||
|             show_ranking_score_details, |             show_ranking_score_details, | ||||||
|  |             semantic_ratio, | ||||||
|  |             embedder, | ||||||
|  |             hybrid, | ||||||
|         } = other; |         } = other; | ||||||
|  |  | ||||||
|         if self.timestamp.is_none() { |         if self.timestamp.is_none() { | ||||||
| @@ -810,6 +825,9 @@ impl SearchAggregator { | |||||||
|  |  | ||||||
|         // vector |         // vector | ||||||
|         self.max_vector_size = self.max_vector_size.max(max_vector_size); |         self.max_vector_size = self.max_vector_size.max(max_vector_size); | ||||||
|  |         self.semantic_ratio |= semantic_ratio; | ||||||
|  |         self.hybrid |= hybrid; | ||||||
|  |         self.embedder |= embedder; | ||||||
|  |  | ||||||
|         // pagination |         // pagination | ||||||
|         self.max_limit = self.max_limit.max(max_limit); |         self.max_limit = self.max_limit.max(max_limit); | ||||||
| @@ -878,6 +896,9 @@ impl SearchAggregator { | |||||||
|             facets_total_number_of_facets, |             facets_total_number_of_facets, | ||||||
|             show_ranking_score, |             show_ranking_score, | ||||||
|             show_ranking_score_details, |             show_ranking_score_details, | ||||||
|  |             semantic_ratio, | ||||||
|  |             embedder, | ||||||
|  |             hybrid, | ||||||
|         } = self; |         } = self; | ||||||
|  |  | ||||||
|         if total_received == 0 { |         if total_received == 0 { | ||||||
| @@ -917,6 +938,11 @@ impl SearchAggregator { | |||||||
|                 "vector": { |                 "vector": { | ||||||
|                     "max_vector_size": max_vector_size, |                     "max_vector_size": max_vector_size, | ||||||
|                 }, |                 }, | ||||||
|  |                 "hybrid": { | ||||||
|  |                     "enabled": hybrid, | ||||||
|  |                     "semantic_ratio": semantic_ratio, | ||||||
|  |                     "embedder": embedder, | ||||||
|  |                 }, | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                    "max_limit": max_limit, |                    "max_limit": max_limit, | ||||||
|                    "max_offset": max_offset, |                    "max_offset": max_offset, | ||||||
| @@ -1012,6 +1038,7 @@ impl MultiSearchAggregator { | |||||||
|                     crop_marker: _, |                     crop_marker: _, | ||||||
|                     matching_strategy: _, |                     matching_strategy: _, | ||||||
|                     attributes_to_search_on: _, |                     attributes_to_search_on: _, | ||||||
|  |                     hybrid: _, | ||||||
|                 } = query; |                 } = query; | ||||||
|  |  | ||||||
|                 index_uid.as_str() |                 index_uid.as_str() | ||||||
| @@ -1158,6 +1185,7 @@ impl FacetSearchAggregator { | |||||||
|             filter, |             filter, | ||||||
|             matching_strategy, |             matching_strategy, | ||||||
|             attributes_to_search_on, |             attributes_to_search_on, | ||||||
|  |             hybrid, | ||||||
|         } = query; |         } = query; | ||||||
|  |  | ||||||
|         let mut ret = Self::default(); |         let mut ret = Self::default(); | ||||||
| @@ -1171,7 +1199,8 @@ impl FacetSearchAggregator { | |||||||
|             || vector.is_some() |             || vector.is_some() | ||||||
|             || filter.is_some() |             || filter.is_some() | ||||||
|             || *matching_strategy != MatchingStrategy::default() |             || *matching_strategy != MatchingStrategy::default() | ||||||
|             || attributes_to_search_on.is_some(); |             || attributes_to_search_on.is_some() | ||||||
|  |             || hybrid.is_some(); | ||||||
|  |  | ||||||
|         ret |         ret | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -51,6 +51,8 @@ pub enum MeilisearchHttpError { | |||||||
|     DocumentFormat(#[from] DocumentFormatError), |     DocumentFormat(#[from] DocumentFormatError), | ||||||
|     #[error(transparent)] |     #[error(transparent)] | ||||||
|     Join(#[from] JoinError), |     Join(#[from] JoinError), | ||||||
|  |     #[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")] | ||||||
|  |     MissingSearchHybrid, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl ErrorCode for MeilisearchHttpError { | impl ErrorCode for MeilisearchHttpError { | ||||||
| @@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError { | |||||||
|             MeilisearchHttpError::FileStore(_) => Code::Internal, |             MeilisearchHttpError::FileStore(_) => Code::Internal, | ||||||
|             MeilisearchHttpError::DocumentFormat(e) => e.error_code(), |             MeilisearchHttpError::DocumentFormat(e) => e.error_code(), | ||||||
|             MeilisearchHttpError::Join(_) => Code::Internal, |             MeilisearchHttpError::Join(_) => Code::Internal, | ||||||
|  |             MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; | |||||||
| /// does all the setup before meilisearch is launched | /// does all the setup before meilisearch is launched | ||||||
| fn setup(opt: &Opt) -> anyhow::Result<()> { | fn setup(opt: &Opt) -> anyhow::Result<()> { | ||||||
|     let mut log_builder = env_logger::Builder::new(); |     let mut log_builder = env_logger::Builder::new(); | ||||||
|     log_builder.parse_filters(&opt.log_level.to_string()); |     let log_filters = format!( | ||||||
|  |         "{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn", | ||||||
|  |         opt.log_level | ||||||
|  |     ); | ||||||
|  |     log_builder.parse_filters(&log_filters); | ||||||
|  |  | ||||||
|     log_builder.init(); |     log_builder.init(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,9 +13,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator}; | |||||||
| use crate::extractors::authentication::policies::*; | use crate::extractors::authentication::policies::*; | ||||||
| use crate::extractors::authentication::GuardedData; | use crate::extractors::authentication::GuardedData; | ||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, |     add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, | ||||||
|     DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, |     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, | ||||||
|     DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, |     DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| pub fn configure(cfg: &mut web::ServiceConfig) { | pub fn configure(cfg: &mut web::ServiceConfig) { | ||||||
| @@ -36,6 +36,8 @@ pub struct FacetSearchQuery { | |||||||
|     pub q: Option<String>, |     pub q: Option<String>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] | ||||||
|     pub vector: Option<Vec<f32>>, |     pub vector: Option<Vec<f32>>, | ||||||
|  |     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] | ||||||
|  |     pub hybrid: Option<HybridQuery>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] | ||||||
|     pub filter: Option<Value>, |     pub filter: Option<Value>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)] | ||||||
| @@ -95,6 +97,7 @@ impl From<FacetSearchQuery> for SearchQuery { | |||||||
|             filter, |             filter, | ||||||
|             matching_strategy, |             matching_strategy, | ||||||
|             attributes_to_search_on, |             attributes_to_search_on, | ||||||
|  |             hybrid, | ||||||
|         } = value; |         } = value; | ||||||
|  |  | ||||||
|         SearchQuery { |         SearchQuery { | ||||||
| @@ -119,6 +122,7 @@ impl From<FacetSearchQuery> for SearchQuery { | |||||||
|             matching_strategy, |             matching_strategy, | ||||||
|             vector, |             vector, | ||||||
|             attributes_to_search_on, |             attributes_to_search_on, | ||||||
|  |             hybrid, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,12 +2,14 @@ use actix_web::web::Data; | |||||||
| use actix_web::{web, HttpRequest, HttpResponse}; | use actix_web::{web, HttpRequest, HttpResponse}; | ||||||
| use deserr::actix_web::{AwebJson, AwebQueryParameter}; | use deserr::actix_web::{AwebJson, AwebQueryParameter}; | ||||||
| use index_scheduler::IndexScheduler; | use index_scheduler::IndexScheduler; | ||||||
| use log::debug; | use log::{debug, warn}; | ||||||
| use meilisearch_types::deserr::query_params::Param; | use meilisearch_types::deserr::query_params::Param; | ||||||
| use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; | use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; | ||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
| use meilisearch_types::error::ResponseError; | use meilisearch_types::error::ResponseError; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
|  | use meilisearch_types::milli; | ||||||
|  | use meilisearch_types::milli::vector::DistributionShift; | ||||||
| use meilisearch_types::serde_cs::vec::CS; | use meilisearch_types::serde_cs::vec::CS; | ||||||
| use serde_json::Value; | use serde_json::Value; | ||||||
|  |  | ||||||
| @@ -16,9 +18,9 @@ use crate::extractors::authentication::policies::*; | |||||||
| use crate::extractors::authentication::GuardedData; | use crate::extractors::authentication::GuardedData; | ||||||
| use crate::extractors::sequential_extractor::SeqHandler; | use crate::extractors::sequential_extractor::SeqHandler; | ||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, |     add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio, | ||||||
|     DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, |     DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, | ||||||
|     DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, |     DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| pub fn configure(cfg: &mut web::ServiceConfig) { | pub fn configure(cfg: &mut web::ServiceConfig) { | ||||||
| @@ -74,6 +76,31 @@ pub struct SearchQueryGet { | |||||||
|     matching_strategy: MatchingStrategy, |     matching_strategy: MatchingStrategy, | ||||||
|     #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)] |     #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)] | ||||||
|     pub attributes_to_search_on: Option<CS<String>>, |     pub attributes_to_search_on: Option<CS<String>>, | ||||||
|  |     #[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)] | ||||||
|  |     pub hybrid_embedder: Option<String>, | ||||||
|  |     #[deserr(default, error = DeserrQueryParamError<InvalidSearchSemanticRatio>)] | ||||||
|  |     pub hybrid_semantic_ratio: Option<SemanticRatioGet>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)] | ||||||
|  | #[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchSemanticRatio)] | ||||||
|  | pub struct SemanticRatioGet(SemanticRatio); | ||||||
|  |  | ||||||
|  | impl std::convert::TryFrom<String> for SemanticRatioGet { | ||||||
|  |     type Error = InvalidSearchSemanticRatio; | ||||||
|  |  | ||||||
|  |     fn try_from(s: String) -> Result<Self, Self::Error> { | ||||||
|  |         let f: f32 = s.parse().map_err(|_| InvalidSearchSemanticRatio)?; | ||||||
|  |         Ok(SemanticRatioGet(SemanticRatio::try_from(f)?)) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl std::ops::Deref for SemanticRatioGet { | ||||||
|  |     type Target = SemanticRatio; | ||||||
|  |  | ||||||
|  |     fn deref(&self) -> &Self::Target { | ||||||
|  |         &self.0 | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl From<SearchQueryGet> for SearchQuery { | impl From<SearchQueryGet> for SearchQuery { | ||||||
| @@ -86,6 +113,20 @@ impl From<SearchQueryGet> for SearchQuery { | |||||||
|             None => None, |             None => None, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|  |         let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) { | ||||||
|  |             (None, None) => None, | ||||||
|  |             (None, Some(semantic_ratio)) => { | ||||||
|  |                 Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: None }) | ||||||
|  |             } | ||||||
|  |             (Some(embedder), None) => Some(HybridQuery { | ||||||
|  |                 semantic_ratio: DEFAULT_SEMANTIC_RATIO(), | ||||||
|  |                 embedder: Some(embedder), | ||||||
|  |             }), | ||||||
|  |             (Some(embedder), Some(semantic_ratio)) => { | ||||||
|  |                 Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: Some(embedder) }) | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |  | ||||||
|         Self { |         Self { | ||||||
|             q: other.q, |             q: other.q, | ||||||
|             vector: other.vector.map(CS::into_inner), |             vector: other.vector.map(CS::into_inner), | ||||||
| @@ -108,6 +149,7 @@ impl From<SearchQueryGet> for SearchQuery { | |||||||
|             crop_marker: other.crop_marker, |             crop_marker: other.crop_marker, | ||||||
|             matching_strategy: other.matching_strategy, |             matching_strategy: other.matching_strategy, | ||||||
|             attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), |             attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), | ||||||
|  |             hybrid, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -158,8 +200,12 @@ pub async fn search_with_url_query( | |||||||
|  |  | ||||||
|     let index = index_scheduler.index(&index_uid)?; |     let index = index_scheduler.index(&index_uid)?; | ||||||
|     let features = index_scheduler.features(); |     let features = index_scheduler.features(); | ||||||
|  |  | ||||||
|  |     let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; | ||||||
|  |  | ||||||
|     let search_result = |     let search_result = | ||||||
|         tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; |         tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) | ||||||
|  |             .await?; | ||||||
|     if let Ok(ref search_result) = search_result { |     if let Ok(ref search_result) = search_result { | ||||||
|         aggregate.succeed(search_result); |         aggregate.succeed(search_result); | ||||||
|     } |     } | ||||||
| @@ -193,8 +239,12 @@ pub async fn search_with_post( | |||||||
|     let index = index_scheduler.index(&index_uid)?; |     let index = index_scheduler.index(&index_uid)?; | ||||||
|  |  | ||||||
|     let features = index_scheduler.features(); |     let features = index_scheduler.features(); | ||||||
|  |  | ||||||
|  |     let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; | ||||||
|  |  | ||||||
|     let search_result = |     let search_result = | ||||||
|         tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; |         tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) | ||||||
|  |             .await?; | ||||||
|     if let Ok(ref search_result) = search_result { |     if let Ok(ref search_result) = search_result { | ||||||
|         aggregate.succeed(search_result); |         aggregate.succeed(search_result); | ||||||
|     } |     } | ||||||
| @@ -206,6 +256,80 @@ pub async fn search_with_post( | |||||||
|     Ok(HttpResponse::Ok().json(search_result)) |     Ok(HttpResponse::Ok().json(search_result)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | pub async fn embed( | ||||||
|  |     query: &mut SearchQuery, | ||||||
|  |     index_scheduler: &IndexScheduler, | ||||||
|  |     index: &milli::Index, | ||||||
|  | ) -> Result<Option<DistributionShift>, ResponseError> { | ||||||
|  |     match (&query.hybrid, &query.vector, &query.q) { | ||||||
|  |         (Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q)) | ||||||
|  |             if !q.trim().is_empty() => | ||||||
|  |         { | ||||||
|  |             let embedder_configs = index.embedding_configs(&index.read_txn()?)?; | ||||||
|  |             let embedders = index_scheduler.embedders(embedder_configs)?; | ||||||
|  |  | ||||||
|  |             let embedder = if let Some(embedder_name) = embedder { | ||||||
|  |                 embedders.get(embedder_name) | ||||||
|  |             } else { | ||||||
|  |                 embedders.get_default() | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             let embedder = embedder | ||||||
|  |                 .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) | ||||||
|  |                 .map_err(milli::Error::from)? | ||||||
|  |                 .0; | ||||||
|  |  | ||||||
|  |             let distribution = embedder.distribution(); | ||||||
|  |  | ||||||
|  |             let embeddings = embedder | ||||||
|  |                 .embed(vec![q.to_owned()]) | ||||||
|  |                 .await | ||||||
|  |                 .map_err(milli::vector::Error::from) | ||||||
|  |                 .map_err(milli::Error::from)? | ||||||
|  |                 .pop() | ||||||
|  |                 .expect("No vector returned from embedding"); | ||||||
|  |  | ||||||
|  |             if embeddings.iter().nth(1).is_some() { | ||||||
|  |                 warn!("Ignoring embeddings past the first one in long search query"); | ||||||
|  |                 query.vector = Some(embeddings.iter().next().unwrap().to_vec()); | ||||||
|  |             } else { | ||||||
|  |                 query.vector = Some(embeddings.into_inner()); | ||||||
|  |             } | ||||||
|  |             Ok(distribution) | ||||||
|  |         } | ||||||
|  |         (Some(hybrid), vector, _) => { | ||||||
|  |             let embedder_configs = index.embedding_configs(&index.read_txn()?)?; | ||||||
|  |             let embedders = index_scheduler.embedders(embedder_configs)?; | ||||||
|  |  | ||||||
|  |             let embedder = if let Some(embedder_name) = &hybrid.embedder { | ||||||
|  |                 embedders.get(embedder_name) | ||||||
|  |             } else { | ||||||
|  |                 embedders.get_default() | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             let embedder = embedder | ||||||
|  |                 .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) | ||||||
|  |                 .map_err(milli::Error::from)? | ||||||
|  |                 .0; | ||||||
|  |  | ||||||
|  |             if let Some(vector) = vector { | ||||||
|  |                 if vector.len() != embedder.dimensions() { | ||||||
|  |                     return Err(meilisearch_types::milli::Error::UserError( | ||||||
|  |                         meilisearch_types::milli::UserError::InvalidVectorDimensions { | ||||||
|  |                             expected: embedder.dimensions(), | ||||||
|  |                             found: vector.len(), | ||||||
|  |                         }, | ||||||
|  |                     ) | ||||||
|  |                     .into()); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             Ok(embedder.distribution()) | ||||||
|  |         } | ||||||
|  |         _ => Ok(None), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod test { | mod test { | ||||||
|     use super::*; |     use super::*; | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError; | |||||||
| use meilisearch_types::error::ResponseError; | use meilisearch_types::error::ResponseError; | ||||||
| use meilisearch_types::facet_values_sort::FacetValuesSort; | use meilisearch_types::facet_values_sort::FacetValuesSort; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
|  | use meilisearch_types::milli::update::Setting; | ||||||
| use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked}; | use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked}; | ||||||
| use meilisearch_types::tasks::KindWithContent; | use meilisearch_types::tasks::KindWithContent; | ||||||
| use serde_json::json; | use serde_json::json; | ||||||
| @@ -546,6 +547,67 @@ make_setting_route!( | |||||||
|     } |     } | ||||||
| ); | ); | ||||||
|  |  | ||||||
|  | make_setting_route!( | ||||||
|  |     "/embedders", | ||||||
|  |     patch, | ||||||
|  |     std::collections::BTreeMap<String, Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>>, | ||||||
|  |     meilisearch_types::deserr::DeserrJsonError< | ||||||
|  |         meilisearch_types::error::deserr_codes::InvalidSettingsEmbedders, | ||||||
|  |     >, | ||||||
|  |     embedders, | ||||||
|  |     "embedders", | ||||||
|  |     analytics, | ||||||
|  |     |setting: &Option<std::collections::BTreeMap<String, Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>>>, req: &HttpRequest| { | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         analytics.publish( | ||||||
|  |             "Embedders Updated".to_string(), | ||||||
|  |             serde_json::json!({"embedders": crate::routes::indexes::settings::embedder_analytics(setting.as_ref())}), | ||||||
|  |             Some(req), | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | ); | ||||||
|  |  | ||||||
|  | fn embedder_analytics( | ||||||
|  |     setting: Option< | ||||||
|  |         &std::collections::BTreeMap< | ||||||
|  |             String, | ||||||
|  |             Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>, | ||||||
|  |         >, | ||||||
|  |     >, | ||||||
|  | ) -> serde_json::Value { | ||||||
|  |     let mut sources = std::collections::HashSet::new(); | ||||||
|  |  | ||||||
|  |     if let Some(s) = &setting { | ||||||
|  |         for source in s | ||||||
|  |             .values() | ||||||
|  |             .filter_map(|config| config.clone().set()) | ||||||
|  |             .filter_map(|config| config.embedder_options.set()) | ||||||
|  |         { | ||||||
|  |             use meilisearch_types::milli::vector::settings::EmbedderSettings; | ||||||
|  |             match source { | ||||||
|  |                 EmbedderSettings::OpenAi(_) => sources.insert("openAi"), | ||||||
|  |                 EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"), | ||||||
|  |                 EmbedderSettings::UserProvided(_) => sources.insert("userProvided"), | ||||||
|  |             }; | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let document_template_used = setting.as_ref().map(|map| { | ||||||
|  |         map.values() | ||||||
|  |             .filter_map(|config| config.clone().set()) | ||||||
|  |             .any(|config| config.document_template.set().is_some()) | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     json!( | ||||||
|  |         { | ||||||
|  |             "total": setting.as_ref().map(|s| s.len()), | ||||||
|  |             "sources": sources, | ||||||
|  |             "document_template_used": document_template_used, | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|  | } | ||||||
|  |  | ||||||
| macro_rules! generate_configure { | macro_rules! generate_configure { | ||||||
|     ($($mod:ident),*) => { |     ($($mod:ident),*) => { | ||||||
|         pub fn configure(cfg: &mut web::ServiceConfig) { |         pub fn configure(cfg: &mut web::ServiceConfig) { | ||||||
| @@ -575,7 +637,8 @@ generate_configure!( | |||||||
|     ranking_rules, |     ranking_rules, | ||||||
|     typo_tolerance, |     typo_tolerance, | ||||||
|     pagination, |     pagination, | ||||||
|     faceting |     faceting, | ||||||
|  |     embedders | ||||||
| ); | ); | ||||||
|  |  | ||||||
| pub async fn update_all( | pub async fn update_all( | ||||||
| @@ -682,6 +745,7 @@ pub async fn update_all( | |||||||
|             "synonyms": { |             "synonyms": { | ||||||
|                 "total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()), |                 "total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()), | ||||||
|             }, |             }, | ||||||
|  |             "embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set()) | ||||||
|         }), |         }), | ||||||
|         Some(&req), |         Some(&req), | ||||||
|     ); |     ); | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator}; | |||||||
| use crate::extractors::authentication::policies::ActionPolicy; | use crate::extractors::authentication::policies::ActionPolicy; | ||||||
| use crate::extractors::authentication::{AuthenticationError, GuardedData}; | use crate::extractors::authentication::{AuthenticationError, GuardedData}; | ||||||
| use crate::extractors::sequential_extractor::SeqHandler; | use crate::extractors::sequential_extractor::SeqHandler; | ||||||
|  | use crate::routes::indexes::search::embed; | ||||||
| use crate::search::{ | use crate::search::{ | ||||||
|     add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, |     add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, | ||||||
| }; | }; | ||||||
| @@ -74,8 +75,13 @@ pub async fn multi_search_with_post( | |||||||
|                 }) |                 }) | ||||||
|                 .with_index(query_index)?; |                 .with_index(query_index)?; | ||||||
|  |  | ||||||
|             let search_result = |             let distribution = embed(&mut query, index_scheduler.get_ref(), &index) | ||||||
|                 tokio::task::spawn_blocking(move || perform_search(&index, query, features)) |                 .await | ||||||
|  |                 .with_index(query_index)?; | ||||||
|  |  | ||||||
|  |             let search_result = tokio::task::spawn_blocking(move || { | ||||||
|  |                 perform_search(&index, query, features, distribution) | ||||||
|  |             }) | ||||||
|             .await |             .await | ||||||
|             .with_index(query_index)?; |             .with_index(query_index)?; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,24 +7,21 @@ use deserr::Deserr; | |||||||
| use either::Either; | use either::Either; | ||||||
| use index_scheduler::RoFeatures; | use index_scheduler::RoFeatures; | ||||||
| use indexmap::IndexMap; | use indexmap::IndexMap; | ||||||
| use log::warn; |  | ||||||
| use meilisearch_auth::IndexSearchRules; | use meilisearch_auth::IndexSearchRules; | ||||||
| use meilisearch_types::deserr::DeserrJsonError; | use meilisearch_types::deserr::DeserrJsonError; | ||||||
| use meilisearch_types::error::deserr_codes::*; | use meilisearch_types::error::deserr_codes::*; | ||||||
| use meilisearch_types::heed::RoTxn; | use meilisearch_types::heed::RoTxn; | ||||||
| use meilisearch_types::index_uid::IndexUid; | use meilisearch_types::index_uid::IndexUid; | ||||||
| use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; | use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; | ||||||
| use meilisearch_types::milli::{ | use meilisearch_types::milli::vector::DistributionShift; | ||||||
|     dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, | use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; | ||||||
| }; |  | ||||||
| use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | ||||||
| use meilisearch_types::{milli, Document}; | use meilisearch_types::{milli, Document}; | ||||||
| use milli::tokenizer::TokenizerBuilder; | use milli::tokenizer::TokenizerBuilder; | ||||||
| use milli::{ | use milli::{ | ||||||
|     AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, |     AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, | ||||||
|     SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, |     SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, | ||||||
| }; | }; | ||||||
| use ordered_float::OrderedFloat; |  | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
| use serde_json::{json, Value}; | use serde_json::{json, Value}; | ||||||
| @@ -39,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10; | |||||||
| pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); | pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); | ||||||
| pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string(); | pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string(); | ||||||
| pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string(); | pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string(); | ||||||
|  | pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5); | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, Default, PartialEq, Deserr)] | #[derive(Debug, Clone, Default, PartialEq, Deserr)] | ||||||
| #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] | #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] | ||||||
| @@ -47,6 +45,8 @@ pub struct SearchQuery { | |||||||
|     pub q: Option<String>, |     pub q: Option<String>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] | ||||||
|     pub vector: Option<Vec<f32>>, |     pub vector: Option<Vec<f32>>, | ||||||
|  |     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] | ||||||
|  |     pub hybrid: Option<HybridQuery>, | ||||||
|     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] |     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] | ||||||
|     pub offset: usize, |     pub offset: usize, | ||||||
|     #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] |     #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] | ||||||
| @@ -87,6 +87,48 @@ pub struct SearchQuery { | |||||||
|     pub attributes_to_search_on: Option<Vec<String>>, |     pub attributes_to_search_on: Option<Vec<String>>, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, PartialEq, Deserr)] | ||||||
|  | #[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub struct HybridQuery { | ||||||
|  |     /// TODO validate that sementic ratio is between 0.0 and 1,0 | ||||||
|  |     #[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)] | ||||||
|  |     pub semantic_ratio: SemanticRatio, | ||||||
|  |     #[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)] | ||||||
|  |     pub embedder: Option<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy, PartialEq, Deserr)] | ||||||
|  | #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] | ||||||
|  | pub struct SemanticRatio(f32); | ||||||
|  |  | ||||||
|  | impl Default for SemanticRatio { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         DEFAULT_SEMANTIC_RATIO() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl std::convert::TryFrom<f32> for SemanticRatio { | ||||||
|  |     type Error = InvalidSearchSemanticRatio; | ||||||
|  |  | ||||||
|  |     fn try_from(f: f32) -> Result<Self, Self::Error> { | ||||||
|  |         // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable | ||||||
|  |         #[allow(clippy::manual_range_contains)] | ||||||
|  |         if f > 1.0 || f < 0.0 { | ||||||
|  |             Err(InvalidSearchSemanticRatio) | ||||||
|  |         } else { | ||||||
|  |             Ok(SemanticRatio(f)) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl std::ops::Deref for SemanticRatio { | ||||||
|  |     type Target = f32; | ||||||
|  |  | ||||||
|  |     fn deref(&self) -> &Self::Target { | ||||||
|  |         &self.0 | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| impl SearchQuery { | impl SearchQuery { | ||||||
|     pub fn is_finite_pagination(&self) -> bool { |     pub fn is_finite_pagination(&self) -> bool { | ||||||
|         self.page.or(self.hits_per_page).is_some() |         self.page.or(self.hits_per_page).is_some() | ||||||
| @@ -106,6 +148,8 @@ pub struct SearchQueryWithIndex { | |||||||
|     pub q: Option<String>, |     pub q: Option<String>, | ||||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] |     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||||
|     pub vector: Option<Vec<f32>>, |     pub vector: Option<Vec<f32>>, | ||||||
|  |     #[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)] | ||||||
|  |     pub hybrid: Option<HybridQuery>, | ||||||
|     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] |     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] | ||||||
|     pub offset: usize, |     pub offset: usize, | ||||||
|     #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] |     #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] | ||||||
| @@ -171,6 +215,7 @@ impl SearchQueryWithIndex { | |||||||
|             crop_marker, |             crop_marker, | ||||||
|             matching_strategy, |             matching_strategy, | ||||||
|             attributes_to_search_on, |             attributes_to_search_on, | ||||||
|  |             hybrid, | ||||||
|         } = self; |         } = self; | ||||||
|         ( |         ( | ||||||
|             index_uid, |             index_uid, | ||||||
| @@ -196,6 +241,7 @@ impl SearchQueryWithIndex { | |||||||
|                 crop_marker, |                 crop_marker, | ||||||
|                 matching_strategy, |                 matching_strategy, | ||||||
|                 attributes_to_search_on, |                 attributes_to_search_on, | ||||||
|  |                 hybrid, | ||||||
|                 // do not use ..Default::default() here, |                 // do not use ..Default::default() here, | ||||||
|                 // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` |                 // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` | ||||||
|             }, |             }, | ||||||
| @@ -335,19 +381,44 @@ fn prepare_search<'t>( | |||||||
|     rtxn: &'t RoTxn, |     rtxn: &'t RoTxn, | ||||||
|     query: &'t SearchQuery, |     query: &'t SearchQuery, | ||||||
|     features: RoFeatures, |     features: RoFeatures, | ||||||
|  |     distribution: Option<DistributionShift>, | ||||||
| ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { | ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { | ||||||
|     let mut search = index.search(rtxn); |     let mut search = index.search(rtxn); | ||||||
|  |  | ||||||
|     if query.vector.is_some() && query.q.is_some() { |     if query.vector.is_some() { | ||||||
|         warn!("Ignoring the query string `q` when used with the `vector` parameter."); |         features.check_vector("Passing `vector` as a query parameter")?; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     if query.hybrid.is_some() { | ||||||
|  |         features.check_vector("Passing `hybrid` as a query parameter")?; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() { | ||||||
|  |         return Err(MeilisearchHttpError::MissingSearchHybrid); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     search.distribution_shift(distribution); | ||||||
|  |  | ||||||
|     if let Some(ref vector) = query.vector { |     if let Some(ref vector) = query.vector { | ||||||
|  |         match &query.hybrid { | ||||||
|  |             // If semantic ratio is 0.0, only the query search will impact the search results, | ||||||
|  |             // skip the vector | ||||||
|  |             Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), | ||||||
|  |             _otherwise => { | ||||||
|                 search.vector(vector.clone()); |                 search.vector(vector.clone()); | ||||||
|             } |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     if let Some(ref query) = query.q { |     if let Some(ref q) = query.q { | ||||||
|         search.query(query); |         match &query.hybrid { | ||||||
|  |             // If semantic ratio is 1.0, only the vector search will impact the search results, | ||||||
|  |             // skip the query | ||||||
|  |             Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (), | ||||||
|  |             _otherwise => { | ||||||
|  |                 search.query(q); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if let Some(ref searchable) = query.attributes_to_search_on { |     if let Some(ref searchable) = query.attributes_to_search_on { | ||||||
| @@ -374,8 +445,8 @@ fn prepare_search<'t>( | |||||||
|         features.check_score_details()?; |         features.check_score_details()?; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if query.vector.is_some() { |     if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { | ||||||
|         features.check_vector()?; |         search.embedder_name(embedder); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // compute the offset on the limit depending on the pagination mode. |     // compute the offset on the limit depending on the pagination mode. | ||||||
| @@ -421,15 +492,22 @@ pub fn perform_search( | |||||||
|     index: &Index, |     index: &Index, | ||||||
|     query: SearchQuery, |     query: SearchQuery, | ||||||
|     features: RoFeatures, |     features: RoFeatures, | ||||||
|  |     distribution: Option<DistributionShift>, | ||||||
| ) -> Result<SearchResult, MeilisearchHttpError> { | ) -> Result<SearchResult, MeilisearchHttpError> { | ||||||
|     let before_search = Instant::now(); |     let before_search = Instant::now(); | ||||||
|     let rtxn = index.read_txn()?; |     let rtxn = index.read_txn()?; | ||||||
|  |  | ||||||
|     let (search, is_finite_pagination, max_total_hits, offset) = |     let (search, is_finite_pagination, max_total_hits, offset) = | ||||||
|         prepare_search(index, &rtxn, &query, features)?; |         prepare_search(index, &rtxn, &query, features, distribution)?; | ||||||
|  |  | ||||||
|     let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = |     let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = | ||||||
|         search.execute()?; |         match &query.hybrid { | ||||||
|  |             Some(hybrid) => match *hybrid.semantic_ratio { | ||||||
|  |                 ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?, | ||||||
|  |                 ratio => search.execute_hybrid(ratio)?, | ||||||
|  |             }, | ||||||
|  |             None => search.execute()?, | ||||||
|  |         }; | ||||||
|  |  | ||||||
|     let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); |     let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); | ||||||
|  |  | ||||||
| @@ -538,13 +616,17 @@ pub fn perform_search( | |||||||
|             insert_geo_distance(sort, &mut document); |             insert_geo_distance(sort, &mut document); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         let semantic_score = match query.vector.as_ref() { |         let mut semantic_score = None; | ||||||
|             Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { |         for details in &score { | ||||||
|                 Some(vectors) => compute_semantic_score(vector, vectors)?, |             if let ScoreDetails::Vector(score_details::Vector { | ||||||
|                 None => None, |                 target_vector: _, | ||||||
|             }, |                 value_similarity: Some((_matching_vector, similarity)), | ||||||
|             None => None, |             }) = details | ||||||
|         }; |             { | ||||||
|  |                 semantic_score = Some(*similarity); | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|         let ranking_score = |         let ranking_score = | ||||||
|             query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); |             query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); | ||||||
| @@ -647,8 +729,9 @@ pub fn perform_facet_search( | |||||||
|     let before_search = Instant::now(); |     let before_search = Instant::now(); | ||||||
|     let rtxn = index.read_txn()?; |     let rtxn = index.read_txn()?; | ||||||
|  |  | ||||||
|     let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?; |     let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?; | ||||||
|     let mut facet_search = SearchForFacetValues::new(facet_name, search); |     let mut facet_search = | ||||||
|  |         SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some()); | ||||||
|     if let Some(facet_query) = &facet_query { |     if let Some(facet_query) = &facet_query { | ||||||
|         facet_search.query(facet_query); |         facet_search.query(facet_query); | ||||||
|     } |     } | ||||||
| @@ -676,18 +759,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result<Option<f32>> { |  | ||||||
|     let vectors = serde_json::from_value(vectors) |  | ||||||
|         .map(VectorOrArrayOfVectors::into_array_of_vectors) |  | ||||||
|         .map_err(InternalError::SerdeJson)?; |  | ||||||
|     Ok(vectors |  | ||||||
|         .into_iter() |  | ||||||
|         .flatten() |  | ||||||
|         .map(|v| OrderedFloat(dot_product_similarity(query, &v))) |  | ||||||
|         .max() |  | ||||||
|         .map(OrderedFloat::into_inner)) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| fn compute_formatted_options( | fn compute_formatted_options( | ||||||
|     attr_to_highlight: &HashSet<String>, |     attr_to_highlight: &HashSet<String>, | ||||||
|     attr_to_crop: &[String], |     attr_to_crop: &[String], | ||||||
| @@ -815,22 +886,6 @@ fn make_document( | |||||||
|     Ok(document) |     Ok(document) | ||||||
| } | } | ||||||
|  |  | ||||||
| /// Extract the JSON value under the field name specified |  | ||||||
| /// but doesn't support nested objects. |  | ||||||
| fn extract_field( |  | ||||||
|     field_name: &str, |  | ||||||
|     field_ids_map: &FieldsIdsMap, |  | ||||||
|     obkv: obkv::KvReaderU16, |  | ||||||
| ) -> Result<Option<serde_json::Value>, MeilisearchHttpError> { |  | ||||||
|     match field_ids_map.id(field_name) { |  | ||||||
|         Some(fid) => match obkv.get(fid) { |  | ||||||
|             Some(value) => Ok(serde_json::from_slice(value).map(Some)?), |  | ||||||
|             None => Ok(None), |  | ||||||
|         }, |  | ||||||
|         None => Ok(None), |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| fn format_fields<'a>( | fn format_fields<'a>( | ||||||
|     document: &Document, |     document: &Document, | ||||||
|     field_ids_map: &FieldsIdsMap, |     field_ids_map: &FieldsIdsMap, | ||||||
|   | |||||||
| @@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "### |     "### | ||||||
|     ); |     ); | ||||||
| @@ -1895,7 +1907,8 @@ async fn import_dump_v6_containing_experimental_features() { | |||||||
|       }, |       }, | ||||||
|       "pagination": { |       "pagination": { | ||||||
|         "maxTotalHits": 1000 |         "maxTotalHits": 1000 | ||||||
|       } |       }, | ||||||
|  |       "embedders": {} | ||||||
|     } |     } | ||||||
|     "###); |     "###); | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										152
									
								
								meilisearch/tests/search/hybrid.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								meilisearch/tests/search/hybrid.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | |||||||
|  | use meili_snap::snapshot; | ||||||
|  | use once_cell::sync::Lazy; | ||||||
|  |  | ||||||
|  | use crate::common::index::Index; | ||||||
|  | use crate::common::{Server, Value}; | ||||||
|  | use crate::json; | ||||||
|  |  | ||||||
|  | async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Index<'a> { | ||||||
|  |     let index = server.index("test"); | ||||||
|  |  | ||||||
|  |     let (response, code) = server.set_features(json!({"vectorStore": true})).await; | ||||||
|  |  | ||||||
|  |     meili_snap::snapshot!(code, @"200 OK"); | ||||||
|  |     meili_snap::snapshot!(meili_snap::json_string!(response), @r###" | ||||||
|  |     { | ||||||
|  |       "scoreDetails": false, | ||||||
|  |       "vectorStore": true, | ||||||
|  |       "metrics": false, | ||||||
|  |       "exportPuffinReports": false, | ||||||
|  |       "proximityPrecision": false | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .update_settings( | ||||||
|  |             json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     assert_eq!(202, code, "{:?}", response); | ||||||
|  |     index.wait_task(response.uid()).await; | ||||||
|  |  | ||||||
|  |     let (response, code) = index.add_documents(documents.clone(), None).await; | ||||||
|  |     assert_eq!(202, code, "{:?}", response); | ||||||
|  |     index.wait_task(response.uid()).await; | ||||||
|  |     index | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static SIMPLE_SEARCH_DOCUMENTS: Lazy<Value> = Lazy::new(|| { | ||||||
|  |     json!([ | ||||||
|  |     { | ||||||
|  |         "title": "Shazam!", | ||||||
|  |         "desc": "a Captain Marvel ersatz", | ||||||
|  |         "id": "1", | ||||||
|  |         "_vectors": {"default": [1.0, 3.0]}, | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "title": "Captain Planet", | ||||||
|  |         "desc": "He's not part of the Marvel Cinematic Universe", | ||||||
|  |         "id": "2", | ||||||
|  |         "_vectors": {"default": [1.0, 2.0]}, | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "title": "Captain Marvel", | ||||||
|  |         "desc": "a Shazam ersatz", | ||||||
|  |         "id": "3", | ||||||
|  |         "_vectors": {"default": [2.0, 3.0]}, | ||||||
|  |     }]) | ||||||
|  | }); | ||||||
|  |  | ||||||
|  | #[actix_rt::test] | ||||||
|  | async fn simple_search() { | ||||||
|  |     let server = Server::new().await; | ||||||
|  |     let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"200 OK"); | ||||||
|  |     snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[actix_rt::test] | ||||||
|  | async fn invalid_semantic_ratio() { | ||||||
|  |     let server = Server::new().await; | ||||||
|  |     let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 1.2}}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", | ||||||
|  |       "code": "invalid_search_semantic_ratio", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_post( | ||||||
|  |             json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": -0.8}}), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", | ||||||
|  |       "code": "invalid_search_semantic_ratio", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_get( | ||||||
|  |             &yaup::to_string( | ||||||
|  |                 &json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": 1.2}), | ||||||
|  |             ) | ||||||
|  |             .unwrap(), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", | ||||||
|  |       "code": "invalid_search_semantic_ratio", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .search_get( | ||||||
|  |             &yaup::to_string( | ||||||
|  |                 &json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": -0.2}), | ||||||
|  |             ) | ||||||
|  |             .unwrap(), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |     snapshot!(code, @"400 Bad Request"); | ||||||
|  |     snapshot!(response, @r###" | ||||||
|  |     { | ||||||
|  |       "message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", | ||||||
|  |       "code": "invalid_search_semantic_ratio", | ||||||
|  |       "type": "invalid_request", | ||||||
|  |       "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" | ||||||
|  |     } | ||||||
|  |     "###); | ||||||
|  | } | ||||||
| @@ -6,6 +6,7 @@ mod errors; | |||||||
| mod facet_search; | mod facet_search; | ||||||
| mod formatted; | mod formatted; | ||||||
| mod geo; | mod geo; | ||||||
|  | mod hybrid; | ||||||
| mod multi; | mod multi; | ||||||
| mod pagination; | mod pagination; | ||||||
| mod restrict_searchable; | mod restrict_searchable; | ||||||
| @@ -20,22 +21,27 @@ static DOCUMENTS: Lazy<Value> = Lazy::new(|| { | |||||||
|         { |         { | ||||||
|             "title": "Shazam!", |             "title": "Shazam!", | ||||||
|             "id": "287947", |             "id": "287947", | ||||||
|  |             "_vectors": { "manual": [1, 2, 3]}, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "title": "Captain Marvel", |             "title": "Captain Marvel", | ||||||
|             "id": "299537", |             "id": "299537", | ||||||
|  |             "_vectors": { "manual": [1, 2, 54] }, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "title": "Escape Room", |             "title": "Escape Room", | ||||||
|             "id": "522681", |             "id": "522681", | ||||||
|  |             "_vectors": { "manual": [10, -23, 32] }, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "title": "How to Train Your Dragon: The Hidden World", |             "title": "How to Train Your Dragon: The Hidden World", | ||||||
|             "id": "166428", |             "id": "166428", | ||||||
|  |             "_vectors": { "manual": [-100, 231, 32] }, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "title": "Gläss", |             "title": "Gläss", | ||||||
|             "id": "450465", |             "id": "450465", | ||||||
|  |             "_vectors": { "manual": [-100, 340, 90] }, | ||||||
|         } |         } | ||||||
|     ]) |     ]) | ||||||
| }); | }); | ||||||
| @@ -57,6 +63,7 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| { | |||||||
|                 }, |                 }, | ||||||
|             ], |             ], | ||||||
|             "cattos": "pésti", |             "cattos": "pésti", | ||||||
|  |             "_vectors": { "manual": [1, 2, 3]}, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "id": 654, |             "id": 654, | ||||||
| @@ -69,12 +76,14 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| { | |||||||
|                 }, |                 }, | ||||||
|             ], |             ], | ||||||
|             "cattos": ["simba", "pestiféré"], |             "cattos": ["simba", "pestiféré"], | ||||||
|  |             "_vectors": { "manual": [1, 2, 54] }, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "id": 750, |             "id": 750, | ||||||
|             "father": "romain", |             "father": "romain", | ||||||
|             "mother": "michelle", |             "mother": "michelle", | ||||||
|             "cattos": ["enigma"], |             "cattos": ["enigma"], | ||||||
|  |             "_vectors": { "manual": [10, 23, 32] }, | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "id": 951, |             "id": 951, | ||||||
| @@ -91,6 +100,7 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| { | |||||||
|                 }, |                 }, | ||||||
|             ], |             ], | ||||||
|             "cattos": ["moumoute", "gomez"], |             "cattos": ["moumoute", "gomez"], | ||||||
|  |             "_vectors": { "manual": [10, 23, 32] }, | ||||||
|         }, |         }, | ||||||
|     ]) |     ]) | ||||||
| }); | }); | ||||||
| @@ -802,6 +812,13 @@ async fn experimental_feature_score_details() { | |||||||
|                   { |                   { | ||||||
|                     "title": "How to Train Your Dragon: The Hidden World", |                     "title": "How to Train Your Dragon: The Hidden World", | ||||||
|                     "id": "166428", |                     "id": "166428", | ||||||
|  |                     "_vectors": { | ||||||
|  |                       "manual": [ | ||||||
|  |                         -100, | ||||||
|  |                         231, | ||||||
|  |                         32 | ||||||
|  |                       ] | ||||||
|  |                     }, | ||||||
|                     "_rankingScoreDetails": { |                     "_rankingScoreDetails": { | ||||||
|                       "words": { |                       "words": { | ||||||
|                         "order": 0, |                         "order": 0, | ||||||
| @@ -823,7 +840,7 @@ async fn experimental_feature_score_details() { | |||||||
|                         "order": 3, |                         "order": 3, | ||||||
|                         "attributeRankingOrderScore": 1.0, |                         "attributeRankingOrderScore": 1.0, | ||||||
|                         "queryWordDistanceScore": 0.8095238095238095, |                         "queryWordDistanceScore": 0.8095238095238095, | ||||||
|                         "score": 0.9365079365079364 |                         "score": 0.9727891156462584 | ||||||
|                       }, |                       }, | ||||||
|                       "exactness": { |                       "exactness": { | ||||||
|                         "order": 4, |                         "order": 4, | ||||||
| @@ -870,13 +887,92 @@ async fn experimental_feature_vector_store() { | |||||||
|     meili_snap::snapshot!(code, @"200 OK"); |     meili_snap::snapshot!(code, @"200 OK"); | ||||||
|     meili_snap::snapshot!(response["vectorStore"], @"true"); |     meili_snap::snapshot!(response["vectorStore"], @"true"); | ||||||
|  |  | ||||||
|  |     let (response, code) = index | ||||||
|  |         .update_settings(json!({"embedders": { | ||||||
|  |             "manual": { | ||||||
|  |                 "source": { | ||||||
|  |                     "userProvided": {"dimensions": 3} | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }})) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |     meili_snap::snapshot!(code, @"202 Accepted"); | ||||||
|  |     let response = index.wait_task(response.uid()).await; | ||||||
|  |  | ||||||
|  |     meili_snap::snapshot!(meili_snap::json_string!(response["status"]), @"\"succeeded\""); | ||||||
|  |  | ||||||
|     let (response, code) = index |     let (response, code) = index | ||||||
|         .search_post(json!({ |         .search_post(json!({ | ||||||
|             "vector": [1.0, 2.0, 3.0], |             "vector": [1.0, 2.0, 3.0], | ||||||
|         })) |         })) | ||||||
|         .await; |         .await; | ||||||
|  |  | ||||||
|     meili_snap::snapshot!(code, @"200 OK"); |     meili_snap::snapshot!(code, @"200 OK"); | ||||||
|     meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); |     // vector search returns all documents that don't have vectors in the last bucket, like all sorts | ||||||
|  |     meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" | ||||||
|  |     [ | ||||||
|  |       { | ||||||
|  |         "title": "Shazam!", | ||||||
|  |         "id": "287947", | ||||||
|  |         "_vectors": { | ||||||
|  |           "manual": [ | ||||||
|  |             1, | ||||||
|  |             2, | ||||||
|  |             3 | ||||||
|  |           ] | ||||||
|  |         }, | ||||||
|  |         "_semanticScore": 1.0 | ||||||
|  |       }, | ||||||
|  |       { | ||||||
|  |         "title": "Captain Marvel", | ||||||
|  |         "id": "299537", | ||||||
|  |         "_vectors": { | ||||||
|  |           "manual": [ | ||||||
|  |             1, | ||||||
|  |             2, | ||||||
|  |             54 | ||||||
|  |           ] | ||||||
|  |         }, | ||||||
|  |         "_semanticScore": 0.9129112 | ||||||
|  |       }, | ||||||
|  |       { | ||||||
|  |         "title": "Gläss", | ||||||
|  |         "id": "450465", | ||||||
|  |         "_vectors": { | ||||||
|  |           "manual": [ | ||||||
|  |             -100, | ||||||
|  |             340, | ||||||
|  |             90 | ||||||
|  |           ] | ||||||
|  |         }, | ||||||
|  |         "_semanticScore": 0.8106413 | ||||||
|  |       }, | ||||||
|  |       { | ||||||
|  |         "title": "How to Train Your Dragon: The Hidden World", | ||||||
|  |         "id": "166428", | ||||||
|  |         "_vectors": { | ||||||
|  |           "manual": [ | ||||||
|  |             -100, | ||||||
|  |             231, | ||||||
|  |             32 | ||||||
|  |           ] | ||||||
|  |         }, | ||||||
|  |         "_semanticScore": 0.74120104 | ||||||
|  |       }, | ||||||
|  |       { | ||||||
|  |         "title": "Escape Room", | ||||||
|  |         "id": "522681", | ||||||
|  |         "_vectors": { | ||||||
|  |           "manual": [ | ||||||
|  |             10, | ||||||
|  |             -23, | ||||||
|  |             32 | ||||||
|  |           ] | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     ] | ||||||
|  |     "###); | ||||||
| } | } | ||||||
|  |  | ||||||
| #[cfg(feature = "default")] | #[cfg(feature = "default")] | ||||||
| @@ -1126,7 +1222,14 @@ async fn simple_search_with_strange_synonyms() { | |||||||
|             [ |             [ | ||||||
|               { |               { | ||||||
|                 "title": "How to Train Your Dragon: The Hidden World", |                 "title": "How to Train Your Dragon: The Hidden World", | ||||||
|                 "id": "166428" |                 "id": "166428", | ||||||
|  |                 "_vectors": { | ||||||
|  |                   "manual": [ | ||||||
|  |                     -100, | ||||||
|  |                     231, | ||||||
|  |                     32 | ||||||
|  |                   ] | ||||||
|  |                 } | ||||||
|               } |               } | ||||||
|             ] |             ] | ||||||
|             "###); |             "###); | ||||||
| @@ -1140,7 +1243,14 @@ async fn simple_search_with_strange_synonyms() { | |||||||
|             [ |             [ | ||||||
|               { |               { | ||||||
|                 "title": "How to Train Your Dragon: The Hidden World", |                 "title": "How to Train Your Dragon: The Hidden World", | ||||||
|                 "id": "166428" |                 "id": "166428", | ||||||
|  |                 "_vectors": { | ||||||
|  |                   "manual": [ | ||||||
|  |                     -100, | ||||||
|  |                     231, | ||||||
|  |                     32 | ||||||
|  |                   ] | ||||||
|  |                 } | ||||||
|               } |               } | ||||||
|             ] |             ] | ||||||
|             "###); |             "###); | ||||||
| @@ -1154,7 +1264,14 @@ async fn simple_search_with_strange_synonyms() { | |||||||
|             [ |             [ | ||||||
|               { |               { | ||||||
|                 "title": "How to Train Your Dragon: The Hidden World", |                 "title": "How to Train Your Dragon: The Hidden World", | ||||||
|                 "id": "166428" |                 "id": "166428", | ||||||
|  |                 "_vectors": { | ||||||
|  |                   "manual": [ | ||||||
|  |                     -100, | ||||||
|  |                     231, | ||||||
|  |                     32 | ||||||
|  |                   ] | ||||||
|  |                 } | ||||||
|               } |               } | ||||||
|             ] |             ] | ||||||
|             "###); |             "###); | ||||||
|   | |||||||
| @@ -72,7 +72,14 @@ async fn simple_search_single_index() { | |||||||
|         "hits": [ |         "hits": [ | ||||||
|           { |           { | ||||||
|             "title": "Gläss", |             "title": "Gläss", | ||||||
|             "id": "450465" |             "id": "450465", | ||||||
|  |             "_vectors": { | ||||||
|  |               "manual": [ | ||||||
|  |                 -100, | ||||||
|  |                 340, | ||||||
|  |                 90 | ||||||
|  |               ] | ||||||
|  |             } | ||||||
|           } |           } | ||||||
|         ], |         ], | ||||||
|         "query": "glass", |         "query": "glass", | ||||||
| @@ -86,7 +93,14 @@ async fn simple_search_single_index() { | |||||||
|         "hits": [ |         "hits": [ | ||||||
|           { |           { | ||||||
|             "title": "Captain Marvel", |             "title": "Captain Marvel", | ||||||
|             "id": "299537" |             "id": "299537", | ||||||
|  |             "_vectors": { | ||||||
|  |               "manual": [ | ||||||
|  |                 1, | ||||||
|  |                 2, | ||||||
|  |                 54 | ||||||
|  |               ] | ||||||
|  |             } | ||||||
|           } |           } | ||||||
|         ], |         ], | ||||||
|         "query": "captain", |         "query": "captain", | ||||||
| @@ -177,7 +191,14 @@ async fn simple_search_two_indexes() { | |||||||
|         "hits": [ |         "hits": [ | ||||||
|           { |           { | ||||||
|             "title": "Gläss", |             "title": "Gläss", | ||||||
|             "id": "450465" |             "id": "450465", | ||||||
|  |             "_vectors": { | ||||||
|  |               "manual": [ | ||||||
|  |                 -100, | ||||||
|  |                 340, | ||||||
|  |                 90 | ||||||
|  |               ] | ||||||
|  |             } | ||||||
|           } |           } | ||||||
|         ], |         ], | ||||||
|         "query": "glass", |         "query": "glass", | ||||||
| @@ -203,7 +224,14 @@ async fn simple_search_two_indexes() { | |||||||
|                 "age": 4 |                 "age": 4 | ||||||
|               } |               } | ||||||
|             ], |             ], | ||||||
|             "cattos": "pésti" |             "cattos": "pésti", | ||||||
|  |             "_vectors": { | ||||||
|  |               "manual": [ | ||||||
|  |                 1, | ||||||
|  |                 2, | ||||||
|  |                 3 | ||||||
|  |               ] | ||||||
|  |             } | ||||||
|           }, |           }, | ||||||
|           { |           { | ||||||
|             "id": 654, |             "id": 654, | ||||||
| @@ -218,8 +246,15 @@ async fn simple_search_two_indexes() { | |||||||
|             "cattos": [ |             "cattos": [ | ||||||
|               "simba", |               "simba", | ||||||
|               "pestiféré" |               "pestiféré" | ||||||
|  |             ], | ||||||
|  |             "_vectors": { | ||||||
|  |               "manual": [ | ||||||
|  |                 1, | ||||||
|  |                 2, | ||||||
|  |                 54 | ||||||
|               ] |               ] | ||||||
|             } |             } | ||||||
|  |           } | ||||||
|         ], |         ], | ||||||
|         "query": "pésti", |         "query": "pésti", | ||||||
|         "processingTimeMs": "[time]", |         "processingTimeMs": "[time]", | ||||||
|   | |||||||
| @@ -54,7 +54,7 @@ async fn get_settings() { | |||||||
|     let (response, code) = index.settings().await; |     let (response, code) = index.settings().await; | ||||||
|     assert_eq!(code, 200); |     assert_eq!(code, 200); | ||||||
|     let settings = response.as_object().unwrap(); |     let settings = response.as_object().unwrap(); | ||||||
|     assert_eq!(settings.keys().len(), 15); |     assert_eq!(settings.keys().len(), 16); | ||||||
|     assert_eq!(settings["displayedAttributes"], json!(["*"])); |     assert_eq!(settings["displayedAttributes"], json!(["*"])); | ||||||
|     assert_eq!(settings["searchableAttributes"], json!(["*"])); |     assert_eq!(settings["searchableAttributes"], json!(["*"])); | ||||||
|     assert_eq!(settings["filterableAttributes"], json!([])); |     assert_eq!(settings["filterableAttributes"], json!([])); | ||||||
| @@ -83,6 +83,7 @@ async fn get_settings() { | |||||||
|             "maxTotalHits": 1000, |             "maxTotalHits": 1000, | ||||||
|         }) |         }) | ||||||
|     ); |     ); | ||||||
|  |     assert_eq!(settings["embedders"], json!({})); | ||||||
| } | } | ||||||
|  |  | ||||||
| #[actix_rt::test] | #[actix_rt::test] | ||||||
|   | |||||||
| @@ -27,13 +27,15 @@ fst = "0.4.7" | |||||||
| fxhash = "0.2.1" | fxhash = "0.2.1" | ||||||
| geoutils = "0.5.1" | geoutils = "0.5.1" | ||||||
| grenad = { version = "0.4.5", default-features = false, features = [ | grenad = { version = "0.4.5", default-features = false, features = [ | ||||||
|     "rayon", "tempfile" |     "rayon", | ||||||
|  |     "tempfile", | ||||||
| ] } | ] } | ||||||
| heed = { version = "0.20.0-alpha.9", default-features = false, features = [ | heed = { version = "0.20.0-alpha.9", default-features = false, features = [ | ||||||
|     "serde-json", "serde-bincode", "read-txn-no-tls" |     "serde-json", | ||||||
|  |     "serde-bincode", | ||||||
|  |     "read-txn-no-tls", | ||||||
| ] } | ] } | ||||||
| indexmap = { version = "2.0.0", features = ["serde"] } | indexmap = { version = "2.0.0", features = ["serde"] } | ||||||
| instant-distance = { version = "0.6.1", features = ["with-serde"] } |  | ||||||
| json-depth-checker = { path = "../json-depth-checker" } | json-depth-checker = { path = "../json-depth-checker" } | ||||||
| levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } | levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } | ||||||
| memmap2 = "0.7.1" | memmap2 = "0.7.1" | ||||||
| @@ -72,6 +74,23 @@ puffin = "0.16.0" | |||||||
| log = "0.4.17" | log = "0.4.17" | ||||||
| logging_timer = "1.1.0" | logging_timer = "1.1.0" | ||||||
| csv = "1.2.1" | csv = "1.2.1" | ||||||
|  | candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } | ||||||
|  | candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } | ||||||
|  | candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } | ||||||
|  | tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } | ||||||
|  | hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ | ||||||
|  |     "online", | ||||||
|  | ] } | ||||||
|  | tokio = { version = "1.34.0", features = ["rt"] } | ||||||
|  | futures = "0.3.29" | ||||||
|  | reqwest = { version = "0.11.16", features = [ | ||||||
|  |     "rustls-tls", | ||||||
|  |     "json", | ||||||
|  | ], default-features = false } | ||||||
|  | tiktoken-rs = "0.5.7" | ||||||
|  | liquid = "0.26.4" | ||||||
|  | arroy = { git = "https://github.com/meilisearch/arroy.git", version = "0.1.0" } | ||||||
|  | rand = "0.8.5" | ||||||
|  |  | ||||||
| [dev-dependencies] | [dev-dependencies] | ||||||
| mimalloc = { version = "0.1.37", default-features = false } | mimalloc = { version = "0.1.37", default-features = false } | ||||||
| @@ -83,7 +102,15 @@ meili-snap = { path = "../meili-snap" } | |||||||
| rand = { version = "0.8.5", features = ["small_rng"] } | rand = { version = "0.8.5", features = ["small_rng"] } | ||||||
|  |  | ||||||
| [features] | [features] | ||||||
| all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] | all-tokenizations = [ | ||||||
|  |     "charabia/chinese", | ||||||
|  |     "charabia/hebrew", | ||||||
|  |     "charabia/japanese", | ||||||
|  |     "charabia/thai", | ||||||
|  |     "charabia/korean", | ||||||
|  |     "charabia/greek", | ||||||
|  |     "charabia/khmer", | ||||||
|  | ] | ||||||
|  |  | ||||||
| # Use POSIX semaphores instead of SysV semaphores in LMDB | # Use POSIX semaphores instead of SysV semaphores in LMDB | ||||||
| # For more information on this feature, see heed's Cargo.toml | # For more information on this feature, see heed's Cargo.toml | ||||||
|   | |||||||
| @@ -5,8 +5,8 @@ use std::time::Instant; | |||||||
|  |  | ||||||
| use heed::EnvOpenOptions; | use heed::EnvOpenOptions; | ||||||
| use milli::{ | use milli::{ | ||||||
|     execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger, |     execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, | ||||||
|     TermsMatchingStrategy, |     SearchLogger, TermsMatchingStrategy, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #[global_allocator] | #[global_allocator] | ||||||
| @@ -49,14 +49,15 @@ fn main() -> Result<(), Box<dyn Error>> { | |||||||
|             let start = Instant::now(); |             let start = Instant::now(); | ||||||
|  |  | ||||||
|             let mut ctx = SearchContext::new(&index, &txn); |             let mut ctx = SearchContext::new(&index, &txn); | ||||||
|  |             let universe = filtered_universe(&ctx, &None)?; | ||||||
|  |  | ||||||
|             let docs = execute_search( |             let docs = execute_search( | ||||||
|                 &mut ctx, |                 &mut ctx, | ||||||
|                 &(!query.trim().is_empty()).then(|| query.trim().to_owned()), |                 (!query.trim().is_empty()).then(|| query.trim()), | ||||||
|                 &None, |  | ||||||
|                 TermsMatchingStrategy::Last, |                 TermsMatchingStrategy::Last, | ||||||
|                 milli::score_details::ScoringStrategy::Skip, |                 milli::score_details::ScoringStrategy::Skip, | ||||||
|                 false, |                 false, | ||||||
|                 &None, |                 universe, | ||||||
|                 &None, |                 &None, | ||||||
|                 GeoSortStrategy::default(), |                 GeoSortStrategy::default(), | ||||||
|                 0, |                 0, | ||||||
|   | |||||||
| @@ -1,41 +0,0 @@ | |||||||
| use std::ops; |  | ||||||
|  |  | ||||||
| use instant_distance::Point; |  | ||||||
| use serde::{Deserialize, Serialize}; |  | ||||||
|  |  | ||||||
| use crate::normalize_vector; |  | ||||||
|  |  | ||||||
| #[derive(Debug, Default, Clone, Serialize, Deserialize)] |  | ||||||
| pub struct NDotProductPoint(Vec<f32>); |  | ||||||
|  |  | ||||||
| impl NDotProductPoint { |  | ||||||
|     pub fn new(point: Vec<f32>) -> Self { |  | ||||||
|         NDotProductPoint(normalize_vector(point)) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn into_inner(self) -> Vec<f32> { |  | ||||||
|         self.0 |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl ops::Deref for NDotProductPoint { |  | ||||||
|     type Target = [f32]; |  | ||||||
|  |  | ||||||
|     fn deref(&self) -> &Self::Target { |  | ||||||
|         self.0.as_slice() |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Point for NDotProductPoint { |  | ||||||
|     fn distance(&self, other: &Self) -> f32 { |  | ||||||
|         let dist = 1.0 - dot_product_similarity(&self.0, &other.0); |  | ||||||
|         debug_assert!(!dist.is_nan()); |  | ||||||
|         dist |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| /// Returns the dot product similarity score that will between 0.0 and 1.0 |  | ||||||
| /// if both vectors are normalized. The higher the more similar the vectors are. |  | ||||||
| pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { |  | ||||||
|     a.iter().zip(b).map(|(a, b)| a * b).sum() |  | ||||||
| } |  | ||||||
| @@ -61,6 +61,10 @@ pub enum InternalError { | |||||||
|     AbortedIndexation, |     AbortedIndexation, | ||||||
|     #[error("The matching words list contains at least one invalid member.")] |     #[error("The matching words list contains at least one invalid member.")] | ||||||
|     InvalidMatchingWords, |     InvalidMatchingWords, | ||||||
|  |     #[error(transparent)] | ||||||
|  |     ArroyError(#[from] arroy::Error), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     VectorEmbeddingError(#[from] crate::vector::Error), | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Error, Debug)] | #[derive(Error, Debug)] | ||||||
| @@ -110,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco | |||||||
|     InvalidGeoField(#[from] GeoError), |     InvalidGeoField(#[from] GeoError), | ||||||
|     #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] |     #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] | ||||||
|     InvalidVectorDimensions { expected: usize, found: usize }, |     InvalidVectorDimensions { expected: usize, found: usize }, | ||||||
|     #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] |     #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] | ||||||
|     InvalidVectorsType { document_id: Value, value: Value }, |     InvalidVectorsType { document_id: Value, value: Value, subfield: String }, | ||||||
|  |     #[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")] | ||||||
|  |     InvalidVectorsMapType { document_id: Value, value: Value }, | ||||||
|     #[error("{0}")] |     #[error("{0}")] | ||||||
|     InvalidFilter(String), |     InvalidFilter(String), | ||||||
|     #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] |     #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] | ||||||
| @@ -180,6 +186,49 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco | |||||||
|     UnknownInternalDocumentId { document_id: DocumentId }, |     UnknownInternalDocumentId { document_id: DocumentId }, | ||||||
|     #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] |     #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] | ||||||
|     InvalidMinTypoWordLenSetting(u8, u8), |     InvalidMinTypoWordLenSetting(u8, u8), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     VectorEmbeddingError(#[from] crate::vector::Error), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     InvalidPrompt(#[from] crate::prompt::error::NewPromptError), | ||||||
|  |     #[error("Invalid prompt in for embeddings with name '{0}': {1}.")] | ||||||
|  |     InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), | ||||||
|  |     #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] | ||||||
|  |     TooManyEmbedders(usize), | ||||||
|  |     #[error("Cannot find embedder with name {0}.")] | ||||||
|  |     InvalidEmbedder(String), | ||||||
|  |     #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] | ||||||
|  |     TooManyVectors(String, usize), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<crate::vector::Error> for Error { | ||||||
|  |     fn from(value: crate::vector::Error) -> Self { | ||||||
|  |         match value.fault() { | ||||||
|  |             FaultSource::User => Error::UserError(value.into()), | ||||||
|  |             FaultSource::Runtime => Error::InternalError(value.into()), | ||||||
|  |             FaultSource::Bug => Error::InternalError(value.into()), | ||||||
|  |             FaultSource::Undecided => Error::InternalError(value.into()), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<arroy::Error> for Error { | ||||||
|  |     fn from(value: arroy::Error) -> Self { | ||||||
|  |         match value { | ||||||
|  |             arroy::Error::Heed(heed) => heed.into(), | ||||||
|  |             arroy::Error::Io(io) => io.into(), | ||||||
|  |             arroy::Error::InvalidVecDimension { expected, received } => { | ||||||
|  |                 Error::UserError(UserError::InvalidVectorDimensions { expected, found: received }) | ||||||
|  |             } | ||||||
|  |             arroy::Error::DatabaseFull | ||||||
|  |             | arroy::Error::InvalidItemAppend | ||||||
|  |             | arroy::Error::UnmatchingDistance { .. } | ||||||
|  |             | arroy::Error::MissingMetadata => { | ||||||
|  |                 Error::InternalError(InternalError::ArroyError(value)) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Error, Debug)] | #[derive(Error, Debug)] | ||||||
| @@ -336,6 +385,26 @@ impl From<HeedError> for Error { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy)] | ||||||
|  | pub enum FaultSource { | ||||||
|  |     User, | ||||||
|  |     Runtime, | ||||||
|  |     Bug, | ||||||
|  |     Undecided, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl std::fmt::Display for FaultSource { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         let s = match self { | ||||||
|  |             FaultSource::User => "user error", | ||||||
|  |             FaultSource::Runtime => "runtime error", | ||||||
|  |             FaultSource::Bug => "coding error", | ||||||
|  |             FaultSource::Undecided => "error", | ||||||
|  |         }; | ||||||
|  |         f.write_str(s) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| #[test] | #[test] | ||||||
| fn conditionally_lookup_for_error_message() { | fn conditionally_lookup_for_error_message() { | ||||||
|     let prefix = "Attribute `name` is not sortable."; |     let prefix = "Attribute `name` is not sortable."; | ||||||
|   | |||||||
| @@ -10,7 +10,6 @@ use roaring::RoaringBitmap; | |||||||
| use rstar::RTree; | use rstar::RTree; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
|  |  | ||||||
| use crate::distance::NDotProductPoint; |  | ||||||
| use crate::documents::PrimaryKey; | use crate::documents::PrimaryKey; | ||||||
| use crate::error::{InternalError, UserError}; | use crate::error::{InternalError, UserError}; | ||||||
| use crate::fields_ids_map::FieldsIdsMap; | use crate::fields_ids_map::FieldsIdsMap; | ||||||
| @@ -22,7 +21,7 @@ use crate::heed_codec::{ | |||||||
|     BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec, |     BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec, | ||||||
| }; | }; | ||||||
| use crate::proximity::ProximityPrecision; | use crate::proximity::ProximityPrecision; | ||||||
| use crate::readable_slices::ReadableSlices; | use crate::vector::EmbeddingConfig; | ||||||
| use crate::{ | use crate::{ | ||||||
|     default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, |     default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, | ||||||
|     FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, |     FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, | ||||||
| @@ -30,9 +29,6 @@ use crate::{ | |||||||
|     BEU32, BEU64, |     BEU32, BEU64, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| /// The HNSW data-structure that we serialize, fill and search in. |  | ||||||
| pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>; |  | ||||||
|  |  | ||||||
| pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; | pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; | ||||||
| pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; | pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; | ||||||
|  |  | ||||||
| @@ -48,10 +44,6 @@ pub mod main_key { | |||||||
|     pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; |     pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; | ||||||
|     pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; |     pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; | ||||||
|     pub const GEO_RTREE_KEY: &str = "geo-rtree"; |     pub const GEO_RTREE_KEY: &str = "geo-rtree"; | ||||||
|     /// The prefix of the key that is used to store the, potential big, HNSW structure. |  | ||||||
|     /// It is concatenated with a big-endian encoded number (non-human readable). |  | ||||||
|     /// e.g. vector-hnsw0x0032. |  | ||||||
|     pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw"; |  | ||||||
|     pub const PRIMARY_KEY_KEY: &str = "primary-key"; |     pub const PRIMARY_KEY_KEY: &str = "primary-key"; | ||||||
|     pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields"; |     pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields"; | ||||||
|     pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields"; |     pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields"; | ||||||
| @@ -74,6 +66,7 @@ pub mod main_key { | |||||||
|     pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; |     pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; | ||||||
|     pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; |     pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; | ||||||
|     pub const PROXIMITY_PRECISION: &str = "proximity-precision"; |     pub const PROXIMITY_PRECISION: &str = "proximity-precision"; | ||||||
|  |     pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; | ||||||
| } | } | ||||||
|  |  | ||||||
| pub mod db_name { | pub mod db_name { | ||||||
| @@ -99,7 +92,8 @@ pub mod db_name { | |||||||
|     pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst"; |     pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst"; | ||||||
|     pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; |     pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; | ||||||
|     pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; |     pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; | ||||||
|     pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; |     pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id"; | ||||||
|  |     pub const VECTOR_ARROY: &str = "vector-arroy"; | ||||||
|     pub const DOCUMENTS: &str = "documents"; |     pub const DOCUMENTS: &str = "documents"; | ||||||
|     pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; |     pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; | ||||||
| } | } | ||||||
| @@ -166,8 +160,10 @@ pub struct Index { | |||||||
|     /// Maps the document id, the facet field id and the strings. |     /// Maps the document id, the facet field id and the strings. | ||||||
|     pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>, |     pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>, | ||||||
|  |  | ||||||
|     /// Maps a vector id to the document id that have it. |     /// Maps an embedder name to its id in the arroy store. | ||||||
|     pub vector_id_docid: Database<BEU32, BEU32>, |     pub embedder_category_id: Database<Str, U8>, | ||||||
|  |     /// Vector store based on arroy™. | ||||||
|  |     pub vector_arroy: arroy::Database<arroy::distances::Angular>, | ||||||
|  |  | ||||||
|     /// Maps the document id to the document as an obkv store. |     /// Maps the document id to the document as an obkv store. | ||||||
|     pub(crate) documents: Database<BEU32, ObkvCodec>, |     pub(crate) documents: Database<BEU32, ObkvCodec>, | ||||||
| @@ -182,7 +178,7 @@ impl Index { | |||||||
|     ) -> Result<Index> { |     ) -> Result<Index> { | ||||||
|         use db_name::*; |         use db_name::*; | ||||||
|  |  | ||||||
|         options.max_dbs(24); |         options.max_dbs(25); | ||||||
|  |  | ||||||
|         let env = options.open(path)?; |         let env = options.open(path)?; | ||||||
|         let mut wtxn = env.write_txn()?; |         let mut wtxn = env.write_txn()?; | ||||||
| @@ -222,7 +218,11 @@ impl Index { | |||||||
|             env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; |             env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; | ||||||
|         let field_id_docid_facet_strings = |         let field_id_docid_facet_strings = | ||||||
|             env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; |             env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; | ||||||
|         let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; |         // vector stuff | ||||||
|  |         let embedder_category_id = | ||||||
|  |             env.create_database(&mut wtxn, Some(VECTOR_EMBEDDER_CATEGORY_ID))?; | ||||||
|  |         let vector_arroy = env.create_database(&mut wtxn, Some(VECTOR_ARROY))?; | ||||||
|  |  | ||||||
|         let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; |         let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; | ||||||
|         wtxn.commit()?; |         wtxn.commit()?; | ||||||
|  |  | ||||||
| @@ -252,7 +252,8 @@ impl Index { | |||||||
|             facet_id_is_empty_docids, |             facet_id_is_empty_docids, | ||||||
|             field_id_docid_facet_f64s, |             field_id_docid_facet_f64s, | ||||||
|             field_id_docid_facet_strings, |             field_id_docid_facet_strings, | ||||||
|             vector_id_docid, |             vector_arroy, | ||||||
|  |             embedder_category_id, | ||||||
|             documents, |             documents, | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
| @@ -475,63 +476,6 @@ impl Index { | |||||||
|             None => Ok(RoaringBitmap::new()), |             None => Ok(RoaringBitmap::new()), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /* vector HNSW */ |  | ||||||
|  |  | ||||||
|     /// Writes the provided `hnsw`. |  | ||||||
|     pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { |  | ||||||
|         // We must delete all the chunks before we write the new HNSW chunks. |  | ||||||
|         self.delete_vector_hnsw(wtxn)?; |  | ||||||
|  |  | ||||||
|         let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB |  | ||||||
|         let bytes = bincode::serialize(hnsw).map_err(Into::into).map_err(heed::Error::Encoding)?; |  | ||||||
|         for (i, chunk) in bytes.chunks(chunk_size).enumerate() { |  | ||||||
|             let i = i as u32; |  | ||||||
|             let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec(); |  | ||||||
|             key.extend_from_slice(&i.to_be_bytes()); |  | ||||||
|             self.main.remap_types::<Bytes, Bytes>().put(wtxn, &key, chunk)?; |  | ||||||
|         } |  | ||||||
|         Ok(()) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     /// Delete the `hnsw`. |  | ||||||
|     pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result<bool> { |  | ||||||
|         let mut iter = self |  | ||||||
|             .main |  | ||||||
|             .remap_types::<Bytes, DecodeIgnore>() |  | ||||||
|             .prefix_iter_mut(wtxn, main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes())?; |  | ||||||
|         let mut deleted = false; |  | ||||||
|         while iter.next().transpose()?.is_some() { |  | ||||||
|             // We do not keep a reference to the key or the value. |  | ||||||
|             unsafe { deleted |= iter.del_current()? }; |  | ||||||
|         } |  | ||||||
|         Ok(deleted) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     /// Returns the `hnsw`. |  | ||||||
|     pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result<Option<Hnsw>> { |  | ||||||
|         let mut slices = Vec::new(); |  | ||||||
|         for result in self |  | ||||||
|             .main |  | ||||||
|             .remap_types::<Str, Bytes>() |  | ||||||
|             .prefix_iter(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)? |  | ||||||
|         { |  | ||||||
|             let (_, slice) = result?; |  | ||||||
|             slices.push(slice); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         if slices.is_empty() { |  | ||||||
|             Ok(None) |  | ||||||
|         } else { |  | ||||||
|             let readable_slices: ReadableSlices<_> = slices.into_iter().collect(); |  | ||||||
|             Ok(Some( |  | ||||||
|                 bincode::deserialize_from(readable_slices) |  | ||||||
|                     .map_err(Into::into) |  | ||||||
|                     .map_err(heed::Error::Decoding)?, |  | ||||||
|             )) |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     /* field distribution */ |     /* field distribution */ | ||||||
|  |  | ||||||
|     /// Writes the field distribution which associates every field name with |     /// Writes the field distribution which associates every field name with | ||||||
| @@ -1528,6 +1472,41 @@ impl Index { | |||||||
|  |  | ||||||
|         Ok(script_language) |         Ok(script_language) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn put_embedding_configs( | ||||||
|  |         &self, | ||||||
|  |         wtxn: &mut RwTxn<'_>, | ||||||
|  |         configs: Vec<(String, EmbeddingConfig)>, | ||||||
|  |     ) -> heed::Result<()> { | ||||||
|  |         self.main.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>().put( | ||||||
|  |             wtxn, | ||||||
|  |             main_key::EMBEDDING_CONFIGS, | ||||||
|  |             &configs, | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> { | ||||||
|  |         self.main.remap_key_type::<Str>().delete(wtxn, main_key::EMBEDDING_CONFIGS) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embedding_configs( | ||||||
|  |         &self, | ||||||
|  |         rtxn: &RoTxn<'_>, | ||||||
|  |     ) -> Result<Vec<(String, crate::vector::EmbeddingConfig)>> { | ||||||
|  |         Ok(self | ||||||
|  |             .main | ||||||
|  |             .remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>() | ||||||
|  |             .get(rtxn, main_key::EMBEDDING_CONFIGS)? | ||||||
|  |             .unwrap_or_default()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result<String> { | ||||||
|  |         let configs = self.embedding_configs(rtxn)?; | ||||||
|  |         Ok(match configs.as_slice() { | ||||||
|  |             [(ref first_name, _)] => first_name.clone(), | ||||||
|  |             _ => "default".to_owned(), | ||||||
|  |         }) | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
|   | |||||||
| @@ -10,18 +10,18 @@ pub mod documents; | |||||||
|  |  | ||||||
| mod asc_desc; | mod asc_desc; | ||||||
| mod criterion; | mod criterion; | ||||||
| pub mod distance; |  | ||||||
| mod error; | mod error; | ||||||
| mod external_documents_ids; | mod external_documents_ids; | ||||||
| pub mod facet; | pub mod facet; | ||||||
| mod fields_ids_map; | mod fields_ids_map; | ||||||
| pub mod heed_codec; | pub mod heed_codec; | ||||||
| pub mod index; | pub mod index; | ||||||
|  | pub mod prompt; | ||||||
| pub mod proximity; | pub mod proximity; | ||||||
| mod readable_slices; |  | ||||||
| pub mod score_details; | pub mod score_details; | ||||||
| mod search; | mod search; | ||||||
| pub mod update; | pub mod update; | ||||||
|  | pub mod vector; | ||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| #[macro_use] | #[macro_use] | ||||||
| @@ -32,13 +32,12 @@ use std::convert::{TryFrom, TryInto}; | |||||||
| use std::hash::BuildHasherDefault; | use std::hash::BuildHasherDefault; | ||||||
|  |  | ||||||
| use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; | use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; | ||||||
| pub use distance::dot_product_similarity; |  | ||||||
| pub use filter_parser::{Condition, FilterCondition, Span, Token}; | pub use filter_parser::{Condition, FilterCondition, Span, Token}; | ||||||
| use fxhash::{FxHasher32, FxHasher64}; | use fxhash::{FxHasher32, FxHasher64}; | ||||||
| pub use grenad::CompressionType; | pub use grenad::CompressionType; | ||||||
| pub use search::new::{ | pub use search::new::{ | ||||||
|     execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger, |     execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, | ||||||
|     VisualSearchLogger, |     SearchLogger, VisualSearchLogger, | ||||||
| }; | }; | ||||||
| use serde_json::Value; | use serde_json::Value; | ||||||
| pub use {charabia as tokenizer, heed}; | pub use {charabia as tokenizer, heed}; | ||||||
|   | |||||||
							
								
								
									
										97
									
								
								milli/src/prompt/context.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								milli/src/prompt/context.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | |||||||
|  | use liquid::model::{ | ||||||
|  |     ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||||
|  | }; | ||||||
|  | use liquid::{ObjectView, ValueView}; | ||||||
|  |  | ||||||
|  | use super::document::Document; | ||||||
|  | use super::fields::Fields; | ||||||
|  | use crate::FieldsIdsMap; | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub struct Context<'a> { | ||||||
|  |     document: &'a Document<'a>, | ||||||
|  |     fields: Fields<'a>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> Context<'a> { | ||||||
|  |     pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { | ||||||
|  |         Self { document, fields: Fields::new(document, field_id_map) } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ObjectView for Context<'a> { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         2 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||||
|  |         Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new( | ||||||
|  |             std::iter::once(self.document.as_value()) | ||||||
|  |                 .chain(std::iter::once(self.fields.as_value())), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||||
|  |         Box::new(self.keys().zip(self.values())) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: &str) -> bool { | ||||||
|  |         index == "doc" || index == "fields" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||||
|  |         match index { | ||||||
|  |             "doc" => Some(self.document.as_value()), | ||||||
|  |             "fields" => Some(self.fields.as_value()), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ValueView for Context<'a> { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "object" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: liquid::model::State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue | State::Empty | State::Blank => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||||
|  |         let s = ObjectRender::new(self).to_string(); | ||||||
|  |         KStringCow::from_string(s) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         LiquidValue::Object( | ||||||
|  |             self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										131
									
								
								milli/src/prompt/document.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								milli/src/prompt/document.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | |||||||
|  | use std::cell::OnceCell; | ||||||
|  | use std::collections::BTreeMap; | ||||||
|  |  | ||||||
|  | use liquid::model::{ | ||||||
|  |     DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||||
|  | }; | ||||||
|  | use liquid::{ObjectView, ValueView}; | ||||||
|  |  | ||||||
|  | use crate::update::del_add::{DelAdd, KvReaderDelAdd}; | ||||||
|  | use crate::FieldsIdsMap; | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>); | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | struct ParsedValue(std::cell::OnceCell<LiquidValue>); | ||||||
|  |  | ||||||
|  | impl ParsedValue { | ||||||
|  |     fn empty() -> ParsedValue { | ||||||
|  |         ParsedValue(OnceCell::new()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get(&self, raw: &[u8]) -> &LiquidValue { | ||||||
|  |         self.0.get_or_init(|| { | ||||||
|  |             let value: serde_json::Value = serde_json::from_slice(raw).unwrap(); | ||||||
|  |             liquid::model::to_value(&value).unwrap() | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> Document<'a> { | ||||||
|  |     pub fn new( | ||||||
|  |         data: obkv::KvReaderU16<'a>, | ||||||
|  |         side: DelAdd, | ||||||
|  |         inverted_field_map: &'a FieldsIdsMap, | ||||||
|  |     ) -> Self { | ||||||
|  |         let mut out_data = BTreeMap::new(); | ||||||
|  |         for (fid, raw) in data { | ||||||
|  |             let obkv = KvReaderDelAdd::new(raw); | ||||||
|  |             let Some(raw) = obkv.get(side) else { | ||||||
|  |                 continue; | ||||||
|  |             }; | ||||||
|  |             let Some(name) = inverted_field_map.name(fid) else { | ||||||
|  |                 continue; | ||||||
|  |             }; | ||||||
|  |             out_data.insert(name, (raw, ParsedValue::empty())); | ||||||
|  |         } | ||||||
|  |         Self(out_data) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn is_empty(&self) -> bool { | ||||||
|  |         self.0.is_empty() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn len(&self) -> usize { | ||||||
|  |         self.0.len() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter(&self) -> impl Iterator<Item = (KString, LiquidValue)> + '_ { | ||||||
|  |         self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned())) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ObjectView for Document<'a> { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         self.len() as i64 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||||
|  |         let keys = BTreeMap::keys(&self.0).map(|&s| s.into()); | ||||||
|  |         Box::new(keys) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView)) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||||
|  |         Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: &str) -> bool { | ||||||
|  |         self.0.contains_key(index) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||||
|  |         self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ValueView for Document<'a> { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "object" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: liquid::model::State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue | State::Empty | State::Blank => self.is_empty(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||||
|  |         let s = ObjectRender::new(self).to_string(); | ||||||
|  |         KStringCow::from_string(s) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         LiquidValue::Object(self.iter().collect()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										56
									
								
								milli/src/prompt/error.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								milli/src/prompt/error.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | |||||||
|  | use crate::error::FaultSource; | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("{fault}: {kind}")] | ||||||
|  | pub struct NewPromptError { | ||||||
|  |     pub kind: NewPromptErrorKind, | ||||||
|  |     pub fault: FaultSource, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<NewPromptError> for crate::Error { | ||||||
|  |     fn from(value: NewPromptError) -> Self { | ||||||
|  |         crate::Error::UserError(crate::UserError::InvalidPrompt(value)) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl NewPromptError { | ||||||
|  |     pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError { | ||||||
|  |         Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError { | ||||||
|  |         Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | pub enum NewPromptErrorKind { | ||||||
|  |     #[error("cannot parse template: {0}")] | ||||||
|  |     CannotParseTemplate(liquid::Error), | ||||||
|  |     #[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")] | ||||||
|  |     InvalidFieldsInTemplate(liquid::Error), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("{fault}: {kind}")] | ||||||
|  | pub struct RenderPromptError { | ||||||
|  |     pub kind: RenderPromptErrorKind, | ||||||
|  |     pub fault: FaultSource, | ||||||
|  | } | ||||||
|  | impl RenderPromptError { | ||||||
|  |     pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError { | ||||||
|  |         Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | pub enum RenderPromptErrorKind { | ||||||
|  |     #[error("missing field in document: {0}")] | ||||||
|  |     MissingContext(liquid::Error), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<RenderPromptError> for crate::Error { | ||||||
|  |     fn from(value: RenderPromptError) -> Self { | ||||||
|  |         crate::Error::UserError(crate::UserError::MissingDocumentField(value)) | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										172
									
								
								milli/src/prompt/fields.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								milli/src/prompt/fields.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | |||||||
|  | use liquid::model::{ | ||||||
|  |     ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||||
|  | }; | ||||||
|  | use liquid::{ObjectView, ValueView}; | ||||||
|  |  | ||||||
|  | use super::document::Document; | ||||||
|  | use crate::FieldsIdsMap; | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub struct Fields<'a>(Vec<FieldValue<'a>>); | ||||||
|  |  | ||||||
|  | impl<'a> Fields<'a> { | ||||||
|  |     pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { | ||||||
|  |         Self( | ||||||
|  |             std::iter::repeat(document) | ||||||
|  |                 .zip(field_id_map.iter()) | ||||||
|  |                 .map(|(document, (_fid, name))| FieldValue { document, name }) | ||||||
|  |                 .collect(), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy)] | ||||||
|  | pub struct FieldValue<'a> { | ||||||
|  |     name: &'a str, | ||||||
|  |     document: &'a Document<'a>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ValueView for FieldValue<'a> { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "object" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: liquid::model::State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue | State::Empty | State::Blank => self.is_empty(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||||
|  |         let s = ObjectRender::new(self).to_string(); | ||||||
|  |         KStringCow::from_string(s) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         LiquidValue::Object( | ||||||
|  |             self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> FieldValue<'a> { | ||||||
|  |     pub fn name(&self) -> &&'a str { | ||||||
|  |         &self.name | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn value(&self) -> &dyn ValueView { | ||||||
|  |         self.document.get(self.name).unwrap_or(&LiquidValue::Nil) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn is_empty(&self) -> bool { | ||||||
|  |         self.size() == 0 | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ObjectView for FieldValue<'a> { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         2 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||||
|  |         Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new( | ||||||
|  |             std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||||
|  |         Box::new(self.keys().zip(self.values())) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: &str) -> bool { | ||||||
|  |         index == "name" || index == "value" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||||
|  |         match index { | ||||||
|  |             "name" => Some(self.name()), | ||||||
|  |             "value" => Some(self.value()), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ArrayView for Fields<'a> { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self.0.as_value() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         self.0.len() as i64 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         self.0.values() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: i64) -> bool { | ||||||
|  |         self.0.contains_key(index) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get(&self, index: i64) -> Option<&dyn ValueView> { | ||||||
|  |         ArrayView::get(&self.0, index) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> ValueView for Fields<'a> { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         self.0.render() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         self.0.source() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         self.0.type_name() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: liquid::model::State) -> bool { | ||||||
|  |         self.0.query_state(state) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||||
|  |         self.0.to_kstr() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         self.0.to_value() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_array(&self) -> Option<&dyn ArrayView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										176
									
								
								milli/src/prompt/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								milli/src/prompt/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,176 @@ | |||||||
|  | mod context; | ||||||
|  | mod document; | ||||||
|  | pub(crate) mod error; | ||||||
|  | mod fields; | ||||||
|  | mod template_checker; | ||||||
|  |  | ||||||
|  | use std::convert::TryFrom; | ||||||
|  |  | ||||||
|  | use error::{NewPromptError, RenderPromptError}; | ||||||
|  |  | ||||||
|  | use self::context::Context; | ||||||
|  | use self::document::Document; | ||||||
|  | use crate::update::del_add::DelAdd; | ||||||
|  | use crate::FieldsIdsMap; | ||||||
|  |  | ||||||
|  | pub struct Prompt { | ||||||
|  |     template: liquid::Template, | ||||||
|  |     template_text: String, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||||||
|  | pub struct PromptData { | ||||||
|  |     pub template: String, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<Prompt> for PromptData { | ||||||
|  |     fn from(value: Prompt) -> Self { | ||||||
|  |         Self { template: value.template_text } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl TryFrom<PromptData> for Prompt { | ||||||
|  |     type Error = NewPromptError; | ||||||
|  |  | ||||||
|  |     fn try_from(value: PromptData) -> Result<Self, Self::Error> { | ||||||
|  |         Prompt::new(value.template) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Clone for Prompt { | ||||||
|  |     fn clone(&self) -> Self { | ||||||
|  |         let template_text = self.template_text.clone(); | ||||||
|  |         Self { template: new_template(&template_text).unwrap(), template_text } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> { | ||||||
|  |     liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn default_template() -> liquid::Template { | ||||||
|  |     new_template(default_template_text()).unwrap() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn default_template_text() -> &'static str { | ||||||
|  |     "{% for field in fields %} \ | ||||||
|  |     {{ field.name }}: {{ field.value }}\n\ | ||||||
|  |     {% endfor %}" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for Prompt { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         Self { template: default_template(), template_text: default_template_text().into() } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for PromptData { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         Self { template: default_template_text().into() } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Prompt { | ||||||
|  |     pub fn new(template: String) -> Result<Self, NewPromptError> { | ||||||
|  |         let this = Self { | ||||||
|  |             template: liquid::ParserBuilder::with_stdlib() | ||||||
|  |                 .build() | ||||||
|  |                 .unwrap() | ||||||
|  |                 .parse(&template) | ||||||
|  |                 .map_err(NewPromptError::cannot_parse_template)?, | ||||||
|  |             template_text: template, | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         // render template with special object that's OK with `doc.*` and `fields.*` | ||||||
|  |         this.template | ||||||
|  |             .render(&template_checker::TemplateChecker) | ||||||
|  |             .map_err(NewPromptError::invalid_fields_in_template)?; | ||||||
|  |  | ||||||
|  |         Ok(this) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn render( | ||||||
|  |         &self, | ||||||
|  |         document: obkv::KvReaderU16<'_>, | ||||||
|  |         side: DelAdd, | ||||||
|  |         field_id_map: &FieldsIdsMap, | ||||||
|  |     ) -> Result<String, RenderPromptError> { | ||||||
|  |         let document = Document::new(document, side, field_id_map); | ||||||
|  |         let context = Context::new(&document, field_id_map); | ||||||
|  |  | ||||||
|  |         self.template.render(&context).map_err(RenderPromptError::missing_context) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[cfg(test)] | ||||||
|  | mod test { | ||||||
|  |     use super::Prompt; | ||||||
|  |     use crate::error::FaultSource; | ||||||
|  |     use crate::prompt::error::{NewPromptError, NewPromptErrorKind}; | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn default_template() { | ||||||
|  |         // does not panic | ||||||
|  |         Prompt::default(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn empty_template() { | ||||||
|  |         Prompt::new("".into()).unwrap(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_ok() { | ||||||
|  |         Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_syntax() { | ||||||
|  |         assert!(matches!( | ||||||
|  |             Prompt::new("{{doc.title: {{doc.overview}}".into()), | ||||||
|  |             Err(NewPromptError { | ||||||
|  |                 kind: NewPromptErrorKind::CannotParseTemplate(_), | ||||||
|  |                 fault: FaultSource::User | ||||||
|  |             }) | ||||||
|  |         )); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_missing_doc() { | ||||||
|  |         assert!(matches!( | ||||||
|  |             Prompt::new("{{title}}: {{overview}}".into()), | ||||||
|  |             Err(NewPromptError { | ||||||
|  |                 kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), | ||||||
|  |                 fault: FaultSource::User | ||||||
|  |             }) | ||||||
|  |         )); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_nested_doc() { | ||||||
|  |         Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_fields() { | ||||||
|  |         Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_fields_ok() { | ||||||
|  |         Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into()) | ||||||
|  |             .unwrap(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn template_fields_invalid() { | ||||||
|  |         assert!(matches!( | ||||||
|  |             // intentionally garbled field | ||||||
|  |             Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()), | ||||||
|  |             Err(NewPromptError { | ||||||
|  |                 kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), | ||||||
|  |                 fault: FaultSource::User | ||||||
|  |             }) | ||||||
|  |         )); | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										301
									
								
								milli/src/prompt/template_checker.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										301
									
								
								milli/src/prompt/template_checker.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,301 @@ | |||||||
|  | use liquid::model::{ | ||||||
|  |     ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||||
|  | }; | ||||||
|  | use liquid::{Object, ObjectView, ValueView}; | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct TemplateChecker; | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct DummyDoc; | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct DummyFields; | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct DummyField; | ||||||
|  |  | ||||||
|  | const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil; | ||||||
|  |  | ||||||
|  | impl ObjectView for DummyField { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         2 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||||
|  |         Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||||
|  |         Box::new(self.keys().zip(self.values())) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: &str) -> bool { | ||||||
|  |         index == "name" || index == "value" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||||
|  |         if self.contains_key(index) { | ||||||
|  |             Some(DUMMY_VALUE.as_view()) | ||||||
|  |         } else { | ||||||
|  |             None | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ValueView for DummyField { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> DisplayCow<'_> { | ||||||
|  |         DUMMY_VALUE.render() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> DisplayCow<'_> { | ||||||
|  |         DUMMY_VALUE.source() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "object" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue => false, | ||||||
|  |             State::Empty => false, | ||||||
|  |             State::Blank => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> KStringCow<'_> { | ||||||
|  |         DUMMY_VALUE.to_kstr() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         let mut this = Object::new(); | ||||||
|  |         this.insert("name".into(), LiquidValue::Nil); | ||||||
|  |         this.insert("value".into(), LiquidValue::Nil); | ||||||
|  |         LiquidValue::Object(this) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ValueView for DummyFields { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> DisplayCow<'_> { | ||||||
|  |         DUMMY_VALUE.render() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> DisplayCow<'_> { | ||||||
|  |         DUMMY_VALUE.source() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "array" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue => false, | ||||||
|  |             State::Empty => false, | ||||||
|  |             State::Blank => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> KStringCow<'_> { | ||||||
|  |         DUMMY_VALUE.to_kstr() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         LiquidValue::Array(vec![DummyField.to_value()]) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_array(&self) -> Option<&dyn ArrayView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ArrayView for DummyFields { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         u16::MAX as i64 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new(std::iter::once(DummyField.as_value())) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: i64) -> bool { | ||||||
|  |         index < self.size() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get(&self, _index: i64) -> Option<&dyn ValueView> { | ||||||
|  |         Some(DummyField.as_value()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ObjectView for DummyDoc { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         1000 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||||
|  |         Box::new(std::iter::empty()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new(std::iter::empty()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||||
|  |         Box::new(std::iter::empty()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, _index: &str) -> bool { | ||||||
|  |         true | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { | ||||||
|  |         // Recursively sends itself | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ValueView for DummyDoc { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> DisplayCow<'_> { | ||||||
|  |         DUMMY_VALUE.render() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> DisplayCow<'_> { | ||||||
|  |         DUMMY_VALUE.source() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "object" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue => false, | ||||||
|  |             State::Empty => false, | ||||||
|  |             State::Blank => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> KStringCow<'_> { | ||||||
|  |         DUMMY_VALUE.to_kstr() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         LiquidValue::Nil | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ObjectView for TemplateChecker { | ||||||
|  |     fn as_value(&self) -> &dyn ValueView { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn size(&self) -> i64 { | ||||||
|  |         2 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||||
|  |         Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||||
|  |         Box::new( | ||||||
|  |             std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||||
|  |         Box::new(self.keys().zip(self.values())) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn contains_key(&self, index: &str) -> bool { | ||||||
|  |         index == "doc" || index == "fields" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||||
|  |         match index { | ||||||
|  |             "doc" => Some(DummyDoc.as_value()), | ||||||
|  |             "fields" => Some(DummyFields.as_value()), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ValueView for TemplateChecker { | ||||||
|  |     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||||
|  |         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn type_name(&self) -> &'static str { | ||||||
|  |         "object" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn query_state(&self, state: liquid::model::State) -> bool { | ||||||
|  |         match state { | ||||||
|  |             State::Truthy => true, | ||||||
|  |             State::DefaultValue | State::Empty | State::Blank => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||||
|  |         let s = ObjectRender::new(self).to_string(); | ||||||
|  |         KStringCow::from_string(s) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn to_value(&self) -> LiquidValue { | ||||||
|  |         LiquidValue::Object( | ||||||
|  |             self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||||
|  |         Some(self) | ||||||
|  |     } | ||||||
|  | } | ||||||
| @@ -1,85 +0,0 @@ | |||||||
| use std::io::{self, Read}; |  | ||||||
| use std::iter::FromIterator; |  | ||||||
|  |  | ||||||
| pub struct ReadableSlices<A> { |  | ||||||
|     inner: Vec<A>, |  | ||||||
|     pos: u64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<A> FromIterator<A> for ReadableSlices<A> { |  | ||||||
|     fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self { |  | ||||||
|         ReadableSlices { inner: iter.into_iter().collect(), pos: 0 } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<A: AsRef<[u8]>> Read for ReadableSlices<A> { |  | ||||||
|     fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> { |  | ||||||
|         let original_buf_len = buf.len(); |  | ||||||
|  |  | ||||||
|         // We explore the list of slices to find the one where we must start reading. |  | ||||||
|         let mut pos = self.pos; |  | ||||||
|         let index = match self |  | ||||||
|             .inner |  | ||||||
|             .iter() |  | ||||||
|             .map(|s| s.as_ref().len() as u64) |  | ||||||
|             .position(|size| pos.checked_sub(size).map(|p| pos = p).is_none()) |  | ||||||
|         { |  | ||||||
|             Some(index) => index, |  | ||||||
|             None => return Ok(0), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         let mut inner_pos = pos as usize; |  | ||||||
|         for slice in &self.inner[index..] { |  | ||||||
|             let slice = &slice.as_ref()[inner_pos..]; |  | ||||||
|  |  | ||||||
|             if buf.len() > slice.len() { |  | ||||||
|                 // We must exhaust the current slice and go to the next one there is not enough here. |  | ||||||
|                 buf[..slice.len()].copy_from_slice(slice); |  | ||||||
|                 buf = &mut buf[slice.len()..]; |  | ||||||
|                 inner_pos = 0; |  | ||||||
|             } else { |  | ||||||
|                 // There is enough in this slice to fill the remaining bytes of the buffer. |  | ||||||
|                 // Let's break just after filling it. |  | ||||||
|                 buf.copy_from_slice(&slice[..buf.len()]); |  | ||||||
|                 buf = &mut []; |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let written = original_buf_len - buf.len(); |  | ||||||
|         self.pos += written as u64; |  | ||||||
|         Ok(written) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod test { |  | ||||||
|     use std::io::Read; |  | ||||||
|  |  | ||||||
|     use super::ReadableSlices; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn basic() { |  | ||||||
|         let data: Vec<_> = (0..100).collect(); |  | ||||||
|         let splits: Vec<_> = data.chunks(3).collect(); |  | ||||||
|         let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); |  | ||||||
|  |  | ||||||
|         let mut output = Vec::new(); |  | ||||||
|         let length = rdslices.read_to_end(&mut output).unwrap(); |  | ||||||
|         assert_eq!(length, data.len()); |  | ||||||
|         assert_eq!(output, data); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn small_reads() { |  | ||||||
|         let data: Vec<_> = (0..u8::MAX).collect(); |  | ||||||
|         let splits: Vec<_> = data.chunks(27).collect(); |  | ||||||
|         let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); |  | ||||||
|  |  | ||||||
|         let buffer = &mut [0; 45]; |  | ||||||
|         let length = rdslices.read(buffer).unwrap(); |  | ||||||
|         let expected: Vec<_> = (0..buffer.len() as u8).collect(); |  | ||||||
|         assert_eq!(length, buffer.len()); |  | ||||||
|         assert_eq!(buffer, &expected[..]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | use std::cmp::Ordering; | ||||||
|  |  | ||||||
|  | use itertools::Itertools; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
|  |  | ||||||
| use crate::distance_between_two_points; | use crate::distance_between_two_points; | ||||||
| @@ -12,9 +15,24 @@ pub enum ScoreDetails { | |||||||
|     ExactAttribute(ExactAttribute), |     ExactAttribute(ExactAttribute), | ||||||
|     ExactWords(ExactWords), |     ExactWords(ExactWords), | ||||||
|     Sort(Sort), |     Sort(Sort), | ||||||
|  |     Vector(Vector), | ||||||
|     GeoSort(GeoSort), |     GeoSort(GeoSort), | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[derive(Clone, Copy)] | ||||||
|  | pub enum ScoreValue<'a> { | ||||||
|  |     Score(f64), | ||||||
|  |     Sort(&'a Sort), | ||||||
|  |     GeoSort(&'a GeoSort), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | enum RankOrValue<'a> { | ||||||
|  |     Rank(Rank), | ||||||
|  |     Sort(&'a Sort), | ||||||
|  |     GeoSort(&'a GeoSort), | ||||||
|  |     Score(f64), | ||||||
|  | } | ||||||
|  |  | ||||||
| impl ScoreDetails { | impl ScoreDetails { | ||||||
|     pub fn local_score(&self) -> Option<f64> { |     pub fn local_score(&self) -> Option<f64> { | ||||||
|         self.rank().map(Rank::local_score) |         self.rank().map(Rank::local_score) | ||||||
| @@ -31,11 +49,55 @@ impl ScoreDetails { | |||||||
|             ScoreDetails::ExactWords(details) => Some(details.rank()), |             ScoreDetails::ExactWords(details) => Some(details.rank()), | ||||||
|             ScoreDetails::Sort(_) => None, |             ScoreDetails::Sort(_) => None, | ||||||
|             ScoreDetails::GeoSort(_) => None, |             ScoreDetails::GeoSort(_) => None, | ||||||
|  |             ScoreDetails::Vector(_) => None, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 { |     pub fn global_score<'a>(details: impl Iterator<Item = &'a Self> + 'a) -> f64 { | ||||||
|         Rank::global_score(details.filter_map(Self::rank)) |         Self::score_values(details) | ||||||
|  |             .find_map(|x| { | ||||||
|  |                 let ScoreValue::Score(score) = x else { | ||||||
|  |                     return None; | ||||||
|  |                 }; | ||||||
|  |                 Some(score) | ||||||
|  |             }) | ||||||
|  |             .unwrap_or(1.0f64) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn score_values<'a>( | ||||||
|  |         details: impl Iterator<Item = &'a Self> + 'a, | ||||||
|  |     ) -> impl Iterator<Item = ScoreValue<'a>> + 'a { | ||||||
|  |         details | ||||||
|  |             .map(ScoreDetails::rank_or_value) | ||||||
|  |             .coalesce(|left, right| match (left, right) { | ||||||
|  |                 (RankOrValue::Rank(left), RankOrValue::Rank(right)) => { | ||||||
|  |                     Ok(RankOrValue::Rank(Rank::merge(left, right))) | ||||||
|  |                 } | ||||||
|  |                 (left, right) => Err((left, right)), | ||||||
|  |             }) | ||||||
|  |             .map(|rank_or_value| match rank_or_value { | ||||||
|  |                 RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()), | ||||||
|  |                 RankOrValue::Sort(s) => ScoreValue::Sort(s), | ||||||
|  |                 RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g), | ||||||
|  |                 RankOrValue::Score(s) => ScoreValue::Score(s), | ||||||
|  |             }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn rank_or_value(&self) -> RankOrValue<'_> { | ||||||
|  |         match self { | ||||||
|  |             ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()), | ||||||
|  |             ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()), | ||||||
|  |             ScoreDetails::Proximity(p) => RankOrValue::Rank(*p), | ||||||
|  |             ScoreDetails::Fid(f) => RankOrValue::Rank(*f), | ||||||
|  |             ScoreDetails::Position(p) => RankOrValue::Rank(*p), | ||||||
|  |             ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()), | ||||||
|  |             ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), | ||||||
|  |             ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), | ||||||
|  |             ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), | ||||||
|  |             ScoreDetails::Vector(vector) => RankOrValue::Score( | ||||||
|  |                 vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64), | ||||||
|  |             ), | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Panics |     /// Panics | ||||||
| @@ -181,6 +243,19 @@ impl ScoreDetails { | |||||||
|                     details_map.insert(sort, sort_details); |                     details_map.insert(sort, sort_details); | ||||||
|                     order += 1; |                     order += 1; | ||||||
|                 } |                 } | ||||||
|  |                 ScoreDetails::Vector(s) => { | ||||||
|  |                     let vector = format!("vectorSort({:?})", s.target_vector); | ||||||
|  |                     let value = s.value_similarity.as_ref().map(|(v, _)| v); | ||||||
|  |                     let similarity = s.value_similarity.as_ref().map(|(_, s)| s); | ||||||
|  |  | ||||||
|  |                     let details = serde_json::json!({ | ||||||
|  |                         "order": order, | ||||||
|  |                         "value": value, | ||||||
|  |                         "similarity": similarity, | ||||||
|  |                     }); | ||||||
|  |                     details_map.insert(vector, details); | ||||||
|  |                     order += 1; | ||||||
|  |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         details_map |         details_map | ||||||
| @@ -297,15 +372,21 @@ impl Rank { | |||||||
|     pub fn global_score(details: impl Iterator<Item = Self>) -> f64 { |     pub fn global_score(details: impl Iterator<Item = Self>) -> f64 { | ||||||
|         let mut rank = Rank { rank: 1, max_rank: 1 }; |         let mut rank = Rank { rank: 1, max_rank: 1 }; | ||||||
|         for inner_rank in details { |         for inner_rank in details { | ||||||
|             rank.rank -= 1; |             rank = Rank::merge(rank, inner_rank); | ||||||
|  |  | ||||||
|             rank.rank *= inner_rank.max_rank; |  | ||||||
|             rank.max_rank *= inner_rank.max_rank; |  | ||||||
|  |  | ||||||
|             rank.rank += inner_rank.rank; |  | ||||||
|         } |         } | ||||||
|         rank.local_score() |         rank.local_score() | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn merge(mut outer: Rank, inner: Rank) -> Rank { | ||||||
|  |         outer.rank = outer.rank.saturating_sub(1); | ||||||
|  |  | ||||||
|  |         outer.rank *= inner.max_rank; | ||||||
|  |         outer.max_rank *= inner.max_rank; | ||||||
|  |  | ||||||
|  |         outer.rank += inner.rank; | ||||||
|  |  | ||||||
|  |         outer | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] | #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] | ||||||
| @@ -335,13 +416,78 @@ pub struct Sort { | |||||||
|     pub value: serde_json::Value, |     pub value: serde_json::Value, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] | impl PartialOrd for Sort { | ||||||
|  |     fn partial_cmp(&self, other: &Self) -> Option<Ordering> { | ||||||
|  |         if self.field_name != other.field_name { | ||||||
|  |             return None; | ||||||
|  |         } | ||||||
|  |         if self.ascending != other.ascending { | ||||||
|  |             return None; | ||||||
|  |         } | ||||||
|  |         match (&self.value, &other.value) { | ||||||
|  |             (serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal), | ||||||
|  |             (serde_json::Value::Null, _) => Some(Ordering::Less), | ||||||
|  |             (_, serde_json::Value::Null) => Some(Ordering::Greater), | ||||||
|  |             // numbers are always before strings | ||||||
|  |             (serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater), | ||||||
|  |             (serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less), | ||||||
|  |             (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { | ||||||
|  |                 // FIXME: unwrap permitted here? | ||||||
|  |                 let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?; | ||||||
|  |                 // 12 < 42, and when ascending, we want to see 12 first, so the smallest. | ||||||
|  |                 // Hence, when ascending, smaller is better | ||||||
|  |                 Some(if self.ascending { order.reverse() } else { order }) | ||||||
|  |             } | ||||||
|  |             (serde_json::Value::String(left), serde_json::Value::String(right)) => { | ||||||
|  |                 let order = left.cmp(right); | ||||||
|  |                 // Taking e.g. "a" and "z" | ||||||
|  |                 // "a" < "z", and when ascending, we want to see "a" first, so the smallest. | ||||||
|  |                 // Hence, when ascending, smaller is better | ||||||
|  |                 Some(if self.ascending { order.reverse() } else { order }) | ||||||
|  |             } | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy, PartialEq)] | ||||||
| pub struct GeoSort { | pub struct GeoSort { | ||||||
|     pub target_point: [f64; 2], |     pub target_point: [f64; 2], | ||||||
|     pub ascending: bool, |     pub ascending: bool, | ||||||
|     pub value: Option<[f64; 2]>, |     pub value: Option<[f64; 2]>, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | impl PartialOrd for GeoSort { | ||||||
|  |     fn partial_cmp(&self, other: &Self) -> Option<Ordering> { | ||||||
|  |         if self.target_point != other.target_point { | ||||||
|  |             return None; | ||||||
|  |         } | ||||||
|  |         if self.ascending != other.ascending { | ||||||
|  |             return None; | ||||||
|  |         } | ||||||
|  |         Some(match (self.distance(), other.distance()) { | ||||||
|  |             (None, None) => Ordering::Equal, | ||||||
|  |             (None, Some(_)) => Ordering::Less, | ||||||
|  |             (Some(_), None) => Ordering::Greater, | ||||||
|  |             (Some(left), Some(right)) => { | ||||||
|  |                 let order = left.partial_cmp(&right)?; | ||||||
|  |                 if self.ascending { | ||||||
|  |                     // when ascending, the one with the smallest distance has the best score | ||||||
|  |                     order.reverse() | ||||||
|  |                 } else { | ||||||
|  |                     order | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, PartialEq, PartialOrd)] | ||||||
|  | pub struct Vector { | ||||||
|  |     pub target_vector: Vec<f32>, | ||||||
|  |     pub value_similarity: Option<(Vec<f32>, f32)>, | ||||||
|  | } | ||||||
|  |  | ||||||
| impl GeoSort { | impl GeoSort { | ||||||
|     pub fn distance(&self) -> Option<f64> { |     pub fn distance(&self) -> Option<f64> { | ||||||
|         self.value.map(|value| distance_between_two_points(&self.target_point, &value)) |         self.value.map(|value| distance_between_two_points(&self.target_point, &value)) | ||||||
|   | |||||||
							
								
								
									
										183
									
								
								milli/src/search/hybrid.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								milli/src/search/hybrid.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,183 @@ | |||||||
|  | use std::cmp::Ordering; | ||||||
|  |  | ||||||
|  | use itertools::Itertools; | ||||||
|  | use roaring::RoaringBitmap; | ||||||
|  |  | ||||||
|  | use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; | ||||||
|  | use crate::{MatchingWords, Result, Search, SearchResult}; | ||||||
|  |  | ||||||
|  | struct ScoreWithRatioResult { | ||||||
|  |     matching_words: MatchingWords, | ||||||
|  |     candidates: RoaringBitmap, | ||||||
|  |     document_scores: Vec<(u32, ScoreWithRatio)>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ScoreWithRatio = (Vec<ScoreDetails>, f32); | ||||||
|  |  | ||||||
|  | fn compare_scores( | ||||||
|  |     &(ref left_scores, left_ratio): &ScoreWithRatio, | ||||||
|  |     &(ref right_scores, right_ratio): &ScoreWithRatio, | ||||||
|  | ) -> Ordering { | ||||||
|  |     let mut left_it = ScoreDetails::score_values(left_scores.iter()); | ||||||
|  |     let mut right_it = ScoreDetails::score_values(right_scores.iter()); | ||||||
|  |  | ||||||
|  |     loop { | ||||||
|  |         let left = left_it.next(); | ||||||
|  |         let right = right_it.next(); | ||||||
|  |  | ||||||
|  |         match (left, right) { | ||||||
|  |             (None, None) => return Ordering::Equal, | ||||||
|  |             (None, Some(_)) => return Ordering::Less, | ||||||
|  |             (Some(_), None) => return Ordering::Greater, | ||||||
|  |             (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { | ||||||
|  |                 let left = left * left_ratio as f64; | ||||||
|  |                 let right = right * right_ratio as f64; | ||||||
|  |                 if (left - right).abs() <= f64::EPSILON { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |                 return left.partial_cmp(&right).unwrap(); | ||||||
|  |             } | ||||||
|  |             (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { | ||||||
|  |                 match left.partial_cmp(right).unwrap() { | ||||||
|  |                     Ordering::Equal => continue, | ||||||
|  |                     order => return order, | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { | ||||||
|  |                 match left.partial_cmp(right).unwrap() { | ||||||
|  |                     Ordering::Equal => continue, | ||||||
|  |                     order => return order, | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             (Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater, | ||||||
|  |             (Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less, | ||||||
|  |             // if we have this, we're bad | ||||||
|  |             (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) | ||||||
|  |             | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { | ||||||
|  |                 unreachable!("Unexpected geo and sort comparison") | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl ScoreWithRatioResult { | ||||||
|  |     fn new(results: SearchResult, ratio: f32) -> Self { | ||||||
|  |         let document_scores = results | ||||||
|  |             .documents_ids | ||||||
|  |             .into_iter() | ||||||
|  |             .zip(results.document_scores.into_iter().map(|scores| (scores, ratio))) | ||||||
|  |             .collect(); | ||||||
|  |  | ||||||
|  |         Self { | ||||||
|  |             matching_words: results.matching_words, | ||||||
|  |             candidates: results.candidates, | ||||||
|  |             document_scores, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult { | ||||||
|  |         let mut documents_ids = | ||||||
|  |             Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); | ||||||
|  |         let mut document_scores = | ||||||
|  |             Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); | ||||||
|  |  | ||||||
|  |         let mut documents_seen = RoaringBitmap::new(); | ||||||
|  |         for (docid, (main_score, _sub_score)) in left | ||||||
|  |             .document_scores | ||||||
|  |             .into_iter() | ||||||
|  |             .merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| { | ||||||
|  |                 // the first value is the one with the greatest score | ||||||
|  |                 compare_scores(left, right).is_ge() | ||||||
|  |             }) | ||||||
|  |             // remove documents we already saw | ||||||
|  |             .filter(|(docid, _)| documents_seen.insert(*docid)) | ||||||
|  |             // start skipping **after** the filter | ||||||
|  |             .skip(from) | ||||||
|  |             // take **after** skipping | ||||||
|  |             .take(length) | ||||||
|  |         { | ||||||
|  |             documents_ids.push(docid); | ||||||
|  |             // TODO: pass both scores to documents_score in some way? | ||||||
|  |             document_scores.push(main_score); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         SearchResult { | ||||||
|  |             matching_words: left.matching_words, | ||||||
|  |             candidates: left.candidates | right.candidates, | ||||||
|  |             documents_ids, | ||||||
|  |             document_scores, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> Search<'a> { | ||||||
|  |     pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> { | ||||||
|  |         // TODO: find classier way to achieve that than to reset vector and query params | ||||||
|  |         // create separate keyword and semantic searches | ||||||
|  |         let mut search = Search { | ||||||
|  |             query: self.query.clone(), | ||||||
|  |             vector: self.vector.clone(), | ||||||
|  |             filter: self.filter.clone(), | ||||||
|  |             offset: 0, | ||||||
|  |             limit: self.limit + self.offset, | ||||||
|  |             sort_criteria: self.sort_criteria.clone(), | ||||||
|  |             searchable_attributes: self.searchable_attributes, | ||||||
|  |             geo_strategy: self.geo_strategy, | ||||||
|  |             terms_matching_strategy: self.terms_matching_strategy, | ||||||
|  |             scoring_strategy: ScoringStrategy::Detailed, | ||||||
|  |             words_limit: self.words_limit, | ||||||
|  |             exhaustive_number_hits: self.exhaustive_number_hits, | ||||||
|  |             rtxn: self.rtxn, | ||||||
|  |             index: self.index, | ||||||
|  |             distribution_shift: self.distribution_shift, | ||||||
|  |             embedder_name: self.embedder_name.clone(), | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         let vector_query = search.vector.take(); | ||||||
|  |         let keyword_results = search.execute()?; | ||||||
|  |  | ||||||
|  |         // skip semantic search if we don't have a vector query (placeholder search) | ||||||
|  |         let Some(vector_query) = vector_query else { | ||||||
|  |             return Ok(keyword_results); | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         // completely skip semantic search if the results of the keyword search are good enough | ||||||
|  |         if self.results_good_enough(&keyword_results, semantic_ratio) { | ||||||
|  |             return Ok(keyword_results); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         search.vector = Some(vector_query); | ||||||
|  |         search.query = None; | ||||||
|  |  | ||||||
|  |         // TODO: would be better to have two distinct functions at this point | ||||||
|  |         let vector_results = search.execute()?; | ||||||
|  |  | ||||||
|  |         let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); | ||||||
|  |         let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); | ||||||
|  |  | ||||||
|  |         let merge_results = | ||||||
|  |             ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); | ||||||
|  |         assert!(merge_results.documents_ids.len() <= self.limit); | ||||||
|  |         Ok(merge_results) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool { | ||||||
|  |         // A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5 | ||||||
|  |         const GOOD_ENOUGH_SCORE: f64 = 0.45; | ||||||
|  |  | ||||||
|  |         // 1. we check that we got a sufficient number of results | ||||||
|  |         if keyword_results.document_scores.len() < self.limit + self.offset { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // 2. and that all results have a good enough score. | ||||||
|  |         // we need to check all results because due to sort like rules, they're not necessarily in relevancy order | ||||||
|  |         for score in &keyword_results.document_scores { | ||||||
|  |             let score = ScoreDetails::global_score(score.iter()); | ||||||
|  |             if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE { | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         true | ||||||
|  |     } | ||||||
|  | } | ||||||
| @@ -12,12 +12,14 @@ use roaring::bitmap::RoaringBitmap; | |||||||
|  |  | ||||||
| pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; | pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; | ||||||
| pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; | pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; | ||||||
| use self::new::PartialSearchResult; | use self::new::{execute_vector_search, PartialSearchResult}; | ||||||
| use crate::error::UserError; | use crate::error::UserError; | ||||||
| use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; | use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; | ||||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||||
|  | use crate::vector::DistributionShift; | ||||||
| use crate::{ | use crate::{ | ||||||
|     execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, |     execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, | ||||||
|  |     Result, SearchContext, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // Building these factories is not free. | // Building these factories is not free. | ||||||
| @@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100; | |||||||
|  |  | ||||||
| pub mod facet; | pub mod facet; | ||||||
| mod fst_utils; | mod fst_utils; | ||||||
|  | pub mod hybrid; | ||||||
| pub mod new; | pub mod new; | ||||||
|  |  | ||||||
| pub struct Search<'a> { | pub struct Search<'a> { | ||||||
| @@ -46,8 +49,11 @@ pub struct Search<'a> { | |||||||
|     scoring_strategy: ScoringStrategy, |     scoring_strategy: ScoringStrategy, | ||||||
|     words_limit: usize, |     words_limit: usize, | ||||||
|     exhaustive_number_hits: bool, |     exhaustive_number_hits: bool, | ||||||
|  |     /// TODO: Add semantic ratio or pass it directly to execute_hybrid() | ||||||
|     rtxn: &'a heed::RoTxn<'a>, |     rtxn: &'a heed::RoTxn<'a>, | ||||||
|     index: &'a Index, |     index: &'a Index, | ||||||
|  |     distribution_shift: Option<DistributionShift>, | ||||||
|  |     embedder_name: Option<String>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<'a> Search<'a> { | impl<'a> Search<'a> { | ||||||
| @@ -67,6 +73,8 @@ impl<'a> Search<'a> { | |||||||
|             words_limit: 10, |             words_limit: 10, | ||||||
|             rtxn, |             rtxn, | ||||||
|             index, |             index, | ||||||
|  |             distribution_shift: None, | ||||||
|  |             embedder_name: None, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -75,8 +83,8 @@ impl<'a> Search<'a> { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> { |     pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> { | ||||||
|         self.vector = Some(vector.into()); |         self.vector = Some(vector); | ||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -133,22 +141,66 @@ impl<'a> Search<'a> { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn distribution_shift( | ||||||
|  |         &mut self, | ||||||
|  |         distribution_shift: Option<DistributionShift>, | ||||||
|  |     ) -> &mut Search<'a> { | ||||||
|  |         self.distribution_shift = distribution_shift; | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> { | ||||||
|  |         self.embedder_name = Some(embedder_name.into()); | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> { | ||||||
|  |         if has_vector_search { | ||||||
|  |             let ctx = SearchContext::new(self.index, self.rtxn); | ||||||
|  |             filtered_universe(&ctx, &self.filter) | ||||||
|  |         } else { | ||||||
|  |             Ok(self.execute()?.candidates) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn execute(&self) -> Result<SearchResult> { |     pub fn execute(&self) -> Result<SearchResult> { | ||||||
|  |         let embedder_name; | ||||||
|  |         let embedder_name = match &self.embedder_name { | ||||||
|  |             Some(embedder_name) => embedder_name, | ||||||
|  |             None => { | ||||||
|  |                 embedder_name = self.index.default_embedding_name(self.rtxn)?; | ||||||
|  |                 &embedder_name | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |  | ||||||
|         let mut ctx = SearchContext::new(self.index, self.rtxn); |         let mut ctx = SearchContext::new(self.index, self.rtxn); | ||||||
|  |  | ||||||
|         if let Some(searchable_attributes) = self.searchable_attributes { |         if let Some(searchable_attributes) = self.searchable_attributes { | ||||||
|             ctx.searchable_attributes(searchable_attributes)?; |             ctx.searchable_attributes(searchable_attributes)?; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         let universe = filtered_universe(&ctx, &self.filter)?; | ||||||
|         let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = |         let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = | ||||||
|             execute_search( |             match self.vector.as_ref() { | ||||||
|  |                 Some(vector) => execute_vector_search( | ||||||
|                     &mut ctx, |                     &mut ctx, | ||||||
|                 &self.query, |                     vector, | ||||||
|                 &self.vector, |                     self.scoring_strategy, | ||||||
|  |                     universe, | ||||||
|  |                     &self.sort_criteria, | ||||||
|  |                     self.geo_strategy, | ||||||
|  |                     self.offset, | ||||||
|  |                     self.limit, | ||||||
|  |                     self.distribution_shift, | ||||||
|  |                     embedder_name, | ||||||
|  |                 )?, | ||||||
|  |                 None => execute_search( | ||||||
|  |                     &mut ctx, | ||||||
|  |                     self.query.as_deref(), | ||||||
|                     self.terms_matching_strategy, |                     self.terms_matching_strategy, | ||||||
|                     self.scoring_strategy, |                     self.scoring_strategy, | ||||||
|                     self.exhaustive_number_hits, |                     self.exhaustive_number_hits, | ||||||
|                 &self.filter, |                     universe, | ||||||
|                     &self.sort_criteria, |                     &self.sort_criteria, | ||||||
|                     self.geo_strategy, |                     self.geo_strategy, | ||||||
|                     self.offset, |                     self.offset, | ||||||
| @@ -156,7 +208,8 @@ impl<'a> Search<'a> { | |||||||
|                     Some(self.words_limit), |                     Some(self.words_limit), | ||||||
|                     &mut DefaultSearchLogger, |                     &mut DefaultSearchLogger, | ||||||
|                     &mut DefaultSearchLogger, |                     &mut DefaultSearchLogger, | ||||||
|             )?; |                 )?, | ||||||
|  |             }; | ||||||
|  |  | ||||||
|         // consume context and located_query_terms to build MatchingWords. |         // consume context and located_query_terms to build MatchingWords. | ||||||
|         let matching_words = match located_query_terms { |         let matching_words = match located_query_terms { | ||||||
| @@ -185,6 +238,8 @@ impl fmt::Debug for Search<'_> { | |||||||
|             exhaustive_number_hits, |             exhaustive_number_hits, | ||||||
|             rtxn: _, |             rtxn: _, | ||||||
|             index: _, |             index: _, | ||||||
|  |             distribution_shift, | ||||||
|  |             embedder_name, | ||||||
|         } = self; |         } = self; | ||||||
|         f.debug_struct("Search") |         f.debug_struct("Search") | ||||||
|             .field("query", query) |             .field("query", query) | ||||||
| @@ -198,6 +253,8 @@ impl fmt::Debug for Search<'_> { | |||||||
|             .field("scoring_strategy", scoring_strategy) |             .field("scoring_strategy", scoring_strategy) | ||||||
|             .field("exhaustive_number_hits", exhaustive_number_hits) |             .field("exhaustive_number_hits", exhaustive_number_hits) | ||||||
|             .field("words_limit", words_limit) |             .field("words_limit", words_limit) | ||||||
|  |             .field("distribution_shift", distribution_shift) | ||||||
|  |             .field("embedder_name", embedder_name) | ||||||
|             .finish() |             .finish() | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -249,11 +306,16 @@ pub struct SearchForFacetValues<'a> { | |||||||
|     query: Option<String>, |     query: Option<String>, | ||||||
|     facet: String, |     facet: String, | ||||||
|     search_query: Search<'a>, |     search_query: Search<'a>, | ||||||
|  |     is_hybrid: bool, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<'a> SearchForFacetValues<'a> { | impl<'a> SearchForFacetValues<'a> { | ||||||
|     pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> { |     pub fn new( | ||||||
|         SearchForFacetValues { query: None, facet, search_query } |         facet: String, | ||||||
|  |         search_query: Search<'a>, | ||||||
|  |         is_hybrid: bool, | ||||||
|  |     ) -> SearchForFacetValues<'a> { | ||||||
|  |         SearchForFacetValues { query: None, facet, search_query, is_hybrid } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn query(&mut self, query: impl Into<String>) -> &mut Self { |     pub fn query(&mut self, query: impl Into<String>) -> &mut Self { | ||||||
| @@ -303,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> { | |||||||
|             None => return Ok(vec![]), |             None => return Ok(vec![]), | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let search_candidates = self.search_query.execute()?.candidates; |         let search_candidates = self | ||||||
|  |             .search_query | ||||||
|  |             .execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?; | ||||||
|  |  | ||||||
|         match self.query.as_ref() { |         match self.query.as_ref() { | ||||||
|             Some(query) => { |             Some(query) => { | ||||||
|   | |||||||
| @@ -107,12 +107,16 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> { | |||||||
|  |  | ||||||
|     /// Refill the internal buffer of cached docids based on the strategy. |     /// Refill the internal buffer of cached docids based on the strategy. | ||||||
|     /// Drop the rtree if we don't need it anymore. |     /// Drop the rtree if we don't need it anymore. | ||||||
|     fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> { |     fn fill_buffer( | ||||||
|  |         &mut self, | ||||||
|  |         ctx: &mut SearchContext, | ||||||
|  |         geo_candidates: &RoaringBitmap, | ||||||
|  |     ) -> Result<()> { | ||||||
|         debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); |         debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); | ||||||
|         debug_assert!(self.cached_sorted_docids.is_empty()); |         debug_assert!(self.cached_sorted_docids.is_empty()); | ||||||
|  |  | ||||||
|         // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` |         // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` | ||||||
|         let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) { |         let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) { | ||||||
|             if let Some(rtree) = self.rtree.as_ref() { |             if let Some(rtree) = self.rtree.as_ref() { | ||||||
|                 // get rtree from cache |                 // get rtree from cache | ||||||
|                 Some(rtree) |                 Some(rtree) | ||||||
| @@ -131,7 +135,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> { | |||||||
|             if self.ascending { |             if self.ascending { | ||||||
|                 let point = lat_lng_to_xyz(&self.point); |                 let point = lat_lng_to_xyz(&self.point); | ||||||
|                 for point in rtree.nearest_neighbor_iter(&point) { |                 for point in rtree.nearest_neighbor_iter(&point) { | ||||||
|                     if self.geo_candidates.contains(point.data.0) { |                     if geo_candidates.contains(point.data.0) { | ||||||
|                         self.cached_sorted_docids.push_back(point.data); |                         self.cached_sorted_docids.push_back(point.data); | ||||||
|                         if self.cached_sorted_docids.len() >= cache_size { |                         if self.cached_sorted_docids.len() >= cache_size { | ||||||
|                             break; |                             break; | ||||||
| @@ -143,7 +147,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> { | |||||||
|                 // and we insert the points in reverse order they get reversed when emptying the cache later on |                 // and we insert the points in reverse order they get reversed when emptying the cache later on | ||||||
|                 let point = lat_lng_to_xyz(&opposite_of(self.point)); |                 let point = lat_lng_to_xyz(&opposite_of(self.point)); | ||||||
|                 for point in rtree.nearest_neighbor_iter(&point) { |                 for point in rtree.nearest_neighbor_iter(&point) { | ||||||
|                     if self.geo_candidates.contains(point.data.0) { |                     if geo_candidates.contains(point.data.0) { | ||||||
|                         self.cached_sorted_docids.push_front(point.data); |                         self.cached_sorted_docids.push_front(point.data); | ||||||
|                         if self.cached_sorted_docids.len() >= cache_size { |                         if self.cached_sorted_docids.len() >= cache_size { | ||||||
|                             break; |                             break; | ||||||
| @@ -155,8 +159,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> { | |||||||
|             // the iterative version |             // the iterative version | ||||||
|             let [lat, lng] = self.field_ids.unwrap(); |             let [lat, lng] = self.field_ids.unwrap(); | ||||||
|  |  | ||||||
|             let mut documents = self |             let mut documents = geo_candidates | ||||||
|                 .geo_candidates |  | ||||||
|                 .iter() |                 .iter() | ||||||
|                 .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) |                 .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) | ||||||
|                 .collect::<Result<Vec<(u32, [f64; 2])>>>()?; |                 .collect::<Result<Vec<(u32, [f64; 2])>>>()?; | ||||||
| @@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | |||||||
|         assert!(self.query.is_none()); |         assert!(self.query.is_none()); | ||||||
|  |  | ||||||
|         self.query = Some(query.clone()); |         self.query = Some(query.clone()); | ||||||
|         self.geo_candidates &= universe; |  | ||||||
|  |  | ||||||
|         if self.geo_candidates.is_empty() { |         let geo_candidates = &self.geo_candidates & universe; | ||||||
|  |  | ||||||
|  |         if geo_candidates.is_empty() { | ||||||
|             return Ok(()); |             return Ok(()); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | |||||||
|         let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat"); |         let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat"); | ||||||
|         let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng"); |         let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng"); | ||||||
|         self.field_ids = Some([lat, lng]); |         self.field_ids = Some([lat, lng]); | ||||||
|         self.fill_buffer(ctx)?; |         self.fill_buffer(ctx, &geo_candidates)?; | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | |||||||
|         universe: &RoaringBitmap, |         universe: &RoaringBitmap, | ||||||
|     ) -> Result<Option<RankingRuleOutput<Q>>> { |     ) -> Result<Option<RankingRuleOutput<Q>>> { | ||||||
|         let query = self.query.as_ref().unwrap().clone(); |         let query = self.query.as_ref().unwrap().clone(); | ||||||
|         self.geo_candidates &= universe; |  | ||||||
|  |  | ||||||
|         if self.geo_candidates.is_empty() { |         let geo_candidates = &self.geo_candidates & universe; | ||||||
|  |  | ||||||
|  |         if geo_candidates.is_empty() { | ||||||
|             return Ok(Some(RankingRuleOutput { |             return Ok(Some(RankingRuleOutput { | ||||||
|                 query, |                 query, | ||||||
|                 candidates: universe.clone(), |                 candidates: universe.clone(), | ||||||
| @@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | |||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
|         while let Some((id, point)) = next(&mut self.cached_sorted_docids) { |         while let Some((id, point)) = next(&mut self.cached_sorted_docids) { | ||||||
|             if self.geo_candidates.contains(id) { |             if geo_candidates.contains(id) { | ||||||
|                 return Ok(Some(RankingRuleOutput { |                 return Ok(Some(RankingRuleOutput { | ||||||
|                     query, |                     query, | ||||||
|                     candidates: RoaringBitmap::from_iter([id]), |                     candidates: RoaringBitmap::from_iter([id]), | ||||||
| @@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | |||||||
|  |  | ||||||
|         // if we got out of this loop it means we've exhausted our cache. |         // if we got out of this loop it means we've exhausted our cache. | ||||||
|         // we need to refill it and run the function again. |         // we need to refill it and run the function again. | ||||||
|         self.fill_buffer(ctx)?; |         self.fill_buffer(ctx, &geo_candidates)?; | ||||||
|         self.next_bucket(ctx, logger, universe) |         self.next_bucket(ctx, logger, universe) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -498,19 +498,19 @@ mod tests { | |||||||
|  |  | ||||||
|     use super::*; |     use super::*; | ||||||
|     use crate::index::tests::TempIndex; |     use crate::index::tests::TempIndex; | ||||||
|     use crate::{execute_search, SearchContext}; |     use crate::{execute_search, filtered_universe, SearchContext}; | ||||||
|  |  | ||||||
|     impl<'a> MatcherBuilder<'a> { |     impl<'a> MatcherBuilder<'a> { | ||||||
|         fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { |         fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { | ||||||
|             let mut ctx = SearchContext::new(index, rtxn); |             let mut ctx = SearchContext::new(index, rtxn); | ||||||
|  |             let universe = filtered_universe(&ctx, &None).unwrap(); | ||||||
|             let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( |             let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( | ||||||
|                 &mut ctx, |                 &mut ctx, | ||||||
|                 &Some(query.to_string()), |                 Some(query), | ||||||
|                 &None, |  | ||||||
|                 crate::TermsMatchingStrategy::default(), |                 crate::TermsMatchingStrategy::default(), | ||||||
|                 crate::score_details::ScoringStrategy::Skip, |                 crate::score_details::ScoringStrategy::Skip, | ||||||
|                 false, |                 false, | ||||||
|                 &None, |                 universe, | ||||||
|                 &None, |                 &None, | ||||||
|                 crate::search::new::GeoSortStrategy::default(), |                 crate::search::new::GeoSortStrategy::default(), | ||||||
|                 0, |                 0, | ||||||
|   | |||||||
| @@ -16,6 +16,7 @@ mod small_bitmap; | |||||||
|  |  | ||||||
| mod exact_attribute; | mod exact_attribute; | ||||||
| mod sort; | mod sort; | ||||||
|  | mod vector_sort; | ||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests; | mod tests; | ||||||
| @@ -28,7 +29,6 @@ use db_cache::DatabaseCache; | |||||||
| use exact_attribute::ExactAttribute; | use exact_attribute::ExactAttribute; | ||||||
| use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; | use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; | ||||||
| use heed::RoTxn; | use heed::RoTxn; | ||||||
| use instant_distance::Search; |  | ||||||
| use interner::{DedupInterner, Interner}; | use interner::{DedupInterner, Interner}; | ||||||
| pub use logger::visual::VisualSearchLogger; | pub use logger::visual::VisualSearchLogger; | ||||||
| pub use logger::{DefaultSearchLogger, SearchLogger}; | pub use logger::{DefaultSearchLogger, SearchLogger}; | ||||||
| @@ -46,10 +46,11 @@ use self::geo_sort::GeoSort; | |||||||
| pub use self::geo_sort::Strategy as GeoSortStrategy; | pub use self::geo_sort::Strategy as GeoSortStrategy; | ||||||
| use self::graph_based_ranking_rule::Words; | use self::graph_based_ranking_rule::Words; | ||||||
| use self::interner::Interned; | use self::interner::Interned; | ||||||
| use crate::distance::NDotProductPoint; | use self::vector_sort::VectorSort; | ||||||
| use crate::error::FieldIdMapMissingEntry; | use crate::error::FieldIdMapMissingEntry; | ||||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||||
| use crate::search::new::distinct::apply_distinct_rule; | use crate::search::new::distinct::apply_distinct_rule; | ||||||
|  | use crate::vector::DistributionShift; | ||||||
| use crate::{ | use crate::{ | ||||||
|     AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, |     AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, | ||||||
| }; | }; | ||||||
| @@ -258,6 +259,80 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( | |||||||
|     Ok(ranking_rules) |     Ok(ranking_rules) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | fn get_ranking_rules_for_vector<'ctx>( | ||||||
|  |     ctx: &SearchContext<'ctx>, | ||||||
|  |     sort_criteria: &Option<Vec<AscDesc>>, | ||||||
|  |     geo_strategy: geo_sort::Strategy, | ||||||
|  |     limit_plus_offset: usize, | ||||||
|  |     target: &[f32], | ||||||
|  |     distribution_shift: Option<DistributionShift>, | ||||||
|  |     embedder_name: &str, | ||||||
|  | ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { | ||||||
|  |     // query graph search | ||||||
|  |  | ||||||
|  |     let mut sort = false; | ||||||
|  |     let mut sorted_fields = HashSet::new(); | ||||||
|  |     let mut geo_sorted = false; | ||||||
|  |  | ||||||
|  |     let mut vector = false; | ||||||
|  |     let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![]; | ||||||
|  |  | ||||||
|  |     let settings_ranking_rules = ctx.index.criteria(ctx.txn)?; | ||||||
|  |     for rr in settings_ranking_rules { | ||||||
|  |         match rr { | ||||||
|  |             crate::Criterion::Words | ||||||
|  |             | crate::Criterion::Typo | ||||||
|  |             | crate::Criterion::Proximity | ||||||
|  |             | crate::Criterion::Attribute | ||||||
|  |             | crate::Criterion::Exactness => { | ||||||
|  |                 if !vector { | ||||||
|  |                     let vector_candidates = ctx.index.documents_ids(ctx.txn)?; | ||||||
|  |                     let vector_sort = VectorSort::new( | ||||||
|  |                         ctx, | ||||||
|  |                         target.to_vec(), | ||||||
|  |                         vector_candidates, | ||||||
|  |                         limit_plus_offset, | ||||||
|  |                         distribution_shift, | ||||||
|  |                         embedder_name, | ||||||
|  |                     )?; | ||||||
|  |                     ranking_rules.push(Box::new(vector_sort)); | ||||||
|  |                     vector = true; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             crate::Criterion::Sort => { | ||||||
|  |                 if sort { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |                 resolve_sort_criteria( | ||||||
|  |                     sort_criteria, | ||||||
|  |                     ctx, | ||||||
|  |                     &mut ranking_rules, | ||||||
|  |                     &mut sorted_fields, | ||||||
|  |                     &mut geo_sorted, | ||||||
|  |                     geo_strategy, | ||||||
|  |                 )?; | ||||||
|  |                 sort = true; | ||||||
|  |             } | ||||||
|  |             crate::Criterion::Asc(field_name) => { | ||||||
|  |                 if sorted_fields.contains(&field_name) { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |                 sorted_fields.insert(field_name.clone()); | ||||||
|  |                 ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?)); | ||||||
|  |             } | ||||||
|  |             crate::Criterion::Desc(field_name) => { | ||||||
|  |                 if sorted_fields.contains(&field_name) { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |                 sorted_fields.insert(field_name.clone()); | ||||||
|  |                 ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     Ok(ranking_rules) | ||||||
|  | } | ||||||
|  |  | ||||||
| /// Return the list of initialised ranking rules to be used for a query graph search. | /// Return the list of initialised ranking rules to be used for a query graph search. | ||||||
| fn get_ranking_rules_for_query_graph_search<'ctx>( | fn get_ranking_rules_for_query_graph_search<'ctx>( | ||||||
|     ctx: &SearchContext<'ctx>, |     ctx: &SearchContext<'ctx>, | ||||||
| @@ -422,15 +497,72 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( | |||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> { | ||||||
|  |     Ok(if let Some(filters) = filters { | ||||||
|  |         filters.evaluate(ctx.txn, ctx.index)? | ||||||
|  |     } else { | ||||||
|  |         ctx.index.documents_ids(ctx.txn)? | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[allow(clippy::too_many_arguments)] | ||||||
|  | pub fn execute_vector_search( | ||||||
|  |     ctx: &mut SearchContext, | ||||||
|  |     vector: &[f32], | ||||||
|  |     scoring_strategy: ScoringStrategy, | ||||||
|  |     universe: RoaringBitmap, | ||||||
|  |     sort_criteria: &Option<Vec<AscDesc>>, | ||||||
|  |     geo_strategy: geo_sort::Strategy, | ||||||
|  |     from: usize, | ||||||
|  |     length: usize, | ||||||
|  |     distribution_shift: Option<DistributionShift>, | ||||||
|  |     embedder_name: &str, | ||||||
|  | ) -> Result<PartialSearchResult> { | ||||||
|  |     check_sort_criteria(ctx, sort_criteria.as_ref())?; | ||||||
|  |  | ||||||
|  |     // FIXME: input universe = universe & documents_with_vectors | ||||||
|  |     // for now if we're computing embeddings for ALL documents, we can assume that this is just universe | ||||||
|  |     let ranking_rules = get_ranking_rules_for_vector( | ||||||
|  |         ctx, | ||||||
|  |         sort_criteria, | ||||||
|  |         geo_strategy, | ||||||
|  |         from + length, | ||||||
|  |         vector, | ||||||
|  |         distribution_shift, | ||||||
|  |         embedder_name, | ||||||
|  |     )?; | ||||||
|  |  | ||||||
|  |     let mut placeholder_search_logger = logger::DefaultSearchLogger; | ||||||
|  |     let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> = | ||||||
|  |         &mut placeholder_search_logger; | ||||||
|  |  | ||||||
|  |     let BucketSortOutput { docids, scores, all_candidates } = bucket_sort( | ||||||
|  |         ctx, | ||||||
|  |         ranking_rules, | ||||||
|  |         &PlaceholderQuery, | ||||||
|  |         &universe, | ||||||
|  |         from, | ||||||
|  |         length, | ||||||
|  |         scoring_strategy, | ||||||
|  |         placeholder_search_logger, | ||||||
|  |     )?; | ||||||
|  |  | ||||||
|  |     Ok(PartialSearchResult { | ||||||
|  |         candidates: all_candidates, | ||||||
|  |         document_scores: scores, | ||||||
|  |         documents_ids: docids, | ||||||
|  |         located_query_terms: None, | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  |  | ||||||
| #[allow(clippy::too_many_arguments)] | #[allow(clippy::too_many_arguments)] | ||||||
| pub fn execute_search( | pub fn execute_search( | ||||||
|     ctx: &mut SearchContext, |     ctx: &mut SearchContext, | ||||||
|     query: &Option<String>, |     query: Option<&str>, | ||||||
|     vector: &Option<Vec<f32>>, |  | ||||||
|     terms_matching_strategy: TermsMatchingStrategy, |     terms_matching_strategy: TermsMatchingStrategy, | ||||||
|     scoring_strategy: ScoringStrategy, |     scoring_strategy: ScoringStrategy, | ||||||
|     exhaustive_number_hits: bool, |     exhaustive_number_hits: bool, | ||||||
|     filters: &Option<Filter>, |     mut universe: RoaringBitmap, | ||||||
|     sort_criteria: &Option<Vec<AscDesc>>, |     sort_criteria: &Option<Vec<AscDesc>>, | ||||||
|     geo_strategy: geo_sort::Strategy, |     geo_strategy: geo_sort::Strategy, | ||||||
|     from: usize, |     from: usize, | ||||||
| @@ -439,60 +571,8 @@ pub fn execute_search( | |||||||
|     placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>, |     placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>, | ||||||
|     query_graph_logger: &mut dyn SearchLogger<QueryGraph>, |     query_graph_logger: &mut dyn SearchLogger<QueryGraph>, | ||||||
| ) -> Result<PartialSearchResult> { | ) -> Result<PartialSearchResult> { | ||||||
|     let mut universe = if let Some(filters) = filters { |  | ||||||
|         filters.evaluate(ctx.txn, ctx.index)? |  | ||||||
|     } else { |  | ||||||
|         ctx.index.documents_ids(ctx.txn)? |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     check_sort_criteria(ctx, sort_criteria.as_ref())?; |     check_sort_criteria(ctx, sort_criteria.as_ref())?; | ||||||
|  |  | ||||||
|     if let Some(vector) = vector { |  | ||||||
|         let mut search = Search::default(); |  | ||||||
|         let docids = match ctx.index.vector_hnsw(ctx.txn)? { |  | ||||||
|             Some(hnsw) => { |  | ||||||
|                 if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { |  | ||||||
|                     if vector.len() != expected_size { |  | ||||||
|                         return Err(UserError::InvalidVectorDimensions { |  | ||||||
|                             expected: expected_size, |  | ||||||
|                             found: vector.len(), |  | ||||||
|                         } |  | ||||||
|                         .into()); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 let vector = NDotProductPoint::new(vector.clone()); |  | ||||||
|  |  | ||||||
|                 let neighbors = hnsw.search(&vector, &mut search); |  | ||||||
|  |  | ||||||
|                 let mut docids = Vec::new(); |  | ||||||
|                 let mut uniq_docids = RoaringBitmap::new(); |  | ||||||
|                 for instant_distance::Item { distance: _, pid, point: _ } in neighbors { |  | ||||||
|                     let index = pid.into_inner(); |  | ||||||
|                     let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); |  | ||||||
|                     if universe.contains(docid) && uniq_docids.insert(docid) { |  | ||||||
|                         docids.push(docid); |  | ||||||
|                         if docids.len() == (from + length) { |  | ||||||
|                             break; |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 // return the nearest documents that are also part of the candidates |  | ||||||
|                 // along with a dummy list of scores that are useless in this context. |  | ||||||
|                 docids.into_iter().skip(from).take(length).collect() |  | ||||||
|             } |  | ||||||
|             None => Vec::new(), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         return Ok(PartialSearchResult { |  | ||||||
|             candidates: universe, |  | ||||||
|             document_scores: vec![Vec::new(); docids.len()], |  | ||||||
|             documents_ids: docids, |  | ||||||
|             located_query_terms: None, |  | ||||||
|         }); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let mut located_query_terms = None; |     let mut located_query_terms = None; | ||||||
|     let query_terms = if let Some(query) = query { |     let query_terms = if let Some(query) = query { | ||||||
|         // We make sure that the analyzer is aware of the stop words |         // We make sure that the analyzer is aware of the stop words | ||||||
| @@ -546,7 +626,7 @@ pub fn execute_search( | |||||||
|             terms_matching_strategy, |             terms_matching_strategy, | ||||||
|         )?; |         )?; | ||||||
|  |  | ||||||
|         universe = |         universe &= | ||||||
|             resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; |             resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; | ||||||
|  |  | ||||||
|         bucket_sort( |         bucket_sort( | ||||||
|   | |||||||
							
								
								
									
										170
									
								
								milli/src/search/new/vector_sort.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								milli/src/search/new/vector_sort.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,170 @@ | |||||||
|  | use std::iter::FromIterator; | ||||||
|  |  | ||||||
|  | use ordered_float::OrderedFloat; | ||||||
|  | use roaring::RoaringBitmap; | ||||||
|  |  | ||||||
|  | use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; | ||||||
|  | use crate::score_details::{self, ScoreDetails}; | ||||||
|  | use crate::vector::DistributionShift; | ||||||
|  | use crate::{DocumentId, Result, SearchContext, SearchLogger}; | ||||||
|  |  | ||||||
|  | pub struct VectorSort<Q: RankingRuleQueryTrait> { | ||||||
|  |     query: Option<Q>, | ||||||
|  |     target: Vec<f32>, | ||||||
|  |     vector_candidates: RoaringBitmap, | ||||||
|  |     cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>, | ||||||
|  |     limit: usize, | ||||||
|  |     distribution_shift: Option<DistributionShift>, | ||||||
|  |     embedder_index: u8, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<Q: RankingRuleQueryTrait> VectorSort<Q> { | ||||||
|  |     pub fn new( | ||||||
|  |         ctx: &SearchContext, | ||||||
|  |         target: Vec<f32>, | ||||||
|  |         vector_candidates: RoaringBitmap, | ||||||
|  |         limit: usize, | ||||||
|  |         distribution_shift: Option<DistributionShift>, | ||||||
|  |         embedder_name: &str, | ||||||
|  |     ) -> Result<Self> { | ||||||
|  |         let embedder_index = ctx | ||||||
|  |             .index | ||||||
|  |             .embedder_category_id | ||||||
|  |             .get(ctx.txn, embedder_name)? | ||||||
|  |             .ok_or_else(|| crate::UserError::InvalidEmbedder(embedder_name.to_owned()))?; | ||||||
|  |  | ||||||
|  |         Ok(Self { | ||||||
|  |             query: None, | ||||||
|  |             target, | ||||||
|  |             vector_candidates, | ||||||
|  |             cached_sorted_docids: Default::default(), | ||||||
|  |             limit, | ||||||
|  |             distribution_shift, | ||||||
|  |             embedder_index, | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn fill_buffer( | ||||||
|  |         &mut self, | ||||||
|  |         ctx: &mut SearchContext<'_>, | ||||||
|  |         vector_candidates: &RoaringBitmap, | ||||||
|  |     ) -> Result<()> { | ||||||
|  |         let writer_index = (self.embedder_index as u16) << 8; | ||||||
|  |         let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX) | ||||||
|  |             .map_while(|k| { | ||||||
|  |                 arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy) | ||||||
|  |                     .map(Some) | ||||||
|  |                     .or_else(|e| match e { | ||||||
|  |                         arroy::Error::MissingMetadata => Ok(None), | ||||||
|  |                         e => Err(e), | ||||||
|  |                     }) | ||||||
|  |                     .transpose() | ||||||
|  |             }) | ||||||
|  |             .collect(); | ||||||
|  |  | ||||||
|  |         let readers = readers?; | ||||||
|  |  | ||||||
|  |         let target = &self.target; | ||||||
|  |         let mut results = Vec::new(); | ||||||
|  |  | ||||||
|  |         for reader in readers.iter() { | ||||||
|  |             let nns_by_vector = | ||||||
|  |                 reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?; | ||||||
|  |             let vectors: std::result::Result<Vec<_>, _> = nns_by_vector | ||||||
|  |                 .iter() | ||||||
|  |                 .map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap()) | ||||||
|  |                 .collect(); | ||||||
|  |             let vectors = vectors?; | ||||||
|  |             results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z))); | ||||||
|  |         } | ||||||
|  |         results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance)); | ||||||
|  |         self.cached_sorted_docids = results.into_iter(); | ||||||
|  |  | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> { | ||||||
|  |     fn id(&self) -> String { | ||||||
|  |         "vector_sort".to_owned() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn start_iteration( | ||||||
|  |         &mut self, | ||||||
|  |         ctx: &mut SearchContext<'ctx>, | ||||||
|  |         _logger: &mut dyn SearchLogger<Q>, | ||||||
|  |         universe: &RoaringBitmap, | ||||||
|  |         query: &Q, | ||||||
|  |     ) -> Result<()> { | ||||||
|  |         assert!(self.query.is_none()); | ||||||
|  |  | ||||||
|  |         self.query = Some(query.clone()); | ||||||
|  |         let vector_candidates = &self.vector_candidates & universe; | ||||||
|  |         self.fill_buffer(ctx, &vector_candidates)?; | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     #[allow(clippy::only_used_in_recursion)] | ||||||
|  |     fn next_bucket( | ||||||
|  |         &mut self, | ||||||
|  |         ctx: &mut SearchContext<'ctx>, | ||||||
|  |         _logger: &mut dyn SearchLogger<Q>, | ||||||
|  |         universe: &RoaringBitmap, | ||||||
|  |     ) -> Result<Option<RankingRuleOutput<Q>>> { | ||||||
|  |         let query = self.query.as_ref().unwrap().clone(); | ||||||
|  |         let vector_candidates = &self.vector_candidates & universe; | ||||||
|  |  | ||||||
|  |         if vector_candidates.is_empty() { | ||||||
|  |             return Ok(Some(RankingRuleOutput { | ||||||
|  |                 query, | ||||||
|  |                 candidates: universe.clone(), | ||||||
|  |                 score: ScoreDetails::Vector(score_details::Vector { | ||||||
|  |                     target_vector: self.target.clone(), | ||||||
|  |                     value_similarity: None, | ||||||
|  |                 }), | ||||||
|  |             })); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         for (docid, distance, vector) in self.cached_sorted_docids.by_ref() { | ||||||
|  |             if vector_candidates.contains(docid) { | ||||||
|  |                 let score = 1.0 - distance; | ||||||
|  |                 let score = self | ||||||
|  |                     .distribution_shift | ||||||
|  |                     .map(|distribution| distribution.shift(score)) | ||||||
|  |                     .unwrap_or(score); | ||||||
|  |                 return Ok(Some(RankingRuleOutput { | ||||||
|  |                     query, | ||||||
|  |                     candidates: RoaringBitmap::from_iter([docid]), | ||||||
|  |                     score: ScoreDetails::Vector(score_details::Vector { | ||||||
|  |                         target_vector: self.target.clone(), | ||||||
|  |                         value_similarity: Some((vector, score)), | ||||||
|  |                     }), | ||||||
|  |                 })); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // if we got out of this loop it means we've exhausted our cache. | ||||||
|  |         // we need to refill it and run the function again. | ||||||
|  |         self.fill_buffer(ctx, &vector_candidates)?; | ||||||
|  |  | ||||||
|  |         // we tried filling the buffer, but it remained empty 😢 | ||||||
|  |         // it means we don't actually have any document remaining in the universe with a vector. | ||||||
|  |         // => exit | ||||||
|  |         if self.cached_sorted_docids.len() == 0 { | ||||||
|  |             return Ok(Some(RankingRuleOutput { | ||||||
|  |                 query, | ||||||
|  |                 candidates: universe.clone(), | ||||||
|  |                 score: ScoreDetails::Vector(score_details::Vector { | ||||||
|  |                     target_vector: self.target.clone(), | ||||||
|  |                     value_similarity: None, | ||||||
|  |                 }), | ||||||
|  |             })); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         self.next_bucket(ctx, _logger, universe) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) { | ||||||
|  |         self.query = None; | ||||||
|  |     } | ||||||
|  | } | ||||||
| @@ -42,7 +42,8 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { | |||||||
|             facet_id_is_empty_docids, |             facet_id_is_empty_docids, | ||||||
|             field_id_docid_facet_f64s, |             field_id_docid_facet_f64s, | ||||||
|             field_id_docid_facet_strings, |             field_id_docid_facet_strings, | ||||||
|             vector_id_docid, |             vector_arroy, | ||||||
|  |             embedder_category_id: _, | ||||||
|             documents, |             documents, | ||||||
|         } = self.index; |         } = self.index; | ||||||
|  |  | ||||||
| @@ -58,7 +59,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { | |||||||
|         self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; |         self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; | ||||||
|         self.index.delete_geo_rtree(self.wtxn)?; |         self.index.delete_geo_rtree(self.wtxn)?; | ||||||
|         self.index.delete_geo_faceted_documents_ids(self.wtxn)?; |         self.index.delete_geo_faceted_documents_ids(self.wtxn)?; | ||||||
|         self.index.delete_vector_hnsw(self.wtxn)?; |  | ||||||
|  |  | ||||||
|         // Clear the other databases. |         // Clear the other databases. | ||||||
|         external_documents_ids.clear(self.wtxn)?; |         external_documents_ids.clear(self.wtxn)?; | ||||||
| @@ -82,7 +82,9 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { | |||||||
|         facet_id_string_docids.clear(self.wtxn)?; |         facet_id_string_docids.clear(self.wtxn)?; | ||||||
|         field_id_docid_facet_f64s.clear(self.wtxn)?; |         field_id_docid_facet_f64s.clear(self.wtxn)?; | ||||||
|         field_id_docid_facet_strings.clear(self.wtxn)?; |         field_id_docid_facet_strings.clear(self.wtxn)?; | ||||||
|         vector_id_docid.clear(self.wtxn)?; |         // vector | ||||||
|  |         vector_arroy.clear(self.wtxn)?; | ||||||
|  |  | ||||||
|         documents.clear(self.wtxn)?; |         documents.clear(self.wtxn)?; | ||||||
|  |  | ||||||
|         Ok(number_of_documents) |         Ok(number_of_documents) | ||||||
|   | |||||||
| @@ -1,9 +1,10 @@ | |||||||
| use std::cmp::Ordering; | use std::cmp::Ordering; | ||||||
| use std::convert::TryFrom; | use std::convert::{TryFrom, TryInto}; | ||||||
| use std::fs::File; | use std::fs::File; | ||||||
| use std::io::{self, BufReader, BufWriter}; | use std::io::{self, BufReader, BufWriter}; | ||||||
| use std::mem::size_of; | use std::mem::size_of; | ||||||
| use std::str::from_utf8; | use std::str::from_utf8; | ||||||
|  | use std::sync::Arc; | ||||||
|  |  | ||||||
| use bytemuck::cast_slice; | use bytemuck::cast_slice; | ||||||
| use grenad::Writer; | use grenad::Writer; | ||||||
| @@ -13,13 +14,56 @@ use serde_json::{from_slice, Value}; | |||||||
|  |  | ||||||
| use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; | use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; | ||||||
| use crate::error::UserError; | use crate::error::UserError; | ||||||
|  | use crate::prompt::Prompt; | ||||||
| use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; | use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; | ||||||
| use crate::update::index_documents::helpers::try_split_at; | use crate::update::index_documents::helpers::try_split_at; | ||||||
| use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; | use crate::vector::Embedder; | ||||||
|  | use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors}; | ||||||
|  |  | ||||||
| /// The length of the elements that are always in the buffer when inserting new values. | /// The length of the elements that are always in the buffer when inserting new values. | ||||||
| const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); | const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); | ||||||
|  |  | ||||||
|  | pub struct ExtractedVectorPoints { | ||||||
|  |     // docid, _index -> KvWriterDelAdd -> Vector | ||||||
|  |     pub manual_vectors: grenad::Reader<BufReader<File>>, | ||||||
|  |     // docid -> () | ||||||
|  |     pub remove_vectors: grenad::Reader<BufReader<File>>, | ||||||
|  |     // docid -> prompt | ||||||
|  |     pub prompts: grenad::Reader<BufReader<File>>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | enum VectorStateDelta { | ||||||
|  |     NoChange, | ||||||
|  |     // Remove all vectors, generated or manual, from this document | ||||||
|  |     NowRemoved, | ||||||
|  |  | ||||||
|  |     // Add the manually specified vectors, passed in the other grenad | ||||||
|  |     // Remove any previously generated vectors | ||||||
|  |     // Note: changing the value of the manually specified vector **should not record** this delta | ||||||
|  |     WasGeneratedNowManual(Vec<Vec<f32>>), | ||||||
|  |  | ||||||
|  |     ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>), | ||||||
|  |  | ||||||
|  |     // Add the vector computed from the specified prompt | ||||||
|  |     // Remove any previous vector | ||||||
|  |     // Note: changing the value of the prompt **does require** recording this delta | ||||||
|  |     NowGenerated(String), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl VectorStateDelta { | ||||||
|  |     fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) { | ||||||
|  |         match self { | ||||||
|  |             VectorStateDelta::NoChange => Default::default(), | ||||||
|  |             VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()), | ||||||
|  |             VectorStateDelta::WasGeneratedNowManual(add) => { | ||||||
|  |                 (true, Default::default(), (Default::default(), add)) | ||||||
|  |             } | ||||||
|  |             VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)), | ||||||
|  |             VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| /// Extracts the embedding vector contained in each document under the `_vectors` field. | /// Extracts the embedding vector contained in each document under the `_vectors` field. | ||||||
| /// | /// | ||||||
| /// Returns the generated grenad reader containing the docid as key associated to the Vec<f32> | /// Returns the generated grenad reader containing the docid as key associated to the Vec<f32> | ||||||
| @@ -27,16 +71,35 @@ const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); | |||||||
| pub fn extract_vector_points<R: io::Read + io::Seek>( | pub fn extract_vector_points<R: io::Read + io::Seek>( | ||||||
|     obkv_documents: grenad::Reader<R>, |     obkv_documents: grenad::Reader<R>, | ||||||
|     indexer: GrenadParameters, |     indexer: GrenadParameters, | ||||||
|     vectors_fid: FieldId, |     field_id_map: &FieldsIdsMap, | ||||||
| ) -> Result<grenad::Reader<BufReader<File>>> { |     prompt: &Prompt, | ||||||
|  |     embedder_name: &str, | ||||||
|  | ) -> Result<ExtractedVectorPoints> { | ||||||
|     puffin::profile_function!(); |     puffin::profile_function!(); | ||||||
|  |  | ||||||
|     let mut writer = create_writer( |     // (docid, _index) -> KvWriterDelAdd -> Vector | ||||||
|  |     let mut manual_vectors_writer = create_writer( | ||||||
|         indexer.chunk_compression_type, |         indexer.chunk_compression_type, | ||||||
|         indexer.chunk_compression_level, |         indexer.chunk_compression_level, | ||||||
|         tempfile::tempfile()?, |         tempfile::tempfile()?, | ||||||
|     ); |     ); | ||||||
|  |  | ||||||
|  |     // (docid) -> (prompt) | ||||||
|  |     let mut prompts_writer = create_writer( | ||||||
|  |         indexer.chunk_compression_type, | ||||||
|  |         indexer.chunk_compression_level, | ||||||
|  |         tempfile::tempfile()?, | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     // (docid) -> () | ||||||
|  |     let mut remove_vectors_writer = create_writer( | ||||||
|  |         indexer.chunk_compression_type, | ||||||
|  |         indexer.chunk_compression_level, | ||||||
|  |         tempfile::tempfile()?, | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     let vectors_fid = field_id_map.id("_vectors"); | ||||||
|  |  | ||||||
|     let mut key_buffer = Vec::new(); |     let mut key_buffer = Vec::new(); | ||||||
|     let mut cursor = obkv_documents.into_cursor()?; |     let mut cursor = obkv_documents.into_cursor()?; | ||||||
|     while let Some((key, value)) = cursor.move_on_next()? { |     while let Some((key, value)) = cursor.move_on_next()? { | ||||||
| @@ -53,43 +116,157 @@ pub fn extract_vector_points<R: io::Read + io::Seek>( | |||||||
|         // lazily get it when needed |         // lazily get it when needed | ||||||
|         let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; |         let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; | ||||||
|  |  | ||||||
|         // first we retrieve the _vectors field |         let vectors_field = vectors_fid | ||||||
|         if let Some(value) = obkv.get(vectors_fid) { |             .and_then(|vectors_fid| obkv.get(vectors_fid)) | ||||||
|             let vectors_obkv = KvReaderDelAdd::new(value); |             .map(KvReaderDelAdd::new) | ||||||
|  |             .map(|obkv| to_vector_maps(obkv, document_id)) | ||||||
|  |             .transpose()?; | ||||||
|  |  | ||||||
|             // then we extract the values |         let (del_map, add_map) = vectors_field.unzip(); | ||||||
|             let del_vectors = vectors_obkv |         let del_map = del_map.flatten(); | ||||||
|                 .get(DelAdd::Deletion) |         let add_map = add_map.flatten(); | ||||||
|                 .map(|vectors| extract_vectors(vectors, document_id)) |  | ||||||
|                 .transpose()? |         let del_value = del_map.and_then(|mut map| map.remove(embedder_name)); | ||||||
|                 .flatten(); |         let add_value = add_map.and_then(|mut map| map.remove(embedder_name)); | ||||||
|             let add_vectors = vectors_obkv |  | ||||||
|                 .get(DelAdd::Addition) |         let delta = match (del_value, add_value) { | ||||||
|                 .map(|vectors| extract_vectors(vectors, document_id)) |             (Some(old), Some(new)) => { | ||||||
|                 .transpose()? |                 // no autogeneration | ||||||
|                 .flatten(); |                 let del_vectors = extract_vectors(old, document_id, embedder_name)?; | ||||||
|  |                 let add_vectors = extract_vectors(new, document_id, embedder_name)?; | ||||||
|  |  | ||||||
|  |                 if add_vectors.len() > u8::MAX.into() { | ||||||
|  |                     return Err(crate::Error::UserError(crate::UserError::TooManyVectors( | ||||||
|  |                         document_id().to_string(), | ||||||
|  |                         add_vectors.len(), | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 VectorStateDelta::ManualDelta(del_vectors, add_vectors) | ||||||
|  |             } | ||||||
|  |             (Some(_old), None) => { | ||||||
|  |                 // Do we keep this document? | ||||||
|  |                 let document_is_kept = obkv | ||||||
|  |                     .iter() | ||||||
|  |                     .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) | ||||||
|  |                     .any(|deladd| deladd.get(DelAdd::Addition).is_some()); | ||||||
|  |                 if document_is_kept { | ||||||
|  |                     // becomes autogenerated | ||||||
|  |                     VectorStateDelta::NowGenerated(prompt.render( | ||||||
|  |                         obkv, | ||||||
|  |                         DelAdd::Addition, | ||||||
|  |                         field_id_map, | ||||||
|  |                     )?) | ||||||
|  |                 } else { | ||||||
|  |                     VectorStateDelta::NowRemoved | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             (None, Some(new)) => { | ||||||
|  |                 // was possibly autogenerated, remove all vectors for that document | ||||||
|  |                 let add_vectors = extract_vectors(new, document_id, embedder_name)?; | ||||||
|  |                 if add_vectors.len() > u8::MAX.into() { | ||||||
|  |                     return Err(crate::Error::UserError(crate::UserError::TooManyVectors( | ||||||
|  |                         document_id().to_string(), | ||||||
|  |                         add_vectors.len(), | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 VectorStateDelta::WasGeneratedNowManual(add_vectors) | ||||||
|  |             } | ||||||
|  |             (None, None) => { | ||||||
|  |                 // Do we keep this document? | ||||||
|  |                 let document_is_kept = obkv | ||||||
|  |                     .iter() | ||||||
|  |                     .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) | ||||||
|  |                     .any(|deladd| deladd.get(DelAdd::Addition).is_some()); | ||||||
|  |  | ||||||
|  |                 if document_is_kept { | ||||||
|  |                     // Don't give up if the old prompt was failing | ||||||
|  |                     let old_prompt = | ||||||
|  |                         prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); | ||||||
|  |                     let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; | ||||||
|  |                     if old_prompt != new_prompt { | ||||||
|  |                         log::trace!( | ||||||
|  |                             "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" | ||||||
|  |                         ); | ||||||
|  |                         VectorStateDelta::NowGenerated(new_prompt) | ||||||
|  |                     } else { | ||||||
|  |                         log::trace!("⏭️ Prompt unmodified, skipping"); | ||||||
|  |                         VectorStateDelta::NoChange | ||||||
|  |                     } | ||||||
|  |                 } else { | ||||||
|  |                     VectorStateDelta::NowRemoved | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |  | ||||||
|         // and we finally push the unique vectors into the writer |         // and we finally push the unique vectors into the writer | ||||||
|         push_vectors_diff( |         push_vectors_diff( | ||||||
|                 &mut writer, |             &mut remove_vectors_writer, | ||||||
|  |             &mut prompts_writer, | ||||||
|  |             &mut manual_vectors_writer, | ||||||
|             &mut key_buffer, |             &mut key_buffer, | ||||||
|                 del_vectors.unwrap_or_default(), |             delta, | ||||||
|                 add_vectors.unwrap_or_default(), |  | ||||||
|         )?; |         )?; | ||||||
|     } |     } | ||||||
|     } |  | ||||||
|  |  | ||||||
|     writer_into_reader(writer) |     Ok(ExtractedVectorPoints { | ||||||
|  |         // docid, _index -> KvWriterDelAdd -> Vector | ||||||
|  |         manual_vectors: writer_into_reader(manual_vectors_writer)?, | ||||||
|  |         // docid -> () | ||||||
|  |         remove_vectors: writer_into_reader(remove_vectors_writer)?, | ||||||
|  |         // docid -> prompt | ||||||
|  |         prompts: writer_into_reader(prompts_writer)?, | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn to_vector_maps( | ||||||
|  |     obkv: KvReaderDelAdd, | ||||||
|  |     document_id: impl Fn() -> Value, | ||||||
|  | ) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> { | ||||||
|  |     let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?; | ||||||
|  |     let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?; | ||||||
|  |     Ok((del, add)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn to_vector_map( | ||||||
|  |     obkv: KvReaderDelAdd, | ||||||
|  |     side: DelAdd, | ||||||
|  |     document_id: &impl Fn() -> Value, | ||||||
|  | ) -> Result<Option<serde_json::Map<String, Value>>> { | ||||||
|  |     Ok(if let Some(value) = obkv.get(side) { | ||||||
|  |         let Ok(value) = from_slice(value) else { | ||||||
|  |             let value = from_slice(value).map_err(InternalError::SerdeJson)?; | ||||||
|  |             return Err(crate::Error::UserError(UserError::InvalidVectorsMapType { | ||||||
|  |                 document_id: document_id(), | ||||||
|  |                 value, | ||||||
|  |             })); | ||||||
|  |         }; | ||||||
|  |         Some(value) | ||||||
|  |     } else { | ||||||
|  |         None | ||||||
|  |     }) | ||||||
| } | } | ||||||
|  |  | ||||||
| /// Computes the diff between both Del and Add numbers and | /// Computes the diff between both Del and Add numbers and | ||||||
| /// only inserts the parts that differ in the sorter. | /// only inserts the parts that differ in the sorter. | ||||||
| fn push_vectors_diff( | fn push_vectors_diff( | ||||||
|     writer: &mut Writer<BufWriter<File>>, |     remove_vectors_writer: &mut Writer<BufWriter<File>>, | ||||||
|  |     prompts_writer: &mut Writer<BufWriter<File>>, | ||||||
|  |     manual_vectors_writer: &mut Writer<BufWriter<File>>, | ||||||
|     key_buffer: &mut Vec<u8>, |     key_buffer: &mut Vec<u8>, | ||||||
|     mut del_vectors: Vec<Vec<f32>>, |     delta: VectorStateDelta, | ||||||
|     mut add_vectors: Vec<Vec<f32>>, |  | ||||||
| ) -> Result<()> { | ) -> Result<()> { | ||||||
|  |     let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); | ||||||
|  |     if must_remove { | ||||||
|  |         key_buffer.truncate(TRUNCATE_SIZE); | ||||||
|  |         remove_vectors_writer.insert(&key_buffer, [])?; | ||||||
|  |     } | ||||||
|  |     if !prompt.is_empty() { | ||||||
|  |         key_buffer.truncate(TRUNCATE_SIZE); | ||||||
|  |         prompts_writer.insert(&key_buffer, prompt.as_bytes())?; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // We sort and dedup the vectors |     // We sort and dedup the vectors | ||||||
|     del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); |     del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); | ||||||
|     add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); |     add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); | ||||||
| @@ -114,7 +291,7 @@ fn push_vectors_diff( | |||||||
|                 let mut obkv = KvWriterDelAdd::memory(); |                 let mut obkv = KvWriterDelAdd::memory(); | ||||||
|                 obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; |                 obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; | ||||||
|                 let bytes = obkv.into_inner()?; |                 let bytes = obkv.into_inner()?; | ||||||
|                 writer.insert(&key_buffer, bytes)?; |                 manual_vectors_writer.insert(&key_buffer, bytes)?; | ||||||
|             } |             } | ||||||
|             EitherOrBoth::Right(vector) => { |             EitherOrBoth::Right(vector) => { | ||||||
|                 // We insert only the Add part of the Obkv to inform |                 // We insert only the Add part of the Obkv to inform | ||||||
| @@ -122,7 +299,7 @@ fn push_vectors_diff( | |||||||
|                 let mut obkv = KvWriterDelAdd::memory(); |                 let mut obkv = KvWriterDelAdd::memory(); | ||||||
|                 obkv.insert(DelAdd::Addition, cast_slice(&vector))?; |                 obkv.insert(DelAdd::Addition, cast_slice(&vector))?; | ||||||
|                 let bytes = obkv.into_inner()?; |                 let bytes = obkv.into_inner()?; | ||||||
|                 writer.insert(&key_buffer, bytes)?; |                 manual_vectors_writer.insert(&key_buffer, bytes)?; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -136,13 +313,112 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering { | |||||||
| } | } | ||||||
|  |  | ||||||
| /// Extracts the vectors from a JSON value. | /// Extracts the vectors from a JSON value. | ||||||
| fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> { | fn extract_vectors( | ||||||
|     match from_slice(value) { |     value: Value, | ||||||
|         Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), |     document_id: impl Fn() -> Value, | ||||||
|  |     name: &str, | ||||||
|  | ) -> Result<Vec<Vec<f32>>> { | ||||||
|  |     // FIXME: ugly clone of the vectors here | ||||||
|  |     match serde_json::from_value(value.clone()) { | ||||||
|  |         Ok(vectors) => { | ||||||
|  |             Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default()) | ||||||
|  |         } | ||||||
|         Err(_) => Err(UserError::InvalidVectorsType { |         Err(_) => Err(UserError::InvalidVectorsType { | ||||||
|             document_id: document_id(), |             document_id: document_id(), | ||||||
|             value: from_slice(value).map_err(InternalError::SerdeJson)?, |             value, | ||||||
|  |             subfield: name.to_owned(), | ||||||
|         } |         } | ||||||
|         .into()), |         .into()), | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[logging_timer::time] | ||||||
|  | pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||||
|  |     // docid, prompt | ||||||
|  |     prompt_reader: grenad::Reader<R>, | ||||||
|  |     indexer: GrenadParameters, | ||||||
|  |     embedder: Arc<Embedder>, | ||||||
|  | ) -> Result<grenad::Reader<BufReader<File>>> { | ||||||
|  |     let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; | ||||||
|  |  | ||||||
|  |     let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism | ||||||
|  |     let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk | ||||||
|  |  | ||||||
|  |     // docid, state with embedding | ||||||
|  |     let mut state_writer = create_writer( | ||||||
|  |         indexer.chunk_compression_type, | ||||||
|  |         indexer.chunk_compression_level, | ||||||
|  |         tempfile::tempfile()?, | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     let mut chunks = Vec::with_capacity(n_chunks); | ||||||
|  |     let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk); | ||||||
|  |     let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk); | ||||||
|  |     let mut chunks_ids = Vec::with_capacity(n_chunks); | ||||||
|  |     let mut cursor = prompt_reader.into_cursor()?; | ||||||
|  |  | ||||||
|  |     while let Some((key, value)) = cursor.move_on_next()? { | ||||||
|  |         let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); | ||||||
|  |         // SAFETY: precondition, the grenad value was saved from a string | ||||||
|  |         let prompt = unsafe { std::str::from_utf8_unchecked(value) }; | ||||||
|  |         if current_chunk.len() == current_chunk.capacity() { | ||||||
|  |             chunks.push(std::mem::replace( | ||||||
|  |                 &mut current_chunk, | ||||||
|  |                 Vec::with_capacity(n_vectors_per_chunk), | ||||||
|  |             )); | ||||||
|  |             chunks_ids.push(std::mem::replace( | ||||||
|  |                 &mut current_chunk_ids, | ||||||
|  |                 Vec::with_capacity(n_vectors_per_chunk), | ||||||
|  |             )); | ||||||
|  |         }; | ||||||
|  |         current_chunk.push(prompt.to_owned()); | ||||||
|  |         current_chunk_ids.push(docid); | ||||||
|  |  | ||||||
|  |         if chunks.len() == chunks.capacity() { | ||||||
|  |             let chunked_embeds = rt | ||||||
|  |                 .block_on( | ||||||
|  |                     embedder | ||||||
|  |                         .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), | ||||||
|  |                 ) | ||||||
|  |                 .map_err(crate::vector::Error::from) | ||||||
|  |                 .map_err(crate::Error::from)?; | ||||||
|  |  | ||||||
|  |             for (docid, embeddings) in chunks_ids | ||||||
|  |                 .iter() | ||||||
|  |                 .flat_map(|docids| docids.iter()) | ||||||
|  |                 .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) | ||||||
|  |             { | ||||||
|  |                 state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||||
|  |             } | ||||||
|  |             chunks_ids.clear(); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // send last chunk | ||||||
|  |     if !chunks.is_empty() { | ||||||
|  |         let chunked_embeds = rt | ||||||
|  |             .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) | ||||||
|  |             .map_err(crate::vector::Error::from) | ||||||
|  |             .map_err(crate::Error::from)?; | ||||||
|  |         for (docid, embeddings) in chunks_ids | ||||||
|  |             .iter() | ||||||
|  |             .flat_map(|docids| docids.iter()) | ||||||
|  |             .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) | ||||||
|  |         { | ||||||
|  |             state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if !current_chunk.is_empty() { | ||||||
|  |         let embeds = rt | ||||||
|  |             .block_on(embedder.embed(std::mem::take(&mut current_chunk))) | ||||||
|  |             .map_err(crate::vector::Error::from) | ||||||
|  |             .map_err(crate::Error::from)?; | ||||||
|  |  | ||||||
|  |         for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { | ||||||
|  |             state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     writer_into_reader(state_writer) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -23,7 +23,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids; | |||||||
| use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; | use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; | ||||||
| use self::extract_fid_word_count_docids::extract_fid_word_count_docids; | use self::extract_fid_word_count_docids::extract_fid_word_count_docids; | ||||||
| use self::extract_geo_points::extract_geo_points; | use self::extract_geo_points::extract_geo_points; | ||||||
| use self::extract_vector_points::extract_vector_points; | use self::extract_vector_points::{ | ||||||
|  |     extract_embeddings, extract_vector_points, ExtractedVectorPoints, | ||||||
|  | }; | ||||||
| use self::extract_word_docids::extract_word_docids; | use self::extract_word_docids::extract_word_docids; | ||||||
| use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; | use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; | ||||||
| use self::extract_word_position_docids::extract_word_position_docids; | use self::extract_word_position_docids::extract_word_position_docids; | ||||||
| @@ -33,7 +35,8 @@ use super::helpers::{ | |||||||
| }; | }; | ||||||
| use super::{helpers, TypedChunk}; | use super::{helpers, TypedChunk}; | ||||||
| use crate::proximity::ProximityPrecision; | use crate::proximity::ProximityPrecision; | ||||||
| use crate::{FieldId, Result}; | use crate::vector::EmbeddingConfigs; | ||||||
|  | use crate::{FieldId, FieldsIdsMap, Result}; | ||||||
|  |  | ||||||
| /// Extract data for each databases from obkv documents in parallel. | /// Extract data for each databases from obkv documents in parallel. | ||||||
| /// Send data in grenad file over provided Sender. | /// Send data in grenad file over provided Sender. | ||||||
| @@ -47,13 +50,14 @@ pub(crate) fn data_from_obkv_documents( | |||||||
|     faceted_fields: HashSet<FieldId>, |     faceted_fields: HashSet<FieldId>, | ||||||
|     primary_key_id: FieldId, |     primary_key_id: FieldId, | ||||||
|     geo_fields_ids: Option<(FieldId, FieldId)>, |     geo_fields_ids: Option<(FieldId, FieldId)>, | ||||||
|     vectors_field_id: Option<FieldId>, |     field_id_map: FieldsIdsMap, | ||||||
|     stop_words: Option<fst::Set<&[u8]>>, |     stop_words: Option<fst::Set<&[u8]>>, | ||||||
|     allowed_separators: Option<&[&str]>, |     allowed_separators: Option<&[&str]>, | ||||||
|     dictionary: Option<&[&str]>, |     dictionary: Option<&[&str]>, | ||||||
|     max_positions_per_attributes: Option<u32>, |     max_positions_per_attributes: Option<u32>, | ||||||
|     exact_attributes: HashSet<FieldId>, |     exact_attributes: HashSet<FieldId>, | ||||||
|     proximity_precision: ProximityPrecision, |     proximity_precision: ProximityPrecision, | ||||||
|  |     embedders: EmbeddingConfigs, | ||||||
| ) -> Result<()> { | ) -> Result<()> { | ||||||
|     puffin::profile_function!(); |     puffin::profile_function!(); | ||||||
|  |  | ||||||
| @@ -64,7 +68,8 @@ pub(crate) fn data_from_obkv_documents( | |||||||
|                 original_documents_chunk, |                 original_documents_chunk, | ||||||
|                 indexer, |                 indexer, | ||||||
|                 lmdb_writer_sx.clone(), |                 lmdb_writer_sx.clone(), | ||||||
|                 vectors_field_id, |                 field_id_map.clone(), | ||||||
|  |                 embedders.clone(), | ||||||
|             ) |             ) | ||||||
|         }) |         }) | ||||||
|         .collect::<Result<()>>()?; |         .collect::<Result<()>>()?; | ||||||
| @@ -276,24 +281,53 @@ fn send_original_documents_data( | |||||||
|     original_documents_chunk: Result<grenad::Reader<BufReader<File>>>, |     original_documents_chunk: Result<grenad::Reader<BufReader<File>>>, | ||||||
|     indexer: GrenadParameters, |     indexer: GrenadParameters, | ||||||
|     lmdb_writer_sx: Sender<Result<TypedChunk>>, |     lmdb_writer_sx: Sender<Result<TypedChunk>>, | ||||||
|     vectors_field_id: Option<FieldId>, |     field_id_map: FieldsIdsMap, | ||||||
|  |     embedders: EmbeddingConfigs, | ||||||
| ) -> Result<()> { | ) -> Result<()> { | ||||||
|     let original_documents_chunk = |     let original_documents_chunk = | ||||||
|         original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; |         original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; | ||||||
|  |  | ||||||
|     if let Some(vectors_field_id) = vectors_field_id { |  | ||||||
|     let documents_chunk_cloned = original_documents_chunk.clone(); |     let documents_chunk_cloned = original_documents_chunk.clone(); | ||||||
|     let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); |     let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); | ||||||
|     rayon::spawn(move || { |     rayon::spawn(move || { | ||||||
|             let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); |         for (name, (embedder, prompt)) in embedders { | ||||||
|             let _ = match result { |             let result = extract_vector_points( | ||||||
|                 Ok(vector_points) => { |                 documents_chunk_cloned.clone(), | ||||||
|                     lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) |                 indexer, | ||||||
|  |                 &field_id_map, | ||||||
|  |                 &prompt, | ||||||
|  |                 &name, | ||||||
|  |             ); | ||||||
|  |             match result { | ||||||
|  |                 Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { | ||||||
|  |                     let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { | ||||||
|  |                         Ok(results) => Some(results), | ||||||
|  |                         Err(error) => { | ||||||
|  |                             let _ = lmdb_writer_sx_cloned.send(Err(error)); | ||||||
|  |                             None | ||||||
|                         } |                         } | ||||||
|                 Err(error) => lmdb_writer_sx_cloned.send(Err(error)), |  | ||||||
|                     }; |                     }; | ||||||
|         }); |  | ||||||
|  |                     if !(remove_vectors.is_empty() | ||||||
|  |                         && manual_vectors.is_empty() | ||||||
|  |                         && embeddings.as_ref().map_or(true, |e| e.is_empty())) | ||||||
|  |                     { | ||||||
|  |                         let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { | ||||||
|  |                             remove_vectors, | ||||||
|  |                             embeddings, | ||||||
|  |                             expected_dimension: embedder.dimensions(), | ||||||
|  |                             manual_vectors, | ||||||
|  |                             embedder_name: name, | ||||||
|  |                         })); | ||||||
|                     } |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 Err(error) => { | ||||||
|  |                     let _ = lmdb_writer_sx_cloned.send(Err(error)); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     }); | ||||||
|  |  | ||||||
|     // TODO: create a custom internal error |     // TODO: create a custom internal error | ||||||
|     lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); |     lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ mod helpers; | |||||||
| mod transform; | mod transform; | ||||||
| mod typed_chunk; | mod typed_chunk; | ||||||
|  |  | ||||||
| use std::collections::HashSet; | use std::collections::{HashMap, HashSet}; | ||||||
| use std::io::{Cursor, Read, Seek}; | use std::io::{Cursor, Read, Seek}; | ||||||
| use std::iter::FromIterator; | use std::iter::FromIterator; | ||||||
| use std::num::NonZeroU32; | use std::num::NonZeroU32; | ||||||
| @@ -14,6 +14,7 @@ use crossbeam_channel::{Receiver, Sender}; | |||||||
| use heed::types::Str; | use heed::types::Str; | ||||||
| use heed::Database; | use heed::Database; | ||||||
| use log::debug; | use log::debug; | ||||||
|  | use rand::SeedableRng; | ||||||
| use roaring::RoaringBitmap; | use roaring::RoaringBitmap; | ||||||
| use serde::{Deserialize, Serialize}; | use serde::{Deserialize, Serialize}; | ||||||
| use slice_group_by::GroupBy; | use slice_group_by::GroupBy; | ||||||
| @@ -36,6 +37,7 @@ pub use crate::update::index_documents::helpers::CursorClonableMmap; | |||||||
| use crate::update::{ | use crate::update::{ | ||||||
|     IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, |     IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, | ||||||
| }; | }; | ||||||
|  | use crate::vector::EmbeddingConfigs; | ||||||
| use crate::{CboRoaringBitmapCodec, Index, Result}; | use crate::{CboRoaringBitmapCodec, Index, Result}; | ||||||
|  |  | ||||||
| static MERGED_DATABASE_COUNT: usize = 7; | static MERGED_DATABASE_COUNT: usize = 7; | ||||||
| @@ -78,6 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { | |||||||
|     should_abort: FA, |     should_abort: FA, | ||||||
|     added_documents: u64, |     added_documents: u64, | ||||||
|     deleted_documents: u64, |     deleted_documents: u64, | ||||||
|  |     embedders: EmbeddingConfigs, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Default, Debug, Clone)] | #[derive(Default, Debug, Clone)] | ||||||
| @@ -121,6 +124,7 @@ where | |||||||
|             index, |             index, | ||||||
|             added_documents: 0, |             added_documents: 0, | ||||||
|             deleted_documents: 0, |             deleted_documents: 0, | ||||||
|  |             embedders: Default::default(), | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -167,6 +171,11 @@ where | |||||||
|         Ok((self, Ok(indexed_documents))) |         Ok((self, Ok(indexed_documents))) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self { | ||||||
|  |         self.embedders = embedders; | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Remove a batch of documents from the current builder. |     /// Remove a batch of documents from the current builder. | ||||||
|     /// |     /// | ||||||
|     /// Returns the number of documents deleted from the builder. |     /// Returns the number of documents deleted from the builder. | ||||||
| @@ -322,17 +331,18 @@ where | |||||||
|         // get filterable fields for facet databases |         // get filterable fields for facet databases | ||||||
|         let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; |         let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; | ||||||
|         // get the fid of the `_geo.lat` and `_geo.lng` fields. |         // get the fid of the `_geo.lat` and `_geo.lng` fields. | ||||||
|         let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") { |         let mut field_id_map = self.index.fields_ids_map(self.wtxn)?; | ||||||
|  |  | ||||||
|  |         // self.index.fields_ids_map($a)? ==>> field_id_map | ||||||
|  |         let geo_fields_ids = match field_id_map.id("_geo") { | ||||||
|             Some(gfid) => { |             Some(gfid) => { | ||||||
|                 let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); |                 let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); | ||||||
|                 let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); |                 let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); | ||||||
|                 // if `_geo` is faceted then we get the `lat` and `lng` |                 // if `_geo` is faceted then we get the `lat` and `lng` | ||||||
|                 if is_sortable || is_filterable { |                 if is_sortable || is_filterable { | ||||||
|                     let field_ids = self |                     let field_ids = field_id_map | ||||||
|                         .index |  | ||||||
|                         .fields_ids_map(self.wtxn)? |  | ||||||
|                         .insert("_geo.lat") |                         .insert("_geo.lat") | ||||||
|                         .zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng")) |                         .zip(field_id_map.insert("_geo.lng")) | ||||||
|                         .ok_or(UserError::AttributeLimitReached)?; |                         .ok_or(UserError::AttributeLimitReached)?; | ||||||
|                     Some(field_ids) |                     Some(field_ids) | ||||||
|                 } else { |                 } else { | ||||||
| @@ -341,8 +351,6 @@ where | |||||||
|             } |             } | ||||||
|             None => None, |             None => None, | ||||||
|         }; |         }; | ||||||
|         // get the fid of the `_vectors` field. |  | ||||||
|         let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); |  | ||||||
|  |  | ||||||
|         let stop_words = self.index.stop_words(self.wtxn)?; |         let stop_words = self.index.stop_words(self.wtxn)?; | ||||||
|         let separators = self.index.allowed_separators(self.wtxn)?; |         let separators = self.index.allowed_separators(self.wtxn)?; | ||||||
| @@ -364,6 +372,8 @@ where | |||||||
|             self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB |             self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB | ||||||
|         let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; |         let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; | ||||||
|  |  | ||||||
|  |         let cloned_embedder = self.embedders.clone(); | ||||||
|  |  | ||||||
|         // Run extraction pipeline in parallel. |         // Run extraction pipeline in parallel. | ||||||
|         pool.install(|| { |         pool.install(|| { | ||||||
|             puffin::profile_scope!("extract_and_send_grenad_chunks"); |             puffin::profile_scope!("extract_and_send_grenad_chunks"); | ||||||
| @@ -387,13 +397,14 @@ where | |||||||
|                     faceted_fields, |                     faceted_fields, | ||||||
|                     primary_key_id, |                     primary_key_id, | ||||||
|                     geo_fields_ids, |                     geo_fields_ids, | ||||||
|                     vectors_field_id, |                     field_id_map, | ||||||
|                     stop_words, |                     stop_words, | ||||||
|                     separators.as_deref(), |                     separators.as_deref(), | ||||||
|                     dictionary.as_deref(), |                     dictionary.as_deref(), | ||||||
|                     max_positions_per_attributes, |                     max_positions_per_attributes, | ||||||
|                     exact_attributes, |                     exact_attributes, | ||||||
|                     proximity_precision, |                     proximity_precision, | ||||||
|  |                     cloned_embedder, | ||||||
|                 ) |                 ) | ||||||
|             }); |             }); | ||||||
|  |  | ||||||
| @@ -402,7 +413,7 @@ where | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             // needs to be dropped to avoid channel waiting lock. |             // needs to be dropped to avoid channel waiting lock. | ||||||
|             drop(lmdb_writer_sx) |             drop(lmdb_writer_sx); | ||||||
|         }); |         }); | ||||||
|  |  | ||||||
|         let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0; |         let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0; | ||||||
| @@ -419,6 +430,8 @@ where | |||||||
|         let mut word_docids = None; |         let mut word_docids = None; | ||||||
|         let mut exact_word_docids = None; |         let mut exact_word_docids = None; | ||||||
|  |  | ||||||
|  |         let mut dimension = HashMap::new(); | ||||||
|  |  | ||||||
|         for result in lmdb_writer_rx { |         for result in lmdb_writer_rx { | ||||||
|             if (self.should_abort)() { |             if (self.should_abort)() { | ||||||
|                 return Err(Error::InternalError(InternalError::AbortedIndexation)); |                 return Err(Error::InternalError(InternalError::AbortedIndexation)); | ||||||
| @@ -448,6 +461,22 @@ where | |||||||
|                     word_position_docids = Some(cloneable_chunk); |                     word_position_docids = Some(cloneable_chunk); | ||||||
|                     TypedChunk::WordPositionDocids(chunk) |                     TypedChunk::WordPositionDocids(chunk) | ||||||
|                 } |                 } | ||||||
|  |                 TypedChunk::VectorPoints { | ||||||
|  |                     expected_dimension, | ||||||
|  |                     remove_vectors, | ||||||
|  |                     embeddings, | ||||||
|  |                     manual_vectors, | ||||||
|  |                     embedder_name, | ||||||
|  |                 } => { | ||||||
|  |                     dimension.insert(embedder_name.clone(), expected_dimension); | ||||||
|  |                     TypedChunk::VectorPoints { | ||||||
|  |                         remove_vectors, | ||||||
|  |                         embeddings, | ||||||
|  |                         expected_dimension, | ||||||
|  |                         manual_vectors, | ||||||
|  |                         embedder_name, | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|                 otherwise => otherwise, |                 otherwise => otherwise, | ||||||
|             }; |             }; | ||||||
|  |  | ||||||
| @@ -480,6 +509,33 @@ where | |||||||
|         // We write the primary key field id into the main database |         // We write the primary key field id into the main database | ||||||
|         self.index.put_primary_key(self.wtxn, &primary_key)?; |         self.index.put_primary_key(self.wtxn, &primary_key)?; | ||||||
|         let number_of_documents = self.index.number_of_documents(self.wtxn)?; |         let number_of_documents = self.index.number_of_documents(self.wtxn)?; | ||||||
|  |         let mut rng = rand::rngs::StdRng::seed_from_u64(42); | ||||||
|  |  | ||||||
|  |         for (embedder_name, dimension) in dimension { | ||||||
|  |             let wtxn = &mut *self.wtxn; | ||||||
|  |             let vector_arroy = self.index.vector_arroy; | ||||||
|  |  | ||||||
|  |             let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( | ||||||
|  |                 InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, | ||||||
|  |             )?; | ||||||
|  |  | ||||||
|  |             pool.install(|| { | ||||||
|  |                 let writer_index = (embedder_index as u16) << 8; | ||||||
|  |                 for k in 0..=u8::MAX { | ||||||
|  |                     let writer = arroy::Writer::prepare( | ||||||
|  |                         wtxn, | ||||||
|  |                         vector_arroy, | ||||||
|  |                         writer_index | (k as u16), | ||||||
|  |                         dimension, | ||||||
|  |                     )?; | ||||||
|  |                     if writer.is_empty(wtxn)? { | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|  |                     writer.build(wtxn, &mut rng, None)?; | ||||||
|  |                 } | ||||||
|  |                 Result::Ok(()) | ||||||
|  |             })?; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         self.execute_prefix_databases( |         self.execute_prefix_databases( | ||||||
|             word_docids, |             word_docids, | ||||||
| @@ -694,6 +750,8 @@ fn execute_word_prefix_docids( | |||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests { | mod tests { | ||||||
|  |     use std::collections::BTreeMap; | ||||||
|  |  | ||||||
|     use big_s::S; |     use big_s::S; | ||||||
|     use fst::IntoStreamer; |     use fst::IntoStreamer; | ||||||
|     use heed::RwTxn; |     use heed::RwTxn; | ||||||
| @@ -703,6 +761,7 @@ mod tests { | |||||||
|     use crate::documents::documents_batch_reader_from_objects; |     use crate::documents::documents_batch_reader_from_objects; | ||||||
|     use crate::index::tests::TempIndex; |     use crate::index::tests::TempIndex; | ||||||
|     use crate::search::TermsMatchingStrategy; |     use crate::search::TermsMatchingStrategy; | ||||||
|  |     use crate::update::Setting; | ||||||
|     use crate::{db_snap, Filter, Search}; |     use crate::{db_snap, Filter, Search}; | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
| @@ -2494,18 +2553,39 @@ mod tests { | |||||||
|     /// Vectors must be of the same length. |     /// Vectors must be of the same length. | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_multiple_vectors() { |     fn test_multiple_vectors() { | ||||||
|  |         use crate::vector::settings::{EmbedderSettings, EmbeddingSettings}; | ||||||
|         let index = TempIndex::new(); |         let index = TempIndex::new(); | ||||||
|  |  | ||||||
|         index.add_documents(documents!([{"id": 0, "_vectors": [[0, 1, 2], [3, 4, 5]] }])).unwrap(); |         index | ||||||
|         index.add_documents(documents!([{"id": 1, "_vectors": [6, 7, 8] }])).unwrap(); |             .update_settings(|settings| { | ||||||
|  |                 let mut embedders = BTreeMap::default(); | ||||||
|  |                 embedders.insert( | ||||||
|  |                     "manual".to_string(), | ||||||
|  |                     Setting::Set(EmbeddingSettings { | ||||||
|  |                         embedder_options: Setting::Set(EmbedderSettings::UserProvided( | ||||||
|  |                             crate::vector::settings::UserProvidedSettings { dimensions: 3 }, | ||||||
|  |                         )), | ||||||
|  |                         document_template: Setting::NotSet, | ||||||
|  |                     }), | ||||||
|  |                 ); | ||||||
|  |                 settings.set_embedder_settings(embedders); | ||||||
|  |             }) | ||||||
|  |             .unwrap(); | ||||||
|  |  | ||||||
|         index |         index | ||||||
|             .add_documents( |             .add_documents( | ||||||
|                 documents!([{"id": 2, "_vectors": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }]), |                 documents!([{"id": 0, "_vectors": { "manual": [[0, 1, 2], [3, 4, 5]] } }]), | ||||||
|  |             ) | ||||||
|  |             .unwrap(); | ||||||
|  |         index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap(); | ||||||
|  |         index | ||||||
|  |             .add_documents( | ||||||
|  |                 documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]), | ||||||
|             ) |             ) | ||||||
|             .unwrap(); |             .unwrap(); | ||||||
|  |  | ||||||
|         let rtxn = index.read_txn().unwrap(); |         let rtxn = index.read_txn().unwrap(); | ||||||
|         let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap(); |         let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); | ||||||
|         assert_eq!(res.documents_ids.len(), 3); |         assert_eq!(res.documents_ids.len(), 3); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| use std::collections::{HashMap, HashSet}; | use std::collections::HashMap; | ||||||
| use std::convert::TryInto; | use std::convert::TryInto; | ||||||
| use std::fs::File; | use std::fs::File; | ||||||
| use std::io::{self, BufReader}; | use std::io::{self, BufReader}; | ||||||
| @@ -8,9 +8,7 @@ use charabia::{Language, Script}; | |||||||
| use grenad::MergerBuilder; | use grenad::MergerBuilder; | ||||||
| use heed::types::Bytes; | use heed::types::Bytes; | ||||||
| use heed::{PutFlags, RwTxn}; | use heed::{PutFlags, RwTxn}; | ||||||
| use log::error; |  | ||||||
| use obkv::{KvReader, KvWriter}; | use obkv::{KvReader, KvWriter}; | ||||||
| use ordered_float::OrderedFloat; |  | ||||||
| use roaring::RoaringBitmap; | use roaring::RoaringBitmap; | ||||||
|  |  | ||||||
| use super::helpers::{ | use super::helpers::{ | ||||||
| @@ -18,16 +16,15 @@ use super::helpers::{ | |||||||
|     valid_lmdb_key, CursorClonableMmap, |     valid_lmdb_key, CursorClonableMmap, | ||||||
| }; | }; | ||||||
| use super::{ClonableMmap, MergeFn}; | use super::{ClonableMmap, MergeFn}; | ||||||
| use crate::distance::NDotProductPoint; |  | ||||||
| use crate::error::UserError; |  | ||||||
| use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind}; | use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind}; | ||||||
| use crate::facet::FacetType; | use crate::facet::FacetType; | ||||||
| use crate::index::db_name::DOCUMENTS; | use crate::index::db_name::DOCUMENTS; | ||||||
| use crate::index::Hnsw; |  | ||||||
| use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; | use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; | ||||||
| use crate::update::facet::FacetsUpdate; | use crate::update::facet::FacetsUpdate; | ||||||
| use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; | use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; | ||||||
| use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; | use crate::{ | ||||||
|  |     lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError, | ||||||
|  | }; | ||||||
|  |  | ||||||
| pub(crate) enum TypedChunk { | pub(crate) enum TypedChunk { | ||||||
|     FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>), |     FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>), | ||||||
| @@ -47,7 +44,13 @@ pub(crate) enum TypedChunk { | |||||||
|     FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>), |     FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>), | ||||||
|     FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>), |     FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>), | ||||||
|     GeoPoints(grenad::Reader<BufReader<File>>), |     GeoPoints(grenad::Reader<BufReader<File>>), | ||||||
|     VectorPoints(grenad::Reader<BufReader<File>>), |     VectorPoints { | ||||||
|  |         remove_vectors: grenad::Reader<BufReader<File>>, | ||||||
|  |         embeddings: Option<grenad::Reader<BufReader<File>>>, | ||||||
|  |         expected_dimension: usize, | ||||||
|  |         manual_vectors: grenad::Reader<BufReader<File>>, | ||||||
|  |         embedder_name: String, | ||||||
|  |     }, | ||||||
|     ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), |     ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -100,8 +103,8 @@ impl TypedChunk { | |||||||
|             TypedChunk::GeoPoints(grenad) => { |             TypedChunk::GeoPoints(grenad) => { | ||||||
|                 format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) |                 format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) | ||||||
|             } |             } | ||||||
|             TypedChunk::VectorPoints(grenad) => { |             TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => { | ||||||
|                 format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) |                 format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name) | ||||||
|             } |             } | ||||||
|             TypedChunk::ScriptLanguageDocids(sl_map) => { |             TypedChunk::ScriptLanguageDocids(sl_map) => { | ||||||
|                 format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) |                 format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) | ||||||
| @@ -355,19 +358,77 @@ pub(crate) fn write_typed_chunk_into_index( | |||||||
|             index.put_geo_rtree(wtxn, &rtree)?; |             index.put_geo_rtree(wtxn, &rtree)?; | ||||||
|             index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; |             index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; | ||||||
|         } |         } | ||||||
|         TypedChunk::VectorPoints(vector_points) => { |         TypedChunk::VectorPoints { | ||||||
|             let mut vectors_set = HashSet::new(); |             remove_vectors, | ||||||
|             // We extract and store the previous vectors |             manual_vectors, | ||||||
|             if let Some(hnsw) = index.vector_hnsw(wtxn)? { |             embeddings, | ||||||
|                 for (pid, point) in hnsw.iter() { |             expected_dimension, | ||||||
|                     let pid_key = pid.into_inner(); |             embedder_name, | ||||||
|                     let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); |         } => { | ||||||
|                     let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); |             let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( | ||||||
|                     vectors_set.insert((docid, vector)); |                 InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, | ||||||
|  |             )?; | ||||||
|  |             let writer_index = (embedder_index as u16) << 8; | ||||||
|  |             // FIXME: allow customizing distance | ||||||
|  |             let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX) | ||||||
|  |                 .map(|k| { | ||||||
|  |                     arroy::Writer::prepare( | ||||||
|  |                         wtxn, | ||||||
|  |                         index.vector_arroy, | ||||||
|  |                         writer_index | (k as u16), | ||||||
|  |                         expected_dimension, | ||||||
|  |                     ) | ||||||
|  |                 }) | ||||||
|  |                 .collect(); | ||||||
|  |             let writers = writers?; | ||||||
|  |  | ||||||
|  |             // remove vectors for docids we want them removed | ||||||
|  |             let mut cursor = remove_vectors.into_cursor()?; | ||||||
|  |             while let Some((key, _)) = cursor.move_on_next()? { | ||||||
|  |                 let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); | ||||||
|  |  | ||||||
|  |                 for writer in &writers { | ||||||
|  |                     // Uses invariant: vectors are packed in the first writers. | ||||||
|  |                     if !writer.del_item(wtxn, docid)? { | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             let mut cursor = vector_points.into_cursor()?; |             // add generated embeddings | ||||||
|  |             if let Some(embeddings) = embeddings { | ||||||
|  |                 let mut cursor = embeddings.into_cursor()?; | ||||||
|  |                 while let Some((key, value)) = cursor.move_on_next()? { | ||||||
|  |                     let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); | ||||||
|  |                     let data = pod_collect_to_vec(value); | ||||||
|  |                     // it is a code error to have embeddings and not expected_dimension | ||||||
|  |                     let embeddings = | ||||||
|  |                         crate::vector::Embeddings::from_inner(data, expected_dimension) | ||||||
|  |                             // code error if we somehow got the wrong dimension | ||||||
|  |                             .unwrap(); | ||||||
|  |  | ||||||
|  |                     if embeddings.embedding_count() > u8::MAX.into() { | ||||||
|  |                         let external_docid = if let Ok(Some(Ok(index))) = index | ||||||
|  |                             .external_id_of(wtxn, std::iter::once(docid)) | ||||||
|  |                             .map(|it| it.into_iter().next()) | ||||||
|  |                         { | ||||||
|  |                             index | ||||||
|  |                         } else { | ||||||
|  |                             format!("internal docid={docid}") | ||||||
|  |                         }; | ||||||
|  |                         return Err(crate::Error::UserError(crate::UserError::TooManyVectors( | ||||||
|  |                             external_docid, | ||||||
|  |                             embeddings.embedding_count(), | ||||||
|  |                         ))); | ||||||
|  |                     } | ||||||
|  |                     for (embedding, writer) in embeddings.iter().zip(&writers) { | ||||||
|  |                         writer.add_item(wtxn, docid, embedding)?; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // perform the manual diff | ||||||
|  |             let mut cursor = manual_vectors.into_cursor()?; | ||||||
|             while let Some((key, value)) = cursor.move_on_next()? { |             while let Some((key, value)) = cursor.move_on_next()? { | ||||||
|                 // convert the key back to a u32 (4 bytes) |                 // convert the key back to a u32 (4 bytes) | ||||||
|                 let (left, _index) = try_split_array_at(key).unwrap(); |                 let (left, _index) = try_split_array_at(key).unwrap(); | ||||||
| @@ -375,58 +436,52 @@ pub(crate) fn write_typed_chunk_into_index( | |||||||
|  |  | ||||||
|                 let vector_deladd_obkv = KvReaderDelAdd::new(value); |                 let vector_deladd_obkv = KvReaderDelAdd::new(value); | ||||||
|                 if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { |                 if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { | ||||||
|                     // convert the vector back to a Vec<f32> |                     let vector: Vec<f32> = pod_collect_to_vec(value); | ||||||
|                     let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); |  | ||||||
|                     let key = (docid, vector); |  | ||||||
|                     if !vectors_set.remove(&key) { |  | ||||||
|                         error!("Unable to delete the vector: {:?}", key.1); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                 if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { |  | ||||||
|                     // convert the vector back to a Vec<f32> |  | ||||||
|                     let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); |  | ||||||
|                     vectors_set.insert((docid, vector)); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             // Extract the most common vector dimension |                     let mut deleted_index = None; | ||||||
|             let expected_dimension_size = { |                     for (index, writer) in writers.iter().enumerate() { | ||||||
|                 let mut dims = HashMap::new(); |                         let Some(candidate) = writer.item_vector(wtxn, docid)? else { | ||||||
|                 vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); |                             // uses invariant: vectors are packed in the first writers. | ||||||
|                 dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) |                             break; | ||||||
|                         }; |                         }; | ||||||
|  |                         if candidate == vector { | ||||||
|             // Ensure that the vector lengths are correct and |                             writer.del_item(wtxn, docid)?; | ||||||
|             // prepare the vectors before inserting them in the HNSW. |                             deleted_index = Some(index); | ||||||
|             let mut points = Vec::new(); |  | ||||||
|             let mut docids = Vec::new(); |  | ||||||
|             for (docid, vector) in vectors_set { |  | ||||||
|                 if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { |  | ||||||
|                     return Err(UserError::InvalidVectorDimensions { |  | ||||||
|                         expected: expected_dimension_size.unwrap_or(vector.len()), |  | ||||||
|                         found: vector.len(), |  | ||||||
|                     } |  | ||||||
|                     .into()); |  | ||||||
|                 } else { |  | ||||||
|                     let vector = vector.into_iter().map(OrderedFloat::into_inner).collect(); |  | ||||||
|                     points.push(NDotProductPoint::new(vector)); |  | ||||||
|                     docids.push(docid); |  | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|             let hnsw_length = points.len(); |                     // 🥲 enforce invariant: vectors are packed in the first writers. | ||||||
|             let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); |                     if let Some(deleted_index) = deleted_index { | ||||||
|  |                         let mut last_index_with_a_vector = None; | ||||||
|             assert_eq!(docids.len(), pids.len()); |                         for (index, writer) in writers.iter().enumerate().skip(deleted_index) { | ||||||
|  |                             let Some(candidate) = writer.item_vector(wtxn, docid)? else { | ||||||
|             // Store the vectors in the point-docid relation database |                                 break; | ||||||
|             index.vector_id_docid.clear(wtxn)?; |                             }; | ||||||
|             for (docid, pid) in docids.into_iter().zip(pids) { |                             last_index_with_a_vector = Some((index, candidate)); | ||||||
|                 index.vector_id_docid.put(wtxn, &pid.into_inner(), &docid)?; |                         } | ||||||
|  |                         if let Some((last_index, vector)) = last_index_with_a_vector { | ||||||
|  |                             // unwrap: computed the index from the list of writers | ||||||
|  |                             let writer = writers.get(last_index).unwrap(); | ||||||
|  |                             writer.del_item(wtxn, docid)?; | ||||||
|  |                             writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|             log::debug!("There are {} entries in the HNSW so far", hnsw_length); |                 if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { | ||||||
|             index.put_vector_hnsw(wtxn, &new_hnsw)?; |                     let vector = pod_collect_to_vec(value); | ||||||
|  |  | ||||||
|  |                     // overflow was detected during vector extraction. | ||||||
|  |                     for writer in &writers { | ||||||
|  |                         if !writer.contains_item(wtxn, docid)? { | ||||||
|  |                             writer.add_item(wtxn, docid, &vector)?; | ||||||
|  |                             break; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             log::debug!("Finished vector chunk for {}", embedder_name); | ||||||
|         } |         } | ||||||
|         TypedChunk::ScriptLanguageDocids(sl_map) => { |         TypedChunk::ScriptLanguageDocids(sl_map) => { | ||||||
|             for (key, (deletion, addition)) in sl_map { |             for (key, (deletion, addition)) in sl_map { | ||||||
|   | |||||||
| @@ -1,9 +1,11 @@ | |||||||
| use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; | use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; | ||||||
|  | use std::convert::TryInto; | ||||||
| use std::result::Result as StdResult; | use std::result::Result as StdResult; | ||||||
|  | use std::sync::Arc; | ||||||
|  |  | ||||||
| use charabia::{Normalize, Tokenizer, TokenizerBuilder}; | use charabia::{Normalize, Tokenizer, TokenizerBuilder}; | ||||||
| use deserr::{DeserializeError, Deserr}; | use deserr::{DeserializeError, Deserr}; | ||||||
| use itertools::Itertools; | use itertools::{EitherOrBoth, Itertools}; | ||||||
| use serde::{Deserialize, Deserializer, Serialize, Serializer}; | use serde::{Deserialize, Deserializer, Serialize, Serializer}; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
|  |  | ||||||
| @@ -15,6 +17,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS | |||||||
| use crate::proximity::ProximityPrecision; | use crate::proximity::ProximityPrecision; | ||||||
| use crate::update::index_documents::IndexDocumentsMethod; | use crate::update::index_documents::IndexDocumentsMethod; | ||||||
| use crate::update::{IndexDocuments, UpdateIndexingStep}; | use crate::update::{IndexDocuments, UpdateIndexingStep}; | ||||||
|  | use crate::vector::settings::{EmbeddingSettings, PromptSettings}; | ||||||
|  | use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; | ||||||
| use crate::{FieldsIdsMap, Index, OrderBy, Result}; | use crate::{FieldsIdsMap, Index, OrderBy, Result}; | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, PartialEq, Eq, Copy)] | #[derive(Debug, Clone, PartialEq, Eq, Copy)] | ||||||
| @@ -73,6 +77,13 @@ impl<T> Setting<T> { | |||||||
|             otherwise => otherwise, |             otherwise => otherwise, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn apply(&mut self, new: Self) { | ||||||
|  |         if let Setting::NotSet = new { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         *self = new; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<T: Serialize> Serialize for Setting<T> { | impl<T: Serialize> Serialize for Setting<T> { | ||||||
| @@ -129,6 +140,7 @@ pub struct Settings<'a, 't, 'i> { | |||||||
|     sort_facet_values_by: Setting<HashMap<String, OrderBy>>, |     sort_facet_values_by: Setting<HashMap<String, OrderBy>>, | ||||||
|     pagination_max_total_hits: Setting<usize>, |     pagination_max_total_hits: Setting<usize>, | ||||||
|     proximity_precision: Setting<ProximityPrecision>, |     proximity_precision: Setting<ProximityPrecision>, | ||||||
|  |     embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<'a, 't, 'i> Settings<'a, 't, 'i> { | impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||||
| @@ -161,6 +173,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|             sort_facet_values_by: Setting::NotSet, |             sort_facet_values_by: Setting::NotSet, | ||||||
|             pagination_max_total_hits: Setting::NotSet, |             pagination_max_total_hits: Setting::NotSet, | ||||||
|             proximity_precision: Setting::NotSet, |             proximity_precision: Setting::NotSet, | ||||||
|  |             embedder_settings: Setting::NotSet, | ||||||
|             indexer_config, |             indexer_config, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -343,6 +356,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|         self.proximity_precision = Setting::Reset; |         self.proximity_precision = Setting::Reset; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn set_embedder_settings(&mut self, value: BTreeMap<String, Setting<EmbeddingSettings>>) { | ||||||
|  |         self.embedder_settings = Setting::Set(value); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn reset_embedder_settings(&mut self) { | ||||||
|  |         self.embedder_settings = Setting::Reset; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn reindex<FP, FA>( |     fn reindex<FP, FA>( | ||||||
|         &mut self, |         &mut self, | ||||||
|         progress_callback: &FP, |         progress_callback: &FP, | ||||||
| @@ -377,6 +398,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|             fields_ids_map, |             fields_ids_map, | ||||||
|         )?; |         )?; | ||||||
|  |  | ||||||
|  |         let embedder_configs = self.index.embedding_configs(self.wtxn)?; | ||||||
|  |         let embedders = self.embedders(embedder_configs)?; | ||||||
|  |  | ||||||
|         // We index the generated `TransformOutput` which must contain |         // We index the generated `TransformOutput` which must contain | ||||||
|         // all the documents with fields in the newly defined searchable order. |         // all the documents with fields in the newly defined searchable order. | ||||||
|         let indexing_builder = IndexDocuments::new( |         let indexing_builder = IndexDocuments::new( | ||||||
| @@ -387,11 +411,33 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|             &progress_callback, |             &progress_callback, | ||||||
|             &should_abort, |             &should_abort, | ||||||
|         )?; |         )?; | ||||||
|  |  | ||||||
|  |         let indexing_builder = indexing_builder.with_embedders(embedders); | ||||||
|         indexing_builder.execute_raw(output)?; |         indexing_builder.execute_raw(output)?; | ||||||
|  |  | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn embedders( | ||||||
|  |         &self, | ||||||
|  |         embedding_configs: Vec<(String, EmbeddingConfig)>, | ||||||
|  |     ) -> Result<EmbeddingConfigs> { | ||||||
|  |         let res: Result<_> = embedding_configs | ||||||
|  |             .into_iter() | ||||||
|  |             .map(|(name, EmbeddingConfig { embedder_options, prompt })| { | ||||||
|  |                 let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?); | ||||||
|  |  | ||||||
|  |                 let embedder = Arc::new( | ||||||
|  |                     Embedder::new(embedder_options.clone()) | ||||||
|  |                         .map_err(crate::vector::Error::from) | ||||||
|  |                         .map_err(crate::Error::from)?, | ||||||
|  |                 ); | ||||||
|  |                 Ok((name, (embedder, prompt))) | ||||||
|  |             }) | ||||||
|  |             .collect(); | ||||||
|  |         res.map(EmbeddingConfigs::new) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn update_displayed(&mut self) -> Result<bool> { |     fn update_displayed(&mut self) -> Result<bool> { | ||||||
|         match self.displayed_fields { |         match self.displayed_fields { | ||||||
|             Setting::Set(ref fields) => { |             Setting::Set(ref fields) => { | ||||||
| @@ -890,6 +936,73 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|         Ok(changed) |         Ok(changed) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn update_embedding_configs(&mut self) -> Result<bool> { | ||||||
|  |         let update = match std::mem::take(&mut self.embedder_settings) { | ||||||
|  |             Setting::Set(configs) => { | ||||||
|  |                 let mut changed = false; | ||||||
|  |                 let old_configs = self.index.embedding_configs(self.wtxn)?; | ||||||
|  |                 let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> = | ||||||
|  |                     old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect(); | ||||||
|  |  | ||||||
|  |                 let mut new_configs = BTreeMap::new(); | ||||||
|  |                 for joined in old_configs | ||||||
|  |                     .into_iter() | ||||||
|  |                     .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) | ||||||
|  |                 { | ||||||
|  |                     match joined { | ||||||
|  |                         EitherOrBoth::Both((name, mut old), (_, new)) => { | ||||||
|  |                             old.apply(new); | ||||||
|  |                             let new = validate_prompt(&name, old)?; | ||||||
|  |                             changed = true; | ||||||
|  |                             new_configs.insert(name, new); | ||||||
|  |                         } | ||||||
|  |                         EitherOrBoth::Left((name, setting)) => { | ||||||
|  |                             new_configs.insert(name, setting); | ||||||
|  |                         } | ||||||
|  |                         EitherOrBoth::Right((name, setting)) => { | ||||||
|  |                             let setting = validate_prompt(&name, setting)?; | ||||||
|  |                             changed = true; | ||||||
|  |                             new_configs.insert(name, setting); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 let new_configs: Vec<(String, EmbeddingConfig)> = new_configs | ||||||
|  |                     .into_iter() | ||||||
|  |                     .filter_map(|(name, setting)| match setting { | ||||||
|  |                         Setting::Set(value) => Some((name, value.into())), | ||||||
|  |                         Setting::Reset => None, | ||||||
|  |                         Setting::NotSet => Some((name, EmbeddingSettings::default().into())), | ||||||
|  |                     }) | ||||||
|  |                     .collect(); | ||||||
|  |  | ||||||
|  |                 self.index.embedder_category_id.clear(self.wtxn)?; | ||||||
|  |                 for (index, (embedder_name, _)) in new_configs.iter().enumerate() { | ||||||
|  |                     self.index.embedder_category_id.put_with_flags( | ||||||
|  |                         self.wtxn, | ||||||
|  |                         heed::PutFlags::APPEND, | ||||||
|  |                         embedder_name, | ||||||
|  |                         &index | ||||||
|  |                             .try_into() | ||||||
|  |                             .map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?, | ||||||
|  |                     )?; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 if new_configs.is_empty() { | ||||||
|  |                     self.index.delete_embedding_configs(self.wtxn)?; | ||||||
|  |                 } else { | ||||||
|  |                     self.index.put_embedding_configs(self.wtxn, new_configs)?; | ||||||
|  |                 } | ||||||
|  |                 changed | ||||||
|  |             } | ||||||
|  |             Setting::Reset => { | ||||||
|  |                 self.index.delete_embedding_configs(self.wtxn)?; | ||||||
|  |                 true | ||||||
|  |             } | ||||||
|  |             Setting::NotSet => false, | ||||||
|  |         }; | ||||||
|  |         Ok(update) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()> |     pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()> | ||||||
|     where |     where | ||||||
|         FP: Fn(UpdateIndexingStep) + Sync, |         FP: Fn(UpdateIndexingStep) + Sync, | ||||||
| @@ -927,6 +1040,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|         let searchable_updated = self.update_searchable()?; |         let searchable_updated = self.update_searchable()?; | ||||||
|         let exact_attributes_updated = self.update_exact_attributes()?; |         let exact_attributes_updated = self.update_exact_attributes()?; | ||||||
|         let proximity_precision = self.update_proximity_precision()?; |         let proximity_precision = self.update_proximity_precision()?; | ||||||
|  |         // TODO: very rough approximation of the needs for reindexing where any change will result in | ||||||
|  |         // a full reindexing. | ||||||
|  |         // What can be done instead: | ||||||
|  |         // 1. Only change the distance on a distance change | ||||||
|  |         // 2. Only change the name -> embedder mapping on a name change | ||||||
|  |         // 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage | ||||||
|  |         let embedding_configs_updated = self.update_embedding_configs()?; | ||||||
|  |  | ||||||
|         if stop_words_updated |         if stop_words_updated | ||||||
|             || non_separator_tokens_updated |             || non_separator_tokens_updated | ||||||
| @@ -937,6 +1057,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|             || searchable_updated |             || searchable_updated | ||||||
|             || exact_attributes_updated |             || exact_attributes_updated | ||||||
|             || proximity_precision |             || proximity_precision | ||||||
|  |             || embedding_configs_updated | ||||||
|         { |         { | ||||||
|             self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; |             self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; | ||||||
|         } |         } | ||||||
| @@ -945,6 +1066,31 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | fn validate_prompt( | ||||||
|  |     name: &str, | ||||||
|  |     new: Setting<EmbeddingSettings>, | ||||||
|  | ) -> Result<Setting<EmbeddingSettings>> { | ||||||
|  |     match new { | ||||||
|  |         Setting::Set(EmbeddingSettings { | ||||||
|  |             embedder_options, | ||||||
|  |             document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), | ||||||
|  |         }) => { | ||||||
|  |             // validate | ||||||
|  |             let template = crate::prompt::Prompt::new(template) | ||||||
|  |                 .map(|prompt| crate::prompt::PromptData::from(prompt).template) | ||||||
|  |                 .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; | ||||||
|  |  | ||||||
|  |             Ok(Setting::Set(EmbeddingSettings { | ||||||
|  |                 embedder_options, | ||||||
|  |                 document_template: Setting::Set(PromptSettings { | ||||||
|  |                     template: Setting::Set(template), | ||||||
|  |                 }), | ||||||
|  |             })) | ||||||
|  |         } | ||||||
|  |         new => Ok(new), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests { | mod tests { | ||||||
|     use big_s::S; |     use big_s::S; | ||||||
| @@ -1763,6 +1909,7 @@ mod tests { | |||||||
|                     sort_facet_values_by, |                     sort_facet_values_by, | ||||||
|                     pagination_max_total_hits, |                     pagination_max_total_hits, | ||||||
|                     proximity_precision, |                     proximity_precision, | ||||||
|  |                     embedder_settings, | ||||||
|                 } = settings; |                 } = settings; | ||||||
|                 assert!(matches!(searchable_fields, Setting::NotSet)); |                 assert!(matches!(searchable_fields, Setting::NotSet)); | ||||||
|                 assert!(matches!(displayed_fields, Setting::NotSet)); |                 assert!(matches!(displayed_fields, Setting::NotSet)); | ||||||
| @@ -1785,6 +1932,7 @@ mod tests { | |||||||
|                 assert!(matches!(sort_facet_values_by, Setting::NotSet)); |                 assert!(matches!(sort_facet_values_by, Setting::NotSet)); | ||||||
|                 assert!(matches!(pagination_max_total_hits, Setting::NotSet)); |                 assert!(matches!(pagination_max_total_hits, Setting::NotSet)); | ||||||
|                 assert!(matches!(proximity_precision, Setting::NotSet)); |                 assert!(matches!(proximity_precision, Setting::NotSet)); | ||||||
|  |                 assert!(matches!(embedder_settings, Setting::NotSet)); | ||||||
|             }) |             }) | ||||||
|             .unwrap(); |             .unwrap(); | ||||||
|     } |     } | ||||||
|   | |||||||
							
								
								
									
										244
									
								
								milli/src/vector/error.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										244
									
								
								milli/src/vector/error.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,244 @@ | |||||||
|  | use std::path::PathBuf; | ||||||
|  |  | ||||||
|  | use hf_hub::api::sync::ApiError; | ||||||
|  |  | ||||||
|  | use crate::error::FaultSource; | ||||||
|  | use crate::vector::openai::OpenAiError; | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("Error while generating embeddings: {inner}")] | ||||||
|  | pub struct Error { | ||||||
|  |     pub inner: Box<ErrorKind>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<I: Into<ErrorKind>> From<I> for Error { | ||||||
|  |     fn from(value: I) -> Self { | ||||||
|  |         Self { inner: Box::new(value.into()) } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Error { | ||||||
|  |     pub fn fault(&self) -> FaultSource { | ||||||
|  |         match &*self.inner { | ||||||
|  |             ErrorKind::NewEmbedderError(inner) => inner.fault, | ||||||
|  |             ErrorKind::EmbedError(inner) => inner.fault, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | pub enum ErrorKind { | ||||||
|  |     #[error(transparent)] | ||||||
|  |     NewEmbedderError(#[from] NewEmbedderError), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     EmbedError(#[from] EmbedError), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("{fault}: {kind}")] | ||||||
|  | pub struct EmbedError { | ||||||
|  |     pub kind: EmbedErrorKind, | ||||||
|  |     pub fault: FaultSource, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | pub enum EmbedErrorKind { | ||||||
|  |     #[error("could not tokenize: {0}")] | ||||||
|  |     Tokenize(Box<dyn std::error::Error + Send + Sync>), | ||||||
|  |     #[error("unexpected tensor shape: {0}")] | ||||||
|  |     TensorShape(candle_core::Error), | ||||||
|  |     #[error("unexpected tensor value: {0}")] | ||||||
|  |     TensorValue(candle_core::Error), | ||||||
|  |     #[error("could not run model: {0}")] | ||||||
|  |     ModelForward(candle_core::Error), | ||||||
|  |     #[error("could not reach OpenAI: {0}")] | ||||||
|  |     OpenAiNetwork(reqwest::Error), | ||||||
|  |     #[error("unexpected response from OpenAI: {0}")] | ||||||
|  |     OpenAiUnexpected(reqwest::Error), | ||||||
|  |     #[error("could not authenticate against OpenAI: {0}")] | ||||||
|  |     OpenAiAuth(OpenAiError), | ||||||
|  |     #[error("sent too many requests to OpenAI: {0}")] | ||||||
|  |     OpenAiTooManyRequests(OpenAiError), | ||||||
|  |     #[error("received internal error from OpenAI: {0}")] | ||||||
|  |     OpenAiInternalServerError(OpenAiError), | ||||||
|  |     #[error("sent too many tokens in a request to OpenAI: {0}")] | ||||||
|  |     OpenAiTooManyTokens(OpenAiError), | ||||||
|  |     #[error("received unhandled HTTP status code {0} from OpenAI")] | ||||||
|  |     OpenAiUnhandledStatusCode(u16), | ||||||
|  |     #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] | ||||||
|  |     ManualEmbed(String), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl EmbedError { | ||||||
|  |     pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self { | ||||||
|  |         Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn tensor_shape(inner: candle_core::Error) -> Self { | ||||||
|  |         Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn tensor_value(inner: candle_core::Error) -> Self { | ||||||
|  |         Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn model_forward(inner: candle_core::Error) -> Self { | ||||||
|  |         Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn openai_network(inner: reqwest::Error) -> Self { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { | ||||||
|  |         Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("{fault}: {kind}")] | ||||||
|  | pub struct NewEmbedderError { | ||||||
|  |     pub kind: NewEmbedderErrorKind, | ||||||
|  |     pub fault: FaultSource, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl NewEmbedderError { | ||||||
|  |     pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError { | ||||||
|  |         let open_config = OpenConfig { filename: config_filename, inner }; | ||||||
|  |  | ||||||
|  |         Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn deserialize_config( | ||||||
|  |         config: String, | ||||||
|  |         config_filename: PathBuf, | ||||||
|  |         inner: serde_json::Error, | ||||||
|  |     ) -> NewEmbedderError { | ||||||
|  |         let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; | ||||||
|  |         Self { | ||||||
|  |             kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), | ||||||
|  |             fault: FaultSource::Runtime, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn open_tokenizer( | ||||||
|  |         tokenizer_filename: PathBuf, | ||||||
|  |         inner: Box<dyn std::error::Error + Send + Sync>, | ||||||
|  |     ) -> NewEmbedderError { | ||||||
|  |         let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner }; | ||||||
|  |         Self { | ||||||
|  |             kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer), | ||||||
|  |             fault: FaultSource::Runtime, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn new_api_fail(inner: ApiError) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn api_get(inner: ApiError) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn pytorch_weight(inner: candle_core::Error) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn safetensor_weight(inner: candle_core::Error) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn load_model(inner: candle_core::Error) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { | ||||||
|  |         Self { | ||||||
|  |             kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), | ||||||
|  |             fault: FaultSource::Runtime, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { | ||||||
|  |         Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("could not open config at {filename:?}: {inner}")] | ||||||
|  | pub struct OpenConfig { | ||||||
|  |     pub filename: PathBuf, | ||||||
|  |     pub inner: std::io::Error, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] | ||||||
|  | pub struct DeserializeConfig { | ||||||
|  |     pub config: String, | ||||||
|  |     pub filename: PathBuf, | ||||||
|  |     pub inner: serde_json::Error, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | #[error("could not open tokenizer at {filename}: {inner}")] | ||||||
|  | pub struct OpenTokenizer { | ||||||
|  |     pub filename: PathBuf, | ||||||
|  |     #[source] | ||||||
|  |     pub inner: Box<dyn std::error::Error + Send + Sync>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, thiserror::Error)] | ||||||
|  | pub enum NewEmbedderErrorKind { | ||||||
|  |     // hf | ||||||
|  |     #[error(transparent)] | ||||||
|  |     OpenConfig(OpenConfig), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     DeserializeConfig(DeserializeConfig), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     OpenTokenizer(OpenTokenizer), | ||||||
|  |     #[error("could not build weights from Pytorch weights: {0}")] | ||||||
|  |     PytorchWeight(candle_core::Error), | ||||||
|  |     #[error("could not build weights from Safetensor weights: {0}")] | ||||||
|  |     SafetensorWeight(candle_core::Error), | ||||||
|  |     #[error("could not spawn HG_HUB API client: {0}")] | ||||||
|  |     NewApiFail(ApiError), | ||||||
|  |     #[error("fetching file from HG_HUB failed: {0}")] | ||||||
|  |     ApiGet(ApiError), | ||||||
|  |     #[error("could not determine model dimensions: test embedding failed with {0}")] | ||||||
|  |     CouldNotDetermineDimension(EmbedError), | ||||||
|  |     #[error("loading model failed: {0}")] | ||||||
|  |     LoadModel(candle_core::Error), | ||||||
|  |     // openai | ||||||
|  |     #[error("initializing web client for sending embedding requests failed: {0}")] | ||||||
|  |     InitWebClient(reqwest::Error), | ||||||
|  |     #[error("The API key passed to Authorization error was in an invalid format: {0}")] | ||||||
|  |     InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), | ||||||
|  | } | ||||||
							
								
								
									
										195
									
								
								milli/src/vector/hf.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										195
									
								
								milli/src/vector/hf.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,195 @@ | |||||||
|  | use candle_core::Tensor; | ||||||
|  | use candle_nn::VarBuilder; | ||||||
|  | use candle_transformers::models::bert::{BertModel, Config, DTYPE}; | ||||||
|  | // FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself | ||||||
|  | use hf_hub::api::sync::Api; | ||||||
|  | use hf_hub::{Repo, RepoType}; | ||||||
|  | use tokenizers::{PaddingParams, Tokenizer}; | ||||||
|  |  | ||||||
|  | pub use super::error::{EmbedError, Error, NewEmbedderError}; | ||||||
|  | use super::{DistributionShift, Embedding, Embeddings}; | ||||||
|  |  | ||||||
|  | #[derive( | ||||||
|  |     Debug, | ||||||
|  |     Clone, | ||||||
|  |     Copy, | ||||||
|  |     Default, | ||||||
|  |     Hash, | ||||||
|  |     PartialEq, | ||||||
|  |     Eq, | ||||||
|  |     serde::Deserialize, | ||||||
|  |     serde::Serialize, | ||||||
|  |     deserr::Deserr, | ||||||
|  | )] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | enum WeightSource { | ||||||
|  |     #[default] | ||||||
|  |     Safetensors, | ||||||
|  |     Pytorch, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||||
|  | pub struct EmbedderOptions { | ||||||
|  |     pub model: String, | ||||||
|  |     pub revision: Option<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl EmbedderOptions { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             model: "BAAI/bge-base-en-v1.5".to_string(), | ||||||
|  |             revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for EmbedderOptions { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         Self::new() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /// Perform embedding of documents and queries | ||||||
|  | pub struct Embedder { | ||||||
|  |     model: BertModel, | ||||||
|  |     tokenizer: Tokenizer, | ||||||
|  |     options: EmbedderOptions, | ||||||
|  |     dimensions: usize, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl std::fmt::Debug for Embedder { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Embedder") | ||||||
|  |             .field("model", &self.options.model) | ||||||
|  |             .field("tokenizer", &self.tokenizer) | ||||||
|  |             .field("options", &self.options) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Embedder { | ||||||
|  |     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { | ||||||
|  |         let device = candle_core::Device::Cpu; | ||||||
|  |         let repo = match options.revision.clone() { | ||||||
|  |             Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), | ||||||
|  |             None => Repo::model(options.model.clone()), | ||||||
|  |         }; | ||||||
|  |         let (config_filename, tokenizer_filename, weights_filename, weight_source) = { | ||||||
|  |             let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; | ||||||
|  |             let api = api.repo(repo); | ||||||
|  |             let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; | ||||||
|  |             let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; | ||||||
|  |             let (weights, source) = { | ||||||
|  |                 api.get("pytorch_model.bin") | ||||||
|  |                     .map(|filename| (filename, WeightSource::Pytorch)) | ||||||
|  |                     .or_else(|_| { | ||||||
|  |                         api.get("model.safetensors") | ||||||
|  |                             .map(|filename| (filename, WeightSource::Safetensors)) | ||||||
|  |                     }) | ||||||
|  |                     .map_err(NewEmbedderError::api_get)? | ||||||
|  |             }; | ||||||
|  |             (config, tokenizer, weights, source) | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         let config = std::fs::read_to_string(&config_filename) | ||||||
|  |             .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; | ||||||
|  |         let config: Config = serde_json::from_str(&config).map_err(|inner| { | ||||||
|  |             NewEmbedderError::deserialize_config(config, config_filename, inner) | ||||||
|  |         })?; | ||||||
|  |         let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) | ||||||
|  |             .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; | ||||||
|  |  | ||||||
|  |         let vb = match weight_source { | ||||||
|  |             WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) | ||||||
|  |                 .map_err(NewEmbedderError::pytorch_weight)?, | ||||||
|  |             WeightSource::Safetensors => unsafe { | ||||||
|  |                 VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) | ||||||
|  |                     .map_err(NewEmbedderError::safetensor_weight)? | ||||||
|  |             }, | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; | ||||||
|  |  | ||||||
|  |         if let Some(pp) = tokenizer.get_padding_mut() { | ||||||
|  |             pp.strategy = tokenizers::PaddingStrategy::BatchLongest | ||||||
|  |         } else { | ||||||
|  |             let pp = PaddingParams { | ||||||
|  |                 strategy: tokenizers::PaddingStrategy::BatchLongest, | ||||||
|  |                 ..Default::default() | ||||||
|  |             }; | ||||||
|  |             tokenizer.with_padding(Some(pp)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         let mut this = Self { model, tokenizer, options, dimensions: 0 }; | ||||||
|  |  | ||||||
|  |         let embeddings = this | ||||||
|  |             .embed(vec!["test".into()]) | ||||||
|  |             .map_err(NewEmbedderError::hf_could_not_determine_dimension)?; | ||||||
|  |         this.dimensions = embeddings.first().unwrap().dimension(); | ||||||
|  |  | ||||||
|  |         Ok(this) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embed( | ||||||
|  |         &self, | ||||||
|  |         mut texts: Vec<String>, | ||||||
|  |     ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { | ||||||
|  |         let tokens = match texts.len() { | ||||||
|  |             1 => vec![self | ||||||
|  |                 .tokenizer | ||||||
|  |                 .encode(texts.pop().unwrap(), true) | ||||||
|  |                 .map_err(EmbedError::tokenize)?], | ||||||
|  |             _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, | ||||||
|  |         }; | ||||||
|  |         let token_ids = tokens | ||||||
|  |             .iter() | ||||||
|  |             .map(|tokens| { | ||||||
|  |                 let tokens = tokens.get_ids().to_vec(); | ||||||
|  |                 Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) | ||||||
|  |             }) | ||||||
|  |             .collect::<Result<Vec<_>, EmbedError>>()?; | ||||||
|  |  | ||||||
|  |         let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; | ||||||
|  |         let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; | ||||||
|  |         let embeddings = | ||||||
|  |             self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; | ||||||
|  |  | ||||||
|  |         // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||||||
|  |         let (_n_sentence, n_tokens, _hidden_size) = | ||||||
|  |             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||||
|  |  | ||||||
|  |         let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) | ||||||
|  |             .map_err(EmbedError::tensor_shape)?; | ||||||
|  |  | ||||||
|  |         let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; | ||||||
|  |         Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embed_chunks( | ||||||
|  |         &self, | ||||||
|  |         text_chunks: Vec<Vec<String>>, | ||||||
|  |     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||||
|  |         text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn chunk_count_hint(&self) -> usize { | ||||||
|  |         1 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||||
|  |         std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn dimensions(&self) -> usize { | ||||||
|  |         self.dimensions | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn distribution(&self) -> Option<DistributionShift> { | ||||||
|  |         if self.options.model == "BAAI/bge-base-en-v1.5" { | ||||||
|  |             Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) | ||||||
|  |         } else { | ||||||
|  |             None | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										34
									
								
								milli/src/vector/manual.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								milli/src/vector/manual.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | use super::error::EmbedError; | ||||||
|  | use super::Embeddings; | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy)] | ||||||
|  | pub struct Embedder { | ||||||
|  |     dimensions: usize, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||||
|  | pub struct EmbedderOptions { | ||||||
|  |     pub dimensions: usize, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Embedder { | ||||||
|  |     pub fn new(options: EmbedderOptions) -> Self { | ||||||
|  |         Self { dimensions: options.dimensions } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||||
|  |         let Some(text) = texts.pop() else { return Ok(Default::default()) }; | ||||||
|  |         Err(EmbedError::embed_on_manual_embedder(text)) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn dimensions(&self) -> usize { | ||||||
|  |         self.dimensions | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embed_chunks( | ||||||
|  |         &self, | ||||||
|  |         text_chunks: Vec<Vec<String>>, | ||||||
|  |     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||||
|  |         text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										257
									
								
								milli/src/vector/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										257
									
								
								milli/src/vector/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,257 @@ | |||||||
|  | use std::collections::HashMap; | ||||||
|  | use std::sync::Arc; | ||||||
|  |  | ||||||
|  | use self::error::{EmbedError, NewEmbedderError}; | ||||||
|  | use crate::prompt::{Prompt, PromptData}; | ||||||
|  |  | ||||||
|  | pub mod error; | ||||||
|  | pub mod hf; | ||||||
|  | pub mod manual; | ||||||
|  | pub mod openai; | ||||||
|  | pub mod settings; | ||||||
|  |  | ||||||
|  | pub use self::error::Error; | ||||||
|  |  | ||||||
|  | pub type Embedding = Vec<f32>; | ||||||
|  |  | ||||||
|  | pub struct Embeddings<F> { | ||||||
|  |     data: Vec<F>, | ||||||
|  |     dimension: usize, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<F> Embeddings<F> { | ||||||
|  |     pub fn new(dimension: usize) -> Self { | ||||||
|  |         Self { data: Default::default(), dimension } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn from_single_embedding(embedding: Vec<F>) -> Self { | ||||||
|  |         Self { dimension: embedding.len(), data: embedding } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> { | ||||||
|  |         let mut this = Self::new(dimension); | ||||||
|  |         this.append(data)?; | ||||||
|  |         Ok(this) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn embedding_count(&self) -> usize { | ||||||
|  |         self.data.len() / self.dimension | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn dimension(&self) -> usize { | ||||||
|  |         self.dimension | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn into_inner(self) -> Vec<F> { | ||||||
|  |         self.data | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn as_inner(&self) -> &[F] { | ||||||
|  |         &self.data | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ { | ||||||
|  |         self.data.as_slice().chunks_exact(self.dimension) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> { | ||||||
|  |         if embedding.len() != self.dimension { | ||||||
|  |             return Err(embedding); | ||||||
|  |         } | ||||||
|  |         self.data.append(&mut embedding); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> { | ||||||
|  |         if embeddings.len() % self.dimension != 0 { | ||||||
|  |             return Err(embeddings); | ||||||
|  |         } | ||||||
|  |         self.data.append(&mut embeddings); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub enum Embedder { | ||||||
|  |     HuggingFace(hf::Embedder), | ||||||
|  |     OpenAi(openai::Embedder), | ||||||
|  |     UserProvided(manual::Embedder), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] | ||||||
|  | pub struct EmbeddingConfig { | ||||||
|  |     pub embedder_options: EmbedderOptions, | ||||||
|  |     pub prompt: PromptData, | ||||||
|  |     // TODO: add metrics and anything needed | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Clone, Default)] | ||||||
|  | pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>); | ||||||
|  |  | ||||||
|  | impl EmbeddingConfigs { | ||||||
|  |     pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self { | ||||||
|  |         Self(data) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> { | ||||||
|  |         self.0.get(name).cloned() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> { | ||||||
|  |         self.get_default_embedder_name().and_then(|default| self.get(&default)) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn get_default_embedder_name(&self) -> Option<String> { | ||||||
|  |         let mut it = self.0.keys(); | ||||||
|  |         let first_name = it.next(); | ||||||
|  |         let second_name = it.next(); | ||||||
|  |         match (first_name, second_name) { | ||||||
|  |             (None, _) => None, | ||||||
|  |             (Some(first), None) => Some(first.to_owned()), | ||||||
|  |             (Some(_), Some(_)) => Some("default".to_owned()), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl IntoIterator for EmbeddingConfigs { | ||||||
|  |     type Item = (String, (Arc<Embedder>, Arc<Prompt>)); | ||||||
|  |  | ||||||
|  |     type IntoIter = std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>)>; | ||||||
|  |  | ||||||
|  |     fn into_iter(self) -> Self::IntoIter { | ||||||
|  |         self.0.into_iter() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||||
|  | pub enum EmbedderOptions { | ||||||
|  |     HuggingFace(hf::EmbedderOptions), | ||||||
|  |     OpenAi(openai::EmbedderOptions), | ||||||
|  |     UserProvided(manual::EmbedderOptions), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for EmbedderOptions { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         Self::HuggingFace(Default::default()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl EmbedderOptions { | ||||||
|  |     pub fn huggingface() -> Self { | ||||||
|  |         Self::HuggingFace(hf::EmbedderOptions::new()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn openai(api_key: Option<String>) -> Self { | ||||||
|  |         Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Embedder { | ||||||
|  |     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { | ||||||
|  |         Ok(match options { | ||||||
|  |             EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), | ||||||
|  |             EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), | ||||||
|  |             EmbedderOptions::UserProvided(options) => { | ||||||
|  |                 Self::UserProvided(manual::Embedder::new(options)) | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub async fn embed( | ||||||
|  |         &self, | ||||||
|  |         texts: Vec<String>, | ||||||
|  |     ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { | ||||||
|  |         match self { | ||||||
|  |             Embedder::HuggingFace(embedder) => embedder.embed(texts), | ||||||
|  |             Embedder::OpenAi(embedder) => embedder.embed(texts).await, | ||||||
|  |             Embedder::UserProvided(embedder) => embedder.embed(texts), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub async fn embed_chunks( | ||||||
|  |         &self, | ||||||
|  |         text_chunks: Vec<Vec<String>>, | ||||||
|  |     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||||
|  |         match self { | ||||||
|  |             Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), | ||||||
|  |             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, | ||||||
|  |             Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn chunk_count_hint(&self) -> usize { | ||||||
|  |         match self { | ||||||
|  |             Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), | ||||||
|  |             Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), | ||||||
|  |             Embedder::UserProvided(_) => 1, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||||
|  |         match self { | ||||||
|  |             Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||||
|  |             Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||||
|  |             Embedder::UserProvided(_) => 1, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn dimensions(&self) -> usize { | ||||||
|  |         match self { | ||||||
|  |             Embedder::HuggingFace(embedder) => embedder.dimensions(), | ||||||
|  |             Embedder::OpenAi(embedder) => embedder.dimensions(), | ||||||
|  |             Embedder::UserProvided(embedder) => embedder.dimensions(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn distribution(&self) -> Option<DistributionShift> { | ||||||
|  |         match self { | ||||||
|  |             Embedder::HuggingFace(embedder) => embedder.distribution(), | ||||||
|  |             Embedder::OpenAi(embedder) => embedder.distribution(), | ||||||
|  |             Embedder::UserProvided(_embedder) => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Copy)] | ||||||
|  | pub struct DistributionShift { | ||||||
|  |     pub current_mean: f32, | ||||||
|  |     pub current_sigma: f32, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl DistributionShift { | ||||||
|  |     /// `None` if sigma <= 0. | ||||||
|  |     pub fn new(mean: f32, sigma: f32) -> Option<Self> { | ||||||
|  |         if sigma <= 0.0 { | ||||||
|  |             None | ||||||
|  |         } else { | ||||||
|  |             Some(Self { current_mean: mean, current_sigma: sigma }) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn shift(&self, score: f32) -> f32 { | ||||||
|  |         // <https://math.stackexchange.com/a/2894689> | ||||||
|  |         // We're somewhat abusively mapping the distribution of distances to a gaussian. | ||||||
|  |         // The parameters we're given is the mean and sigma of the native result distribution. | ||||||
|  |         // We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4. | ||||||
|  |  | ||||||
|  |         let target_mean = 0.5; | ||||||
|  |         let target_sigma = 0.4; | ||||||
|  |  | ||||||
|  |         // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. | ||||||
|  |         let factor = target_sigma / self.current_sigma; | ||||||
|  |         // a*mu1 + b = mu2 => b = mu2 - a*mu1 | ||||||
|  |         let offset = target_mean - (factor * self.current_mean); | ||||||
|  |  | ||||||
|  |         let mut score = factor * score + offset; | ||||||
|  |  | ||||||
|  |         // clamp the final score in the ]0, 1] interval. | ||||||
|  |         if score <= 0.0 { | ||||||
|  |             score = f32::EPSILON; | ||||||
|  |         } | ||||||
|  |         if score > 1.0 { | ||||||
|  |             score = 1.0; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         score | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										445
									
								
								milli/src/vector/openai.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										445
									
								
								milli/src/vector/openai.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,445 @@ | |||||||
|  | use std::fmt::Display; | ||||||
|  |  | ||||||
|  | use reqwest::StatusCode; | ||||||
|  | use serde::{Deserialize, Serialize}; | ||||||
|  |  | ||||||
|  | use super::error::{EmbedError, NewEmbedderError}; | ||||||
|  | use super::{DistributionShift, Embedding, Embeddings}; | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct Embedder { | ||||||
|  |     client: reqwest::Client, | ||||||
|  |     tokenizer: tiktoken_rs::CoreBPE, | ||||||
|  |     options: EmbedderOptions, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||||
|  | pub struct EmbedderOptions { | ||||||
|  |     pub api_key: Option<String>, | ||||||
|  |     pub embedding_model: EmbeddingModel, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive( | ||||||
|  |     Debug, | ||||||
|  |     Clone, | ||||||
|  |     Copy, | ||||||
|  |     Default, | ||||||
|  |     Hash, | ||||||
|  |     PartialEq, | ||||||
|  |     Eq, | ||||||
|  |     serde::Serialize, | ||||||
|  |     serde::Deserialize, | ||||||
|  |     deserr::Deserr, | ||||||
|  | )] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub enum EmbeddingModel { | ||||||
|  |     #[default] | ||||||
|  |     #[serde(rename = "text-embedding-ada-002")] | ||||||
|  |     #[deserr(rename = "text-embedding-ada-002")] | ||||||
|  |     TextEmbeddingAda002, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl EmbeddingModel { | ||||||
|  |     pub fn max_token(&self) -> usize { | ||||||
|  |         match self { | ||||||
|  |             EmbeddingModel::TextEmbeddingAda002 => 8191, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn dimensions(&self) -> usize { | ||||||
|  |         match self { | ||||||
|  |             EmbeddingModel::TextEmbeddingAda002 => 1536, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn name(&self) -> &'static str { | ||||||
|  |         match self { | ||||||
|  |             EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn from_name(name: &'static str) -> Option<Self> { | ||||||
|  |         match name { | ||||||
|  |             "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn distribution(&self) -> Option<DistributionShift> { | ||||||
|  |         match self { | ||||||
|  |             EmbeddingModel::TextEmbeddingAda002 => { | ||||||
|  |                 Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; | ||||||
|  |  | ||||||
|  | impl EmbedderOptions { | ||||||
|  |     pub fn with_default_model(api_key: Option<String>) -> Self { | ||||||
|  |         Self { api_key, embedding_model: Default::default() } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { | ||||||
|  |         Self { api_key, embedding_model } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Embedder { | ||||||
|  |     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||||
|  |         let mut headers = reqwest::header::HeaderMap::new(); | ||||||
|  |         let mut inferred_api_key = Default::default(); | ||||||
|  |         let api_key = options.api_key.as_ref().unwrap_or_else(|| { | ||||||
|  |             inferred_api_key = infer_api_key(); | ||||||
|  |             &inferred_api_key | ||||||
|  |         }); | ||||||
|  |         headers.insert( | ||||||
|  |             reqwest::header::AUTHORIZATION, | ||||||
|  |             reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) | ||||||
|  |                 .map_err(NewEmbedderError::openai_invalid_api_key_format)?, | ||||||
|  |         ); | ||||||
|  |         headers.insert( | ||||||
|  |             reqwest::header::CONTENT_TYPE, | ||||||
|  |             reqwest::header::HeaderValue::from_static("application/json"), | ||||||
|  |         ); | ||||||
|  |         let client = reqwest::ClientBuilder::new() | ||||||
|  |             .default_headers(headers) | ||||||
|  |             .build() | ||||||
|  |             .map_err(NewEmbedderError::openai_initialize_web_client)?; | ||||||
|  |  | ||||||
|  |         // looking at the code it is very unclear that this can actually fail. | ||||||
|  |         let tokenizer = tiktoken_rs::cl100k_base().unwrap(); | ||||||
|  |  | ||||||
|  |         Ok(Self { options, client, tokenizer }) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||||
|  |         let mut tokenized = false; | ||||||
|  |  | ||||||
|  |         for attempt in 0..7 { | ||||||
|  |             let result = if tokenized { | ||||||
|  |                 self.try_embed_tokenized(&texts).await | ||||||
|  |             } else { | ||||||
|  |                 self.try_embed(&texts).await | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             let retry_duration = match result { | ||||||
|  |                 Ok(embeddings) => return Ok(embeddings), | ||||||
|  |                 Err(retry) => { | ||||||
|  |                     log::warn!("Failed: {}", retry.error); | ||||||
|  |                     tokenized |= retry.must_tokenize(); | ||||||
|  |                     retry.into_duration(attempt) | ||||||
|  |                 } | ||||||
|  |             }?; | ||||||
|  |             log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); | ||||||
|  |             tokio::time::sleep(retry_duration).await; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         let result = if tokenized { | ||||||
|  |             self.try_embed_tokenized(&texts).await | ||||||
|  |         } else { | ||||||
|  |             self.try_embed(&texts).await | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         result.map_err(Retry::into_error) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> { | ||||||
|  |         if !response.status().is_success() { | ||||||
|  |             match response.status() { | ||||||
|  |                 StatusCode::UNAUTHORIZED => { | ||||||
|  |                     let error_response: OpenAiErrorResponse = response | ||||||
|  |                         .json() | ||||||
|  |                         .await | ||||||
|  |                         .map_err(EmbedError::openai_unexpected) | ||||||
|  |                         .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |                     return Err(Retry::give_up(EmbedError::openai_auth_error( | ||||||
|  |                         error_response.error, | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |                 StatusCode::TOO_MANY_REQUESTS => { | ||||||
|  |                     let error_response: OpenAiErrorResponse = response | ||||||
|  |                         .json() | ||||||
|  |                         .await | ||||||
|  |                         .map_err(EmbedError::openai_unexpected) | ||||||
|  |                         .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |                     return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( | ||||||
|  |                         error_response.error, | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |                 StatusCode::INTERNAL_SERVER_ERROR => { | ||||||
|  |                     let error_response: OpenAiErrorResponse = response | ||||||
|  |                         .json() | ||||||
|  |                         .await | ||||||
|  |                         .map_err(EmbedError::openai_unexpected) | ||||||
|  |                         .map_err(Retry::retry_later)?; | ||||||
|  |                     return Err(Retry::retry_later(EmbedError::openai_internal_server_error( | ||||||
|  |                         error_response.error, | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |                 StatusCode::SERVICE_UNAVAILABLE => { | ||||||
|  |                     let error_response: OpenAiErrorResponse = response | ||||||
|  |                         .json() | ||||||
|  |                         .await | ||||||
|  |                         .map_err(EmbedError::openai_unexpected) | ||||||
|  |                         .map_err(Retry::retry_later)?; | ||||||
|  |                     return Err(Retry::retry_later(EmbedError::openai_internal_server_error( | ||||||
|  |                         error_response.error, | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |                 StatusCode::BAD_REQUEST => { | ||||||
|  |                     // Most probably, one text contained too many tokens | ||||||
|  |                     let error_response: OpenAiErrorResponse = response | ||||||
|  |                         .json() | ||||||
|  |                         .await | ||||||
|  |                         .map_err(EmbedError::openai_unexpected) | ||||||
|  |                         .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |                     log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt."); | ||||||
|  |  | ||||||
|  |                     return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( | ||||||
|  |                         error_response.error, | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |                 code => { | ||||||
|  |                     return Err(Retry::give_up(EmbedError::openai_unhandled_status_code( | ||||||
|  |                         code.as_u16(), | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Ok(response) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn try_embed<S: AsRef<str> + serde::Serialize>( | ||||||
|  |         &self, | ||||||
|  |         texts: &[S], | ||||||
|  |     ) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||||
|  |         for text in texts { | ||||||
|  |             log::trace!("Received prompt: {}", text.as_ref()) | ||||||
|  |         } | ||||||
|  |         let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; | ||||||
|  |         let response = self | ||||||
|  |             .client | ||||||
|  |             .post(OPENAI_EMBEDDINGS_URL) | ||||||
|  |             .json(&request) | ||||||
|  |             .send() | ||||||
|  |             .await | ||||||
|  |             .map_err(EmbedError::openai_network) | ||||||
|  |             .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |         let response = Self::check_response(response).await?; | ||||||
|  |  | ||||||
|  |         let response: OpenAiResponse = response | ||||||
|  |             .json() | ||||||
|  |             .await | ||||||
|  |             .map_err(EmbedError::openai_unexpected) | ||||||
|  |             .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |         log::trace!("response: {:?}", response.data); | ||||||
|  |  | ||||||
|  |         Ok(response | ||||||
|  |             .data | ||||||
|  |             .into_iter() | ||||||
|  |             .map(|data| Embeddings::from_single_embedding(data.embedding)) | ||||||
|  |             .collect()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||||
|  |         pub const OVERLAP_SIZE: usize = 200; | ||||||
|  |         let mut all_embeddings = Vec::with_capacity(text.len()); | ||||||
|  |         for text in text { | ||||||
|  |             let max_token_count = self.options.embedding_model.max_token(); | ||||||
|  |             let encoded = self.tokenizer.encode_ordinary(text.as_str()); | ||||||
|  |             let len = encoded.len(); | ||||||
|  |             if len < max_token_count { | ||||||
|  |                 all_embeddings.append(&mut self.try_embed(&[text]).await?); | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             let mut tokens = encoded.as_slice(); | ||||||
|  |             let mut embeddings_for_prompt = | ||||||
|  |                 Embeddings::new(self.options.embedding_model.dimensions()); | ||||||
|  |             while tokens.len() > max_token_count { | ||||||
|  |                 let window = &tokens[..max_token_count]; | ||||||
|  |                 embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); | ||||||
|  |  | ||||||
|  |                 tokens = &tokens[max_token_count - OVERLAP_SIZE..]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // end of text | ||||||
|  |             embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); | ||||||
|  |  | ||||||
|  |             all_embeddings.push(embeddings_for_prompt); | ||||||
|  |         } | ||||||
|  |         Ok(all_embeddings) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { | ||||||
|  |         for attempt in 0..9 { | ||||||
|  |             let duration = match self.try_embed_tokens(tokens).await { | ||||||
|  |                 Ok(embedding) => return Ok(embedding), | ||||||
|  |                 Err(retry) => retry.into_duration(attempt), | ||||||
|  |             } | ||||||
|  |             .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |             tokio::time::sleep(duration).await; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { | ||||||
|  |         let request = | ||||||
|  |             OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; | ||||||
|  |         let response = self | ||||||
|  |             .client | ||||||
|  |             .post(OPENAI_EMBEDDINGS_URL) | ||||||
|  |             .json(&request) | ||||||
|  |             .send() | ||||||
|  |             .await | ||||||
|  |             .map_err(EmbedError::openai_network) | ||||||
|  |             .map_err(Retry::retry_later)?; | ||||||
|  |  | ||||||
|  |         let response = Self::check_response(response).await?; | ||||||
|  |  | ||||||
|  |         let mut response: OpenAiResponse = response | ||||||
|  |             .json() | ||||||
|  |             .await | ||||||
|  |             .map_err(EmbedError::openai_unexpected) | ||||||
|  |             .map_err(Retry::retry_later)?; | ||||||
|  |         Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub async fn embed_chunks( | ||||||
|  |         &self, | ||||||
|  |         text_chunks: Vec<Vec<String>>, | ||||||
|  |     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||||
|  |         futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) | ||||||
|  |             .await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn chunk_count_hint(&self) -> usize { | ||||||
|  |         10 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||||
|  |         10 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn dimensions(&self) -> usize { | ||||||
|  |         self.options.embedding_model.dimensions() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn distribution(&self) -> Option<DistributionShift> { | ||||||
|  |         self.options.embedding_model.distribution() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // retrying in case of failure | ||||||
|  |  | ||||||
|  | struct Retry { | ||||||
|  |     error: EmbedError, | ||||||
|  |     strategy: RetryStrategy, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | enum RetryStrategy { | ||||||
|  |     GiveUp, | ||||||
|  |     Retry, | ||||||
|  |     RetryTokenized, | ||||||
|  |     RetryAfterRateLimit, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Retry { | ||||||
|  |     fn give_up(error: EmbedError) -> Self { | ||||||
|  |         Self { error, strategy: RetryStrategy::GiveUp } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn retry_later(error: EmbedError) -> Self { | ||||||
|  |         Self { error, strategy: RetryStrategy::Retry } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn retry_tokenized(error: EmbedError) -> Self { | ||||||
|  |         Self { error, strategy: RetryStrategy::RetryTokenized } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn rate_limited(error: EmbedError) -> Self { | ||||||
|  |         Self { error, strategy: RetryStrategy::RetryAfterRateLimit } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> { | ||||||
|  |         match self.strategy { | ||||||
|  |             RetryStrategy::GiveUp => Err(self.error), | ||||||
|  |             RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), | ||||||
|  |             RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), | ||||||
|  |             RetryStrategy::RetryAfterRateLimit => { | ||||||
|  |                 Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn must_tokenize(&self) -> bool { | ||||||
|  |         matches!(self.strategy, RetryStrategy::RetryTokenized) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn into_error(self) -> EmbedError { | ||||||
|  |         self.error | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // openai api structs | ||||||
|  |  | ||||||
|  | #[derive(Debug, Serialize)] | ||||||
|  | struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { | ||||||
|  |     model: &'a str, | ||||||
|  |     input: &'a [S], | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Serialize)] | ||||||
|  | struct OpenAiTokensRequest<'a> { | ||||||
|  |     model: &'a str, | ||||||
|  |     input: &'a [usize], | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Deserialize)] | ||||||
|  | struct OpenAiResponse { | ||||||
|  |     data: Vec<OpenAiEmbedding>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Deserialize)] | ||||||
|  | struct OpenAiErrorResponse { | ||||||
|  |     error: OpenAiError, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Deserialize)] | ||||||
|  | pub struct OpenAiError { | ||||||
|  |     message: String, | ||||||
|  |     // type: String, | ||||||
|  |     code: Option<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Display for OpenAiError { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         match &self.code { | ||||||
|  |             Some(code) => write!(f, "{} ({})", self.message, code), | ||||||
|  |             None => write!(f, "{}", self.message), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Deserialize)] | ||||||
|  | struct OpenAiEmbedding { | ||||||
|  |     embedding: Embedding, | ||||||
|  |     // object: String, | ||||||
|  |     // index: usize, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn infer_api_key() -> String { | ||||||
|  |     std::env::var("MEILI_OPENAI_API_KEY") | ||||||
|  |         .or_else(|_| std::env::var("OPENAI_API_KEY")) | ||||||
|  |         .unwrap_or_default() | ||||||
|  | } | ||||||
							
								
								
									
										292
									
								
								milli/src/vector/settings.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										292
									
								
								milli/src/vector/settings.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,292 @@ | |||||||
|  | use deserr::Deserr; | ||||||
|  | use serde::{Deserialize, Serialize}; | ||||||
|  |  | ||||||
|  | use crate::prompt::PromptData; | ||||||
|  | use crate::update::Setting; | ||||||
|  | use crate::vector::EmbeddingConfig; | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub struct EmbeddingSettings { | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] | ||||||
|  |     #[deserr(default, rename = "source")] | ||||||
|  |     pub embedder_options: Setting<EmbedderSettings>, | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||||
|  |     #[deserr(default)] | ||||||
|  |     pub document_template: Setting<PromptSettings>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl EmbeddingSettings { | ||||||
|  |     pub fn apply(&mut self, new: Self) { | ||||||
|  |         let EmbeddingSettings { embedder_options, document_template: prompt } = new; | ||||||
|  |         self.embedder_options.apply(embedder_options); | ||||||
|  |         self.document_template.apply(prompt); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<EmbeddingConfig> for EmbeddingSettings { | ||||||
|  |     fn from(value: EmbeddingConfig) -> Self { | ||||||
|  |         Self { | ||||||
|  |             embedder_options: Setting::Set(value.embedder_options.into()), | ||||||
|  |             document_template: Setting::Set(value.prompt.into()), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<EmbeddingSettings> for EmbeddingConfig { | ||||||
|  |     fn from(value: EmbeddingSettings) -> Self { | ||||||
|  |         let mut this = Self::default(); | ||||||
|  |         let EmbeddingSettings { embedder_options, document_template: prompt } = value; | ||||||
|  |         if let Some(embedder_options) = embedder_options.set() { | ||||||
|  |             this.embedder_options = embedder_options.into(); | ||||||
|  |         } | ||||||
|  |         if let Some(prompt) = prompt.set() { | ||||||
|  |             this.prompt = prompt.into(); | ||||||
|  |         } | ||||||
|  |         this | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub struct PromptSettings { | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||||
|  |     #[deserr(default)] | ||||||
|  |     pub template: Setting<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl PromptSettings { | ||||||
|  |     pub fn apply(&mut self, new: Self) { | ||||||
|  |         let PromptSettings { template } = new; | ||||||
|  |         self.template.apply(template); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<PromptData> for PromptSettings { | ||||||
|  |     fn from(value: PromptData) -> Self { | ||||||
|  |         Self { template: Setting::Set(value.template) } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<PromptSettings> for PromptData { | ||||||
|  |     fn from(value: PromptSettings) -> Self { | ||||||
|  |         let mut this = PromptData::default(); | ||||||
|  |         let PromptSettings { template } = value; | ||||||
|  |         if let Some(template) = template.set() { | ||||||
|  |             this.template = template; | ||||||
|  |         } | ||||||
|  |         this | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | pub enum EmbedderSettings { | ||||||
|  |     HuggingFace(Setting<HfEmbedderSettings>), | ||||||
|  |     OpenAi(Setting<OpenAiEmbedderSettings>), | ||||||
|  |     UserProvided(UserProvidedSettings), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<E> Deserr<E> for EmbedderSettings | ||||||
|  | where | ||||||
|  |     E: deserr::DeserializeError, | ||||||
|  | { | ||||||
|  |     fn deserialize_from_value<V: deserr::IntoValue>( | ||||||
|  |         value: deserr::Value<V>, | ||||||
|  |         location: deserr::ValuePointerRef, | ||||||
|  |     ) -> Result<Self, E> { | ||||||
|  |         match value { | ||||||
|  |             deserr::Value::Map(map) => { | ||||||
|  |                 if deserr::Map::len(&map) != 1 { | ||||||
|  |                     return Err(deserr::take_cf_content(E::error::<V>( | ||||||
|  |                         None, | ||||||
|  |                         deserr::ErrorKind::Unexpected { | ||||||
|  |                             msg: format!( | ||||||
|  |                                 "Expected a single field, got {} fields", | ||||||
|  |                                 deserr::Map::len(&map) | ||||||
|  |                             ), | ||||||
|  |                         }, | ||||||
|  |                         location, | ||||||
|  |                     ))); | ||||||
|  |                 } | ||||||
|  |                 let mut it = deserr::Map::into_iter(map); | ||||||
|  |                 let (k, v) = it.next().unwrap(); | ||||||
|  |  | ||||||
|  |                 match k.as_str() { | ||||||
|  |                     "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( | ||||||
|  |                         HfEmbedderSettings::deserialize_from_value( | ||||||
|  |                             v.into_value(), | ||||||
|  |                             location.push_key(&k), | ||||||
|  |                         )?, | ||||||
|  |                     ))), | ||||||
|  |                     "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( | ||||||
|  |                         OpenAiEmbedderSettings::deserialize_from_value( | ||||||
|  |                             v.into_value(), | ||||||
|  |                             location.push_key(&k), | ||||||
|  |                         )?, | ||||||
|  |                     ))), | ||||||
|  |                     "userProvided" => Ok(EmbedderSettings::UserProvided( | ||||||
|  |                         UserProvidedSettings::deserialize_from_value( | ||||||
|  |                             v.into_value(), | ||||||
|  |                             location.push_key(&k), | ||||||
|  |                         )?, | ||||||
|  |                     )), | ||||||
|  |                     other => Err(deserr::take_cf_content(E::error::<V>( | ||||||
|  |                         None, | ||||||
|  |                         deserr::ErrorKind::UnknownKey { | ||||||
|  |                             key: other, | ||||||
|  |                             accepted: &["huggingFace", "openAi", "userProvided"], | ||||||
|  |                         }, | ||||||
|  |                         location, | ||||||
|  |                     ))), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             _ => Err(deserr::take_cf_content(E::error::<V>( | ||||||
|  |                 None, | ||||||
|  |                 deserr::ErrorKind::IncorrectValueKind { | ||||||
|  |                     actual: value, | ||||||
|  |                     accepted: &[deserr::ValueKind::Map], | ||||||
|  |                 }, | ||||||
|  |                 location, | ||||||
|  |             ))), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for EmbedderSettings { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         Self::OpenAi(Default::default()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<crate::vector::EmbedderOptions> for EmbedderSettings { | ||||||
|  |     fn from(value: crate::vector::EmbedderOptions) -> Self { | ||||||
|  |         match value { | ||||||
|  |             crate::vector::EmbedderOptions::HuggingFace(hf) => { | ||||||
|  |                 Self::HuggingFace(Setting::Set(hf.into())) | ||||||
|  |             } | ||||||
|  |             crate::vector::EmbedderOptions::OpenAi(openai) => { | ||||||
|  |                 Self::OpenAi(Setting::Set(openai.into())) | ||||||
|  |             } | ||||||
|  |             crate::vector::EmbedderOptions::UserProvided(user_provided) => { | ||||||
|  |                 Self::UserProvided(user_provided.into()) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<EmbedderSettings> for crate::vector::EmbedderOptions { | ||||||
|  |     fn from(value: EmbedderSettings) -> Self { | ||||||
|  |         match value { | ||||||
|  |             EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), | ||||||
|  |             EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), | ||||||
|  |             EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), | ||||||
|  |             EmbedderSettings::OpenAi(_setting) => { | ||||||
|  |                 Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None)) | ||||||
|  |             } | ||||||
|  |             EmbedderSettings::UserProvided(user_provided) => { | ||||||
|  |                 Self::UserProvided(user_provided.into()) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub struct HfEmbedderSettings { | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||||
|  |     #[deserr(default)] | ||||||
|  |     pub model: Setting<String>, | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||||
|  |     #[deserr(default)] | ||||||
|  |     pub revision: Setting<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl HfEmbedderSettings { | ||||||
|  |     pub fn apply(&mut self, new: Self) { | ||||||
|  |         let HfEmbedderSettings { model, revision } = new; | ||||||
|  |         self.model.apply(model); | ||||||
|  |         self.revision.apply(revision); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings { | ||||||
|  |     fn from(value: crate::vector::hf::EmbedderOptions) -> Self { | ||||||
|  |         Self { | ||||||
|  |             model: Setting::Set(value.model), | ||||||
|  |             revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions { | ||||||
|  |     fn from(value: HfEmbedderSettings) -> Self { | ||||||
|  |         let HfEmbedderSettings { model, revision } = value; | ||||||
|  |         let mut this = Self::default(); | ||||||
|  |         if let Some(model) = model.set() { | ||||||
|  |             this.model = model; | ||||||
|  |         } | ||||||
|  |         if let Some(revision) = revision.set() { | ||||||
|  |             this.revision = Some(revision); | ||||||
|  |         } | ||||||
|  |         this | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub struct OpenAiEmbedderSettings { | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||||
|  |     #[deserr(default)] | ||||||
|  |     pub api_key: Setting<String>, | ||||||
|  |     #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")] | ||||||
|  |     #[deserr(default, rename = "model")] | ||||||
|  |     pub embedding_model: Setting<crate::vector::openai::EmbeddingModel>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl OpenAiEmbedderSettings { | ||||||
|  |     pub fn apply(&mut self, new: Self) { | ||||||
|  |         let Self { api_key, embedding_model: embedding_mode } = new; | ||||||
|  |         self.api_key.apply(api_key); | ||||||
|  |         self.embedding_model.apply(embedding_mode); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings { | ||||||
|  |     fn from(value: crate::vector::openai::EmbedderOptions) -> Self { | ||||||
|  |         Self { | ||||||
|  |             api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset), | ||||||
|  |             embedding_model: Setting::Set(value.embedding_model), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions { | ||||||
|  |     fn from(value: OpenAiEmbedderSettings) -> Self { | ||||||
|  |         let OpenAiEmbedderSettings { api_key, embedding_model } = value; | ||||||
|  |         Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||||
|  | #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||||
|  | #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||||
|  | pub struct UserProvidedSettings { | ||||||
|  |     pub dimensions: usize, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions { | ||||||
|  |     fn from(value: UserProvidedSettings) -> Self { | ||||||
|  |         Self { dimensions: value.dimensions } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings { | ||||||
|  |     fn from(value: crate::vector::manual::EmbedderOptions) -> Self { | ||||||
|  |         Self { dimensions: value.dimensions } | ||||||
|  |     } | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user