mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 13:06:27 +00:00 
			
		
		
		
	Merge pull request #5716 from meilisearch/document-sorting
Allow sorting on the /documents route
This commit is contained in:
		
							
								
								
									
										6
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -11,12 +11,18 @@ | ||||
| /bench | ||||
| /_xtask_benchmark.ms | ||||
| /benchmarks | ||||
| .DS_Store | ||||
|  | ||||
| # Snapshots | ||||
| ## ... large | ||||
| *.full.snap | ||||
| ## ... unreviewed | ||||
| *.snap.new | ||||
| ## ... pending | ||||
| *.pending-snap | ||||
|  | ||||
| # Tmp files | ||||
| .tmp* | ||||
|  | ||||
| # Database snapshot | ||||
| crates/meilisearch/db.snapshot | ||||
|   | ||||
| @@ -51,3 +51,8 @@ harness = false | ||||
| [[bench]] | ||||
| name = "indexing" | ||||
| harness = false | ||||
|  | ||||
| [[bench]] | ||||
| name = "sort" | ||||
| harness = false | ||||
|  | ||||
|   | ||||
							
								
								
									
										114
									
								
								crates/benchmarks/benches/sort.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								crates/benchmarks/benches/sort.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| //! This benchmark module is used to compare the performance of sorting documents in /search VS /documents | ||||
| //! | ||||
| //! The tests/benchmarks were designed in the context of a query returning only 20 documents. | ||||
|  | ||||
| mod datasets_paths; | ||||
| mod utils; | ||||
|  | ||||
| use criterion::{criterion_group, criterion_main}; | ||||
| use milli::update::Settings; | ||||
| use utils::Conf; | ||||
|  | ||||
| #[cfg(not(windows))] | ||||
| #[global_allocator] | ||||
| static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; | ||||
|  | ||||
| fn base_conf(builder: &mut Settings) { | ||||
|     let displayed_fields = | ||||
|         ["geonameid", "name", "asciiname", "alternatenames", "_geo", "population"] | ||||
|             .iter() | ||||
|             .map(|s| s.to_string()) | ||||
|             .collect(); | ||||
|     builder.set_displayed_fields(displayed_fields); | ||||
|  | ||||
|     let sortable_fields = | ||||
|         ["_geo", "name", "population", "elevation", "timezone", "modification-date"] | ||||
|             .iter() | ||||
|             .map(|s| s.to_string()) | ||||
|             .collect(); | ||||
|     builder.set_sortable_fields(sortable_fields); | ||||
| } | ||||
|  | ||||
| #[rustfmt::skip] | ||||
| const BASE_CONF: Conf = Conf { | ||||
|     dataset: datasets_paths::SMOL_ALL_COUNTRIES, | ||||
|     dataset_format: "jsonl", | ||||
|     configure: base_conf, | ||||
|     primary_key: Some("geonameid"), | ||||
|     queries: &[""], | ||||
|     offsets: &[ | ||||
|         Some((0, 20)), // The most common query in the real world | ||||
|         Some((0, 500)), // A query that ranges over many documents | ||||
|         Some((980, 20)), // The worst query that could happen in the real world | ||||
|         Some((800_000, 20)) // The worst query | ||||
|     ], | ||||
|     get_documents: true, | ||||
|     ..Conf::BASE | ||||
| }; | ||||
|  | ||||
| fn bench_sort(c: &mut criterion::Criterion) { | ||||
|     #[rustfmt::skip] | ||||
|     let confs = &[ | ||||
|         utils::Conf { | ||||
|             group_name: "without sort", | ||||
|             sort: None, | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many different values", | ||||
|             sort: Some(vec!["name:asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many similar values", | ||||
|             sort: Some(vec!["timezone:desc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many similar then different values", | ||||
|             sort: Some(vec!["timezone:desc", "name:asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many different then similar values", | ||||
|             sort: Some(vec!["timezone:desc", "name:asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "geo sort", | ||||
|             sample_size: Some(10), | ||||
|             sort: Some(vec!["_geoPoint(45.4777599, 9.1967508):asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many similar values then geo sort", | ||||
|             sample_size: Some(50), | ||||
|             sort: Some(vec!["timezone:desc", "_geoPoint(45.4777599, 9.1967508):asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many different values then geo sort", | ||||
|             sample_size: Some(50), | ||||
|             sort: Some(vec!["name:desc", "_geoPoint(45.4777599, 9.1967508):asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|  | ||||
|         utils::Conf { | ||||
|             group_name: "sort on many fields", | ||||
|             sort: Some(vec!["population:asc", "name:asc", "elevation:asc", "timezone:asc"]), | ||||
|             ..BASE_CONF | ||||
|         }, | ||||
|     ]; | ||||
|  | ||||
|     utils::run_benches(c, confs); | ||||
| } | ||||
|  | ||||
| criterion_group!(benches, bench_sort); | ||||
| criterion_main!(benches); | ||||
| @@ -9,6 +9,7 @@ use anyhow::Context; | ||||
| use bumpalo::Bump; | ||||
| use criterion::BenchmarkId; | ||||
| use memmap2::Mmap; | ||||
| use milli::documents::sort::recursive_sort; | ||||
| use milli::heed::EnvOpenOptions; | ||||
| use milli::progress::Progress; | ||||
| use milli::update::new::indexer; | ||||
| @@ -35,6 +36,12 @@ pub struct Conf<'a> { | ||||
|     pub configure: fn(&mut Settings), | ||||
|     pub filter: Option<&'a str>, | ||||
|     pub sort: Option<Vec<&'a str>>, | ||||
|     /// set to skip documents (offset, limit) | ||||
|     pub offsets: &'a [Option<(usize, usize)>], | ||||
|     /// enable if you want to bench getting documents without querying | ||||
|     pub get_documents: bool, | ||||
|     /// configure the benchmark sample size | ||||
|     pub sample_size: Option<usize>, | ||||
|     /// enable or disable the optional words on the query | ||||
|     pub optional_words: bool, | ||||
|     /// primary key, if there is None we'll auto-generate docids for every documents | ||||
| @@ -52,6 +59,9 @@ impl Conf<'_> { | ||||
|         configure: |_| (), | ||||
|         filter: None, | ||||
|         sort: None, | ||||
|         offsets: &[None], | ||||
|         get_documents: false, | ||||
|         sample_size: None, | ||||
|         optional_words: true, | ||||
|         primary_key: None, | ||||
|     }; | ||||
| @@ -145,13 +155,26 @@ pub fn run_benches(c: &mut criterion::Criterion, confs: &[Conf]) { | ||||
|         let file_name = Path::new(conf.dataset).file_name().and_then(|f| f.to_str()).unwrap(); | ||||
|         let name = format!("{}: {}", file_name, conf.group_name); | ||||
|         let mut group = c.benchmark_group(&name); | ||||
|         if let Some(sample_size) = conf.sample_size { | ||||
|             group.sample_size(sample_size); | ||||
|         } | ||||
|  | ||||
|         for &query in conf.queries { | ||||
|             group.bench_with_input(BenchmarkId::from_parameter(query), &query, |b, &query| { | ||||
|             for offset in conf.offsets { | ||||
|                 let parameter = match offset { | ||||
|                     None => query.to_string(), | ||||
|                     Some((offset, limit)) => format!("{query}[{offset}:{limit}]"), | ||||
|                 }; | ||||
|                 group.bench_with_input( | ||||
|                     BenchmarkId::from_parameter(parameter), | ||||
|                     &query, | ||||
|                     |b, &query| { | ||||
|                         b.iter(|| { | ||||
|                             let rtxn = index.read_txn().unwrap(); | ||||
|                             let mut search = index.search(&rtxn); | ||||
|                     search.query(query).terms_matching_strategy(TermsMatchingStrategy::default()); | ||||
|                             search | ||||
|                                 .query(query) | ||||
|                                 .terms_matching_strategy(TermsMatchingStrategy::default()); | ||||
|                             if let Some(filter) = conf.filter { | ||||
|                                 let filter = Filter::from_str(filter).unwrap().unwrap(); | ||||
|                                 search.filter(filter); | ||||
| @@ -160,10 +183,51 @@ pub fn run_benches(c: &mut criterion::Criterion, confs: &[Conf]) { | ||||
|                                 let sort = sort.iter().map(|sort| sort.parse().unwrap()).collect(); | ||||
|                                 search.sort_criteria(sort); | ||||
|                             } | ||||
|                             if let Some((offset, limit)) = offset { | ||||
|                                 search.offset(*offset).limit(*limit); | ||||
|                             } | ||||
|  | ||||
|                             let _ids = search.execute().unwrap(); | ||||
|                         }); | ||||
|                     }, | ||||
|                 ); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if conf.get_documents { | ||||
|             for offset in conf.offsets { | ||||
|                 let parameter = match offset { | ||||
|                     None => String::from("get_documents"), | ||||
|                     Some((offset, limit)) => format!("get_documents[{offset}:{limit}]"), | ||||
|                 }; | ||||
|                 group.bench_with_input(BenchmarkId::from_parameter(parameter), &(), |b, &()| { | ||||
|                     b.iter(|| { | ||||
|                         let rtxn = index.read_txn().unwrap(); | ||||
|                         if let Some(sort) = &conf.sort { | ||||
|                             let sort = sort.iter().map(|sort| sort.parse().unwrap()).collect(); | ||||
|                             let all_docs = index.documents_ids(&rtxn).unwrap(); | ||||
|                             let facet_sort = | ||||
|                                 recursive_sort(&index, &rtxn, sort, &all_docs).unwrap(); | ||||
|                             let iter = facet_sort.iter().unwrap(); | ||||
|                             if let Some((offset, limit)) = offset { | ||||
|                                 let _results = iter.skip(*offset).take(*limit).collect::<Vec<_>>(); | ||||
|                             } else { | ||||
|                                 let _results = iter.collect::<Vec<_>>(); | ||||
|                             } | ||||
|                         } else { | ||||
|                             let all_docs = index.documents_ids(&rtxn).unwrap(); | ||||
|                             if let Some((offset, limit)) = offset { | ||||
|                                 let _results = | ||||
|                                     all_docs.iter().skip(*offset).take(*limit).collect::<Vec<_>>(); | ||||
|                             } else { | ||||
|                                 let _results = all_docs.iter().collect::<Vec<_>>(); | ||||
|                             } | ||||
|                         } | ||||
|                     }); | ||||
|                 }); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         group.finish(); | ||||
|  | ||||
|         index.prepare_for_closing().wait(); | ||||
|   | ||||
| @@ -237,6 +237,7 @@ InvalidDocumentRetrieveVectors                 , InvalidRequest       , BAD_REQU | ||||
| MissingDocumentFilter                          , InvalidRequest       , BAD_REQUEST ; | ||||
| MissingDocumentEditionFunction                 , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidDocumentFilter                          , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidDocumentSort                            , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidDocumentGeoField                        , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidVectorDimensions                        , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidVectorsType                             , InvalidRequest       , BAD_REQUEST ; | ||||
| @@ -477,7 +478,8 @@ impl ErrorCode for milli::Error { | ||||
|                     UserError::InvalidDistinctAttribute { .. } => Code::InvalidSearchDistinct, | ||||
|                     UserError::SortRankingRuleMissing => Code::InvalidSearchSort, | ||||
|                     UserError::InvalidFacetsDistribution { .. } => Code::InvalidSearchFacets, | ||||
|                     UserError::InvalidSortableAttribute { .. } => Code::InvalidSearchSort, | ||||
|                     UserError::InvalidSearchSortableAttribute { .. } => Code::InvalidSearchSort, | ||||
|                     UserError::InvalidDocumentSortableAttribute { .. } => Code::InvalidDocumentSort, | ||||
|                     UserError::InvalidSearchableAttribute { .. } => { | ||||
|                         Code::InvalidSearchAttributesToSearchOn | ||||
|                     } | ||||
| @@ -493,7 +495,8 @@ impl ErrorCode for milli::Error { | ||||
|                     UserError::InvalidVectorsMapType { .. } | ||||
|                     | UserError::InvalidVectorsEmbedderConf { .. } => Code::InvalidVectorsType, | ||||
|                     UserError::TooManyVectors(_, _) => Code::TooManyVectors, | ||||
|                     UserError::SortError(_) => Code::InvalidSearchSort, | ||||
|                     UserError::SortError { search: true, .. } => Code::InvalidSearchSort, | ||||
|                     UserError::SortError { search: false, .. } => Code::InvalidDocumentSort, | ||||
|                     UserError::InvalidMinTypoWordLenSetting(_, _) => { | ||||
|                         Code::InvalidSettingsTypoTolerance | ||||
|                     } | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								crates/meilisearch/db.snapshot
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								crates/meilisearch/db.snapshot
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -104,6 +104,4 @@ impl Analytics for MockAnalytics { | ||||
|         _request: &HttpRequest, | ||||
|     ) { | ||||
|     } | ||||
|     fn get_fetch_documents(&self, _documents_query: &DocumentFetchKind, _request: &HttpRequest) {} | ||||
|     fn post_fetch_documents(&self, _documents_query: &DocumentFetchKind, _request: &HttpRequest) {} | ||||
| } | ||||
|   | ||||
| @@ -73,12 +73,6 @@ pub enum DocumentDeletionKind { | ||||
|     PerFilter, | ||||
| } | ||||
|  | ||||
| #[derive(Copy, Clone, Debug, PartialEq, Eq)] | ||||
| pub enum DocumentFetchKind { | ||||
|     PerDocumentId { retrieve_vectors: bool }, | ||||
|     Normal { with_filter: bool, limit: usize, offset: usize, retrieve_vectors: bool }, | ||||
| } | ||||
|  | ||||
| /// To send an event to segment, your event must be able to aggregate itself with another event of the same type. | ||||
| pub trait Aggregate: 'static + mopa::Any + Send { | ||||
|     /// The name of the event that will be sent to segment. | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| use std::collections::HashSet; | ||||
| use std::io::{ErrorKind, Seek as _}; | ||||
| use std::marker::PhantomData; | ||||
| use std::str::FromStr; | ||||
|  | ||||
| use actix_web::http::header::CONTENT_TYPE; | ||||
| use actix_web::web::Data; | ||||
| @@ -17,9 +18,10 @@ use meilisearch_types::error::deserr_codes::*; | ||||
| use meilisearch_types::error::{Code, ResponseError}; | ||||
| use meilisearch_types::heed::RoTxn; | ||||
| use meilisearch_types::index_uid::IndexUid; | ||||
| use meilisearch_types::milli::documents::sort::recursive_sort; | ||||
| use meilisearch_types::milli::update::IndexDocumentsMethod; | ||||
| use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors; | ||||
| use meilisearch_types::milli::DocumentId; | ||||
| use meilisearch_types::milli::{AscDesc, DocumentId}; | ||||
| use meilisearch_types::serde_cs::vec::CS; | ||||
| use meilisearch_types::star_or::OptionStarOrList; | ||||
| use meilisearch_types::tasks::KindWithContent; | ||||
| @@ -42,6 +44,7 @@ use crate::extractors::authentication::policies::*; | ||||
| use crate::extractors::authentication::GuardedData; | ||||
| use crate::extractors::payload::Payload; | ||||
| use crate::extractors::sequential_extractor::SeqHandler; | ||||
| use crate::routes::indexes::search::fix_sort_query_parameters; | ||||
| use crate::routes::{ | ||||
|     get_task_id, is_dry_run, PaginationView, SummarizedTaskView, PAGINATION_DEFAULT_LIMIT, | ||||
| }; | ||||
| @@ -135,6 +138,8 @@ pub struct DocumentsFetchAggregator<Method: AggregateMethod> { | ||||
|     per_document_id: bool, | ||||
|     // if a filter was used | ||||
|     per_filter: bool, | ||||
|     // if documents were sorted | ||||
|     sort: bool, | ||||
|  | ||||
|     #[serde(rename = "vector.retrieve_vectors")] | ||||
|     retrieve_vectors: bool, | ||||
| @@ -151,39 +156,6 @@ pub struct DocumentsFetchAggregator<Method: AggregateMethod> { | ||||
|     marker: std::marker::PhantomData<Method>, | ||||
| } | ||||
|  | ||||
| #[derive(Copy, Clone, Debug, PartialEq, Eq)] | ||||
| pub enum DocumentFetchKind { | ||||
|     PerDocumentId { retrieve_vectors: bool }, | ||||
|     Normal { with_filter: bool, limit: usize, offset: usize, retrieve_vectors: bool, ids: usize }, | ||||
| } | ||||
|  | ||||
| impl<Method: AggregateMethod> DocumentsFetchAggregator<Method> { | ||||
|     pub fn from_query(query: &DocumentFetchKind) -> Self { | ||||
|         let (limit, offset, retrieve_vectors) = match query { | ||||
|             DocumentFetchKind::PerDocumentId { retrieve_vectors } => (1, 0, *retrieve_vectors), | ||||
|             DocumentFetchKind::Normal { limit, offset, retrieve_vectors, .. } => { | ||||
|                 (*limit, *offset, *retrieve_vectors) | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         let ids = match query { | ||||
|             DocumentFetchKind::Normal { ids, .. } => *ids, | ||||
|             DocumentFetchKind::PerDocumentId { .. } => 0, | ||||
|         }; | ||||
|  | ||||
|         Self { | ||||
|             per_document_id: matches!(query, DocumentFetchKind::PerDocumentId { .. }), | ||||
|             per_filter: matches!(query, DocumentFetchKind::Normal { with_filter, .. } if *with_filter), | ||||
|             max_limit: limit, | ||||
|             max_offset: offset, | ||||
|             retrieve_vectors, | ||||
|             max_document_ids: ids, | ||||
|  | ||||
|             marker: PhantomData, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<Method: AggregateMethod> Aggregate for DocumentsFetchAggregator<Method> { | ||||
|     fn event_name(&self) -> &'static str { | ||||
|         Method::event_name() | ||||
| @@ -193,6 +165,7 @@ impl<Method: AggregateMethod> Aggregate for DocumentsFetchAggregator<Method> { | ||||
|         Box::new(Self { | ||||
|             per_document_id: self.per_document_id | new.per_document_id, | ||||
|             per_filter: self.per_filter | new.per_filter, | ||||
|             sort: self.sort | new.sort, | ||||
|             retrieve_vectors: self.retrieve_vectors | new.retrieve_vectors, | ||||
|             max_limit: self.max_limit.max(new.max_limit), | ||||
|             max_offset: self.max_offset.max(new.max_offset), | ||||
| @@ -276,6 +249,7 @@ pub async fn get_document( | ||||
|             retrieve_vectors: param_retrieve_vectors.0, | ||||
|             per_document_id: true, | ||||
|             per_filter: false, | ||||
|             sort: false, | ||||
|             max_limit: 0, | ||||
|             max_offset: 0, | ||||
|             max_document_ids: 0, | ||||
| @@ -406,6 +380,8 @@ pub struct BrowseQueryGet { | ||||
|     #[param(default, value_type = Option<String>, example = "popularity > 1000")] | ||||
|     #[deserr(default, error = DeserrQueryParamError<InvalidDocumentFilter>)] | ||||
|     filter: Option<String>, | ||||
|     #[deserr(default, error = DeserrQueryParamError<InvalidDocumentSort>)] | ||||
|     sort: Option<String>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserr, ToSchema)] | ||||
| @@ -430,6 +406,9 @@ pub struct BrowseQuery { | ||||
|     #[schema(default, value_type = Option<Value>, example = "popularity > 1000")] | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidDocumentFilter>)] | ||||
|     filter: Option<Value>, | ||||
|     #[schema(default, value_type = Option<Vec<String>>, example = json!(["title:asc", "rating:desc"]))] | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidDocumentSort>)] | ||||
|     sort: Option<Vec<String>>, | ||||
| } | ||||
|  | ||||
| /// Get documents with POST | ||||
| @@ -495,6 +474,7 @@ pub async fn documents_by_query_post( | ||||
|     analytics.publish( | ||||
|         DocumentsFetchAggregator::<DocumentsPOST> { | ||||
|             per_filter: body.filter.is_some(), | ||||
|             sort: body.sort.is_some(), | ||||
|             retrieve_vectors: body.retrieve_vectors, | ||||
|             max_limit: body.limit, | ||||
|             max_offset: body.offset, | ||||
| @@ -571,7 +551,7 @@ pub async fn get_documents( | ||||
| ) -> Result<HttpResponse, ResponseError> { | ||||
|     debug!(parameters = ?params, "Get documents GET"); | ||||
|  | ||||
|     let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids } = | ||||
|     let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids, sort } = | ||||
|         params.into_inner(); | ||||
|  | ||||
|     let filter = match filter { | ||||
| @@ -582,20 +562,20 @@ pub async fn get_documents( | ||||
|         None => None, | ||||
|     }; | ||||
|  | ||||
|     let ids = ids.map(|ids| ids.into_iter().map(Into::into).collect()); | ||||
|  | ||||
|     let query = BrowseQuery { | ||||
|         offset: offset.0, | ||||
|         limit: limit.0, | ||||
|         fields: fields.merge_star_and_none(), | ||||
|         retrieve_vectors: retrieve_vectors.0, | ||||
|         filter, | ||||
|         ids, | ||||
|         ids: ids.map(|ids| ids.into_iter().map(Into::into).collect()), | ||||
|         sort: sort.map(|attr| fix_sort_query_parameters(&attr)), | ||||
|     }; | ||||
|  | ||||
|     analytics.publish( | ||||
|         DocumentsFetchAggregator::<DocumentsGET> { | ||||
|             per_filter: query.filter.is_some(), | ||||
|             sort: query.sort.is_some(), | ||||
|             retrieve_vectors: query.retrieve_vectors, | ||||
|             max_limit: query.limit, | ||||
|             max_offset: query.offset, | ||||
| @@ -615,7 +595,7 @@ fn documents_by_query( | ||||
|     query: BrowseQuery, | ||||
| ) -> Result<HttpResponse, ResponseError> { | ||||
|     let index_uid = IndexUid::try_from(index_uid.into_inner())?; | ||||
|     let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids } = query; | ||||
|     let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids, sort } = query; | ||||
|  | ||||
|     let retrieve_vectors = RetrieveVectors::new(retrieve_vectors); | ||||
|  | ||||
| @@ -633,6 +613,18 @@ fn documents_by_query( | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     let sort_criteria = if let Some(sort) = &sort { | ||||
|         let sorts: Vec<_> = match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() { | ||||
|             Ok(sorts) => sorts, | ||||
|             Err(asc_desc_error) => { | ||||
|                 return Err(milli::SortError::from(asc_desc_error).into_document_error().into()) | ||||
|             } | ||||
|         }; | ||||
|         Some(sorts) | ||||
|     } else { | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     let index = index_scheduler.index(&index_uid)?; | ||||
|     let (total, documents) = retrieve_documents( | ||||
|         &index, | ||||
| @@ -643,6 +635,7 @@ fn documents_by_query( | ||||
|         fields, | ||||
|         retrieve_vectors, | ||||
|         index_scheduler.features(), | ||||
|         sort_criteria, | ||||
|     )?; | ||||
|  | ||||
|     let ret = PaginationView::new(offset, limit, total as usize, documents); | ||||
| @@ -1494,6 +1487,7 @@ fn retrieve_documents<S: AsRef<str>>( | ||||
|     attributes_to_retrieve: Option<Vec<S>>, | ||||
|     retrieve_vectors: RetrieveVectors, | ||||
|     features: RoFeatures, | ||||
|     sort_criteria: Option<Vec<AscDesc>>, | ||||
| ) -> Result<(u64, Vec<Document>), ResponseError> { | ||||
|     let rtxn = index.read_txn()?; | ||||
|     let filter = &filter; | ||||
| @@ -1526,15 +1520,32 @@ fn retrieve_documents<S: AsRef<str>>( | ||||
|         })? | ||||
|     } | ||||
|  | ||||
|     let (it, number_of_documents) = { | ||||
|     let (it, number_of_documents) = if let Some(sort) = sort_criteria { | ||||
|         let number_of_documents = candidates.len(); | ||||
|         let facet_sort = recursive_sort(index, &rtxn, sort, &candidates)?; | ||||
|         let iter = facet_sort.iter()?; | ||||
|         let mut documents = Vec::with_capacity(limit); | ||||
|         for result in iter.skip(offset).take(limit) { | ||||
|             documents.push(result?); | ||||
|         } | ||||
|         ( | ||||
|             itertools::Either::Left(some_documents( | ||||
|                 index, | ||||
|                 &rtxn, | ||||
|                 documents.into_iter(), | ||||
|                 retrieve_vectors, | ||||
|             )?), | ||||
|             number_of_documents, | ||||
|         ) | ||||
|     } else { | ||||
|         let number_of_documents = candidates.len(); | ||||
|         ( | ||||
|             some_documents( | ||||
|             itertools::Either::Right(some_documents( | ||||
|                 index, | ||||
|                 &rtxn, | ||||
|                 candidates.into_iter().skip(offset).take(limit), | ||||
|                 retrieve_vectors, | ||||
|             )?, | ||||
|             )?), | ||||
|             number_of_documents, | ||||
|         ) | ||||
|     }; | ||||
|   | ||||
| @@ -745,9 +745,8 @@ impl SearchByIndex { | ||||
|                         match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() { | ||||
|                             Ok(sorts) => sorts, | ||||
|                             Err(asc_desc_error) => { | ||||
|                                 return Err(milli::Error::from(milli::SortError::from( | ||||
|                                     asc_desc_error, | ||||
|                                 )) | ||||
|                                 return Err(milli::SortError::from(asc_desc_error) | ||||
|                                     .into_search_error() | ||||
|                                     .into()) | ||||
|                             } | ||||
|                         }; | ||||
|   | ||||
| @@ -1092,7 +1092,7 @@ pub fn prepare_search<'t>( | ||||
|         let sort = match sort.iter().map(|s| AscDesc::from_str(s)).collect() { | ||||
|             Ok(sorts) => sorts, | ||||
|             Err(asc_desc_error) => { | ||||
|                 return Err(milli::Error::from(SortError::from(asc_desc_error)).into()) | ||||
|                 return Err(SortError::from(asc_desc_error).into_search_error().into()) | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|   | ||||
| @@ -562,5 +562,7 @@ pub struct GetAllDocumentsOptions { | ||||
|     pub offset: Option<usize>, | ||||
|     #[serde(skip_serializing_if = "Option::is_none")] | ||||
|     pub fields: Option<Vec<&'static str>>, | ||||
|     #[serde(skip_serializing_if = "Option::is_none")] | ||||
|     pub sort: Option<Vec<&'static str>>, | ||||
|     pub retrieve_vectors: bool, | ||||
| } | ||||
|   | ||||
| @@ -5,8 +5,8 @@ use urlencoding::encode as urlencode; | ||||
|  | ||||
| use crate::common::encoder::Encoder; | ||||
| use crate::common::{ | ||||
|     shared_does_not_exists_index, shared_empty_index, shared_index_with_test_set, | ||||
|     GetAllDocumentsOptions, Server, Value, | ||||
|     shared_does_not_exists_index, shared_empty_index, shared_index_with_geo_documents, | ||||
|     shared_index_with_test_set, GetAllDocumentsOptions, Server, Value, | ||||
| }; | ||||
| use crate::json; | ||||
|  | ||||
| @@ -83,6 +83,311 @@ async fn get_document() { | ||||
|     ); | ||||
| } | ||||
|  | ||||
| #[actix_rt::test] | ||||
| async fn get_document_sorted() { | ||||
|     let server = Server::new_shared(); | ||||
|     let index = server.unique_index(); | ||||
|     index.load_test_set().await; | ||||
|  | ||||
|     let (task, _status_code) = | ||||
|         index.update_settings_sortable_attributes(json!(["age", "email", "gender", "name"])).await; | ||||
|     server.wait_task(task.uid()).await.succeeded(); | ||||
|  | ||||
|     let (response, _code) = index | ||||
|         .get_all_documents(GetAllDocumentsOptions { | ||||
|             fields: Some(vec!["id", "age", "email"]), | ||||
|             sort: Some(vec!["age:asc", "email:desc"]), | ||||
|             ..Default::default() | ||||
|         }) | ||||
|         .await; | ||||
|     let results = response["results"].as_array().unwrap(); | ||||
|     snapshot!(json_string!(results), @r#" | ||||
|     [ | ||||
|       { | ||||
|         "id": 5, | ||||
|         "age": 20, | ||||
|         "email": "warrenwatson@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 6, | ||||
|         "age": 20, | ||||
|         "email": "sheliaberry@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 57, | ||||
|         "age": 20, | ||||
|         "email": "kaitlinconner@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 45, | ||||
|         "age": 20, | ||||
|         "email": "irenebennett@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 40, | ||||
|         "age": 21, | ||||
|         "email": "staffordemerson@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 41, | ||||
|         "age": 21, | ||||
|         "email": "salinasgamble@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 63, | ||||
|         "age": 21, | ||||
|         "email": "knowleshebert@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 50, | ||||
|         "age": 21, | ||||
|         "email": "guerramcintyre@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 44, | ||||
|         "age": 22, | ||||
|         "email": "jonispears@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 56, | ||||
|         "age": 23, | ||||
|         "email": "tuckerbarry@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 51, | ||||
|         "age": 23, | ||||
|         "email": "keycervantes@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 60, | ||||
|         "age": 23, | ||||
|         "email": "jodyherrera@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 70, | ||||
|         "age": 23, | ||||
|         "email": "glassperkins@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 75, | ||||
|         "age": 24, | ||||
|         "email": "emmajacobs@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 68, | ||||
|         "age": 24, | ||||
|         "email": "angelinadyer@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 17, | ||||
|         "age": 25, | ||||
|         "email": "ortegabrennan@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 76, | ||||
|         "age": 25, | ||||
|         "email": "claricegardner@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 43, | ||||
|         "age": 25, | ||||
|         "email": "arnoldbender@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 12, | ||||
|         "age": 25, | ||||
|         "email": "aidakirby@chorizon.com" | ||||
|       }, | ||||
|       { | ||||
|         "id": 9, | ||||
|         "age": 26, | ||||
|         "email": "kellimendez@chorizon.com" | ||||
|       } | ||||
|     ] | ||||
|     "#); | ||||
|  | ||||
|     let (response, _code) = index | ||||
|         .get_all_documents(GetAllDocumentsOptions { | ||||
|             fields: Some(vec!["id", "gender", "name"]), | ||||
|             sort: Some(vec!["gender:asc", "name:asc"]), | ||||
|             ..Default::default() | ||||
|         }) | ||||
|         .await; | ||||
|     let results = response["results"].as_array().unwrap(); | ||||
|     snapshot!(json_string!(results), @r#" | ||||
|     [ | ||||
|       { | ||||
|         "id": 3, | ||||
|         "name": "Adeline Flynn", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 12, | ||||
|         "name": "Aida Kirby", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 68, | ||||
|         "name": "Angelina Dyer", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 15, | ||||
|         "name": "Aurelia Contreras", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 36, | ||||
|         "name": "Barbra Valenzuela", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 23, | ||||
|         "name": "Blanca Mcclain", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 53, | ||||
|         "name": "Caitlin Burnett", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 71, | ||||
|         "name": "Candace Sawyer", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 65, | ||||
|         "name": "Carole Rowland", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 33, | ||||
|         "name": "Cecilia Greer", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 1, | ||||
|         "name": "Cherry Orr", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 38, | ||||
|         "name": "Christina Short", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 7, | ||||
|         "name": "Chrystal Boyd", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 76, | ||||
|         "name": "Clarice Gardner", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 73, | ||||
|         "name": "Eleanor Shepherd", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 75, | ||||
|         "name": "Emma Jacobs", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 16, | ||||
|         "name": "Estella Bass", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 62, | ||||
|         "name": "Estelle Ramirez", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 20, | ||||
|         "name": "Florence Long", | ||||
|         "gender": "female" | ||||
|       }, | ||||
|       { | ||||
|         "id": 42, | ||||
|         "name": "Graciela Russell", | ||||
|         "gender": "female" | ||||
|       } | ||||
|     ] | ||||
|     "#); | ||||
| } | ||||
|  | ||||
| #[actix_rt::test] | ||||
| async fn get_document_geosorted() { | ||||
|     let index = shared_index_with_geo_documents().await; | ||||
|  | ||||
|     let (response, _code) = index | ||||
|         .get_all_documents(GetAllDocumentsOptions { | ||||
|             sort: Some(vec!["_geoPoint(45.4777599, 9.1967508):asc"]), | ||||
|             ..Default::default() | ||||
|         }) | ||||
|         .await; | ||||
|     let results = response["results"].as_array().unwrap(); | ||||
|     snapshot!(json_string!(results), @r#" | ||||
|     [ | ||||
|       { | ||||
|         "id": 2, | ||||
|         "name": "La Bella Italia", | ||||
|         "address": "456 Elm Street, Townsville", | ||||
|         "type": "Italian", | ||||
|         "rating": 9, | ||||
|         "_geo": { | ||||
|           "lat": "45.4777599", | ||||
|           "lng": "9.1967508" | ||||
|         } | ||||
|       }, | ||||
|       { | ||||
|         "id": 1, | ||||
|         "name": "Taco Truck", | ||||
|         "address": "444 Salsa Street, Burritoville", | ||||
|         "type": "Mexican", | ||||
|         "rating": 9, | ||||
|         "_geo": { | ||||
|           "lat": 34.0522, | ||||
|           "lng": -118.2437 | ||||
|         } | ||||
|       }, | ||||
|       { | ||||
|         "id": 3, | ||||
|         "name": "Crêpe Truck", | ||||
|         "address": "2 Billig Avenue, Rouenville", | ||||
|         "type": "French", | ||||
|         "rating": 10 | ||||
|       } | ||||
|     ] | ||||
|     "#); | ||||
| } | ||||
|  | ||||
| #[actix_rt::test] | ||||
| async fn get_document_sort_the_unsortable() { | ||||
|     let index = shared_index_with_test_set().await; | ||||
|  | ||||
|     let (response, _code) = index | ||||
|         .get_all_documents(GetAllDocumentsOptions { | ||||
|             fields: Some(vec!["id", "name"]), | ||||
|             sort: Some(vec!["name:asc"]), | ||||
|             ..Default::default() | ||||
|         }) | ||||
|         .await; | ||||
|  | ||||
|     snapshot!(json_string!(response), @r#" | ||||
|     { | ||||
|       "message": "Attribute `name` is not sortable. This index does not have configured sortable attributes.", | ||||
|       "code": "invalid_document_sort", | ||||
|       "type": "invalid_request", | ||||
|       "link": "https://docs.meilisearch.com/errors#invalid_document_sort" | ||||
|     } | ||||
|     "#); | ||||
| } | ||||
|  | ||||
| #[actix_rt::test] | ||||
| async fn error_get_unexisting_index_all_documents() { | ||||
|     let index = shared_does_not_exists_index().await; | ||||
|   | ||||
| @@ -101,14 +101,7 @@ async fn reset_embedder_documents() { | ||||
|     server.wait_task(response.uid()).await; | ||||
|  | ||||
|     // Make sure the documents are still present | ||||
|     let (documents, _code) = index | ||||
|         .get_all_documents(GetAllDocumentsOptions { | ||||
|             limit: None, | ||||
|             offset: None, | ||||
|             retrieve_vectors: false, | ||||
|             fields: None, | ||||
|         }) | ||||
|         .await; | ||||
|     let (documents, _code) = index.get_all_documents(GetAllDocumentsOptions::default()).await; | ||||
|     snapshot!(json_string!(documents), @r###" | ||||
|     { | ||||
|       "results": [ | ||||
|   | ||||
| @@ -168,6 +168,16 @@ pub enum SortError { | ||||
|     ReservedNameForFilter { name: String }, | ||||
| } | ||||
|  | ||||
| impl SortError { | ||||
|     pub fn into_search_error(self) -> Error { | ||||
|         Error::UserError(UserError::SortError { error: self, search: true }) | ||||
|     } | ||||
|  | ||||
|     pub fn into_document_error(self) -> Error { | ||||
|         Error::UserError(UserError::SortError { error: self, search: false }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<AscDescError> for SortError { | ||||
|     fn from(error: AscDescError) -> Self { | ||||
|         match error { | ||||
| @@ -190,12 +200,6 @@ impl From<AscDescError> for SortError { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<SortError> for Error { | ||||
|     fn from(error: SortError) -> Self { | ||||
|         Self::UserError(UserError::SortError(error)) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use big_s::S; | ||||
|   | ||||
							
								
								
									
										294
									
								
								crates/milli/src/documents/geo_sort.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								crates/milli/src/documents/geo_sort.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,294 @@ | ||||
| use crate::{ | ||||
|     distance_between_two_points, | ||||
|     heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec}, | ||||
|     lat_lng_to_xyz, | ||||
|     search::new::{facet_string_values, facet_values_prefix_key}, | ||||
|     GeoPoint, Index, | ||||
| }; | ||||
| use heed::{ | ||||
|     types::{Bytes, Unit}, | ||||
|     RoPrefix, RoTxn, | ||||
| }; | ||||
| use roaring::RoaringBitmap; | ||||
| use rstar::RTree; | ||||
| use std::collections::VecDeque; | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub struct GeoSortParameter { | ||||
|     // Define the strategy used by the geo sort | ||||
|     pub strategy: GeoSortStrategy, | ||||
|     // Limit the number of docs in a single bucket to avoid unexpectedly large overhead | ||||
|     pub max_bucket_size: u64, | ||||
|     // Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal | ||||
|     pub distance_error_margin: f64, | ||||
| } | ||||
|  | ||||
| impl Default for GeoSortParameter { | ||||
|     fn default() -> Self { | ||||
|         Self { | ||||
|             strategy: GeoSortStrategy::default(), | ||||
|             max_bucket_size: 1000, | ||||
|             distance_error_margin: 1.0, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| /// Define the strategy used by the geo sort. | ||||
| /// The parameter represents the cache size, and, in the case of the Dynamic strategy, | ||||
| /// the point where we move from using the iterative strategy to the rtree. | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub enum GeoSortStrategy { | ||||
|     AlwaysIterative(usize), | ||||
|     AlwaysRtree(usize), | ||||
|     Dynamic(usize), | ||||
| } | ||||
|  | ||||
| impl Default for GeoSortStrategy { | ||||
|     fn default() -> Self { | ||||
|         GeoSortStrategy::Dynamic(1000) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl GeoSortStrategy { | ||||
|     pub fn use_rtree(&self, candidates: usize) -> bool { | ||||
|         match self { | ||||
|             GeoSortStrategy::AlwaysIterative(_) => false, | ||||
|             GeoSortStrategy::AlwaysRtree(_) => true, | ||||
|             GeoSortStrategy::Dynamic(i) => candidates >= *i, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn cache_size(&self) -> usize { | ||||
|         match self { | ||||
|             GeoSortStrategy::AlwaysIterative(i) | ||||
|             | GeoSortStrategy::AlwaysRtree(i) | ||||
|             | GeoSortStrategy::Dynamic(i) => *i, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[allow(clippy::too_many_arguments)] | ||||
| pub fn fill_cache( | ||||
|     index: &Index, | ||||
|     txn: &RoTxn<heed::AnyTls>, | ||||
|     strategy: GeoSortStrategy, | ||||
|     ascending: bool, | ||||
|     target_point: [f64; 2], | ||||
|     field_ids: &Option<[u16; 2]>, | ||||
|     rtree: &mut Option<RTree<GeoPoint>>, | ||||
|     geo_candidates: &RoaringBitmap, | ||||
|     cached_sorted_docids: &mut VecDeque<(u32, [f64; 2])>, | ||||
| ) -> crate::Result<()> { | ||||
|     debug_assert!(cached_sorted_docids.is_empty()); | ||||
|  | ||||
|     // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` | ||||
|     let rtree = if strategy.use_rtree(geo_candidates.len() as usize) { | ||||
|         if let Some(rtree) = rtree.as_ref() { | ||||
|             // get rtree from cache | ||||
|             Some(rtree) | ||||
|         } else { | ||||
|             let rtree2 = index.geo_rtree(txn)?.expect("geo candidates but no rtree"); | ||||
|             // insert rtree in cache and returns it. | ||||
|             // Can't use `get_or_insert_with` because getting the rtree from the DB is a fallible operation. | ||||
|             Some(&*rtree.insert(rtree2)) | ||||
|         } | ||||
|     } else { | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     let cache_size = strategy.cache_size(); | ||||
|     if let Some(rtree) = rtree { | ||||
|         if ascending { | ||||
|             let point = lat_lng_to_xyz(&target_point); | ||||
|             for point in rtree.nearest_neighbor_iter(&point) { | ||||
|                 if geo_candidates.contains(point.data.0) { | ||||
|                     cached_sorted_docids.push_back(point.data); | ||||
|                     if cached_sorted_docids.len() >= cache_size { | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             // in the case of the desc geo sort we look for the closest point to the opposite of the queried point | ||||
|             // and we insert the points in reverse order they get reversed when emptying the cache later on | ||||
|             let point = lat_lng_to_xyz(&opposite_of(target_point)); | ||||
|             for point in rtree.nearest_neighbor_iter(&point) { | ||||
|                 if geo_candidates.contains(point.data.0) { | ||||
|                     cached_sorted_docids.push_front(point.data); | ||||
|                     if cached_sorted_docids.len() >= cache_size { | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         // the iterative version | ||||
|         let [lat, lng] = field_ids.expect("fill_buffer can't be called without the lat&lng"); | ||||
|  | ||||
|         let mut documents = geo_candidates | ||||
|             .iter() | ||||
|             .map(|id| -> crate::Result<_> { Ok((id, geo_value(id, lat, lng, index, txn)?)) }) | ||||
|             .collect::<crate::Result<Vec<(u32, [f64; 2])>>>()?; | ||||
|         // computing the distance between two points is expensive thus we cache the result | ||||
|         documents | ||||
|             .sort_by_cached_key(|(_, p)| distance_between_two_points(&target_point, p) as usize); | ||||
|         cached_sorted_docids.extend(documents); | ||||
|     }; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| #[allow(clippy::too_many_arguments)] | ||||
| pub fn next_bucket( | ||||
|     index: &Index, | ||||
|     txn: &RoTxn<heed::AnyTls>, | ||||
|     universe: &RoaringBitmap, | ||||
|     ascending: bool, | ||||
|     target_point: [f64; 2], | ||||
|     field_ids: &Option<[u16; 2]>, | ||||
|     rtree: &mut Option<RTree<GeoPoint>>, | ||||
|     cached_sorted_docids: &mut VecDeque<(u32, [f64; 2])>, | ||||
|     geo_candidates: &RoaringBitmap, | ||||
|     parameter: GeoSortParameter, | ||||
| ) -> crate::Result<Option<(RoaringBitmap, Option<[f64; 2]>)>> { | ||||
|     let mut geo_candidates = geo_candidates & universe; | ||||
|  | ||||
|     if geo_candidates.is_empty() { | ||||
|         return Ok(Some((universe.clone(), None))); | ||||
|     } | ||||
|  | ||||
|     let next = |cache: &mut VecDeque<_>| { | ||||
|         if ascending { | ||||
|             cache.pop_front() | ||||
|         } else { | ||||
|             cache.pop_back() | ||||
|         } | ||||
|     }; | ||||
|     let put_back = |cache: &mut VecDeque<_>, x: _| { | ||||
|         if ascending { | ||||
|             cache.push_front(x) | ||||
|         } else { | ||||
|             cache.push_back(x) | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     let mut current_bucket = RoaringBitmap::new(); | ||||
|     // current_distance stores the first point and distance in current bucket | ||||
|     let mut current_distance: Option<([f64; 2], f64)> = None; | ||||
|     loop { | ||||
|         // The loop will only exit when we have found all points with equal distance or have exhausted the candidates. | ||||
|         if let Some((id, point)) = next(cached_sorted_docids) { | ||||
|             if geo_candidates.contains(id) { | ||||
|                 let distance = distance_between_two_points(&target_point, &point); | ||||
|                 if let Some((point0, bucket_distance)) = current_distance.as_ref() { | ||||
|                     if (bucket_distance - distance).abs() > parameter.distance_error_margin { | ||||
|                         // different distance, point belongs to next bucket | ||||
|                         put_back(cached_sorted_docids, (id, point)); | ||||
|                         return Ok(Some((current_bucket, Some(point0.to_owned())))); | ||||
|                     } else { | ||||
|                         // same distance, point belongs to current bucket | ||||
|                         current_bucket.insert(id); | ||||
|                         // remove from candidates to prevent it from being added to the cache again | ||||
|                         geo_candidates.remove(id); | ||||
|                         // current bucket size reaches limit, force return | ||||
|                         if current_bucket.len() == parameter.max_bucket_size { | ||||
|                             return Ok(Some((current_bucket, Some(point0.to_owned())))); | ||||
|                         } | ||||
|                     } | ||||
|                 } else { | ||||
|                     // first doc in current bucket | ||||
|                     current_distance = Some((point, distance)); | ||||
|                     current_bucket.insert(id); | ||||
|                     geo_candidates.remove(id); | ||||
|                     // current bucket size reaches limit, force return | ||||
|                     if current_bucket.len() == parameter.max_bucket_size { | ||||
|                         return Ok(Some((current_bucket, Some(point.to_owned())))); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             // cache exhausted, we need to refill it | ||||
|             fill_cache( | ||||
|                 index, | ||||
|                 txn, | ||||
|                 parameter.strategy, | ||||
|                 ascending, | ||||
|                 target_point, | ||||
|                 field_ids, | ||||
|                 rtree, | ||||
|                 &geo_candidates, | ||||
|                 cached_sorted_docids, | ||||
|             )?; | ||||
|  | ||||
|             if cached_sorted_docids.is_empty() { | ||||
|                 // candidates exhausted, exit | ||||
|                 if let Some((point0, _)) = current_distance.as_ref() { | ||||
|                     return Ok(Some((current_bucket, Some(point0.to_owned())))); | ||||
|                 } else { | ||||
|                     return Ok(Some((universe.clone(), None))); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Return an iterator over each number value in the given field of the given document. | ||||
| fn facet_number_values<'a>( | ||||
|     docid: u32, | ||||
|     field_id: u16, | ||||
|     index: &Index, | ||||
|     txn: &'a RoTxn<'a>, | ||||
| ) -> crate::Result<RoPrefix<'a, FieldDocIdFacetCodec<OrderedF64Codec>, Unit>> { | ||||
|     let key = facet_values_prefix_key(field_id, docid); | ||||
|  | ||||
|     let iter = index | ||||
|         .field_id_docid_facet_f64s | ||||
|         .remap_key_type::<Bytes>() | ||||
|         .prefix_iter(txn, &key)? | ||||
|         .remap_key_type(); | ||||
|  | ||||
|     Ok(iter) | ||||
| } | ||||
|  | ||||
| /// Extracts the lat and long values from a single document. | ||||
| /// | ||||
| /// If it is not able to find it in the facet number index it will extract it | ||||
| /// from the facet string index and parse it as f64 (as the geo extraction behaves). | ||||
| pub(crate) fn geo_value( | ||||
|     docid: u32, | ||||
|     field_lat: u16, | ||||
|     field_lng: u16, | ||||
|     index: &Index, | ||||
|     rtxn: &RoTxn<'_>, | ||||
| ) -> crate::Result<[f64; 2]> { | ||||
|     let extract_geo = |geo_field: u16| -> crate::Result<f64> { | ||||
|         match facet_number_values(docid, geo_field, index, rtxn)?.next() { | ||||
|             Some(Ok(((_, _, geo), ()))) => Ok(geo), | ||||
|             Some(Err(e)) => Err(e.into()), | ||||
|             None => match facet_string_values(docid, geo_field, index, rtxn)?.next() { | ||||
|                 Some(Ok((_, geo))) => { | ||||
|                     Ok(geo.parse::<f64>().expect("cannot parse geo field as f64")) | ||||
|                 } | ||||
|                 Some(Err(e)) => Err(e.into()), | ||||
|                 None => panic!("A geo faceted document doesn't contain any lat or lng"), | ||||
|             }, | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     let lat = extract_geo(field_lat)?; | ||||
|     let lng = extract_geo(field_lng)?; | ||||
|  | ||||
|     Ok([lat, lng]) | ||||
| } | ||||
|  | ||||
| /// Compute the antipodal coordinate of `coord` | ||||
| pub(crate) fn opposite_of(mut coord: [f64; 2]) -> [f64; 2] { | ||||
|     coord[0] *= -1.; | ||||
|     // in the case of x,0 we want to return x,180 | ||||
|     if coord[1] > 0. { | ||||
|         coord[1] -= 180.; | ||||
|     } else { | ||||
|         coord[1] += 180.; | ||||
|     } | ||||
|  | ||||
|     coord | ||||
| } | ||||
| @@ -1,8 +1,10 @@ | ||||
| mod builder; | ||||
| mod enriched; | ||||
| pub mod geo_sort; | ||||
| mod primary_key; | ||||
| mod reader; | ||||
| mod serde_impl; | ||||
| pub mod sort; | ||||
|  | ||||
| use std::fmt::Debug; | ||||
| use std::io; | ||||
| @@ -19,6 +21,7 @@ pub use primary_key::{ | ||||
| pub use reader::{DocumentsBatchCursor, DocumentsBatchCursorError, DocumentsBatchReader}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| pub use self::geo_sort::{GeoSortParameter, GeoSortStrategy}; | ||||
| use crate::error::{FieldIdMapMissingEntry, InternalError}; | ||||
| use crate::{FieldId, Object, Result}; | ||||
|  | ||||
|   | ||||
							
								
								
									
										444
									
								
								crates/milli/src/documents/sort.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										444
									
								
								crates/milli/src/documents/sort.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,444 @@ | ||||
| use std::collections::{BTreeSet, VecDeque}; | ||||
|  | ||||
| use crate::{ | ||||
|     constants::RESERVED_GEO_FIELD_NAME, | ||||
|     documents::{geo_sort::next_bucket, GeoSortParameter}, | ||||
|     heed_codec::{ | ||||
|         facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, | ||||
|         BytesRefCodec, | ||||
|     }, | ||||
|     is_faceted, | ||||
|     search::facet::{ascending_facet_sort, descending_facet_sort}, | ||||
|     AscDesc, DocumentId, Member, UserError, | ||||
| }; | ||||
| use heed::Database; | ||||
| use roaring::RoaringBitmap; | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| enum AscDescId { | ||||
|     Facet { field_id: u16, ascending: bool }, | ||||
|     Geo { field_ids: [u16; 2], target_point: [f64; 2], ascending: bool }, | ||||
| } | ||||
|  | ||||
| /// A [`SortedDocumentsIterator`] allows efficient access to a continuous range of sorted documents. | ||||
| /// This is ideal in the context of paginated queries in which only a small number of documents are needed at a time. | ||||
| /// Search operations will only be performed upon access. | ||||
| pub enum SortedDocumentsIterator<'ctx> { | ||||
|     Leaf { | ||||
|         /// The exact number of documents remaining | ||||
|         size: usize, | ||||
|         values: Box<dyn Iterator<Item = DocumentId> + 'ctx>, | ||||
|     }, | ||||
|     Branch { | ||||
|         /// The current child, got from the children iterator | ||||
|         current_child: Option<Box<SortedDocumentsIterator<'ctx>>>, | ||||
|         /// The exact number of documents remaining, excluding documents in the current child | ||||
|         next_children_size: usize, | ||||
|         /// Iterators to become the current child once it is exhausted | ||||
|         next_children: | ||||
|             Box<dyn Iterator<Item = crate::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx>, | ||||
|     }, | ||||
| } | ||||
|  | ||||
| impl SortedDocumentsIterator<'_> { | ||||
|     /// Takes care of updating the current child if it is `None`, and also updates the size | ||||
|     fn update_current<'ctx>( | ||||
|         current_child: &mut Option<Box<SortedDocumentsIterator<'ctx>>>, | ||||
|         next_children_size: &mut usize, | ||||
|         next_children: &mut Box< | ||||
|             dyn Iterator<Item = crate::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx, | ||||
|         >, | ||||
|     ) -> crate::Result<()> { | ||||
|         if current_child.is_none() { | ||||
|             *current_child = match next_children.next() { | ||||
|                 Some(Ok(builder)) => { | ||||
|                     let next_child = Box::new(builder.build()?); | ||||
|                     *next_children_size -= next_child.size_hint().0; | ||||
|                     Some(next_child) | ||||
|                 } | ||||
|                 Some(Err(e)) => return Err(e), | ||||
|                 None => return Ok(()), | ||||
|             }; | ||||
|         } | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Iterator for SortedDocumentsIterator<'_> { | ||||
|     type Item = crate::Result<DocumentId>; | ||||
|  | ||||
|     /// Implementing the `nth` method allows for efficient access to the nth document in the sorted order. | ||||
|     /// It's used by `skip` internally. | ||||
|     /// The default implementation of `nth` would iterate over all children, which is inefficient for large datasets. | ||||
|     /// This implementation will jump over whole chunks of children until it gets close. | ||||
|     fn nth(&mut self, n: usize) -> Option<Self::Item> { | ||||
|         if n == 0 { | ||||
|             return self.next(); | ||||
|         } | ||||
|  | ||||
|         // If it's at the leaf level, just forward the call to the values iterator | ||||
|         let (current_child, next_children, next_children_size) = match self { | ||||
|             SortedDocumentsIterator::Leaf { values, size } => { | ||||
|                 *size = size.saturating_sub(n); | ||||
|                 return values.nth(n).map(Ok); | ||||
|             } | ||||
|             SortedDocumentsIterator::Branch { | ||||
|                 current_child, | ||||
|                 next_children, | ||||
|                 next_children_size, | ||||
|             } => (current_child, next_children, next_children_size), | ||||
|         }; | ||||
|  | ||||
|         // Otherwise don't directly iterate over children, skip them if we know we will go further | ||||
|         let mut to_skip = n - 1; | ||||
|         while to_skip > 0 { | ||||
|             if let Err(e) = SortedDocumentsIterator::update_current( | ||||
|                 current_child, | ||||
|                 next_children_size, | ||||
|                 next_children, | ||||
|             ) { | ||||
|                 return Some(Err(e)); | ||||
|             } | ||||
|             let Some(inner) = current_child else { | ||||
|                 return None; // No more inner iterators, everything has been consumed. | ||||
|             }; | ||||
|  | ||||
|             if to_skip >= inner.size_hint().0 { | ||||
|                 // The current child isn't large enough to contain the nth element. | ||||
|                 // Skip it and continue with the next one. | ||||
|                 to_skip -= inner.size_hint().0; | ||||
|                 *current_child = None; | ||||
|                 continue; | ||||
|             } else { | ||||
|                 // The current iterator is large enough, so we can forward the call to it. | ||||
|                 return inner.nth(to_skip + 1); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         self.next() | ||||
|     } | ||||
|  | ||||
|     /// Iterators need to keep track of their size so that they can be skipped efficiently by the `nth` method. | ||||
|     fn size_hint(&self) -> (usize, Option<usize>) { | ||||
|         let size = match self { | ||||
|             SortedDocumentsIterator::Leaf { size, .. } => *size, | ||||
|             SortedDocumentsIterator::Branch { | ||||
|                 next_children_size, | ||||
|                 current_child: Some(current_child), | ||||
|                 .. | ||||
|             } => current_child.size_hint().0 + next_children_size, | ||||
|             SortedDocumentsIterator::Branch { next_children_size, current_child: None, .. } => { | ||||
|                 *next_children_size | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         (size, Some(size)) | ||||
|     } | ||||
|  | ||||
|     fn next(&mut self) -> Option<Self::Item> { | ||||
|         match self { | ||||
|             SortedDocumentsIterator::Leaf { values, size } => { | ||||
|                 let result = values.next().map(Ok); | ||||
|                 if result.is_some() { | ||||
|                     *size -= 1; | ||||
|                 } | ||||
|                 result | ||||
|             } | ||||
|             SortedDocumentsIterator::Branch { | ||||
|                 current_child, | ||||
|                 next_children_size, | ||||
|                 next_children, | ||||
|             } => { | ||||
|                 let mut result = None; | ||||
|                 while result.is_none() { | ||||
|                     // Ensure we have selected an iterator to work with | ||||
|                     if let Err(e) = SortedDocumentsIterator::update_current( | ||||
|                         current_child, | ||||
|                         next_children_size, | ||||
|                         next_children, | ||||
|                     ) { | ||||
|                         return Some(Err(e)); | ||||
|                     } | ||||
|                     let Some(inner) = current_child else { | ||||
|                         return None; | ||||
|                     }; | ||||
|  | ||||
|                     result = inner.next(); | ||||
|  | ||||
|                     // If the current iterator is exhausted, we need to try the next one | ||||
|                     if result.is_none() { | ||||
|                         *current_child = None; | ||||
|                     } | ||||
|                 } | ||||
|                 result | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Builder for a [`SortedDocumentsIterator`]. | ||||
| /// Most builders won't ever be built, because pagination will skip them. | ||||
| pub struct SortedDocumentsIteratorBuilder<'ctx> { | ||||
|     index: &'ctx crate::Index, | ||||
|     rtxn: &'ctx heed::RoTxn<'ctx>, | ||||
|     number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|     string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|     fields: &'ctx [AscDescId], | ||||
|     candidates: RoaringBitmap, | ||||
|     geo_candidates: &'ctx RoaringBitmap, | ||||
| } | ||||
|  | ||||
| impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> { | ||||
|     /// Performs the sort and builds a [`SortedDocumentsIterator`]. | ||||
|     fn build(self) -> crate::Result<SortedDocumentsIterator<'ctx>> { | ||||
|         let size = self.candidates.len() as usize; | ||||
|  | ||||
|         match self.fields { | ||||
|             [] => Ok(SortedDocumentsIterator::Leaf { | ||||
|                 size, | ||||
|                 values: Box::new(self.candidates.into_iter()), | ||||
|             }), | ||||
|             [AscDescId::Facet { field_id, ascending }, next_fields @ ..] => { | ||||
|                 SortedDocumentsIteratorBuilder::build_facet( | ||||
|                     self.index, | ||||
|                     self.rtxn, | ||||
|                     self.number_db, | ||||
|                     self.string_db, | ||||
|                     next_fields, | ||||
|                     self.candidates, | ||||
|                     self.geo_candidates, | ||||
|                     *field_id, | ||||
|                     *ascending, | ||||
|                 ) | ||||
|             } | ||||
|             [AscDescId::Geo { field_ids, target_point, ascending }, next_fields @ ..] => { | ||||
|                 SortedDocumentsIteratorBuilder::build_geo( | ||||
|                     self.index, | ||||
|                     self.rtxn, | ||||
|                     self.number_db, | ||||
|                     self.string_db, | ||||
|                     next_fields, | ||||
|                     self.candidates, | ||||
|                     self.geo_candidates, | ||||
|                     *field_ids, | ||||
|                     *target_point, | ||||
|                     *ascending, | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Builds a [`SortedDocumentsIterator`] based on the results of a facet sort. | ||||
|     #[allow(clippy::too_many_arguments)] | ||||
|     fn build_facet( | ||||
|         index: &'ctx crate::Index, | ||||
|         rtxn: &'ctx heed::RoTxn<'ctx>, | ||||
|         number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|         string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|         next_fields: &'ctx [AscDescId], | ||||
|         candidates: RoaringBitmap, | ||||
|         geo_candidates: &'ctx RoaringBitmap, | ||||
|         field_id: u16, | ||||
|         ascending: bool, | ||||
|     ) -> crate::Result<SortedDocumentsIterator<'ctx>> { | ||||
|         let size = candidates.len() as usize; | ||||
|  | ||||
|         // Perform the sort on the first field | ||||
|         let (number_iter, string_iter) = if ascending { | ||||
|             let number_iter = ascending_facet_sort(rtxn, number_db, field_id, candidates.clone())?; | ||||
|             let string_iter = ascending_facet_sort(rtxn, string_db, field_id, candidates)?; | ||||
|  | ||||
|             (itertools::Either::Left(number_iter), itertools::Either::Left(string_iter)) | ||||
|         } else { | ||||
|             let number_iter = descending_facet_sort(rtxn, number_db, field_id, candidates.clone())?; | ||||
|             let string_iter = descending_facet_sort(rtxn, string_db, field_id, candidates)?; | ||||
|  | ||||
|             (itertools::Either::Right(number_iter), itertools::Either::Right(string_iter)) | ||||
|         }; | ||||
|  | ||||
|         // Create builders for the next level of the tree | ||||
|         let number_iter = number_iter.map(|r| r.map(|(d, _)| d)); | ||||
|         let string_iter = string_iter.map(|r| r.map(|(d, _)| d)); | ||||
|         let next_children = number_iter.chain(string_iter).map(move |r| { | ||||
|             Ok(SortedDocumentsIteratorBuilder { | ||||
|                 index, | ||||
|                 rtxn, | ||||
|                 number_db, | ||||
|                 string_db, | ||||
|                 fields: next_fields, | ||||
|                 candidates: r?, | ||||
|                 geo_candidates, | ||||
|             }) | ||||
|         }); | ||||
|  | ||||
|         Ok(SortedDocumentsIterator::Branch { | ||||
|             current_child: None, | ||||
|             next_children_size: size, | ||||
|             next_children: Box::new(next_children), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     /// Builds a [`SortedDocumentsIterator`] based on the (lazy) results of a geo sort. | ||||
|     #[allow(clippy::too_many_arguments)] | ||||
|     fn build_geo( | ||||
|         index: &'ctx crate::Index, | ||||
|         rtxn: &'ctx heed::RoTxn<'ctx>, | ||||
|         number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|         string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|         next_fields: &'ctx [AscDescId], | ||||
|         candidates: RoaringBitmap, | ||||
|         geo_candidates: &'ctx RoaringBitmap, | ||||
|         field_ids: [u16; 2], | ||||
|         target_point: [f64; 2], | ||||
|         ascending: bool, | ||||
|     ) -> crate::Result<SortedDocumentsIterator<'ctx>> { | ||||
|         let mut cache = VecDeque::new(); | ||||
|         let mut rtree = None; | ||||
|         let size = candidates.len() as usize; | ||||
|         let not_geo_candidates = candidates.clone() - geo_candidates; | ||||
|         let mut geo_remaining = size - not_geo_candidates.len() as usize; | ||||
|         let mut not_geo_candidates = Some(not_geo_candidates); | ||||
|  | ||||
|         let next_children = std::iter::from_fn(move || { | ||||
|             // Find the next bucket of geo-sorted documents. | ||||
|             // next_bucket loops and will go back to the beginning so we use a variable to track how many are left. | ||||
|             if geo_remaining > 0 { | ||||
|                 if let Ok(Some((docids, _point))) = next_bucket( | ||||
|                     index, | ||||
|                     rtxn, | ||||
|                     &candidates, | ||||
|                     ascending, | ||||
|                     target_point, | ||||
|                     &Some(field_ids), | ||||
|                     &mut rtree, | ||||
|                     &mut cache, | ||||
|                     geo_candidates, | ||||
|                     GeoSortParameter::default(), | ||||
|                 ) { | ||||
|                     geo_remaining -= docids.len() as usize; | ||||
|                     return Some(Ok(SortedDocumentsIteratorBuilder { | ||||
|                         index, | ||||
|                         rtxn, | ||||
|                         number_db, | ||||
|                         string_db, | ||||
|                         fields: next_fields, | ||||
|                         candidates: docids, | ||||
|                         geo_candidates, | ||||
|                     })); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Once all geo candidates have been processed, we can return the others | ||||
|             if let Some(not_geo_candidates) = not_geo_candidates.take() { | ||||
|                 if !not_geo_candidates.is_empty() { | ||||
|                     return Some(Ok(SortedDocumentsIteratorBuilder { | ||||
|                         index, | ||||
|                         rtxn, | ||||
|                         number_db, | ||||
|                         string_db, | ||||
|                         fields: next_fields, | ||||
|                         candidates: not_geo_candidates, | ||||
|                         geo_candidates, | ||||
|                     })); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             None | ||||
|         }); | ||||
|  | ||||
|         Ok(SortedDocumentsIterator::Branch { | ||||
|             current_child: None, | ||||
|             next_children_size: size, | ||||
|             next_children: Box::new(next_children), | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// A structure owning the data needed during the lifetime of a [`SortedDocumentsIterator`]. | ||||
| pub struct SortedDocuments<'ctx> { | ||||
|     index: &'ctx crate::Index, | ||||
|     rtxn: &'ctx heed::RoTxn<'ctx>, | ||||
|     fields: Vec<AscDescId>, | ||||
|     number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|     string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, | ||||
|     candidates: &'ctx RoaringBitmap, | ||||
|     geo_candidates: RoaringBitmap, | ||||
| } | ||||
|  | ||||
| impl<'ctx> SortedDocuments<'ctx> { | ||||
|     pub fn iter(&'ctx self) -> crate::Result<SortedDocumentsIterator<'ctx>> { | ||||
|         let builder = SortedDocumentsIteratorBuilder { | ||||
|             index: self.index, | ||||
|             rtxn: self.rtxn, | ||||
|             number_db: self.number_db, | ||||
|             string_db: self.string_db, | ||||
|             fields: &self.fields, | ||||
|             candidates: self.candidates.clone(), | ||||
|             geo_candidates: &self.geo_candidates, | ||||
|         }; | ||||
|         builder.build() | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn recursive_sort<'ctx>( | ||||
|     index: &'ctx crate::Index, | ||||
|     rtxn: &'ctx heed::RoTxn<'ctx>, | ||||
|     sort: Vec<AscDesc>, | ||||
|     candidates: &'ctx RoaringBitmap, | ||||
| ) -> crate::Result<SortedDocuments<'ctx>> { | ||||
|     let sortable_fields: BTreeSet<_> = index.sortable_fields(rtxn)?.into_iter().collect(); | ||||
|     let fields_ids_map = index.fields_ids_map(rtxn)?; | ||||
|  | ||||
|     // Retrieve the field ids that are used for sorting | ||||
|     let mut fields = Vec::new(); | ||||
|     let mut need_geo_candidates = false; | ||||
|     for asc_desc in sort { | ||||
|         let (field, geofield) = match asc_desc { | ||||
|             AscDesc::Asc(Member::Field(field)) => (Some((field, true)), None), | ||||
|             AscDesc::Desc(Member::Field(field)) => (Some((field, false)), None), | ||||
|             AscDesc::Asc(Member::Geo(target_point)) => (None, Some((target_point, true))), | ||||
|             AscDesc::Desc(Member::Geo(target_point)) => (None, Some((target_point, false))), | ||||
|         }; | ||||
|         if let Some((field, ascending)) = field { | ||||
|             if is_faceted(&field, &sortable_fields) { | ||||
|                 if let Some(field_id) = fields_ids_map.id(&field) { | ||||
|                     fields.push(AscDescId::Facet { field_id, ascending }); | ||||
|                     continue; | ||||
|                 } | ||||
|             } | ||||
|             return Err(UserError::InvalidDocumentSortableAttribute { | ||||
|                 field: field.to_string(), | ||||
|                 sortable_fields: sortable_fields.clone(), | ||||
|             } | ||||
|             .into()); | ||||
|         } | ||||
|         if let Some((target_point, ascending)) = geofield { | ||||
|             if sortable_fields.contains(RESERVED_GEO_FIELD_NAME) { | ||||
|                 if let (Some(lat), Some(lng)) = | ||||
|                     (fields_ids_map.id("_geo.lat"), fields_ids_map.id("_geo.lng")) | ||||
|                 { | ||||
|                     need_geo_candidates = true; | ||||
|                     fields.push(AscDescId::Geo { field_ids: [lat, lng], target_point, ascending }); | ||||
|                     continue; | ||||
|                 } | ||||
|             } | ||||
|             return Err(UserError::InvalidDocumentSortableAttribute { | ||||
|                 field: RESERVED_GEO_FIELD_NAME.to_string(), | ||||
|                 sortable_fields: sortable_fields.clone(), | ||||
|             } | ||||
|             .into()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     let geo_candidates = if need_geo_candidates { | ||||
|         index.geo_faceted_documents_ids(rtxn)? | ||||
|     } else { | ||||
|         RoaringBitmap::new() | ||||
|     }; | ||||
|  | ||||
|     let number_db = index.facet_id_f64_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>(); | ||||
|     let string_db = | ||||
|         index.facet_id_string_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>(); | ||||
|  | ||||
|     Ok(SortedDocuments { index, rtxn, fields, number_db, string_db, candidates, geo_candidates }) | ||||
| } | ||||
| @@ -191,7 +191,21 @@ and can not be more than 511 bytes.", .document_id.to_string() | ||||
|                 ), | ||||
|         } | ||||
|     )] | ||||
|     InvalidSortableAttribute { field: String, valid_fields: BTreeSet<String>, hidden_fields: bool }, | ||||
|     InvalidSearchSortableAttribute { | ||||
|         field: String, | ||||
|         valid_fields: BTreeSet<String>, | ||||
|         hidden_fields: bool, | ||||
|     }, | ||||
|     #[error("Attribute `{}` is not sortable. {}", | ||||
|         .field, | ||||
|         match .sortable_fields.is_empty() { | ||||
|             true => "This index does not have configured sortable attributes.".to_string(), | ||||
|             false => format!("Available sortable attributes are: `{}`.", | ||||
|                     sortable_fields.iter().map(AsRef::as_ref).collect::<Vec<&str>>().join(", ") | ||||
|                 ), | ||||
|         } | ||||
|     )] | ||||
|     InvalidDocumentSortableAttribute { field: String, sortable_fields: BTreeSet<String> }, | ||||
|     #[error("Attribute `{}` is not filterable and thus, cannot be used as distinct attribute. {}", | ||||
|         .field, | ||||
|         match (.valid_patterns.is_empty(), .matching_rule_index) { | ||||
| @@ -272,8 +286,8 @@ and can not be more than 511 bytes.", .document_id.to_string() | ||||
|     PrimaryKeyCannotBeChanged(String), | ||||
|     #[error(transparent)] | ||||
|     SerdeJson(serde_json::Error), | ||||
|     #[error(transparent)] | ||||
|     SortError(#[from] SortError), | ||||
|     #[error("{error}")] | ||||
|     SortError { error: SortError, search: bool }, | ||||
|     #[error("An unknown internal document id have been used: `{document_id}`.")] | ||||
|     UnknownInternalDocumentId { document_id: DocumentId }, | ||||
|     #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] | ||||
| @@ -616,7 +630,7 @@ fn conditionally_lookup_for_error_message() { | ||||
|     ]; | ||||
|  | ||||
|     for (list, suffix) in messages { | ||||
|         let err = UserError::InvalidSortableAttribute { | ||||
|         let err = UserError::InvalidSearchSortableAttribute { | ||||
|             field: "name".to_string(), | ||||
|             valid_fields: list, | ||||
|             hidden_fields: false, | ||||
|   | ||||
| @@ -43,12 +43,13 @@ use std::fmt; | ||||
| use std::hash::BuildHasherDefault; | ||||
|  | ||||
| use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; | ||||
| pub use documents::GeoSortStrategy; | ||||
| pub use filter_parser::{Condition, FilterCondition, Span, Token}; | ||||
| use fxhash::{FxHasher32, FxHasher64}; | ||||
| pub use grenad::CompressionType; | ||||
| pub use search::new::{ | ||||
|     execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, | ||||
|     SearchLogger, VisualSearchLogger, | ||||
|     execute_search, filtered_universe, DefaultSearchLogger, SearchContext, SearchLogger, | ||||
|     VisualSearchLogger, | ||||
| }; | ||||
| use serde_json::Value; | ||||
| pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; | ||||
|   | ||||
| @@ -9,6 +9,7 @@ use roaring::bitmap::RoaringBitmap; | ||||
| pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; | ||||
| pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; | ||||
| use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats}; | ||||
| use crate::documents::GeoSortParameter; | ||||
| use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features}; | ||||
| use crate::index::MatchingStrategy; | ||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||
| @@ -47,7 +48,7 @@ pub struct Search<'a> { | ||||
|     sort_criteria: Option<Vec<AscDesc>>, | ||||
|     distinct: Option<String>, | ||||
|     searchable_attributes: Option<&'a [String]>, | ||||
|     geo_param: new::GeoSortParameter, | ||||
|     geo_param: GeoSortParameter, | ||||
|     terms_matching_strategy: TermsMatchingStrategy, | ||||
|     scoring_strategy: ScoringStrategy, | ||||
|     words_limit: usize, | ||||
| @@ -71,7 +72,7 @@ impl<'a> Search<'a> { | ||||
|             sort_criteria: None, | ||||
|             distinct: None, | ||||
|             searchable_attributes: None, | ||||
|             geo_param: new::GeoSortParameter::default(), | ||||
|             geo_param: GeoSortParameter::default(), | ||||
|             terms_matching_strategy: TermsMatchingStrategy::default(), | ||||
|             scoring_strategy: Default::default(), | ||||
|             exhaustive_number_hits: false, | ||||
| @@ -149,7 +150,7 @@ impl<'a> Search<'a> { | ||||
|     } | ||||
|  | ||||
|     #[cfg(test)] | ||||
|     pub fn geo_sort_strategy(&mut self, strategy: new::GeoSortStrategy) -> &mut Search<'a> { | ||||
|     pub fn geo_sort_strategy(&mut self, strategy: crate::GeoSortStrategy) -> &mut Search<'a> { | ||||
|         self.geo_param.strategy = strategy; | ||||
|         self | ||||
|     } | ||||
|   | ||||
| @@ -82,7 +82,7 @@ fn facet_value_docids( | ||||
| } | ||||
|  | ||||
| /// Return an iterator over each number value in the given field of the given document. | ||||
| fn facet_number_values<'a>( | ||||
| pub(crate) fn facet_number_values<'a>( | ||||
|     docid: u32, | ||||
|     field_id: u16, | ||||
|     index: &Index, | ||||
| @@ -118,7 +118,7 @@ pub fn facet_string_values<'a>( | ||||
| } | ||||
|  | ||||
| #[allow(clippy::drop_non_drop)] | ||||
| fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] { | ||||
| pub(crate) fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] { | ||||
|     concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes()) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,96 +1,18 @@ | ||||
| use std::collections::VecDeque; | ||||
|  | ||||
| use heed::types::{Bytes, Unit}; | ||||
| use heed::{RoPrefix, RoTxn}; | ||||
| use roaring::RoaringBitmap; | ||||
| use rstar::RTree; | ||||
|  | ||||
| use super::facet_string_values; | ||||
| use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; | ||||
| use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec}; | ||||
| use crate::documents::geo_sort::{fill_cache, next_bucket}; | ||||
| use crate::documents::{GeoSortParameter, GeoSortStrategy}; | ||||
| use crate::score_details::{self, ScoreDetails}; | ||||
| use crate::{ | ||||
|     distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext, | ||||
|     SearchLogger, | ||||
| }; | ||||
|  | ||||
| const FID_SIZE: usize = 2; | ||||
| const DOCID_SIZE: usize = 4; | ||||
|  | ||||
| #[allow(clippy::drop_non_drop)] | ||||
| fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] { | ||||
|     concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes()) | ||||
| } | ||||
|  | ||||
| /// Return an iterator over each number value in the given field of the given document. | ||||
| fn facet_number_values<'a>( | ||||
|     docid: u32, | ||||
|     field_id: u16, | ||||
|     index: &Index, | ||||
|     txn: &'a RoTxn<'a>, | ||||
| ) -> Result<RoPrefix<'a, FieldDocIdFacetCodec<OrderedF64Codec>, Unit>> { | ||||
|     let key = facet_values_prefix_key(field_id, docid); | ||||
|  | ||||
|     let iter = index | ||||
|         .field_id_docid_facet_f64s | ||||
|         .remap_key_type::<Bytes>() | ||||
|         .prefix_iter(txn, &key)? | ||||
|         .remap_key_type(); | ||||
|  | ||||
|     Ok(iter) | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub struct Parameter { | ||||
|     // Define the strategy used by the geo sort | ||||
|     pub strategy: Strategy, | ||||
|     // Limit the number of docs in a single bucket to avoid unexpectedly large overhead | ||||
|     pub max_bucket_size: u64, | ||||
|     // Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal | ||||
|     pub distance_error_margin: f64, | ||||
| } | ||||
|  | ||||
| impl Default for Parameter { | ||||
|     fn default() -> Self { | ||||
|         Self { strategy: Strategy::default(), max_bucket_size: 1000, distance_error_margin: 1.0 } | ||||
|     } | ||||
| } | ||||
| /// Define the strategy used by the geo sort. | ||||
| /// The parameter represents the cache size, and, in the case of the Dynamic strategy, | ||||
| /// the point where we move from using the iterative strategy to the rtree. | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub enum Strategy { | ||||
|     AlwaysIterative(usize), | ||||
|     AlwaysRtree(usize), | ||||
|     Dynamic(usize), | ||||
| } | ||||
|  | ||||
| impl Default for Strategy { | ||||
|     fn default() -> Self { | ||||
|         Strategy::Dynamic(1000) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Strategy { | ||||
|     pub fn use_rtree(&self, candidates: usize) -> bool { | ||||
|         match self { | ||||
|             Strategy::AlwaysIterative(_) => false, | ||||
|             Strategy::AlwaysRtree(_) => true, | ||||
|             Strategy::Dynamic(i) => candidates >= *i, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn cache_size(&self) -> usize { | ||||
|         match self { | ||||
|             Strategy::AlwaysIterative(i) | Strategy::AlwaysRtree(i) | Strategy::Dynamic(i) => *i, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| use crate::{GeoPoint, Result, SearchContext, SearchLogger}; | ||||
|  | ||||
| pub struct GeoSort<Q: RankingRuleQueryTrait> { | ||||
|     query: Option<Q>, | ||||
|  | ||||
|     strategy: Strategy, | ||||
|     strategy: GeoSortStrategy, | ||||
|     ascending: bool, | ||||
|     point: [f64; 2], | ||||
|     field_ids: Option<[u16; 2]>, | ||||
| @@ -107,12 +29,12 @@ pub struct GeoSort<Q: RankingRuleQueryTrait> { | ||||
|  | ||||
| impl<Q: RankingRuleQueryTrait> GeoSort<Q> { | ||||
|     pub fn new( | ||||
|         parameter: Parameter, | ||||
|         parameter: GeoSortParameter, | ||||
|         geo_faceted_docids: RoaringBitmap, | ||||
|         point: [f64; 2], | ||||
|         ascending: bool, | ||||
|     ) -> Result<Self> { | ||||
|         let Parameter { strategy, max_bucket_size, distance_error_margin } = parameter; | ||||
|         let GeoSortParameter { strategy, max_bucket_size, distance_error_margin } = parameter; | ||||
|         Ok(Self { | ||||
|             query: None, | ||||
|             strategy, | ||||
| @@ -134,98 +56,22 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> { | ||||
|         ctx: &mut SearchContext<'_>, | ||||
|         geo_candidates: &RoaringBitmap, | ||||
|     ) -> Result<()> { | ||||
|         debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); | ||||
|         debug_assert!(self.cached_sorted_docids.is_empty()); | ||||
|  | ||||
|         // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` | ||||
|         let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) { | ||||
|             if let Some(rtree) = self.rtree.as_ref() { | ||||
|                 // get rtree from cache | ||||
|                 Some(rtree) | ||||
|             } else { | ||||
|                 let rtree = ctx.index.geo_rtree(ctx.txn)?.expect("geo candidates but no rtree"); | ||||
|                 // insert rtree in cache and returns it. | ||||
|                 // Can't use `get_or_insert_with` because getting the rtree from the DB is a fallible operation. | ||||
|                 Some(&*self.rtree.insert(rtree)) | ||||
|             } | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
|  | ||||
|         let cache_size = self.strategy.cache_size(); | ||||
|         if let Some(rtree) = rtree { | ||||
|             if self.ascending { | ||||
|                 let point = lat_lng_to_xyz(&self.point); | ||||
|                 for point in rtree.nearest_neighbor_iter(&point) { | ||||
|                     if geo_candidates.contains(point.data.0) { | ||||
|                         self.cached_sorted_docids.push_back(point.data); | ||||
|                         if self.cached_sorted_docids.len() >= cache_size { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } else { | ||||
|                 // in the case of the desc geo sort we look for the closest point to the opposite of the queried point | ||||
|                 // and we insert the points in reverse order they get reversed when emptying the cache later on | ||||
|                 let point = lat_lng_to_xyz(&opposite_of(self.point)); | ||||
|                 for point in rtree.nearest_neighbor_iter(&point) { | ||||
|                     if geo_candidates.contains(point.data.0) { | ||||
|                         self.cached_sorted_docids.push_front(point.data); | ||||
|                         if self.cached_sorted_docids.len() >= cache_size { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             // the iterative version | ||||
|             let [lat, lng] = self.field_ids.unwrap(); | ||||
|  | ||||
|             let mut documents = geo_candidates | ||||
|                 .iter() | ||||
|                 .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) | ||||
|                 .collect::<Result<Vec<(u32, [f64; 2])>>>()?; | ||||
|             // computing the distance between two points is expensive thus we cache the result | ||||
|             documents | ||||
|                 .sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize); | ||||
|             self.cached_sorted_docids.extend(documents); | ||||
|         }; | ||||
|         fill_cache( | ||||
|             ctx.index, | ||||
|             ctx.txn, | ||||
|             self.strategy, | ||||
|             self.ascending, | ||||
|             self.point, | ||||
|             &self.field_ids, | ||||
|             &mut self.rtree, | ||||
|             geo_candidates, | ||||
|             &mut self.cached_sorted_docids, | ||||
|         )?; | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Extracts the lat and long values from a single document. | ||||
| /// | ||||
| /// If it is not able to find it in the facet number index it will extract it | ||||
| /// from the facet string index and parse it as f64 (as the geo extraction behaves). | ||||
| fn geo_value( | ||||
|     docid: u32, | ||||
|     field_lat: u16, | ||||
|     field_lng: u16, | ||||
|     index: &Index, | ||||
|     rtxn: &RoTxn<'_>, | ||||
| ) -> Result<[f64; 2]> { | ||||
|     let extract_geo = |geo_field: u16| -> Result<f64> { | ||||
|         match facet_number_values(docid, geo_field, index, rtxn)?.next() { | ||||
|             Some(Ok(((_, _, geo), ()))) => Ok(geo), | ||||
|             Some(Err(e)) => Err(e.into()), | ||||
|             None => match facet_string_values(docid, geo_field, index, rtxn)?.next() { | ||||
|                 Some(Ok((_, geo))) => { | ||||
|                     Ok(geo.parse::<f64>().expect("cannot parse geo field as f64")) | ||||
|                 } | ||||
|                 Some(Err(e)) => Err(e.into()), | ||||
|                 None => panic!("A geo faceted document doesn't contain any lat or lng"), | ||||
|             }, | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     let lat = extract_geo(field_lat)?; | ||||
|     let lng = extract_geo(field_lng)?; | ||||
|  | ||||
|     Ok([lat, lng]) | ||||
| } | ||||
|  | ||||
| impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | ||||
|     fn id(&self) -> String { | ||||
|         "geo_sort".to_owned() | ||||
| @@ -267,124 +113,33 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | ||||
|     ) -> Result<Option<RankingRuleOutput<Q>>> { | ||||
|         let query = self.query.as_ref().unwrap().clone(); | ||||
|  | ||||
|         let mut geo_candidates = &self.geo_candidates & universe; | ||||
|  | ||||
|         if geo_candidates.is_empty() { | ||||
|             return Ok(Some(RankingRuleOutput { | ||||
|         next_bucket( | ||||
|             ctx.index, | ||||
|             ctx.txn, | ||||
|             universe, | ||||
|             self.ascending, | ||||
|             self.point, | ||||
|             &self.field_ids, | ||||
|             &mut self.rtree, | ||||
|             &mut self.cached_sorted_docids, | ||||
|             &self.geo_candidates, | ||||
|             GeoSortParameter { | ||||
|                 strategy: self.strategy, | ||||
|                 max_bucket_size: self.max_bucket_size, | ||||
|                 distance_error_margin: self.distance_error_margin, | ||||
|             }, | ||||
|         ) | ||||
|         .map(|o| { | ||||
|             o.map(|(candidates, point)| RankingRuleOutput { | ||||
|                 query, | ||||
|                 candidates: universe.clone(), | ||||
|                 candidates, | ||||
|                 score: ScoreDetails::GeoSort(score_details::GeoSort { | ||||
|                     target_point: self.point, | ||||
|                     ascending: self.ascending, | ||||
|                     value: None, | ||||
|                     value: point, | ||||
|                 }), | ||||
|             })); | ||||
|         } | ||||
|  | ||||
|         let ascending = self.ascending; | ||||
|         let next = |cache: &mut VecDeque<_>| { | ||||
|             if ascending { | ||||
|                 cache.pop_front() | ||||
|             } else { | ||||
|                 cache.pop_back() | ||||
|             } | ||||
|         }; | ||||
|         let put_back = |cache: &mut VecDeque<_>, x: _| { | ||||
|             if ascending { | ||||
|                 cache.push_front(x) | ||||
|             } else { | ||||
|                 cache.push_back(x) | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         let mut current_bucket = RoaringBitmap::new(); | ||||
|         // current_distance stores the first point and distance in current bucket | ||||
|         let mut current_distance: Option<([f64; 2], f64)> = None; | ||||
|         loop { | ||||
|             // The loop will only exit when we have found all points with equal distance or have exhausted the candidates. | ||||
|             if let Some((id, point)) = next(&mut self.cached_sorted_docids) { | ||||
|                 if geo_candidates.contains(id) { | ||||
|                     let distance = distance_between_two_points(&self.point, &point); | ||||
|                     if let Some((point0, bucket_distance)) = current_distance.as_ref() { | ||||
|                         if (bucket_distance - distance).abs() > self.distance_error_margin { | ||||
|                             // different distance, point belongs to next bucket | ||||
|                             put_back(&mut self.cached_sorted_docids, (id, point)); | ||||
|                             return Ok(Some(RankingRuleOutput { | ||||
|                                 query, | ||||
|                                 candidates: current_bucket, | ||||
|                                 score: ScoreDetails::GeoSort(score_details::GeoSort { | ||||
|                                     target_point: self.point, | ||||
|                                     ascending: self.ascending, | ||||
|                                     value: Some(point0.to_owned()), | ||||
|                                 }), | ||||
|                             })); | ||||
|                         } else { | ||||
|                             // same distance, point belongs to current bucket | ||||
|                             current_bucket.insert(id); | ||||
|                             // remove from cadidates to prevent it from being added to the cache again | ||||
|                             geo_candidates.remove(id); | ||||
|                             // current bucket size reaches limit, force return | ||||
|                             if current_bucket.len() == self.max_bucket_size { | ||||
|                                 return Ok(Some(RankingRuleOutput { | ||||
|                                     query, | ||||
|                                     candidates: current_bucket, | ||||
|                                     score: ScoreDetails::GeoSort(score_details::GeoSort { | ||||
|                                         target_point: self.point, | ||||
|                                         ascending: self.ascending, | ||||
|                                         value: Some(point0.to_owned()), | ||||
|                                     }), | ||||
|                                 })); | ||||
|                             } | ||||
|                         } | ||||
|                     } else { | ||||
|                         // first doc in current bucket | ||||
|                         current_distance = Some((point, distance)); | ||||
|                         current_bucket.insert(id); | ||||
|                         geo_candidates.remove(id); | ||||
|                         // current bucket size reaches limit, force return | ||||
|                         if current_bucket.len() == self.max_bucket_size { | ||||
|                             return Ok(Some(RankingRuleOutput { | ||||
|                                 query, | ||||
|                                 candidates: current_bucket, | ||||
|                                 score: ScoreDetails::GeoSort(score_details::GeoSort { | ||||
|                                     target_point: self.point, | ||||
|                                     ascending: self.ascending, | ||||
|                                     value: Some(point.to_owned()), | ||||
|                                 }), | ||||
|                             })); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } else { | ||||
|                 // cache exhausted, we need to refill it | ||||
|                 self.fill_buffer(ctx, &geo_candidates)?; | ||||
|  | ||||
|                 if self.cached_sorted_docids.is_empty() { | ||||
|                     // candidates exhausted, exit | ||||
|                     if let Some((point0, _)) = current_distance.as_ref() { | ||||
|                         return Ok(Some(RankingRuleOutput { | ||||
|                             query, | ||||
|                             candidates: current_bucket, | ||||
|                             score: ScoreDetails::GeoSort(score_details::GeoSort { | ||||
|                                 target_point: self.point, | ||||
|                                 ascending: self.ascending, | ||||
|                                 value: Some(point0.to_owned()), | ||||
|                             }), | ||||
|                         })); | ||||
|                     } else { | ||||
|                         return Ok(Some(RankingRuleOutput { | ||||
|                             query, | ||||
|                             candidates: universe.clone(), | ||||
|                             score: ScoreDetails::GeoSort(score_details::GeoSort { | ||||
|                                 target_point: self.point, | ||||
|                                 ascending: self.ascending, | ||||
|                                 value: None, | ||||
|                             }), | ||||
|                         })); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|             }) | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     #[tracing::instrument(level = "trace", skip_all, target = "search::geo_sort")] | ||||
| @@ -394,16 +149,3 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { | ||||
|         self.cached_sorted_docids.clear(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Compute the antipodal coordinate of `coord` | ||||
| fn opposite_of(mut coord: [f64; 2]) -> [f64; 2] { | ||||
|     coord[0] *= -1.; | ||||
|     // in the case of x,0 we want to return x,180 | ||||
|     if coord[1] > 0. { | ||||
|         coord[1] -= 180.; | ||||
|     } else { | ||||
|         coord[1] += 180.; | ||||
|     } | ||||
|  | ||||
|     coord | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| mod bucket_sort; | ||||
| mod db_cache; | ||||
| mod distinct; | ||||
| mod geo_sort; | ||||
| pub(crate) mod geo_sort; | ||||
| mod graph_based_ranking_rule; | ||||
| mod interner; | ||||
| mod limits; | ||||
| @@ -46,14 +46,14 @@ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; | ||||
| use roaring::RoaringBitmap; | ||||
| use sort::Sort; | ||||
|  | ||||
| use self::distinct::facet_string_values; | ||||
| pub(crate) use self::distinct::{facet_string_values, facet_values_prefix_key}; | ||||
| use self::geo_sort::GeoSort; | ||||
| pub use self::geo_sort::{Parameter as GeoSortParameter, Strategy as GeoSortStrategy}; | ||||
| use self::graph_based_ranking_rule::Words; | ||||
| use self::interner::Interned; | ||||
| use self::vector_sort::VectorSort; | ||||
| use crate::attribute_patterns::{match_pattern, PatternMatch}; | ||||
| use crate::constants::RESERVED_GEO_FIELD_NAME; | ||||
| use crate::documents::GeoSortParameter; | ||||
| use crate::index::PrefixSearch; | ||||
| use crate::localized_attributes_rules::LocalizedFieldIds; | ||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||
| @@ -319,7 +319,7 @@ fn resolve_negative_phrases( | ||||
| fn get_ranking_rules_for_placeholder_search<'ctx>( | ||||
|     ctx: &SearchContext<'ctx>, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     geo_param: geo_sort::Parameter, | ||||
|     geo_param: GeoSortParameter, | ||||
| ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { | ||||
|     let mut sort = false; | ||||
|     let mut sorted_fields = HashSet::new(); | ||||
| @@ -371,7 +371,7 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( | ||||
| fn get_ranking_rules_for_vector<'ctx>( | ||||
|     ctx: &SearchContext<'ctx>, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     geo_param: geo_sort::Parameter, | ||||
|     geo_param: GeoSortParameter, | ||||
|     limit_plus_offset: usize, | ||||
|     target: &[f32], | ||||
|     embedder_name: &str, | ||||
| @@ -448,7 +448,7 @@ fn get_ranking_rules_for_vector<'ctx>( | ||||
| fn get_ranking_rules_for_query_graph_search<'ctx>( | ||||
|     ctx: &SearchContext<'ctx>, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     geo_param: geo_sort::Parameter, | ||||
|     geo_param: GeoSortParameter, | ||||
|     terms_matching_strategy: TermsMatchingStrategy, | ||||
| ) -> Result<Vec<BoxRankingRule<'ctx, QueryGraph>>> { | ||||
|     // query graph search | ||||
| @@ -559,7 +559,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( | ||||
|     ranking_rules: &mut Vec<BoxRankingRule<'ctx, Query>>, | ||||
|     sorted_fields: &mut HashSet<String>, | ||||
|     geo_sorted: &mut bool, | ||||
|     geo_param: geo_sort::Parameter, | ||||
|     geo_param: GeoSortParameter, | ||||
| ) -> Result<()> { | ||||
|     let sort_criteria = sort_criteria.clone().unwrap_or_default(); | ||||
|     ranking_rules.reserve(sort_criteria.len()); | ||||
| @@ -631,7 +631,7 @@ pub fn execute_vector_search( | ||||
|     universe: RoaringBitmap, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     distinct: &Option<String>, | ||||
|     geo_param: geo_sort::Parameter, | ||||
|     geo_param: GeoSortParameter, | ||||
|     from: usize, | ||||
|     length: usize, | ||||
|     embedder_name: &str, | ||||
| @@ -697,7 +697,7 @@ pub fn execute_search( | ||||
|     mut universe: RoaringBitmap, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     distinct: &Option<String>, | ||||
|     geo_param: geo_sort::Parameter, | ||||
|     geo_param: GeoSortParameter, | ||||
|     from: usize, | ||||
|     length: usize, | ||||
|     words_limit: Option<usize>, | ||||
| @@ -881,7 +881,7 @@ pub fn execute_search( | ||||
|     }) | ||||
| } | ||||
|  | ||||
| fn check_sort_criteria( | ||||
| pub(crate) fn check_sort_criteria( | ||||
|     ctx: &SearchContext<'_>, | ||||
|     sort_criteria: Option<&Vec<AscDesc>>, | ||||
| ) -> Result<()> { | ||||
| @@ -911,7 +911,7 @@ fn check_sort_criteria( | ||||
|                 let (valid_fields, hidden_fields) = | ||||
|                     ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?; | ||||
|  | ||||
|                 return Err(UserError::InvalidSortableAttribute { | ||||
|                 return Err(UserError::InvalidSearchSortableAttribute { | ||||
|                     field: field.to_string(), | ||||
|                     valid_fields, | ||||
|                     hidden_fields, | ||||
| @@ -922,7 +922,7 @@ fn check_sort_criteria( | ||||
|                 let (valid_fields, hidden_fields) = | ||||
|                     ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?; | ||||
|  | ||||
|                 return Err(UserError::InvalidSortableAttribute { | ||||
|                 return Err(UserError::InvalidSearchSortableAttribute { | ||||
|                     field: RESERVED_GEO_FIELD_NAME.to_string(), | ||||
|                     valid_fields, | ||||
|                     hidden_fields, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user