Introduce a WordDerivationsCache struct

This commit is contained in:
Clément Renault
2021-03-05 11:02:24 +01:00
committed by Kerollmops
parent 2606c92ef9
commit 5fcaedb880
7 changed files with 169 additions and 100 deletions

View File

@ -1,5 +1,7 @@
use std::borrow::Cow;
use std::collections::hash_map::{HashMap, Entry};
use std::fmt;
use std::str::Utf8Error;
use std::time::Instant;
use fst::{IntoStreamer, Streamer, Set};
@ -97,8 +99,9 @@ impl<'a> Search<'a> {
let mut offset = self.offset;
let mut limit = self.limit;
let mut documents_ids = Vec::new();
let mut words_derivations_cache = WordDerivationsCache::new();
let mut initial_candidates = RoaringBitmap::new();
while let Some(CriterionResult { candidates, bucket_candidates, .. }) = criteria.next()? {
while let Some(CriterionResult { candidates, bucket_candidates, .. }) = criteria.next(&mut words_derivations_cache)? {
debug!("Number of candidates found {}", candidates.len());
@ -145,24 +148,32 @@ pub struct SearchResult {
pub documents_ids: Vec<DocumentId>,
}
pub fn word_derivations(
pub type WordDerivationsCache = HashMap<(String, bool, u8), Vec<(String, u8)>>;
pub fn word_derivations<'c>(
word: &str,
is_prefix: bool,
max_typo: u8,
fst: &fst::Set<Cow<[u8]>>,
) -> anyhow::Result<Vec<(String, u8)>>
cache: &'c mut WordDerivationsCache,
) -> Result<&'c [(String, u8)], Utf8Error>
{
let mut derived_words = Vec::new();
let dfa = build_dfa(word, max_typo, is_prefix);
let mut stream = fst.search_with_state(&dfa).into_stream();
match cache.entry((word.to_string(), is_prefix, max_typo)) {
Entry::Occupied(entry) => Ok(entry.into_mut()),
Entry::Vacant(entry) => {
let mut derived_words = Vec::new();
let dfa = build_dfa(word, max_typo, is_prefix);
let mut stream = fst.search_with_state(&dfa).into_stream();
while let Some((word, state)) = stream.next() {
let word = std::str::from_utf8(word)?;
let distance = dfa.distance(state);
derived_words.push((word.to_string(), distance.to_u8()));
while let Some((word, state)) = stream.next() {
let word = std::str::from_utf8(word)?;
let distance = dfa.distance(state);
derived_words.push((word.to_string(), distance.to_u8()));
}
Ok(entry.insert(derived_words))
},
}
Ok(derived_words)
}
pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA {