mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 21:16:28 +00:00 
			
		
		
		
	feat: Move the multi-word rewriting algorithm into its own function
This commit is contained in:
		| @@ -14,7 +14,7 @@ meilidb-tokenizer = { path = "../meilidb-tokenizer", version = "0.1.0" } | ||||
| rayon = "1.0.3" | ||||
| sdset = "0.3.2" | ||||
| serde = { version = "1.0.88", features = ["derive"] } | ||||
| slice-group-by = "0.2.4" | ||||
| slice-group-by = "0.2.6" | ||||
| zerocopy = "0.2.2" | ||||
|  | ||||
| [dependencies.fst] | ||||
|   | ||||
| @@ -21,7 +21,7 @@ fn custom_log10(n: u8) -> f32 { | ||||
|  | ||||
| #[inline] | ||||
| fn sum_matches_typos(query_index: &[u32], distance: &[u8]) -> usize { | ||||
|     let mut number_words = 0; | ||||
|     let mut number_words: usize = 0; | ||||
|     let mut sum_typos = 0.0; | ||||
|     let mut index = 0; | ||||
|  | ||||
|   | ||||
| @@ -197,6 +197,110 @@ impl<'c, S, FI> QueryBuilder<'c, S, FI> | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn multiword_rewrite_matches( | ||||
|     mut matches: Vec<(DocumentId, TmpMatch)>, | ||||
|     query_enhancer: &QueryEnhancer, | ||||
| ) -> SetBuf<(DocumentId, TmpMatch)> | ||||
| { | ||||
|     let mut padded_matches = Vec::with_capacity(matches.len()); | ||||
|  | ||||
|     // we sort the matches by word index to make them rewritable | ||||
|     let start = Instant::now(); | ||||
|     matches.par_sort_unstable_by_key(|(id, match_)| (*id, match_.attribute, match_.word_index)); | ||||
|     info!("rewrite sort by word_index took {:.2?}", start.elapsed()); | ||||
|  | ||||
|     let start = Instant::now(); | ||||
|     // for each attribute of each document | ||||
|     for same_document_attribute in matches.linear_group_by_key(|(id, m)| (*id, m.attribute)) { | ||||
|  | ||||
|         // padding will only be applied | ||||
|         // to word indices in the same attribute | ||||
|         let mut padding = 0; | ||||
|         let mut iter = same_document_attribute.linear_group_by_key(|(_, m)| m.word_index); | ||||
|  | ||||
|         // for each match at the same position | ||||
|         // in this document attribute | ||||
|         while let Some(same_word_index) = iter.next() { | ||||
|  | ||||
|             // find the biggest padding | ||||
|             let mut biggest = 0; | ||||
|             for (id, match_) in same_word_index { | ||||
|  | ||||
|                 let mut replacement = query_enhancer.replacement(match_.query_index); | ||||
|                 let replacement_len = replacement.len(); | ||||
|                 let nexts = iter.remainder().linear_group_by_key(|(_, m)| m.word_index); | ||||
|  | ||||
|                 if let Some(query_index) = replacement.next() { | ||||
|                     let word_index = match_.word_index + padding as u16; | ||||
|                     let match_ = TmpMatch { query_index, word_index, ..match_.clone() }; | ||||
|                     padded_matches.push((*id, match_)); | ||||
|                 } | ||||
|  | ||||
|                 let mut found = false; | ||||
|  | ||||
|                 // look ahead and if there already is a match | ||||
|                 // corresponding to this padding word, abort the padding | ||||
|                 'padding: for (x, next_group) in nexts.enumerate() { | ||||
|  | ||||
|                     for (i, query_index) in replacement.clone().enumerate().skip(x) { | ||||
|                         let word_index = match_.word_index + padding as u16 + (i + 1) as u16; | ||||
|                         let padmatch = TmpMatch { query_index, word_index, ..match_.clone() }; | ||||
|  | ||||
|                         for (_, nmatch_) in next_group { | ||||
|                             let mut rep = query_enhancer.replacement(nmatch_.query_index); | ||||
|                             let query_index = rep.next().unwrap(); | ||||
|                             if query_index == padmatch.query_index { | ||||
|  | ||||
|                                 if !found { | ||||
|                                     // if we find a corresponding padding for the | ||||
|                                     // first time we must push preceding paddings | ||||
|                                     for (i, query_index) in replacement.clone().enumerate().take(i) { | ||||
|                                         let word_index = match_.word_index + padding as u16 + (i + 1) as u16; | ||||
|                                         let match_ = TmpMatch { query_index, word_index, ..match_.clone() }; | ||||
|                                         padded_matches.push((*id, match_)); | ||||
|                                         biggest = biggest.max(i + 1); | ||||
|                                     } | ||||
|                                 } | ||||
|  | ||||
|                                 padded_matches.push((*id, padmatch)); | ||||
|                                 found = true; | ||||
|                                 continue 'padding; | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     // if we do not find a corresponding padding in the | ||||
|                     // next groups so stop here and pad what was found | ||||
|                     break | ||||
|                 } | ||||
|  | ||||
|                 if !found { | ||||
|                     // if no padding was found in the following matches | ||||
|                     // we must insert the entire padding | ||||
|                     for (i, query_index) in replacement.enumerate() { | ||||
|                         let word_index = match_.word_index + padding as u16 + (i + 1) as u16; | ||||
|                         let match_ = TmpMatch { query_index, word_index, ..match_.clone() }; | ||||
|                         padded_matches.push((*id, match_)); | ||||
|                     } | ||||
|  | ||||
|                     biggest = biggest.max(replacement_len - 1); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             padding += biggest; | ||||
|         } | ||||
|     } | ||||
|     info!("main multiword rewrite took {:.2?}", start.elapsed()); | ||||
|  | ||||
|     let start = Instant::now(); | ||||
|     for document_matches in padded_matches.linear_group_by_key_mut(|(id, _)| *id) { | ||||
|         document_matches.sort_unstable(); | ||||
|     } | ||||
|     info!("final rewrite sort took {:.2?}", start.elapsed()); | ||||
|  | ||||
|     SetBuf::new_unchecked(padded_matches) | ||||
| } | ||||
|  | ||||
| impl<'c, S, FI> QueryBuilder<'c, S, FI> | ||||
| where S: Store, | ||||
| { | ||||
| @@ -217,22 +321,26 @@ where S: Store, | ||||
|         let mut matches = Vec::new(); | ||||
|         let mut highlights = Vec::new(); | ||||
|  | ||||
|         let mut query_db = std::time::Duration::default(); | ||||
|  | ||||
|         let start = Instant::now(); | ||||
|         while let Some((input, indexed_values)) = stream.next() { | ||||
|             for iv in indexed_values { | ||||
|                 let Automaton { is_exact, query_len, ref dfa } = automatons[iv.index]; | ||||
|                 let distance = dfa.eval(input).to_u8(); | ||||
|                 let is_exact = is_exact && distance == 0 && input.len() == query_len; | ||||
|  | ||||
|                 let start = Instant::now(); | ||||
|                 let doc_indexes = self.store.word_indexes(input)?; | ||||
|                 let doc_indexes = match doc_indexes { | ||||
|                     Some(doc_indexes) => doc_indexes, | ||||
|                     None => continue, | ||||
|                 }; | ||||
|                 query_db += start.elapsed(); | ||||
|  | ||||
|                 for di in doc_indexes.as_slice() { | ||||
|                     let attribute = searchables.map_or(Some(di.attribute), |r| r.get(di.attribute)); | ||||
|                     if let Some(attribute) = attribute { | ||||
|  | ||||
|                         let match_ = TmpMatch { | ||||
|                             query_index: iv.index as u32, | ||||
|                             distance, | ||||
| @@ -253,118 +361,28 @@ where S: Store, | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         info!("main query all took {:.2?} (get indexes {:.2?})", start.elapsed(), query_db); | ||||
|  | ||||
|         // we sort the matches to make them rewritable | ||||
|         matches.par_sort_unstable_by_key(|(id, match_)| (*id, match_.attribute, match_.word_index)); | ||||
|         info!("{} total matches to rewrite", matches.len()); | ||||
|  | ||||
|         let mut padded_matches = Vec::with_capacity(matches.len()); | ||||
|         for same_document in matches.linear_group_by(|a, b| a.0 == b.0) { | ||||
|  | ||||
|             for same_attribute in same_document.linear_group_by(|a, b| a.1.attribute == b.1.attribute) { | ||||
|  | ||||
|                 let mut padding = 0; | ||||
|                 let mut iter = same_attribute.linear_group_by(|a, b| a.1.word_index == b.1.word_index); | ||||
|                 while let Some(same_word_index) = iter.next() { | ||||
|  | ||||
|                     let mut biggest = 0; | ||||
|                     for (id, match_) in same_word_index { | ||||
|  | ||||
|                         let mut replacement = query_enhancer.replacement(match_.query_index); | ||||
|                         let replacement_len = replacement.len() - 1; | ||||
|                         let nexts = iter.remainder().linear_group_by(|a, b| a.1.word_index == b.1.word_index); | ||||
|  | ||||
|                         if let Some(query_index) = replacement.next() { | ||||
|                             let match_ = TmpMatch { | ||||
|                                 query_index, | ||||
|                                 word_index: match_.word_index + padding as u16, | ||||
|                                 ..match_.clone() | ||||
|                             }; | ||||
|                             padded_matches.push((*id, match_)); | ||||
|                         } | ||||
|  | ||||
|                         let mut found = false; | ||||
|  | ||||
|                         // look ahead and if there already is a match | ||||
|                         // corresponding to this padding word, abort the padding | ||||
|                         'padding: for (x, next_group) in nexts.enumerate() { | ||||
|  | ||||
|                             for (i, query_index) in replacement.clone().enumerate().skip(x) { | ||||
|                                 let padmatch_ = TmpMatch { | ||||
|                                     query_index, | ||||
|                                     word_index: match_.word_index + padding as u16 + (i + 1) as u16, | ||||
|                                     ..match_.clone() | ||||
|                                 }; | ||||
|  | ||||
|                                 for (_, nmatch_) in next_group { | ||||
|                                     let mut rep = query_enhancer.replacement(nmatch_.query_index); | ||||
|                                     let query_index = rep.next().unwrap(); | ||||
|                                     let nmatch_ = TmpMatch { query_index, ..nmatch_.clone() }; | ||||
|                                     if nmatch_.query_index == padmatch_.query_index { | ||||
|  | ||||
|                                         if !found { | ||||
|                                             // if we find a corresponding padding for the | ||||
|                                             // first time we must push preceding paddings | ||||
|                                             for (i, query_index) in replacement.clone().enumerate().take(i) { | ||||
|                                                 let match_ = TmpMatch { | ||||
|                                                     query_index, | ||||
|                                                     word_index: match_.word_index + padding as u16 + (i + 1) as u16, | ||||
|                                                     ..match_.clone() | ||||
|                                                 }; | ||||
|                                                 padded_matches.push((*id, match_)); | ||||
|                                                 biggest = biggest.max(i + 1); | ||||
|                                             } | ||||
|                                         } | ||||
|  | ||||
|                                         padded_matches.push((*id, padmatch_)); | ||||
|                                         found = true; | ||||
|                                         continue 'padding; | ||||
|                                     } | ||||
|                                 } | ||||
|                             } | ||||
|  | ||||
|                             // if we do not find a corresponding padding in the | ||||
|                             // next groups so stop here and pad what was found | ||||
|                             break | ||||
|                         } | ||||
|  | ||||
|                         if !found { | ||||
|                             // if no padding was found in the following matches | ||||
|                             // we must insert the entire padding | ||||
|                             for (i, query_index) in replacement.enumerate() { | ||||
|                                 let match_ = TmpMatch { | ||||
|                                     query_index, | ||||
|                                     word_index: match_.word_index + padding as u16 + (i + 1) as u16, | ||||
|                                     ..match_.clone() | ||||
|                                 }; | ||||
|                                 padded_matches.push((*id, match_)); | ||||
|                             } | ||||
|  | ||||
|                             biggest = biggest.max(replacement_len); | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     padding += biggest; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|         } | ||||
|  | ||||
|  | ||||
|         let matches = { | ||||
|             padded_matches.par_sort_unstable(); | ||||
|             SetBuf::new_unchecked(padded_matches) | ||||
|         }; | ||||
|         let start = Instant::now(); | ||||
|         let matches = multiword_rewrite_matches(matches, &query_enhancer); | ||||
|         info!("multiword rewrite took {:.2?}", start.elapsed()); | ||||
|  | ||||
|         let start = Instant::now(); | ||||
|         let highlights = { | ||||
|             highlights.par_sort_unstable_by_key(|(id, _)| *id); | ||||
|             SetBuf::new_unchecked(highlights) | ||||
|         }; | ||||
|         info!("sorting highlights took {:.2?}", start.elapsed()); | ||||
|  | ||||
|         let total_matches = matches.len(); | ||||
|         info!("{} total matches to classify", matches.len()); | ||||
|  | ||||
|         let start = Instant::now(); | ||||
|         let raw_documents = raw_documents_from(matches, highlights); | ||||
|         info!("making raw documents took {:.2?}", start.elapsed()); | ||||
|  | ||||
|         info!("{} total documents to classify", raw_documents.len()); | ||||
|         info!("{} total matches to classify", total_matches); | ||||
|  | ||||
|         Ok(raw_documents) | ||||
|     } | ||||
|   | ||||
| @@ -52,17 +52,20 @@ where S: AsRef<str>, | ||||
|     !original.map(AsRef::as_ref).eq(words.iter().map(AsRef::as_ref)) | ||||
| } | ||||
|  | ||||
| type Origin = usize; | ||||
| type RealLength = usize; | ||||
|  | ||||
| struct FakeIntervalTree { | ||||
|     intervals: Vec<(Range<usize>, (usize, usize))>, // origin, real_length | ||||
|     intervals: Vec<(Range<usize>, (Origin, RealLength))>, | ||||
| } | ||||
|  | ||||
| impl FakeIntervalTree { | ||||
|     fn new(mut intervals: Vec<(Range<usize>, (usize, usize))>) -> FakeIntervalTree { | ||||
|     fn new(mut intervals: Vec<(Range<usize>, (Origin, RealLength))>) -> FakeIntervalTree { | ||||
|         intervals.sort_unstable_by_key(|(r, _)| (r.start, r.end)); | ||||
|         FakeIntervalTree { intervals } | ||||
|     } | ||||
|  | ||||
|     fn query(&self, point: usize) -> Option<(Range<usize>, (usize, usize))> { | ||||
|     fn query(&self, point: usize) -> Option<(Range<usize>, (Origin, RealLength))> { | ||||
|         let element = self.intervals.binary_search_by(|(r, _)| { | ||||
|             if point >= r.start { | ||||
|                 if point < r.end { Equal } else { Less } | ||||
| @@ -81,7 +84,7 @@ impl FakeIntervalTree { | ||||
| pub struct QueryEnhancerBuilder<'a, S> { | ||||
|     query: &'a [S], | ||||
|     origins: Vec<usize>, | ||||
|     real_to_origin: Vec<(Range<usize>, (usize, usize))>, | ||||
|     real_to_origin: Vec<(Range<usize>, (Origin, RealLength))>, | ||||
| } | ||||
|  | ||||
| impl<S: AsRef<str>> QueryEnhancerBuilder<'_, S> { | ||||
| @@ -147,8 +150,8 @@ impl QueryEnhancer { | ||||
|         // query the fake interval tree with the real query index | ||||
|         let (range, (origin, real_length)) = | ||||
|             self.real_to_origin | ||||
|             .query(real) | ||||
|             .expect("real has never been declared"); | ||||
|                 .query(real) | ||||
|                 .expect("real has never been declared"); | ||||
|  | ||||
|         // if `real` is the end bound of the range | ||||
|         if (range.start + real_length - 1) == real { | ||||
|   | ||||
| @@ -74,8 +74,8 @@ pub fn raw_documents_from( | ||||
|     let mut docs_ranges: Vec<(_, Range, _)> = Vec::new(); | ||||
|     let mut matches2 = Matches::with_capacity(matches.len()); | ||||
|  | ||||
|     let matches = matches.linear_group_by(|(a, _), (b, _)| a == b); | ||||
|     let highlights = highlights.linear_group_by(|(a, _), (b, _)| a == b); | ||||
|     let matches = matches.linear_group_by_key(|(id, _)| *id); | ||||
|     let highlights = highlights.linear_group_by_key(|(id, _)| *id); | ||||
|  | ||||
|     for (mgroup, hgroup) in matches.zip(highlights) { | ||||
|         debug_assert_eq!(mgroup[0].0, hgroup[0].0); | ||||
|   | ||||
| @@ -21,10 +21,10 @@ impl<'a> SynonymsAddition<'a> { | ||||
|     pub fn add_synonym<S, T, I>(&mut self, synonym: S, alternatives: I) | ||||
|     where S: AsRef<str>, | ||||
|           T: AsRef<str>, | ||||
|           I: Iterator<Item=T>, | ||||
|           I: IntoIterator<Item=T>, | ||||
|     { | ||||
|         let synonym = normalize_str(synonym.as_ref()); | ||||
|         let alternatives = alternatives.map(|s| s.as_ref().to_lowercase()); | ||||
|         let alternatives = alternatives.into_iter().map(|s| s.as_ref().to_lowercase()); | ||||
|         self.synonyms.entry(synonym).or_insert_with(Vec::new).extend(alternatives); | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -31,9 +31,13 @@ pub struct Opt { | ||||
|     #[structopt(long = "schema", parse(from_os_str))] | ||||
|     pub schema_path: PathBuf, | ||||
|  | ||||
|     /// The file with the synonyms. | ||||
|     #[structopt(long = "synonyms", parse(from_os_str))] | ||||
|     pub synonyms: Option<PathBuf>, | ||||
|  | ||||
|     /// The path to the list of stop words (one by line). | ||||
|     #[structopt(long = "stop-words", parse(from_os_str))] | ||||
|     pub stop_words_path: Option<PathBuf>, | ||||
|     pub stop_words: Option<PathBuf>, | ||||
|  | ||||
|     #[structopt(long = "update-group-size")] | ||||
|     pub update_group_size: Option<usize>, | ||||
| @@ -45,12 +49,40 @@ struct Document<'a> ( | ||||
|     HashMap<Cow<'a, str>, Cow<'a, str>> | ||||
| ); | ||||
|  | ||||
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||||
| #[serde(untagged)] | ||||
| pub enum Synonym { | ||||
|     OneWay(SynonymOneWay), | ||||
|     MultiWay { synonyms: Vec<String> }, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||||
| #[serde(rename_all = "camelCase")] | ||||
| pub struct SynonymOneWay { | ||||
|     pub search_terms: String, | ||||
|     pub synonyms: Synonyms, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||||
| #[serde(untagged)] | ||||
| pub enum Synonyms { | ||||
|     Multiple(Vec<String>), | ||||
|     Single(String), | ||||
| } | ||||
|  | ||||
| fn read_synomys(path: &Path) -> Result<Vec<Synonym>, Box<dyn Error>> { | ||||
|     let file = File::open(path)?; | ||||
|     let synonyms = serde_json::from_reader(file)?; | ||||
|     Ok(synonyms) | ||||
| } | ||||
|  | ||||
| fn index( | ||||
|     schema: Schema, | ||||
|     database_path: &Path, | ||||
|     csv_data_path: &Path, | ||||
|     update_group_size: Option<usize>, | ||||
|     stop_words: &HashSet<String>, | ||||
|     synonyms: Vec<Synonym>, | ||||
| ) -> Result<Database, Box<dyn Error>> | ||||
| { | ||||
|     let database = Database::start_default(database_path)?; | ||||
| @@ -62,6 +94,28 @@ fn index( | ||||
|  | ||||
|     let index = database.create_index("test", schema.clone())?; | ||||
|  | ||||
|     let mut synonyms_adder = index.synonyms_addition(); | ||||
|     for synonym in synonyms { | ||||
|         match synonym { | ||||
|             Synonym::OneWay(SynonymOneWay { search_terms, synonyms }) => { | ||||
|                 let alternatives = match synonyms { | ||||
|                     Synonyms::Multiple(alternatives) => alternatives, | ||||
|                     Synonyms::Single(alternative) => vec![alternative], | ||||
|                 }; | ||||
|                 synonyms_adder.add_synonym(search_terms, alternatives); | ||||
|             }, | ||||
|             Synonym::MultiWay { mut synonyms } => { | ||||
|                 for _ in 0..synonyms.len() { | ||||
|                     if let Some((synonym, alternatives)) = synonyms.split_first() { | ||||
|                         synonyms_adder.add_synonym(synonym, alternatives); | ||||
|                     } | ||||
|                     synonyms.rotate_left(1); | ||||
|                 } | ||||
|             }, | ||||
|         } | ||||
|     } | ||||
|     synonyms_adder.finalize()?; | ||||
|  | ||||
|     let mut rdr = csv::Reader::from_path(csv_data_path)?; | ||||
|     let mut raw_record = csv::StringRecord::new(); | ||||
|     let headers = rdr.headers()?.clone(); | ||||
| @@ -133,13 +187,25 @@ fn main() -> Result<(), Box<dyn Error>> { | ||||
|         Schema::from_toml(file)? | ||||
|     }; | ||||
|  | ||||
|     let stop_words = match opt.stop_words_path { | ||||
|     let stop_words = match opt.stop_words { | ||||
|         Some(ref path) => retrieve_stop_words(path)?, | ||||
|         None           => HashSet::new(), | ||||
|     }; | ||||
|  | ||||
|     let synonyms = match opt.synonyms { | ||||
|         Some(ref path) => read_synomys(path)?, | ||||
|         None           => Vec::new(), | ||||
|     }; | ||||
|  | ||||
|     let start = Instant::now(); | ||||
|     let result = index(schema, &opt.database_path, &opt.csv_data_path, opt.update_group_size, &stop_words); | ||||
|     let result = index( | ||||
|         schema, | ||||
|         &opt.database_path, | ||||
|         &opt.csv_data_path, | ||||
|         opt.update_group_size, | ||||
|         &stop_words, | ||||
|         synonyms, | ||||
|     ); | ||||
|  | ||||
|     if let Err(e) = result { | ||||
|         return Err(e.into()) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user