Remove the vectors from the documents database

This commit is contained in:
Tamo
2024-05-22 15:27:09 +02:00
parent 7a84697570
commit 84e498299b
14 changed files with 407 additions and 51 deletions

View File

@ -10,16 +10,16 @@ use bytemuck::cast_slice;
use grenad::Writer;
use itertools::EitherOrBoth;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use serde_json::Value;
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at;
use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::parsed_vectors::{ParsedVectorsDiff, RESERVED_VECTORS_FIELD_NAME};
use crate::vector::Embedder;
use crate::{DocumentId, Result, ThreadPoolNoAbort};
use crate::{try_split_array_at, DocumentId, Result, ThreadPoolNoAbort};
/// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
@ -35,6 +35,8 @@ pub struct ExtractedVectorPoints {
// embedder
pub embedder_name: String,
pub embedder: Arc<Embedder>,
pub user_defined: RoaringBitmap,
pub remove_from_user_defined: RoaringBitmap,
}
enum VectorStateDelta {
@ -80,6 +82,11 @@ struct EmbedderVectorExtractor {
prompts_writer: Writer<BufWriter<File>>,
// (docid) -> ()
remove_vectors_writer: Writer<BufWriter<File>>,
// The docids of the documents that contains a user defined embedding
user_defined: RoaringBitmap,
// The docids of the documents that contains an auto-generated embedding
remove_from_user_defined: RoaringBitmap,
}
/// Extracts the embedding vector contained in each document under the `_vectors` field.
@ -134,6 +141,8 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined: RoaringBitmap::new(),
remove_from_user_defined: RoaringBitmap::new(),
});
}
@ -141,13 +150,15 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
// this must always be serialized as (docid, external_docid);
const SIZE_OF_DOCUMENTID: usize = std::mem::size_of::<DocumentId>();
let (docid_bytes, external_id_bytes) =
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
try_split_array_at::<u8, SIZE_OF_DOCUMENTID>(key).unwrap();
debug_assert!(from_utf8(external_id_bytes).is_ok());
let docid = DocumentId::from_be_bytes(docid_bytes);
let obkv = obkv::KvReader::new(value);
key_buffer.clear();
key_buffer.extend_from_slice(docid_bytes);
key_buffer.extend_from_slice(docid_bytes.as_slice());
// since we only need the primary key when we throw an error we create this getter to
// lazily get it when needed
@ -163,10 +174,22 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined,
remove_from_user_defined,
} in extractors.iter_mut()
{
let delta = match parsed_vectors.remove(embedder_name) {
(Some(old), Some(new)) => {
match (old.is_user_provided(), new.is_user_provided()) {
(true, true) | (false, false) => (),
(true, false) => {
remove_from_user_defined.insert(docid);
}
(false, true) => {
user_defined.insert(docid);
}
}
// no autogeneration
let del_vectors = old.into_array_of_vectors();
let add_vectors = new.into_array_of_vectors();
@ -187,6 +210,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
remove_from_user_defined.insert(docid);
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
@ -198,6 +222,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
}
}
(None, Some(new)) => {
if new.is_user_provided() {
user_defined.insert(docid);
} else {
remove_from_user_defined.insert(docid);
}
// was possibly autogenerated, remove all vectors for that document
let add_vectors = new.into_array_of_vectors();
if add_vectors.len() > usize::from(u8::MAX) {
@ -239,6 +268,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
VectorStateDelta::NoChange
}
} else {
remove_from_user_defined.remove(docid);
VectorStateDelta::NowRemoved
}
}
@ -265,18 +295,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined,
remove_from_user_defined,
} in extractors
{
results.push(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
embedder,
embedder_name,
user_defined,
remove_from_user_defined,
})
}

View File

@ -238,6 +238,8 @@ fn send_original_documents_data(
prompts,
embedder_name,
embedder,
user_defined,
remove_from_user_defined: auto_generated,
} in extracted_vectors
{
let embeddings = match extract_embeddings(
@ -262,6 +264,8 @@ fn send_original_documents_data(
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined: auto_generated,
}));
}
}