Move geosort code out of search

This commit is contained in:
Mubelotix
2025-06-30 13:12:00 +02:00
parent 63827bbee0
commit e35d58b531
10 changed files with 257 additions and 243 deletions

View File

@@ -641,14 +641,10 @@ fn documents_by_query(
}; };
let sort_criteria = if let Some(sort) = &sort { let sort_criteria = if let Some(sort) = &sort {
let sorts: Vec<_> = let sorts: Vec<_> = match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() {
match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() {
Ok(sorts) => sorts, Ok(sorts) => sorts,
Err(asc_desc_error) => { Err(asc_desc_error) => {
return Err(milli::Error::from(milli::SortError::from( return Err(milli::Error::from(milli::SortError::from(asc_desc_error)).into())
asc_desc_error,
))
.into())
} }
}; };
Some(sorts) Some(sorts)

View File

@@ -1,14 +1,70 @@
use std::collections::VecDeque; use crate::{
distance_between_two_points,
use heed::RoTxn; heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec},
lat_lng_to_xyz,
search::new::{facet_string_values, facet_values_prefix_key},
GeoPoint, Index,
};
use heed::{
types::{Bytes, Unit},
RoPrefix, RoTxn,
};
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use rstar::RTree; use rstar::RTree;
use std::collections::VecDeque;
use crate::{ #[derive(Debug, Clone, Copy)]
distance_between_two_points, lat_lng_to_xyz, pub struct GeoSortParameter {
search::new::geo_sort::{geo_value, opposite_of}, // Define the strategy used by the geo sort
GeoPoint, GeoSortStrategy, Index, pub strategy: GeoSortStrategy,
}; // Limit the number of docs in a single bucket to avoid unexpectedly large overhead
pub max_bucket_size: u64,
// Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal
pub distance_error_margin: f64,
}
impl Default for GeoSortParameter {
fn default() -> Self {
Self {
strategy: GeoSortStrategy::default(),
max_bucket_size: 1000,
distance_error_margin: 1.0,
}
}
}
/// Define the strategy used by the geo sort.
/// The parameter represents the cache size, and, in the case of the Dynamic strategy,
/// the point where we move from using the iterative strategy to the rtree.
#[derive(Debug, Clone, Copy)]
pub enum GeoSortStrategy {
AlwaysIterative(usize),
AlwaysRtree(usize),
Dynamic(usize),
}
impl Default for GeoSortStrategy {
fn default() -> Self {
GeoSortStrategy::Dynamic(1000)
}
}
impl GeoSortStrategy {
pub fn use_rtree(&self, candidates: usize) -> bool {
match self {
GeoSortStrategy::AlwaysIterative(_) => false,
GeoSortStrategy::AlwaysRtree(_) => true,
GeoSortStrategy::Dynamic(i) => candidates >= *i,
}
}
pub fn cache_size(&self) -> usize {
match self {
GeoSortStrategy::AlwaysIterative(i)
| GeoSortStrategy::AlwaysRtree(i)
| GeoSortStrategy::Dynamic(i) => *i,
}
}
}
// TODO: Make it take a mut reference to cache // TODO: Make it take a mut reference to cache
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@@ -74,7 +130,8 @@ pub fn fill_cache(
.map(|id| -> crate::Result<_> { Ok((id, geo_value(id, lat, lng, index, txn)?)) }) .map(|id| -> crate::Result<_> { Ok((id, geo_value(id, lat, lng, index, txn)?)) })
.collect::<crate::Result<Vec<(u32, [f64; 2])>>>()?; .collect::<crate::Result<Vec<(u32, [f64; 2])>>>()?;
// computing the distance between two points is expensive thus we cache the result // computing the distance between two points is expensive thus we cache the result
documents.sort_by_cached_key(|(_, p)| distance_between_two_points(&target_point, p) as usize); documents
.sort_by_cached_key(|(_, p)| distance_between_two_points(&target_point, p) as usize);
cached_sorted_docids.extend(documents); cached_sorted_docids.extend(documents);
}; };
@@ -86,19 +143,13 @@ pub fn next_bucket(
index: &Index, index: &Index,
txn: &RoTxn<heed::AnyTls>, txn: &RoTxn<heed::AnyTls>,
universe: &RoaringBitmap, universe: &RoaringBitmap,
strategy: GeoSortStrategy,
ascending: bool, ascending: bool,
target_point: [f64; 2], target_point: [f64; 2],
field_ids: &Option<[u16; 2]>, field_ids: &Option<[u16; 2]>,
rtree: &mut Option<RTree<GeoPoint>>, rtree: &mut Option<RTree<GeoPoint>>,
cached_sorted_docids: &mut VecDeque<(u32, [f64; 2])>, cached_sorted_docids: &mut VecDeque<(u32, [f64; 2])>,
geo_candidates: &RoaringBitmap, geo_candidates: &RoaringBitmap,
parameter: GeoSortParameter,
// Limit the number of docs in a single bucket to avoid unexpectedly large overhead
max_bucket_size: u64,
// Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal
distance_error_margin: f64,
) -> crate::Result<Option<(RoaringBitmap, Option<[f64; 2]>)>> { ) -> crate::Result<Option<(RoaringBitmap, Option<[f64; 2]>)>> {
let mut geo_candidates = geo_candidates & universe; let mut geo_candidates = geo_candidates & universe;
@@ -130,7 +181,7 @@ pub fn next_bucket(
if geo_candidates.contains(id) { if geo_candidates.contains(id) {
let distance = distance_between_two_points(&target_point, &point); let distance = distance_between_two_points(&target_point, &point);
if let Some((point0, bucket_distance)) = current_distance.as_ref() { if let Some((point0, bucket_distance)) = current_distance.as_ref() {
if (bucket_distance - distance).abs() > distance_error_margin { if (bucket_distance - distance).abs() > parameter.distance_error_margin {
// different distance, point belongs to next bucket // different distance, point belongs to next bucket
put_back(cached_sorted_docids, (id, point)); put_back(cached_sorted_docids, (id, point));
return Ok(Some((current_bucket, Some(point0.to_owned())))); return Ok(Some((current_bucket, Some(point0.to_owned()))));
@@ -140,7 +191,7 @@ pub fn next_bucket(
// remove from candidates to prevent it from being added to the cache again // remove from candidates to prevent it from being added to the cache again
geo_candidates.remove(id); geo_candidates.remove(id);
// current bucket size reaches limit, force return // current bucket size reaches limit, force return
if current_bucket.len() == max_bucket_size { if current_bucket.len() == parameter.max_bucket_size {
return Ok(Some((current_bucket, Some(point0.to_owned())))); return Ok(Some((current_bucket, Some(point0.to_owned()))));
} }
} }
@@ -150,7 +201,7 @@ pub fn next_bucket(
current_bucket.insert(id); current_bucket.insert(id);
geo_candidates.remove(id); geo_candidates.remove(id);
// current bucket size reaches limit, force return // current bucket size reaches limit, force return
if current_bucket.len() == max_bucket_size { if current_bucket.len() == parameter.max_bucket_size {
return Ok(Some((current_bucket, Some(point.to_owned())))); return Ok(Some((current_bucket, Some(point.to_owned()))));
} }
} }
@@ -160,7 +211,7 @@ pub fn next_bucket(
fill_cache( fill_cache(
index, index,
txn, txn,
strategy, parameter.strategy,
ascending, ascending,
target_point, target_point,
field_ids, field_ids,
@@ -180,3 +231,65 @@ pub fn next_bucket(
} }
} }
} }
/// Return an iterator over each number value in the given field of the given document.
fn facet_number_values<'a>(
docid: u32,
field_id: u16,
index: &Index,
txn: &'a RoTxn<'a>,
) -> crate::Result<RoPrefix<'a, FieldDocIdFacetCodec<OrderedF64Codec>, Unit>> {
let key = facet_values_prefix_key(field_id, docid);
let iter = index
.field_id_docid_facet_f64s
.remap_key_type::<Bytes>()
.prefix_iter(txn, &key)?
.remap_key_type();
Ok(iter)
}
/// Extracts the lat and long values from a single document.
///
/// If it is not able to find it in the facet number index it will extract it
/// from the facet string index and parse it as f64 (as the geo extraction behaves).
pub(crate) fn geo_value(
docid: u32,
field_lat: u16,
field_lng: u16,
index: &Index,
rtxn: &RoTxn<'_>,
) -> crate::Result<[f64; 2]> {
let extract_geo = |geo_field: u16| -> crate::Result<f64> {
match facet_number_values(docid, geo_field, index, rtxn)?.next() {
Some(Ok(((_, _, geo), ()))) => Ok(geo),
Some(Err(e)) => Err(e.into()),
None => match facet_string_values(docid, geo_field, index, rtxn)?.next() {
Some(Ok((_, geo))) => {
Ok(geo.parse::<f64>().expect("cannot parse geo field as f64"))
}
Some(Err(e)) => Err(e.into()),
None => panic!("A geo faceted document doesn't contain any lat or lng"),
},
}
};
let lat = extract_geo(field_lat)?;
let lng = extract_geo(field_lng)?;
Ok([lat, lng])
}
/// Compute the antipodal coordinate of `coord`
pub(crate) fn opposite_of(mut coord: [f64; 2]) -> [f64; 2] {
coord[0] *= -1.;
// in the case of x,0 we want to return x,180
if coord[1] > 0. {
coord[1] -= 180.;
} else {
coord[1] += 180.;
}
coord
}

View File

@@ -1,9 +1,9 @@
mod builder; mod builder;
mod enriched; mod enriched;
pub mod geo_sort;
mod primary_key; mod primary_key;
mod reader; mod reader;
mod serde_impl; mod serde_impl;
pub mod geo_sort;
use std::fmt::Debug; use std::fmt::Debug;
use std::io; use std::io;
@@ -20,6 +20,7 @@ pub use primary_key::{
pub use reader::{DocumentsBatchCursor, DocumentsBatchCursorError, DocumentsBatchReader}; pub use reader::{DocumentsBatchCursor, DocumentsBatchCursorError, DocumentsBatchReader};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use self::geo_sort::{GeoSortParameter, GeoSortStrategy};
use crate::error::{FieldIdMapMissingEntry, InternalError}; use crate::error::{FieldIdMapMissingEntry, InternalError};
use crate::{FieldId, Object, Result}; use crate::{FieldId, Object, Result};

View File

@@ -1,6 +1,16 @@
use roaring::RoaringBitmap; use crate::{
heed_codec::{
facet::{FacetGroupKeyCodec, FacetGroupValueCodec},
BytesRefCodec,
},
search::{
facet::{ascending_facet_sort, descending_facet_sort},
new::check_sort_criteria,
},
AscDesc, DocumentId, Member,
};
use heed::Database; use heed::Database;
use crate::{heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}, search::{facet::{ascending_facet_sort, descending_facet_sort}, new::check_sort_criteria}, AscDesc, DocumentId, Member}; use roaring::RoaringBitmap;
/// Builder for a [`SortedDocumentsIterator`]. /// Builder for a [`SortedDocumentsIterator`].
/// Most builders won't ever be built, because pagination will skip them. /// Most builders won't ever be built, because pagination will skip them.
@@ -15,13 +25,8 @@ pub struct SortedDocumentsIteratorBuilder<'ctx> {
impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> { impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
/// Performs the sort and builds a [`SortedDocumentsIterator`]. /// Performs the sort and builds a [`SortedDocumentsIterator`].
fn build(self) -> heed::Result<SortedDocumentsIterator<'ctx>> { fn build(self) -> heed::Result<SortedDocumentsIterator<'ctx>> {
let SortedDocumentsIteratorBuilder { let SortedDocumentsIteratorBuilder { rtxn, number_db, string_db, fields, candidates } =
rtxn, self;
number_db,
string_db,
fields,
candidates,
} = self;
let size = candidates.len() as usize; let size = candidates.len() as usize;
// There is no point sorting a 1-element array // There is no point sorting a 1-element array
@@ -42,33 +47,13 @@ impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
// Perform the sort on the first field // Perform the sort on the first field
let (number_iter, string_iter) = if ascending { let (number_iter, string_iter) = if ascending {
let number_iter = ascending_facet_sort( let number_iter = ascending_facet_sort(rtxn, number_db, field_id, candidates.clone())?;
rtxn, let string_iter = ascending_facet_sort(rtxn, string_db, field_id, candidates)?;
number_db,
field_id,
candidates.clone(),
)?;
let string_iter = ascending_facet_sort(
rtxn,
string_db,
field_id,
candidates,
)?;
(itertools::Either::Left(number_iter), itertools::Either::Left(string_iter)) (itertools::Either::Left(number_iter), itertools::Either::Left(string_iter))
} else { } else {
let number_iter = descending_facet_sort( let number_iter = descending_facet_sort(rtxn, number_db, field_id, candidates.clone())?;
rtxn, let string_iter = descending_facet_sort(rtxn, string_db, field_id, candidates)?;
number_db,
field_id,
candidates.clone(),
)?;
let string_iter = descending_facet_sort(
rtxn,
string_db,
field_id,
candidates,
)?;
(itertools::Either::Right(number_iter), itertools::Either::Right(string_iter)) (itertools::Either::Right(number_iter), itertools::Either::Right(string_iter))
}; };
@@ -76,7 +61,8 @@ impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
// Create builders for the next level of the tree // Create builders for the next level of the tree
let number_db2 = number_db; let number_db2 = number_db;
let string_db2 = string_db; let string_db2 = string_db;
let number_iter = number_iter.map(move |r| -> heed::Result<SortedDocumentsIteratorBuilder> { let number_iter =
number_iter.map(move |r| -> heed::Result<SortedDocumentsIteratorBuilder> {
let (docids, _bytes) = r?; let (docids, _bytes) = r?;
Ok(SortedDocumentsIteratorBuilder { Ok(SortedDocumentsIteratorBuilder {
rtxn, rtxn,
@@ -86,7 +72,8 @@ impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
candidates: docids, candidates: docids,
}) })
}); });
let string_iter = string_iter.map(move |r| -> heed::Result<SortedDocumentsIteratorBuilder> { let string_iter =
string_iter.map(move |r| -> heed::Result<SortedDocumentsIteratorBuilder> {
let (docids, _bytes) = r?; let (docids, _bytes) = r?;
Ok(SortedDocumentsIteratorBuilder { Ok(SortedDocumentsIteratorBuilder {
rtxn, rtxn,
@@ -112,7 +99,7 @@ pub enum SortedDocumentsIterator<'ctx> {
Leaf { Leaf {
/// The exact number of documents remaining /// The exact number of documents remaining
size: usize, size: usize,
values: Box<dyn Iterator<Item = DocumentId> + 'ctx> values: Box<dyn Iterator<Item = DocumentId> + 'ctx>,
}, },
Branch { Branch {
/// The current child, got from the children iterator /// The current child, got from the children iterator
@@ -120,20 +107,27 @@ pub enum SortedDocumentsIterator<'ctx> {
/// The exact number of documents remaining, excluding documents in the current child /// The exact number of documents remaining, excluding documents in the current child
next_children_size: usize, next_children_size: usize,
/// Iterators to become the current child once it is exhausted /// Iterators to become the current child once it is exhausted
next_children: Box<dyn Iterator<Item = heed::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx>, next_children:
} Box<dyn Iterator<Item = heed::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx>,
},
} }
impl SortedDocumentsIterator<'_> { impl SortedDocumentsIterator<'_> {
/// Takes care of updating the current child if it is `None`, and also updates the size /// Takes care of updating the current child if it is `None`, and also updates the size
fn update_current<'ctx>(current_child: &mut Option<Box<SortedDocumentsIterator<'ctx>>>, next_children_size: &mut usize, next_children: &mut Box<dyn Iterator<Item = heed::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx>) -> heed::Result<()> { fn update_current<'ctx>(
current_child: &mut Option<Box<SortedDocumentsIterator<'ctx>>>,
next_children_size: &mut usize,
next_children: &mut Box<
dyn Iterator<Item = heed::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx,
>,
) -> heed::Result<()> {
if current_child.is_none() { if current_child.is_none() {
*current_child = match next_children.next() { *current_child = match next_children.next() {
Some(Ok(builder)) => { Some(Ok(builder)) => {
let next_child = Box::new(builder.build()?); let next_child = Box::new(builder.build()?);
*next_children_size -= next_child.size_hint().0; *next_children_size -= next_child.size_hint().0;
Some(next_child) Some(next_child)
}, }
Some(Err(e)) => return Err(e), Some(Err(e)) => return Err(e),
None => return Ok(()), None => return Ok(()),
}; };
@@ -150,15 +144,23 @@ impl Iterator for SortedDocumentsIterator<'_> {
let (current_child, next_children, next_children_size) = match self { let (current_child, next_children, next_children_size) = match self {
SortedDocumentsIterator::Leaf { values, size } => { SortedDocumentsIterator::Leaf { values, size } => {
*size = size.saturating_sub(n); *size = size.saturating_sub(n);
return values.nth(n).map(Ok) return values.nth(n).map(Ok);
}, }
SortedDocumentsIterator::Branch { current_child, next_children, next_children_size } => (current_child, next_children, next_children_size), SortedDocumentsIterator::Branch {
current_child,
next_children,
next_children_size,
} => (current_child, next_children, next_children_size),
}; };
// Otherwise don't directly iterate over children, skip them if we know we will go further // Otherwise don't directly iterate over children, skip them if we know we will go further
let mut to_skip = n - 1; let mut to_skip = n - 1;
while to_skip > 0 { while to_skip > 0 {
if let Err(e) = SortedDocumentsIterator::update_current(current_child, next_children_size, next_children) { if let Err(e) = SortedDocumentsIterator::update_current(
current_child,
next_children_size,
next_children,
) {
return Some(Err(e)); return Some(Err(e));
} }
let Some(inner) = current_child else { let Some(inner) = current_child else {
@@ -183,8 +185,14 @@ impl Iterator for SortedDocumentsIterator<'_> {
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
let size = match self { let size = match self {
SortedDocumentsIterator::Leaf { size, .. } => *size, SortedDocumentsIterator::Leaf { size, .. } => *size,
SortedDocumentsIterator::Branch { next_children_size, current_child: Some(current_child), .. } => current_child.size_hint().0 + next_children_size, SortedDocumentsIterator::Branch {
SortedDocumentsIterator::Branch { next_children_size, current_child: None, .. } => *next_children_size, next_children_size,
current_child: Some(current_child),
..
} => current_child.size_hint().0 + next_children_size,
SortedDocumentsIterator::Branch { next_children_size, current_child: None, .. } => {
*next_children_size
}
}; };
(size, Some(size)) (size, Some(size))
@@ -198,12 +206,20 @@ impl Iterator for SortedDocumentsIterator<'_> {
*size -= 1; *size -= 1;
} }
result result
}, }
SortedDocumentsIterator::Branch { current_child, next_children_size, next_children } => { SortedDocumentsIterator::Branch {
current_child,
next_children_size,
next_children,
} => {
let mut result = None; let mut result = None;
while result.is_none() { while result.is_none() {
// Ensure we have selected an iterator to work with // Ensure we have selected an iterator to work with
if let Err(e) = SortedDocumentsIterator::update_current(current_child, next_children_size, next_children) { if let Err(e) = SortedDocumentsIterator::update_current(
current_child,
next_children_size,
next_children,
) {
return Some(Err(e)); return Some(Err(e));
} }
let Some(inner) = current_child else { let Some(inner) = current_child else {
@@ -267,18 +283,9 @@ pub fn recursive_facet_sort<'ctx>(
} }
} }
let number_db = index let number_db = index.facet_id_f64_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
.facet_id_f64_docids let string_db =
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>(); index.facet_id_string_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let string_db = index
.facet_id_string_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
Ok(SortedDocuments { Ok(SortedDocuments { rtxn, fields, number_db, string_db, candidates })
rtxn,
fields,
number_db,
string_db,
candidates,
})
} }

View File

@@ -1,7 +1,7 @@
pub mod facet_sort_recursive;
mod facet_type; mod facet_type;
mod facet_value; mod facet_value;
pub mod value_encoding; pub mod value_encoding;
pub mod facet_sort_recursive;
pub use self::facet_type::FacetType; pub use self::facet_type::FacetType;
pub use self::facet_value::FacetValue; pub use self::facet_value::FacetValue;

View File

@@ -43,12 +43,13 @@ use std::fmt;
use std::hash::BuildHasherDefault; use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
pub use documents::GeoSortStrategy;
pub use filter_parser::{Condition, FilterCondition, Span, Token}; pub use filter_parser::{Condition, FilterCondition, Span, Token};
use fxhash::{FxHasher32, FxHasher64}; use fxhash::{FxHasher32, FxHasher64};
pub use grenad::CompressionType; pub use grenad::CompressionType;
pub use search::new::{ pub use search::new::{
execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, execute_search, filtered_universe, DefaultSearchLogger, SearchContext, SearchLogger,
SearchLogger, VisualSearchLogger, VisualSearchLogger,
}; };
use serde_json::Value; use serde_json::Value;
pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};

View File

@@ -9,6 +9,8 @@ use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats}; use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats};
use crate::documents::GeoSortParameter;
use crate::documents::GeoSortStrategy;
use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features}; use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features};
use crate::index::MatchingStrategy; use crate::index::MatchingStrategy;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
@@ -46,7 +48,7 @@ pub struct Search<'a> {
sort_criteria: Option<Vec<AscDesc>>, sort_criteria: Option<Vec<AscDesc>>,
distinct: Option<String>, distinct: Option<String>,
searchable_attributes: Option<&'a [String]>, searchable_attributes: Option<&'a [String]>,
geo_param: new::GeoSortParameter, geo_param: GeoSortParameter,
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
words_limit: usize, words_limit: usize,
@@ -69,7 +71,7 @@ impl<'a> Search<'a> {
sort_criteria: None, sort_criteria: None,
distinct: None, distinct: None,
searchable_attributes: None, searchable_attributes: None,
geo_param: new::GeoSortParameter::default(), geo_param: GeoSortParameter::default(),
terms_matching_strategy: TermsMatchingStrategy::default(), terms_matching_strategy: TermsMatchingStrategy::default(),
scoring_strategy: Default::default(), scoring_strategy: Default::default(),
exhaustive_number_hits: false, exhaustive_number_hits: false,
@@ -145,7 +147,7 @@ impl<'a> Search<'a> {
} }
#[cfg(test)] #[cfg(test)]
pub fn geo_sort_strategy(&mut self, strategy: new::GeoSortStrategy) -> &mut Search<'a> { pub fn geo_sort_strategy(&mut self, strategy: GeoSortStrategy) -> &mut Search<'a> {
self.geo_param.strategy = strategy; self.geo_param.strategy = strategy;
self self
} }

View File

@@ -118,7 +118,7 @@ pub fn facet_string_values<'a>(
} }
#[allow(clippy::drop_non_drop)] #[allow(clippy::drop_non_drop)]
fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] { pub(crate) fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] {
concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes()) concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes())
} }

View File

@@ -8,6 +8,7 @@ use rstar::RTree;
use super::facet_string_values; use super::facet_string_values;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::documents::geo_sort::{fill_cache, next_bucket}; use crate::documents::geo_sort::{fill_cache, next_bucket};
use crate::documents::{GeoSortParameter, GeoSortStrategy};
use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec}; use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec};
use crate::score_details::{self, ScoreDetails}; use crate::score_details::{self, ScoreDetails};
use crate::{GeoPoint, Index, Result, SearchContext, SearchLogger}; use crate::{GeoPoint, Index, Result, SearchContext, SearchLogger};
@@ -20,75 +21,10 @@ fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE
concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes()) concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes())
} }
/// Return an iterator over each number value in the given field of the given document.
fn facet_number_values<'a>(
docid: u32,
field_id: u16,
index: &Index,
txn: &'a RoTxn<'a>,
) -> Result<RoPrefix<'a, FieldDocIdFacetCodec<OrderedF64Codec>, Unit>> {
let key = facet_values_prefix_key(field_id, docid);
let iter = index
.field_id_docid_facet_f64s
.remap_key_type::<Bytes>()
.prefix_iter(txn, &key)?
.remap_key_type();
Ok(iter)
}
#[derive(Debug, Clone, Copy)]
pub struct Parameter {
// Define the strategy used by the geo sort
pub strategy: Strategy,
// Limit the number of docs in a single bucket to avoid unexpectedly large overhead
pub max_bucket_size: u64,
// Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal
pub distance_error_margin: f64,
}
impl Default for Parameter {
fn default() -> Self {
Self { strategy: Strategy::default(), max_bucket_size: 1000, distance_error_margin: 1.0 }
}
}
/// Define the strategy used by the geo sort.
/// The parameter represents the cache size, and, in the case of the Dynamic strategy,
/// the point where we move from using the iterative strategy to the rtree.
#[derive(Debug, Clone, Copy)]
pub enum Strategy {
AlwaysIterative(usize),
AlwaysRtree(usize),
Dynamic(usize),
}
impl Default for Strategy {
fn default() -> Self {
Strategy::Dynamic(1000)
}
}
impl Strategy {
pub fn use_rtree(&self, candidates: usize) -> bool {
match self {
Strategy::AlwaysIterative(_) => false,
Strategy::AlwaysRtree(_) => true,
Strategy::Dynamic(i) => candidates >= *i,
}
}
pub fn cache_size(&self) -> usize {
match self {
Strategy::AlwaysIterative(i) | Strategy::AlwaysRtree(i) | Strategy::Dynamic(i) => *i,
}
}
}
pub struct GeoSort<Q: RankingRuleQueryTrait> { pub struct GeoSort<Q: RankingRuleQueryTrait> {
query: Option<Q>, query: Option<Q>,
strategy: Strategy, strategy: GeoSortStrategy,
ascending: bool, ascending: bool,
point: [f64; 2], point: [f64; 2],
field_ids: Option<[u16; 2]>, field_ids: Option<[u16; 2]>,
@@ -105,12 +41,12 @@ pub struct GeoSort<Q: RankingRuleQueryTrait> {
impl<Q: RankingRuleQueryTrait> GeoSort<Q> { impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
pub fn new( pub fn new(
parameter: Parameter, parameter: GeoSortParameter,
geo_faceted_docids: RoaringBitmap, geo_faceted_docids: RoaringBitmap,
point: [f64; 2], point: [f64; 2],
ascending: bool, ascending: bool,
) -> Result<Self> { ) -> Result<Self> {
let Parameter { strategy, max_bucket_size, distance_error_margin } = parameter; let GeoSortParameter { strategy, max_bucket_size, distance_error_margin } = parameter;
Ok(Self { Ok(Self {
query: None, query: None,
strategy, strategy,
@@ -148,37 +84,6 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
} }
} }
/// Extracts the lat and long values from a single document.
///
/// If it is not able to find it in the facet number index it will extract it
/// from the facet string index and parse it as f64 (as the geo extraction behaves).
pub(crate) fn geo_value(
docid: u32,
field_lat: u16,
field_lng: u16,
index: &Index,
rtxn: &RoTxn<'_>,
) -> Result<[f64; 2]> {
let extract_geo = |geo_field: u16| -> Result<f64> {
match facet_number_values(docid, geo_field, index, rtxn)?.next() {
Some(Ok(((_, _, geo), ()))) => Ok(geo),
Some(Err(e)) => Err(e.into()),
None => match facet_string_values(docid, geo_field, index, rtxn)?.next() {
Some(Ok((_, geo))) => {
Ok(geo.parse::<f64>().expect("cannot parse geo field as f64"))
}
Some(Err(e)) => Err(e.into()),
None => panic!("A geo faceted document doesn't contain any lat or lng"),
},
}
};
let lat = extract_geo(field_lat)?;
let lng = extract_geo(field_lng)?;
Ok([lat, lng])
}
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> { impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
fn id(&self) -> String { fn id(&self) -> String {
"geo_sort".to_owned() "geo_sort".to_owned()
@@ -224,15 +129,17 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
ctx.index, ctx.index,
ctx.txn, ctx.txn,
universe, universe,
self.strategy,
self.ascending, self.ascending,
self.point, self.point,
&self.field_ids, &self.field_ids,
&mut self.rtree, &mut self.rtree,
&mut self.cached_sorted_docids, &mut self.cached_sorted_docids,
&self.geo_candidates, &self.geo_candidates,
self.max_bucket_size, GeoSortParameter {
self.distance_error_margin, strategy: self.strategy,
max_bucket_size: self.max_bucket_size,
distance_error_margin: self.distance_error_margin,
},
) )
.map(|o| { .map(|o| {
o.map(|(candidates, point)| RankingRuleOutput { o.map(|(candidates, point)| RankingRuleOutput {
@@ -254,16 +161,3 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
self.cached_sorted_docids.clear(); self.cached_sorted_docids.clear();
} }
} }
/// Compute the antipodal coordinate of `coord`
pub(crate) fn opposite_of(mut coord: [f64; 2]) -> [f64; 2] {
coord[0] *= -1.;
// in the case of x,0 we want to return x,180
if coord[1] > 0. {
coord[1] -= 180.;
} else {
coord[1] += 180.;
}
coord
}

View File

@@ -46,14 +46,14 @@ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache};
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use sort::Sort; use sort::Sort;
use self::distinct::facet_string_values; pub(crate) use self::distinct::{facet_string_values, facet_values_prefix_key};
use self::geo_sort::GeoSort; use self::geo_sort::GeoSort;
pub use self::geo_sort::{Parameter as GeoSortParameter, Strategy as GeoSortStrategy};
use self::graph_based_ranking_rule::Words; use self::graph_based_ranking_rule::Words;
use self::interner::Interned; use self::interner::Interned;
use self::vector_sort::VectorSort; use self::vector_sort::VectorSort;
use crate::attribute_patterns::{match_pattern, PatternMatch}; use crate::attribute_patterns::{match_pattern, PatternMatch};
use crate::constants::RESERVED_GEO_FIELD_NAME; use crate::constants::RESERVED_GEO_FIELD_NAME;
use crate::documents::GeoSortParameter;
use crate::index::PrefixSearch; use crate::index::PrefixSearch;
use crate::localized_attributes_rules::LocalizedFieldIds; use crate::localized_attributes_rules::LocalizedFieldIds;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
@@ -319,7 +319,7 @@ fn resolve_negative_phrases(
fn get_ranking_rules_for_placeholder_search<'ctx>( fn get_ranking_rules_for_placeholder_search<'ctx>(
ctx: &SearchContext<'ctx>, ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
geo_param: geo_sort::Parameter, geo_param: GeoSortParameter,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
let mut sort = false; let mut sort = false;
let mut sorted_fields = HashSet::new(); let mut sorted_fields = HashSet::new();
@@ -371,7 +371,7 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
fn get_ranking_rules_for_vector<'ctx>( fn get_ranking_rules_for_vector<'ctx>(
ctx: &SearchContext<'ctx>, ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
geo_param: geo_sort::Parameter, geo_param: GeoSortParameter,
limit_plus_offset: usize, limit_plus_offset: usize,
target: &[f32], target: &[f32],
embedder_name: &str, embedder_name: &str,
@@ -448,7 +448,7 @@ fn get_ranking_rules_for_vector<'ctx>(
fn get_ranking_rules_for_query_graph_search<'ctx>( fn get_ranking_rules_for_query_graph_search<'ctx>(
ctx: &SearchContext<'ctx>, ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
geo_param: geo_sort::Parameter, geo_param: GeoSortParameter,
terms_matching_strategy: TermsMatchingStrategy, terms_matching_strategy: TermsMatchingStrategy,
) -> Result<Vec<BoxRankingRule<'ctx, QueryGraph>>> { ) -> Result<Vec<BoxRankingRule<'ctx, QueryGraph>>> {
// query graph search // query graph search
@@ -559,7 +559,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
ranking_rules: &mut Vec<BoxRankingRule<'ctx, Query>>, ranking_rules: &mut Vec<BoxRankingRule<'ctx, Query>>,
sorted_fields: &mut HashSet<String>, sorted_fields: &mut HashSet<String>,
geo_sorted: &mut bool, geo_sorted: &mut bool,
geo_param: geo_sort::Parameter, geo_param: GeoSortParameter,
) -> Result<()> { ) -> Result<()> {
let sort_criteria = sort_criteria.clone().unwrap_or_default(); let sort_criteria = sort_criteria.clone().unwrap_or_default();
ranking_rules.reserve(sort_criteria.len()); ranking_rules.reserve(sort_criteria.len());
@@ -629,7 +629,7 @@ pub fn execute_vector_search(
universe: RoaringBitmap, universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
distinct: &Option<String>, distinct: &Option<String>,
geo_param: geo_sort::Parameter, geo_param: GeoSortParameter,
from: usize, from: usize,
length: usize, length: usize,
embedder_name: &str, embedder_name: &str,
@@ -692,7 +692,7 @@ pub fn execute_search(
mut universe: RoaringBitmap, mut universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>, sort_criteria: &Option<Vec<AscDesc>>,
distinct: &Option<String>, distinct: &Option<String>,
geo_param: geo_sort::Parameter, geo_param: GeoSortParameter,
from: usize, from: usize,
length: usize, length: usize,
words_limit: Option<usize>, words_limit: Option<usize>,