This commit is contained in:
Louis Dureuil
2025-07-17 11:28:30 +02:00
parent 5d363205a5
commit f0b55e0349
8 changed files with 564 additions and 5 deletions

View File

@@ -575,6 +575,63 @@ impl<'b> ExtractorBbqueueSender<'b> {
Ok(())
}
fn set_vectors_flat(
&self,
docid: u32,
embedder_id: u8,
dimensions: usize,
flat_embeddings: &[f32],
) -> crate::Result<()> {
let max_grant = self.max_grant;
let refcell = self.producers.get().unwrap();
let mut producer = refcell.0.borrow_mut_or_yield();
let arroy_set_vector = ArroySetVectors { docid, embedder_id, _padding: [0; 3] };
let payload_header = EntryHeader::ArroySetVectors(arroy_set_vector);
// we are taking the number of floats in the flat embeddings so we mustn't use the dimensions here
let total_length = EntryHeader::total_set_vectors_size(flat_embeddings.len(), 1);
if total_length > max_grant {
let mut value_file = tempfile::tempfile().map(BufWriter::new)?;
let mut embedding_bytes = bytemuck::cast_slice(flat_embeddings);
io::copy(&mut embedding_bytes, &mut value_file)?;
let value_file = value_file.into_inner().map_err(|ie| ie.into_error())?;
let embeddings = unsafe { Mmap::map(&value_file)? };
let large_vectors = LargeVectors { docid, embedder_id, embeddings };
self.sender.send(ReceiverAction::LargeVectors(large_vectors)).unwrap();
return Ok(());
}
// Spin loop to have a frame the size we requested.
reserve_and_write_grant(
&mut producer,
total_length,
&self.sender,
&self.sent_messages_attempts,
&self.blocking_sent_messages_attempts,
|grant| {
let header_size = payload_header.header_size();
let (header_bytes, remaining) = grant.split_at_mut(header_size);
payload_header.serialize_into(header_bytes);
if dimensions != 0 {
let output_iter =
remaining.chunks_exact_mut(dimensions * mem::size_of::<f32>());
for (embedding, output) in flat_embeddings.chunks(dimensions).zip(output_iter) {
output.copy_from_slice(bytemuck::cast_slice(embedding));
}
}
Ok(())
},
)?;
Ok(())
}
fn set_vectors(
&self,
docid: u32,
@@ -640,7 +697,7 @@ impl<'b> ExtractorBbqueueSender<'b> {
docid: u32,
embedder_id: u8,
extractor_id: u8,
embedding: Option<Embedding>,
embedding: Option<&[f32]>,
) -> crate::Result<()> {
let max_grant = self.max_grant;
let refcell = self.producers.get().unwrap();
@@ -648,7 +705,7 @@ impl<'b> ExtractorBbqueueSender<'b> {
// If there are no vectors we specify the dimensions
// to zero to allocate no extra space at all
let dimensions = embedding.as_ref().map_or(0, |emb| emb.len());
let dimensions = embedding.map_or(0, |emb| emb.len());
let arroy_set_vector =
ArroySetVector { docid, embedder_id, extractor_id, _padding: [0; 2] };
@@ -1081,12 +1138,22 @@ impl EmbeddingSender<'_, '_> {
self.0.set_vectors(docid, embedder_id, &embeddings[..])
}
pub fn set_vectors_flat(
&self,
docid: DocumentId,
embedder_id: u8,
dimensions: usize,
flat_embeddings: &[f32],
) -> crate::Result<()> {
self.0.set_vectors_flat(docid, embedder_id, dimensions, flat_embeddings)
}
pub fn set_vector(
&self,
docid: DocumentId,
embedder_id: u8,
extractor_id: u8,
embedding: Option<Embedding>,
embedding: Option<&[f32]>,
) -> crate::Result<()> {
self.0.set_vector_for_extractor(docid, embedder_id, extractor_id, embedding)
}

View File

@@ -469,7 +469,7 @@ impl<'doc> OnEmbed<'doc> for OnEmbeddingDocumentUpdates<'doc, '_> {
response.metadata.docid,
self.embedder_id,
response.metadata.extractor_id,
response.embedding,
response.embedding.as_deref(),
)
.unwrap();
}

View File

@@ -21,8 +21,10 @@ use super::thread_local::ThreadLocal;
use crate::documents::PrimaryKey;
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::progress::{EmbedderStats, Progress};
use crate::update::new::indexer::vector::Visitable;
use crate::update::settings::SettingsDelta;
use crate::update::GrenadParameters;
use crate::vector::db::EmbeddingStatus;
use crate::vector::settings::{EmbedderAction, RemoveFragments, WriteBackToDocuments};
use crate::vector::{ArroyWrapper, Embedder, RuntimeEmbedders};
use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result, ThreadPoolNoAbort};
@@ -37,6 +39,7 @@ mod partial_dump;
mod post_processing;
pub mod settings_changes;
mod update_by_function;
mod vector;
mod write;
static LOG_MEMORY_METRICS_ONCE: Once = Once::new();
@@ -336,6 +339,115 @@ where
Ok(congestion)
}
#[allow(clippy::too_many_arguments)]
pub fn import_vectors<'indexer, DC, MSP, V>(
visitables: &[V],
statuses: HashMap<String, EmbeddingStatus>,
wtxn: &mut RwTxn,
index: &Index,
pool: &ThreadPoolNoAbort,
grenad_parameters: GrenadParameters,
embedders: RuntimeEmbedders,
must_stop_processing: &'indexer MSP,
progress: &'indexer Progress,
) -> Result<ChannelCongestion>
where
MSP: Fn() -> bool + Sync,
V: Visitable + Sync,
{
let mut bbbuffers = Vec::new();
let finished_extraction = AtomicBool::new(false);
let arroy_memory = grenad_parameters.max_memory;
let (_, total_bbbuffer_capacity) =
indexer_memory_settings(pool.current_num_threads(), grenad_parameters);
let (extractor_sender, writer_receiver) = pool
.install(|| extractor_writer_bbqueue(&mut bbbuffers, total_bbbuffer_capacity, 1000))
.unwrap();
let index_embeddings = index.embedding_configs().embedding_configs(wtxn)?;
let congestion = thread::scope(|s| -> Result<ChannelCongestion> {
let indexer_span = tracing::Span::current();
let embedders = &embedders;
let finished_extraction = &finished_extraction;
let extractor_handle =
Builder::new().name(S("indexer-extractors")).spawn_scoped(s, move || {
pool.install(move || {
vector::import_vectors(
visitables,
statuses,
must_stop_processing,
progress,
indexer_span,
extractor_sender,
finished_extraction,
index,
embedders,
)
})
.unwrap()
})?;
let vector_arroy = index.vector_arroy;
let arroy_writers: Result<HashMap<_, _>> = embedders
.inner_as_ref()
.iter()
.map(|(embedder_name, runtime)| {
let embedder_index = index
.embedding_configs()
.embedder_id(wtxn, embedder_name)?
.ok_or(InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
})?;
let dimensions = runtime.embedder.dimensions();
let writer = ArroyWrapper::new(vector_arroy, embedder_index, runtime.is_quantized);
Ok((
embedder_index,
(embedder_name.as_str(), &*runtime.embedder, writer, dimensions),
))
})
.collect();
let mut arroy_writers = arroy_writers?;
let congestion =
write_to_db(writer_receiver, finished_extraction, index, wtxn, &arroy_writers)?;
progress.update_progress(IndexingStep::WaitingForExtractors);
let () = extractor_handle.join().unwrap()?;
progress.update_progress(IndexingStep::WritingEmbeddingsToDatabase);
pool.install(|| {
build_vectors(
index,
wtxn,
progress,
index_embeddings,
arroy_memory,
&mut arroy_writers,
None,
&must_stop_processing,
)
})
.unwrap()?;
progress.update_progress(IndexingStep::Finalizing);
Ok(congestion) as Result<_>
})?;
Ok(congestion)
}
fn arroy_writers_from_embedder_actions<'indexer>(
index: &Index,
embedder_actions: &'indexer BTreeMap<String, EmbedderAction>,

View File

@@ -0,0 +1,213 @@
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use hashbrown::HashMap;
use heed::{RoTxn, WithoutTls};
use rayon::iter::IntoParallelIterator as _;
use tracing::Span;
use crate::progress::Progress;
use crate::update::new::channel::{EmbeddingSender, ExtractorBbqueueSender};
use crate::update::new::parallel_iterator_ext::ParallelIteratorExt as _;
use crate::update::new::steps::IndexingStep;
use crate::vector::db::EmbeddingStatus;
use crate::vector::RuntimeEmbedders;
use crate::{DocumentId, Index, InternalError, Result};
// 1. a parallel iterator of visitables
// implement the latter on dump::VectorReader
// add skip vectors to regular indexing ops
// call import vectors
// write vector files
pub trait Visitor {
type Error: 'static + std::fmt::Debug;
fn on_current_embedder_change(&mut self, name: &str)
-> std::result::Result<usize, Self::Error>;
fn on_current_store_change(
&mut self,
name: Option<&str>,
) -> std::result::Result<(), Self::Error>;
fn on_current_docid_change(
&mut self,
external_docid: &str,
) -> std::result::Result<(), Self::Error>;
fn on_set_vector(&mut self, v: &[f32]) -> std::result::Result<(), Self::Error>;
fn on_set_vectors_flat(&mut self, v: &[f32]) -> std::result::Result<(), Self::Error>;
}
pub trait Visitable {
type Error: std::fmt::Debug;
fn visit<V: Visitor>(
&self,
v: &mut V,
) -> std::result::Result<std::result::Result<(), V::Error>, Self::Error>;
}
struct ImportVectorVisitor<'a, 'b, MSP> {
embedder: Option<EmbedderData>,
store_id: Option<u8>,
docid: Option<DocumentId>,
sender: EmbeddingSender<'a, 'b>,
rtxn: RoTxn<'a, WithoutTls>,
index: &'a Index,
runtimes: &'a RuntimeEmbedders,
must_stop_processing: MSP,
}
impl<'a, 'b, MSP> ImportVectorVisitor<'a, 'b, MSP>
where
MSP: Fn() -> bool + Sync,
{
pub fn new(
sender: EmbeddingSender<'a, 'b>,
index: &'a Index,
rtxn: RoTxn<'a, WithoutTls>,
runtimes: &'a RuntimeEmbedders,
must_stop_processing: MSP,
) -> Self {
Self {
embedder: None,
store_id: None,
docid: None,
sender,
rtxn,
index,
runtimes,
must_stop_processing,
}
}
}
struct EmbedderData {
id: u8,
dimensions: usize,
name: String,
}
impl<MSP> Visitor for ImportVectorVisitor<'_, '_, MSP>
where
MSP: Fn() -> bool + Sync,
{
type Error = crate::Error;
fn on_current_embedder_change(
&mut self,
name: &str,
) -> std::result::Result<usize, Self::Error> {
if (self.must_stop_processing)() {
return Err(InternalError::AbortedIndexation.into());
}
let embedder_id = self.index.embedding_configs().embedder_id(&self.rtxn, name)?.unwrap();
let embedder_name = name.to_string();
let runtime_embedder = self.runtimes.get(name).unwrap();
let dimensions = runtime_embedder.embedder.dimensions();
self.embedder = Some(EmbedderData { id: embedder_id, dimensions, name: embedder_name });
self.store_id = None;
self.docid = None;
Ok(dimensions)
}
fn on_current_store_change(
&mut self,
name: Option<&str>,
) -> std::result::Result<(), Self::Error> {
if (self.must_stop_processing)() {
return Err(InternalError::AbortedIndexation.into());
}
self.store_id = if let Some(fragment_name) = name {
let embedder_name = self.embedder.as_ref().map(|e| &e.name).unwrap();
let fragments = self.runtimes.get(embedder_name).unwrap().fragments();
Some(
fragments[fragments
.binary_search_by(|fragment| fragment.name.as_str().cmp(fragment_name))
.unwrap()]
.id,
)
} else {
None
};
Ok(())
}
fn on_current_docid_change(
&mut self,
external_docid: &str,
) -> std::result::Result<(), Self::Error> {
if (self.must_stop_processing)() {
return Err(InternalError::AbortedIndexation.into());
}
let docid = self.index.external_documents_ids().get(&self.rtxn, external_docid)?.unwrap();
self.docid = Some(docid);
Ok(())
}
fn on_set_vector(&mut self, v: &[f32]) -> std::result::Result<(), Self::Error> {
if (self.must_stop_processing)() {
return Err(InternalError::AbortedIndexation.into());
}
self.sender.set_vector(
self.docid.unwrap(),
self.embedder.as_ref().unwrap().id,
self.store_id.unwrap(),
Some(v),
)
}
fn on_set_vectors_flat(&mut self, v: &[f32]) -> std::result::Result<(), Self::Error> {
if (self.must_stop_processing)() {
return Err(InternalError::AbortedIndexation.into());
}
let embedder = self.embedder.as_ref().unwrap();
self.sender.set_vectors_flat(self.docid.unwrap(), embedder.id, embedder.dimensions, v)
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn import_vectors<MSP, V: Visitable + Sync>(
visitables: &[V],
statuses: HashMap<String, EmbeddingStatus>,
must_stop_processing: MSP,
progress: &Progress,
indexer_span: Span,
extractor_sender: ExtractorBbqueueSender,
finished_extraction: &AtomicBool,
index: &Index,
runtimes: &RuntimeEmbedders,
) -> Result<()>
where
MSP: Fn() -> bool + Sync,
{
let span = tracing::trace_span!(target: "indexing::vectors", parent: &indexer_span, "import");
let _entered = span.enter();
let rtxn = index.read_txn()?;
let embedders = index.embedding_configs();
let embedding_sender = extractor_sender.embeddings();
for (name, status) in statuses {
let Some(mut info) = embedders.embedder_info(&rtxn, &name)? else { continue };
info.embedding_status = status;
embedding_sender.embedding_status(&name, info)?;
}
visitables.into_par_iter().try_arc_for_each_try_init(
|| {
let rtxn = index.read_txn()?;
let v = ImportVectorVisitor::new(
extractor_sender.embeddings(),
index,
rtxn,
runtimes,
&must_stop_processing,
);
Ok(v)
},
|context, visitable| visitable.visit(context).unwrap().map_err(Arc::new),
)?;
progress.update_progress(IndexingStep::WaitingForDatabaseWrites);
finished_extraction.store(true, std::sync::atomic::Ordering::Relaxed);
Result::Ok(())
}