From 5da92a3d53a56af7719afe5ba097c32bf41ca93e Mon Sep 17 00:00:00 2001 From: hdt3213 Date: Mon, 14 Apr 2025 22:50:32 +0800 Subject: [PATCH] test geo sort reached max_bucket_size --- crates/milli/src/search/new/geo_sort.rs | 22 ++++++- crates/milli/src/search/new/tests/geo_sort.rs | 66 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/crates/milli/src/search/new/geo_sort.rs b/crates/milli/src/search/new/geo_sort.rs index 47694c6ce..1c52b0a5b 100644 --- a/crates/milli/src/search/new/geo_sort.rs +++ b/crates/milli/src/search/new/geo_sort.rs @@ -71,6 +71,26 @@ impl Strategy { } } +#[cfg(not(test))] +fn default_max_bucket_size() -> u64 { + 1000 +} + +#[cfg(test)] +static DEFAULT_MAX_BUCKET_SIZE: std::sync::Mutex = std::sync::Mutex::new(1000); + +#[cfg(test)] +pub fn set_default_max_bucket_size(n: u64) { + let mut size = DEFAULT_MAX_BUCKET_SIZE.lock().unwrap(); + *size = n; +} + +#[cfg(test)] +fn default_max_bucket_size() -> u64 { + let max_size = *(DEFAULT_MAX_BUCKET_SIZE.lock().unwrap()); + max_size +} + pub struct GeoSort { query: Option, @@ -105,7 +125,7 @@ impl GeoSort { field_ids: None, rtree: None, cached_sorted_docids: VecDeque::new(), - max_bucket_size: 1000, + max_bucket_size: default_max_bucket_size(), distance_error_margin: 1.0, }) } diff --git a/crates/milli/src/search/new/tests/geo_sort.rs b/crates/milli/src/search/new/tests/geo_sort.rs index ff946d226..e5993925a 100644 --- a/crates/milli/src/search/new/tests/geo_sort.rs +++ b/crates/milli/src/search/new/tests/geo_sort.rs @@ -4,6 +4,7 @@ This module tests the `geo_sort` ranking rule use big_s::S; use heed::RoTxn; +use itertools::Itertools; use maplit::hashset; use crate::constants::RESERVED_GEO_FIELD_NAME; @@ -136,6 +137,71 @@ fn test_geo_sort_with_following_ranking_rules() { insta::assert_snapshot!(format!("{scores:#?}")); } +#[test] +fn test_geo_sort_reached_max_bucket_size() { + let index = create_index(); + + index + .add_documents(documents!([ + { "id": 1 }, { "id": 4 }, { "id": 3 }, { "id": 2 }, { "id": 5 }, + { "id": 6, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 10 }, + { "id": 7, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 9 }, + { "id": 8, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 8 }, + { "id": 9, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 7 }, + { "id": 10, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score":6 }, + { "id": 11, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 5 }, + { "id": 12, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 10 }, + { "id": 13, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 9 }, + { "id": 14, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 8 }, + { "id": 15, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 7 }, + ])) + .unwrap(); + + crate::search::new::geo_sort::set_default_max_bucket_size(2); + let rtxn = index.read_txn().unwrap(); + + let mut s = Search::new(&rtxn, &index); + s.scoring_strategy(crate::score_details::ScoringStrategy::Detailed); + s.sort_criteria(vec![ + AscDesc::Asc(Member::Geo([0., 0.])), + AscDesc::Desc(Member::Field("score".to_string())), + ]); + + /* We should not expect the results to obey the following ranking rules when the bucket size limit is reached, + * nor should we expect Iteration and rtree to give exactly the same order for the same bucket in this case.*/ + s.geo_sort_strategy(GeoSortStrategy::AlwaysIterative(1000)); + let SearchResult { documents_ids, .. } = s.execute().unwrap(); + let iterative_ids = collect_field_values(&index, &rtxn, "id", &documents_ids); + + assert_eq!(iterative_ids.len(), 15); + for id_str in &iterative_ids[0..6] { + let id = id_str.parse::().unwrap(); + assert!(id >= 6 && id <= 11) + } + for id_str in &iterative_ids[6..10] { + let id = id_str.parse::().unwrap(); + assert!(id >= 12 && id <= 15) + } + let no_geo_ids = iterative_ids[10..].iter().collect_vec(); + insta::assert_snapshot!(format!("{no_geo_ids:?}"), @r#"["1", "4", "3", "2", "5"]"#); + + s.geo_sort_strategy(GeoSortStrategy::AlwaysRtree(1000)); + let SearchResult { documents_ids, .. } = s.execute().unwrap(); + let rtree_ids = collect_field_values(&index, &rtxn, "id", &documents_ids); + + assert_eq!(rtree_ids.len(), 15); + for id_str in &rtree_ids[0..6] { + let id = id_str.parse::().unwrap(); + assert!(id >= 6 && id <= 11) + } + for id_str in &rtree_ids[6..10] { + let id = id_str.parse::().unwrap(); + assert!(id >= 12 && id <= 15) + } + let no_geo_ids = rtree_ids[10..].iter().collect_vec(); + insta::assert_snapshot!(format!("{no_geo_ids:?}"), @r#"["1", "4", "3", "2", "5"]"#); +} + #[test] fn test_geo_sort_around_the_edge_of_the_flat_earth() { let index = create_index();