mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-12-24 05:16:59 +00:00
Compare commits
23 Commits
aggregate_
...
prototype-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd01613a63 | ||
|
|
70d975b399 | ||
|
|
a8e6d946a7 | ||
|
|
7c1f72ae33 | ||
|
|
442a8f44c6 | ||
|
|
185a238c77 | ||
|
|
a82bf776f3 | ||
|
|
b2f86df127 | ||
|
|
c3a5f51705 | ||
|
|
686d1f4c12 | ||
|
|
ba75606731 | ||
|
|
baf3b036d9 | ||
|
|
0d499f0055 | ||
|
|
7999c397c5 | ||
|
|
c44db8b4bc | ||
|
|
9466949e34 | ||
|
|
f051bbfd84 | ||
|
|
72b1c3df08 | ||
|
|
01d2ee5cc1 | ||
|
|
e0c4682758 | ||
|
|
d9b4b39922 | ||
|
|
4829348d6e | ||
|
|
b6b6a80b76 |
21
.github/workflows/sdks-tests.yml
vendored
21
.github/workflows/sdks-tests.yml
vendored
@@ -16,8 +16,23 @@ env:
|
||||
MEILI_NO_ANALYTICS: 'true'
|
||||
|
||||
jobs:
|
||||
define-docker-image:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
docker-image: ${{ steps.define-image.outputs.docker-image }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Define the Docker image we need to use
|
||||
id: define-image
|
||||
run: |
|
||||
event=${{ github.event.action }}
|
||||
echo "docker-image=nightly" >> $GITHUB_OUTPUT
|
||||
if [[ $event == 'workflow_dispatch' ]]; then
|
||||
echo "docker-image=${{ github.event.inputs.docker_image }}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
meilisearch-js-tests:
|
||||
needs: define-docker-image
|
||||
name: JS SDK tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
@@ -52,6 +67,7 @@ jobs:
|
||||
run: yarn test:env:browser
|
||||
|
||||
instant-meilisearch-tests:
|
||||
needs: define-docker-image
|
||||
name: instant-meilisearch tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
@@ -78,6 +94,7 @@ jobs:
|
||||
run: yarn build
|
||||
|
||||
meilisearch-php-tests:
|
||||
needs: define-docker-image
|
||||
name: PHP SDK tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
@@ -108,6 +125,7 @@ jobs:
|
||||
composer remove --dev guzzlehttp/guzzle http-interop/http-factory-guzzle
|
||||
|
||||
meilisearch-python-tests:
|
||||
needs: define-docker-image
|
||||
name: Python SDK tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
@@ -132,6 +150,7 @@ jobs:
|
||||
run: pipenv run pytest
|
||||
|
||||
meilisearch-go-tests:
|
||||
needs: define-docker-image
|
||||
name: Go SDK tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
@@ -161,6 +180,7 @@ jobs:
|
||||
run: go test -v ./...
|
||||
|
||||
meilisearch-ruby-tests:
|
||||
needs: define-docker-image
|
||||
name: Ruby SDK tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
@@ -185,6 +205,7 @@ jobs:
|
||||
run: bundle exec rspec
|
||||
|
||||
meilisearch-rust-tests:
|
||||
needs: define-docker-image
|
||||
name: Rust SDK tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
|
||||
64
Cargo.lock
generated
64
Cargo.lock
generated
@@ -1207,6 +1207,12 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "doc-comment"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
|
||||
|
||||
[[package]]
|
||||
name = "dump"
|
||||
version = "1.2.0"
|
||||
@@ -1763,6 +1769,15 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
|
||||
dependencies = [
|
||||
"ahash 0.7.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
@@ -1864,6 +1879,22 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hnsw"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752"
|
||||
dependencies = [
|
||||
"ahash 0.7.6",
|
||||
"hashbrown 0.11.2",
|
||||
"libm",
|
||||
"num-traits",
|
||||
"rand_core",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"space",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.9"
|
||||
@@ -1994,7 +2025,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown",
|
||||
"hashbrown 0.12.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -2086,7 +2117,7 @@ checksum = "37228e06c75842d1097432d94d02f37fe3ebfca9791c2e8fef6e9db17ed128c1"
|
||||
dependencies = [
|
||||
"cedarwood",
|
||||
"fxhash",
|
||||
"hashbrown",
|
||||
"hashbrown 0.12.3",
|
||||
"lazy_static",
|
||||
"phf",
|
||||
"phf_codegen",
|
||||
@@ -2715,6 +2746,7 @@ dependencies = [
|
||||
"bimap",
|
||||
"bincode",
|
||||
"bstr",
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"charabia",
|
||||
"concat-arrays",
|
||||
@@ -2730,6 +2762,7 @@ dependencies = [
|
||||
"geoutils",
|
||||
"grenad",
|
||||
"heed",
|
||||
"hnsw",
|
||||
"insta",
|
||||
"itertools",
|
||||
"json-depth-checker",
|
||||
@@ -2744,6 +2777,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
"rand",
|
||||
"rand_pcg",
|
||||
"rayon",
|
||||
"roaring",
|
||||
"rstar",
|
||||
@@ -2753,6 +2787,7 @@ dependencies = [
|
||||
"smallstr",
|
||||
"smallvec",
|
||||
"smartstring",
|
||||
"space",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"time",
|
||||
@@ -3327,6 +3362,16 @@ dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_pcg"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.7.0"
|
||||
@@ -3764,6 +3809,9 @@ name = "smallvec"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smartstring"
|
||||
@@ -3786,6 +3834,16 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "space"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123"
|
||||
dependencies = [
|
||||
"doc-comment",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
@@ -4433,7 +4491,7 @@ version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c531a2dc4c462b833788be2c07eef4e621d0e9edbd55bf280cc164c1c1aa043"
|
||||
dependencies = [
|
||||
"hashbrown",
|
||||
"hashbrown 0.12.3",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
|
||||
@@ -217,6 +217,7 @@ InvalidDocumentFields , InvalidRequest , BAD_REQUEST ;
|
||||
MissingDocumentFilter , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
|
||||
@@ -236,13 +237,10 @@ InvalidSearchHighlightPreTag , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchHitsPerPage , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchLimit , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidMultiSearchMergeStrategy , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchOffset , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchPage , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchQ , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchSort , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsDisplayedAttributes , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsDistinctAttribute , InvalidRequest , BAD_REQUEST ;
|
||||
@@ -335,6 +333,7 @@ impl ErrorCode for milli::Error {
|
||||
UserError::InvalidSortableAttribute { .. } => Code::InvalidSearchSort,
|
||||
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
|
||||
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
|
||||
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
|
||||
UserError::SortError(_) => Code::InvalidSearchSort,
|
||||
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
||||
Code::InvalidSettingsTypoTolerance
|
||||
|
||||
@@ -34,6 +34,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
pub struct SearchQueryGet {
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchQ>)]
|
||||
q: Option<String>,
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchQ>)]
|
||||
vector: Option<Vec<f32>>,
|
||||
#[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError<InvalidSearchOffset>)]
|
||||
offset: Param<usize>,
|
||||
#[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError<InvalidSearchLimit>)]
|
||||
@@ -56,10 +58,6 @@ pub struct SearchQueryGet {
|
||||
sort: Option<String>,
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchShowMatchesPosition>)]
|
||||
show_matches_position: Param<bool>,
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchShowRankingScore>)]
|
||||
show_ranking_score: Param<bool>,
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchShowRankingScoreDetails>)]
|
||||
show_ranking_score_details: Param<bool>,
|
||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchFacets>)]
|
||||
facets: Option<CS<String>>,
|
||||
#[deserr( default = DEFAULT_HIGHLIGHT_PRE_TAG(), error = DeserrQueryParamError<InvalidSearchHighlightPreTag>)]
|
||||
@@ -84,6 +82,7 @@ impl From<SearchQueryGet> for SearchQuery {
|
||||
|
||||
Self {
|
||||
q: other.q,
|
||||
vector: other.vector,
|
||||
offset: other.offset.0,
|
||||
limit: other.limit.0,
|
||||
page: other.page.as_deref().copied(),
|
||||
@@ -95,8 +94,6 @@ impl From<SearchQueryGet> for SearchQuery {
|
||||
filter,
|
||||
sort: other.sort.map(|attr| fix_sort_query_parameters(&attr)),
|
||||
show_matches_position: other.show_matches_position.0,
|
||||
show_ranking_score: other.show_ranking_score.0,
|
||||
show_ranking_score_details: other.show_ranking_score_details.0,
|
||||
facets: other.facets.map(|o| o.into_iter().collect()),
|
||||
highlight_pre_tag: other.highlight_pre_tag,
|
||||
highlight_post_tag: other.highlight_post_tag,
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use deserr::actix_web::AwebJson;
|
||||
use deserr::Deserr;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use log::debug;
|
||||
use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::InvalidMultiSearchMergeStrategy;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::score_details::NotComparable;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::analytics::{Analytics, MultiSearchAggregator};
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::{AuthenticationError, GuardedData};
|
||||
use crate::extractors::sequential_extractor::SeqHandler;
|
||||
use crate::milli::score_details::ScoreDetails;
|
||||
use crate::search::{
|
||||
add_search_rules, perform_search, SearchHit, SearchQueryWithIndex, SearchResultWithIndex,
|
||||
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
|
||||
};
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
@@ -29,34 +23,13 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SearchResults {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
aggregate_hits: Option<Vec<SearchHitWithIndex>>,
|
||||
results: Vec<SearchResultWithIndex>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct SearchHitWithIndex {
|
||||
pub index_uid: String,
|
||||
#[serde(flatten)]
|
||||
pub hit: SearchHit,
|
||||
}
|
||||
|
||||
#[derive(Debug, deserr::Deserr)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct SearchQueries {
|
||||
queries: Vec<SearchQueryWithIndex>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchMergeStrategy>, default)]
|
||||
merge_strategy: MergeStrategy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserr, Default)]
|
||||
#[deserr(rename_all = camelCase)]
|
||||
pub enum MergeStrategy {
|
||||
#[default]
|
||||
None,
|
||||
ByNormalizedScore,
|
||||
ByScoreDetails,
|
||||
}
|
||||
|
||||
pub async fn multi_search_with_post(
|
||||
@@ -65,13 +38,7 @@ pub async fn multi_search_with_post(
|
||||
req: HttpRequest,
|
||||
analytics: web::Data<dyn Analytics>,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
let SearchQueries { queries, merge_strategy } = params.into_inner();
|
||||
// FIXME: REMOVE UNWRAP
|
||||
let max_hits = queries
|
||||
.iter()
|
||||
.map(|SearchQueryWithIndex { limit, hits_per_page, .. }| hits_per_page.unwrap_or(*limit))
|
||||
.max()
|
||||
.unwrap();
|
||||
let queries = params.into_inner().queries;
|
||||
|
||||
let mut multi_aggregate = MultiSearchAggregator::from_queries(&queries, &req);
|
||||
|
||||
@@ -137,117 +104,7 @@ pub async fn multi_search_with_post(
|
||||
|
||||
debug!("returns: {:?}", search_results);
|
||||
|
||||
let aggregate_hits = match merge_strategy {
|
||||
MergeStrategy::None => None,
|
||||
MergeStrategy::ByScoreDetails => Some(merge_by_score_details(&search_results, max_hits)),
|
||||
MergeStrategy::ByNormalizedScore => {
|
||||
Some(merge_by_normalized_score(&search_results, max_hits))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(SearchResults { aggregate_hits, results: search_results }))
|
||||
}
|
||||
|
||||
fn merge_by_score_details(
|
||||
search_results: &[SearchResultWithIndex],
|
||||
max_hits: usize,
|
||||
) -> Vec<SearchHitWithIndex> {
|
||||
let mut iterators: Vec<_> = search_results
|
||||
.iter()
|
||||
.filter_map(|SearchResultWithIndex { index_uid, result }| {
|
||||
let mut it = result.hits.iter();
|
||||
let next = it.next()?;
|
||||
Some((index_uid, it, next))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut hits = Vec::with_capacity(max_hits);
|
||||
|
||||
let mut inconsistent_indexes = HashMap::new();
|
||||
|
||||
for _ in 0..max_hits {
|
||||
iterators.sort_by(|(left_uid, _, left_hit), (right_uid, _, right_hit)| {
|
||||
let error = match ScoreDetails::partial_cmp_iter(
|
||||
left_hit.ranking_score_raw.iter(),
|
||||
right_hit.ranking_score_raw.iter(),
|
||||
) {
|
||||
Ok(ord) => return ord,
|
||||
Err(NotComparable(incomparable_index)) => incomparable_index,
|
||||
};
|
||||
inconsistent_indexes.entry((left_uid.to_owned(), right_uid.to_owned())).or_insert_with(
|
||||
|| {
|
||||
format!(
|
||||
"Detailed score {:?} is not comparable with {:?}: (left: {:#?}, right: {:#?})",
|
||||
left_hit.ranking_score_raw.get(error),
|
||||
right_hit.ranking_score_raw.get(error),
|
||||
left_hit.ranking_score_raw,
|
||||
right_hit.ranking_score_raw
|
||||
)
|
||||
},
|
||||
);
|
||||
std::cmp::Ordering::Less
|
||||
});
|
||||
if !inconsistent_indexes.is_empty() {
|
||||
let mut s = String::new();
|
||||
for ((left_uid, right_uid), error) in &inconsistent_indexes {
|
||||
use std::fmt::Write;
|
||||
writeln!(s, "Indexes {} and {} are inconsistent: {}", left_uid, right_uid, error)
|
||||
.unwrap();
|
||||
}
|
||||
// Replace panic with proper error
|
||||
panic!("{}", s);
|
||||
}
|
||||
|
||||
let Some((index_uid, it, next)) = iterators.last_mut()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
|
||||
let hit = SearchHitWithIndex { index_uid: index_uid.clone(), hit: next.clone() };
|
||||
if let Some(next_hit) = it.next() {
|
||||
*next = next_hit;
|
||||
} else {
|
||||
iterators.pop();
|
||||
}
|
||||
hits.push(hit);
|
||||
}
|
||||
hits
|
||||
}
|
||||
|
||||
fn merge_by_normalized_score(
|
||||
search_results: &[SearchResultWithIndex],
|
||||
max_hits: usize,
|
||||
) -> Vec<SearchHitWithIndex> {
|
||||
let mut iterators: Vec<_> = search_results
|
||||
.iter()
|
||||
.filter_map(|SearchResultWithIndex { index_uid, result }| {
|
||||
let mut it = result.hits.iter();
|
||||
let next = it.next()?;
|
||||
Some((index_uid, it, next))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut hits = Vec::with_capacity(max_hits);
|
||||
|
||||
for _ in 0..max_hits {
|
||||
iterators.sort_by_key(|(_, _, hit)| {
|
||||
ScoreDetails::global_score_linear_scale(hit.ranking_score_raw.iter())
|
||||
});
|
||||
|
||||
let Some((index_uid, it, next)) = iterators.last_mut()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
|
||||
let hit = SearchHitWithIndex { index_uid: index_uid.clone(), hit: next.clone() };
|
||||
if let Some(next_hit) = it.next() {
|
||||
*next = next_hit;
|
||||
} else {
|
||||
iterators.pop();
|
||||
}
|
||||
hits.push(hit);
|
||||
}
|
||||
hits
|
||||
Ok(HttpResponse::Ok().json(SearchResults { results: search_results }))
|
||||
}
|
||||
|
||||
/// Local `Result` extension trait to avoid `map_err` boilerplate.
|
||||
|
||||
@@ -9,7 +9,6 @@ use meilisearch_auth::IndexSearchRules;
|
||||
use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::score_details::ScoreDetails;
|
||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||
use meilisearch_types::{milli, Document};
|
||||
use milli::tokenizer::TokenizerBuilder;
|
||||
@@ -32,11 +31,13 @@ pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string();
|
||||
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
|
||||
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserr)]
|
||||
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct SearchQuery {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub q: Option<String>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub vector: Option<Vec<f32>>,
|
||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
pub offset: usize,
|
||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||
@@ -55,10 +56,6 @@ pub struct SearchQuery {
|
||||
pub attributes_to_highlight: Option<HashSet<String>>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchShowMatchesPosition>, default)]
|
||||
pub show_matches_position: bool,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScore>, default)]
|
||||
pub show_ranking_score: bool,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScoreDetails>, default)]
|
||||
pub show_ranking_score_details: bool,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
|
||||
pub filter: Option<Value>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchSort>)]
|
||||
@@ -85,13 +82,15 @@ impl SearchQuery {
|
||||
// This struct contains the fields of `SearchQuery` inline.
|
||||
// This is because neither deserr nor serde support `flatten` when using `deny_unknown_fields.
|
||||
// The `From<SearchQueryWithIndex>` implementation ensures both structs remain up to date.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserr)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserr)]
|
||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct SearchQueryWithIndex {
|
||||
#[deserr(error = DeserrJsonError<InvalidIndexUid>, missing_field_error = DeserrJsonError::missing_index_uid)]
|
||||
pub index_uid: IndexUid,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub q: Option<String>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub vector: Option<Vec<f32>>,
|
||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
pub offset: usize,
|
||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||
@@ -108,10 +107,6 @@ pub struct SearchQueryWithIndex {
|
||||
pub crop_length: usize,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToHighlight>)]
|
||||
pub attributes_to_highlight: Option<HashSet<String>>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScore>, default)]
|
||||
pub show_ranking_score: bool,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScoreDetails>, default)]
|
||||
pub show_ranking_score_details: bool,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchShowMatchesPosition>, default)]
|
||||
pub show_matches_position: bool,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
|
||||
@@ -135,6 +130,7 @@ impl SearchQueryWithIndex {
|
||||
let SearchQueryWithIndex {
|
||||
index_uid,
|
||||
q,
|
||||
vector,
|
||||
offset,
|
||||
limit,
|
||||
page,
|
||||
@@ -143,8 +139,6 @@ impl SearchQueryWithIndex {
|
||||
attributes_to_crop,
|
||||
crop_length,
|
||||
attributes_to_highlight,
|
||||
show_ranking_score,
|
||||
show_ranking_score_details,
|
||||
show_matches_position,
|
||||
filter,
|
||||
sort,
|
||||
@@ -158,6 +152,7 @@ impl SearchQueryWithIndex {
|
||||
index_uid,
|
||||
SearchQuery {
|
||||
q,
|
||||
vector,
|
||||
offset,
|
||||
limit,
|
||||
page,
|
||||
@@ -166,8 +161,6 @@ impl SearchQueryWithIndex {
|
||||
attributes_to_crop,
|
||||
crop_length,
|
||||
attributes_to_highlight,
|
||||
show_ranking_score,
|
||||
show_ranking_score_details,
|
||||
show_matches_position,
|
||||
filter,
|
||||
sort,
|
||||
@@ -207,7 +200,7 @@ impl From<MatchingStrategy> for TermsMatchingStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
|
||||
pub struct SearchHit {
|
||||
#[serde(flatten)]
|
||||
pub document: Document,
|
||||
@@ -215,12 +208,6 @@ pub struct SearchHit {
|
||||
pub formatted: Document,
|
||||
#[serde(rename = "_matchesPosition", skip_serializing_if = "Option::is_none")]
|
||||
pub matches_position: Option<MatchesPosition>,
|
||||
#[serde(rename = "_rankingScore", skip_serializing_if = "Option::is_none")]
|
||||
pub ranking_score: Option<u64>,
|
||||
#[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")]
|
||||
pub ranking_score_details: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
#[serde(skip)]
|
||||
pub ranking_score_raw: Vec<ScoreDetails>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||
@@ -289,6 +276,10 @@ pub fn perform_search(
|
||||
|
||||
let mut search = index.search(&rtxn);
|
||||
|
||||
if let Some(ref vector) = query.vector {
|
||||
search.vector(vector.clone());
|
||||
}
|
||||
|
||||
if let Some(ref query) = query.q {
|
||||
search.query(query);
|
||||
}
|
||||
@@ -339,8 +330,7 @@ pub fn perform_search(
|
||||
search.sort_criteria(sort);
|
||||
}
|
||||
|
||||
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
|
||||
search.execute()?;
|
||||
let milli::SearchResult { documents_ids, matching_words, candidates, .. } = search.execute()?;
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
|
||||
|
||||
@@ -412,7 +402,7 @@ pub fn perform_search(
|
||||
|
||||
let documents_iter = index.documents(&rtxn, documents_ids)?;
|
||||
|
||||
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
||||
for (_id, obkv) in documents_iter {
|
||||
// First generate a document with all the displayed fields
|
||||
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?;
|
||||
|
||||
@@ -436,19 +426,7 @@ pub fn perform_search(
|
||||
insert_geo_distance(sort, &mut document);
|
||||
}
|
||||
|
||||
let ranking_score =
|
||||
query.show_ranking_score.then(|| ScoreDetails::global_score_linear_scale(score.iter()));
|
||||
let ranking_score_details =
|
||||
query.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
|
||||
|
||||
let hit = SearchHit {
|
||||
document,
|
||||
formatted,
|
||||
matches_position,
|
||||
ranking_score_details,
|
||||
ranking_score,
|
||||
ranking_score_raw: score,
|
||||
};
|
||||
let hit = SearchHit { document, formatted, matches_position };
|
||||
documents.push(hit);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use insta::{allow_duplicates, assert_json_snapshot};
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
@@ -19,45 +18,30 @@ async fn formatted_contain_wildcard() {
|
||||
|response, code|
|
||||
{
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "<em>pésti</em>"
|
||||
},
|
||||
"_matchesPosition": {
|
||||
"cattos": [
|
||||
{
|
||||
"start": 0,
|
||||
"length": 5
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "<em>pésti</em>",
|
||||
},
|
||||
"_matchesPosition": {"cattos": [{"start": 0, "length": 5}]},
|
||||
})
|
||||
);
|
||||
}
|
||||
)
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "q": "pésti", "attributesToRetrieve": ["*"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -66,30 +50,20 @@ async fn formatted_contain_wildcard() {
|
||||
json!({ "q": "pésti", "attributesToRetrieve": ["*"], "attributesToHighlight": ["id"], "showMatchesPosition": true }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "pésti"
|
||||
},
|
||||
"_matchesPosition": {
|
||||
"cattos": [
|
||||
{
|
||||
"start": 0,
|
||||
"length": 5
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
})
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "pésti",
|
||||
},
|
||||
"_matchesPosition": {"cattos": [{"start": 0, "length": 5}]},
|
||||
})
|
||||
);
|
||||
}
|
||||
)
|
||||
.await;
|
||||
|
||||
index
|
||||
@@ -97,21 +71,17 @@ async fn formatted_contain_wildcard() {
|
||||
json!({ "q": "pésti", "attributesToRetrieve": ["*"], "attributesToCrop": ["*"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "pésti"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "pésti",
|
||||
}
|
||||
})
|
||||
);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -119,21 +89,17 @@ async fn formatted_contain_wildcard() {
|
||||
index
|
||||
.search(json!({ "q": "pésti", "attributesToCrop": ["*"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "pésti"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"cattos": "pésti",
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
"cattos": "pésti",
|
||||
}
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
@@ -150,25 +116,21 @@ async fn format_nested() {
|
||||
index
|
||||
.search(json!({ "q": "pésti", "attributesToRetrieve": ["doggos"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
"age": 2
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
"age": 4
|
||||
}
|
||||
],
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
"age": 2,
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
"age": 4,
|
||||
},
|
||||
],
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -177,23 +139,19 @@ async fn format_nested() {
|
||||
json!({ "q": "pésti", "attributesToRetrieve": ["doggos.name"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby"
|
||||
},
|
||||
{
|
||||
"name": "buddy"
|
||||
}
|
||||
],
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
},
|
||||
],
|
||||
})
|
||||
);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -203,31 +161,20 @@ async fn format_nested() {
|
||||
json!({ "q": "bobby", "attributesToRetrieve": ["doggos.name"], "showMatchesPosition": true }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby"
|
||||
},
|
||||
{
|
||||
"name": "buddy"
|
||||
}
|
||||
],
|
||||
"_matchesPosition": {
|
||||
"doggos.name": [
|
||||
{
|
||||
"start": 0,
|
||||
"length": 5
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
},
|
||||
],
|
||||
"_matchesPosition": {"doggos.name": [{"start": 0, "length": 5}]},
|
||||
})
|
||||
);
|
||||
}
|
||||
)
|
||||
.await;
|
||||
@@ -236,25 +183,21 @@ async fn format_nested() {
|
||||
.search(json!({ "q": "pésti", "attributesToRetrieve": [], "attributesToHighlight": ["doggos.name"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby"
|
||||
},
|
||||
{
|
||||
"name": "buddy"
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -262,25 +205,21 @@ async fn format_nested() {
|
||||
.search(json!({ "q": "pésti", "attributesToRetrieve": [], "attributesToCrop": ["doggos.name"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby"
|
||||
},
|
||||
{
|
||||
"name": "buddy"
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -288,63 +227,55 @@ async fn format_nested() {
|
||||
.search(json!({ "q": "pésti", "attributesToRetrieve": ["doggos.name"], "attributesToHighlight": ["doggos.age"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby"
|
||||
},
|
||||
{
|
||||
"name": "buddy"
|
||||
}
|
||||
],
|
||||
"_formatted": {
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
"age": "2"
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
"age": "4"
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
})
|
||||
{
|
||||
"name": "bobby",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
},
|
||||
],
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
"age": "2",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
"age": "4",
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "q": "pésti", "attributesToRetrieve": [], "attributesToHighlight": ["doggos.age"], "attributesToCrop": ["doggos.name"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"_formatted": {
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby",
|
||||
"age": "2"
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
"age": "4"
|
||||
}
|
||||
]
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
"name": "bobby",
|
||||
"age": "2",
|
||||
},
|
||||
{
|
||||
"name": "buddy",
|
||||
"age": "4",
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
)
|
||||
.await;
|
||||
@@ -366,70 +297,54 @@ async fn displayedattr_2_smol() {
|
||||
.search(json!({ "attributesToRetrieve": ["father", "id"], "attributesToHighlight": ["mother"], "attributesToCrop": ["cattos"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "attributesToRetrieve": ["id"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "attributesToHighlight": ["id"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_formatted": {
|
||||
"id": "852"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
}
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "attributesToCrop": ["id"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_formatted": {
|
||||
"id": "852"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
}
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -438,19 +353,15 @@ async fn displayedattr_2_smol() {
|
||||
json!({ "attributesToHighlight": ["id"], "attributesToCrop": ["id"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_formatted": {
|
||||
"id": "852"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
}
|
||||
})
|
||||
);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -458,47 +369,31 @@ async fn displayedattr_2_smol() {
|
||||
index
|
||||
.search(json!({ "attributesToHighlight": ["cattos"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "attributesToCrop": ["cattos"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"id": 852,
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"id": 852,
|
||||
})
|
||||
);
|
||||
})
|
||||
.await;
|
||||
|
||||
index
|
||||
.search(json!({ "attributesToRetrieve": ["cattos"] }), |response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(response["hits"][0], json!({}));
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -507,15 +402,7 @@ async fn displayedattr_2_smol() {
|
||||
json!({ "attributesToRetrieve": ["cattos"], "attributesToHighlight": ["cattos"], "attributesToCrop": ["cattos"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(response["hits"][0], json!({}));
|
||||
|
||||
}
|
||||
)
|
||||
@@ -526,18 +413,14 @@ async fn displayedattr_2_smol() {
|
||||
json!({ "attributesToRetrieve": ["cattos"], "attributesToHighlight": ["id"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_formatted": {
|
||||
"id": "852"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
}
|
||||
})
|
||||
);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -547,18 +430,14 @@ async fn displayedattr_2_smol() {
|
||||
json!({ "attributesToRetrieve": ["cattos"], "attributesToCrop": ["id"] }),
|
||||
|response, code| {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
allow_duplicates! {
|
||||
assert_json_snapshot!(response["hits"][0],
|
||||
{ "._rankingScore" => "[score]" },
|
||||
@r###"
|
||||
{
|
||||
"_formatted": {
|
||||
"id": "852"
|
||||
},
|
||||
"_rankingScore": "[score]"
|
||||
}
|
||||
"###)
|
||||
}
|
||||
assert_eq!(
|
||||
response["hits"][0],
|
||||
json!({
|
||||
"_formatted": {
|
||||
"id": "852",
|
||||
}
|
||||
})
|
||||
);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -65,15 +65,14 @@ async fn simple_search_single_index() {
|
||||
]}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
insta::assert_json_snapshot!(response["results"], { "[].processingTimeMs" => "[time]", ".**._rankingScore" => "[score]" }, @r###"
|
||||
insta::assert_json_snapshot!(response["results"], { "[].processingTimeMs" => "[time]" }, @r###"
|
||||
[
|
||||
{
|
||||
"indexUid": "test",
|
||||
"hits": [
|
||||
{
|
||||
"title": "Gläss",
|
||||
"id": "450465",
|
||||
"_rankingScore": "[score]"
|
||||
"id": "450465"
|
||||
}
|
||||
],
|
||||
"query": "glass",
|
||||
@@ -87,8 +86,7 @@ async fn simple_search_single_index() {
|
||||
"hits": [
|
||||
{
|
||||
"title": "Captain Marvel",
|
||||
"id": "299537",
|
||||
"_rankingScore": "[score]"
|
||||
"id": "299537"
|
||||
}
|
||||
],
|
||||
"query": "captain",
|
||||
@@ -172,15 +170,14 @@ async fn simple_search_two_indexes() {
|
||||
]}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
insta::assert_json_snapshot!(response["results"], { "[].processingTimeMs" => "[time]", ".**._rankingScore" => "[score]" }, @r###"
|
||||
insta::assert_json_snapshot!(response["results"], { "[].processingTimeMs" => "[time]" }, @r###"
|
||||
[
|
||||
{
|
||||
"indexUid": "test",
|
||||
"hits": [
|
||||
{
|
||||
"title": "Gläss",
|
||||
"id": "450465",
|
||||
"_rankingScore": "[score]"
|
||||
"id": "450465"
|
||||
}
|
||||
],
|
||||
"query": "glass",
|
||||
@@ -206,8 +203,7 @@ async fn simple_search_two_indexes() {
|
||||
"age": 4
|
||||
}
|
||||
],
|
||||
"cattos": "pésti",
|
||||
"_rankingScore": "[score]"
|
||||
"cattos": "pésti"
|
||||
},
|
||||
{
|
||||
"id": 654,
|
||||
@@ -222,8 +218,7 @@ async fn simple_search_two_indexes() {
|
||||
"cattos": [
|
||||
"simba",
|
||||
"pestiféré"
|
||||
],
|
||||
"_rankingScore": "[score]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"query": "pésti",
|
||||
|
||||
@@ -15,6 +15,7 @@ license.workspace = true
|
||||
bimap = { version = "0.6.3", features = ["serde"] }
|
||||
bincode = "1.3.3"
|
||||
bstr = "1.4.0"
|
||||
bytemuck = { version = "1.13.1", features = ["extern_crate_alloc"] }
|
||||
byteorder = "1.4.3"
|
||||
charabia = { version = "0.7.2", default-features = false }
|
||||
concat-arrays = "0.1.2"
|
||||
@@ -32,18 +33,21 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f
|
||||
"lmdb",
|
||||
"sync-read-txn",
|
||||
] }
|
||||
hnsw = { version = "0.11.0", features = ["serde1"] }
|
||||
json-depth-checker = { path = "../json-depth-checker" }
|
||||
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
|
||||
memmap2 = "0.5.10"
|
||||
obkv = "0.2.0"
|
||||
once_cell = "1.17.1"
|
||||
ordered-float = "3.6.0"
|
||||
rand_pcg = { version = "0.3.1", features = ["serde1"] }
|
||||
rayon = "1.7.0"
|
||||
roaring = "0.10.1"
|
||||
rstar = { version = "0.10.0", features = ["serde"] }
|
||||
serde = { version = "1.0.160", features = ["derive"] }
|
||||
serde_json = { version = "1.0.95", features = ["preserve_order"] }
|
||||
slice-group-by = "0.3.0"
|
||||
space = "0.17.0"
|
||||
smallstr = { version = "0.3.0", features = ["serde"] }
|
||||
smallvec = "1.10.0"
|
||||
smartstring = "1.0.1"
|
||||
|
||||
@@ -52,6 +52,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
let docs = execute_search(
|
||||
&mut ctx,
|
||||
&(!query.trim().is_empty()).then(|| query.trim().to_owned()),
|
||||
&None,
|
||||
TermsMatchingStrategy::Last,
|
||||
false,
|
||||
&None,
|
||||
|
||||
34
milli/src/distance.rs
Normal file
34
milli/src/distance.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use space::Metric;
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct DotProduct;
|
||||
|
||||
impl Metric<Vec<f32>> for DotProduct {
|
||||
type Unit = u32;
|
||||
|
||||
// TODO explain me this function, I don't understand why f32.to_bits is ordered.
|
||||
// I tried to do this and it wasn't OK <https://stackoverflow.com/a/43305015/1941280>
|
||||
//
|
||||
// Following <https://docs.rs/space/0.17.0/space/trait.Metric.html>.
|
||||
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
|
||||
let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum();
|
||||
let dist = 1.0 - dist;
|
||||
debug_assert!(!dist.is_nan());
|
||||
dist.to_bits()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct Euclidean;
|
||||
|
||||
impl Metric<Vec<f32>> for Euclidean {
|
||||
type Unit = u32;
|
||||
|
||||
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
|
||||
let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum();
|
||||
let dist = squared.sqrt();
|
||||
debug_assert!(!dist.is_nan());
|
||||
dist.to_bits()
|
||||
}
|
||||
}
|
||||
@@ -110,9 +110,11 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
},
|
||||
#[error(transparent)]
|
||||
InvalidGeoField(#[from] GeoError),
|
||||
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
|
||||
InvalidVectorDimensions { expected: usize, found: usize },
|
||||
#[error("{0}")]
|
||||
InvalidFilter(String),
|
||||
#[error("Invalid type for filter subexpression: `expected {}, found: {1}`.", .0.join(", "))]
|
||||
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
|
||||
InvalidFilterExpression(&'static [&'static str], Value),
|
||||
#[error("Attribute `{}` is not sortable. {}",
|
||||
.field,
|
||||
|
||||
@@ -8,10 +8,12 @@ use charabia::{Language, Script};
|
||||
use heed::flags::Flags;
|
||||
use heed::types::*;
|
||||
use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn};
|
||||
use rand_pcg::Pcg32;
|
||||
use roaring::RoaringBitmap;
|
||||
use rstar::RTree;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::distance::DotProduct;
|
||||
use crate::error::{InternalError, UserError};
|
||||
use crate::facet::FacetType;
|
||||
use crate::fields_ids_map::FieldsIdsMap;
|
||||
@@ -26,6 +28,9 @@ use crate::{
|
||||
Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32,
|
||||
};
|
||||
|
||||
/// The HNSW data-structure that we serialize, fill and search in.
|
||||
pub type Hnsw = hnsw::Hnsw<DotProduct, Vec<f32>, Pcg32, 12, 24>;
|
||||
|
||||
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
||||
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
|
||||
|
||||
@@ -42,6 +47,7 @@ pub mod main_key {
|
||||
pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map";
|
||||
pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids";
|
||||
pub const GEO_RTREE_KEY: &str = "geo-rtree";
|
||||
pub const VECTOR_HNSW_KEY: &str = "vector-hnsw";
|
||||
pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids";
|
||||
pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids";
|
||||
pub const PRIMARY_KEY_KEY: &str = "primary-key";
|
||||
@@ -86,6 +92,7 @@ pub mod db_name {
|
||||
pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids";
|
||||
pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s";
|
||||
pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings";
|
||||
pub const VECTOR_ID_DOCID: &str = "vector-id-docids";
|
||||
pub const DOCUMENTS: &str = "documents";
|
||||
pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids";
|
||||
}
|
||||
@@ -149,6 +156,9 @@ pub struct Index {
|
||||
/// Maps the document id, the facet field id and the strings.
|
||||
pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>,
|
||||
|
||||
/// Maps a vector id to the document id that have it.
|
||||
pub vector_id_docid: Database<OwnedType<BEU32>, OwnedType<BEU32>>,
|
||||
|
||||
/// Maps the document id to the document as an obkv store.
|
||||
pub(crate) documents: Database<OwnedType<BEU32>, ObkvCodec>,
|
||||
}
|
||||
@@ -162,7 +172,7 @@ impl Index {
|
||||
) -> Result<Index> {
|
||||
use db_name::*;
|
||||
|
||||
options.max_dbs(23);
|
||||
options.max_dbs(24);
|
||||
unsafe { options.flag(Flags::MdbAlwaysFreePages) };
|
||||
|
||||
let env = options.open(path)?;
|
||||
@@ -198,11 +208,11 @@ impl Index {
|
||||
env.create_database(&mut wtxn, Some(FACET_ID_IS_NULL_DOCIDS))?;
|
||||
let facet_id_is_empty_docids =
|
||||
env.create_database(&mut wtxn, Some(FACET_ID_IS_EMPTY_DOCIDS))?;
|
||||
|
||||
let field_id_docid_facet_f64s =
|
||||
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?;
|
||||
let field_id_docid_facet_strings =
|
||||
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?;
|
||||
let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?;
|
||||
let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?;
|
||||
wtxn.commit()?;
|
||||
|
||||
@@ -231,6 +241,7 @@ impl Index {
|
||||
facet_id_is_empty_docids,
|
||||
field_id_docid_facet_f64s,
|
||||
field_id_docid_facet_strings,
|
||||
vector_id_docid,
|
||||
documents,
|
||||
})
|
||||
}
|
||||
@@ -502,6 +513,26 @@ impl Index {
|
||||
}
|
||||
}
|
||||
|
||||
/* vector HNSW */
|
||||
|
||||
/// Writes the provided `hnsw`.
|
||||
pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> {
|
||||
self.main.put::<_, Str, SerdeBincode<Hnsw>>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw)
|
||||
}
|
||||
|
||||
/// Delete the `hnsw`.
|
||||
pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result<bool> {
|
||||
self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY)
|
||||
}
|
||||
|
||||
/// Returns the `hnsw`.
|
||||
pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result<Option<Hnsw>> {
|
||||
match self.main.get::<_, Str, SerdeBincode<Hnsw>>(rtxn, main_key::VECTOR_HNSW_KEY)? {
|
||||
Some(hnsw) => Ok(Some(hnsw)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/* field distribution */
|
||||
|
||||
/// Writes the field distribution which associates every field name with
|
||||
@@ -1466,9 +1497,9 @@ pub(crate) mod tests {
|
||||
|
||||
db_snap!(index, field_distribution,
|
||||
@r###"
|
||||
age 1
|
||||
id 2
|
||||
name 2
|
||||
age 1 |
|
||||
id 2 |
|
||||
name 2 |
|
||||
"###
|
||||
);
|
||||
|
||||
@@ -1486,9 +1517,9 @@ pub(crate) mod tests {
|
||||
|
||||
db_snap!(index, field_distribution,
|
||||
@r###"
|
||||
age 1
|
||||
id 2
|
||||
name 2
|
||||
age 1 |
|
||||
id 2 |
|
||||
name 2 |
|
||||
"###
|
||||
);
|
||||
|
||||
@@ -1502,9 +1533,9 @@ pub(crate) mod tests {
|
||||
|
||||
db_snap!(index, field_distribution,
|
||||
@r###"
|
||||
has_dog 1
|
||||
id 2
|
||||
name 2
|
||||
has_dog 1 |
|
||||
id 2 |
|
||||
name 2 |
|
||||
"###
|
||||
);
|
||||
}
|
||||
@@ -2488,12 +2519,8 @@ pub(crate) mod tests {
|
||||
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
let search = Search::new(&rtxn, &index);
|
||||
let SearchResult {
|
||||
matching_words: _,
|
||||
candidates: _,
|
||||
document_scores: _,
|
||||
mut documents_ids,
|
||||
} = search.execute().unwrap();
|
||||
let SearchResult { matching_words: _, candidates: _, mut documents_ids } =
|
||||
search.execute().unwrap();
|
||||
let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap();
|
||||
documents_ids.sort_unstable();
|
||||
let docs = index.documents(&rtxn, documents_ids).unwrap();
|
||||
|
||||
@@ -10,6 +10,7 @@ pub mod documents;
|
||||
|
||||
mod asc_desc;
|
||||
mod criterion;
|
||||
pub mod distance;
|
||||
mod error;
|
||||
mod external_documents_ids;
|
||||
pub mod facet;
|
||||
@@ -17,7 +18,6 @@ mod fields_ids_map;
|
||||
pub mod heed_codec;
|
||||
pub mod index;
|
||||
pub mod proximity;
|
||||
pub mod score_details;
|
||||
mod search;
|
||||
pub mod update;
|
||||
|
||||
|
||||
@@ -1,544 +0,0 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::distance_between_two_points;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ScoreDetails {
|
||||
Words(Words),
|
||||
Typo(Typo),
|
||||
Proximity(Rank),
|
||||
Fid(Rank),
|
||||
Position(Rank),
|
||||
ExactAttribute(ExactAttribute),
|
||||
Exactness(Rank),
|
||||
Sort(Sort),
|
||||
GeoSort(GeoSort),
|
||||
}
|
||||
|
||||
impl PartialOrd for ScoreDetails {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
use ScoreDetails::*;
|
||||
match (self, other) {
|
||||
// matching left and right hands => defer to sub impl
|
||||
(Words(left), Words(right)) => left.partial_cmp(right),
|
||||
(Typo(left), Typo(right)) => left.partial_cmp(right),
|
||||
(Proximity(left), Proximity(right)) => left.partial_cmp(right),
|
||||
(Fid(left), Fid(right)) => left.partial_cmp(right),
|
||||
(Position(left), Position(right)) => left.partial_cmp(right),
|
||||
(ExactAttribute(left), ExactAttribute(right)) => left.partial_cmp(right),
|
||||
(Exactness(left), Exactness(right)) => left.partial_cmp(right),
|
||||
(Sort(left), Sort(right)) => left.partial_cmp(right),
|
||||
(GeoSort(left), GeoSort(right)) => left.partial_cmp(right),
|
||||
// non matching left and right hands => None
|
||||
// written this way rather than with a single `_` arm, so that adding a new variant
|
||||
// still results in a compile error
|
||||
(Words(_), _) => None,
|
||||
(Typo(_), _) => None,
|
||||
(Proximity(_), _) => None,
|
||||
(Fid(_), _) => None,
|
||||
(Position(_), _) => None,
|
||||
(ExactAttribute(_), _) => None,
|
||||
(Exactness(_), _) => None,
|
||||
(Sort(_), _) => None,
|
||||
(GeoSort(_), _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScoreDetails {
|
||||
pub fn local_score(&self) -> Option<f64> {
|
||||
self.rank().map(Rank::local_score)
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> Option<Rank> {
|
||||
match self {
|
||||
ScoreDetails::Words(details) => Some(details.rank()),
|
||||
ScoreDetails::Typo(details) => Some(details.rank()),
|
||||
ScoreDetails::Proximity(details) => Some(*details),
|
||||
ScoreDetails::Fid(details) => Some(*details),
|
||||
ScoreDetails::Position(details) => Some(*details),
|
||||
ScoreDetails::ExactAttribute(details) => Some(details.rank()),
|
||||
ScoreDetails::Exactness(details) => Some(*details),
|
||||
ScoreDetails::Sort(_) => None,
|
||||
ScoreDetails::GeoSort(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 {
|
||||
Rank::global_score(details.filter_map(Self::rank))
|
||||
}
|
||||
|
||||
pub fn global_score_linear_scale<'a>(details: impl Iterator<Item = &'a Self>) -> u64 {
|
||||
(Self::global_score(details) * LINEAR_SCALE_FACTOR).round() as u64
|
||||
}
|
||||
|
||||
/// Panics
|
||||
///
|
||||
/// - If Position is not preceded by Fid
|
||||
/// - If Exactness is not preceded by ExactAttribute
|
||||
/// - If a sort fid is not contained in the passed `fields_ids_map`.
|
||||
pub fn to_json_map<'a>(
|
||||
details: impl Iterator<Item = &'a Self>,
|
||||
) -> serde_json::Map<String, serde_json::Value> {
|
||||
let mut order = 0;
|
||||
let mut details_map = serde_json::Map::default();
|
||||
for details in details {
|
||||
match details {
|
||||
ScoreDetails::Words(words) => {
|
||||
let words_details = serde_json::json!({
|
||||
"order": order,
|
||||
"matchingWords": words.matching_words,
|
||||
"maxMatchingWords": words.max_matching_words,
|
||||
"score": words.rank().local_score_linear_scale(),
|
||||
});
|
||||
details_map.insert("words".into(), words_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Typo(typo) => {
|
||||
let typo_details = serde_json::json!({
|
||||
"order": order,
|
||||
"typoCount": typo.typo_count,
|
||||
"maxTypoCount": typo.max_typo_count,
|
||||
"score": typo.rank().local_score_linear_scale(),
|
||||
});
|
||||
details_map.insert("typo".into(), typo_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Proximity(proximity) => {
|
||||
let proximity_details = serde_json::json!({
|
||||
"order": order,
|
||||
"score": proximity.local_score_linear_scale(),
|
||||
});
|
||||
details_map.insert("proximity".into(), proximity_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Fid(fid) => {
|
||||
// For now, fid is a virtual rule always followed by the "position" rule
|
||||
let fid_details = serde_json::json!({
|
||||
"order": order,
|
||||
"attributes_ranking_order": fid.local_score_linear_scale(),
|
||||
});
|
||||
details_map.insert("attribute".into(), fid_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Position(position) => {
|
||||
// For now, position is a virtual rule always preceded by the "fid" rule
|
||||
let attribute_details = details_map
|
||||
.get_mut("attribute")
|
||||
.expect("position not preceded by attribute");
|
||||
let attribute_details = attribute_details
|
||||
.as_object_mut()
|
||||
.expect("attribute details was not an object");
|
||||
attribute_details.insert(
|
||||
"attributes_query_word_order".into(),
|
||||
position.local_score_linear_scale().into(),
|
||||
);
|
||||
// do not update the order since this was already done by fid
|
||||
}
|
||||
ScoreDetails::ExactAttribute(exact_attribute) => {
|
||||
let exactness_details = serde_json::json!({
|
||||
"order": order,
|
||||
"exactIn": exact_attribute,
|
||||
"score": exact_attribute.rank().local_score_linear_scale(),
|
||||
});
|
||||
details_map.insert("exactness".into(), exactness_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Exactness(details) => {
|
||||
// For now, exactness is a virtual rule always preceded by the "ExactAttribute" rule
|
||||
let exactness_details = details_map
|
||||
.get_mut("exactness")
|
||||
.expect("Exactness not preceded by exactAttribute");
|
||||
let exactness_details = exactness_details
|
||||
.as_object_mut()
|
||||
.expect("exactness details was not an object");
|
||||
if exactness_details.get("exactIn").expect("missing 'exactIn'")
|
||||
== &serde_json::json!(ExactAttribute::NoExactMatch)
|
||||
{
|
||||
let score = Rank::global_score_linear_scale(
|
||||
[ExactAttribute::NoExactMatch.rank(), *details].iter().copied(),
|
||||
);
|
||||
*exactness_details.get_mut("score").expect("missing score") = score.into();
|
||||
}
|
||||
// do not update the order since this was already done by exactAttribute
|
||||
}
|
||||
ScoreDetails::Sort(details) => {
|
||||
let sort = format!(
|
||||
"{}:{}",
|
||||
details.field_name,
|
||||
if details.ascending { "asc" } else { "desc" }
|
||||
);
|
||||
let sort_details = serde_json::json!({
|
||||
"order": order,
|
||||
"value": details.value,
|
||||
});
|
||||
details_map.insert(sort, sort_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::GeoSort(details) => {
|
||||
let sort = format!(
|
||||
"_geoPoint({}, {}):{}",
|
||||
details.target_point[0],
|
||||
details.target_point[1],
|
||||
if details.ascending { "asc" } else { "desc" }
|
||||
);
|
||||
let point = if let Some(value) = details.value {
|
||||
serde_json::json!({ "lat": value[0], "lng": value[1]})
|
||||
} else {
|
||||
serde_json::Value::Null
|
||||
};
|
||||
let sort_details = serde_json::json!({
|
||||
"order": order,
|
||||
"value": point,
|
||||
"distance": details.distance(),
|
||||
});
|
||||
details_map.insert(sort, sort_details);
|
||||
order += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
details_map
|
||||
}
|
||||
|
||||
pub fn partial_cmp_iter<'a>(
|
||||
mut left: impl Iterator<Item = &'a Self>,
|
||||
mut right: impl Iterator<Item = &'a Self>,
|
||||
) -> Result<Ordering, NotComparable> {
|
||||
let mut index = 0;
|
||||
let mut order = match (left.next(), right.next()) {
|
||||
(Some(left), Some(right)) => left.partial_cmp(right).incomparable(index)?,
|
||||
_ => return Ok(Ordering::Equal),
|
||||
};
|
||||
for (left, right) in left.zip(right) {
|
||||
if order != Ordering::Equal {
|
||||
return Ok(order);
|
||||
};
|
||||
|
||||
index += 1;
|
||||
order = left.partial_cmp(right).incomparable(index)?;
|
||||
}
|
||||
Ok(order)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct NotComparable(pub usize);
|
||||
|
||||
trait OptionToNotComparable<T> {
|
||||
fn incomparable(self, index: usize) -> Result<T, NotComparable>;
|
||||
}
|
||||
|
||||
impl<T> OptionToNotComparable<T> for Option<T> {
|
||||
fn incomparable(self, index: usize) -> Result<T, NotComparable> {
|
||||
match self {
|
||||
Some(t) => Ok(t),
|
||||
None => Err(NotComparable(index)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct Words {
|
||||
pub matching_words: u32,
|
||||
pub max_matching_words: u32,
|
||||
}
|
||||
|
||||
impl PartialOrd for Words {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
(self.max_matching_words == other.max_matching_words)
|
||||
.then(|| self.matching_words.cmp(&other.matching_words))
|
||||
}
|
||||
}
|
||||
|
||||
impl Words {
|
||||
pub fn rank(&self) -> Rank {
|
||||
Rank { rank: self.matching_words, max_rank: self.max_matching_words }
|
||||
}
|
||||
|
||||
pub(crate) fn from_rank(rank: Rank) -> Words {
|
||||
Words { matching_words: rank.rank, max_matching_words: rank.max_rank }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct Typo {
|
||||
pub typo_count: u32,
|
||||
pub max_typo_count: u32,
|
||||
}
|
||||
|
||||
impl PartialOrd for Typo {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
(self.max_typo_count == other.max_typo_count).then(|| {
|
||||
// the order is reverted as having fewer typos gives a better score
|
||||
self.typo_count.cmp(&other.typo_count).reverse()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Typo {
|
||||
pub fn rank(&self) -> Rank {
|
||||
Rank {
|
||||
rank: self.max_typo_count - self.typo_count + 1,
|
||||
max_rank: (self.max_typo_count + 1),
|
||||
}
|
||||
}
|
||||
|
||||
// max_rank = max_typo + 1
|
||||
// max_typo = max_rank - 1
|
||||
//
|
||||
// rank = max_typo - typo + 1
|
||||
// rank = max_rank - 1 - typo + 1
|
||||
// rank + typo = max_rank
|
||||
// typo = max_rank - rank
|
||||
pub fn from_rank(rank: Rank) -> Typo {
|
||||
Typo { typo_count: rank.max_rank - rank.rank, max_typo_count: rank.max_rank - 1 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct Rank {
|
||||
/// The ordinal rank, such that `max_rank` is the first rank, and 0 is the last rank.
|
||||
///
|
||||
/// The higher the better. Documents with a rank of 0 have a score of 0 and are typically never returned
|
||||
/// (they don't match the query).
|
||||
pub rank: u32,
|
||||
/// The maximum possible rank. Documents with this rank have a score of 1.
|
||||
///
|
||||
/// The max rank should not be 0.
|
||||
pub max_rank: u32,
|
||||
}
|
||||
|
||||
impl PartialOrd for Rank {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
(self.max_rank == other.max_rank).then(|| self.rank.cmp(&other.rank))
|
||||
}
|
||||
}
|
||||
|
||||
impl Rank {
|
||||
pub fn local_score(self) -> f64 {
|
||||
self.rank as f64 / self.max_rank as f64
|
||||
}
|
||||
|
||||
pub fn local_score_linear_scale(self) -> u64 {
|
||||
(self.local_score() * LINEAR_SCALE_FACTOR).round() as u64
|
||||
}
|
||||
|
||||
pub fn global_score(details: impl Iterator<Item = Self>) -> f64 {
|
||||
let mut rank = Rank { rank: 1, max_rank: 1 };
|
||||
for inner_rank in details {
|
||||
rank.rank -= 1;
|
||||
|
||||
rank.rank *= inner_rank.max_rank;
|
||||
rank.max_rank *= inner_rank.max_rank;
|
||||
|
||||
rank.rank += inner_rank.rank;
|
||||
}
|
||||
rank.local_score()
|
||||
}
|
||||
|
||||
pub fn global_score_linear_scale(details: impl Iterator<Item = Self>) -> u64 {
|
||||
(Self::global_score(details) * LINEAR_SCALE_FACTOR).round() as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum ExactAttribute {
|
||||
// Do not reorder as the order is significant, from least relevant to most relevant
|
||||
NoExactMatch,
|
||||
MatchesStart,
|
||||
MatchesFull,
|
||||
}
|
||||
|
||||
impl ExactAttribute {
|
||||
pub fn rank(&self) -> Rank {
|
||||
let rank = match self {
|
||||
ExactAttribute::MatchesFull => 3,
|
||||
ExactAttribute::MatchesStart => 2,
|
||||
ExactAttribute::NoExactMatch => 1,
|
||||
};
|
||||
Rank { rank, max_rank: 3 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Sort {
|
||||
pub field_name: String,
|
||||
pub ascending: bool,
|
||||
pub value: serde_json::Value,
|
||||
}
|
||||
|
||||
impl PartialOrd for Sort {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
if self.field_name != other.field_name {
|
||||
return None;
|
||||
}
|
||||
if self.ascending != other.ascending {
|
||||
return None;
|
||||
}
|
||||
match (&self.value, &other.value) {
|
||||
(serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal),
|
||||
(serde_json::Value::Null, _) => Some(Ordering::Less),
|
||||
(_, serde_json::Value::Null) => Some(Ordering::Greater),
|
||||
// numbers are always before strings
|
||||
(serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater),
|
||||
(serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less),
|
||||
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
|
||||
//FIXME: unwrap permitted here?
|
||||
let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?;
|
||||
// always reverted, as bigger is better
|
||||
Some(if self.ascending { order.reverse() } else { order })
|
||||
}
|
||||
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
|
||||
let order = left.cmp(right);
|
||||
Some(if self.ascending { order.reverse() } else { order })
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct GeoSort {
|
||||
pub target_point: [f64; 2],
|
||||
pub ascending: bool,
|
||||
pub value: Option<[f64; 2]>,
|
||||
}
|
||||
|
||||
impl PartialOrd for GeoSort {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
if self.target_point != other.target_point {
|
||||
return None;
|
||||
}
|
||||
if self.ascending != other.ascending {
|
||||
return None;
|
||||
}
|
||||
Some(match (self.distance(), other.distance()) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Less,
|
||||
(Some(_), None) => Ordering::Greater,
|
||||
(Some(left), Some(right)) => {
|
||||
let order = left.partial_cmp(&right)?;
|
||||
if self.ascending {
|
||||
// when ascending, the one with the smallest distance has the best score
|
||||
order.reverse()
|
||||
} else {
|
||||
order
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GeoSort {
|
||||
pub fn distance(&self) -> Option<f64> {
|
||||
self.value.map(|value| distance_between_two_points(&self.target_point, &value))
|
||||
}
|
||||
}
|
||||
|
||||
const LINEAR_SCALE_FACTOR: f64 = 1000.0;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn compare() {
|
||||
let left = [
|
||||
ScoreDetails::Words(Words { matching_words: 3, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
field_name: "doggo".into(),
|
||||
ascending: true,
|
||||
value: "Intel the Beagle".into(),
|
||||
}),
|
||||
];
|
||||
let right = [
|
||||
ScoreDetails::Words(Words { matching_words: 3, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
field_name: "doggo".into(),
|
||||
ascending: true,
|
||||
value: "Max the Labrador".into(),
|
||||
}),
|
||||
];
|
||||
assert_eq!(
|
||||
Ok(Ordering::Greater),
|
||||
ScoreDetails::partial_cmp_iter(left.iter(), right.iter())
|
||||
);
|
||||
// equal when all the common components are equal
|
||||
assert_eq!(
|
||||
Ok(Ordering::Equal),
|
||||
ScoreDetails::partial_cmp_iter(left[0..1].iter(), right.iter())
|
||||
);
|
||||
|
||||
let right = [
|
||||
ScoreDetails::Words(Words { matching_words: 4, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
field_name: "doggo".into(),
|
||||
ascending: true,
|
||||
value: "Max the Labrador".into(),
|
||||
}),
|
||||
];
|
||||
|
||||
assert_eq!(Ok(Ordering::Less), ScoreDetails::partial_cmp_iter(left.iter(), right.iter()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sort_not_comparable() {
|
||||
let left = [
|
||||
ScoreDetails::Words(Words { matching_words: 3, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
// not the same field name
|
||||
field_name: "catto".into(),
|
||||
ascending: true,
|
||||
value: "Sylver the cat".into(),
|
||||
}),
|
||||
];
|
||||
let right = [
|
||||
ScoreDetails::Words(Words { matching_words: 3, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
field_name: "doggo".into(),
|
||||
ascending: true,
|
||||
value: "Max the Labrador".into(),
|
||||
}),
|
||||
];
|
||||
assert_eq!(
|
||||
Err(NotComparable(1)),
|
||||
ScoreDetails::partial_cmp_iter(left.iter(), right.iter())
|
||||
);
|
||||
let left = [
|
||||
ScoreDetails::Words(Words { matching_words: 3, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
field_name: "doggo".into(),
|
||||
// Not the same order
|
||||
ascending: false,
|
||||
value: "Intel the Beagle".into(),
|
||||
}),
|
||||
];
|
||||
let right = [
|
||||
ScoreDetails::Words(Words { matching_words: 3, max_matching_words: 4 }),
|
||||
ScoreDetails::Sort(Sort {
|
||||
field_name: "doggo".into(),
|
||||
ascending: true,
|
||||
value: "Max the Labrador".into(),
|
||||
}),
|
||||
];
|
||||
assert_eq!(
|
||||
Err(NotComparable(1)),
|
||||
ScoreDetails::partial_cmp_iter(left.iter(), right.iter())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sort_behavior() {
|
||||
let left = Sort { field_name: "price".into(), ascending: true, value: "5400".into() };
|
||||
let right = Sort { field_name: "price".into(), ascending: true, value: 53.into() };
|
||||
// number always better match than strings
|
||||
assert_eq!(Some(Ordering::Less), left.partial_cmp(&right));
|
||||
|
||||
let left = Sort { field_name: "price".into(), ascending: false, value: "5400".into() };
|
||||
let right = Sort { field_name: "price".into(), ascending: false, value: 53.into() };
|
||||
// true regardless of the sort direction
|
||||
assert_eq!(Some(Ordering::Less), left.partial_cmp(&right));
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ use roaring::bitmap::RoaringBitmap;
|
||||
pub use self::facet::{FacetDistribution, Filter, DEFAULT_VALUES_PER_FACET};
|
||||
pub use self::new::matches::{FormatOptions, MatchBounds, Matcher, MatcherBuilder, MatchingWords};
|
||||
use self::new::PartialSearchResult;
|
||||
use crate::score_details::ScoreDetails;
|
||||
use crate::{
|
||||
execute_search, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, SearchContext,
|
||||
};
|
||||
@@ -23,6 +22,7 @@ pub mod new;
|
||||
|
||||
pub struct Search<'a> {
|
||||
query: Option<String>,
|
||||
vector: Option<Vec<f32>>,
|
||||
// this should be linked to the String in the query
|
||||
filter: Option<Filter<'a>>,
|
||||
offset: usize,
|
||||
@@ -40,6 +40,7 @@ impl<'a> Search<'a> {
|
||||
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
||||
Search {
|
||||
query: None,
|
||||
vector: None,
|
||||
filter: None,
|
||||
offset: 0,
|
||||
limit: 20,
|
||||
@@ -58,6 +59,11 @@ impl<'a> Search<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> {
|
||||
self.vector = Some(vector.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn offset(&mut self, offset: usize) -> &mut Search<'a> {
|
||||
self.offset = offset;
|
||||
self
|
||||
@@ -94,7 +100,7 @@ impl<'a> Search<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
/// Forces the search to exhaustively compute the number of candidates,
|
||||
/// Force the search to exhastivelly compute the number of candidates,
|
||||
/// this will increase the search time but allows finite pagination.
|
||||
pub fn exhaustive_number_hits(&mut self, exhaustive_number_hits: bool) -> &mut Search<'a> {
|
||||
self.exhaustive_number_hits = exhaustive_number_hits;
|
||||
@@ -103,10 +109,11 @@ impl<'a> Search<'a> {
|
||||
|
||||
pub fn execute(&self) -> Result<SearchResult> {
|
||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
|
||||
let PartialSearchResult { located_query_terms, candidates, documents_ids } =
|
||||
execute_search(
|
||||
&mut ctx,
|
||||
&self.query,
|
||||
&self.vector,
|
||||
self.terms_matching_strategy,
|
||||
self.exhaustive_number_hits,
|
||||
&self.filter,
|
||||
@@ -125,7 +132,7 @@ impl<'a> Search<'a> {
|
||||
None => MatchingWords::default(),
|
||||
};
|
||||
|
||||
Ok(SearchResult { matching_words, candidates, document_scores, documents_ids })
|
||||
Ok(SearchResult { matching_words, candidates, documents_ids })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +140,7 @@ impl fmt::Debug for Search<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let Search {
|
||||
query,
|
||||
vector: _,
|
||||
filter,
|
||||
offset,
|
||||
limit,
|
||||
@@ -146,6 +154,7 @@ impl fmt::Debug for Search<'_> {
|
||||
} = self;
|
||||
f.debug_struct("Search")
|
||||
.field("query", query)
|
||||
.field("vector", &"[...]")
|
||||
.field("filter", filter)
|
||||
.field("offset", offset)
|
||||
.field("limit", limit)
|
||||
@@ -161,8 +170,8 @@ impl fmt::Debug for Search<'_> {
|
||||
pub struct SearchResult {
|
||||
pub matching_words: MatchingWords,
|
||||
pub candidates: RoaringBitmap,
|
||||
// TODO those documents ids should be associated with their criteria scores.
|
||||
pub documents_ids: Vec<DocumentId>,
|
||||
pub document_scores: Vec<Vec<ScoreDetails>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
|
||||
@@ -3,13 +3,11 @@ use roaring::RoaringBitmap;
|
||||
use super::logger::SearchLogger;
|
||||
use super::ranking_rules::{BoxRankingRule, RankingRuleQueryTrait};
|
||||
use super::SearchContext;
|
||||
use crate::score_details::ScoreDetails;
|
||||
use crate::search::new::distinct::{apply_distinct_rule, distinct_single_docid, DistinctOutput};
|
||||
use crate::Result;
|
||||
|
||||
pub struct BucketSortOutput {
|
||||
pub docids: Vec<u32>,
|
||||
pub scores: Vec<Vec<ScoreDetails>>,
|
||||
pub all_candidates: RoaringBitmap,
|
||||
}
|
||||
|
||||
@@ -33,11 +31,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
};
|
||||
|
||||
if universe.len() < from as u64 {
|
||||
return Ok(BucketSortOutput {
|
||||
docids: vec![],
|
||||
scores: vec![],
|
||||
all_candidates: universe.clone(),
|
||||
});
|
||||
return Ok(BucketSortOutput { docids: vec![], all_candidates: universe.clone() });
|
||||
}
|
||||
if ranking_rules.is_empty() {
|
||||
if let Some(distinct_fid) = distinct_fid {
|
||||
@@ -55,32 +49,22 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
}
|
||||
let mut all_candidates = universe - excluded;
|
||||
all_candidates.extend(results.iter().copied());
|
||||
return Ok(BucketSortOutput {
|
||||
scores: vec![Default::default(); results.len()],
|
||||
docids: results,
|
||||
all_candidates,
|
||||
});
|
||||
return Ok(BucketSortOutput { docids: results, all_candidates });
|
||||
} else {
|
||||
let docids: Vec<u32> = universe.iter().skip(from).take(length).collect();
|
||||
return Ok(BucketSortOutput {
|
||||
scores: vec![Default::default(); docids.len()],
|
||||
docids,
|
||||
all_candidates: universe.clone(),
|
||||
});
|
||||
let docids = universe.iter().skip(from).take(length).collect();
|
||||
return Ok(BucketSortOutput { docids, all_candidates: universe.clone() });
|
||||
};
|
||||
}
|
||||
|
||||
let ranking_rules_len = ranking_rules.len();
|
||||
|
||||
logger.start_iteration_ranking_rule(0, ranking_rules[0].as_ref(), query, universe);
|
||||
|
||||
ranking_rules[0].start_iteration(ctx, logger, universe, query)?;
|
||||
|
||||
let mut ranking_rule_scores: Vec<ScoreDetails> = vec![];
|
||||
|
||||
let mut ranking_rule_universes: Vec<RoaringBitmap> =
|
||||
vec![RoaringBitmap::default(); ranking_rules_len];
|
||||
ranking_rule_universes[0] = universe.clone();
|
||||
|
||||
let mut cur_ranking_rule_index = 0;
|
||||
|
||||
/// Finish iterating over the current ranking rule, yielding
|
||||
@@ -105,16 +89,11 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
} else {
|
||||
cur_ranking_rule_index -= 1;
|
||||
}
|
||||
// FIXME: check off by one
|
||||
if ranking_rule_scores.len() > cur_ranking_rule_index {
|
||||
ranking_rule_scores.pop();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let mut all_candidates = universe.clone();
|
||||
let mut valid_docids = vec![];
|
||||
let mut valid_scores = vec![];
|
||||
let mut cur_offset = 0usize;
|
||||
|
||||
macro_rules! maybe_add_to_results {
|
||||
@@ -125,23 +104,23 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
length,
|
||||
logger,
|
||||
&mut valid_docids,
|
||||
&mut valid_scores,
|
||||
&mut all_candidates,
|
||||
&mut ranking_rule_universes,
|
||||
&mut ranking_rules,
|
||||
cur_ranking_rule_index,
|
||||
&mut cur_offset,
|
||||
distinct_fid,
|
||||
&ranking_rule_scores,
|
||||
$candidates,
|
||||
)?;
|
||||
};
|
||||
}
|
||||
|
||||
while valid_docids.len() < length {
|
||||
// The universe for this bucket is zero, so we don't need to sort
|
||||
// anything, just go back to the parent ranking rule.
|
||||
if ranking_rule_universes[cur_ranking_rule_index].is_empty() {
|
||||
// The universe for this bucket is zero or one element, so we don't need to sort
|
||||
// anything, just extend the results and go back to the parent ranking rule.
|
||||
if ranking_rule_universes[cur_ranking_rule_index].len() <= 1 {
|
||||
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
|
||||
maybe_add_to_results!(bucket);
|
||||
back!();
|
||||
continue;
|
||||
}
|
||||
@@ -151,8 +130,6 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
continue;
|
||||
};
|
||||
|
||||
ranking_rule_scores.push(next_bucket.score);
|
||||
|
||||
logger.next_bucket_ranking_rule(
|
||||
cur_ranking_rule_index,
|
||||
ranking_rules[cur_ranking_rule_index].as_ref(),
|
||||
@@ -166,11 +143,10 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates;
|
||||
|
||||
if cur_ranking_rule_index == ranking_rules_len - 1
|
||||
|| next_bucket.candidates.len() <= 1
|
||||
|| cur_offset + (next_bucket.candidates.len() as usize) < from
|
||||
{
|
||||
maybe_add_to_results!(next_bucket.candidates);
|
||||
// FIXME: use index based logic like all the other rules so that you don't have to maintain the pop/push?
|
||||
ranking_rule_scores.pop();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -190,7 +166,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(BucketSortOutput { docids: valid_docids, scores: valid_scores, all_candidates })
|
||||
Ok(BucketSortOutput { docids: valid_docids, all_candidates })
|
||||
}
|
||||
|
||||
/// Add the candidates to the results. Take `distinct`, `from`, `length`, and `cur_offset`
|
||||
@@ -203,18 +179,14 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
|
||||
logger: &mut dyn SearchLogger<Q>,
|
||||
|
||||
valid_docids: &mut Vec<u32>,
|
||||
valid_scores: &mut Vec<Vec<ScoreDetails>>,
|
||||
all_candidates: &mut RoaringBitmap,
|
||||
|
||||
ranking_rule_universes: &mut [RoaringBitmap],
|
||||
ranking_rules: &mut [BoxRankingRule<'ctx, Q>],
|
||||
|
||||
cur_ranking_rule_index: usize,
|
||||
|
||||
cur_offset: &mut usize,
|
||||
|
||||
distinct_fid: Option<u16>,
|
||||
ranking_rule_scores: &[ScoreDetails],
|
||||
candidates: RoaringBitmap,
|
||||
) -> Result<()> {
|
||||
// First apply the distinct rule on the candidates, reducing the universes if necessary
|
||||
@@ -259,17 +231,13 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
|
||||
let candidates =
|
||||
candidates.iter().take(length - valid_docids.len()).copied().collect::<Vec<_>>();
|
||||
logger.add_to_results(&candidates);
|
||||
valid_docids.extend_from_slice(&candidates);
|
||||
valid_scores
|
||||
.extend(std::iter::repeat(ranking_rule_scores.to_owned()).take(candidates.len()));
|
||||
valid_docids.extend(&candidates);
|
||||
}
|
||||
} else {
|
||||
// if we have passed the offset already, add some of the documents (up to the limit)
|
||||
let candidates = candidates.iter().take(length - valid_docids.len()).collect::<Vec<u32>>();
|
||||
logger.add_to_results(&candidates);
|
||||
valid_docids.extend_from_slice(&candidates);
|
||||
valid_scores
|
||||
.extend(std::iter::repeat(ranking_rule_scores.to_owned()).take(candidates.len()));
|
||||
valid_docids.extend(&candidates);
|
||||
}
|
||||
|
||||
*cur_offset += candidates.len() as usize;
|
||||
|
||||
@@ -2,7 +2,6 @@ use roaring::{MultiOps, RoaringBitmap};
|
||||
|
||||
use super::query_graph::QueryGraph;
|
||||
use super::ranking_rules::{RankingRule, RankingRuleOutput};
|
||||
use crate::score_details::{self, ScoreDetails};
|
||||
use crate::search::new::query_graph::QueryNodeData;
|
||||
use crate::search::new::query_term::ExactTerm;
|
||||
use crate::{Result, SearchContext, SearchLogger};
|
||||
@@ -245,13 +244,7 @@ impl State {
|
||||
candidates &= universe;
|
||||
(
|
||||
State::AttributeStarts(query_graph.clone(), candidates_per_attribute),
|
||||
Some(RankingRuleOutput {
|
||||
query: query_graph,
|
||||
candidates,
|
||||
score: ScoreDetails::ExactAttribute(
|
||||
score_details::ExactAttribute::MatchesFull,
|
||||
),
|
||||
}),
|
||||
Some(RankingRuleOutput { query: query_graph, candidates }),
|
||||
)
|
||||
}
|
||||
State::AttributeStarts(query_graph, candidates_per_attribute) => {
|
||||
@@ -264,24 +257,12 @@ impl State {
|
||||
candidates &= universe;
|
||||
(
|
||||
State::Empty(query_graph.clone()),
|
||||
Some(RankingRuleOutput {
|
||||
query: query_graph,
|
||||
candidates,
|
||||
score: ScoreDetails::ExactAttribute(
|
||||
score_details::ExactAttribute::MatchesStart,
|
||||
),
|
||||
}),
|
||||
Some(RankingRuleOutput { query: query_graph, candidates }),
|
||||
)
|
||||
}
|
||||
State::Empty(query_graph) => (
|
||||
State::Empty(query_graph.clone()),
|
||||
Some(RankingRuleOutput {
|
||||
query: query_graph,
|
||||
candidates: universe.clone(),
|
||||
score: ScoreDetails::ExactAttribute(
|
||||
score_details::ExactAttribute::NoExactMatch,
|
||||
),
|
||||
}),
|
||||
Some(RankingRuleOutput { query: query_graph, candidates: universe.clone() }),
|
||||
),
|
||||
};
|
||||
(state, output)
|
||||
|
||||
@@ -8,7 +8,6 @@ use rstar::RTree;
|
||||
|
||||
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
|
||||
use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec};
|
||||
use crate::score_details::{self, ScoreDetails};
|
||||
use crate::{
|
||||
distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext,
|
||||
SearchLogger,
|
||||
@@ -81,7 +80,7 @@ pub struct GeoSort<Q: RankingRuleQueryTrait> {
|
||||
field_ids: Option<[u16; 2]>,
|
||||
rtree: Option<RTree<GeoPoint>>,
|
||||
|
||||
cached_sorted_docids: VecDeque<(u32, [f64; 2])>,
|
||||
cached_sorted_docids: VecDeque<u32>,
|
||||
geo_candidates: RoaringBitmap,
|
||||
}
|
||||
|
||||
@@ -131,7 +130,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
|
||||
let point = lat_lng_to_xyz(&self.point);
|
||||
for point in rtree.nearest_neighbor_iter(&point) {
|
||||
if self.geo_candidates.contains(point.data.0) {
|
||||
self.cached_sorted_docids.push_back(point.data);
|
||||
self.cached_sorted_docids.push_back(point.data.0);
|
||||
if self.cached_sorted_docids.len() >= cache_size {
|
||||
break;
|
||||
}
|
||||
@@ -143,7 +142,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
|
||||
let point = lat_lng_to_xyz(&opposite_of(self.point));
|
||||
for point in rtree.nearest_neighbor_iter(&point) {
|
||||
if self.geo_candidates.contains(point.data.0) {
|
||||
self.cached_sorted_docids.push_front(point.data);
|
||||
self.cached_sorted_docids.push_front(point.data.0);
|
||||
if self.cached_sorted_docids.len() >= cache_size {
|
||||
break;
|
||||
}
|
||||
@@ -178,7 +177,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
|
||||
// computing the distance between two points is expensive thus we cache the result
|
||||
documents
|
||||
.sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize);
|
||||
self.cached_sorted_docids.extend(documents.into_iter());
|
||||
self.cached_sorted_docids.extend(documents.into_iter().map(|(doc_id, _)| doc_id));
|
||||
};
|
||||
|
||||
Ok(())
|
||||
@@ -221,19 +220,12 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
|
||||
logger: &mut dyn SearchLogger<Q>,
|
||||
universe: &RoaringBitmap,
|
||||
) -> Result<Option<RankingRuleOutput<Q>>> {
|
||||
assert!(universe.len() > 1);
|
||||
let query = self.query.as_ref().unwrap().clone();
|
||||
self.geo_candidates &= universe;
|
||||
|
||||
if self.geo_candidates.is_empty() {
|
||||
return Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: universe.clone(),
|
||||
score: ScoreDetails::GeoSort(score_details::GeoSort {
|
||||
target_point: self.point,
|
||||
ascending: self.ascending,
|
||||
value: None,
|
||||
}),
|
||||
}));
|
||||
return Ok(Some(RankingRuleOutput { query, candidates: universe.clone() }));
|
||||
}
|
||||
|
||||
let ascending = self.ascending;
|
||||
@@ -244,16 +236,11 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
|
||||
cache.pop_back()
|
||||
}
|
||||
};
|
||||
while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
|
||||
while let Some(id) = next(&mut self.cached_sorted_docids) {
|
||||
if self.geo_candidates.contains(id) {
|
||||
return Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: RoaringBitmap::from_iter([id]),
|
||||
score: ScoreDetails::GeoSort(score_details::GeoSort {
|
||||
target_point: self.point,
|
||||
ascending: self.ascending,
|
||||
value: Some(point),
|
||||
}),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,6 @@ use super::ranking_rule_graph::{
|
||||
};
|
||||
use super::small_bitmap::SmallBitmap;
|
||||
use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext};
|
||||
use crate::score_details::Rank;
|
||||
use crate::search::new::query_term::LocatedQueryTermSubset;
|
||||
use crate::search::new::ranking_rule_graph::PathVisitor;
|
||||
use crate::{Result, TermsMatchingStrategy};
|
||||
@@ -119,8 +118,6 @@ pub struct GraphBasedRankingRuleState<G: RankingRuleGraphTrait> {
|
||||
all_costs: MappedInterner<QueryNode, Vec<u64>>,
|
||||
/// An index in the first element of `all_distances`, giving the cost of the next bucket
|
||||
cur_cost: u64,
|
||||
/// One above the highest possible cost for this rule
|
||||
next_max_cost: u64,
|
||||
}
|
||||
|
||||
impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBasedRankingRule<G> {
|
||||
@@ -142,12 +139,13 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
|
||||
let mut forbidden_nodes =
|
||||
SmallBitmap::for_interned_values_in(&query_graph.nodes);
|
||||
let mut costs = query_graph.nodes.map(|_| None);
|
||||
// FIXME: this works because only words uses termsmatchingstrategy at the moment.
|
||||
let mut cost = 100;
|
||||
for ns in removal_order {
|
||||
for n in ns.iter() {
|
||||
*costs.get_mut(n) = Some((1, forbidden_nodes.clone()));
|
||||
*costs.get_mut(n) = Some((cost, forbidden_nodes.clone()));
|
||||
}
|
||||
forbidden_nodes.union(&ns);
|
||||
cost += 100;
|
||||
}
|
||||
costs
|
||||
}
|
||||
@@ -164,16 +162,12 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
|
||||
// Then pre-compute the cost of all paths from each node to the end node
|
||||
let all_costs = graph.find_all_costs_to_end();
|
||||
|
||||
let next_max_cost =
|
||||
all_costs.get(graph.query_graph.root_node).iter().copied().max().unwrap_or(0) + 1;
|
||||
|
||||
let state = GraphBasedRankingRuleState {
|
||||
graph,
|
||||
conditions_cache: condition_docids_cache,
|
||||
dead_ends_cache,
|
||||
all_costs,
|
||||
cur_cost: 0,
|
||||
next_max_cost,
|
||||
};
|
||||
|
||||
self.state = Some(state);
|
||||
@@ -187,13 +181,17 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
|
||||
logger: &mut dyn SearchLogger<QueryGraph>,
|
||||
universe: &RoaringBitmap,
|
||||
) -> Result<Option<RankingRuleOutput<QueryGraph>>> {
|
||||
// If universe.len() <= 1, the bucket sort algorithm
|
||||
// should not have called this function.
|
||||
assert!(universe.len() > 1);
|
||||
// Will crash if `next_bucket` is called before `start_iteration` or after `end_iteration`,
|
||||
// should never happen
|
||||
let mut state = self.state.take().unwrap();
|
||||
|
||||
let all_costs = state.all_costs.get(state.graph.query_graph.root_node);
|
||||
// Retrieve the cost of the paths to compute
|
||||
let Some(&cost) = all_costs
|
||||
let Some(&cost) = state
|
||||
.all_costs
|
||||
.get(state.graph.query_graph.root_node)
|
||||
.iter()
|
||||
.find(|c| **c >= state.cur_cost) else {
|
||||
self.state = None;
|
||||
@@ -209,12 +207,8 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
|
||||
dead_ends_cache,
|
||||
all_costs,
|
||||
cur_cost: _,
|
||||
next_max_cost,
|
||||
} = &mut state;
|
||||
|
||||
let rank = *next_max_cost - cost;
|
||||
let score = G::rank_to_score(Rank { rank: rank as u32, max_rank: *next_max_cost as u32 });
|
||||
|
||||
let mut universe = universe.clone();
|
||||
|
||||
let mut used_conditions = SmallBitmap::for_interned_values_in(&graph.conditions_interner);
|
||||
@@ -331,7 +325,7 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
|
||||
|
||||
self.state = Some(state);
|
||||
|
||||
Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket, score }))
|
||||
Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket }))
|
||||
}
|
||||
|
||||
fn end_iteration(
|
||||
|
||||
@@ -509,6 +509,7 @@ mod tests {
|
||||
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
|
||||
&mut ctx,
|
||||
&Some(query.to_string()),
|
||||
&None,
|
||||
crate::TermsMatchingStrategy::default(),
|
||||
false,
|
||||
&None,
|
||||
|
||||
@@ -28,6 +28,7 @@ use db_cache::DatabaseCache;
|
||||
use exact_attribute::ExactAttribute;
|
||||
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
|
||||
use heed::RoTxn;
|
||||
use hnsw::Searcher;
|
||||
use interner::{DedupInterner, Interner};
|
||||
pub use logger::visual::VisualSearchLogger;
|
||||
pub use logger::{DefaultSearchLogger, SearchLogger};
|
||||
@@ -39,14 +40,16 @@ use ranking_rules::{
|
||||
use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache};
|
||||
use roaring::RoaringBitmap;
|
||||
use sort::Sort;
|
||||
use space::Neighbor;
|
||||
|
||||
use self::geo_sort::GeoSort;
|
||||
pub use self::geo_sort::Strategy as GeoSortStrategy;
|
||||
use self::graph_based_ranking_rule::Words;
|
||||
use self::interner::Interned;
|
||||
use crate::score_details::ScoreDetails;
|
||||
use crate::search::new::distinct::apply_distinct_rule;
|
||||
use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError};
|
||||
use crate::{
|
||||
AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32,
|
||||
};
|
||||
|
||||
/// A structure used throughout the execution of a search query.
|
||||
pub struct SearchContext<'ctx> {
|
||||
@@ -350,6 +353,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
|
||||
pub fn execute_search(
|
||||
ctx: &mut SearchContext,
|
||||
query: &Option<String>,
|
||||
vector: &Option<Vec<f32>>,
|
||||
terms_matching_strategy: TermsMatchingStrategy,
|
||||
exhaustive_number_hits: bool,
|
||||
filters: &Option<Filter>,
|
||||
@@ -427,15 +431,40 @@ pub fn execute_search(
|
||||
)?
|
||||
};
|
||||
|
||||
let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output;
|
||||
let BucketSortOutput { docids, mut all_candidates } = bucket_sort_output;
|
||||
|
||||
let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?;
|
||||
let docids = match vector {
|
||||
Some(vector) => {
|
||||
// return the nearest documents that are also part of the candidates.
|
||||
let mut searcher = Searcher::new();
|
||||
let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default();
|
||||
let ef = hnsw.len().min(100);
|
||||
let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef];
|
||||
let neighbors = hnsw.nearest(vector, ef, &mut searcher, &mut dest[..]);
|
||||
|
||||
let mut docids = Vec::new();
|
||||
for Neighbor { index, distance: _ } in neighbors.iter() {
|
||||
let index = BEU32::new(*index as u32);
|
||||
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get();
|
||||
if universe.contains(docid) {
|
||||
docids.push(docid);
|
||||
if docids.len() == (from + length) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
docids.into_iter().skip(from).take(length).collect()
|
||||
}
|
||||
// return the search docids if the vector field is not specified
|
||||
None => docids,
|
||||
};
|
||||
|
||||
// The candidates is the universe unless the exhaustive number of hits
|
||||
// is requested and a distinct attribute is set.
|
||||
if exhaustive_number_hits {
|
||||
if let Some(f) = ctx.index.distinct_field(ctx.txn)? {
|
||||
if let Some(distinct_fid) = fields_ids_map.id(f) {
|
||||
if let Some(distinct_fid) = ctx.index.fields_ids_map(ctx.txn)?.id(f) {
|
||||
all_candidates = apply_distinct_rule(ctx, distinct_fid, &all_candidates)?.remaining;
|
||||
}
|
||||
}
|
||||
@@ -443,7 +472,6 @@ pub fn execute_search(
|
||||
|
||||
Ok(PartialSearchResult {
|
||||
candidates: all_candidates,
|
||||
document_scores: scores,
|
||||
documents_ids: docids,
|
||||
located_query_terms,
|
||||
})
|
||||
@@ -495,5 +523,4 @@ pub struct PartialSearchResult {
|
||||
pub located_query_terms: Option<Vec<LocatedQueryTerm>>,
|
||||
pub candidates: RoaringBitmap,
|
||||
pub documents_ids: Vec<DocumentId>,
|
||||
pub document_scores: Vec<Vec<ScoreDetails>>,
|
||||
}
|
||||
|
||||
@@ -49,15 +49,10 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
||||
if let Some((cost_of_ignoring, forbidden_nodes)) =
|
||||
cost_of_ignoring_node.get(dest_idx)
|
||||
{
|
||||
let dest = graph_nodes.get(dest_idx);
|
||||
let dest_size = match &dest.data {
|
||||
QueryNodeData::Term(term) => term.term_ids.len(),
|
||||
_ => panic!(),
|
||||
};
|
||||
let new_edge_id = edges_store.insert(Some(Edge {
|
||||
source_node: source_id,
|
||||
dest_node: dest_idx,
|
||||
cost: *cost_of_ignoring * dest_size as u32,
|
||||
cost: *cost_of_ignoring,
|
||||
condition: None,
|
||||
nodes_to_skip: forbidden_nodes.clone(),
|
||||
}));
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::{ComputedCondition, RankingRuleGraphTrait};
|
||||
use crate::score_details::{Rank, ScoreDetails};
|
||||
use crate::search::new::interner::{DedupInterner, Interned};
|
||||
use crate::search::new::query_term::{ExactTerm, LocatedQueryTermSubset};
|
||||
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids;
|
||||
@@ -85,8 +84,4 @@ impl RankingRuleGraphTrait for ExactnessGraph {
|
||||
|
||||
Ok(vec![(0, exact_condition), (dest_node.term_ids.len() as u32, skip_condition)])
|
||||
}
|
||||
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails {
|
||||
ScoreDetails::Exactness(rank)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use fxhash::FxHashSet;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::{ComputedCondition, RankingRuleGraphTrait};
|
||||
use crate::score_details::{Rank, ScoreDetails};
|
||||
use crate::search::new::interner::{DedupInterner, Interned};
|
||||
use crate::search::new::query_term::LocatedQueryTermSubset;
|
||||
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_field_id;
|
||||
@@ -69,7 +68,7 @@ impl RankingRuleGraphTrait for FidGraph {
|
||||
}
|
||||
|
||||
let mut edges = vec![];
|
||||
for fid in all_fields.iter().copied() {
|
||||
for fid in all_fields {
|
||||
// TODO: We can improve performances and relevancy by storing
|
||||
// the term subsets associated to each field ids fetched.
|
||||
edges.push((
|
||||
@@ -81,35 +80,6 @@ impl RankingRuleGraphTrait for FidGraph {
|
||||
));
|
||||
}
|
||||
|
||||
// always lookup the max_fid if we don't already and add an artificial condition for max scoring
|
||||
let max_fid: Option<u16> = {
|
||||
if let Some(max_fid) = ctx
|
||||
.index
|
||||
.searchable_fields_ids(ctx.txn)?
|
||||
.map(|field_ids| field_ids.into_iter().max())
|
||||
{
|
||||
max_fid
|
||||
} else {
|
||||
ctx.index.fields_ids_map(ctx.txn)?.ids().max()
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(max_fid) = max_fid {
|
||||
if !all_fields.contains(&max_fid) {
|
||||
edges.push((
|
||||
max_fid as u32 * term.term_ids.len() as u32, // TODO improve the fid score i.e. fid^10.
|
||||
conditions_interner.insert(FidCondition {
|
||||
term: term.clone(), // TODO remove this ugly clone
|
||||
fid: max_fid,
|
||||
}),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(edges)
|
||||
}
|
||||
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails {
|
||||
ScoreDetails::Fid(rank)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ use super::interner::{DedupInterner, FixedSizeInterner, Interned, MappedInterner
|
||||
use super::query_term::LocatedQueryTermSubset;
|
||||
use super::small_bitmap::SmallBitmap;
|
||||
use super::{QueryGraph, QueryNode, SearchContext};
|
||||
use crate::score_details::{Rank, ScoreDetails};
|
||||
use crate::Result;
|
||||
|
||||
pub struct ComputedCondition {
|
||||
@@ -111,9 +110,6 @@ pub trait RankingRuleGraphTrait: Sized + 'static {
|
||||
source_node: Option<&LocatedQueryTermSubset>,
|
||||
dest_node: &LocatedQueryTermSubset,
|
||||
) -> Result<Vec<(u32, Interned<Self::Condition>)>>;
|
||||
|
||||
/// Convert the rank of a path to its corresponding score for the ranking rule
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails;
|
||||
}
|
||||
|
||||
/// The graph used by graph-based ranking rules.
|
||||
|
||||
@@ -2,7 +2,6 @@ use fxhash::{FxHashMap, FxHashSet};
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::{ComputedCondition, RankingRuleGraphTrait};
|
||||
use crate::score_details::{Rank, ScoreDetails};
|
||||
use crate::search::new::interner::{DedupInterner, Interned};
|
||||
use crate::search::new::query_term::LocatedQueryTermSubset;
|
||||
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_position;
|
||||
@@ -106,20 +105,8 @@ impl RankingRuleGraphTrait for PositionGraph {
|
||||
));
|
||||
}
|
||||
|
||||
// artificial empty condition for computing max cost
|
||||
let max_cost = term.term_ids.len() as u32 * 10;
|
||||
edges.push((
|
||||
max_cost,
|
||||
conditions_interner
|
||||
.insert(PositionCondition { term: term.clone(), positions: Vec::default() }),
|
||||
));
|
||||
|
||||
Ok(edges)
|
||||
}
|
||||
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails {
|
||||
ScoreDetails::Position(rank)
|
||||
}
|
||||
}
|
||||
|
||||
fn cost_from_position(sum_positions: u32) -> u32 {
|
||||
|
||||
@@ -4,7 +4,6 @@ pub mod compute_docids;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::{ComputedCondition, RankingRuleGraphTrait};
|
||||
use crate::score_details::{Rank, ScoreDetails};
|
||||
use crate::search::new::interner::{DedupInterner, Interned};
|
||||
use crate::search::new::query_term::LocatedQueryTermSubset;
|
||||
use crate::search::new::SearchContext;
|
||||
@@ -37,8 +36,4 @@ impl RankingRuleGraphTrait for ProximityGraph {
|
||||
) -> Result<Vec<(u32, Interned<Self::Condition>)>> {
|
||||
build::build_edges(ctx, conditions_interner, source_term, dest_term)
|
||||
}
|
||||
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails {
|
||||
ScoreDetails::Proximity(rank)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::{ComputedCondition, RankingRuleGraphTrait};
|
||||
use crate::score_details::{self, Rank, ScoreDetails};
|
||||
use crate::search::new::interner::{DedupInterner, Interned};
|
||||
use crate::search::new::query_term::LocatedQueryTermSubset;
|
||||
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids;
|
||||
@@ -76,8 +75,4 @@ impl RankingRuleGraphTrait for TypoGraph {
|
||||
}
|
||||
Ok(edges)
|
||||
}
|
||||
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails {
|
||||
ScoreDetails::Typo(score_details::Typo::from_rank(rank))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::{ComputedCondition, RankingRuleGraphTrait};
|
||||
use crate::score_details::{self, Rank, ScoreDetails};
|
||||
use crate::search::new::interner::{DedupInterner, Interned};
|
||||
use crate::search::new::query_term::LocatedQueryTermSubset;
|
||||
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids;
|
||||
@@ -42,10 +41,9 @@ impl RankingRuleGraphTrait for WordsGraph {
|
||||
_from: Option<&LocatedQueryTermSubset>,
|
||||
to_term: &LocatedQueryTermSubset,
|
||||
) -> Result<Vec<(u32, Interned<Self::Condition>)>> {
|
||||
Ok(vec![(0, conditions_interner.insert(WordsCondition { term: to_term.clone() }))])
|
||||
}
|
||||
|
||||
fn rank_to_score(rank: Rank) -> ScoreDetails {
|
||||
ScoreDetails::Words(score_details::Words::from_rank(rank))
|
||||
Ok(vec![(
|
||||
to_term.term_ids.len() as u32,
|
||||
conditions_interner.insert(WordsCondition { term: to_term.clone() }),
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use roaring::RoaringBitmap;
|
||||
|
||||
use super::logger::SearchLogger;
|
||||
use super::{QueryGraph, SearchContext};
|
||||
use crate::score_details::ScoreDetails;
|
||||
use crate::Result;
|
||||
|
||||
/// An internal trait implemented by only [`PlaceholderQuery`] and [`QueryGraph`]
|
||||
@@ -67,6 +66,4 @@ pub struct RankingRuleOutput<Q> {
|
||||
pub query: Q,
|
||||
/// The allowed candidates for the child ranking rule
|
||||
pub candidates: RoaringBitmap,
|
||||
/// The score for the candidates of the current bucket
|
||||
pub score: ScoreDetails,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use heed::BytesDecode;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::logger::SearchLogger;
|
||||
use super::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait, SearchContext};
|
||||
use crate::heed_codec::facet::{FacetGroupKeyCodec, OrderedF64Codec};
|
||||
use crate::heed_codec::{ByteSliceRefCodec, StrRefCodec};
|
||||
use crate::score_details::{self, ScoreDetails};
|
||||
use crate::heed_codec::facet::FacetGroupKeyCodec;
|
||||
use crate::heed_codec::ByteSliceRefCodec;
|
||||
use crate::search::facet::{ascending_facet_sort, descending_facet_sort};
|
||||
use crate::{FieldId, Index, Result};
|
||||
|
||||
@@ -69,7 +67,7 @@ impl<'ctx, Query> Sort<'ctx, Query> {
|
||||
impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx, Query> {
|
||||
fn id(&self) -> String {
|
||||
let Self { field_name, is_ascending, .. } = self;
|
||||
format!("{field_name}:{}", if *is_ascending { "asc" } else { "desc" })
|
||||
format!("{field_name}:{}", if *is_ascending { "asc" } else { "desc " })
|
||||
}
|
||||
fn start_iteration(
|
||||
&mut self,
|
||||
@@ -120,43 +118,12 @@ impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx,
|
||||
|
||||
(itertools::Either::Right(number_iter), itertools::Either::Right(string_iter))
|
||||
};
|
||||
let number_iter = number_iter.map(|r| -> Result<_> {
|
||||
let (docids, bytes) = r?;
|
||||
Ok((
|
||||
docids,
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(
|
||||
OrderedF64Codec::bytes_decode(bytes).expect("some number"),
|
||||
)
|
||||
.expect("too big float"),
|
||||
),
|
||||
))
|
||||
});
|
||||
let string_iter = string_iter.map(|r| -> Result<_> {
|
||||
let (docids, bytes) = r?;
|
||||
Ok((
|
||||
docids,
|
||||
serde_json::Value::String(
|
||||
StrRefCodec::bytes_decode(bytes).expect("some string").to_owned(),
|
||||
),
|
||||
))
|
||||
});
|
||||
|
||||
let query_graph = parent_query.clone();
|
||||
let ascending = self.is_ascending;
|
||||
let field_name = self.field_name.clone();
|
||||
RankingRuleOutputIterWrapper::new(Box::new(number_iter.chain(string_iter).map(
|
||||
move |r| {
|
||||
let (docids, value) = r?;
|
||||
Ok(RankingRuleOutput {
|
||||
query: query_graph.clone(),
|
||||
candidates: docids,
|
||||
score: ScoreDetails::Sort(score_details::Sort {
|
||||
field_name: field_name.clone(),
|
||||
ascending,
|
||||
value,
|
||||
}),
|
||||
})
|
||||
let (docids, _) = r?;
|
||||
Ok(RankingRuleOutput { query: query_graph.clone(), candidates: docids })
|
||||
},
|
||||
)))
|
||||
}
|
||||
@@ -183,15 +150,7 @@ impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx,
|
||||
Ok(Some(bucket))
|
||||
} else {
|
||||
let query = self.original_query.as_ref().unwrap().clone();
|
||||
Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: universe.clone(),
|
||||
score: ScoreDetails::Sort(score_details::Sort {
|
||||
field_name: self.field_name.clone(),
|
||||
ascending: self.is_ascending,
|
||||
value: serde_json::Value::Null,
|
||||
}),
|
||||
}))
|
||||
Ok(Some(RankingRuleOutput { query, candidates: universe.clone() }))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -318,7 +318,7 @@ pub fn snap_field_distributions(index: &Index) -> String {
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
let mut snap = String::new();
|
||||
for (field, count) in index.field_distribution(&rtxn).unwrap() {
|
||||
writeln!(&mut snap, "{field:<16} {count:<6}").unwrap();
|
||||
writeln!(&mut snap, "{field:<16} {count:<6} |").unwrap();
|
||||
}
|
||||
snap
|
||||
}
|
||||
@@ -328,7 +328,7 @@ pub fn snap_fields_ids_map(index: &Index) -> String {
|
||||
let mut snap = String::new();
|
||||
for field_id in fields_ids_map.ids() {
|
||||
let name = fields_ids_map.name(field_id).unwrap();
|
||||
writeln!(&mut snap, "{field_id:<3} {name:<16}").unwrap();
|
||||
writeln!(&mut snap, "{field_id:<3} {name:<16} |").unwrap();
|
||||
}
|
||||
snap
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
source: milli/src/index.rs
|
||||
---
|
||||
age 1
|
||||
id 2
|
||||
name 2
|
||||
age 1 |
|
||||
id 2 |
|
||||
name 2 |
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
source: milli/src/index.rs
|
||||
---
|
||||
age 1
|
||||
id 2
|
||||
name 2
|
||||
age 1 |
|
||||
id 2 |
|
||||
name 2 |
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> {
|
||||
facet_id_is_empty_docids,
|
||||
field_id_docid_facet_f64s,
|
||||
field_id_docid_facet_strings,
|
||||
vector_id_docid,
|
||||
documents,
|
||||
} = self.index;
|
||||
|
||||
@@ -57,6 +58,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> {
|
||||
self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?;
|
||||
self.index.delete_geo_rtree(self.wtxn)?;
|
||||
self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
|
||||
self.index.delete_vector_hnsw(self.wtxn)?;
|
||||
|
||||
// We clean all the faceted documents ids.
|
||||
for field_id in faceted_fields {
|
||||
@@ -95,6 +97,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> {
|
||||
facet_id_string_docids.clear(self.wtxn)?;
|
||||
field_id_docid_facet_f64s.clear(self.wtxn)?;
|
||||
field_id_docid_facet_strings.clear(self.wtxn)?;
|
||||
vector_id_docid.clear(self.wtxn)?;
|
||||
documents.clear(self.wtxn)?;
|
||||
|
||||
Ok(number_of_documents)
|
||||
|
||||
@@ -4,8 +4,10 @@ use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use fst::IntoStreamer;
|
||||
use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice};
|
||||
use heed::{BytesDecode, BytesEncode, Database, RwIter};
|
||||
use hnsw::Searcher;
|
||||
use roaring::RoaringBitmap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use space::KnnPoints;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use super::facet::delete::FacetsDelete;
|
||||
@@ -14,6 +16,7 @@ use crate::error::InternalError;
|
||||
use crate::facet::FacetType;
|
||||
use crate::heed_codec::facet::FieldDocIdFacetCodec;
|
||||
use crate::heed_codec::CboRoaringBitmapCodec;
|
||||
use crate::index::Hnsw;
|
||||
use crate::{
|
||||
ExternalDocumentsIds, FieldId, FieldIdMapMissingEntry, Index, Result, RoaringBitmapCodec, BEU32,
|
||||
};
|
||||
@@ -247,6 +250,7 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> {
|
||||
facet_id_exists_docids,
|
||||
facet_id_is_null_docids,
|
||||
facet_id_is_empty_docids,
|
||||
vector_id_docid,
|
||||
documents,
|
||||
} = self.index;
|
||||
// Remove from the documents database
|
||||
@@ -436,6 +440,30 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> {
|
||||
&self.to_delete_docids,
|
||||
)?;
|
||||
|
||||
// An ugly and slow way to remove the vectors from the HNSW
|
||||
// It basically reconstructs the HNSW from scratch without editing the current one.
|
||||
let current_hnsw = self.index.vector_hnsw(self.wtxn)?.unwrap_or_default();
|
||||
if !current_hnsw.is_empty() {
|
||||
let mut new_hnsw = Hnsw::default();
|
||||
let mut searcher = Searcher::new();
|
||||
let mut new_vector_id_docids = Vec::new();
|
||||
|
||||
for result in vector_id_docid.iter(self.wtxn)? {
|
||||
let (vector_id, docid) = result?;
|
||||
if !self.to_delete_docids.contains(docid.get()) {
|
||||
let vector = current_hnsw.get_point(vector_id.get() as usize).clone();
|
||||
let vector_id = new_hnsw.insert(vector, &mut searcher);
|
||||
new_vector_id_docids.push((vector_id as u32, docid));
|
||||
}
|
||||
}
|
||||
|
||||
vector_id_docid.clear(self.wtxn)?;
|
||||
for (vector_id, docid) in new_vector_id_docids {
|
||||
vector_id_docid.put(self.wtxn, &BEU32::new(vector_id), &docid)?;
|
||||
}
|
||||
self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?;
|
||||
}
|
||||
|
||||
self.index.put_soft_deleted_documents_ids(self.wtxn, &RoaringBitmap::new())?;
|
||||
|
||||
Ok(DetailedDocumentDeletionResult {
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
|
||||
use bytemuck::cast_slice;
|
||||
use serde_json::from_slice;
|
||||
|
||||
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
||||
use crate::{FieldId, InternalError, Result};
|
||||
|
||||
/// Extracts the embedding vector contained in each document under the `_vector` field.
|
||||
///
|
||||
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
||||
#[logging_timer::time]
|
||||
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
obkv_documents: grenad::Reader<R>,
|
||||
indexer: GrenadParameters,
|
||||
vector_fid: FieldId,
|
||||
) -> Result<grenad::Reader<File>> {
|
||||
let mut writer = create_writer(
|
||||
indexer.chunk_compression_type,
|
||||
indexer.chunk_compression_level,
|
||||
tempfile::tempfile()?,
|
||||
);
|
||||
|
||||
let mut cursor = obkv_documents.into_cursor()?;
|
||||
while let Some((docid_bytes, value)) = cursor.move_on_next()? {
|
||||
let obkv = obkv::KvReader::new(value);
|
||||
|
||||
// first we get the _vector field
|
||||
if let Some(vector) = obkv.get(vector_fid) {
|
||||
// try to extract the vector
|
||||
let vector: Vec<f32> = from_slice(vector).map_err(InternalError::SerdeJson).unwrap();
|
||||
let bytes = cast_slice(&vector);
|
||||
writer.insert(docid_bytes, bytes)?;
|
||||
}
|
||||
// else => the _vector object was `null`, there is nothing to do
|
||||
}
|
||||
|
||||
writer_into_reader(writer)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ mod extract_facet_string_docids;
|
||||
mod extract_fid_docid_facet_values;
|
||||
mod extract_fid_word_count_docids;
|
||||
mod extract_geo_points;
|
||||
mod extract_vector_points;
|
||||
mod extract_word_docids;
|
||||
mod extract_word_fid_docids;
|
||||
mod extract_word_pair_proximity_docids;
|
||||
@@ -22,6 +23,7 @@ use self::extract_facet_string_docids::extract_facet_string_docids;
|
||||
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
|
||||
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
|
||||
use self::extract_geo_points::extract_geo_points;
|
||||
use self::extract_vector_points::extract_vector_points;
|
||||
use self::extract_word_docids::extract_word_docids;
|
||||
use self::extract_word_fid_docids::extract_word_fid_docids;
|
||||
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
|
||||
@@ -45,6 +47,7 @@ pub(crate) fn data_from_obkv_documents(
|
||||
faceted_fields: HashSet<FieldId>,
|
||||
primary_key_id: FieldId,
|
||||
geo_fields_ids: Option<(FieldId, FieldId)>,
|
||||
vector_field_id: Option<FieldId>,
|
||||
stop_words: Option<fst::Set<&[u8]>>,
|
||||
max_positions_per_attributes: Option<u32>,
|
||||
exact_attributes: HashSet<FieldId>,
|
||||
@@ -69,6 +72,7 @@ pub(crate) fn data_from_obkv_documents(
|
||||
&faceted_fields,
|
||||
primary_key_id,
|
||||
geo_fields_ids,
|
||||
vector_field_id,
|
||||
&stop_words,
|
||||
max_positions_per_attributes,
|
||||
)
|
||||
@@ -279,6 +283,7 @@ fn send_and_extract_flattened_documents_data(
|
||||
faceted_fields: &HashSet<FieldId>,
|
||||
primary_key_id: FieldId,
|
||||
geo_fields_ids: Option<(FieldId, FieldId)>,
|
||||
vector_field_id: Option<FieldId>,
|
||||
stop_words: &Option<fst::Set<&[u8]>>,
|
||||
max_positions_per_attributes: Option<u32>,
|
||||
) -> Result<(
|
||||
@@ -307,6 +312,20 @@ fn send_and_extract_flattened_documents_data(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(vector_field_id) = vector_field_id {
|
||||
let documents_chunk_cloned = flattened_documents_chunk.clone();
|
||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||
rayon::spawn(move || {
|
||||
let result = extract_vector_points(documents_chunk_cloned, indexer, vector_field_id);
|
||||
let _ = match result {
|
||||
Ok(vector_points) => {
|
||||
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))
|
||||
}
|
||||
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
let (docid_word_positions_chunk, docid_fid_facet_values_chunks): (Result<_>, Result<_>) =
|
||||
rayon::join(
|
||||
|| {
|
||||
|
||||
@@ -304,6 +304,8 @@ where
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
// get the fid of the `_vector` field.
|
||||
let vector_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vector");
|
||||
|
||||
let stop_words = self.index.stop_words(self.wtxn)?;
|
||||
let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?;
|
||||
@@ -340,6 +342,7 @@ where
|
||||
faceted_fields,
|
||||
primary_key_id,
|
||||
geo_fields_ids,
|
||||
vector_field_id,
|
||||
stop_words,
|
||||
max_positions_per_attributes,
|
||||
exact_attributes,
|
||||
|
||||
@@ -4,20 +4,24 @@ use std::convert::TryInto;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
|
||||
use bytemuck::allocation::pod_collect_to_vec;
|
||||
use charabia::{Language, Script};
|
||||
use grenad::MergerBuilder;
|
||||
use heed::types::ByteSlice;
|
||||
use heed::RwTxn;
|
||||
use hnsw::Searcher;
|
||||
use roaring::RoaringBitmap;
|
||||
use space::KnnPoints;
|
||||
|
||||
use super::helpers::{
|
||||
self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap,
|
||||
};
|
||||
use super::{ClonableMmap, MergeFn};
|
||||
use crate::error::UserError;
|
||||
use crate::facet::FacetType;
|
||||
use crate::update::facet::FacetsUpdate;
|
||||
use crate::update::index_documents::helpers::as_cloneable_grenad;
|
||||
use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result};
|
||||
use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32};
|
||||
|
||||
pub(crate) enum TypedChunk {
|
||||
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
|
||||
@@ -38,6 +42,7 @@ pub(crate) enum TypedChunk {
|
||||
FieldIdFacetIsNullDocids(grenad::Reader<File>),
|
||||
FieldIdFacetIsEmptyDocids(grenad::Reader<File>),
|
||||
GeoPoints(grenad::Reader<File>),
|
||||
VectorPoints(grenad::Reader<File>),
|
||||
ScriptLanguageDocids(HashMap<(Script, Language), RoaringBitmap>),
|
||||
}
|
||||
|
||||
@@ -221,6 +226,38 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
index.put_geo_rtree(wtxn, &rtree)?;
|
||||
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
|
||||
}
|
||||
TypedChunk::VectorPoints(vector_points) => {
|
||||
let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default();
|
||||
let mut searcher = Searcher::new();
|
||||
|
||||
let mut expected_dimensions = match index.vector_id_docid.iter(wtxn)?.next() {
|
||||
Some(result) => {
|
||||
let (vector_id, _) = result?;
|
||||
Some(hnsw.get_point(vector_id.get() as usize).len())
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut cursor = vector_points.into_cursor()?;
|
||||
while let Some((key, value)) = cursor.move_on_next()? {
|
||||
// convert the key back to a u32 (4 bytes)
|
||||
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
||||
// convert the vector back to a Vec<f32>
|
||||
let vector: Vec<f32> = pod_collect_to_vec(value);
|
||||
|
||||
// TODO Move this error in the vector extractor
|
||||
let found = vector.len();
|
||||
let expected = *expected_dimensions.get_or_insert(found);
|
||||
if expected != found {
|
||||
return Err(UserError::InvalidVectorDimensions { expected, found })?;
|
||||
}
|
||||
|
||||
let vector_id = hnsw.insert(vector, &mut searcher) as u32;
|
||||
index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?;
|
||||
}
|
||||
log::debug!("There are {} entries in the HNSW so far", hnsw.len());
|
||||
index.put_vector_hnsw(wtxn, &hnsw)?;
|
||||
}
|
||||
TypedChunk::ScriptLanguageDocids(hash_pair) => {
|
||||
let mut buffer = Vec::new();
|
||||
for (key, value) in hash_pair {
|
||||
|
||||
Reference in New Issue
Block a user