mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-07-18 12:20:48 +00:00
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
1a083d54fc | |||
3a97d30cd9 | |||
3bdb2b06be | |||
64afed4dbe | |||
b26ddfcc3d | |||
c59f3f2f95 | |||
049bd45849 | |||
2491db8746 | |||
425bc92ce6 | |||
cbd065ed46 | |||
b9f365a965 | |||
3f21daf2e7 | |||
d77df4ecdb | |||
fdac97e3c8 | |||
bbdfbd8ea1 | |||
da7c796be1 | |||
014eaea428 |
30
Cargo.lock
generated
30
Cargo.lock
generated
@ -491,7 +491,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "benchmarks"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
@ -1402,7 +1402,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "dump"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"big_s",
|
||||
@ -1634,7 +1634,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "file-store"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"faux",
|
||||
"tempfile",
|
||||
@ -1656,7 +1656,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "filter-parser"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"insta",
|
||||
"nom",
|
||||
@ -1687,7 +1687,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "flatten-serde-json"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"serde_json",
|
||||
@ -1805,7 +1805,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "fuzzers"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"clap",
|
||||
@ -2763,7 +2763,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "index-scheduler"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"big_s",
|
||||
@ -2960,7 +2960,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "json-depth-checker"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"serde_json",
|
||||
@ -3472,7 +3472,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
|
||||
|
||||
[[package]]
|
||||
name = "meili-snap"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"insta",
|
||||
"md5",
|
||||
@ -3481,7 +3481,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"actix-cors",
|
||||
"actix-http",
|
||||
@ -3572,7 +3572,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch-auth"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"base64 0.21.5",
|
||||
"enum-iterator",
|
||||
@ -3591,7 +3591,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch-types"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"actix-web",
|
||||
"anyhow",
|
||||
@ -3621,7 +3621,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilitool"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
@ -3669,7 +3669,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "milli"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"arroy",
|
||||
"big_s",
|
||||
@ -4076,7 +4076,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "permissive-json-pointer"
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
dependencies = [
|
||||
"big_s",
|
||||
"serde_json",
|
||||
|
@ -19,7 +19,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "1.6.0"
|
||||
version = "1.6.2"
|
||||
authors = ["Quentin de Quelen <quentin@dequelen.me>", "Clément Renault <clement@meilisearch.com>"]
|
||||
description = "Meilisearch HTTP server"
|
||||
homepage = "https://meilisearch.com"
|
||||
|
@ -154,5 +154,5 @@ greek = ["meilisearch-types/greek"]
|
||||
khmer = ["meilisearch-types/khmer"]
|
||||
|
||||
[package.metadata.mini-dashboard]
|
||||
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.12/build.zip"
|
||||
sha1 = "acfe9a018c93eb0604ea87ee87bff7df5474e18e"
|
||||
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.13/build.zip"
|
||||
sha1 = "e20cc9b390003c6c844f4b8bcc5c5013191a77ff"
|
||||
|
@ -64,7 +64,7 @@ impl Display for Value {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
json_string!(self, { ".enqueuedAt" => "[date]", ".processedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" })
|
||||
json_string!(self, { ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" })
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1760,6 +1760,181 @@ async fn add_documents_invalid_geo_field() {
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
// The three next tests are related to #4333
|
||||
|
||||
// _geo has a lat and lng but set to `null`
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "12",
|
||||
"_geo": { "lng": null, "lat": 67}
|
||||
}
|
||||
]);
|
||||
|
||||
let (response, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@r###"
|
||||
{
|
||||
"uid": 14,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse longitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
// _geo has a lat and lng but set to `null`
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "12",
|
||||
"_geo": { "lng": 35, "lat": null }
|
||||
}
|
||||
]);
|
||||
|
||||
let (response, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@r###"
|
||||
{
|
||||
"uid": 15,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse latitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
// _geo has a lat and lng but set to `null`
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "13",
|
||||
"_geo": { "lng": null, "lat": null }
|
||||
}
|
||||
]);
|
||||
|
||||
let (response, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@r###"
|
||||
{
|
||||
"uid": 16,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse latitude nor longitude in the document with the id: `13`. Was expecting finite numbers but instead got `null` and `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
// Related to #4333
|
||||
#[actix_rt::test]
|
||||
async fn add_invalid_geo_and_then_settings() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
index.create(Some("id")).await;
|
||||
|
||||
// _geo is not an object
|
||||
let documents = json!([
|
||||
{
|
||||
"id": "11",
|
||||
"_geo": { "lat": null, "lng": null },
|
||||
}
|
||||
]);
|
||||
let (ret, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let ret = index.wait_task(ret.uid()).await;
|
||||
snapshot!(ret, @r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
let (ret, code) = index.update_settings(json!({"sortableAttributes": ["_geo"]})).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let ret = index.wait_task(ret.uid()).await;
|
||||
snapshot!(ret, @r###"
|
||||
{
|
||||
"uid": 2,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "settingsUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"sortableAttributes": [
|
||||
"_geo"
|
||||
]
|
||||
},
|
||||
"error": {
|
||||
"message": "Could not parse latitude in the document with the id: `\"11\"`. Was expecting a finite number but instead got `null`.",
|
||||
"code": "invalid_document_geo_field",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
|
@ -87,6 +87,52 @@ async fn simple_search() {
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn highlighter() {
|
||||
let server = Server::new().await;
|
||||
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
|
||||
|
||||
let (response, code) = index
|
||||
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
|
||||
"hybrid": {"semanticRatio": 0.2},
|
||||
"attributesToHighlight": [
|
||||
"desc"
|
||||
],
|
||||
"highlightPreTag": "**BEGIN**",
|
||||
"highlightPostTag": "**END**"
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}}}]"###);
|
||||
|
||||
let (response, code) = index
|
||||
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
|
||||
"hybrid": {"semanticRatio": 0.8},
|
||||
"attributesToHighlight": [
|
||||
"desc"
|
||||
],
|
||||
"highlightPreTag": "**BEGIN**",
|
||||
"highlightPostTag": "**END**"
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_semanticScore":0.9472136}]"###);
|
||||
|
||||
// no highlighting on full semantic
|
||||
let (response, code) = index
|
||||
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
|
||||
"hybrid": {"semanticRatio": 1.0},
|
||||
"attributesToHighlight": [
|
||||
"desc"
|
||||
],
|
||||
"highlightPreTag": "**BEGIN**",
|
||||
"highlightPostTag": "**END**"
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}}]"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn invalid_semantic_ratio() {
|
||||
let server = Server::new().await;
|
||||
|
@ -222,72 +222,3 @@ where
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use heed::BytesDecode;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::lexicographically_iterate_over_facet_distribution;
|
||||
use crate::heed_codec::facet::OrderedF64Codec;
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::tests::{get_random_looking_index, get_simple_index};
|
||||
|
||||
#[test]
|
||||
fn filter_distribution_all() {
|
||||
let indexes = [get_simple_index(), get_random_looking_index()];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (0..=255).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
lexicographically_iterate_over_facet_distribution(
|
||||
&txn,
|
||||
index.content,
|
||||
0,
|
||||
&candidates,
|
||||
|facet, count, _| {
|
||||
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
|
||||
results.push_str(&format!("{facet}: {count}\n"));
|
||||
Ok(ControlFlow::Continue(()))
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_distribution_all_stop_early() {
|
||||
let indexes = [get_simple_index(), get_random_looking_index()];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (0..=255).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let mut nbr_facets = 0;
|
||||
lexicographically_iterate_over_facet_distribution(
|
||||
&txn,
|
||||
index.content,
|
||||
0,
|
||||
&candidates,
|
||||
|facet, count, _| {
|
||||
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
|
||||
if nbr_facets == 100 {
|
||||
Ok(ControlFlow::Break(()))
|
||||
} else {
|
||||
nbr_facets += 1;
|
||||
results.push_str(&format!("{facet}: {count}\n"));
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -303,347 +303,3 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Bound;
|
||||
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::find_docids_of_facet_within_bounds;
|
||||
use crate::heed_codec::facet::{FacetGroupKeyCodec, OrderedF64Codec};
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::tests::{
|
||||
get_random_looking_index, get_random_looking_index_with_multiple_field_ids,
|
||||
get_simple_index, get_simple_index_with_multiple_field_ids,
|
||||
};
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
|
||||
#[test]
|
||||
fn random_looking_index_snap() {
|
||||
let index = get_random_looking_index();
|
||||
milli_snap!(format!("{index}"), @"3256c76a7c1b768a013e78d5fa6e9ff9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_looking_index_with_multiple_field_ids_snap() {
|
||||
let index = get_random_looking_index_with_multiple_field_ids();
|
||||
milli_snap!(format!("{index}"), @"c3e5fe06a8f1c404ed4935b32c90a89b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_index_snap() {
|
||||
let index = get_simple_index();
|
||||
milli_snap!(format!("{index}"), @"5dbfa134cc44abeb3ab6242fc182e48e");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_index_with_multiple_field_ids_snap() {
|
||||
let index = get_simple_index_with_multiple_field_ids();
|
||||
milli_snap!(format!("{index}"), @"a4893298218f682bc76357f46777448c");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_range_increasing() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(0.);
|
||||
let end = Bound::Included(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!("0 <= . <= {i} : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("included_{i}"));
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Excluded(0.);
|
||||
let end = Bound::Excluded(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!("0 < . < {i} : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("excluded_{i}"));
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn filter_range_decreasing() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=255).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Included(255.);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!("{i} <= . <= 255 : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("included_{i}"));
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=255).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Excluded(i);
|
||||
let end = Bound::Excluded(255.);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!("{i} < . < 255 : {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("excluded_{i}"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn filter_range_pinch() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=128).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Included(255. - i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!(
|
||||
"{i} <= . <= {r} : {docids}\n",
|
||||
r = 255. - i,
|
||||
docids = display_bitmap(&docids)
|
||||
));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("included_{i}"));
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
for i in (0..=128).rev() {
|
||||
let i = i as f64;
|
||||
let start = Bound::Excluded(i);
|
||||
let end = Bound::Excluded(255. - i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
results.push_str(&format!(
|
||||
"{i} < . < {r} {docids}\n",
|
||||
r = 255. - i,
|
||||
docids = display_bitmap(&docids)
|
||||
));
|
||||
}
|
||||
|
||||
milli_snap!(results, format!("excluded_{i}"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_range_unbounded() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Unbounded;
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!(">= {i}: {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("start_from_included_{i}"));
|
||||
let mut results = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Unbounded;
|
||||
let end = Bound::Included(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results.push_str(&format!("<= {i}: {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results, format!("end_at_included_{i}"));
|
||||
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&Bound::Unbounded,
|
||||
&Bound::Unbounded,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(
|
||||
&format!("all field_id 0: {}\n", display_bitmap(&docids)),
|
||||
format!("unbounded_field_id_0_{i}")
|
||||
);
|
||||
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
1,
|
||||
&Bound::Unbounded,
|
||||
&Bound::Unbounded,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
milli_snap!(
|
||||
&format!("all field_id 1: {}\n", display_bitmap(&docids)),
|
||||
format!("unbounded_field_id_1_{i}")
|
||||
);
|
||||
|
||||
drop(txn);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_range_exact() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
get_random_looking_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let mut results_0 = String::new();
|
||||
let mut results_1 = String::new();
|
||||
for i in 0..=255 {
|
||||
let i = i as f64;
|
||||
let start = Bound::Included(i);
|
||||
let end = Bound::Included(i);
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
0,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results_0.push_str(&format!("{i}: {}\n", display_bitmap(&docids)));
|
||||
|
||||
let mut docids = RoaringBitmap::new();
|
||||
find_docids_of_facet_within_bounds::<OrderedF64Codec>(
|
||||
&txn,
|
||||
index.content.remap_key_type::<FacetGroupKeyCodec<OrderedF64Codec>>(),
|
||||
1,
|
||||
&start,
|
||||
&end,
|
||||
&mut docids,
|
||||
)
|
||||
.unwrap();
|
||||
#[allow(clippy::format_push_string)]
|
||||
results_1.push_str(&format!("{i}: {}\n", display_bitmap(&docids)));
|
||||
}
|
||||
milli_snap!(results_0, format!("field_id_0_exact_{i}"));
|
||||
milli_snap!(results_1, format!("field_id_1_exact_{i}"));
|
||||
|
||||
drop(txn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -112,119 +112,3 @@ impl<'t, 'e> Iterator for AscendingFacetSort<'t, 'e> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::facet_sort_ascending::ascending_facet_sort;
|
||||
use crate::search::facet::tests::{
|
||||
get_random_looking_index, get_random_looking_string_index_with_multiple_field_ids,
|
||||
get_simple_index, get_simple_string_index_with_multiple_field_ids,
|
||||
};
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending() {
|
||||
let indexes = [get_simple_index(), get_random_looking_index()];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 0, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_multiple_field_ids() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-0"));
|
||||
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-1"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_no_candidates() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_inexisting_field_id() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = ascending_facet_sort(&txn, index.content, 3, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -117,128 +117,3 @@ impl<'t> Iterator for DescendingFacetSort<'t> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::heed_codec::facet::FacetGroupKeyCodec;
|
||||
use crate::heed_codec::BytesRefCodec;
|
||||
use crate::milli_snap;
|
||||
use crate::search::facet::facet_sort_descending::descending_facet_sort;
|
||||
use crate::search::facet::tests::{
|
||||
get_random_looking_index, get_random_looking_string_index_with_multiple_field_ids,
|
||||
get_simple_index, get_simple_index_with_multiple_field_ids,
|
||||
get_simple_string_index_with_multiple_field_ids,
|
||||
};
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
|
||||
#[test]
|
||||
fn filter_sort_descending() {
|
||||
let indexes = [
|
||||
get_simple_index(),
|
||||
get_random_looking_index(),
|
||||
get_simple_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let db = index.content.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
|
||||
let iter = descending_facet_sort(&txn, db, 0, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, i);
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_descending_multiple_field_ids() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = (200..=300).collect::<RoaringBitmap>();
|
||||
let mut results = String::new();
|
||||
let db = index.content.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
|
||||
let iter = descending_facet_sort(&txn, db, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-0"));
|
||||
|
||||
let mut results = String::new();
|
||||
|
||||
let iter = descending_facet_sort(&txn, db, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
milli_snap!(results, format!("{i}-1"));
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_no_candidates() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = descending_facet_sort(&txn, index.content, 0, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
let mut results = String::new();
|
||||
let iter = descending_facet_sort(&txn, index.content, 1, candidates).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_sort_ascending_with_inexisting_field_id() {
|
||||
let indexes = [
|
||||
get_simple_string_index_with_multiple_field_ids(),
|
||||
get_random_looking_string_index_with_multiple_field_ids(),
|
||||
];
|
||||
for (_i, index) in indexes.iter().enumerate() {
|
||||
let txn = index.env.read_txn().unwrap();
|
||||
let candidates = RoaringBitmap::new();
|
||||
let mut results = String::new();
|
||||
let iter = descending_facet_sort(&txn, index.content, 3, candidates.clone()).unwrap();
|
||||
for el in iter {
|
||||
let (docids, _) = el.unwrap();
|
||||
results.push_str(&display_bitmap(&docids));
|
||||
results.push('\n');
|
||||
}
|
||||
assert!(results.is_empty());
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -116,109 +116,3 @@ pub(crate) fn get_highest_level<'t>(
|
||||
})
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use rand::{Rng, SeedableRng};
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::heed_codec::facet::OrderedF64Codec;
|
||||
use crate::heed_codec::StrRefCodec;
|
||||
use crate::update::facet::test_helpers::FacetIndex;
|
||||
|
||||
pub fn get_simple_index() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
for i in 0..256u16 {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(i as u32);
|
||||
index.insert(&mut txn, 0, &(i as f64), &bitmap);
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_random_looking_index() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
|
||||
|
||||
for (_i, key) in std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).enumerate() {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(key);
|
||||
bitmap.insert(key + 100);
|
||||
index.insert(&mut txn, 0, &(key as f64), &bitmap);
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_simple_index_with_multiple_field_ids() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
for fid in 0..2 {
|
||||
for i in 0..256u16 {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(i as u32);
|
||||
index.insert(&mut txn, fid, &(i as f64), &bitmap);
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_random_looking_index_with_multiple_field_ids() -> FacetIndex<OrderedF64Codec> {
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
|
||||
let keys =
|
||||
std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).collect::<Vec<u32>>();
|
||||
for fid in 0..2 {
|
||||
for (_i, &key) in keys.iter().enumerate() {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(key);
|
||||
bitmap.insert(key + 100);
|
||||
index.insert(&mut txn, fid, &(key as f64), &bitmap);
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_simple_string_index_with_multiple_field_ids() -> FacetIndex<StrRefCodec> {
|
||||
let index = FacetIndex::<StrRefCodec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
for fid in 0..2 {
|
||||
for i in 0..256u16 {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(i as u32);
|
||||
if i % 2 == 0 {
|
||||
index.insert(&mut txn, fid, &format!("{i}").as_str(), &bitmap);
|
||||
} else {
|
||||
index.insert(&mut txn, fid, &"", &bitmap);
|
||||
}
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
pub fn get_random_looking_string_index_with_multiple_field_ids() -> FacetIndex<StrRefCodec> {
|
||||
let index = FacetIndex::<StrRefCodec>::new(4, 8, 5);
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
|
||||
let keys =
|
||||
std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).collect::<Vec<u32>>();
|
||||
for fid in 0..2 {
|
||||
for (_i, &key) in keys.iter().enumerate() {
|
||||
let mut bitmap = RoaringBitmap::new();
|
||||
bitmap.insert(key);
|
||||
bitmap.insert(key + 100);
|
||||
if key % 2 == 0 {
|
||||
index.insert(&mut txn, fid, &format!("{key}").as_str(), &bitmap);
|
||||
} else {
|
||||
index.insert(&mut txn, fid, &"", &bitmap);
|
||||
}
|
||||
}
|
||||
}
|
||||
txn.commit().unwrap();
|
||||
index
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ impl ScoreWithRatioResult {
|
||||
}
|
||||
|
||||
SearchResult {
|
||||
matching_words: left.matching_words,
|
||||
matching_words: right.matching_words,
|
||||
candidates: left.candidates | right.candidates,
|
||||
documents_ids,
|
||||
document_scores,
|
||||
|
@ -407,54 +407,6 @@ mod tests {
|
||||
test("large_group_small_min_level", 16, 2);
|
||||
test("odd_group_odd_min_level", 7, 3);
|
||||
}
|
||||
#[test]
|
||||
fn insert_delete_field_insert() {
|
||||
let test = |name: &str, group_size: u8, min_level_size: u8| {
|
||||
let index =
|
||||
FacetIndex::<OrderedF64Codec>::new(group_size, 0 /*NA*/, min_level_size);
|
||||
let mut wtxn = index.env.write_txn().unwrap();
|
||||
|
||||
let mut elements = Vec::<((u16, f64), RoaringBitmap)>::new();
|
||||
for i in 0..100u32 {
|
||||
// field id = 0, left_bound = i, docids = [i]
|
||||
elements.push(((0, i as f64), once(i).collect()));
|
||||
}
|
||||
for i in 0..100u32 {
|
||||
// field id = 1, left_bound = i, docids = [i]
|
||||
elements.push(((1, i as f64), once(i).collect()));
|
||||
}
|
||||
index.bulk_insert(&mut wtxn, &[0, 1], elements.iter());
|
||||
|
||||
index.verify_structure_validity(&wtxn, 0);
|
||||
index.verify_structure_validity(&wtxn, 1);
|
||||
// delete all the elements for the facet id 0
|
||||
for i in 0..100u32 {
|
||||
index.delete_single_docid(&mut wtxn, 0, &(i as f64), i);
|
||||
}
|
||||
index.verify_structure_validity(&wtxn, 0);
|
||||
index.verify_structure_validity(&wtxn, 1);
|
||||
|
||||
let mut elements = Vec::<((u16, f64), RoaringBitmap)>::new();
|
||||
// then add some elements again for the facet id 1
|
||||
for i in 0..110u32 {
|
||||
// field id = 1, left_bound = i, docids = [i]
|
||||
elements.push(((1, i as f64), once(i).collect()));
|
||||
}
|
||||
index.verify_structure_validity(&wtxn, 0);
|
||||
index.verify_structure_validity(&wtxn, 1);
|
||||
index.bulk_insert(&mut wtxn, &[0, 1], elements.iter());
|
||||
|
||||
wtxn.commit().unwrap();
|
||||
|
||||
milli_snap!(format!("{index}"), name);
|
||||
};
|
||||
|
||||
test("default", 4, 5);
|
||||
test("small_group_small_min_level", 2, 2);
|
||||
test("small_group_large_min_level", 2, 128);
|
||||
test("large_group_small_min_level", 16, 2);
|
||||
test("odd_group_odd_min_level", 7, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bug_3165() {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -72,7 +72,6 @@ two methods.
|
||||
Related PR: https://github.com/meilisearch/milli/pull/619
|
||||
*/
|
||||
|
||||
pub const FACET_MAX_GROUP_SIZE: u8 = 8;
|
||||
pub const FACET_GROUP_SIZE: u8 = 4;
|
||||
pub const FACET_MIN_LEVEL_SIZE: u8 = 5;
|
||||
|
||||
@ -88,17 +87,14 @@ use heed::BytesEncode;
|
||||
use log::debug;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use self::incremental::FacetsUpdateIncremental;
|
||||
use super::FacetsUpdateBulk;
|
||||
use crate::facet::FacetType;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec};
|
||||
use crate::heed_codec::BytesRefCodec;
|
||||
use crate::heed_codec::facet::FacetGroupKey;
|
||||
use crate::update::index_documents::create_sorter;
|
||||
use crate::update::merge_btreeset_string;
|
||||
use crate::{BEU16StrCodec, Index, Result, MAX_FACET_VALUE_LENGTH};
|
||||
|
||||
pub mod bulk;
|
||||
pub mod incremental;
|
||||
|
||||
/// A builder used to add new elements to the `facet_id_string_docids` or `facet_id_f64_docids` databases.
|
||||
///
|
||||
@ -106,11 +102,9 @@ pub mod incremental;
|
||||
/// a bulk update method or an incremental update method.
|
||||
pub struct FacetsUpdate<'i> {
|
||||
index: &'i Index,
|
||||
database: heed::Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
|
||||
facet_type: FacetType,
|
||||
delta_data: grenad::Reader<BufReader<File>>,
|
||||
group_size: u8,
|
||||
max_group_size: u8,
|
||||
min_level_size: u8,
|
||||
}
|
||||
impl<'i> FacetsUpdate<'i> {
|
||||
@ -119,19 +113,9 @@ impl<'i> FacetsUpdate<'i> {
|
||||
facet_type: FacetType,
|
||||
delta_data: grenad::Reader<BufReader<File>>,
|
||||
) -> Self {
|
||||
let database = match facet_type {
|
||||
FacetType::String => {
|
||||
index.facet_id_string_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>()
|
||||
}
|
||||
FacetType::Number => {
|
||||
index.facet_id_f64_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>()
|
||||
}
|
||||
};
|
||||
Self {
|
||||
index,
|
||||
database,
|
||||
group_size: FACET_GROUP_SIZE,
|
||||
max_group_size: FACET_MAX_GROUP_SIZE,
|
||||
min_level_size: FACET_MIN_LEVEL_SIZE,
|
||||
facet_type,
|
||||
delta_data,
|
||||
@ -145,30 +129,16 @@ impl<'i> FacetsUpdate<'i> {
|
||||
debug!("Computing and writing the facet values levels docids into LMDB on disk...");
|
||||
self.index.set_updated_at(wtxn, &OffsetDateTime::now_utc())?;
|
||||
|
||||
// See self::comparison_bench::benchmark_facet_indexing
|
||||
if self.delta_data.len() >= (self.database.len(wtxn)? / 50) {
|
||||
let field_ids =
|
||||
self.index.faceted_fields_ids(wtxn)?.iter().copied().collect::<Vec<_>>();
|
||||
let bulk_update = FacetsUpdateBulk::new(
|
||||
self.index,
|
||||
field_ids,
|
||||
self.facet_type,
|
||||
self.delta_data,
|
||||
self.group_size,
|
||||
self.min_level_size,
|
||||
);
|
||||
bulk_update.execute(wtxn)?;
|
||||
} else {
|
||||
let incremental_update = FacetsUpdateIncremental::new(
|
||||
self.index,
|
||||
self.facet_type,
|
||||
self.delta_data,
|
||||
self.group_size,
|
||||
self.min_level_size,
|
||||
self.max_group_size,
|
||||
);
|
||||
incremental_update.execute(wtxn)?;
|
||||
}
|
||||
let field_ids = self.index.faceted_fields_ids(wtxn)?.iter().copied().collect::<Vec<_>>();
|
||||
let bulk_update = FacetsUpdateBulk::new(
|
||||
self.index,
|
||||
field_ids,
|
||||
self.facet_type,
|
||||
self.delta_data,
|
||||
self.group_size,
|
||||
self.min_level_size,
|
||||
);
|
||||
bulk_update.execute(wtxn)?;
|
||||
|
||||
// We clear the list of normalized-for-search facets
|
||||
// and the previous FSTs to compute everything from scratch
|
||||
@ -264,7 +234,6 @@ impl<'i> FacetsUpdate<'i> {
|
||||
pub(crate) mod test_helpers {
|
||||
use std::cell::Cell;
|
||||
use std::fmt::Display;
|
||||
use std::iter::FromIterator;
|
||||
use std::marker::PhantomData;
|
||||
use std::rc::Rc;
|
||||
|
||||
@ -280,7 +249,6 @@ pub(crate) mod test_helpers {
|
||||
use crate::search::facet::get_highest_level;
|
||||
use crate::snapshot_tests::display_bitmap;
|
||||
use crate::update::del_add::{DelAdd, KvWriterDelAdd};
|
||||
use crate::update::FacetsUpdateIncrementalInner;
|
||||
use crate::CboRoaringBitmapCodec;
|
||||
|
||||
/// Utility function to generate a string whose position in a lexicographically
|
||||
@ -396,49 +364,6 @@ pub(crate) mod test_helpers {
|
||||
self.min_level_size.set(std::cmp::max(1, min_level_size));
|
||||
}
|
||||
|
||||
pub fn insert<'a>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
field_id: u16,
|
||||
key: &'a <BoundCodec as BytesEncode<'a>>::EItem,
|
||||
docids: &RoaringBitmap,
|
||||
) {
|
||||
let update = FacetsUpdateIncrementalInner {
|
||||
db: self.content,
|
||||
group_size: self.group_size.get(),
|
||||
min_level_size: self.min_level_size.get(),
|
||||
max_group_size: self.max_group_size.get(),
|
||||
};
|
||||
let key_bytes = BoundCodec::bytes_encode(key).unwrap();
|
||||
update.insert(wtxn, field_id, &key_bytes, docids).unwrap();
|
||||
}
|
||||
pub fn delete_single_docid<'a>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
field_id: u16,
|
||||
key: &'a <BoundCodec as BytesEncode<'a>>::EItem,
|
||||
docid: u32,
|
||||
) {
|
||||
self.delete(wtxn, field_id, key, &RoaringBitmap::from_iter(std::iter::once(docid)))
|
||||
}
|
||||
|
||||
pub fn delete<'a>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
field_id: u16,
|
||||
key: &'a <BoundCodec as BytesEncode<'a>>::EItem,
|
||||
docids: &RoaringBitmap,
|
||||
) {
|
||||
let update = FacetsUpdateIncrementalInner {
|
||||
db: self.content,
|
||||
group_size: self.group_size.get(),
|
||||
min_level_size: self.min_level_size.get(),
|
||||
max_group_size: self.max_group_size.get(),
|
||||
};
|
||||
let key_bytes = BoundCodec::bytes_encode(key).unwrap();
|
||||
update.delete(wtxn, field_id, &key_bytes, docids).unwrap();
|
||||
}
|
||||
|
||||
pub fn bulk_insert<'a, 'b>(
|
||||
&self,
|
||||
wtxn: &'a mut RwTxn,
|
||||
@ -555,63 +480,3 @@ pub(crate) mod test_helpers {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[cfg(test)]
|
||||
mod comparison_bench {
|
||||
use std::iter::once;
|
||||
|
||||
use rand::Rng;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::test_helpers::FacetIndex;
|
||||
use crate::heed_codec::facet::OrderedF64Codec;
|
||||
|
||||
// This is a simple test to get an intuition on the relative speed
|
||||
// of the incremental vs. bulk indexer.
|
||||
//
|
||||
// The benchmark shows the worst-case scenario for the incremental indexer, since
|
||||
// each facet value contains only one document ID.
|
||||
//
|
||||
// In that scenario, it appears that the incremental indexer is about 50 times slower than the
|
||||
// bulk indexer.
|
||||
// #[test]
|
||||
fn benchmark_facet_indexing() {
|
||||
let mut facet_value = 0;
|
||||
|
||||
let mut r = rand::thread_rng();
|
||||
|
||||
for i in 1..=20 {
|
||||
let size = 50_000 * i;
|
||||
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
|
||||
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
let mut elements = Vec::<((u16, f64), RoaringBitmap)>::new();
|
||||
for i in 0..size {
|
||||
// field id = 0, left_bound = i, docids = [i]
|
||||
elements.push(((0, facet_value as f64), once(i).collect()));
|
||||
facet_value += 1;
|
||||
}
|
||||
let timer = std::time::Instant::now();
|
||||
index.bulk_insert(&mut txn, &[0], elements.iter());
|
||||
let time_spent = timer.elapsed().as_millis();
|
||||
println!("bulk {size} : {time_spent}ms");
|
||||
|
||||
txn.commit().unwrap();
|
||||
|
||||
for nbr_doc in [1, 100, 1000, 10_000] {
|
||||
let mut txn = index.env.write_txn().unwrap();
|
||||
let timer = std::time::Instant::now();
|
||||
//
|
||||
// insert one document
|
||||
//
|
||||
for _ in 0..nbr_doc {
|
||||
index.insert(&mut txn, 0, &r.gen(), &once(1).collect());
|
||||
}
|
||||
let time_spent = timer.elapsed().as_millis();
|
||||
println!(" add {nbr_doc} : {time_spent}ms");
|
||||
txn.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,9 @@ pub fn extract_geo_points<R: io::Read + io::Seek>(
|
||||
// since we only need the primary key when we throw an error
|
||||
// we create this getter to lazily get it when needed
|
||||
let document_id = || -> Value {
|
||||
let document_id = obkv.get(primary_key_id).unwrap();
|
||||
let reader = KvReaderDelAdd::new(obkv.get(primary_key_id).unwrap());
|
||||
let document_id =
|
||||
reader.get(DelAdd::Deletion).or(reader.get(DelAdd::Addition)).unwrap();
|
||||
serde_json::from_slice(document_id).unwrap()
|
||||
};
|
||||
|
||||
|
@ -339,9 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
indexer: GrenadParameters,
|
||||
embedder: Arc<Embedder>,
|
||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
||||
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
|
||||
|
||||
// docid, state with embedding
|
||||
@ -375,11 +373,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
current_chunk_ids.push(docid);
|
||||
|
||||
if chunks.len() == chunks.capacity() {
|
||||
let chunked_embeds = rt
|
||||
.block_on(
|
||||
embedder
|
||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
||||
)
|
||||
let chunked_embeds = embedder
|
||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
@ -396,8 +391,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
|
||||
// send last chunk
|
||||
if !chunks.is_empty() {
|
||||
let chunked_embeds = rt
|
||||
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
||||
let chunked_embeds = embedder
|
||||
.embed_chunks(std::mem::take(&mut chunks))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
for (docid, embeddings) in chunks_ids
|
||||
@ -410,13 +405,15 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
}
|
||||
|
||||
if !current_chunk.is_empty() {
|
||||
let embeds = rt
|
||||
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
||||
let embeds = embedder
|
||||
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
if let Some(embeds) = embeds.first() {
|
||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
pub use self::available_documents_ids::AvailableDocumentsIds;
|
||||
pub use self::clear_documents::ClearDocuments;
|
||||
pub use self::facet::bulk::FacetsUpdateBulk;
|
||||
pub use self::facet::incremental::FacetsUpdateIncrementalInner;
|
||||
pub use self::index_documents::{
|
||||
merge_btreeset_string, merge_cbo_roaring_bitmaps, merge_roaring_bitmaps,
|
||||
DocumentAdditionResult, DocumentId, IndexDocuments, IndexDocumentsConfig, IndexDocumentsMethod,
|
||||
|
@ -67,6 +67,10 @@ pub enum EmbedErrorKind {
|
||||
OpenAiUnhandledStatusCode(u16),
|
||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||
ManualEmbed(String),
|
||||
#[error("could not initialize asynchronous runtime: {0}")]
|
||||
OpenAiRuntimeInit(std::io::Error),
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
@ -117,6 +121,14 @@ impl EmbedError {
|
||||
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@ -183,10 +195,6 @@ impl NewEmbedderError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
||||
}
|
||||
@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind {
|
||||
#[error("loading model failed: {0}")]
|
||||
LoadModel(candle_core::Error),
|
||||
// openai
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
||||
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
||||
}
|
||||
|
@ -145,7 +145,8 @@ impl Embedder {
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
tokens.truncate(512);
|
||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||
|
@ -163,18 +163,24 @@ impl Embedder {
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
||||
Embedder::OpenAi(embedder) => {
|
||||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
}
|
||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
/// # Panics
|
||||
///
|
||||
/// - if called from an asynchronous context
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
client: reqwest::Client,
|
||||
headers: reqwest::header::HeaderMap,
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
@ -95,6 +95,13 @@ impl EmbedderOptions {
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||||
reqwest::ClientBuilder::new()
|
||||
.default_headers(self.headers.clone())
|
||||
.build()
|
||||
.map_err(EmbedError::openai_initialize_web_client)
|
||||
}
|
||||
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let mut inferred_api_key = Default::default();
|
||||
@ -111,25 +118,25 @@ impl Embedder {
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(NewEmbedderError::openai_initialize_web_client)?;
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
Ok(Self { options, client, tokenizer })
|
||||
Ok(Self { options, headers, tokenizer })
|
||||
}
|
||||
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
pub async fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut tokenized = false;
|
||||
|
||||
for attempt in 0..7 {
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts).await
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts).await
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
let retry_duration = match result {
|
||||
@ -145,9 +152,9 @@ impl Embedder {
|
||||
}
|
||||
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts).await
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts).await
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
result.map_err(Retry::into_error)
|
||||
@ -225,13 +232,13 @@ impl Embedder {
|
||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
for text in texts {
|
||||
log::trace!("Received prompt: {}", text.as_ref())
|
||||
}
|
||||
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
|
||||
let response = self
|
||||
.client
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
@ -256,7 +263,11 @@ impl Embedder {
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
async fn try_embed_tokenized(
|
||||
&self,
|
||||
text: &[String],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
pub const OVERLAP_SIZE: usize = 200;
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
@ -264,7 +275,7 @@ impl Embedder {
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.try_embed(&[text]).await?);
|
||||
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -273,22 +284,26 @@ impl Embedder {
|
||||
Embeddings::new(self.options.embedding_model.dimensions());
|
||||
while tokens.len() > max_token_count {
|
||||
let window = &tokens[..max_token_count];
|
||||
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
|
||||
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
|
||||
|
||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||
}
|
||||
|
||||
// end of text
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
||||
async fn embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
for attempt in 0..9 {
|
||||
let duration = match self.try_embed_tokens(tokens).await {
|
||||
let duration = match self.try_embed_tokens(tokens, client).await {
|
||||
Ok(embedding) => return Ok(embedding),
|
||||
Err(retry) => retry.into_duration(attempt),
|
||||
}
|
||||
@ -297,14 +312,19 @@ impl Embedder {
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
|
||||
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
self.try_embed_tokens(tokens, client)
|
||||
.await
|
||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
}
|
||||
|
||||
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
||||
async fn try_embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
let request =
|
||||
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
|
||||
let response = self
|
||||
.client
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
@ -322,12 +342,19 @@ impl Embedder {
|
||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
|
||||
.await
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.map_err(EmbedError::openai_runtime_init)?;
|
||||
let client = self.new_client()?;
|
||||
rt.block_on(futures::future::try_join_all(
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
|
Reference in New Issue
Block a user