Return original string in facet distributions, work on facet tests

This commit is contained in:
Loïc Lecrenier
2022-09-07 17:56:38 +02:00
committed by Loïc Lecrenier
parent 27454e9828
commit fca4577e23
10 changed files with 350 additions and 213 deletions

View File

@ -4,8 +4,9 @@ use heed::Result;
use roaring::RoaringBitmap;
use super::{get_first_facet_value, get_highest_level};
use crate::heed_codec::facet::{
ByteSliceRef, FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec,
use crate::{
heed_codec::facet::{ByteSliceRef, FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec},
DocumentId,
};
pub fn iterate_over_facet_distribution<'t, CB>(
@ -16,7 +17,7 @@ pub fn iterate_over_facet_distribution<'t, CB>(
callback: CB,
) -> Result<()>
where
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
{
let mut fd = FacetDistribution { rtxn, db, field_id, callback };
let highest_level =
@ -32,7 +33,7 @@ where
struct FacetDistribution<'t, CB>
where
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
{
rtxn: &'t heed::RoTxn<'t>,
db: heed::Database<FacetGroupKeyCodec<ByteSliceRef>, FacetGroupValueCodec>,
@ -42,7 +43,7 @@ where
impl<'t, CB> FacetDistribution<'t, CB>
where
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
{
fn iterate_level_0(
&mut self,
@ -62,7 +63,8 @@ where
}
let docids_in_common = value.bitmap.intersection_len(candidates);
if docids_in_common > 0 {
match (self.callback)(key.left_bound, docids_in_common) {
let any_docid = value.bitmap.iter().next().unwrap();
match (self.callback)(key.left_bound, docids_in_common, any_docid)? {
ControlFlow::Continue(_) => {}
ControlFlow::Break(_) => return Ok(ControlFlow::Break(())),
}
@ -112,50 +114,14 @@ where
#[cfg(test)]
mod tests {
use super::iterate_over_facet_distribution;
use crate::milli_snap;
use crate::search::facet::tests::get_random_looking_index;
use crate::{heed_codec::facet::OrderedF64Codec, search::facet::tests::get_simple_index};
use heed::BytesDecode;
use roaring::RoaringBitmap;
use std::ops::ControlFlow;
use super::iterate_over_facet_distribution;
use crate::heed_codec::facet::OrderedF64Codec;
use crate::milli_snap;
use crate::update::facet::tests::FacetIndex;
use heed::BytesDecode;
use rand::{Rng, SeedableRng};
use roaring::RoaringBitmap;
fn get_simple_index() -> FacetIndex<OrderedF64Codec> {
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
let mut txn = index.env.write_txn().unwrap();
for i in 0..256u16 {
let mut bitmap = RoaringBitmap::new();
bitmap.insert(i as u32);
index.insert(&mut txn, 0, &(i as f64), &bitmap);
}
txn.commit().unwrap();
index
}
fn get_random_looking_index() -> FacetIndex<OrderedF64Codec> {
let index = FacetIndex::<OrderedF64Codec>::new(4, 8, 5);
let mut txn = index.env.write_txn().unwrap();
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
let keys =
std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).collect::<Vec<u32>>();
for (_i, key) in keys.into_iter().enumerate() {
let mut bitmap = RoaringBitmap::new();
bitmap.insert(key);
bitmap.insert(key + 100);
index.insert(&mut txn, 0, &(key as f64), &bitmap);
}
txn.commit().unwrap();
index
}
#[test]
fn random_looking_index_snap() {
let index = get_random_looking_index();
milli_snap!(format!("{index}"));
}
#[test]
fn filter_distribution_all() {
let indexes = [get_simple_index(), get_random_looking_index()];
@ -163,11 +129,17 @@ mod tests {
let txn = index.env.read_txn().unwrap();
let candidates = (0..=255).into_iter().collect::<RoaringBitmap>();
let mut results = String::new();
iterate_over_facet_distribution(&txn, index.content, 0, &candidates, |facet, count| {
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
results.push_str(&format!("{facet}: {count}\n"));
ControlFlow::Continue(())
})
iterate_over_facet_distribution(
&txn,
index.content,
0,
&candidates,
|facet, count, _| {
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
results.push_str(&format!("{facet}: {count}\n"));
Ok(ControlFlow::Continue(()))
},
)
.unwrap();
milli_snap!(results, i);
@ -182,17 +154,23 @@ mod tests {
let candidates = (0..=255).into_iter().collect::<RoaringBitmap>();
let mut results = String::new();
let mut nbr_facets = 0;
iterate_over_facet_distribution(&txn, index.content, 0, &candidates, |facet, count| {
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
if nbr_facets == 100 {
return ControlFlow::Break(());
} else {
nbr_facets += 1;
results.push_str(&format!("{facet}: {count}\n"));
iterate_over_facet_distribution(
&txn,
index.content,
0,
&candidates,
|facet, count, _| {
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
if nbr_facets == 100 {
return Ok(ControlFlow::Break(()));
} else {
nbr_facets += 1;
results.push_str(&format!("{facet}: {count}\n"));
ControlFlow::Continue(())
}
})
Ok(ControlFlow::Continue(()))
}
},
)
.unwrap();
milli_snap!(results, i);