mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 21:16:28 +00:00 
			
		
		
		
	Store the word positions under the documents
This commit is contained in:
		| @@ -1,3 +1,4 @@ | ||||
| use std::collections::HashMap; | ||||
| use std::convert::{TryFrom, TryInto}; | ||||
| use std::fs::File; | ||||
| use std::io::{self, Read, Write}; | ||||
| @@ -13,6 +14,7 @@ use cow_utils::CowUtils; | ||||
| use csv::StringRecord; | ||||
| use flate2::read::GzDecoder; | ||||
| use fst::IntoStreamer; | ||||
| use heed::BytesDecode; | ||||
| use heed::BytesEncode; | ||||
| use heed::EnvOpenOptions; | ||||
| use heed::types::*; | ||||
| @@ -25,7 +27,7 @@ use structopt::StructOpt; | ||||
|  | ||||
| use milli::heed_codec::CsvStringRecordCodec; | ||||
| use milli::tokenizer::{simple_tokenizer, only_words}; | ||||
| use milli::{SmallVec32, Index, DocumentId, Position, Attribute, BEU32}; | ||||
| use milli::{SmallVec32, Index, DocumentId, BEU32, StrBEU32Codec}; | ||||
|  | ||||
| const LMDB_MAX_KEY_LENGTH: usize = 511; | ||||
| const ONE_MILLION: usize = 1_000_000; | ||||
| @@ -37,10 +39,8 @@ const HEADERS_KEY: &[u8] = b"\0headers"; | ||||
| const DOCUMENTS_IDS_KEY: &[u8] = b"\x04documents-ids"; | ||||
| const WORDS_FST_KEY: &[u8] = b"\x06words-fst"; | ||||
| const DOCUMENTS_IDS_BYTE: u8 = 4; | ||||
| const WORD_ATTRIBUTE_DOCIDS_BYTE: u8 = 3; | ||||
| const WORD_FOUR_POSITIONS_DOCIDS_BYTE: u8 = 5; | ||||
| const WORD_POSITION_DOCIDS_BYTE: u8 = 2; | ||||
| const WORD_POSITIONS_BYTE: u8 = 1; | ||||
| const WORD_DOCIDS_BYTE: u8 = 2; | ||||
| const WORD_DOCID_POSITIONS_BYTE: u8 = 1; | ||||
|  | ||||
| #[cfg(target_os = "linux")] | ||||
| #[global_allocator] | ||||
| @@ -125,10 +125,7 @@ fn lmdb_key_valid_size(key: &[u8]) -> bool { | ||||
| type MergeFn = fn(&[u8], &[Vec<u8>]) -> Result<Vec<u8>, ()>; | ||||
|  | ||||
| struct Store { | ||||
|     word_positions: ArcCache<SmallVec32<u8>, RoaringBitmap>, | ||||
|     word_position_docids: ArcCache<(SmallVec32<u8>, Position), RoaringBitmap>, | ||||
|     word_four_positions_docids: ArcCache<(SmallVec32<u8>, Position), RoaringBitmap>, | ||||
|     word_attribute_docids: ArcCache<(SmallVec32<u8>, Attribute), RoaringBitmap>, | ||||
|     word_docids: ArcCache<SmallVec32<u8>, RoaringBitmap>, | ||||
|     documents_ids: RoaringBitmap, | ||||
|     sorter: Sorter<MergeFn>, | ||||
|     documents_sorter: Sorter<MergeFn>, | ||||
| @@ -162,10 +159,7 @@ impl Store { | ||||
|         } | ||||
|  | ||||
|         Store { | ||||
|             word_positions: ArcCache::new(arc_cache_size), | ||||
|             word_position_docids: ArcCache::new(arc_cache_size), | ||||
|             word_four_positions_docids: ArcCache::new(arc_cache_size), | ||||
|             word_attribute_docids: ArcCache::new(arc_cache_size), | ||||
|             word_docids: ArcCache::new(arc_cache_size), | ||||
|             documents_ids: RoaringBitmap::new(), | ||||
|             sorter: builder.build(), | ||||
|             documents_sorter: documents_builder.build(), | ||||
| @@ -173,65 +167,48 @@ impl Store { | ||||
|     } | ||||
|  | ||||
|     // Save the documents ids under the position and word we have seen it. | ||||
|     pub fn insert_word_position_docid(&mut self, word: &str, position: Position, id: DocumentId) -> anyhow::Result<()> { | ||||
|     pub fn insert_word_docid(&mut self, word: &str, id: DocumentId) -> anyhow::Result<()> { | ||||
|         let word_vec = SmallVec32::from(word.as_bytes()); | ||||
|         let ids = RoaringBitmap::from_iter(Some(id)); | ||||
|         let (_, lrus) = self.word_position_docids.insert((word_vec, position), ids, |old, new| old.union_with(&new)); | ||||
|         Self::write_word_position_docids(&mut self.sorter, lrus)?; | ||||
|         self.insert_word_position(word, position)?; | ||||
|         self.insert_word_four_positions_docid(word, position, id)?; | ||||
|         self.insert_word_attribute_docid(word, position / MAX_POSITION as u32, id)?; | ||||
|         let (_, lrus) = self.word_docids.insert(word_vec, ids, |old, new| old.union_with(&new)); | ||||
|         Self::write_word_docids(&mut self.sorter, lrus)?; | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     pub fn insert_word_four_positions_docid(&mut self, word: &str, position: Position, id: DocumentId) -> anyhow::Result<()> { | ||||
|         let position = position - position % 4; | ||||
|         let word_vec = SmallVec32::from(word.as_bytes()); | ||||
|         let ids = RoaringBitmap::from_iter(Some(id)); | ||||
|         let (_, lrus) = self.word_four_positions_docids.insert((word_vec, position), ids, |old, new| old.union_with(&new)); | ||||
|         Self::write_word_four_positions_docids(&mut self.sorter, lrus) | ||||
|     } | ||||
|  | ||||
|     // Save the positions where this word has been seen. | ||||
|     pub fn insert_word_position(&mut self, word: &str, position: Position) -> anyhow::Result<()> { | ||||
|         let word = SmallVec32::from(word.as_bytes()); | ||||
|         let position = RoaringBitmap::from_iter(Some(position)); | ||||
|         let (_, lrus) = self.word_positions.insert(word, position, |old, new| old.union_with(&new)); | ||||
|         Self::write_word_positions(&mut self.sorter, lrus) | ||||
|     } | ||||
|  | ||||
|     // Save the documents ids under the attribute and word we have seen it. | ||||
|     fn insert_word_attribute_docid(&mut self, word: &str, attribute: Attribute, id: DocumentId) -> anyhow::Result<()> { | ||||
|         let word = SmallVec32::from(word.as_bytes()); | ||||
|         let ids = RoaringBitmap::from_iter(Some(id)); | ||||
|         let (_, lrus) = self.word_attribute_docids.insert((word, attribute), ids, |old, new| old.union_with(&new)); | ||||
|         Self::write_word_attribute_docids(&mut self.sorter, lrus) | ||||
|     } | ||||
|  | ||||
|     pub fn write_headers(&mut self, headers: &StringRecord) -> anyhow::Result<()> { | ||||
|         let headers = CsvStringRecordCodec::bytes_encode(headers) | ||||
|             .with_context(|| format!("could not encode csv record"))?; | ||||
|         Ok(self.sorter.insert(HEADERS_KEY, headers)?) | ||||
|     } | ||||
|  | ||||
|     pub fn write_document(&mut self, id: DocumentId, record: &StringRecord) -> anyhow::Result<()> { | ||||
|     pub fn write_document( | ||||
|         &mut self, | ||||
|         id: DocumentId, | ||||
|         iter: impl IntoIterator<Item=(String, RoaringBitmap)>, | ||||
|         record: &StringRecord, | ||||
|     ) -> anyhow::Result<()> | ||||
|     { | ||||
|         let record = CsvStringRecordCodec::bytes_encode(record) | ||||
|             .with_context(|| format!("could not encode csv record"))?; | ||||
|         self.documents_ids.insert(id); | ||||
|         Ok(self.documents_sorter.insert(id.to_be_bytes(), record)?) | ||||
|         self.documents_sorter.insert(id.to_be_bytes(), record)?; | ||||
|         Self::write_docid_word_positions(&mut self.sorter, id, iter)?; | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     fn write_word_positions<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> | ||||
|     where I: IntoIterator<Item=(SmallVec32<u8>, RoaringBitmap)> | ||||
|     fn write_docid_word_positions<I>(sorter: &mut Sorter<MergeFn>, id: DocumentId, iter: I) -> anyhow::Result<()> | ||||
|     where I: IntoIterator<Item=(String, RoaringBitmap)> | ||||
|     { | ||||
|         // postings ids keys are all prefixed | ||||
|         let mut key = vec![WORD_POSITIONS_BYTE]; | ||||
|         // postings positions ids keys are all prefixed | ||||
|         let mut key = vec![WORD_DOCID_POSITIONS_BYTE]; | ||||
|         let mut buffer = Vec::new(); | ||||
|  | ||||
|         for (word, positions) in iter { | ||||
|             key.truncate(1); | ||||
|             key.extend_from_slice(&word); | ||||
|             // We serialize the positions into a buffer | ||||
|             key.extend_from_slice(word.as_bytes()); | ||||
|             // We prefix the words by the document id. | ||||
|             key.extend_from_slice(&id.to_be_bytes()); | ||||
|             // We serialize the document ids into a buffer | ||||
|             buffer.clear(); | ||||
|             buffer.reserve(positions.serialized_size()); | ||||
|             positions.serialize_into(&mut buffer)?; | ||||
| @@ -244,68 +221,16 @@ impl Store { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     fn write_word_position_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> | ||||
|     where I: IntoIterator<Item=((SmallVec32<u8>, Position), RoaringBitmap)> | ||||
|     fn write_word_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> | ||||
|     where I: IntoIterator<Item=(SmallVec32<u8>, RoaringBitmap)> | ||||
|     { | ||||
|         // postings positions ids keys are all prefixed | ||||
|         let mut key = vec![WORD_POSITION_DOCIDS_BYTE]; | ||||
|         let mut key = vec![WORD_DOCIDS_BYTE]; | ||||
|         let mut buffer = Vec::new(); | ||||
|  | ||||
|         for ((word, pos), ids) in iter { | ||||
|         for (word, ids) in iter { | ||||
|             key.truncate(1); | ||||
|             key.extend_from_slice(&word); | ||||
|             // we postfix the word by the positions it appears in | ||||
|             key.extend_from_slice(&pos.to_be_bytes()); | ||||
|             // We serialize the document ids into a buffer | ||||
|             buffer.clear(); | ||||
|             buffer.reserve(ids.serialized_size()); | ||||
|             ids.serialize_into(&mut buffer)?; | ||||
|             // that we write under the generated key into MTBL | ||||
|             if lmdb_key_valid_size(&key) { | ||||
|                 sorter.insert(&key, &buffer)?; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     fn write_word_four_positions_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> | ||||
|     where I: IntoIterator<Item=((SmallVec32<u8>, Position), RoaringBitmap)> | ||||
|     { | ||||
|         // postings positions ids keys are all prefixed | ||||
|         let mut key = vec![WORD_FOUR_POSITIONS_DOCIDS_BYTE]; | ||||
|         let mut buffer = Vec::new(); | ||||
|  | ||||
|         for ((word, pos), ids) in iter { | ||||
|             key.truncate(1); | ||||
|             key.extend_from_slice(&word); | ||||
|             // we postfix the word by the positions it appears in | ||||
|             key.extend_from_slice(&pos.to_be_bytes()); | ||||
|             // We serialize the document ids into a buffer | ||||
|             buffer.clear(); | ||||
|             buffer.reserve(ids.serialized_size()); | ||||
|             ids.serialize_into(&mut buffer)?; | ||||
|             // that we write under the generated key into MTBL | ||||
|             if lmdb_key_valid_size(&key) { | ||||
|                 sorter.insert(&key, &buffer)?; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     fn write_word_attribute_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> | ||||
|     where I: IntoIterator<Item=((SmallVec32<u8>, Attribute), RoaringBitmap)> | ||||
|     { | ||||
|         // postings attributes keys are all prefixed | ||||
|         let mut key = vec![WORD_ATTRIBUTE_DOCIDS_BYTE]; | ||||
|         let mut buffer = Vec::new(); | ||||
|  | ||||
|         for ((word, attr), ids) in iter { | ||||
|             key.truncate(1); | ||||
|             key.extend_from_slice(&word); | ||||
|             // we postfix the word by the positions it appears in | ||||
|             key.extend_from_slice(&attr.to_be_bytes()); | ||||
|             // We serialize the document ids into a buffer | ||||
|             buffer.clear(); | ||||
|             buffer.reserve(ids.serialized_size()); | ||||
| @@ -327,10 +252,7 @@ impl Store { | ||||
|     } | ||||
|  | ||||
|     pub fn finish(mut self) -> anyhow::Result<(Reader<Mmap>, Reader<Mmap>)> { | ||||
|         Self::write_word_positions(&mut self.sorter, self.word_positions)?; | ||||
|         Self::write_word_position_docids(&mut self.sorter, self.word_position_docids)?; | ||||
|         Self::write_word_four_positions_docids(&mut self.sorter, self.word_four_positions_docids)?; | ||||
|         Self::write_word_attribute_docids(&mut self.sorter, self.word_attribute_docids)?; | ||||
|         Self::write_word_docids(&mut self.sorter, self.word_docids)?; | ||||
|         Self::write_documents_ids(&mut self.sorter, self.documents_ids)?; | ||||
|  | ||||
|         let mut wtr = tempfile::tempfile().map(Writer::new)?; | ||||
| @@ -339,7 +261,8 @@ impl Store { | ||||
|         let mut iter = self.sorter.into_iter()?; | ||||
|         while let Some(result) = iter.next() { | ||||
|             let (key, val) = result?; | ||||
|             if let Some((&1, word)) = key.split_first() { | ||||
|             if let Some((&1, bytes)) = key.split_first() { | ||||
|                 let (word, _docid) = StrBEU32Codec::bytes_decode(bytes).unwrap(); | ||||
|                 // This is a lexicographically ordered word position | ||||
|                 // we use the key to construct the words fst. | ||||
|                 builder.insert(word)?; | ||||
| @@ -389,12 +312,7 @@ fn merge(key: &[u8], values: &[Vec<u8>]) -> Result<Vec<u8>, ()> { | ||||
|             Ok(values[0].to_vec()) | ||||
|         }, | ||||
|         key => match key[0] { | ||||
|                 DOCUMENTS_IDS_BYTE | ||||
|               | WORD_POSITIONS_BYTE | ||||
|               | WORD_POSITION_DOCIDS_BYTE | ||||
|               | WORD_FOUR_POSITIONS_DOCIDS_BYTE | ||||
|               | WORD_ATTRIBUTE_DOCIDS_BYTE => | ||||
|             { | ||||
|             DOCUMENTS_IDS_BYTE | WORD_DOCIDS_BYTE | WORD_DOCID_POSITIONS_BYTE => { | ||||
|                 let (head, tail) = values.split_first().unwrap(); | ||||
|  | ||||
|                 let mut head = RoaringBitmap::deserialize_from(head.as_slice()).unwrap(); | ||||
| @@ -427,24 +345,14 @@ fn lmdb_writer(wtxn: &mut heed::RwTxn, index: &Index, key: &[u8], val: &[u8]) -> | ||||
|         // Write the documents ids list | ||||
|         index.main.put::<_, Str, ByteSlice>(wtxn, "documents-ids", val)?; | ||||
|     } | ||||
|     else if key.starts_with(&[WORD_POSITIONS_BYTE]) { | ||||
|     else if key.starts_with(&[WORD_DOCIDS_BYTE]) { | ||||
|         // Write the postings lists | ||||
|         index.word_positions.as_polymorph() | ||||
|         index.word_docids.as_polymorph() | ||||
|             .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; | ||||
|     } | ||||
|     else if key.starts_with(&[WORD_POSITION_DOCIDS_BYTE]) { | ||||
|     else if key.starts_with(&[WORD_DOCID_POSITIONS_BYTE]) { | ||||
|         // Write the postings lists | ||||
|         index.word_position_docids.as_polymorph() | ||||
|             .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; | ||||
|     } | ||||
|     else if key.starts_with(&[WORD_FOUR_POSITIONS_DOCIDS_BYTE]) { | ||||
|         // Write the postings lists | ||||
|         index.word_four_positions_docids.as_polymorph() | ||||
|             .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; | ||||
|     } | ||||
|     else if key.starts_with(&[WORD_ATTRIBUTE_DOCIDS_BYTE]) { | ||||
|         // Write the attribute postings lists | ||||
|         index.word_attribute_docids.as_polymorph() | ||||
|         index.word_docid_positions.as_polymorph() | ||||
|             .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; | ||||
|     } | ||||
|  | ||||
| @@ -499,6 +407,7 @@ fn index_csv( | ||||
|     let mut before = Instant::now(); | ||||
|     let mut document_id: usize = 0; | ||||
|     let mut document = csv::StringRecord::new(); | ||||
|     let mut word_positions = HashMap::new(); | ||||
|     while rdr.read_record(&mut document)? { | ||||
|  | ||||
|         // We skip documents that must not be indexed by this thread. | ||||
| @@ -512,14 +421,15 @@ fn index_csv( | ||||
|             let document_id = DocumentId::try_from(document_id).context("generated id is too big")?; | ||||
|             for (attr, content) in document.iter().enumerate().take(MAX_ATTRIBUTES) { | ||||
|                 for (pos, (_, token)) in simple_tokenizer(&content).filter(only_words).enumerate().take(MAX_POSITION) { | ||||
|                     let word = token.cow_to_lowercase(); | ||||
|                     let word = token.to_lowercase(); | ||||
|                     let position = (attr * MAX_POSITION + pos) as u32; | ||||
|                     store.insert_word_position_docid(&word, position, document_id)?; | ||||
|                     store.insert_word_docid(&word, document_id)?; | ||||
|                     word_positions.entry(word).or_insert_with(RoaringBitmap::new).insert(position); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // We write the document in the database. | ||||
|             store.write_document(document_id, &document)?; | ||||
|             store.write_document(document_id, word_positions.drain(), &document)?; | ||||
|         } | ||||
|  | ||||
|         // Compute the document id of the the next document. | ||||
|   | ||||
| @@ -8,8 +8,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU32Codec { | ||||
|     type DItem = (&'a str, u32); | ||||
|  | ||||
|     fn bytes_decode(bytes: &'a [u8]) -> Option<Self::DItem> { | ||||
|         let str_len = bytes.len().checked_sub(4)?; | ||||
|         let (str_bytes, n_bytes) = bytes.split_at(str_len); | ||||
|         let (str_bytes, n_bytes) = bytes.split_at(bytes.len() - 4); | ||||
|         let s = str::from_utf8(str_bytes).ok()?; | ||||
|         let n = n_bytes.try_into().map(u32::from_be_bytes).ok()?; | ||||
|         Some((s, n)) | ||||
|   | ||||
							
								
								
									
										21
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								src/lib.rs
									
									
									
									
									
								
							| @@ -1,5 +1,4 @@ | ||||
| mod criterion; | ||||
| mod node; | ||||
| mod query_tokens; | ||||
| mod search; | ||||
| pub mod heed_codec; | ||||
| @@ -16,7 +15,7 @@ use heed::{PolyDatabase, Database}; | ||||
|  | ||||
| pub use self::search::{Search, SearchResult}; | ||||
| pub use self::criterion::{Criterion, default_criteria}; | ||||
| use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec, CsvStringRecordCodec}; | ||||
| pub use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec, CsvStringRecordCodec}; | ||||
|  | ||||
| pub type FastMap4<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher32>>; | ||||
| pub type FastMap8<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher64>>; | ||||
| @@ -36,14 +35,10 @@ const DOCUMENTS_IDS_KEY: &str = "documents-ids"; | ||||
| pub struct Index { | ||||
|     /// Contains many different types (e.g. the documents CSV headers). | ||||
|     pub main: PolyDatabase, | ||||
|     /// A word and all the positions where it appears in the whole dataset. | ||||
|     pub word_positions: Database<Str, RoaringBitmapCodec>, | ||||
|     /// Maps a word at a position (u32) and all the documents ids where the given word appears. | ||||
|     pub word_position_docids: Database<StrBEU32Codec, RoaringBitmapCodec>, | ||||
|     /// Maps a word and a range of 4 positions, i.e. 0..4, 4..8, 12..16. | ||||
|     pub word_four_positions_docids: Database<StrBEU32Codec, RoaringBitmapCodec>, | ||||
|     /// Maps a word and an attribute (u32) to all the documents ids where the given word appears. | ||||
|     pub word_attribute_docids: Database<StrBEU32Codec, RoaringBitmapCodec>, | ||||
|     /// A word and all the documents ids containing the word. | ||||
|     pub word_docids: Database<Str, RoaringBitmapCodec>, | ||||
|     /// Maps a word and a document id (u32) to all the positions where the given word appears. | ||||
|     pub word_docid_positions: Database<StrBEU32Codec, RoaringBitmapCodec>, | ||||
|     /// Maps the document id to the document as a CSV line. | ||||
|     pub documents: Database<OwnedType<BEU32>, ByteSlice>, | ||||
| } | ||||
| @@ -52,10 +47,8 @@ impl Index { | ||||
|     pub fn new(env: &heed::Env) -> anyhow::Result<Index> { | ||||
|         Ok(Index { | ||||
|             main: env.create_poly_database(None)?, | ||||
|             word_positions: env.create_database(Some("word-positions"))?, | ||||
|             word_position_docids: env.create_database(Some("word-position-docids"))?, | ||||
|             word_four_positions_docids: env.create_database(Some("word-four-positions-docids"))?, | ||||
|             word_attribute_docids: env.create_database(Some("word-attribute-docids"))?, | ||||
|             word_docids: env.create_database(Some("word-docids"))?, | ||||
|             word_docid_positions: env.create_database(Some("word-docid-positions"))?, | ||||
|             documents: env.create_database(Some("documents"))?, | ||||
|         }) | ||||
|     } | ||||
|   | ||||
							
								
								
									
										109
									
								
								src/node.rs
									
									
									
									
									
								
							
							
						
						
									
										109
									
								
								src/node.rs
									
									
									
									
									
								
							| @@ -1,109 +0,0 @@ | ||||
| use std::cmp; | ||||
| use roaring::RoaringBitmap; | ||||
|  | ||||
| const ONE_ATTRIBUTE: u32 = 1000; | ||||
| const MAX_DISTANCE: u32 = 8; | ||||
|  | ||||
| fn index_proximity(lhs: u32, rhs: u32) -> u32 { | ||||
|     if lhs <= rhs { | ||||
|         cmp::min(rhs - lhs, MAX_DISTANCE) | ||||
|     } else { | ||||
|         cmp::min((lhs - rhs) + 1, MAX_DISTANCE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 { | ||||
|     let (lhs_attr, lhs_index) = extract_position(lhs); | ||||
|     let (rhs_attr, rhs_index) = extract_position(rhs); | ||||
|     if lhs_attr != rhs_attr { MAX_DISTANCE } | ||||
|     else { index_proximity(lhs_index, rhs_index) } | ||||
| } | ||||
|  | ||||
| // Returns the attribute and index parts. | ||||
| pub fn extract_position(position: u32) -> (u32, u32) { | ||||
|     (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) | ||||
| } | ||||
|  | ||||
| // Returns the group of four positions in which this position reside (i.e. 0, 4, 12). | ||||
| pub fn group_of_four(position: u32) -> u32 { | ||||
|     position - position % 4 | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] | ||||
| pub enum Node { | ||||
|     // Is this node is the first node. | ||||
|     Uninit, | ||||
|     Init { | ||||
|         // The layer where this node located. | ||||
|         layer: usize, | ||||
|         // The position where this node is located. | ||||
|         position: u32, | ||||
|         // The parent position from the above layer. | ||||
|         parent_position: u32, | ||||
|     }, | ||||
| } | ||||
|  | ||||
| impl Node { | ||||
|     // TODO we must skip the successors that have already been seen | ||||
|     // TODO we must skip the successors that doesn't return any documents | ||||
|     //      this way we are able to skip entire paths | ||||
|     pub fn successors<F>(&self, positions: &[RoaringBitmap], contains_documents: &mut F) -> Vec<(Node, u32)> | ||||
|     where F: FnMut((usize, u32), (usize, u32)) -> bool, | ||||
|     { | ||||
|         match self { | ||||
|             Node::Uninit => { | ||||
|                 positions[0].iter().map(|position| { | ||||
|                     (Node::Init { layer: 0, position, parent_position: 0 }, 0) | ||||
|                 }).collect() | ||||
|             }, | ||||
|             // We reached the highest layer | ||||
|             n @ Node::Init { .. } if n.is_complete(positions) => vec![], | ||||
|             Node::Init { layer, position, .. } => { | ||||
|                 positions[layer + 1].iter().filter_map(|p| { | ||||
|                     let proximity = positions_proximity(*position, p); | ||||
|                     let node = Node::Init { | ||||
|                         layer: layer + 1, | ||||
|                         position: p, | ||||
|                         parent_position: *position, | ||||
|                     }; | ||||
|                     // We do not produce the nodes we have already seen in previous iterations loops. | ||||
|                     if node.is_reachable(contains_documents) { | ||||
|                         Some((node, proximity)) | ||||
|                     } else { | ||||
|                         None | ||||
|                     } | ||||
|                 }).collect() | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn is_complete(&self, positions: &[RoaringBitmap]) -> bool { | ||||
|         match self { | ||||
|             Node::Uninit => false, | ||||
|             Node::Init { layer, .. } => *layer == positions.len() - 1, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn position(&self) -> Option<u32> { | ||||
|         match self { | ||||
|             Node::Uninit => None, | ||||
|             Node::Init { position, .. } => Some(*position), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn is_reachable<F>(&self, contains_documents: &mut F) -> bool | ||||
|     where F: FnMut((usize, u32), (usize, u32)) -> bool, | ||||
|     { | ||||
|         match self { | ||||
|             Node::Uninit => true, | ||||
|             Node::Init { layer, position, parent_position, .. } => { | ||||
|                 match layer.checked_sub(1) { | ||||
|                     Some(parent_layer) => { | ||||
|                         (contains_documents)((parent_layer, *parent_position), (*layer, *position)) | ||||
|                     }, | ||||
|                     None => true, | ||||
|                 } | ||||
|             }, | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										311
									
								
								src/search.rs
									
									
									
									
									
								
							
							
						
						
									
										311
									
								
								src/search.rs
									
									
									
									
									
								
							| @@ -1,8 +1,5 @@ | ||||
| use std::cell::RefCell; | ||||
| use std::collections::{HashMap, HashSet}; | ||||
| use std::rc::Rc; | ||||
|  | ||||
| use astar_iter::AstarBagIter; | ||||
| use fst::{IntoStreamer, Streamer}; | ||||
| use levenshtein_automata::DFA; | ||||
| use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; | ||||
| @@ -10,7 +7,6 @@ use log::debug; | ||||
| use once_cell::sync::Lazy; | ||||
| use roaring::RoaringBitmap; | ||||
|  | ||||
| use crate::node::{self, Node}; | ||||
| use crate::query_tokens::{QueryTokens, QueryToken}; | ||||
| use crate::{Index, DocumentId, Position, Attribute}; | ||||
|  | ||||
| @@ -86,69 +82,52 @@ impl<'a> Search<'a> { | ||||
|         .collect() | ||||
|     } | ||||
|  | ||||
|     /// Fetch the words from the given FST related to the given DFAs along with the associated | ||||
|     /// positions and the unions of those positions where the words found appears in the documents. | ||||
|     fn fetch_words_positions( | ||||
|     /// Fetch the words from the given FST related to the | ||||
|     /// given DFAs along with the associated documents ids. | ||||
|     fn fetch_words_docids( | ||||
|         rtxn: &heed::RoTxn, | ||||
|         index: &Index, | ||||
|         fst: &fst::Set<&[u8]>, | ||||
|         dfas: Vec<(String, bool, DFA)>, | ||||
|     ) -> anyhow::Result<(Vec<Vec<(String, u8, RoaringBitmap)>>, Vec<RoaringBitmap>)> | ||||
|     ) -> anyhow::Result<Vec<(HashMap<String, (u8, RoaringBitmap)>, RoaringBitmap)>> | ||||
|     { | ||||
|         // A Vec storing all the derived words from the original query words, associated | ||||
|         // with the distance from the original word and the positions it appears at. | ||||
|         // The index the derived words appears in the Vec corresponds to the original query | ||||
|         // word position. | ||||
|         let mut derived_words = Vec::<Vec::<(String, u8, RoaringBitmap)>>::with_capacity(dfas.len()); | ||||
|         // A Vec storing the unions of all of each of the derived words positions. The index | ||||
|         // the union appears in the Vec corresponds to the original query word position. | ||||
|         let mut union_positions = Vec::<RoaringBitmap>::with_capacity(dfas.len()); | ||||
|         // with the distance from the original word and the docids where the words appears. | ||||
|         let mut derived_words = Vec::<(HashMap::<String, (u8, RoaringBitmap)>, RoaringBitmap)>::with_capacity(dfas.len()); | ||||
|  | ||||
|         for (_word, _is_prefix, dfa) in dfas { | ||||
|  | ||||
|             let mut acc_derived_words = Vec::new(); | ||||
|             let mut acc_union_positions = RoaringBitmap::new(); | ||||
|             let mut acc_derived_words = HashMap::new(); | ||||
|             let mut unions_docids = RoaringBitmap::new(); | ||||
|             let mut stream = fst.search_with_state(&dfa).into_stream(); | ||||
|             while let Some((word, state)) = stream.next() { | ||||
|  | ||||
|                 let word = std::str::from_utf8(word)?; | ||||
|                 let positions = index.word_positions.get(rtxn, word)?.unwrap(); | ||||
|                 let docids = index.word_docids.get(rtxn, word)?.unwrap(); | ||||
|                 let distance = dfa.distance(state); | ||||
|                 acc_union_positions.union_with(&positions); | ||||
|                 acc_derived_words.push((word.to_string(), distance.to_u8(), positions)); | ||||
|                 unions_docids.union_with(&docids); | ||||
|                 acc_derived_words.insert(word.to_string(), (distance.to_u8(), docids)); | ||||
|             } | ||||
|             derived_words.push(acc_derived_words); | ||||
|             union_positions.push(acc_union_positions); | ||||
|             derived_words.push((acc_derived_words, unions_docids)); | ||||
|         } | ||||
|  | ||||
|         Ok((derived_words, union_positions)) | ||||
|         Ok(derived_words) | ||||
|     } | ||||
|  | ||||
|     /// Returns the set of docids that contains all of the query words. | ||||
|     fn compute_candidates( | ||||
|         rtxn: &heed::RoTxn, | ||||
|         index: &Index, | ||||
|         derived_words: &[Vec<(String, u8, RoaringBitmap)>], | ||||
|         derived_words: &[(HashMap<String, (u8, RoaringBitmap)>, RoaringBitmap)], | ||||
|     ) -> anyhow::Result<RoaringBitmap> | ||||
|     { | ||||
|         // we do a union between all the docids of each of the derived words, | ||||
|         // we got N unions (the number of original query words), we then intersect them. | ||||
|         // TODO we must store the words documents ids to avoid these unions. | ||||
|         let mut candidates = RoaringBitmap::new(); | ||||
|         let number_of_attributes = index.number_of_attributes(rtxn)?.map_or(0, |n| n as u32); | ||||
|  | ||||
|         for (i, derived_words) in derived_words.iter().enumerate() { | ||||
|             let mut union_docids = RoaringBitmap::new(); | ||||
|             for (word, _distance, _positions) in derived_words { | ||||
|                 for attr in 0..number_of_attributes { | ||||
|                     if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attr))? { | ||||
|                         union_docids.union_with(&docids); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|         for (i, (_, union_docids)) in derived_words.iter().enumerate() { | ||||
|             if i == 0 { | ||||
|                 candidates = union_docids; | ||||
|                 candidates = union_docids.clone(); | ||||
|             } else { | ||||
|                 candidates.intersect_with(&union_docids); | ||||
|             } | ||||
| @@ -157,161 +136,6 @@ impl<'a> Search<'a> { | ||||
|         Ok(candidates) | ||||
|     } | ||||
|  | ||||
|     /// Returns the union of the same position for all the given words. | ||||
|     fn union_word_position( | ||||
|         rtxn: &heed::RoTxn, | ||||
|         index: &Index, | ||||
|         words: &[(String, u8, RoaringBitmap)], | ||||
|         position: Position, | ||||
|     ) -> anyhow::Result<RoaringBitmap> | ||||
|     { | ||||
|         let mut union_docids = RoaringBitmap::new(); | ||||
|         for (word, _distance, positions) in words { | ||||
|             if positions.contains(position) { | ||||
|                 if let Some(docids) = index.word_position_docids.get(rtxn, &(word, position))? { | ||||
|                     union_docids.union_with(&docids); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         Ok(union_docids) | ||||
|     } | ||||
|  | ||||
|     /// Returns the union of the same gorup of four positions for all the given words. | ||||
|     fn union_word_four_positions( | ||||
|         rtxn: &heed::RoTxn, | ||||
|         index: &Index, | ||||
|         words: &[(String, u8, RoaringBitmap)], | ||||
|         group: Position, | ||||
|     ) -> anyhow::Result<RoaringBitmap> | ||||
|     { | ||||
|         let mut union_docids = RoaringBitmap::new(); | ||||
|         for (word, _distance, _positions) in words { | ||||
|             // TODO would be better to check if the group exist | ||||
|             if let Some(docids) = index.word_four_positions_docids.get(rtxn, &(word, group))? { | ||||
|                 union_docids.union_with(&docids); | ||||
|             } | ||||
|         } | ||||
|         Ok(union_docids) | ||||
|     } | ||||
|  | ||||
|     /// Returns the union of the same attribute for all the given words. | ||||
|     fn union_word_attribute( | ||||
|         rtxn: &heed::RoTxn, | ||||
|         index: &Index, | ||||
|         words: &[(String, u8, RoaringBitmap)], | ||||
|         attribute: Attribute, | ||||
|     ) -> anyhow::Result<RoaringBitmap> | ||||
|     { | ||||
|         let mut union_docids = RoaringBitmap::new(); | ||||
|         for (word, _distance, _positions) in words { | ||||
|             if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attribute))? { | ||||
|                 union_docids.union_with(&docids); | ||||
|             } | ||||
|         } | ||||
|         Ok(union_docids) | ||||
|     } | ||||
|  | ||||
|     // Returns `true` if there is documents in common between the two words and positions given. | ||||
|     fn contains_documents( | ||||
|         rtxn: &heed::RoTxn, | ||||
|         index: &Index, | ||||
|         (lword, lpos): (usize, u32), | ||||
|         (rword, rpos): (usize, u32), | ||||
|         candidates: &RoaringBitmap, | ||||
|         derived_words: &[Vec<(String, u8, RoaringBitmap)>], | ||||
|         union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, | ||||
|         non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, | ||||
|         group_four_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, | ||||
|         group_four_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, | ||||
|         attribute_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, | ||||
|         attribute_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, | ||||
|     ) -> bool | ||||
|     { | ||||
|         if lpos == rpos { return false } | ||||
|  | ||||
|         // TODO move this function to a better place. | ||||
|         let (lattr, _) = node::extract_position(lpos); | ||||
|         let (rattr, _) = node::extract_position(rpos); | ||||
|  | ||||
|         if lattr == rattr { | ||||
|             // TODO move this function to a better place. | ||||
|             let lgroup = node::group_of_four(lpos); | ||||
|             let rgroup = node::group_of_four(rpos); | ||||
|  | ||||
|             // We can't compute a disjunction on a group of four positions if those | ||||
|             // two positions are in the same group, we must go down to the position. | ||||
|             if lgroup == rgroup { | ||||
|                 // We retrieve or compute the intersection between the two given words and positions. | ||||
|                 *non_disjoint_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { | ||||
|                     // We retrieve or compute the unions for the two words and positions. | ||||
|                     union_cache.entry((lword, lpos)).or_insert_with(|| { | ||||
|                         let words = &derived_words[lword]; | ||||
|                         Self::union_word_position(rtxn, index, words, lpos).unwrap() | ||||
|                     }); | ||||
|                     union_cache.entry((rword, rpos)).or_insert_with(|| { | ||||
|                         let words = &derived_words[rword]; | ||||
|                         Self::union_word_position(rtxn, index, words, rpos).unwrap() | ||||
|                     }); | ||||
|  | ||||
|                     // TODO is there a way to avoid this double gets? | ||||
|                     let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); | ||||
|                     let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); | ||||
|  | ||||
|                     // We first check that the docids of these unions are part of the candidates. | ||||
|                     if lunion_docids.is_disjoint(candidates) { return false } | ||||
|                     if runion_docids.is_disjoint(candidates) { return false } | ||||
|  | ||||
|                     !lunion_docids.is_disjoint(&runion_docids) | ||||
|                 }) | ||||
|             } else { | ||||
|                 // We retrieve or compute the intersection between the two given words and positions. | ||||
|                 *group_four_non_disjoint_cache.entry(((lword, lgroup), (rword, rgroup))).or_insert_with(|| { | ||||
|                     // We retrieve or compute the unions for the two words and group of four positions. | ||||
|                     group_four_union_cache.entry((lword, lgroup)).or_insert_with(|| { | ||||
|                         let words = &derived_words[lword]; | ||||
|                         Self::union_word_four_positions(rtxn, index, words, lgroup).unwrap() | ||||
|                     }); | ||||
|                     group_four_union_cache.entry((rword, rgroup)).or_insert_with(|| { | ||||
|                         let words = &derived_words[rword]; | ||||
|                         Self::union_word_four_positions(rtxn, index, words, rgroup).unwrap() | ||||
|                     }); | ||||
|  | ||||
|                     // TODO is there a way to avoid this double gets? | ||||
|                     let lunion_group_docids = group_four_union_cache.get(&(lword, lgroup)).unwrap(); | ||||
|                     let runion_group_docids = group_four_union_cache.get(&(rword, rgroup)).unwrap(); | ||||
|  | ||||
|                     // We first check that the docids of these unions are part of the candidates. | ||||
|                     if lunion_group_docids.is_disjoint(candidates) { return false } | ||||
|                     if runion_group_docids.is_disjoint(candidates) { return false } | ||||
|  | ||||
|                     !lunion_group_docids.is_disjoint(&runion_group_docids) | ||||
|                 }) | ||||
|             } | ||||
|         } else { | ||||
|             *attribute_non_disjoint_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { | ||||
|                 // We retrieve or compute the unions for the two words and positions. | ||||
|                 attribute_union_cache.entry((lword, lattr)).or_insert_with(|| { | ||||
|                     let words = &derived_words[lword]; | ||||
|                     Self::union_word_attribute(rtxn, index, words, lattr).unwrap() | ||||
|                 }); | ||||
|                 attribute_union_cache.entry((rword, rattr)).or_insert_with(|| { | ||||
|                     let words = &derived_words[rword]; | ||||
|                     Self::union_word_attribute(rtxn, index, words, rattr).unwrap() | ||||
|                 }); | ||||
|  | ||||
|                 // TODO is there a way to avoid this double gets? | ||||
|                 let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); | ||||
|                 let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); | ||||
|  | ||||
|                 // We first check that the docids of these unions are part of the candidates. | ||||
|                 if lunion_docids.is_disjoint(candidates) { return false } | ||||
|                 if runion_docids.is_disjoint(candidates) { return false } | ||||
|  | ||||
|                 !lunion_docids.is_disjoint(&runion_docids) | ||||
|             }) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn execute(&self) -> anyhow::Result<SearchResult> { | ||||
|         let rtxn = self.rtxn; | ||||
|         let index = self.index; | ||||
| @@ -333,111 +157,14 @@ impl<'a> Search<'a> { | ||||
|             return Ok(Default::default()); | ||||
|         } | ||||
|  | ||||
|         let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?; | ||||
|         let derived_words = Self::fetch_words_docids(rtxn, index, &fst, dfas)?; | ||||
|         let candidates = Self::compute_candidates(rtxn, index, &derived_words)?; | ||||
|  | ||||
|         debug!("candidates: {:?}", candidates); | ||||
|  | ||||
|         let union_cache = HashMap::new(); | ||||
|         let mut non_disjoint_cache = HashMap::new(); | ||||
|         let documents = vec![candidates]; | ||||
|  | ||||
|         let mut group_four_union_cache = HashMap::new(); | ||||
|         let mut group_four_non_disjoint_cache = HashMap::new(); | ||||
|  | ||||
|         let mut attribute_union_cache = HashMap::new(); | ||||
|         let mut attribute_non_disjoint_cache = HashMap::new(); | ||||
|  | ||||
|         let candidates = Rc::new(RefCell::new(candidates)); | ||||
|         let union_cache = Rc::new(RefCell::new(union_cache)); | ||||
|  | ||||
|         let candidates_cloned = candidates.clone(); | ||||
|         let union_cache_cloned = union_cache.clone(); | ||||
|         let mut contains_documents = |left, right| { | ||||
|             Self::contains_documents( | ||||
|                 rtxn, index, | ||||
|                 left, right, | ||||
|                 &candidates_cloned.borrow(), | ||||
|                 &derived_words, | ||||
|                 &mut union_cache_cloned.borrow_mut(), | ||||
|                 &mut non_disjoint_cache, | ||||
|                 &mut group_four_union_cache, | ||||
|                 &mut group_four_non_disjoint_cache, | ||||
|                 &mut attribute_union_cache, | ||||
|                 &mut attribute_non_disjoint_cache, | ||||
|             ) | ||||
|         }; | ||||
|  | ||||
|         let astar_iter = AstarBagIter::new( | ||||
|             Node::Uninit, // start | ||||
|             |n| n.successors(&union_positions, &mut contains_documents), // successors | ||||
|             |_| 0, // heuristic | ||||
|             |n| n.is_complete(&union_positions), // success | ||||
|         ); | ||||
|  | ||||
|         let mut documents = Vec::new(); | ||||
|         for (paths, proximity) in astar_iter { | ||||
|             let mut union_cache = union_cache.borrow_mut(); | ||||
|             let mut candidates = candidates.borrow_mut(); | ||||
|  | ||||
|             let mut positions: Vec<Vec<_>> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect(); | ||||
|             positions.sort_unstable(); | ||||
|  | ||||
|             debug!("Found {} positions with a proximity of {}", positions.len(), proximity); | ||||
|  | ||||
|             let mut same_proximity_union = RoaringBitmap::default(); | ||||
|             for positions in positions { | ||||
|                 // Precompute the potentially missing unions | ||||
|                 positions.iter().enumerate().for_each(|(word, pos)| { | ||||
|                     union_cache.entry((word, *pos)).or_insert_with(|| { | ||||
|                         let words = &&derived_words[word]; | ||||
|                         Self::union_word_position(rtxn, index, words, *pos).unwrap() | ||||
|                     }); | ||||
|                 }); | ||||
|  | ||||
|                 // Retrieve the unions along with the popularity of it. | ||||
|                 let mut to_intersect = Vec::new(); | ||||
|                 for (word, pos) in positions.into_iter().enumerate() { | ||||
|                     let docids = union_cache.get(&(word, pos)).unwrap(); | ||||
|                     to_intersect.push((docids.len(), docids)); | ||||
|                 } | ||||
|  | ||||
|                 // Sort the unions by popularity to help reduce | ||||
|                 // the number of documents as soon as possible. | ||||
|                 to_intersect.sort_unstable_by_key(|(l, _)| *l); | ||||
|  | ||||
|                 // Intersect all the unions in the inverse popularity order. | ||||
|                 let mut intersect_docids = RoaringBitmap::new(); | ||||
|                 for (i, (_, union_docids)) in to_intersect.into_iter().enumerate() { | ||||
|                     if i == 0 { | ||||
|                         intersect_docids = union_docids.clone(); | ||||
|                     } else { | ||||
|                         intersect_docids.intersect_with(union_docids); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 same_proximity_union.union_with(&intersect_docids); | ||||
|             } | ||||
|  | ||||
|             // We achieve to find valid documents ids so we remove them from the candidates list. | ||||
|             candidates.difference_with(&same_proximity_union); | ||||
|  | ||||
|             // We remove documents we have already been seen in previous | ||||
|             // fetches from this set of documents we just fetched. | ||||
|             for previous_documents in &documents { | ||||
|                 same_proximity_union.difference_with(previous_documents); | ||||
|             } | ||||
|  | ||||
|             if !same_proximity_union.is_empty() { | ||||
|                 documents.push(same_proximity_union); | ||||
|             } | ||||
|  | ||||
|             // We found enough documents we can stop here. | ||||
|             if documents.iter().map(RoaringBitmap::len).sum::<u64>() >= limit as u64 { | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let found_words = derived_words.into_iter().flatten().map(|(w, _, _)| w).collect(); | ||||
|         let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect(); | ||||
|         let documents_ids = documents.iter().flatten().take(limit).collect(); | ||||
|  | ||||
|         Ok(SearchResult { found_words, documents_ids }) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user