WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil
2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
20 changed files with 438 additions and 158 deletions

View File

@ -71,8 +71,8 @@ impl VectorStateDelta {
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
field_id_map: FieldsIdsMap,
prompt: Option<&Prompt>,
field_id_map: &FieldsIdsMap,
prompt: &Prompt,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!();
@ -142,14 +142,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
match prompt {
Some(prompt) => VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
&field_id_map,
)?),
None => VectorStateDelta::NowRemoved,
}
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
@ -162,26 +159,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
match prompt {
Some(prompt) => {
// Don't give up if the old prompt was failing
let old_prompt = prompt
.render(obkv, DelAdd::Deletion, &field_id_map)
.unwrap_or_default();
let new_prompt =
prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
}
// We no longer have a prompt, so we need to remove any existing vector
None => VectorStateDelta::NowRemoved,
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
@ -196,24 +185,16 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
match prompt {
Some(prompt) => {
// Don't give up if the old prompt was failing
let old_prompt = prompt
.render(obkv, DelAdd::Deletion, &field_id_map)
.unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
}
None => VectorStateDelta::NowRemoved,
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}");
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
@ -322,7 +303,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> {
) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
@ -341,8 +322,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?;
let mut expected_dimension = None;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string
@ -367,7 +346,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
)
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
@ -376,7 +354,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
}
chunks_ids.clear();
}
@ -387,7 +364,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let chunked_embeds = rt
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
@ -395,7 +371,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
}
}
@ -403,14 +378,12 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let embeds = rt
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
}
}
Ok((writer_into_reader(state_writer)?, expected_dimension))
writer_into_reader(state_writer)
}

View File

@ -292,43 +292,42 @@ fn send_original_documents_data(
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
let (embedder, prompt) = embedders.get("default").cloned().unzip();
let result =
extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref());
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
/// FIXME: support multiple embedders
let results = embedder.and_then(|embedder| {
match extract_embeddings(prompts, indexer, embedder.clone()) {
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
documents_chunk_cloned.clone(),
indexer,
&field_id_map,
&prompt,
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
None
}
}
});
let (embeddings, expected_dimension) = results.unzip();
let expected_dimension = expected_dimension.flatten();
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
/// FIXME FIXME FIXME
if expected_dimension.is_some() {
};
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
remove_vectors,
embeddings,
/// FIXME: compute an expected dimension from the manual vectors if any
expected_dimension: expected_dimension.unwrap(),
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name: name,
}));
}
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
};
}
});
// TODO: create a custom internal error