Remove the vectors from the documents database

This commit is contained in:
Tamo
2024-05-22 15:27:09 +02:00
parent 7a84697570
commit 84e498299b
14 changed files with 407 additions and 51 deletions

View File

@ -10,16 +10,16 @@ use bytemuck::cast_slice;
use grenad::Writer;
use itertools::EitherOrBoth;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use serde_json::Value;
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at;
use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::parsed_vectors::{ParsedVectorsDiff, RESERVED_VECTORS_FIELD_NAME};
use crate::vector::Embedder;
use crate::{DocumentId, Result, ThreadPoolNoAbort};
use crate::{try_split_array_at, DocumentId, Result, ThreadPoolNoAbort};
/// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
@ -35,6 +35,8 @@ pub struct ExtractedVectorPoints {
// embedder
pub embedder_name: String,
pub embedder: Arc<Embedder>,
pub user_defined: RoaringBitmap,
pub remove_from_user_defined: RoaringBitmap,
}
enum VectorStateDelta {
@ -80,6 +82,11 @@ struct EmbedderVectorExtractor {
prompts_writer: Writer<BufWriter<File>>,
// (docid) -> ()
remove_vectors_writer: Writer<BufWriter<File>>,
// The docids of the documents that contains a user defined embedding
user_defined: RoaringBitmap,
// The docids of the documents that contains an auto-generated embedding
remove_from_user_defined: RoaringBitmap,
}
/// Extracts the embedding vector contained in each document under the `_vectors` field.
@ -134,6 +141,8 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined: RoaringBitmap::new(),
remove_from_user_defined: RoaringBitmap::new(),
});
}
@ -141,13 +150,15 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
// this must always be serialized as (docid, external_docid);
const SIZE_OF_DOCUMENTID: usize = std::mem::size_of::<DocumentId>();
let (docid_bytes, external_id_bytes) =
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
try_split_array_at::<u8, SIZE_OF_DOCUMENTID>(key).unwrap();
debug_assert!(from_utf8(external_id_bytes).is_ok());
let docid = DocumentId::from_be_bytes(docid_bytes);
let obkv = obkv::KvReader::new(value);
key_buffer.clear();
key_buffer.extend_from_slice(docid_bytes);
key_buffer.extend_from_slice(docid_bytes.as_slice());
// since we only need the primary key when we throw an error we create this getter to
// lazily get it when needed
@ -163,10 +174,22 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined,
remove_from_user_defined,
} in extractors.iter_mut()
{
let delta = match parsed_vectors.remove(embedder_name) {
(Some(old), Some(new)) => {
match (old.is_user_provided(), new.is_user_provided()) {
(true, true) | (false, false) => (),
(true, false) => {
remove_from_user_defined.insert(docid);
}
(false, true) => {
user_defined.insert(docid);
}
}
// no autogeneration
let del_vectors = old.into_array_of_vectors();
let add_vectors = new.into_array_of_vectors();
@ -187,6 +210,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
remove_from_user_defined.insert(docid);
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
@ -198,6 +222,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
}
}
(None, Some(new)) => {
if new.is_user_provided() {
user_defined.insert(docid);
} else {
remove_from_user_defined.insert(docid);
}
// was possibly autogenerated, remove all vectors for that document
let add_vectors = new.into_array_of_vectors();
if add_vectors.len() > usize::from(u8::MAX) {
@ -239,6 +268,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
VectorStateDelta::NoChange
}
} else {
remove_from_user_defined.remove(docid);
VectorStateDelta::NowRemoved
}
}
@ -265,18 +295,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined,
remove_from_user_defined,
} in extractors
{
results.push(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
embedder,
embedder_name,
user_defined,
remove_from_user_defined,
})
}

View File

@ -238,6 +238,8 @@ fn send_original_documents_data(
prompts,
embedder_name,
embedder,
user_defined,
remove_from_user_defined: auto_generated,
} in extracted_vectors
{
let embeddings = match extract_embeddings(
@ -262,6 +264,8 @@ fn send_original_documents_data(
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined: auto_generated,
}));
}
}

View File

@ -501,6 +501,8 @@ where
embeddings,
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined,
} => {
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
@ -509,6 +511,8 @@ where
expected_dimension,
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined,
}
}
otherwise => otherwise,
@ -2616,10 +2620,11 @@ mod tests {
let rtxn = index.read_txn().unwrap();
let mut embedding_configs = index.embedding_configs(&rtxn).unwrap();
let (embedder_name, embedder) = embedding_configs.pop().unwrap();
let (embedder_name, embedder, user_defined) = embedding_configs.pop().unwrap();
insta::assert_snapshot!(embedder_name, @"manual");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[0, 1, 2]>");
let embedder =
std::sync::Arc::new(crate::vector::Embedder::new(embedder.embedder_options).unwrap());
assert_eq!("manual", embedder_name);
let res = index
.search(&rtxn)
.semantic(embedder_name, embedder, Some([0.0, 1.0, 2.0].to_vec()))

View File

@ -90,6 +90,8 @@ pub(crate) enum TypedChunk {
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
user_defined: RoaringBitmap,
remove_from_user_defined: RoaringBitmap,
},
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
}
@ -155,7 +157,7 @@ pub(crate) fn write_typed_chunk_into_index(
let mut iter = merger.into_stream_merger_iter()?;
let embedders: BTreeSet<_> =
index.embedding_configs(wtxn)?.into_iter().map(|(k, _v)| k).collect();
index.embedding_configs(wtxn)?.into_iter().map(|(name, _, _)| name).collect();
let mut vectors_buffer = Vec::new();
while let Some((key, reader)) = iter.next()? {
let mut writer: KvWriter<_, FieldId> = KvWriter::memory();
@ -181,7 +183,7 @@ pub(crate) fn write_typed_chunk_into_index(
// if the `_vectors` field cannot be parsed as map of vectors, just write it as-is
break 'vectors Some(addition);
};
vectors.retain_user_provided_vectors(&embedders);
vectors.retain_not_embedded_vectors(&embedders);
let crate::vector::parsed_vectors::ParsedVectors(vectors) = vectors;
if vectors.is_empty() {
// skip writing empty `_vectors` map
@ -619,6 +621,8 @@ pub(crate) fn write_typed_chunk_into_index(
let mut remove_vectors_builder = MergerBuilder::new(keep_first as MergeFn);
let mut manual_vectors_builder = MergerBuilder::new(keep_first as MergeFn);
let mut embeddings_builder = MergerBuilder::new(keep_first as MergeFn);
let mut user_defined = RoaringBitmap::new();
let mut remove_from_user_defined = RoaringBitmap::new();
let mut params = None;
for typed_chunk in typed_chunks {
let TypedChunk::VectorPoints {
@ -627,6 +631,8 @@ pub(crate) fn write_typed_chunk_into_index(
embeddings,
expected_dimension,
embedder_name,
user_defined: ud,
remove_from_user_defined: rud,
} = typed_chunk
else {
unreachable!();
@ -639,11 +645,21 @@ pub(crate) fn write_typed_chunk_into_index(
if let Some(embeddings) = embeddings {
embeddings_builder.push(embeddings.into_cursor()?);
}
user_defined |= ud;
remove_from_user_defined |= rud;
}
// typed chunks has always at least 1 chunk.
let Some((expected_dimension, embedder_name)) = params else { unreachable!() };
let mut embedding_configs = index.embedding_configs(&wtxn)?;
let (_name, _conf, ud) =
embedding_configs.iter_mut().find(|config| config.0 == embedder_name).unwrap();
*ud -= remove_from_user_defined;
*ud |= user_defined;
index.put_embedding_configs(wtxn, embedding_configs)?;
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;