Store the word positions under the documents

This commit is contained in:
Clément Renault
2020-09-05 18:03:06 +02:00
parent 580ed1119a
commit dc88a86259
7 changed files with 72 additions and 563 deletions

10
Cargo.lock generated
View File

@@ -85,15 +85,6 @@ dependencies = [
"warp", "warp",
] ]
[[package]]
name = "astar-iter"
version = "0.1.0"
source = "git+https://github.com/Kerollmops/astar-iter#87cb97a11c701f1a6025b72b673a8bfd0ca249a5"
dependencies = [
"indexmap",
"num-traits",
]
[[package]] [[package]]
name = "atty" name = "atty"
version = "0.2.11" version = "0.2.11"
@@ -990,7 +981,6 @@ dependencies = [
"arc-cache", "arc-cache",
"askama", "askama",
"askama_warp", "askama_warp",
"astar-iter",
"bitpacking", "bitpacking",
"bstr", "bstr",
"byteorder", "byteorder",

View File

@@ -8,7 +8,6 @@ default-run = "indexer"
[dependencies] [dependencies]
anyhow = "1.0.28" anyhow = "1.0.28"
arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" } arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" }
astar-iter = { git = "https://github.com/Kerollmops/astar-iter" }
bitpacking = "0.8.2" bitpacking = "0.8.2"
bstr = "0.2.13" bstr = "0.2.13"
byteorder = "1.3.4" byteorder = "1.3.4"

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::fs::File; use std::fs::File;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
@@ -13,6 +14,7 @@ use cow_utils::CowUtils;
use csv::StringRecord; use csv::StringRecord;
use flate2::read::GzDecoder; use flate2::read::GzDecoder;
use fst::IntoStreamer; use fst::IntoStreamer;
use heed::BytesDecode;
use heed::BytesEncode; use heed::BytesEncode;
use heed::EnvOpenOptions; use heed::EnvOpenOptions;
use heed::types::*; use heed::types::*;
@@ -25,7 +27,7 @@ use structopt::StructOpt;
use milli::heed_codec::CsvStringRecordCodec; use milli::heed_codec::CsvStringRecordCodec;
use milli::tokenizer::{simple_tokenizer, only_words}; use milli::tokenizer::{simple_tokenizer, only_words};
use milli::{SmallVec32, Index, DocumentId, Position, Attribute, BEU32}; use milli::{SmallVec32, Index, DocumentId, BEU32, StrBEU32Codec};
const LMDB_MAX_KEY_LENGTH: usize = 511; const LMDB_MAX_KEY_LENGTH: usize = 511;
const ONE_MILLION: usize = 1_000_000; const ONE_MILLION: usize = 1_000_000;
@@ -37,10 +39,8 @@ const HEADERS_KEY: &[u8] = b"\0headers";
const DOCUMENTS_IDS_KEY: &[u8] = b"\x04documents-ids"; const DOCUMENTS_IDS_KEY: &[u8] = b"\x04documents-ids";
const WORDS_FST_KEY: &[u8] = b"\x06words-fst"; const WORDS_FST_KEY: &[u8] = b"\x06words-fst";
const DOCUMENTS_IDS_BYTE: u8 = 4; const DOCUMENTS_IDS_BYTE: u8 = 4;
const WORD_ATTRIBUTE_DOCIDS_BYTE: u8 = 3; const WORD_DOCIDS_BYTE: u8 = 2;
const WORD_FOUR_POSITIONS_DOCIDS_BYTE: u8 = 5; const WORD_DOCID_POSITIONS_BYTE: u8 = 1;
const WORD_POSITION_DOCIDS_BYTE: u8 = 2;
const WORD_POSITIONS_BYTE: u8 = 1;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
#[global_allocator] #[global_allocator]
@@ -125,10 +125,7 @@ fn lmdb_key_valid_size(key: &[u8]) -> bool {
type MergeFn = fn(&[u8], &[Vec<u8>]) -> Result<Vec<u8>, ()>; type MergeFn = fn(&[u8], &[Vec<u8>]) -> Result<Vec<u8>, ()>;
struct Store { struct Store {
word_positions: ArcCache<SmallVec32<u8>, RoaringBitmap>, word_docids: ArcCache<SmallVec32<u8>, RoaringBitmap>,
word_position_docids: ArcCache<(SmallVec32<u8>, Position), RoaringBitmap>,
word_four_positions_docids: ArcCache<(SmallVec32<u8>, Position), RoaringBitmap>,
word_attribute_docids: ArcCache<(SmallVec32<u8>, Attribute), RoaringBitmap>,
documents_ids: RoaringBitmap, documents_ids: RoaringBitmap,
sorter: Sorter<MergeFn>, sorter: Sorter<MergeFn>,
documents_sorter: Sorter<MergeFn>, documents_sorter: Sorter<MergeFn>,
@@ -162,10 +159,7 @@ impl Store {
} }
Store { Store {
word_positions: ArcCache::new(arc_cache_size), word_docids: ArcCache::new(arc_cache_size),
word_position_docids: ArcCache::new(arc_cache_size),
word_four_positions_docids: ArcCache::new(arc_cache_size),
word_attribute_docids: ArcCache::new(arc_cache_size),
documents_ids: RoaringBitmap::new(), documents_ids: RoaringBitmap::new(),
sorter: builder.build(), sorter: builder.build(),
documents_sorter: documents_builder.build(), documents_sorter: documents_builder.build(),
@@ -173,65 +167,48 @@ impl Store {
} }
// Save the documents ids under the position and word we have seen it. // Save the documents ids under the position and word we have seen it.
pub fn insert_word_position_docid(&mut self, word: &str, position: Position, id: DocumentId) -> anyhow::Result<()> { pub fn insert_word_docid(&mut self, word: &str, id: DocumentId) -> anyhow::Result<()> {
let word_vec = SmallVec32::from(word.as_bytes()); let word_vec = SmallVec32::from(word.as_bytes());
let ids = RoaringBitmap::from_iter(Some(id)); let ids = RoaringBitmap::from_iter(Some(id));
let (_, lrus) = self.word_position_docids.insert((word_vec, position), ids, |old, new| old.union_with(&new)); let (_, lrus) = self.word_docids.insert(word_vec, ids, |old, new| old.union_with(&new));
Self::write_word_position_docids(&mut self.sorter, lrus)?; Self::write_word_docids(&mut self.sorter, lrus)?;
self.insert_word_position(word, position)?;
self.insert_word_four_positions_docid(word, position, id)?;
self.insert_word_attribute_docid(word, position / MAX_POSITION as u32, id)?;
Ok(()) Ok(())
} }
pub fn insert_word_four_positions_docid(&mut self, word: &str, position: Position, id: DocumentId) -> anyhow::Result<()> {
let position = position - position % 4;
let word_vec = SmallVec32::from(word.as_bytes());
let ids = RoaringBitmap::from_iter(Some(id));
let (_, lrus) = self.word_four_positions_docids.insert((word_vec, position), ids, |old, new| old.union_with(&new));
Self::write_word_four_positions_docids(&mut self.sorter, lrus)
}
// Save the positions where this word has been seen.
pub fn insert_word_position(&mut self, word: &str, position: Position) -> anyhow::Result<()> {
let word = SmallVec32::from(word.as_bytes());
let position = RoaringBitmap::from_iter(Some(position));
let (_, lrus) = self.word_positions.insert(word, position, |old, new| old.union_with(&new));
Self::write_word_positions(&mut self.sorter, lrus)
}
// Save the documents ids under the attribute and word we have seen it.
fn insert_word_attribute_docid(&mut self, word: &str, attribute: Attribute, id: DocumentId) -> anyhow::Result<()> {
let word = SmallVec32::from(word.as_bytes());
let ids = RoaringBitmap::from_iter(Some(id));
let (_, lrus) = self.word_attribute_docids.insert((word, attribute), ids, |old, new| old.union_with(&new));
Self::write_word_attribute_docids(&mut self.sorter, lrus)
}
pub fn write_headers(&mut self, headers: &StringRecord) -> anyhow::Result<()> { pub fn write_headers(&mut self, headers: &StringRecord) -> anyhow::Result<()> {
let headers = CsvStringRecordCodec::bytes_encode(headers) let headers = CsvStringRecordCodec::bytes_encode(headers)
.with_context(|| format!("could not encode csv record"))?; .with_context(|| format!("could not encode csv record"))?;
Ok(self.sorter.insert(HEADERS_KEY, headers)?) Ok(self.sorter.insert(HEADERS_KEY, headers)?)
} }
pub fn write_document(&mut self, id: DocumentId, record: &StringRecord) -> anyhow::Result<()> { pub fn write_document(
&mut self,
id: DocumentId,
iter: impl IntoIterator<Item=(String, RoaringBitmap)>,
record: &StringRecord,
) -> anyhow::Result<()>
{
let record = CsvStringRecordCodec::bytes_encode(record) let record = CsvStringRecordCodec::bytes_encode(record)
.with_context(|| format!("could not encode csv record"))?; .with_context(|| format!("could not encode csv record"))?;
self.documents_ids.insert(id); self.documents_ids.insert(id);
Ok(self.documents_sorter.insert(id.to_be_bytes(), record)?) self.documents_sorter.insert(id.to_be_bytes(), record)?;
Self::write_docid_word_positions(&mut self.sorter, id, iter)?;
Ok(())
} }
fn write_word_positions<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> fn write_docid_word_positions<I>(sorter: &mut Sorter<MergeFn>, id: DocumentId, iter: I) -> anyhow::Result<()>
where I: IntoIterator<Item=(SmallVec32<u8>, RoaringBitmap)> where I: IntoIterator<Item=(String, RoaringBitmap)>
{ {
// postings ids keys are all prefixed // postings positions ids keys are all prefixed
let mut key = vec![WORD_POSITIONS_BYTE]; let mut key = vec![WORD_DOCID_POSITIONS_BYTE];
let mut buffer = Vec::new(); let mut buffer = Vec::new();
for (word, positions) in iter { for (word, positions) in iter {
key.truncate(1); key.truncate(1);
key.extend_from_slice(&word); key.extend_from_slice(word.as_bytes());
// We serialize the positions into a buffer // We prefix the words by the document id.
key.extend_from_slice(&id.to_be_bytes());
// We serialize the document ids into a buffer
buffer.clear(); buffer.clear();
buffer.reserve(positions.serialized_size()); buffer.reserve(positions.serialized_size());
positions.serialize_into(&mut buffer)?; positions.serialize_into(&mut buffer)?;
@@ -244,68 +221,16 @@ impl Store {
Ok(()) Ok(())
} }
fn write_word_position_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()> fn write_word_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()>
where I: IntoIterator<Item=((SmallVec32<u8>, Position), RoaringBitmap)> where I: IntoIterator<Item=(SmallVec32<u8>, RoaringBitmap)>
{ {
// postings positions ids keys are all prefixed // postings positions ids keys are all prefixed
let mut key = vec![WORD_POSITION_DOCIDS_BYTE]; let mut key = vec![WORD_DOCIDS_BYTE];
let mut buffer = Vec::new(); let mut buffer = Vec::new();
for ((word, pos), ids) in iter { for (word, ids) in iter {
key.truncate(1); key.truncate(1);
key.extend_from_slice(&word); key.extend_from_slice(&word);
// we postfix the word by the positions it appears in
key.extend_from_slice(&pos.to_be_bytes());
// We serialize the document ids into a buffer
buffer.clear();
buffer.reserve(ids.serialized_size());
ids.serialize_into(&mut buffer)?;
// that we write under the generated key into MTBL
if lmdb_key_valid_size(&key) {
sorter.insert(&key, &buffer)?;
}
}
Ok(())
}
fn write_word_four_positions_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()>
where I: IntoIterator<Item=((SmallVec32<u8>, Position), RoaringBitmap)>
{
// postings positions ids keys are all prefixed
let mut key = vec![WORD_FOUR_POSITIONS_DOCIDS_BYTE];
let mut buffer = Vec::new();
for ((word, pos), ids) in iter {
key.truncate(1);
key.extend_from_slice(&word);
// we postfix the word by the positions it appears in
key.extend_from_slice(&pos.to_be_bytes());
// We serialize the document ids into a buffer
buffer.clear();
buffer.reserve(ids.serialized_size());
ids.serialize_into(&mut buffer)?;
// that we write under the generated key into MTBL
if lmdb_key_valid_size(&key) {
sorter.insert(&key, &buffer)?;
}
}
Ok(())
}
fn write_word_attribute_docids<I>(sorter: &mut Sorter<MergeFn>, iter: I) -> anyhow::Result<()>
where I: IntoIterator<Item=((SmallVec32<u8>, Attribute), RoaringBitmap)>
{
// postings attributes keys are all prefixed
let mut key = vec![WORD_ATTRIBUTE_DOCIDS_BYTE];
let mut buffer = Vec::new();
for ((word, attr), ids) in iter {
key.truncate(1);
key.extend_from_slice(&word);
// we postfix the word by the positions it appears in
key.extend_from_slice(&attr.to_be_bytes());
// We serialize the document ids into a buffer // We serialize the document ids into a buffer
buffer.clear(); buffer.clear();
buffer.reserve(ids.serialized_size()); buffer.reserve(ids.serialized_size());
@@ -327,10 +252,7 @@ impl Store {
} }
pub fn finish(mut self) -> anyhow::Result<(Reader<Mmap>, Reader<Mmap>)> { pub fn finish(mut self) -> anyhow::Result<(Reader<Mmap>, Reader<Mmap>)> {
Self::write_word_positions(&mut self.sorter, self.word_positions)?; Self::write_word_docids(&mut self.sorter, self.word_docids)?;
Self::write_word_position_docids(&mut self.sorter, self.word_position_docids)?;
Self::write_word_four_positions_docids(&mut self.sorter, self.word_four_positions_docids)?;
Self::write_word_attribute_docids(&mut self.sorter, self.word_attribute_docids)?;
Self::write_documents_ids(&mut self.sorter, self.documents_ids)?; Self::write_documents_ids(&mut self.sorter, self.documents_ids)?;
let mut wtr = tempfile::tempfile().map(Writer::new)?; let mut wtr = tempfile::tempfile().map(Writer::new)?;
@@ -339,7 +261,8 @@ impl Store {
let mut iter = self.sorter.into_iter()?; let mut iter = self.sorter.into_iter()?;
while let Some(result) = iter.next() { while let Some(result) = iter.next() {
let (key, val) = result?; let (key, val) = result?;
if let Some((&1, word)) = key.split_first() { if let Some((&1, bytes)) = key.split_first() {
let (word, _docid) = StrBEU32Codec::bytes_decode(bytes).unwrap();
// This is a lexicographically ordered word position // This is a lexicographically ordered word position
// we use the key to construct the words fst. // we use the key to construct the words fst.
builder.insert(word)?; builder.insert(word)?;
@@ -389,12 +312,7 @@ fn merge(key: &[u8], values: &[Vec<u8>]) -> Result<Vec<u8>, ()> {
Ok(values[0].to_vec()) Ok(values[0].to_vec())
}, },
key => match key[0] { key => match key[0] {
DOCUMENTS_IDS_BYTE DOCUMENTS_IDS_BYTE | WORD_DOCIDS_BYTE | WORD_DOCID_POSITIONS_BYTE => {
| WORD_POSITIONS_BYTE
| WORD_POSITION_DOCIDS_BYTE
| WORD_FOUR_POSITIONS_DOCIDS_BYTE
| WORD_ATTRIBUTE_DOCIDS_BYTE =>
{
let (head, tail) = values.split_first().unwrap(); let (head, tail) = values.split_first().unwrap();
let mut head = RoaringBitmap::deserialize_from(head.as_slice()).unwrap(); let mut head = RoaringBitmap::deserialize_from(head.as_slice()).unwrap();
@@ -427,24 +345,14 @@ fn lmdb_writer(wtxn: &mut heed::RwTxn, index: &Index, key: &[u8], val: &[u8]) ->
// Write the documents ids list // Write the documents ids list
index.main.put::<_, Str, ByteSlice>(wtxn, "documents-ids", val)?; index.main.put::<_, Str, ByteSlice>(wtxn, "documents-ids", val)?;
} }
else if key.starts_with(&[WORD_POSITIONS_BYTE]) { else if key.starts_with(&[WORD_DOCIDS_BYTE]) {
// Write the postings lists // Write the postings lists
index.word_positions.as_polymorph() index.word_docids.as_polymorph()
.put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?;
} }
else if key.starts_with(&[WORD_POSITION_DOCIDS_BYTE]) { else if key.starts_with(&[WORD_DOCID_POSITIONS_BYTE]) {
// Write the postings lists // Write the postings lists
index.word_position_docids.as_polymorph() index.word_docid_positions.as_polymorph()
.put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?;
}
else if key.starts_with(&[WORD_FOUR_POSITIONS_DOCIDS_BYTE]) {
// Write the postings lists
index.word_four_positions_docids.as_polymorph()
.put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?;
}
else if key.starts_with(&[WORD_ATTRIBUTE_DOCIDS_BYTE]) {
// Write the attribute postings lists
index.word_attribute_docids.as_polymorph()
.put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?;
} }
@@ -499,6 +407,7 @@ fn index_csv(
let mut before = Instant::now(); let mut before = Instant::now();
let mut document_id: usize = 0; let mut document_id: usize = 0;
let mut document = csv::StringRecord::new(); let mut document = csv::StringRecord::new();
let mut word_positions = HashMap::new();
while rdr.read_record(&mut document)? { while rdr.read_record(&mut document)? {
// We skip documents that must not be indexed by this thread. // We skip documents that must not be indexed by this thread.
@@ -512,14 +421,15 @@ fn index_csv(
let document_id = DocumentId::try_from(document_id).context("generated id is too big")?; let document_id = DocumentId::try_from(document_id).context("generated id is too big")?;
for (attr, content) in document.iter().enumerate().take(MAX_ATTRIBUTES) { for (attr, content) in document.iter().enumerate().take(MAX_ATTRIBUTES) {
for (pos, (_, token)) in simple_tokenizer(&content).filter(only_words).enumerate().take(MAX_POSITION) { for (pos, (_, token)) in simple_tokenizer(&content).filter(only_words).enumerate().take(MAX_POSITION) {
let word = token.cow_to_lowercase(); let word = token.to_lowercase();
let position = (attr * MAX_POSITION + pos) as u32; let position = (attr * MAX_POSITION + pos) as u32;
store.insert_word_position_docid(&word, position, document_id)?; store.insert_word_docid(&word, document_id)?;
word_positions.entry(word).or_insert_with(RoaringBitmap::new).insert(position);
} }
} }
// We write the document in the database. // We write the document in the database.
store.write_document(document_id, &document)?; store.write_document(document_id, word_positions.drain(), &document)?;
} }
// Compute the document id of the the next document. // Compute the document id of the the next document.

View File

@@ -8,8 +8,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU32Codec {
type DItem = (&'a str, u32); type DItem = (&'a str, u32);
fn bytes_decode(bytes: &'a [u8]) -> Option<Self::DItem> { fn bytes_decode(bytes: &'a [u8]) -> Option<Self::DItem> {
let str_len = bytes.len().checked_sub(4)?; let (str_bytes, n_bytes) = bytes.split_at(bytes.len() - 4);
let (str_bytes, n_bytes) = bytes.split_at(str_len);
let s = str::from_utf8(str_bytes).ok()?; let s = str::from_utf8(str_bytes).ok()?;
let n = n_bytes.try_into().map(u32::from_be_bytes).ok()?; let n = n_bytes.try_into().map(u32::from_be_bytes).ok()?;
Some((s, n)) Some((s, n))

View File

@@ -1,5 +1,4 @@
mod criterion; mod criterion;
mod node;
mod query_tokens; mod query_tokens;
mod search; mod search;
pub mod heed_codec; pub mod heed_codec;
@@ -16,7 +15,7 @@ use heed::{PolyDatabase, Database};
pub use self::search::{Search, SearchResult}; pub use self::search::{Search, SearchResult};
pub use self::criterion::{Criterion, default_criteria}; pub use self::criterion::{Criterion, default_criteria};
use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec, CsvStringRecordCodec}; pub use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec, CsvStringRecordCodec};
pub type FastMap4<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher32>>; pub type FastMap4<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher32>>;
pub type FastMap8<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher64>>; pub type FastMap8<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher64>>;
@@ -36,14 +35,10 @@ const DOCUMENTS_IDS_KEY: &str = "documents-ids";
pub struct Index { pub struct Index {
/// Contains many different types (e.g. the documents CSV headers). /// Contains many different types (e.g. the documents CSV headers).
pub main: PolyDatabase, pub main: PolyDatabase,
/// A word and all the positions where it appears in the whole dataset. /// A word and all the documents ids containing the word.
pub word_positions: Database<Str, RoaringBitmapCodec>, pub word_docids: Database<Str, RoaringBitmapCodec>,
/// Maps a word at a position (u32) and all the documents ids where the given word appears. /// Maps a word and a document id (u32) to all the positions where the given word appears.
pub word_position_docids: Database<StrBEU32Codec, RoaringBitmapCodec>, pub word_docid_positions: Database<StrBEU32Codec, RoaringBitmapCodec>,
/// Maps a word and a range of 4 positions, i.e. 0..4, 4..8, 12..16.
pub word_four_positions_docids: Database<StrBEU32Codec, RoaringBitmapCodec>,
/// Maps a word and an attribute (u32) to all the documents ids where the given word appears.
pub word_attribute_docids: Database<StrBEU32Codec, RoaringBitmapCodec>,
/// Maps the document id to the document as a CSV line. /// Maps the document id to the document as a CSV line.
pub documents: Database<OwnedType<BEU32>, ByteSlice>, pub documents: Database<OwnedType<BEU32>, ByteSlice>,
} }
@@ -52,10 +47,8 @@ impl Index {
pub fn new(env: &heed::Env) -> anyhow::Result<Index> { pub fn new(env: &heed::Env) -> anyhow::Result<Index> {
Ok(Index { Ok(Index {
main: env.create_poly_database(None)?, main: env.create_poly_database(None)?,
word_positions: env.create_database(Some("word-positions"))?, word_docids: env.create_database(Some("word-docids"))?,
word_position_docids: env.create_database(Some("word-position-docids"))?, word_docid_positions: env.create_database(Some("word-docid-positions"))?,
word_four_positions_docids: env.create_database(Some("word-four-positions-docids"))?,
word_attribute_docids: env.create_database(Some("word-attribute-docids"))?,
documents: env.create_database(Some("documents"))?, documents: env.create_database(Some("documents"))?,
}) })
} }

View File

@@ -1,109 +0,0 @@
use std::cmp;
use roaring::RoaringBitmap;
const ONE_ATTRIBUTE: u32 = 1000;
const MAX_DISTANCE: u32 = 8;
fn index_proximity(lhs: u32, rhs: u32) -> u32 {
if lhs <= rhs {
cmp::min(rhs - lhs, MAX_DISTANCE)
} else {
cmp::min((lhs - rhs) + 1, MAX_DISTANCE)
}
}
pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 {
let (lhs_attr, lhs_index) = extract_position(lhs);
let (rhs_attr, rhs_index) = extract_position(rhs);
if lhs_attr != rhs_attr { MAX_DISTANCE }
else { index_proximity(lhs_index, rhs_index) }
}
// Returns the attribute and index parts.
pub fn extract_position(position: u32) -> (u32, u32) {
(position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE)
}
// Returns the group of four positions in which this position reside (i.e. 0, 4, 12).
pub fn group_of_four(position: u32) -> u32 {
position - position % 4
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Node {
// Is this node is the first node.
Uninit,
Init {
// The layer where this node located.
layer: usize,
// The position where this node is located.
position: u32,
// The parent position from the above layer.
parent_position: u32,
},
}
impl Node {
// TODO we must skip the successors that have already been seen
// TODO we must skip the successors that doesn't return any documents
// this way we are able to skip entire paths
pub fn successors<F>(&self, positions: &[RoaringBitmap], contains_documents: &mut F) -> Vec<(Node, u32)>
where F: FnMut((usize, u32), (usize, u32)) -> bool,
{
match self {
Node::Uninit => {
positions[0].iter().map(|position| {
(Node::Init { layer: 0, position, parent_position: 0 }, 0)
}).collect()
},
// We reached the highest layer
n @ Node::Init { .. } if n.is_complete(positions) => vec![],
Node::Init { layer, position, .. } => {
positions[layer + 1].iter().filter_map(|p| {
let proximity = positions_proximity(*position, p);
let node = Node::Init {
layer: layer + 1,
position: p,
parent_position: *position,
};
// We do not produce the nodes we have already seen in previous iterations loops.
if node.is_reachable(contains_documents) {
Some((node, proximity))
} else {
None
}
}).collect()
}
}
}
pub fn is_complete(&self, positions: &[RoaringBitmap]) -> bool {
match self {
Node::Uninit => false,
Node::Init { layer, .. } => *layer == positions.len() - 1,
}
}
pub fn position(&self) -> Option<u32> {
match self {
Node::Uninit => None,
Node::Init { position, .. } => Some(*position),
}
}
pub fn is_reachable<F>(&self, contains_documents: &mut F) -> bool
where F: FnMut((usize, u32), (usize, u32)) -> bool,
{
match self {
Node::Uninit => true,
Node::Init { layer, position, parent_position, .. } => {
match layer.checked_sub(1) {
Some(parent_layer) => {
(contains_documents)((parent_layer, *parent_position), (*layer, *position))
},
None => true,
}
},
}
}
}

View File

@@ -1,8 +1,5 @@
use std::cell::RefCell;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use astar_iter::AstarBagIter;
use fst::{IntoStreamer, Streamer}; use fst::{IntoStreamer, Streamer};
use levenshtein_automata::DFA; use levenshtein_automata::DFA;
use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder;
@@ -10,7 +7,6 @@ use log::debug;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use crate::node::{self, Node};
use crate::query_tokens::{QueryTokens, QueryToken}; use crate::query_tokens::{QueryTokens, QueryToken};
use crate::{Index, DocumentId, Position, Attribute}; use crate::{Index, DocumentId, Position, Attribute};
@@ -86,69 +82,52 @@ impl<'a> Search<'a> {
.collect() .collect()
} }
/// Fetch the words from the given FST related to the given DFAs along with the associated /// Fetch the words from the given FST related to the
/// positions and the unions of those positions where the words found appears in the documents. /// given DFAs along with the associated documents ids.
fn fetch_words_positions( fn fetch_words_docids(
rtxn: &heed::RoTxn, rtxn: &heed::RoTxn,
index: &Index, index: &Index,
fst: &fst::Set<&[u8]>, fst: &fst::Set<&[u8]>,
dfas: Vec<(String, bool, DFA)>, dfas: Vec<(String, bool, DFA)>,
) -> anyhow::Result<(Vec<Vec<(String, u8, RoaringBitmap)>>, Vec<RoaringBitmap>)> ) -> anyhow::Result<Vec<(HashMap<String, (u8, RoaringBitmap)>, RoaringBitmap)>>
{ {
// A Vec storing all the derived words from the original query words, associated // A Vec storing all the derived words from the original query words, associated
// with the distance from the original word and the positions it appears at. // with the distance from the original word and the docids where the words appears.
// The index the derived words appears in the Vec corresponds to the original query let mut derived_words = Vec::<(HashMap::<String, (u8, RoaringBitmap)>, RoaringBitmap)>::with_capacity(dfas.len());
// word position.
let mut derived_words = Vec::<Vec::<(String, u8, RoaringBitmap)>>::with_capacity(dfas.len());
// A Vec storing the unions of all of each of the derived words positions. The index
// the union appears in the Vec corresponds to the original query word position.
let mut union_positions = Vec::<RoaringBitmap>::with_capacity(dfas.len());
for (_word, _is_prefix, dfa) in dfas { for (_word, _is_prefix, dfa) in dfas {
let mut acc_derived_words = Vec::new(); let mut acc_derived_words = HashMap::new();
let mut acc_union_positions = RoaringBitmap::new(); let mut unions_docids = RoaringBitmap::new();
let mut stream = fst.search_with_state(&dfa).into_stream(); let mut stream = fst.search_with_state(&dfa).into_stream();
while let Some((word, state)) = stream.next() { while let Some((word, state)) = stream.next() {
let word = std::str::from_utf8(word)?; let word = std::str::from_utf8(word)?;
let positions = index.word_positions.get(rtxn, word)?.unwrap(); let docids = index.word_docids.get(rtxn, word)?.unwrap();
let distance = dfa.distance(state); let distance = dfa.distance(state);
acc_union_positions.union_with(&positions); unions_docids.union_with(&docids);
acc_derived_words.push((word.to_string(), distance.to_u8(), positions)); acc_derived_words.insert(word.to_string(), (distance.to_u8(), docids));
} }
derived_words.push(acc_derived_words); derived_words.push((acc_derived_words, unions_docids));
union_positions.push(acc_union_positions);
} }
Ok((derived_words, union_positions)) Ok(derived_words)
} }
/// Returns the set of docids that contains all of the query words. /// Returns the set of docids that contains all of the query words.
fn compute_candidates( fn compute_candidates(
rtxn: &heed::RoTxn, rtxn: &heed::RoTxn,
index: &Index, index: &Index,
derived_words: &[Vec<(String, u8, RoaringBitmap)>], derived_words: &[(HashMap<String, (u8, RoaringBitmap)>, RoaringBitmap)],
) -> anyhow::Result<RoaringBitmap> ) -> anyhow::Result<RoaringBitmap>
{ {
// we do a union between all the docids of each of the derived words, // we do a union between all the docids of each of the derived words,
// we got N unions (the number of original query words), we then intersect them. // we got N unions (the number of original query words), we then intersect them.
// TODO we must store the words documents ids to avoid these unions.
let mut candidates = RoaringBitmap::new(); let mut candidates = RoaringBitmap::new();
let number_of_attributes = index.number_of_attributes(rtxn)?.map_or(0, |n| n as u32);
for (i, derived_words) in derived_words.iter().enumerate() {
let mut union_docids = RoaringBitmap::new();
for (word, _distance, _positions) in derived_words {
for attr in 0..number_of_attributes {
if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attr))? {
union_docids.union_with(&docids);
}
}
}
for (i, (_, union_docids)) in derived_words.iter().enumerate() {
if i == 0 { if i == 0 {
candidates = union_docids; candidates = union_docids.clone();
} else { } else {
candidates.intersect_with(&union_docids); candidates.intersect_with(&union_docids);
} }
@@ -157,161 +136,6 @@ impl<'a> Search<'a> {
Ok(candidates) Ok(candidates)
} }
/// Returns the union of the same position for all the given words.
fn union_word_position(
rtxn: &heed::RoTxn,
index: &Index,
words: &[(String, u8, RoaringBitmap)],
position: Position,
) -> anyhow::Result<RoaringBitmap>
{
let mut union_docids = RoaringBitmap::new();
for (word, _distance, positions) in words {
if positions.contains(position) {
if let Some(docids) = index.word_position_docids.get(rtxn, &(word, position))? {
union_docids.union_with(&docids);
}
}
}
Ok(union_docids)
}
/// Returns the union of the same gorup of four positions for all the given words.
fn union_word_four_positions(
rtxn: &heed::RoTxn,
index: &Index,
words: &[(String, u8, RoaringBitmap)],
group: Position,
) -> anyhow::Result<RoaringBitmap>
{
let mut union_docids = RoaringBitmap::new();
for (word, _distance, _positions) in words {
// TODO would be better to check if the group exist
if let Some(docids) = index.word_four_positions_docids.get(rtxn, &(word, group))? {
union_docids.union_with(&docids);
}
}
Ok(union_docids)
}
/// Returns the union of the same attribute for all the given words.
fn union_word_attribute(
rtxn: &heed::RoTxn,
index: &Index,
words: &[(String, u8, RoaringBitmap)],
attribute: Attribute,
) -> anyhow::Result<RoaringBitmap>
{
let mut union_docids = RoaringBitmap::new();
for (word, _distance, _positions) in words {
if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attribute))? {
union_docids.union_with(&docids);
}
}
Ok(union_docids)
}
// Returns `true` if there is documents in common between the two words and positions given.
fn contains_documents(
rtxn: &heed::RoTxn,
index: &Index,
(lword, lpos): (usize, u32),
(rword, rpos): (usize, u32),
candidates: &RoaringBitmap,
derived_words: &[Vec<(String, u8, RoaringBitmap)>],
union_cache: &mut HashMap<(usize, u32), RoaringBitmap>,
non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>,
group_four_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>,
group_four_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>,
attribute_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>,
attribute_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>,
) -> bool
{
if lpos == rpos { return false }
// TODO move this function to a better place.
let (lattr, _) = node::extract_position(lpos);
let (rattr, _) = node::extract_position(rpos);
if lattr == rattr {
// TODO move this function to a better place.
let lgroup = node::group_of_four(lpos);
let rgroup = node::group_of_four(rpos);
// We can't compute a disjunction on a group of four positions if those
// two positions are in the same group, we must go down to the position.
if lgroup == rgroup {
// We retrieve or compute the intersection between the two given words and positions.
*non_disjoint_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| {
// We retrieve or compute the unions for the two words and positions.
union_cache.entry((lword, lpos)).or_insert_with(|| {
let words = &derived_words[lword];
Self::union_word_position(rtxn, index, words, lpos).unwrap()
});
union_cache.entry((rword, rpos)).or_insert_with(|| {
let words = &derived_words[rword];
Self::union_word_position(rtxn, index, words, rpos).unwrap()
});
// TODO is there a way to avoid this double gets?
let lunion_docids = union_cache.get(&(lword, lpos)).unwrap();
let runion_docids = union_cache.get(&(rword, rpos)).unwrap();
// We first check that the docids of these unions are part of the candidates.
if lunion_docids.is_disjoint(candidates) { return false }
if runion_docids.is_disjoint(candidates) { return false }
!lunion_docids.is_disjoint(&runion_docids)
})
} else {
// We retrieve or compute the intersection between the two given words and positions.
*group_four_non_disjoint_cache.entry(((lword, lgroup), (rword, rgroup))).or_insert_with(|| {
// We retrieve or compute the unions for the two words and group of four positions.
group_four_union_cache.entry((lword, lgroup)).or_insert_with(|| {
let words = &derived_words[lword];
Self::union_word_four_positions(rtxn, index, words, lgroup).unwrap()
});
group_four_union_cache.entry((rword, rgroup)).or_insert_with(|| {
let words = &derived_words[rword];
Self::union_word_four_positions(rtxn, index, words, rgroup).unwrap()
});
// TODO is there a way to avoid this double gets?
let lunion_group_docids = group_four_union_cache.get(&(lword, lgroup)).unwrap();
let runion_group_docids = group_four_union_cache.get(&(rword, rgroup)).unwrap();
// We first check that the docids of these unions are part of the candidates.
if lunion_group_docids.is_disjoint(candidates) { return false }
if runion_group_docids.is_disjoint(candidates) { return false }
!lunion_group_docids.is_disjoint(&runion_group_docids)
})
}
} else {
*attribute_non_disjoint_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| {
// We retrieve or compute the unions for the two words and positions.
attribute_union_cache.entry((lword, lattr)).or_insert_with(|| {
let words = &derived_words[lword];
Self::union_word_attribute(rtxn, index, words, lattr).unwrap()
});
attribute_union_cache.entry((rword, rattr)).or_insert_with(|| {
let words = &derived_words[rword];
Self::union_word_attribute(rtxn, index, words, rattr).unwrap()
});
// TODO is there a way to avoid this double gets?
let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap();
let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap();
// We first check that the docids of these unions are part of the candidates.
if lunion_docids.is_disjoint(candidates) { return false }
if runion_docids.is_disjoint(candidates) { return false }
!lunion_docids.is_disjoint(&runion_docids)
})
}
}
pub fn execute(&self) -> anyhow::Result<SearchResult> { pub fn execute(&self) -> anyhow::Result<SearchResult> {
let rtxn = self.rtxn; let rtxn = self.rtxn;
let index = self.index; let index = self.index;
@@ -333,111 +157,14 @@ impl<'a> Search<'a> {
return Ok(Default::default()); return Ok(Default::default());
} }
let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?; let derived_words = Self::fetch_words_docids(rtxn, index, &fst, dfas)?;
let candidates = Self::compute_candidates(rtxn, index, &derived_words)?; let candidates = Self::compute_candidates(rtxn, index, &derived_words)?;
debug!("candidates: {:?}", candidates); debug!("candidates: {:?}", candidates);
let union_cache = HashMap::new(); let documents = vec![candidates];
let mut non_disjoint_cache = HashMap::new();
let mut group_four_union_cache = HashMap::new(); let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect();
let mut group_four_non_disjoint_cache = HashMap::new();
let mut attribute_union_cache = HashMap::new();
let mut attribute_non_disjoint_cache = HashMap::new();
let candidates = Rc::new(RefCell::new(candidates));
let union_cache = Rc::new(RefCell::new(union_cache));
let candidates_cloned = candidates.clone();
let union_cache_cloned = union_cache.clone();
let mut contains_documents = |left, right| {
Self::contains_documents(
rtxn, index,
left, right,
&candidates_cloned.borrow(),
&derived_words,
&mut union_cache_cloned.borrow_mut(),
&mut non_disjoint_cache,
&mut group_four_union_cache,
&mut group_four_non_disjoint_cache,
&mut attribute_union_cache,
&mut attribute_non_disjoint_cache,
)
};
let astar_iter = AstarBagIter::new(
Node::Uninit, // start
|n| n.successors(&union_positions, &mut contains_documents), // successors
|_| 0, // heuristic
|n| n.is_complete(&union_positions), // success
);
let mut documents = Vec::new();
for (paths, proximity) in astar_iter {
let mut union_cache = union_cache.borrow_mut();
let mut candidates = candidates.borrow_mut();
let mut positions: Vec<Vec<_>> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect();
positions.sort_unstable();
debug!("Found {} positions with a proximity of {}", positions.len(), proximity);
let mut same_proximity_union = RoaringBitmap::default();
for positions in positions {
// Precompute the potentially missing unions
positions.iter().enumerate().for_each(|(word, pos)| {
union_cache.entry((word, *pos)).or_insert_with(|| {
let words = &&derived_words[word];
Self::union_word_position(rtxn, index, words, *pos).unwrap()
});
});
// Retrieve the unions along with the popularity of it.
let mut to_intersect = Vec::new();
for (word, pos) in positions.into_iter().enumerate() {
let docids = union_cache.get(&(word, pos)).unwrap();
to_intersect.push((docids.len(), docids));
}
// Sort the unions by popularity to help reduce
// the number of documents as soon as possible.
to_intersect.sort_unstable_by_key(|(l, _)| *l);
// Intersect all the unions in the inverse popularity order.
let mut intersect_docids = RoaringBitmap::new();
for (i, (_, union_docids)) in to_intersect.into_iter().enumerate() {
if i == 0 {
intersect_docids = union_docids.clone();
} else {
intersect_docids.intersect_with(union_docids);
}
}
same_proximity_union.union_with(&intersect_docids);
}
// We achieve to find valid documents ids so we remove them from the candidates list.
candidates.difference_with(&same_proximity_union);
// We remove documents we have already been seen in previous
// fetches from this set of documents we just fetched.
for previous_documents in &documents {
same_proximity_union.difference_with(previous_documents);
}
if !same_proximity_union.is_empty() {
documents.push(same_proximity_union);
}
// We found enough documents we can stop here.
if documents.iter().map(RoaringBitmap::len).sum::<u64>() >= limit as u64 {
break;
}
}
let found_words = derived_words.into_iter().flatten().map(|(w, _, _)| w).collect();
let documents_ids = documents.iter().flatten().take(limit).collect(); let documents_ids = documents.iter().flatten().take(limit).collect();
Ok(SearchResult { found_words, documents_ids }) Ok(SearchResult { found_words, documents_ids })