Merge branch 'main' into fix-threshold-overcounting-bug

This commit is contained in:
Mubelotix
2025-07-07 12:26:37 +02:00
141 changed files with 6398 additions and 1608 deletions

View File

@@ -288,6 +288,8 @@ and can not be more than 511 bytes.", .document_id.to_string()
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
TooManyEmbedders(usize),
#[error("Too many fragments in the configuration. Found {0}, but limited to 256.")]
TooManyFragments(usize),
#[error("Cannot find embedder with name `{0}`.")]
InvalidSearchEmbedder(String),
#[error("Cannot find embedder with name `{0}`.")]

View File

@@ -30,7 +30,8 @@ use crate::order_by_map::OrderByMap;
use crate::prompt::PromptData;
use crate::proximity::ProximityPrecision;
use crate::update::new::StdResult;
use crate::vector::{ArroyStats, ArroyWrapper, Embedding, EmbeddingConfig};
use crate::vector::db::IndexEmbeddingConfigs;
use crate::vector::{ArroyStats, ArroyWrapper, Embedding};
use crate::{
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
FacetDistribution, FieldDistribution, FieldId, FieldIdMapMissingEntry, FieldIdWordCountCodec,
@@ -177,7 +178,7 @@ pub struct Index {
pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>,
/// Maps an embedder name to its id in the arroy store.
pub embedder_category_id: Database<Str, U8>,
pub(crate) embedder_category_id: Database<Unspecified, Unspecified>,
/// Vector store based on arroy™.
pub vector_arroy: arroy::Database<Unspecified>,
@@ -1745,34 +1746,6 @@ impl Index {
self.main.remap_key_type::<Str>().delete(txn, main_key::LOCALIZED_ATTRIBUTES_RULES)
}
/// Put the embedding configs:
/// 1. The name of the embedder
/// 2. The configuration option for this embedder
/// 3. The list of documents with a user provided embedding
pub(crate) fn put_embedding_configs(
&self,
wtxn: &mut RwTxn<'_>,
configs: Vec<IndexEmbeddingConfig>,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeJson<Vec<IndexEmbeddingConfig>>>().put(
wtxn,
main_key::EMBEDDING_CONFIGS,
&configs,
)
}
pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
self.main.remap_key_type::<Str>().delete(wtxn, main_key::EMBEDDING_CONFIGS)
}
pub fn embedding_configs(&self, rtxn: &RoTxn<'_>) -> Result<Vec<IndexEmbeddingConfig>> {
Ok(self
.main
.remap_types::<Str, SerdeJson<Vec<IndexEmbeddingConfig>>>()
.get(rtxn, main_key::EMBEDDING_CONFIGS)?
.unwrap_or_default())
}
pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> {
self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff)
}
@@ -1785,19 +1758,29 @@ impl Index {
self.main.remap_key_type::<Str>().delete(wtxn, main_key::SEARCH_CUTOFF)
}
pub fn embedding_configs(&self) -> IndexEmbeddingConfigs {
IndexEmbeddingConfigs::new(self.main, self.embedder_category_id)
}
pub fn embeddings(
&self,
rtxn: &RoTxn<'_>,
docid: DocumentId,
) -> Result<BTreeMap<String, Vec<Embedding>>> {
) -> Result<BTreeMap<String, (Vec<Embedding>, bool)>> {
let mut res = BTreeMap::new();
let embedding_configs = self.embedding_configs(rtxn)?;
for config in embedding_configs {
let embedder_id = self.embedder_category_id.get(rtxn, &config.name)?.unwrap();
let reader =
ArroyWrapper::new(self.vector_arroy, embedder_id, config.config.quantized());
let embedders = self.embedding_configs();
for config in embedders.embedding_configs(rtxn)? {
let embedder_info = embedders.embedder_info(rtxn, &config.name)?.unwrap();
let reader = ArroyWrapper::new(
self.vector_arroy,
embedder_info.embedder_id,
config.config.quantized(),
);
let embeddings = reader.item_vectors(rtxn, docid)?;
res.insert(config.name.to_owned(), embeddings);
res.insert(
config.name.to_owned(),
(embeddings, embedder_info.embedding_status.must_regenerate(docid)),
);
}
Ok(res)
}
@@ -1809,9 +1792,9 @@ impl Index {
pub fn arroy_stats(&self, rtxn: &RoTxn<'_>) -> Result<ArroyStats> {
let mut stats = ArroyStats::default();
let embedding_configs = self.embedding_configs(rtxn)?;
for config in embedding_configs {
let embedder_id = self.embedder_category_id.get(rtxn, &config.name)?.unwrap();
let embedding_configs = self.embedding_configs();
for config in embedding_configs.embedding_configs(rtxn)? {
let embedder_id = embedding_configs.embedder_id(rtxn, &config.name)?.unwrap();
let reader =
ArroyWrapper::new(self.vector_arroy, embedder_id, config.config.quantized());
reader.aggregate_stats(rtxn, &mut stats)?;
@@ -1936,13 +1919,6 @@ impl Index {
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct IndexEmbeddingConfig {
pub name: String,
pub config: EmbeddingConfig,
pub user_provided: RoaringBitmap,
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct ChatConfig {
pub description: String,

View File

@@ -6,12 +6,18 @@ use liquid::{ObjectView, ValueView};
#[derive(Debug, Clone)]
pub struct Context<'a, D: ObjectView, F: ArrayView> {
document: &'a D,
fields: &'a F,
fields: Option<&'a F>,
}
impl<'a, D: ObjectView, F: ArrayView> Context<'a, D, F> {
pub fn new(document: &'a D, fields: &'a F) -> Self {
Self { document, fields }
Self { document, fields: Some(fields) }
}
}
impl<'a, D: ObjectView> Context<'a, D, Vec<bool>> {
pub fn without_fields(document: &'a D) -> Self {
Self { document, fields: None }
}
}
@@ -21,17 +27,27 @@ impl<D: ObjectView, F: ArrayView> ObjectView for Context<'_, D, F> {
}
fn size(&self) -> i64 {
2
if self.fields.is_some() {
2
} else {
1
}
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
let keys = if self.fields.is_some() {
either::Either::Left(["doc", "fields"])
} else {
either::Either::Right(["doc"])
};
Box::new(keys.into_iter().map(KStringCow::from_static))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(self.document.as_value())
.chain(std::iter::once(self.fields.as_value())),
.chain(self.fields.iter().map(|fields| fields.as_value())),
)
}
@@ -40,13 +56,13 @@ impl<D: ObjectView, F: ArrayView> ObjectView for Context<'_, D, F> {
}
fn contains_key(&self, index: &str) -> bool {
index == "doc" || index == "fields"
index == "doc" || (index == "fields" && self.fields.is_some())
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"doc" => Some(self.document.as_value()),
"fields" => Some(self.fields.as_value()),
match (index, &self.fields) {
("doc", _) => Some(self.document.as_value()),
("fields", Some(fields)) => Some(fields.as_value()),
_ => None,
}
}

View File

@@ -144,18 +144,19 @@ impl ValueView for Document<'_> {
use crate::update::new::document::Document as DocumentTrait;
#[derive(Debug)]
pub struct ParseableDocument<'doc, D> {
pub struct ParseableDocument<'a, 'doc, D: DocumentTrait<'a> + Debug> {
document: D,
doc_alloc: &'doc Bump,
_marker: std::marker::PhantomData<&'a ()>,
}
impl<'doc, D> ParseableDocument<'doc, D> {
impl<'a, 'doc, D: DocumentTrait<'a> + Debug> ParseableDocument<'a, 'doc, D> {
pub fn new(document: D, doc_alloc: &'doc Bump) -> Self {
Self { document, doc_alloc }
Self { document, doc_alloc, _marker: std::marker::PhantomData }
}
}
impl<'doc, D: DocumentTrait<'doc> + Debug> ObjectView for ParseableDocument<'doc, D> {
impl<'a, D: DocumentTrait<'a> + Debug> ObjectView for ParseableDocument<'a, '_, D> {
fn as_value(&self) -> &dyn ValueView {
self
}
@@ -195,7 +196,7 @@ impl<'doc, D: DocumentTrait<'doc> + Debug> ObjectView for ParseableDocument<'doc
}
}
impl<'doc, D: DocumentTrait<'doc> + Debug> ValueView for ParseableDocument<'doc, D> {
impl<'a, D: DocumentTrait<'a> + Debug> ValueView for ParseableDocument<'a, '_, D> {
fn as_debug(&self) -> &dyn Debug {
self
}

View File

@@ -121,10 +121,10 @@ impl<D: ObjectView> ObjectView for FieldValue<'_, D> {
pub struct OwnedFields<'a, D: ObjectView>(Vec<FieldValue<'a, D>>);
#[derive(Debug)]
pub struct BorrowedFields<'a, 'map, D: ObjectView> {
pub struct BorrowedFields<'a, 'doc, 'map, D: ObjectView> {
document: &'a D,
field_id_map: &'a RefCell<GlobalFieldsIdsMap<'map>>,
doc_alloc: &'a Bump,
doc_alloc: &'doc Bump,
}
impl<'a, D: ObjectView> OwnedFields<'a, D> {
@@ -138,11 +138,11 @@ impl<'a, D: ObjectView> OwnedFields<'a, D> {
}
}
impl<'a, 'map, D: ObjectView> BorrowedFields<'a, 'map, D> {
impl<'a, 'doc, 'map, D: ObjectView> BorrowedFields<'a, 'doc, 'map, D> {
pub fn new(
document: &'a D,
field_id_map: &'a RefCell<GlobalFieldsIdsMap<'map>>,
doc_alloc: &'a Bump,
doc_alloc: &'doc Bump,
) -> Self {
Self { document, field_id_map, doc_alloc }
}
@@ -170,7 +170,7 @@ impl<D: ObjectView> ArrayView for OwnedFields<'_, D> {
}
}
impl<D: ObjectView> ArrayView for BorrowedFields<'_, '_, D> {
impl<D: ObjectView> ArrayView for BorrowedFields<'_, '_, '_, D> {
fn as_value(&self) -> &dyn ValueView {
self
}
@@ -212,7 +212,7 @@ impl<D: ObjectView> ArrayView for BorrowedFields<'_, '_, D> {
}
}
impl<D: ObjectView> ValueView for BorrowedFields<'_, '_, D> {
impl<D: ObjectView> ValueView for BorrowedFields<'_, '_, '_, D> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
@@ -288,11 +288,11 @@ impl<D: ObjectView> ValueView for OwnedFields<'_, D> {
}
}
struct ArraySource<'a, 'map, D: ObjectView> {
s: &'a BorrowedFields<'a, 'map, D>,
struct ArraySource<'a, 'doc, 'map, D: ObjectView> {
s: &'a BorrowedFields<'a, 'doc, 'map, D>,
}
impl<D: ObjectView> fmt::Display for ArraySource<'_, '_, D> {
impl<D: ObjectView> fmt::Display for ArraySource<'_, '_, '_, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[")?;
for item in self.s.values() {
@@ -303,11 +303,11 @@ impl<D: ObjectView> fmt::Display for ArraySource<'_, '_, D> {
}
}
struct ArrayRender<'a, 'map, D: ObjectView> {
s: &'a BorrowedFields<'a, 'map, D>,
struct ArrayRender<'a, 'doc, 'map, D: ObjectView> {
s: &'a BorrowedFields<'a, 'doc, 'map, D>,
}
impl<D: ObjectView> fmt::Display for ArrayRender<'_, '_, D> {
impl<D: ObjectView> fmt::Display for ArrayRender<'_, '_, '_, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for item in self.s.values() {
write!(f, "{}", item.render())?;

View File

@@ -9,12 +9,11 @@ use std::fmt::Debug;
use std::num::NonZeroUsize;
use bumpalo::Bump;
use document::ParseableDocument;
pub(crate) use document::{Document, ParseableDocument};
use error::{NewPromptError, RenderPromptError};
use fields::{BorrowedFields, OwnedFields};
pub use fields::{BorrowedFields, OwnedFields};
use self::context::Context;
use self::document::Document;
pub use self::context::Context;
use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
use crate::update::del_add::DelAdd;
use crate::GlobalFieldsIdsMap;
@@ -108,8 +107,8 @@ impl Prompt {
}
pub fn render_document<
'a, // lifetime of the borrow of the document
'doc: 'a, // lifetime of the allocator, will live for an entire chunk of documents
'a, // lifetime of the borrow of the document
'doc, // lifetime of the allocator, will live for an entire chunk of documents
>(
&self,
external_docid: &str,

View File

@@ -7,6 +7,7 @@ use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::search::new::{distinct_fid, distinct_single_docid};
use crate::search::SemanticSearch;
use crate::vector::SearchQuery;
use crate::{Index, MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult {
@@ -226,12 +227,9 @@ impl Search<'_> {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
}
// no vector search against placeholder search
let Some(query) = search.query.take() else {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
};
// no embedder, no semantic search
let Some(SemanticSearch { vector, embedder_name, embedder, quantized }) = semantic else {
let Some(SemanticSearch { vector, embedder_name, embedder, quantized, media }) = semantic
else {
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
};
@@ -242,9 +240,17 @@ impl Search<'_> {
let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
let _entered = span.enter();
let q = search.query.as_deref();
let media = media.as_ref();
let query = match (q, media) {
(Some(text), None) => SearchQuery::Text(text),
(q, media) => SearchQuery::Media { q, media },
};
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_search(&query, Some(deadline)) {
match embedder.embed_search(query, Some(deadline)) {
Ok(embedding) => embedding,
Err(error) => {
tracing::error!(error=%error, "Embedding failed");
@@ -258,8 +264,13 @@ impl Search<'_> {
}
};
search.semantic =
Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder, quantized });
search.semantic = Some(SemanticSearch {
vector: Some(vector_query),
embedder_name,
embedder,
quantized,
media,
});
// TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?;

View File

@@ -12,7 +12,7 @@ use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats};
use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features};
use crate::index::MatchingStrategy;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::Embedder;
use crate::vector::{Embedder, Embedding};
use crate::{
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Error, Index,
Result, SearchContext, TimeBudget, UserError,
@@ -32,6 +32,7 @@ pub mod similar;
#[derive(Debug, Clone)]
pub struct SemanticSearch {
vector: Option<Vec<f32>>,
media: Option<serde_json::Value>,
embedder_name: String,
embedder: Arc<Embedder>,
quantized: bool,
@@ -95,9 +96,10 @@ impl<'a> Search<'a> {
embedder_name: String,
embedder: Arc<Embedder>,
quantized: bool,
vector: Option<Vec<f32>>,
vector: Option<Embedding>,
media: Option<serde_json::Value>,
) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector });
self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector, media });
self
}
@@ -238,26 +240,30 @@ impl<'a> Search<'a> {
degraded,
used_negative_operator,
} = match self.semantic.as_ref() {
Some(SemanticSearch { vector: Some(vector), embedder_name, embedder, quantized }) => {
execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
self.exhaustive_number_hits,
self.max_total_hits,
universe,
&self.sort_criteria,
&self.distinct,
self.geo_param,
self.offset,
self.limit,
embedder_name,
embedder,
*quantized,
self.time_budget.clone(),
self.ranking_score_threshold,
)?
}
Some(SemanticSearch {
vector: Some(vector),
embedder_name,
embedder,
quantized,
media: _,
}) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
self.exhaustive_number_hits,
self.max_total_hits,
universe,
&self.sort_criteria,
&self.distinct,
self.geo_param,
self.offset,
self.limit,
embedder_name,
embedder,
*quantized,
self.time_budget.clone(),
self.ranking_score_threshold,
)?,
_ => execute_search(
&mut ctx,
self.query.as_deref(),

View File

@@ -8,7 +8,7 @@ use maplit::{btreemap, hashset};
use crate::progress::Progress;
use crate::update::new::indexer;
use crate::update::{IndexerConfig, Settings};
use crate::vector::EmbeddingConfigs;
use crate::vector::RuntimeEmbedders;
use crate::{db_snap, Criterion, FilterableAttributesRule, Index};
pub const CONTENT: &str = include_str!("../../../../tests/assets/test_set.ndjson");
use crate::constants::RESERVED_GEO_FIELD_NAME;
@@ -55,7 +55,7 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index {
let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let mut file = tempfile::tempfile().unwrap();

View File

@@ -32,8 +32,8 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
) -> Result<Self> {
let embedder_index = ctx
.index
.embedder_category_id
.get(ctx.txn, embedder_name)?
.embedding_configs()
.embedder_id(ctx.txn, embedder_name)?
.ok_or_else(|| crate::UserError::InvalidSearchEmbedder(embedder_name.to_owned()))?;
Ok(Self {

View File

@@ -64,10 +64,13 @@ impl<'a> Similar<'a> {
let universe = universe;
let embedder_index =
self.index.embedder_category_id.get(self.rtxn, &self.embedder_name)?.ok_or_else(
|| crate::UserError::InvalidSimilarEmbedder(self.embedder_name.to_owned()),
)?;
let embedder_index = self
.index
.embedding_configs()
.embedder_id(self.rtxn, &self.embedder_name)?
.ok_or_else(|| {
crate::UserError::InvalidSimilarEmbedder(self.embedder_name.to_owned())
})?;
let reader = ArroyWrapper::new(self.index.vector_arroy, embedder_index, self.quantized);
let results = reader.nns_by_item(

View File

@@ -18,7 +18,7 @@ use crate::update::{
self, IndexDocumentsConfig, IndexDocumentsMethod, IndexerConfig, Setting, Settings,
};
use crate::vector::settings::{EmbedderSource, EmbeddingSettings};
use crate::vector::EmbeddingConfigs;
use crate::vector::RuntimeEmbedders;
use crate::{db_snap, obkv_to_json, Filter, FilterableAttributesRule, Index, Search, SearchResult};
pub(crate) struct TempIndex {
@@ -66,7 +66,7 @@ impl TempIndex {
let db_fields_ids_map = self.inner.fields_ids_map(&rtxn)?;
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.embedding_configs;
let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.runtime_embedders;
let mut indexer = indexer::DocumentOperation::new();
match self.index_documents_config.update_method {
IndexDocumentsMethod::ReplaceDocuments => {
@@ -151,7 +151,7 @@ impl TempIndex {
let db_fields_ids_map = self.inner.fields_ids_map(&rtxn)?;
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.embedding_configs;
let embedders = InnerIndexSettings::from_index(&self.inner, &rtxn, None)?.runtime_embedders;
let mut indexer = indexer::DocumentOperation::new();
let external_document_ids: Vec<_> =
@@ -223,7 +223,7 @@ fn aborting_indexation() {
let db_fields_ids_map = index.inner.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let payload = documents!([
{ "id": 1, "name": "kevin" },

View File

@@ -1,7 +1,7 @@
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use rayon::{ThreadPool, ThreadPoolBuilder};
use rayon::{BroadcastContext, ThreadPool, ThreadPoolBuilder};
use thiserror::Error;
/// A rayon ThreadPool wrapper that can catch panics in the pool
@@ -32,6 +32,22 @@ impl ThreadPoolNoAbort {
}
}
pub fn broadcast<OP, R>(&self, op: OP) -> Result<Vec<R>, PanicCatched>
where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
self.active_operations.fetch_add(1, Ordering::Relaxed);
let output = self.thread_pool.broadcast(op);
self.active_operations.fetch_sub(1, Ordering::Relaxed);
// While reseting the pool panic catcher we return an error if we catched one.
if self.pool_catched_panic.swap(false, Ordering::SeqCst) {
Err(PanicCatched)
} else {
Ok(output)
}
}
pub fn current_num_threads(&self) -> usize {
self.thread_pool.current_num_threads()
}

View File

@@ -64,11 +64,7 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
// Remove all user-provided bits from the configs
let mut configs = self.index.embedding_configs(self.wtxn)?;
for config in configs.iter_mut() {
config.user_provided.clear();
}
self.index.put_embedding_configs(self.wtxn, configs)?;
self.index.embedding_configs().clear_embedder_info_docids(self.wtxn)?;
// Clear the other databases.
external_documents_ids.clear(self.wtxn)?;

View File

@@ -23,16 +23,17 @@ use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, Extra
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
use self::extract_geo_points::extract_geo_points;
use self::extract_vector_points::{
extract_embeddings, extract_vector_points, ExtractedVectorPoints,
extract_embeddings_from_prompts, extract_vector_points, ExtractedVectorPoints,
};
use self::extract_word_docids::extract_word_docids;
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
use self::extract_word_position_docids::extract_word_position_docids;
use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters};
use super::{helpers, TypedChunk};
use crate::index::IndexEmbeddingConfig;
use crate::progress::EmbedderStats;
use crate::update::index_documents::extract::extract_vector_points::extract_embeddings_from_fragments;
use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::db::EmbedderInfo;
use crate::vector::error::PossibleEmbeddingMistakes;
use crate::{FieldId, Result, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
@@ -46,9 +47,9 @@ pub(crate) fn data_from_obkv_documents(
indexer: GrenadParameters,
lmdb_writer_sx: Sender<Result<TypedChunk>>,
primary_key_id: FieldId,
embedders_configs: Arc<Vec<IndexEmbeddingConfig>>,
settings_diff: Arc<InnerIndexSettingsDiff>,
max_positions_per_attributes: Option<u32>,
embedder_info: Arc<Vec<(String, EmbedderInfo)>>,
possible_embedding_mistakes: Arc<PossibleEmbeddingMistakes>,
embedder_stats: &Arc<EmbedderStats>,
) -> Result<()> {
@@ -61,8 +62,8 @@ pub(crate) fn data_from_obkv_documents(
original_documents_chunk,
indexer,
lmdb_writer_sx.clone(),
embedders_configs.clone(),
settings_diff.clone(),
embedder_info.clone(),
possible_embedding_mistakes.clone(),
embedder_stats.clone(),
)
@@ -213,7 +214,7 @@ fn run_extraction_task<FE, FS, M>(
})
}
fn request_threads() -> &'static ThreadPoolNoAbort {
pub fn request_threads() -> &'static ThreadPoolNoAbort {
static REQUEST_THREADS: OnceLock<ThreadPoolNoAbort> = OnceLock::new();
REQUEST_THREADS.get_or_init(|| {
@@ -231,8 +232,8 @@ fn send_original_documents_data(
original_documents_chunk: Result<grenad::Reader<BufReader<File>>>,
indexer: GrenadParameters,
lmdb_writer_sx: Sender<Result<TypedChunk>>,
embedders_configs: Arc<Vec<IndexEmbeddingConfig>>,
settings_diff: Arc<InnerIndexSettingsDiff>,
embedder_info: Arc<Vec<(String, EmbedderInfo)>>,
possible_embedding_mistakes: Arc<PossibleEmbeddingMistakes>,
embedder_stats: Arc<EmbedderStats>,
) -> Result<()> {
@@ -241,11 +242,10 @@ fn send_original_documents_data(
let index_vectors = (settings_diff.reindex_vectors() || !settings_diff.settings_update_only())
// no point in indexing vectors without embedders
&& (!settings_diff.new.embedding_configs.inner_as_ref().is_empty());
&& (!settings_diff.new.runtime_embedders.inner_as_ref().is_empty());
if index_vectors {
let settings_diff = settings_diff.clone();
let embedders_configs = embedders_configs.clone();
let original_documents_chunk = original_documents_chunk.clone();
let lmdb_writer_sx = lmdb_writer_sx.clone();
@@ -253,8 +253,8 @@ fn send_original_documents_data(
match extract_vector_points(
original_documents_chunk.clone(),
indexer,
&embedders_configs,
&settings_diff,
embedder_info.as_slice(),
&possible_embedding_mistakes,
) {
Ok((extracted_vectors, unused_vectors_distribution)) => {
@@ -262,16 +262,16 @@ fn send_original_documents_data(
manual_vectors,
remove_vectors,
prompts,
inputs,
embedder_name,
embedder,
add_to_user_provided,
remove_from_user_provided,
runtime,
embedding_status_delta,
} in extracted_vectors
{
let embeddings = match extract_embeddings(
let embeddings_from_prompts = match extract_embeddings_from_prompts(
prompts,
indexer,
embedder.clone(),
runtime.clone(),
&embedder_name,
&possible_embedding_mistakes,
&embedder_stats,
@@ -284,18 +284,37 @@ fn send_original_documents_data(
None
}
};
let embeddings_from_fragments = match extract_embeddings_from_fragments(
inputs,
indexer,
runtime.clone(),
&embedder_name,
&possible_embedding_mistakes,
&embedder_stats,
&unused_vectors_distribution,
request_threads(),
) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx.send(Err(error));
None
}
};
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().is_none_or(|e| e.is_empty()))
&& embeddings_from_prompts.as_ref().is_none_or(|e| e.is_empty())
&& embeddings_from_fragments.as_ref().is_none_or(|e| e.is_empty()))
{
let _ = lmdb_writer_sx.send(Ok(TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension: embedder.dimensions(),
embeddings_from_prompts,
embeddings_from_fragments,
expected_dimension: runtime.embedder.dimensions(),
manual_vectors,
embedder_name,
add_to_user_provided,
remove_from_user_provided,
embedding_status_delta,
}));
}
}

View File

@@ -12,6 +12,7 @@ use std::sync::Arc;
use crossbeam_channel::{Receiver, Sender};
use enrich::enrich_documents_batch;
pub use extract::request_threads;
use grenad::{Merger, MergerBuilder};
use hashbrown::HashMap;
use heed::types::Str;
@@ -37,7 +38,8 @@ pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
};
use crate::vector::{ArroyWrapper, EmbeddingConfigs};
use crate::vector::db::EmbedderInfo;
use crate::vector::{ArroyWrapper, RuntimeEmbedders};
use crate::{CboRoaringBitmapCodec, Index, Result, UserError};
static MERGED_DATABASE_COUNT: usize = 7;
@@ -80,7 +82,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
should_abort: FA,
added_documents: u64,
deleted_documents: u64,
embedders: EmbeddingConfigs,
embedders: RuntimeEmbedders,
embedder_stats: &'t Arc<EmbedderStats>,
}
@@ -171,7 +173,7 @@ where
Ok((self, Ok(indexed_documents)))
}
pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self {
pub fn with_embedders(mut self, embedders: RuntimeEmbedders) -> Self {
self.embedders = embedders;
self
}
@@ -225,7 +227,13 @@ where
settings_diff.new.recompute_searchables(self.wtxn, self.index)?;
let settings_diff = Arc::new(settings_diff);
let embedders_configs = Arc::new(self.index.embedding_configs(self.wtxn)?);
let embedder_infos: heed::Result<Vec<(String, EmbedderInfo)>> = self
.index
.embedding_configs()
.iter_embedder_info(self.wtxn)?
.map(|res| res.map(|(name, info)| (name.to_owned(), info)))
.collect();
let embedder_infos = Arc::new(embedder_infos?);
let possible_embedding_mistakes =
crate::vector::error::PossibleEmbeddingMistakes::new(&field_distribution);
@@ -327,9 +335,9 @@ where
pool_params,
lmdb_writer_sx.clone(),
primary_key_id,
embedders_configs.clone(),
settings_diff_cloned,
max_positions_per_attributes,
embedder_infos,
Arc::new(possible_embedding_mistakes),
&embedder_stats
)
@@ -429,21 +437,21 @@ where
TypedChunk::VectorPoints {
expected_dimension,
remove_vectors,
embeddings,
embeddings_from_prompts,
embeddings_from_fragments,
manual_vectors,
embedder_name,
add_to_user_provided,
remove_from_user_provided,
embedding_status_delta,
} => {
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
remove_vectors,
embeddings,
embeddings_from_prompts,
embeddings_from_fragments,
expected_dimension,
manual_vectors,
embedder_name,
add_to_user_provided,
remove_from_user_provided,
embedding_status_delta,
}
}
otherwise => otherwise,
@@ -479,7 +487,7 @@ where
// we should insert it in `dimension`
for (name, action) in settings_diff.embedding_config_updates.iter() {
if action.is_being_quantized && !dimension.contains_key(name.as_str()) {
let index = self.index.embedder_category_id.get(self.wtxn, name)?.ok_or(
let index = self.index.embedding_configs().embedder_id(self.wtxn, name)?.ok_or(
InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
@@ -487,7 +495,9 @@ where
)?;
let reader =
ArroyWrapper::new(self.index.vector_arroy, index, action.was_quantized);
let dim = reader.dimensions(self.wtxn)?;
let Some(dim) = reader.dimensions(self.wtxn)? else {
continue;
};
dimension.insert(name.to_string(), dim);
}
}
@@ -497,12 +507,19 @@ where
let vector_arroy = self.index.vector_arroy;
let cancel = &self.should_abort;
let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
let embedder_index =
self.index.embedding_configs().embedder_id(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
},
)?;
let embedder_config = settings_diff.embedding_config_updates.get(&embedder_name);
let was_quantized =
settings_diff.old.embedding_configs.get(&embedder_name).is_some_and(|conf| conf.2);
let was_quantized = settings_diff
.old
.runtime_embedders
.get(&embedder_name)
.is_some_and(|conf| conf.is_quantized);
let is_quantizing = embedder_config.is_some_and(|action| action.is_being_quantized);
pool.install(|| {
@@ -772,11 +789,11 @@ mod tests {
use crate::constants::RESERVED_GEO_FIELD_NAME;
use crate::documents::mmap_from_objects;
use crate::index::tests::TempIndex;
use crate::index::IndexEmbeddingConfig;
use crate::progress::Progress;
use crate::search::TermsMatchingStrategy;
use crate::update::new::indexer;
use crate::update::Setting;
use crate::vector::db::IndexEmbeddingConfig;
use crate::{all_obkv_to_json, db_snap, Filter, FilterableAttributesRule, Search, UserError};
#[test]
@@ -2027,7 +2044,7 @@ mod tests {
new_fields_ids_map,
primary_key,
&document_changes,
EmbeddingConfigs::default(),
RuntimeEmbedders::default(),
&|| false,
&Progress::default(),
&Default::default(),
@@ -2115,7 +2132,7 @@ mod tests {
new_fields_ids_map,
primary_key,
&document_changes,
EmbeddingConfigs::default(),
RuntimeEmbedders::default(),
&|| false,
&Progress::default(),
&Default::default(),
@@ -2276,7 +2293,7 @@ mod tests {
]);
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.replace_documents(&documents).unwrap();
indexer.delete_documents(&["2"]);
@@ -2342,7 +2359,7 @@ mod tests {
indexer.delete_documents(&["1", "2"]);
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let (document_changes, _operation_stats, primary_key) = indexer
.into_changes(
&indexer_alloc,
@@ -2393,7 +2410,7 @@ mod tests {
{ "id": 3, "name": "jean", "age": 25 },
]);
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.update_documents(&documents).unwrap();
@@ -2445,7 +2462,7 @@ mod tests {
{ "id": 3, "legs": 4 },
]);
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.update_documents(&documents).unwrap();
indexer.delete_documents(&["1", "2"]);
@@ -2495,7 +2512,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.delete_documents(&["1", "2"]);
@@ -2551,7 +2568,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.delete_documents(&["1", "2", "1", "2"]);
@@ -2610,7 +2627,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let documents = documents!([
@@ -2660,7 +2677,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.delete_documents(&["1"]);
@@ -2774,6 +2791,8 @@ mod tests {
document_template: Setting::NotSet,
document_template_max_bytes: Setting::NotSet,
url: Setting::NotSet,
indexing_fragments: Setting::NotSet,
search_fragments: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
distribution: Setting::NotSet,
@@ -2800,17 +2819,27 @@ mod tests {
.unwrap();
let rtxn = index.read_txn().unwrap();
let mut embedding_configs = index.embedding_configs(&rtxn).unwrap();
let IndexEmbeddingConfig { name: embedder_name, config: embedder, user_provided } =
let embedders = index.embedding_configs();
let mut embedding_configs = embedders.embedding_configs(&rtxn).unwrap();
let IndexEmbeddingConfig { name: embedder_name, config: embedder, fragments } =
embedding_configs.pop().unwrap();
let info = embedders.embedder_info(&rtxn, &embedder_name).unwrap().unwrap();
insta::assert_snapshot!(info.embedder_id, @"0");
insta::assert_debug_snapshot!(info.embedding_status.user_provided_docids(), @"RoaringBitmap<[0, 1, 2]>");
insta::assert_debug_snapshot!(info.embedding_status.skip_regenerate_docids(), @"RoaringBitmap<[0, 1, 2]>");
insta::assert_snapshot!(embedder_name, @"manual");
insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[0, 1, 2]>");
insta::assert_debug_snapshot!(fragments, @r###"
FragmentConfigs(
[],
)
"###);
let embedder = std::sync::Arc::new(
crate::vector::Embedder::new(embedder.embedder_options, 0).unwrap(),
);
let res = index
.search(&rtxn)
.semantic(embedder_name, embedder, false, Some([0.0, 1.0, 2.0].to_vec()))
.semantic(embedder_name, embedder, false, Some([0.0, 1.0, 2.0].to_vec()), None)
.execute()
.unwrap();
assert_eq!(res.documents_ids.len(), 3);
@@ -2859,7 +2888,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
// OP
@@ -2920,7 +2949,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.delete_documents(&["1"]);
@@ -2979,7 +3008,7 @@ mod tests {
let mut new_fields_ids_map = db_fields_ids_map.clone();
let indexer_alloc = Bump::new();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let documents = documents!([

View File

@@ -31,7 +31,7 @@ use crate::update::index_documents::GrenadParameters;
use crate::update::settings::{InnerIndexSettings, InnerIndexSettingsDiff};
use crate::update::{AvailableIds, UpdateIndexingStep};
use crate::vector::parsed_vectors::{ExplicitVectors, VectorOrArrayOfVectors};
use crate::vector::settings::WriteBackToDocuments;
use crate::vector::settings::{RemoveFragments, WriteBackToDocuments};
use crate::vector::ArroyWrapper;
use crate::{FieldDistribution, FieldId, FieldIdMapMissingEntry, Index, Result};
@@ -933,10 +933,47 @@ impl<'a, 'i> Transform<'a, 'i> {
// delete all vectors from the embedders that need removal
for (_, (reader, _)) in readers {
let dimensions = reader.dimensions(wtxn)?;
let Some(dimensions) = reader.dimensions(wtxn)? else {
continue;
};
reader.clear(wtxn, dimensions)?;
}
// remove all vectors for the specified fragments
for (embedder_name, RemoveFragments { fragment_ids }, was_quantized) in
settings_diff.embedding_config_updates.iter().filter_map(|(name, action)| {
action.remove_fragments().map(|fragments| (name, fragments, action.was_quantized))
})
{
let Some(infos) = self.index.embedding_configs().embedder_info(wtxn, embedder_name)?
else {
continue;
};
let arroy =
ArroyWrapper::new(self.index.vector_arroy, infos.embedder_id, was_quantized);
let Some(dimensions) = arroy.dimensions(wtxn)? else {
continue;
};
for fragment_id in fragment_ids {
// we must keep the user provided embeddings that ended up in this store
if infos.embedding_status.user_provided_docids().is_empty() {
// no user provided: clear store
arroy.clear_store(wtxn, *fragment_id, dimensions)?;
continue;
}
// some user provided, remove only the ids that are not user provided
let to_delete = arroy.items_in_store(wtxn, *fragment_id, |items| {
items - infos.embedding_status.user_provided_docids()
})?;
for to_delete in to_delete {
arroy.del_item_in_store(wtxn, to_delete, *fragment_id, dimensions)?;
}
}
}
let grenad_params = GrenadParameters {
chunk_compression_type: self.indexer_settings.chunk_compression_type,
chunk_compression_level: self.indexer_settings.chunk_compression_level,

View File

@@ -4,6 +4,7 @@ use std::fs::File;
use std::io::{self, BufReader};
use bytemuck::allocation::pod_collect_to_vec;
use byteorder::{BigEndian, ReadBytesExt as _};
use grenad::{MergeFunction, Merger, MergerBuilder};
use heed::types::Bytes;
use heed::{BytesDecode, RwTxn};
@@ -18,7 +19,6 @@ use super::helpers::{
use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind};
use crate::facet::FacetType;
use crate::index::db_name::DOCUMENTS;
use crate::index::IndexEmbeddingConfig;
use crate::proximity::MAX_DISTANCE;
use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd};
use crate::update::facet::FacetsUpdate;
@@ -26,6 +26,7 @@ use crate::update::index_documents::helpers::{
as_cloneable_grenad, try_split_array_at, KeepLatestObkv,
};
use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::db::{EmbeddingStatusDelta, IndexEmbeddingConfig};
use crate::vector::ArroyWrapper;
use crate::{
lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, FieldId, GeoPoint, Index, InternalError,
@@ -86,12 +87,14 @@ pub(crate) enum TypedChunk {
GeoPoints(grenad::Reader<BufReader<File>>),
VectorPoints {
remove_vectors: grenad::Reader<BufReader<File>>,
embeddings: Option<grenad::Reader<BufReader<File>>>,
// docid -> vector
embeddings_from_prompts: Option<grenad::Reader<BufReader<File>>>,
// docid, extractor_id -> Option<vector>,
embeddings_from_fragments: Option<grenad::Reader<BufReader<File>>>,
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
add_to_user_provided: RoaringBitmap,
remove_from_user_provided: RoaringBitmap,
embedding_status_delta: EmbeddingStatusDelta,
},
}
@@ -155,6 +158,7 @@ pub(crate) fn write_typed_chunk_into_index(
let mut iter = merger.into_stream_merger_iter()?;
let embedders: BTreeSet<_> = index
.embedding_configs()
.embedding_configs(wtxn)?
.into_iter()
.map(|IndexEmbeddingConfig { name, .. }| name)
@@ -614,57 +618,66 @@ pub(crate) fn write_typed_chunk_into_index(
let span = tracing::trace_span!(target: "indexing::write_db", "vector_points");
let _entered = span.enter();
let embedders = index.embedding_configs();
let mut remove_vectors_builder = MergerBuilder::new(KeepFirst);
let mut manual_vectors_builder = MergerBuilder::new(KeepFirst);
let mut embeddings_builder = MergerBuilder::new(KeepFirst);
let mut add_to_user_provided = RoaringBitmap::new();
let mut remove_from_user_provided = RoaringBitmap::new();
let mut embeddings_from_prompts_builder = MergerBuilder::new(KeepFirst);
let mut embeddings_from_fragments_builder = MergerBuilder::new(KeepFirst);
let mut params = None;
let mut infos = None;
for typed_chunk in typed_chunks {
let TypedChunk::VectorPoints {
remove_vectors,
manual_vectors,
embeddings,
embeddings_from_prompts,
embeddings_from_fragments,
expected_dimension,
embedder_name,
add_to_user_provided: aud,
remove_from_user_provided: rud,
embedding_status_delta,
} = typed_chunk
else {
unreachable!();
};
if infos.is_none() {
infos = Some(embedders.embedder_info(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
},
)?);
}
params = Some((expected_dimension, embedder_name));
remove_vectors_builder.push(remove_vectors.into_cursor()?);
manual_vectors_builder.push(manual_vectors.into_cursor()?);
if let Some(embeddings) = embeddings {
embeddings_builder.push(embeddings.into_cursor()?);
if let Some(embeddings) = embeddings_from_prompts {
embeddings_from_prompts_builder.push(embeddings.into_cursor()?);
}
if let Some(embeddings) = embeddings_from_fragments {
embeddings_from_fragments_builder.push(embeddings.into_cursor()?);
}
if let Some(infos) = &mut infos {
embedding_status_delta.apply_to(&mut infos.embedding_status);
}
add_to_user_provided |= aud;
remove_from_user_provided |= rud;
}
// typed chunks has always at least 1 chunk.
let Some((expected_dimension, embedder_name)) = params else { unreachable!() };
let Some(infos) = infos else { unreachable!() };
let mut embedding_configs = index.embedding_configs(wtxn)?;
let index_embedder_config = embedding_configs
.iter_mut()
.find(|IndexEmbeddingConfig { name, .. }| name == &embedder_name)
.unwrap();
index_embedder_config.user_provided -= remove_from_user_provided;
index_embedder_config.user_provided |= add_to_user_provided;
embedders.put_embedder_info(wtxn, &embedder_name, &infos)?;
index.put_embedding_configs(wtxn, embedding_configs)?;
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
let binary_quantized =
settings_diff.old.embedding_configs.get(&embedder_name).is_some_and(|conf| conf.2);
let binary_quantized = settings_diff
.old
.runtime_embedders
.get(&embedder_name)
.is_some_and(|conf| conf.is_quantized);
// FIXME: allow customizing distance
let writer = ArroyWrapper::new(index.vector_arroy, embedder_index, binary_quantized);
let writer = ArroyWrapper::new(index.vector_arroy, infos.embedder_id, binary_quantized);
// remove vectors for docids we want them removed
let merger = remove_vectors_builder.build();
@@ -674,8 +687,8 @@ pub(crate) fn write_typed_chunk_into_index(
writer.del_items(wtxn, expected_dimension, docid)?;
}
// add generated embeddings
let merger = embeddings_builder.build();
// add generated embeddings -- from prompts
let merger = embeddings_from_prompts_builder.build();
let mut iter = merger.into_stream_merger_iter()?;
while let Some((key, value)) = iter.next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
@@ -702,6 +715,24 @@ pub(crate) fn write_typed_chunk_into_index(
writer.add_items(wtxn, docid, &embeddings)?;
}
// add generated embeddings -- from fragments
let merger = embeddings_from_fragments_builder.build();
let mut iter = merger.into_stream_merger_iter()?;
while let Some((mut key, value)) = iter.next()? {
let docid = key.read_u32::<BigEndian>().unwrap();
let extractor_id = key.read_u8().unwrap();
if value.is_empty() {
writer.del_item_in_store(wtxn, docid, extractor_id, expected_dimension)?;
} else {
let data = pod_collect_to_vec(value);
// it is a code error to have embeddings and not expected_dimension
if data.len() != expected_dimension {
panic!("wrong dimensions")
}
writer.add_item_in_store(wtxn, docid, extractor_id, &data)?;
}
}
// perform the manual diff
let merger = manual_vectors_builder.build();
let mut iter = merger.into_stream_merger_iter()?;

View File

@@ -4,7 +4,7 @@ pub use self::clear_documents::ClearDocuments;
pub use self::concurrent_available_ids::ConcurrentAvailableIds;
pub use self::facet::bulk::FacetsUpdateBulk;
pub use self::facet::incremental::FacetsUpdateIncrementalInner;
pub use self::index_documents::*;
pub use self::index_documents::{request_threads, *};
pub use self::indexer_config::{default_thread_pool_and_threads, IndexerConfig};
pub use self::new::ChannelCongestion;
pub use self::settings::{validate_embedding_settings, Setting, Settings};

View File

@@ -138,6 +138,7 @@ pub enum ReceiverAction {
WakeUp,
LargeEntry(LargeEntry),
LargeVectors(LargeVectors),
LargeVector(LargeVector),
}
/// An entry that cannot fit in the BBQueue buffers has been
@@ -174,6 +175,24 @@ impl LargeVectors {
}
}
#[derive(Debug)]
pub struct LargeVector {
/// The document id associated to the large embedding.
pub docid: DocumentId,
/// The embedder id in which to insert the large embedding.
pub embedder_id: u8,
/// The extractor id in which to insert the large embedding.
pub extractor_id: u8,
/// The large embedding that must be written.
pub embedding: Mmap,
}
impl LargeVector {
pub fn read_embedding(&self, dimensions: usize) -> &[f32] {
self.embedding.chunks_exact(dimensions).map(bytemuck::cast_slice).next().unwrap()
}
}
impl<'a> WriterBbqueueReceiver<'a> {
/// Tries to receive an action to do until the timeout occurs
/// and if it does, consider it as a spurious wake up.
@@ -238,6 +257,7 @@ pub enum EntryHeader {
DbOperation(DbOperation),
ArroyDeleteVector(ArroyDeleteVector),
ArroySetVectors(ArroySetVectors),
ArroySetVector(ArroySetVector),
}
impl EntryHeader {
@@ -250,6 +270,7 @@ impl EntryHeader {
EntryHeader::DbOperation(_) => 0,
EntryHeader::ArroyDeleteVector(_) => 1,
EntryHeader::ArroySetVectors(_) => 2,
EntryHeader::ArroySetVector(_) => 3,
}
}
@@ -274,11 +295,17 @@ impl EntryHeader {
Self::variant_size() + mem::size_of::<ArroySetVectors>() + embedding_size * count
}
fn total_set_vector_size(dimensions: usize) -> usize {
let embedding_size = dimensions * mem::size_of::<f32>();
Self::variant_size() + mem::size_of::<ArroySetVector>() + embedding_size
}
fn header_size(&self) -> usize {
let payload_size = match self {
EntryHeader::DbOperation(op) => mem::size_of_val(op),
EntryHeader::ArroyDeleteVector(adv) => mem::size_of_val(adv),
EntryHeader::ArroySetVectors(asvs) => mem::size_of_val(asvs),
EntryHeader::ArroySetVector(asv) => mem::size_of_val(asv),
};
Self::variant_size() + payload_size
}
@@ -301,6 +328,11 @@ impl EntryHeader {
let header = checked::pod_read_unaligned(header_bytes);
EntryHeader::ArroySetVectors(header)
}
3 => {
let header_bytes = &remaining[..mem::size_of::<ArroySetVector>()];
let header = checked::pod_read_unaligned(header_bytes);
EntryHeader::ArroySetVector(header)
}
id => panic!("invalid variant id: {id}"),
}
}
@@ -311,6 +343,7 @@ impl EntryHeader {
EntryHeader::DbOperation(op) => bytemuck::bytes_of(op),
EntryHeader::ArroyDeleteVector(adv) => bytemuck::bytes_of(adv),
EntryHeader::ArroySetVectors(asvs) => bytemuck::bytes_of(asvs),
EntryHeader::ArroySetVector(asv) => bytemuck::bytes_of(asv),
};
*first = self.variant_id();
remaining.copy_from_slice(payload_bytes);
@@ -379,6 +412,37 @@ impl ArroySetVectors {
}
}
#[derive(Debug, Clone, Copy, NoUninit, CheckedBitPattern)]
#[repr(C)]
/// The embeddings are in the remaining space and represents
/// non-aligned [f32] each with dimensions f32s.
pub struct ArroySetVector {
pub docid: DocumentId,
pub embedder_id: u8,
pub extractor_id: u8,
_padding: [u8; 2],
}
impl ArroySetVector {
fn embeddings_bytes<'a>(frame: &'a FrameGrantR<'_>) -> &'a [u8] {
let skip = EntryHeader::variant_size() + mem::size_of::<Self>();
&frame[skip..]
}
/// Read the embedding and write it into an aligned `f32` Vec.
pub fn read_all_embeddings_into_vec<'v>(
&self,
frame: &FrameGrantR<'_>,
vec: &'v mut Vec<f32>,
) -> &'v [f32] {
let embeddings_bytes = Self::embeddings_bytes(frame);
let embeddings_count = embeddings_bytes.len() / mem::size_of::<f32>();
vec.resize(embeddings_count, 0.0);
bytemuck::cast_slice_mut(vec.as_mut()).copy_from_slice(embeddings_bytes);
&vec[..]
}
}
#[derive(Debug, Clone, Copy, NoUninit, CheckedBitPattern)]
#[repr(u16)]
pub enum Database {
@@ -398,6 +462,7 @@ pub enum Database {
FacetIdStringDocids,
FieldIdDocidFacetStrings,
FieldIdDocidFacetF64s,
VectorEmbedderCategoryId,
}
impl Database {
@@ -419,6 +484,7 @@ impl Database {
Database::FacetIdStringDocids => index.facet_id_string_docids.remap_types(),
Database::FieldIdDocidFacetStrings => index.field_id_docid_facet_strings.remap_types(),
Database::FieldIdDocidFacetF64s => index.field_id_docid_facet_f64s.remap_types(),
Database::VectorEmbedderCategoryId => index.embedder_category_id.remap_types(),
}
}
@@ -440,6 +506,7 @@ impl Database {
Database::FacetIdStringDocids => db_name::FACET_ID_STRING_DOCIDS,
Database::FieldIdDocidFacetStrings => db_name::FIELD_ID_DOCID_FACET_STRINGS,
Database::FieldIdDocidFacetF64s => db_name::FIELD_ID_DOCID_FACET_F64S,
Database::VectorEmbedderCategoryId => db_name::VECTOR_EMBEDDER_CATEGORY_ID,
}
}
}
@@ -568,6 +635,82 @@ impl<'b> ExtractorBbqueueSender<'b> {
Ok(())
}
fn set_vector_for_extractor(
&self,
docid: u32,
embedder_id: u8,
extractor_id: u8,
embedding: Option<Embedding>,
) -> crate::Result<()> {
let max_grant = self.max_grant;
let refcell = self.producers.get().unwrap();
let mut producer = refcell.0.borrow_mut_or_yield();
// If there are no vectors we specify the dimensions
// to zero to allocate no extra space at all
let dimensions = embedding.as_ref().map_or(0, |emb| emb.len());
let arroy_set_vector =
ArroySetVector { docid, embedder_id, extractor_id, _padding: [0; 2] };
let payload_header = EntryHeader::ArroySetVector(arroy_set_vector);
let total_length = EntryHeader::total_set_vector_size(dimensions);
if total_length > max_grant {
let mut value_file = tempfile::tempfile().map(BufWriter::new)?;
let embedding = embedding.expect("set_vector without a vector does not fit in RAM");
let mut embedding_bytes = bytemuck::cast_slice(&embedding);
io::copy(&mut embedding_bytes, &mut value_file)?;
let value_file = value_file.into_inner().map_err(|ie| ie.into_error())?;
let embedding = unsafe { Mmap::map(&value_file)? };
let large_vectors = LargeVector { docid, embedder_id, extractor_id, embedding };
self.sender.send(ReceiverAction::LargeVector(large_vectors)).unwrap();
return Ok(());
}
// Spin loop to have a frame the size we requested.
reserve_and_write_grant(
&mut producer,
total_length,
&self.sender,
&self.sent_messages_attempts,
&self.blocking_sent_messages_attempts,
|grant| {
let header_size = payload_header.header_size();
let (header_bytes, remaining) = grant.split_at_mut(header_size);
payload_header.serialize_into(header_bytes);
if dimensions != 0 {
let output_iter =
remaining.chunks_exact_mut(dimensions * mem::size_of::<f32>());
for (embedding, output) in embedding.iter().zip(output_iter) {
output.copy_from_slice(bytemuck::cast_slice(embedding));
}
}
Ok(())
},
)?;
Ok(())
}
fn embedding_status(
&self,
name: &str,
infos: crate::vector::db::EmbedderInfo,
) -> crate::Result<()> {
let bytes = infos.to_bytes().map_err(|_| {
InternalError::Serialization(crate::SerializationError::Encoding {
db_name: Some(Database::VectorEmbedderCategoryId.database_name()),
})
})?;
self.write_key_value(Database::VectorEmbedderCategoryId, name.as_bytes(), &bytes)
}
fn write_key_value(&self, database: Database, key: &[u8], value: &[u8]) -> crate::Result<()> {
let key_length = key.len().try_into().ok().and_then(NonZeroU16::new).ok_or_else(|| {
InternalError::StorePut {
@@ -942,9 +1085,18 @@ impl EmbeddingSender<'_, '_> {
&self,
docid: DocumentId,
embedder_id: u8,
embedding: Embedding,
extractor_id: u8,
embedding: Option<Embedding>,
) -> crate::Result<()> {
self.0.set_vectors(docid, embedder_id, &[embedding])
self.0.set_vector_for_extractor(docid, embedder_id, extractor_id, embedding)
}
pub(crate) fn embedding_status(
&self,
name: &str,
infos: crate::vector::db::EmbedderInfo,
) -> crate::Result<()> {
self.0.embedding_status(name, infos)
}
}

View File

@@ -12,6 +12,7 @@ use super::vector_document::VectorDocument;
use super::{KvReaderFieldId, KvWriterFieldId};
use crate::constants::{RESERVED_GEO_FIELD_NAME, RESERVED_VECTORS_FIELD_NAME};
use crate::documents::FieldIdMapper;
use crate::update::del_add::KvReaderDelAdd;
use crate::update::new::thread_local::{FullySend, MostlySend, ThreadLocal};
use crate::update::new::vector_document::VectorDocumentFromDb;
use crate::vector::settings::EmbedderAction;
@@ -469,6 +470,110 @@ impl<'doc> Versions<'doc> {
}
}
#[derive(Debug)]
pub struct KvDelAddDocument<'a, Mapper: FieldIdMapper> {
document: &'a obkv::KvReaderU16,
side: crate::update::del_add::DelAdd,
fields_ids_map: &'a Mapper,
}
impl<'a, Mapper: FieldIdMapper> KvDelAddDocument<'a, Mapper> {
pub fn new(
document: &'a obkv::KvReaderU16,
side: crate::update::del_add::DelAdd,
fields_ids_map: &'a Mapper,
) -> Self {
Self { document, side, fields_ids_map }
}
fn get(&self, k: &str) -> Result<Option<&'a RawValue>> {
let Some(id) = self.fields_ids_map.id(k) else { return Ok(None) };
let Some(value) = self.document.get(id) else { return Ok(None) };
let Some(value) = KvReaderDelAdd::from_slice(value).get(self.side) else { return Ok(None) };
let value = serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?;
Ok(Some(value))
}
}
impl<'a, Mapper: FieldIdMapper> Document<'a> for KvDelAddDocument<'a, Mapper> {
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'a str, &'a RawValue)>> {
let mut it = self.document.iter();
std::iter::from_fn(move || loop {
let (fid, value) = it.next()?;
let Some(value) = KvReaderDelAdd::from_slice(value).get(self.side) else {
continue;
};
let name = match self.fields_ids_map.name(fid).ok_or(
InternalError::FieldIdMapMissingEntry(crate::FieldIdMapMissingEntry::FieldId {
field_id: fid,
process: "getting current document",
}),
) {
Ok(name) => name,
Err(error) => return Some(Err(error.into())),
};
if name == RESERVED_VECTORS_FIELD_NAME || name == RESERVED_GEO_FIELD_NAME {
continue;
}
let res = (|| {
let value =
serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?;
Ok((name, value))
})();
return Some(res);
})
}
fn top_level_fields_count(&self) -> usize {
let mut it = self.document.iter();
std::iter::from_fn(move || loop {
let (fid, value) = it.next()?;
let Some(_) = KvReaderDelAdd::from_slice(value).get(self.side) else {
continue;
};
let name = match self.fields_ids_map.name(fid).ok_or(
InternalError::FieldIdMapMissingEntry(crate::FieldIdMapMissingEntry::FieldId {
field_id: fid,
process: "getting current document",
}),
) {
Ok(name) => name,
Err(_) => return Some(()),
};
if name == RESERVED_VECTORS_FIELD_NAME || name == RESERVED_GEO_FIELD_NAME {
continue;
}
return Some(());
})
.count()
}
fn top_level_field(&self, k: &str) -> Result<Option<&'a RawValue>> {
if k == RESERVED_VECTORS_FIELD_NAME || k == RESERVED_GEO_FIELD_NAME {
return Ok(None);
}
self.get(k)
}
fn vectors_field(&self) -> Result<Option<&'a RawValue>> {
self.get(RESERVED_VECTORS_FIELD_NAME)
}
fn geo_field(&self) -> Result<Option<&'a RawValue>> {
self.get(RESERVED_GEO_FIELD_NAME)
}
}
pub struct DocumentIdentifiers<'doc> {
docid: DocumentId,
external_document_id: &'doc str,

View File

@@ -11,7 +11,7 @@ use super::vector_document::{
use crate::attribute_patterns::PatternMatch;
use crate::documents::FieldIdMapper;
use crate::update::new::document::DocumentIdentifiers;
use crate::vector::EmbeddingConfigs;
use crate::vector::RuntimeEmbedders;
use crate::{DocumentId, Index, InternalError, Result};
pub enum DocumentChange<'doc> {
@@ -70,7 +70,7 @@ impl<'doc> Insertion<'doc> {
pub fn inserted_vectors(
&self,
doc_alloc: &'doc Bump,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
) -> Result<Option<VectorDocumentFromVersions<'doc>>> {
VectorDocumentFromVersions::new(self.external_document_id, &self.new, doc_alloc, embedders)
}
@@ -241,7 +241,7 @@ impl<'doc> Update<'doc> {
pub fn only_changed_vectors(
&self,
doc_alloc: &'doc Bump,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
) -> Result<Option<VectorDocumentFromVersions<'doc>>> {
VectorDocumentFromVersions::new(self.external_document_id, &self.new, doc_alloc, embedders)
}
@@ -252,7 +252,7 @@ impl<'doc> Update<'doc> {
index: &'doc Index,
mapper: &'doc Mapper,
doc_alloc: &'doc Bump,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
) -> Result<Option<MergedVectorDocument<'doc>>> {
if self.from_scratch {
MergedVectorDocument::without_db(

View File

@@ -7,8 +7,7 @@ use hashbrown::HashMap;
use super::DelAddRoaringBitmap;
use crate::constants::RESERVED_GEO_FIELD_NAME;
use crate::update::new::channel::{DocumentsSender, ExtractorBbqueueSender};
use crate::update::new::document::{write_to_obkv, Document};
use crate::update::new::document::{DocumentContext, DocumentIdentifiers};
use crate::update::new::document::{write_to_obkv, Document, DocumentContext, DocumentIdentifiers};
use crate::update::new::indexer::document_changes::{Extractor, IndexingContext};
use crate::update::new::indexer::settings_changes::{
settings_change_extract, DocumentsIndentifiers, SettingsChangeExtractor,
@@ -19,16 +18,16 @@ use crate::update::new::vector_document::VectorDocument;
use crate::update::new::DocumentChange;
use crate::update::settings::SettingsDelta;
use crate::vector::settings::EmbedderAction;
use crate::vector::EmbeddingConfigs;
use crate::vector::RuntimeEmbedders;
use crate::Result;
pub struct DocumentsExtractor<'a, 'b> {
document_sender: DocumentsSender<'a, 'b>,
embedders: &'a EmbeddingConfigs,
embedders: &'a RuntimeEmbedders,
}
impl<'a, 'b> DocumentsExtractor<'a, 'b> {
pub fn new(document_sender: DocumentsSender<'a, 'b>, embedders: &'a EmbeddingConfigs) -> Self {
pub fn new(document_sender: DocumentsSender<'a, 'b>, embedders: &'a RuntimeEmbedders) -> Self {
Self { document_sender, embedders }
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -13,21 +13,17 @@ use super::super::thread_local::{FullySend, ThreadLocal};
use super::super::FacetFieldIdsDelta;
use super::document_changes::{extract, DocumentChanges, IndexingContext};
use super::settings_changes::settings_change_extract;
use crate::documents::FieldIdMapper;
use crate::documents::PrimaryKey;
use crate::index::IndexEmbeddingConfig;
use crate::progress::EmbedderStats;
use crate::progress::MergingWordCache;
use crate::documents::{FieldIdMapper, PrimaryKey};
use crate::progress::{EmbedderStats, MergingWordCache};
use crate::proximity::ProximityPrecision;
use crate::update::new::extract::EmbeddingExtractor;
use crate::update::new::indexer::settings_changes::DocumentsIndentifiers;
use crate::update::new::merger::merge_and_send_rtree;
use crate::update::new::{merge_and_send_docids, merge_and_send_facet_docids, FacetDatabases};
use crate::update::settings::SettingsDelta;
use crate::vector::EmbeddingConfigs;
use crate::Index;
use crate::InternalError;
use crate::{Result, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
use crate::vector::db::{EmbedderInfo, IndexEmbeddingConfig};
use crate::vector::RuntimeEmbedders;
use crate::{Index, InternalError, Result, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
#[allow(clippy::too_many_arguments)]
pub(super) fn extract_all<'pl, 'extractor, DC, MSP>(
@@ -35,7 +31,7 @@ pub(super) fn extract_all<'pl, 'extractor, DC, MSP>(
indexing_context: IndexingContext<MSP>,
indexer_span: Span,
extractor_sender: ExtractorBbqueueSender,
embedders: &EmbeddingConfigs,
embedders: &RuntimeEmbedders,
extractor_allocs: &'extractor mut ThreadLocal<FullySend<Bump>>,
finished_extraction: &AtomicBool,
field_distribution: &mut BTreeMap<String, u64>,
@@ -275,14 +271,19 @@ where
let span = tracing::debug_span!(target: "indexing::documents::merge", "vectors");
let _entered = span.enter();
let embedder_configs = index.embedding_configs();
for config in &mut index_embeddings {
let mut infos = embedder_configs.embedder_info(&rtxn, &config.name)?.unwrap();
'data: for data in datastore.iter_mut() {
let data = &mut data.get_mut().0;
let Some(deladd) = data.remove(&config.name) else {
let Some(delta) = data.remove(&config.name) else {
continue 'data;
};
deladd.apply_to(&mut config.user_provided, modified_docids);
delta.apply_to(&mut infos.embedding_status);
}
extractor_sender.embeddings().embedding_status(&config.name, infos).unwrap();
}
}
}
@@ -332,12 +333,11 @@ pub(super) fn extract_all_settings_changes<MSP, SD>(
finished_extraction: &AtomicBool,
field_distribution: &mut BTreeMap<String, u64>,
mut index_embeddings: Vec<IndexEmbeddingConfig>,
modified_docids: &mut RoaringBitmap,
embedder_stats: &EmbedderStats,
) -> Result<Vec<IndexEmbeddingConfig>>
where
MSP: Fn() -> bool + Sync,
SD: SettingsDelta,
SD: SettingsDelta + Sync,
{
// Create the list of document ids to extract
let rtxn = indexing_context.index.read_txn()?;
@@ -368,10 +368,7 @@ where
// extract the remaining embeddings
let extractor = SettingsChangeEmbeddingExtractor::new(
settings_delta.new_embedders(),
settings_delta.old_embedders(),
settings_delta.embedder_actions(),
settings_delta.new_embedder_category_id(),
settings_delta,
embedder_stats,
embedding_sender,
field_distribution,
@@ -395,14 +392,25 @@ where
let span = tracing::debug_span!(target: "indexing::documents::merge", "vectors");
let _entered = span.enter();
let embedder_configs = indexing_context.index.embedding_configs();
for config in &mut index_embeddings {
// retrieve infos for existing embedder or create a fresh one
let mut infos =
embedder_configs.embedder_info(&rtxn, &config.name)?.unwrap_or_else(|| {
let embedder_id =
*settings_delta.new_embedder_category_id().get(&config.name).unwrap();
EmbedderInfo { embedder_id, embedding_status: Default::default() }
});
'data: for data in datastore.iter_mut() {
let data = &mut data.get_mut().0;
let Some(deladd) = data.remove(&config.name) else {
let Some(delta) = data.remove(&config.name) else {
continue 'data;
};
deladd.apply_to(&mut config.user_provided, modified_docids);
delta.apply_to(&mut infos.embedding_status);
}
extractor_sender.embeddings().embedding_status(&config.name, infos).unwrap();
}
}
}

View File

@@ -23,8 +23,8 @@ use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::progress::{EmbedderStats, Progress};
use crate::update::settings::SettingsDelta;
use crate::update::GrenadParameters;
use crate::vector::settings::{EmbedderAction, WriteBackToDocuments};
use crate::vector::{ArroyWrapper, Embedder, EmbeddingConfigs};
use crate::vector::settings::{EmbedderAction, RemoveFragments, WriteBackToDocuments};
use crate::vector::{ArroyWrapper, Embedder, RuntimeEmbedders};
use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result, ThreadPoolNoAbort};
pub(crate) mod de;
@@ -54,7 +54,7 @@ pub fn index<'pl, 'indexer, 'index, DC, MSP>(
new_fields_ids_map: FieldsIdsMap,
new_primary_key: Option<PrimaryKey<'pl>>,
document_changes: &DC,
embedders: EmbeddingConfigs,
embedders: RuntimeEmbedders,
must_stop_processing: &'indexer MSP,
progress: &'indexer Progress,
embedder_stats: &'indexer EmbedderStats,
@@ -93,7 +93,7 @@ where
grenad_parameters: &grenad_parameters,
};
let index_embeddings = index.embedding_configs(wtxn)?;
let index_embeddings = index.embedding_configs().embedding_configs(wtxn)?;
let mut field_distribution = index.field_distribution(wtxn)?;
let mut document_ids = index.documents_ids(wtxn)?;
let mut modified_docids = roaring::RoaringBitmap::new();
@@ -133,20 +133,21 @@ where
let arroy_writers: Result<HashMap<_, _>> = embedders
.inner_as_ref()
.iter()
.map(|(embedder_name, (embedder, _, was_quantized))| {
let embedder_index = index.embedder_category_id.get(wtxn, embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry {
.map(|(embedder_name, runtime)| {
let embedder_index = index
.embedding_configs()
.embedder_id(wtxn, embedder_name)?
.ok_or(InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
},
)?;
})?;
let dimensions = embedder.dimensions();
let writer = ArroyWrapper::new(vector_arroy, embedder_index, *was_quantized);
let dimensions = runtime.embedder.dimensions();
let writer = ArroyWrapper::new(vector_arroy, embedder_index, runtime.is_quantized);
Ok((
embedder_index,
(embedder_name.as_str(), embedder.as_ref(), writer, dimensions),
(embedder_name.as_str(), &*runtime.embedder, writer, dimensions),
))
})
.collect();
@@ -220,7 +221,7 @@ where
MSP: Fn() -> bool + Sync,
SD: SettingsDelta + Sync,
{
delete_old_embedders(wtxn, index, settings_delta)?;
delete_old_embedders_and_fragments(wtxn, index, settings_delta)?;
let mut bbbuffers = Vec::new();
let finished_extraction = AtomicBool::new(false);
@@ -253,16 +254,14 @@ where
grenad_parameters: &grenad_parameters,
};
let index_embeddings = index.embedding_configs(wtxn)?;
let index_embeddings = index.embedding_configs().embedding_configs(wtxn)?;
let mut field_distribution = index.field_distribution(wtxn)?;
let mut modified_docids = roaring::RoaringBitmap::new();
let congestion = thread::scope(|s| -> Result<ChannelCongestion> {
let indexer_span = tracing::Span::current();
let finished_extraction = &finished_extraction;
// prevent moving the field_distribution and document_ids in the inner closure...
let field_distribution = &mut field_distribution;
let modified_docids = &mut modified_docids;
let extractor_handle =
Builder::new().name(S("indexer-extractors")).spawn_scoped(s, move || {
pool.install(move || {
@@ -275,7 +274,6 @@ where
finished_extraction,
field_distribution,
index_embeddings,
modified_docids,
&embedder_stats,
)
})
@@ -341,7 +339,7 @@ where
fn arroy_writers_from_embedder_actions<'indexer>(
index: &Index,
embedder_actions: &'indexer BTreeMap<String, EmbedderAction>,
embedders: &'indexer EmbeddingConfigs,
embedders: &'indexer RuntimeEmbedders,
index_embedder_category_ids: &'indexer std::collections::HashMap<String, u8>,
) -> Result<HashMap<u8, (&'indexer str, &'indexer Embedder, ArroyWrapper, usize)>> {
let vector_arroy = index.vector_arroy;
@@ -349,7 +347,7 @@ fn arroy_writers_from_embedder_actions<'indexer>(
embedders
.inner_as_ref()
.iter()
.filter_map(|(embedder_name, (embedder, _, _))| match embedder_actions.get(embedder_name) {
.filter_map(|(embedder_name, runtime)| match embedder_actions.get(embedder_name) {
None => None,
Some(action) if action.write_back().is_some() => None,
Some(action) => {
@@ -364,25 +362,65 @@ fn arroy_writers_from_embedder_actions<'indexer>(
};
let writer =
ArroyWrapper::new(vector_arroy, embedder_category_id, action.was_quantized);
let dimensions = embedder.dimensions();
let dimensions = runtime.embedder.dimensions();
Some(Ok((
embedder_category_id,
(embedder_name.as_str(), embedder.as_ref(), writer, dimensions),
(embedder_name.as_str(), runtime.embedder.as_ref(), writer, dimensions),
)))
}
})
.collect()
}
fn delete_old_embedders<SD>(wtxn: &mut RwTxn<'_>, index: &Index, settings_delta: &SD) -> Result<()>
fn delete_old_embedders_and_fragments<SD>(
wtxn: &mut RwTxn<'_>,
index: &Index,
settings_delta: &SD,
) -> Result<()>
where
SD: SettingsDelta,
{
for action in settings_delta.embedder_actions().values() {
if let Some(WriteBackToDocuments { embedder_id, .. }) = action.write_back() {
let reader = ArroyWrapper::new(index.vector_arroy, *embedder_id, action.was_quantized);
let dimensions = reader.dimensions(wtxn)?;
reader.clear(wtxn, dimensions)?;
let Some(WriteBackToDocuments { embedder_id, .. }) = action.write_back() else {
continue;
};
let reader = ArroyWrapper::new(index.vector_arroy, *embedder_id, action.was_quantized);
let Some(dimensions) = reader.dimensions(wtxn)? else {
continue;
};
reader.clear(wtxn, dimensions)?;
}
// remove all vectors for the specified fragments
for (embedder_name, RemoveFragments { fragment_ids }, was_quantized) in
settings_delta.embedder_actions().iter().filter_map(|(name, action)| {
action.remove_fragments().map(|fragments| (name, fragments, action.was_quantized))
})
{
let Some(infos) = index.embedding_configs().embedder_info(wtxn, embedder_name)? else {
continue;
};
let arroy = ArroyWrapper::new(index.vector_arroy, infos.embedder_id, was_quantized);
let Some(dimensions) = arroy.dimensions(wtxn)? else {
continue;
};
for fragment_id in fragment_ids {
// we must keep the user provided embeddings that ended up in this store
if infos.embedding_status.user_provided_docids().is_empty() {
// no user provided: clear store
arroy.clear_store(wtxn, *fragment_id, dimensions)?;
continue;
}
// some user provided, remove only the ids that are not user provided
let to_delete = arroy.items_in_store(wtxn, *fragment_id, |items| {
items - infos.embedding_status.user_provided_docids()
})?;
for to_delete in to_delete {
arroy.del_item_in_store(wtxn, to_delete, *fragment_id, dimensions)?;
}
}
}

View File

@@ -11,11 +11,11 @@ use super::super::channel::*;
use crate::database_stats::DatabaseStats;
use crate::documents::PrimaryKey;
use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
use crate::index::IndexEmbeddingConfig;
use crate::progress::Progress;
use crate::update::settings::InnerIndexSettings;
use crate::vector::db::IndexEmbeddingConfig;
use crate::vector::settings::EmbedderAction;
use crate::vector::{ArroyWrapper, Embedder, EmbeddingConfigs, Embeddings};
use crate::vector::{ArroyWrapper, Embedder, Embeddings, RuntimeEmbedders};
use crate::{Error, Index, InternalError, Result, UserError};
pub fn write_to_db(
@@ -64,6 +64,14 @@ pub fn write_to_db(
writer.del_items(wtxn, *dimensions, docid)?;
writer.add_items(wtxn, docid, &embeddings)?;
}
ReceiverAction::LargeVector(
large_vector @ LargeVector { docid, embedder_id, extractor_id, .. },
) => {
let (_, _, writer, dimensions) =
arroy_writers.get(&embedder_id).expect("requested a missing embedder");
let embedding = large_vector.read_embedding(*dimensions);
writer.add_item_in_store(wtxn, docid, extractor_id, embedding)?;
}
}
// Every time the is a message in the channel we search
@@ -137,7 +145,7 @@ where
)?;
}
index.put_embedding_configs(wtxn, index_embeddings)?;
index.embedding_configs().put_embedding_configs(wtxn, index_embeddings)?;
Ok(())
}
@@ -147,7 +155,7 @@ pub(super) fn update_index(
wtxn: &mut RwTxn<'_>,
new_fields_ids_map: FieldIdMapWithMetadata,
new_primary_key: Option<PrimaryKey<'_>>,
embedders: EmbeddingConfigs,
embedders: RuntimeEmbedders,
field_distribution: std::collections::BTreeMap<String, u64>,
document_ids: roaring::RoaringBitmap,
) -> Result<()> {
@@ -226,14 +234,36 @@ pub fn write_from_bbqueue(
arroy_writers.get(&embedder_id).expect("requested a missing embedder");
let mut embeddings = Embeddings::new(*dimensions);
let all_embeddings = asvs.read_all_embeddings_into_vec(frame, aligned_embedding);
if embeddings.append(all_embeddings.to_vec()).is_err() {
return Err(Error::UserError(UserError::InvalidVectorDimensions {
expected: *dimensions,
found: all_embeddings.len(),
}));
}
writer.del_items(wtxn, *dimensions, docid)?;
writer.add_items(wtxn, docid, &embeddings)?;
if !all_embeddings.is_empty() {
if embeddings.append(all_embeddings.to_vec()).is_err() {
return Err(Error::UserError(UserError::InvalidVectorDimensions {
expected: *dimensions,
found: all_embeddings.len(),
}));
}
writer.add_items(wtxn, docid, &embeddings)?;
}
}
EntryHeader::ArroySetVector(
asv @ ArroySetVector { docid, embedder_id, extractor_id, .. },
) => {
let frame = frame_with_header.frame();
let (_, _, writer, dimensions) =
arroy_writers.get(&embedder_id).expect("requested a missing embedder");
let embedding = asv.read_all_embeddings_into_vec(frame, aligned_embedding);
if embedding.is_empty() {
writer.del_item_in_store(wtxn, docid, extractor_id, *dimensions)?;
} else {
if embedding.len() != *dimensions {
return Err(Error::UserError(UserError::InvalidVectorDimensions {
expected: *dimensions,
found: embedding.len(),
}));
}
writer.add_item_in_store(wtxn, docid, extractor_id, embedding)?;
}
}
}
}

View File

@@ -12,9 +12,9 @@ use super::document::{Document, DocumentFromDb, DocumentFromVersions, Versions};
use super::indexer::de::DeserrRawValue;
use crate::constants::RESERVED_VECTORS_FIELD_NAME;
use crate::documents::FieldIdMapper;
use crate::index::IndexEmbeddingConfig;
use crate::vector::db::{EmbeddingStatus, IndexEmbeddingConfig};
use crate::vector::parsed_vectors::{RawVectors, RawVectorsError, VectorOrArrayOfVectors};
use crate::vector::{ArroyWrapper, Embedding, EmbeddingConfigs};
use crate::vector::{ArroyWrapper, Embedding, RuntimeEmbedders};
use crate::{DocumentId, Index, InternalError, Result, UserError};
#[derive(Serialize)]
@@ -109,7 +109,7 @@ impl<'t> VectorDocumentFromDb<'t> {
None => None,
};
let embedding_config = index.embedding_configs(rtxn)?;
let embedding_config = index.embedding_configs().embedding_configs(rtxn)?;
Ok(Some(Self { docid, embedding_config, index, vectors_field, rtxn, doc_alloc }))
}
@@ -118,6 +118,7 @@ impl<'t> VectorDocumentFromDb<'t> {
&self,
embedder_id: u8,
config: &IndexEmbeddingConfig,
status: &EmbeddingStatus,
) -> Result<VectorEntry<'t>> {
let reader =
ArroyWrapper::new(self.index.vector_arroy, embedder_id, config.config.quantized());
@@ -126,7 +127,7 @@ impl<'t> VectorDocumentFromDb<'t> {
Ok(VectorEntry {
has_configured_embedder: true,
embeddings: Some(Embeddings::FromDb(vectors)),
regenerate: !config.user_provided.contains(self.docid),
regenerate: status.must_regenerate(self.docid),
implicit: false,
})
}
@@ -137,9 +138,9 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> {
self.embedding_config
.iter()
.map(|config| {
let embedder_id =
self.index.embedder_category_id.get(self.rtxn, &config.name)?.unwrap();
let entry = self.entry_from_db(embedder_id, config)?;
let info =
self.index.embedding_configs().embedder_info(self.rtxn, &config.name)?.unwrap();
let entry = self.entry_from_db(info.embedder_id, config, &info.embedding_status)?;
let config_name = self.doc_alloc.alloc_str(config.name.as_str());
Ok((&*config_name, entry))
})
@@ -156,11 +157,11 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> {
}
fn vectors_for_key(&self, key: &str) -> Result<Option<VectorEntry<'t>>> {
Ok(match self.index.embedder_category_id.get(self.rtxn, key)? {
Some(embedder_id) => {
Ok(match self.index.embedding_configs().embedder_info(self.rtxn, key)? {
Some(info) => {
let config =
self.embedding_config.iter().find(|config| config.name == key).unwrap();
Some(self.entry_from_db(embedder_id, config)?)
Some(self.entry_from_db(info.embedder_id, config, &info.embedding_status)?)
}
None => match self.vectors_field.as_ref().and_then(|obkv| obkv.get(key)) {
Some(embedding_from_doc) => {
@@ -222,7 +223,7 @@ fn entry_from_raw_value(
pub struct VectorDocumentFromVersions<'doc> {
external_document_id: &'doc str,
vectors: RawMap<'doc, FxBuildHasher>,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
}
impl<'doc> VectorDocumentFromVersions<'doc> {
@@ -230,7 +231,7 @@ impl<'doc> VectorDocumentFromVersions<'doc> {
external_document_id: &'doc str,
versions: &Versions<'doc>,
bump: &'doc Bump,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
) -> Result<Option<Self>> {
let document = DocumentFromVersions::new(versions);
if let Some(vectors_field) = document.vectors_field()? {
@@ -283,7 +284,7 @@ impl<'doc> MergedVectorDocument<'doc> {
db_fields_ids_map: &'doc Mapper,
versions: &Versions<'doc>,
doc_alloc: &'doc Bump,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
) -> Result<Option<Self>> {
let db = VectorDocumentFromDb::new(docid, index, rtxn, db_fields_ids_map, doc_alloc)?;
let new_doc =
@@ -295,7 +296,7 @@ impl<'doc> MergedVectorDocument<'doc> {
external_document_id: &'doc str,
versions: &Versions<'doc>,
doc_alloc: &'doc Bump,
embedders: &'doc EmbeddingConfigs,
embedders: &'doc RuntimeEmbedders,
) -> Result<Option<Self>> {
let Some(new_doc) =
VectorDocumentFromVersions::new(external_document_id, versions, doc_alloc, embedders)?

View File

@@ -7,7 +7,6 @@ use std::sync::Arc;
use charabia::{Normalize, Tokenizer, TokenizerBuilder};
use deserr::{DeserializeError, Deserr};
use itertools::{merge_join_by, EitherOrBoth, Itertools};
use roaring::RoaringBitmap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
@@ -23,22 +22,25 @@ use crate::error::UserError::{self, InvalidChatSettingsDocumentTemplateMaxBytes}
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::filterable_attributes_rules::match_faceted_field;
use crate::index::{
ChatConfig, IndexEmbeddingConfig, PrefixSearch, SearchParameters,
DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
ChatConfig, PrefixSearch, SearchParameters, DEFAULT_MIN_WORD_LEN_ONE_TYPO,
DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
};
use crate::order_by_map::OrderByMap;
use crate::progress::EmbedderStats;
use crate::progress::Progress;
use crate::progress::{EmbedderStats, Progress};
use crate::prompt::{default_max_bytes, default_template_text, PromptData};
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::new::indexer::reindex;
use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::db::{FragmentConfigs, IndexEmbeddingConfig};
use crate::vector::json_template::JsonTemplate;
use crate::vector::settings::{
EmbedderAction, EmbedderSource, EmbeddingSettings, NestingContext, ReindexAction,
SubEmbeddingSettings, WriteBackToDocuments,
EmbedderAction, EmbedderSource, EmbeddingSettings, EmbeddingValidationContext, NestingContext,
ReindexAction, SubEmbeddingSettings, WriteBackToDocuments,
};
use crate::vector::{
Embedder, EmbeddingConfig, RuntimeEmbedder, RuntimeEmbedders, RuntimeFragment,
};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{
ChannelCongestion, FieldId, FilterableAttributesRule, Index, LocalizedAttributesRule, Result,
};
@@ -1044,22 +1046,27 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
match std::mem::take(&mut self.embedder_settings) {
Setting::Set(configs) => self.update_embedding_configs_set(configs),
Setting::Reset => {
let embedders = self.index.embedding_configs();
// all vectors should be written back to documents
let old_configs = self.index.embedding_configs(self.wtxn)?;
let old_configs = embedders.embedding_configs(self.wtxn)?;
let remove_all: Result<BTreeMap<String, EmbedderAction>> = old_configs
.into_iter()
.map(|IndexEmbeddingConfig { name, config, user_provided }| -> Result<_> {
let embedder_id =
self.index.embedder_category_id.get(self.wtxn, &name)?.ok_or(
crate::InternalError::DatabaseMissingEntry {
db_name: crate::index::db_name::VECTOR_EMBEDDER_CATEGORY_ID,
key: None,
},
)?;
.map(|IndexEmbeddingConfig { name, config, fragments: _ }| -> Result<_> {
let embedder_info = embedders.embedder_info(self.wtxn, &name)?.ok_or(
crate::InternalError::DatabaseMissingEntry {
db_name: crate::index::db_name::VECTOR_EMBEDDER_CATEGORY_ID,
key: None,
},
)?;
Ok((
name,
EmbedderAction::with_write_back(
WriteBackToDocuments { embedder_id, user_provided },
WriteBackToDocuments {
embedder_id: embedder_info.embedder_id,
user_provided: embedder_info
.embedding_status
.into_user_provided(),
},
config.quantized(),
),
))
@@ -1069,7 +1076,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
let remove_all = remove_all?;
self.index.embedder_category_id.clear(self.wtxn)?;
self.index.delete_embedding_configs(self.wtxn)?;
embedders.delete_embedding_configs(self.wtxn)?;
Ok(remove_all)
}
Setting::NotSet => Ok(Default::default()),
@@ -1081,12 +1088,12 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
configs: BTreeMap<String, Setting<EmbeddingSettings>>,
) -> Result<BTreeMap<String, EmbedderAction>> {
use crate::vector::settings::SettingsDiff;
let old_configs = self.index.embedding_configs(self.wtxn)?;
let old_configs: BTreeMap<String, (EmbeddingSettings, RoaringBitmap)> = old_configs
let embedders = self.index.embedding_configs();
let old_configs = embedders.embedding_configs(self.wtxn)?;
let old_configs: BTreeMap<String, (EmbeddingSettings, FragmentConfigs)> = old_configs
.into_iter()
.map(|IndexEmbeddingConfig { name, config, user_provided }| {
(name, (config.into(), user_provided))
.map(|IndexEmbeddingConfig { name, config, fragments }| {
(name, (config.into(), fragments))
})
.collect();
let mut updated_configs = BTreeMap::new();
@@ -1097,71 +1104,111 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
{
match joined {
// updated config
EitherOrBoth::Both((name, (old, user_provided)), (_, new)) => {
EitherOrBoth::Both((name, (old, mut fragments)), (_, new)) => {
let was_quantized = old.binary_quantized.set().unwrap_or_default();
let settings_diff = SettingsDiff::from_settings(&name, old, new)?;
match settings_diff {
SettingsDiff::Remove => {
let info = embedders.remove_embedder(self.wtxn, &name)?.ok_or(
crate::InternalError::DatabaseMissingEntry {
db_name: crate::index::db_name::VECTOR_EMBEDDER_CATEGORY_ID,
key: None,
},
)?;
tracing::debug!(
embedder = name,
user_provided = user_provided.len(),
user_provided = info.embedding_status.user_provided_docids().len(),
"removing embedder"
);
let embedder_id =
self.index.embedder_category_id.get(self.wtxn, &name)?.ok_or(
crate::InternalError::DatabaseMissingEntry {
db_name: crate::index::db_name::VECTOR_EMBEDDER_CATEGORY_ID,
key: None,
},
)?;
// free id immediately
self.index.embedder_category_id.delete(self.wtxn, &name)?;
embedder_actions.insert(
name,
EmbedderAction::with_write_back(
WriteBackToDocuments { embedder_id, user_provided },
WriteBackToDocuments {
embedder_id: info.embedder_id,
user_provided: info.embedding_status.into_user_provided(),
},
was_quantized,
),
);
}
SettingsDiff::Reindex { action, updated_settings, quantize } => {
tracing::debug!(
embedder = name,
user_provided = user_provided.len(),
?action,
"reindex embedder"
);
embedder_actions.insert(
name.clone(),
let mut remove_fragments = None;
let updated_settings = Setting::Set(updated_settings);
if let ReindexAction::RegenerateFragments(regenerate_fragments) =
&action
{
let it = regenerate_fragments
.iter()
.filter(|(_, action)| {
matches!(
action,
crate::vector::settings::RegenerateFragment::Remove
)
})
.map(|(name, _)| name.as_str());
remove_fragments = fragments.remove_fragments(it);
let it = regenerate_fragments
.iter()
.filter(|(_, action)| {
matches!(
action,
crate::vector::settings::RegenerateFragment::Add
)
})
.map(|(name, _)| name.clone());
fragments.add_new_fragments(it)?;
} else {
// needs full reindex of fragments
fragments = FragmentConfigs::new();
fragments.add_new_fragments(
crate::vector::settings::fragments_from_settings(
&updated_settings,
),
)?;
}
tracing::debug!(embedder = name, ?action, "reindex embedder");
let embedder_action =
EmbedderAction::with_reindex(action, was_quantized)
.with_is_being_quantized(quantize),
);
let new =
validate_embedding_settings(Setting::Set(updated_settings), &name)?;
updated_configs.insert(name, (new, user_provided));
.with_is_being_quantized(quantize);
let embedder_action = if let Some(remove_fragments) = remove_fragments {
embedder_action.with_remove_fragments(remove_fragments)
} else {
embedder_action
};
embedder_actions.insert(name.clone(), embedder_action);
let new = validate_embedding_settings(
updated_settings,
&name,
EmbeddingValidationContext::FullSettings,
)?;
updated_configs.insert(name, (new, fragments));
}
SettingsDiff::UpdateWithoutReindex { updated_settings, quantize } => {
tracing::debug!(
embedder = name,
user_provided = user_provided.len(),
"update without reindex embedder"
);
let new =
validate_embedding_settings(Setting::Set(updated_settings), &name)?;
tracing::debug!(embedder = name, "update without reindex embedder");
let new = validate_embedding_settings(
Setting::Set(updated_settings),
&name,
EmbeddingValidationContext::FullSettings,
)?;
if quantize {
embedder_actions.insert(
name.clone(),
EmbedderAction::default().with_is_being_quantized(true),
);
}
updated_configs.insert(name, (new, user_provided));
updated_configs.insert(name, (new, fragments));
}
}
}
// unchanged config
EitherOrBoth::Left((name, (setting, user_provided))) => {
EitherOrBoth::Left((name, (setting, fragments))) => {
tracing::debug!(embedder = name, "unchanged embedder");
updated_configs.insert(name, (Setting::Set(setting), user_provided));
updated_configs.insert(name, (Setting::Set(setting), fragments));
}
// new config
EitherOrBoth::Right((name, mut setting)) => {
@@ -1171,52 +1218,51 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
crate::vector::settings::EmbeddingSettings::apply_default_openai_model(
&mut setting,
);
let setting = validate_embedding_settings(setting, &name)?;
let setting = validate_embedding_settings(
setting,
&name,
EmbeddingValidationContext::FullSettings,
)?;
embedder_actions.insert(
name.clone(),
EmbedderAction::with_reindex(ReindexAction::FullReindex, false),
);
updated_configs.insert(name, (setting, RoaringBitmap::new()));
let mut fragments = FragmentConfigs::new();
fragments.add_new_fragments(
crate::vector::settings::fragments_from_settings(&setting),
)?;
updated_configs.insert(name, (setting, fragments));
}
}
}
let mut free_indices: [bool; u8::MAX as usize] = [true; u8::MAX as usize];
for res in self.index.embedder_category_id.iter(self.wtxn)? {
let (_name, id) = res?;
free_indices[id as usize] = false;
}
let mut free_indices = free_indices.iter_mut().enumerate();
let mut find_free_index =
move || free_indices.find(|(_, free)| **free).map(|(index, _)| index as u8);
for (name, action) in embedder_actions.iter() {
// ignore actions that are not possible for a new embedder
if matches!(action.reindex(), Some(ReindexAction::FullReindex))
&& self.index.embedder_category_id.get(self.wtxn, name)?.is_none()
{
let id =
find_free_index().ok_or(UserError::TooManyEmbedders(updated_configs.len()))?;
tracing::debug!(embedder = name, id, "assigning free id to new embedder");
self.index.embedder_category_id.put(self.wtxn, name, &id)?;
}
}
embedders.add_new_embedders(
self.wtxn,
embedder_actions
.iter()
// ignore actions that are not possible for a new embedder, most critically deleted embedders
.filter(|(_, action)| matches!(action.reindex(), Some(ReindexAction::FullReindex)))
.map(|(name, _)| name.as_str()),
updated_configs.len(),
)?;
let updated_configs: Vec<IndexEmbeddingConfig> = updated_configs
.into_iter()
.filter_map(|(name, (config, user_provided))| match config {
.filter_map(|(name, (config, fragments))| match config {
Setting::Set(config) => {
Some(IndexEmbeddingConfig { name, config: config.into(), user_provided })
Some(IndexEmbeddingConfig { name, config: config.into(), fragments })
}
Setting::Reset => None,
Setting::NotSet => Some(IndexEmbeddingConfig {
name,
config: EmbeddingSettings::default().into(),
user_provided,
fragments: Default::default(),
}),
})
.collect();
if updated_configs.is_empty() {
self.index.delete_embedding_configs(self.wtxn)?;
embedders.delete_embedding_configs(self.wtxn)?;
} else {
self.index.put_embedding_configs(self.wtxn, updated_configs)?;
embedders.put_embedding_configs(self.wtxn, updated_configs)?;
}
Ok(embedder_actions)
}
@@ -1543,6 +1589,7 @@ pub struct InnerIndexSettingsDiff {
/// The set of only the additional searchable fields.
/// If any other searchable field has been modified, is set to None.
pub(crate) only_additional_fields: Option<HashSet<String>>,
fragment_diffs: BTreeMap<String, Vec<(Option<usize>, usize)>>,
// Cache the check to see if all the stop_words, allowed_separators, dictionary,
// exact_attributes, proximity_precision are different.
@@ -1611,13 +1658,13 @@ impl InnerIndexSettingsDiff {
// if the user-defined searchables changed, then we need to reindex prompts.
if cache_user_defined_searchables {
for (embedder_name, (config, _, _quantized)) in
new_settings.embedding_configs.inner_as_ref()
{
let was_quantized =
old_settings.embedding_configs.get(embedder_name).is_some_and(|conf| conf.2);
for (embedder_name, runtime) in new_settings.runtime_embedders.inner_as_ref() {
let was_quantized = old_settings
.runtime_embedders
.get(embedder_name)
.is_some_and(|conf| conf.is_quantized);
// skip embedders that don't use document templates
if !config.uses_document_template() {
if !runtime.embedder.uses_document_template() {
continue;
}
@@ -1630,22 +1677,86 @@ impl InnerIndexSettingsDiff {
was_quantized,
));
}
std::collections::btree_map::Entry::Occupied(entry) => {
std::collections::btree_map::Entry::Occupied(mut entry) => {
// future-proofing, make sure to destructure here so that any new field is taken into account in this case
// case in point: adding `remove_fragments` was detected.
let EmbedderAction {
was_quantized: _,
is_being_quantized: _,
write_back: _, // We are deleting this embedder, so no point in regeneration
reindex: _, // We are already fully reindexing
} = entry.get();
write_back, // We are deleting this embedder, so no point in regeneration
reindex,
remove_fragments: _,
} = entry.get_mut();
// fixup reindex to make sure we regenerate all fragments
*reindex = match reindex.take() {
Some(reindex) => Some(reindex), // We are at least regenerating prompts
None => {
if write_back.is_none() {
Some(ReindexAction::RegeneratePrompts) // quantization case
} else {
None
}
}
};
}
};
}
}
// build the fragment diffs
let mut fragment_diffs = BTreeMap::new();
for (embedder_name, embedder_action) in &embedding_config_updates {
let Some(new_embedder) = new_settings.runtime_embedders.get(embedder_name) else {
continue;
};
let regenerate_fragments =
if let Some(ReindexAction::RegenerateFragments(regenerate_fragments)) =
embedder_action.reindex()
{
either::Either::Left(
regenerate_fragments
.iter()
.filter(|(_, action)| {
!matches!(
action,
crate::vector::settings::RegenerateFragment::Remove
)
})
.map(|(name, _)| name),
)
} else {
either::Either::Right(
new_embedder.fragments().iter().map(|fragment| &fragment.name),
)
};
let old_embedder = old_settings.runtime_embedders.get(embedder_name);
let mut fragments = Vec::new();
for fragment_name in regenerate_fragments {
let Ok(new) = new_embedder
.fragments()
.binary_search_by_key(&fragment_name, |fragment| &fragment.name)
else {
continue;
};
let old = old_embedder.as_ref().and_then(|old_embedder| {
old_embedder
.fragments()
.binary_search_by_key(&fragment_name, |fragment| &fragment.name)
.ok()
});
fragments.push((old, new));
}
fragment_diffs.insert(embedder_name.clone(), fragments);
}
InnerIndexSettingsDiff {
old: old_settings,
new: new_settings,
primary_key_id,
fragment_diffs,
embedding_config_updates,
settings_update_only,
only_additional_fields,
@@ -1790,7 +1901,7 @@ pub(crate) struct InnerIndexSettings {
pub exact_attributes: HashSet<FieldId>,
pub disabled_typos_terms: DisabledTyposTerms,
pub proximity_precision: ProximityPrecision,
pub embedding_configs: EmbeddingConfigs,
pub runtime_embedders: RuntimeEmbedders,
pub embedder_category_id: HashMap<String, u8>,
pub geo_fields_ids: Option<(FieldId, FieldId)>,
pub prefix_search: PrefixSearch,
@@ -1801,7 +1912,7 @@ impl InnerIndexSettings {
pub fn from_index(
index: &Index,
rtxn: &heed::RoTxn<'_>,
embedding_configs: Option<EmbeddingConfigs>,
runtime_embedders: Option<RuntimeEmbedders>,
) -> Result<Self> {
let stop_words = index.stop_words(rtxn)?;
let stop_words = stop_words.map(|sw| sw.map_data(Vec::from).unwrap());
@@ -1810,13 +1921,13 @@ impl InnerIndexSettings {
let mut fields_ids_map = index.fields_ids_map(rtxn)?;
let exact_attributes = index.exact_attributes_ids(rtxn)?;
let proximity_precision = index.proximity_precision(rtxn)?.unwrap_or_default();
let embedding_configs = match embedding_configs {
let runtime_embedders = match runtime_embedders {
Some(embedding_configs) => embedding_configs,
None => embedders(index.embedding_configs(rtxn)?)?,
None => embedders(index.embedding_configs().embedding_configs(rtxn)?)?,
};
let embedder_category_id = index
.embedder_category_id
.iter(rtxn)?
.embedding_configs()
.iter_embedder_id(rtxn)?
.map(|r| r.map(|(k, v)| (k.to_string(), v)))
.collect::<heed::Result<_>>()?;
let prefix_search = index.prefix_search(rtxn)?.unwrap_or_default();
@@ -1857,7 +1968,7 @@ impl InnerIndexSettings {
sortable_fields,
exact_attributes,
proximity_precision,
embedding_configs,
runtime_embedders,
embedder_category_id,
geo_fields_ids,
prefix_search,
@@ -1900,28 +2011,49 @@ impl InnerIndexSettings {
}
}
fn embedders(embedding_configs: Vec<IndexEmbeddingConfig>) -> Result<EmbeddingConfigs> {
fn embedders(embedding_configs: Vec<IndexEmbeddingConfig>) -> Result<RuntimeEmbedders> {
let res: Result<_> = embedding_configs
.into_iter()
.map(
|IndexEmbeddingConfig {
name,
config: EmbeddingConfig { embedder_options, prompt, quantized },
..
fragments,
}| {
let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
let document_template = prompt.try_into().map_err(crate::Error::from)?;
let embedder = Arc::new(
let embedder =
// cache_cap: no cache needed for indexing purposes
Embedder::new(embedder_options.clone(), 0)
Arc::new(Embedder::new(embedder_options.clone(), 0)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?,
);
Ok((name, (embedder, prompt, quantized.unwrap_or_default())))
.map_err(crate::Error::from)?);
let fragments = fragments
.into_inner()
.into_iter()
.map(|fragment| {
let template = JsonTemplate::new(
embedder_options.fragment(&fragment.name).unwrap().clone(),
)
.unwrap();
RuntimeFragment { name: fragment.name, id: fragment.id, template }
})
.collect();
Ok((
name,
Arc::new(RuntimeEmbedder::new(
embedder,
document_template,
fragments,
quantized.unwrap_or_default(),
)),
))
},
)
.collect();
res.map(EmbeddingConfigs::new)
res.map(RuntimeEmbedders::new)
}
fn validate_prompt(
@@ -1958,6 +2090,7 @@ fn validate_prompt(
pub fn validate_embedding_settings(
settings: Setting<EmbeddingSettings>,
name: &str,
context: EmbeddingValidationContext,
) -> Result<Setting<EmbeddingSettings>> {
let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings {
@@ -1970,6 +2103,8 @@ pub fn validate_embedding_settings(
document_template,
document_template_max_bytes,
url,
indexing_fragments,
search_fragments,
request,
response,
search_embedder,
@@ -1996,9 +2131,106 @@ pub fn validate_embedding_settings(
})?;
}
// used below
enum WithFragments {
Yes {
indexing_fragments: BTreeMap<String, serde_json::Value>,
search_fragments: BTreeMap<String, serde_json::Value>,
},
No,
Maybe,
}
let with_fragments = {
let has_reset = matches!(indexing_fragments, Setting::Reset)
|| matches!(search_fragments, Setting::Reset);
let indexing_fragments: BTreeMap<_, _> = indexing_fragments
.as_ref()
.set()
.iter()
.flat_map(|map| map.iter())
.filter_map(|(name, fragment)| {
Some((name.clone(), fragment.as_ref().map(|fragment| fragment.value.clone())?))
})
.collect();
let search_fragments: BTreeMap<_, _> = search_fragments
.as_ref()
.set()
.iter()
.flat_map(|map| map.iter())
.filter_map(|(name, fragment)| {
Some((name.clone(), fragment.as_ref().map(|fragment| fragment.value.clone())?))
})
.collect();
let has_fragments = !indexing_fragments.is_empty() || !search_fragments.is_empty();
if context == EmbeddingValidationContext::FullSettings {
let are_fragments_inconsistent =
indexing_fragments.is_empty() ^ search_fragments.is_empty();
if are_fragments_inconsistent {
return Err(crate::vector::error::NewEmbedderError::rest_inconsistent_fragments(
indexing_fragments.is_empty(),
indexing_fragments,
search_fragments,
))
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into()).into());
}
}
if has_fragments {
if context == EmbeddingValidationContext::SettingsPartialUpdate
&& matches!(document_template, Setting::Set(_))
{
return Err(
crate::vector::error::NewEmbedderError::rest_document_template_and_fragments(
indexing_fragments.len(),
search_fragments.len(),
),
)
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into()).into());
}
WithFragments::Yes { indexing_fragments, search_fragments }
} else if has_reset || context == EmbeddingValidationContext::FullSettings {
WithFragments::No
} else {
// if we are working with partial settings, the user could have changed only the `request` and not given again the fragments
WithFragments::Maybe
}
};
if let Some(request) = request.as_ref().set() {
let request = crate::vector::rest::Request::new(request.to_owned())
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into()))?;
let request = match with_fragments {
WithFragments::Yes { indexing_fragments, search_fragments } => {
crate::vector::rest::RequestData::new(
request.to_owned(),
indexing_fragments,
search_fragments,
)
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into()))
}
WithFragments::No => crate::vector::rest::RequestData::new(
request.to_owned(),
Default::default(),
Default::default(),
)
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into())),
WithFragments::Maybe => {
let mut indexing_fragments = BTreeMap::new();
indexing_fragments.insert("test".to_string(), serde_json::json!("test"));
crate::vector::rest::RequestData::new(
request.to_owned(),
indexing_fragments,
Default::default(),
)
.or_else(|_| {
crate::vector::rest::RequestData::new(
request.to_owned(),
Default::default(),
Default::default(),
)
})
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into()))
}
}?;
if let Some(response) = response.as_ref().set() {
crate::vector::rest::Response::new(response.to_owned(), &request)
.map_err(|error| crate::UserError::VectorEmbeddingError(error.into()))?;
@@ -2017,6 +2249,8 @@ pub fn validate_embedding_settings(
document_template,
document_template_max_bytes,
url,
indexing_fragments,
search_fragments,
request,
response,
search_embedder,
@@ -2036,6 +2270,8 @@ pub fn validate_embedding_settings(
&dimensions,
&api_key,
&url,
&indexing_fragments,
&search_fragments,
&request,
&response,
&document_template,
@@ -2114,6 +2350,8 @@ pub fn validate_embedding_settings(
&embedder.dimensions,
&embedder.api_key,
&embedder.url,
&embedder.indexing_fragments,
&embedder.search_fragments,
&embedder.request,
&embedder.response,
&embedder.document_template,
@@ -2169,6 +2407,8 @@ pub fn validate_embedding_settings(
&embedder.dimensions,
&embedder.api_key,
&embedder.url,
&embedder.indexing_fragments,
&embedder.search_fragments,
&embedder.request,
&embedder.response,
&embedder.document_template,
@@ -2201,6 +2441,8 @@ pub fn validate_embedding_settings(
document_template,
document_template_max_bytes,
url,
indexing_fragments,
search_fragments,
request,
response,
search_embedder,
@@ -2231,20 +2473,32 @@ fn deserialize_sub_embedder(
/// Implement this trait for the settings delta type.
/// This is used in the new settings update flow and will allow to easily replace the old settings delta type: `InnerIndexSettingsDiff`.
pub trait SettingsDelta {
fn new_embedders(&self) -> &EmbeddingConfigs;
fn old_embedders(&self) -> &EmbeddingConfigs;
fn new_embedders(&self) -> &RuntimeEmbedders;
fn old_embedders(&self) -> &RuntimeEmbedders;
fn new_embedder_category_id(&self) -> &HashMap<String, u8>;
fn embedder_actions(&self) -> &BTreeMap<String, EmbedderAction>;
fn try_for_each_fragment_diff<F, E>(
&self,
embedder_name: &str,
for_each: F,
) -> std::result::Result<(), E>
where
F: FnMut(FragmentDiff) -> std::result::Result<(), E>;
fn new_fields_ids_map(&self) -> &FieldIdMapWithMetadata;
}
pub struct FragmentDiff<'a> {
pub old: Option<&'a RuntimeFragment>,
pub new: &'a RuntimeFragment,
}
impl SettingsDelta for InnerIndexSettingsDiff {
fn new_embedders(&self) -> &EmbeddingConfigs {
&self.new.embedding_configs
fn new_embedders(&self) -> &RuntimeEmbedders {
&self.new.runtime_embedders
}
fn old_embedders(&self) -> &EmbeddingConfigs {
&self.old.embedding_configs
fn old_embedders(&self) -> &RuntimeEmbedders {
&self.old.runtime_embedders
}
fn new_embedder_category_id(&self) -> &HashMap<String, u8> {
@@ -2258,6 +2512,37 @@ impl SettingsDelta for InnerIndexSettingsDiff {
fn new_fields_ids_map(&self) -> &FieldIdMapWithMetadata {
&self.new.fields_ids_map
}
fn try_for_each_fragment_diff<F, E>(
&self,
embedder_name: &str,
mut for_each: F,
) -> std::result::Result<(), E>
where
F: FnMut(FragmentDiff) -> std::result::Result<(), E>,
{
let Some(fragment_diff) = self.fragment_diffs.get(embedder_name) else { return Ok(()) };
for (old, new) in fragment_diff {
let Some(new_runtime) = self.new.runtime_embedders.get(embedder_name) else {
continue;
};
let new = new_runtime.fragments().get(*new).unwrap();
match old {
Some(old) => {
if let Some(old_runtime) = self.old.runtime_embedders.get(embedder_name) {
let old = &old_runtime.fragments().get(*old).unwrap();
for_each(FragmentDiff { old: Some(old), new })?;
} else {
for_each(FragmentDiff { old: None, new })?;
}
}
None => for_each(FragmentDiff { old: None, new })?,
};
}
Ok(())
}
}
#[cfg(test)]

View File

@@ -0,0 +1,443 @@
//! Module containing types and methods to store meta-information about the embedders and fragments
use std::borrow::Cow;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use heed::types::{SerdeJson, Str, U8};
use heed::{BytesEncode, Database, RoTxn, RwTxn, Unspecified};
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use crate::vector::settings::RemoveFragments;
use crate::vector::EmbeddingConfig;
use crate::{CboRoaringBitmapCodec, DocumentId, UserError};
#[derive(Debug, Deserialize, Serialize)]
pub struct IndexEmbeddingConfig {
pub name: String,
pub config: EmbeddingConfig,
#[serde(default)]
pub fragments: FragmentConfigs,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct FragmentConfigs(Vec<FragmentConfig>);
impl FragmentConfigs {
pub fn new() -> Self {
Default::default()
}
pub fn as_slice(&self) -> &[FragmentConfig] {
self.0.as_slice()
}
pub fn into_inner(self) -> Vec<FragmentConfig> {
self.0
}
pub fn remove_fragments<'a>(
&mut self,
fragments: impl IntoIterator<Item = &'a str>,
) -> Option<RemoveFragments> {
let mut remove_fragments = Vec::new();
for fragment in fragments {
let Ok(index_to_remove) = self.0.binary_search_by_key(&fragment, |f| &f.name) else {
continue;
};
let fragment = self.0.swap_remove(index_to_remove);
remove_fragments.push(fragment.id);
}
(!remove_fragments.is_empty()).then_some(RemoveFragments { fragment_ids: remove_fragments })
}
pub fn add_new_fragments(
&mut self,
new_fragments: impl IntoIterator<Item = String>,
) -> crate::Result<()> {
let mut free_indices: [bool; u8::MAX as usize] = [true; u8::MAX as usize];
for FragmentConfig { id, name: _ } in self.0.iter() {
free_indices[*id as usize] = false;
}
let mut free_indices = free_indices.iter_mut().enumerate();
let mut find_free_index =
move || free_indices.find(|(_, free)| **free).map(|(index, _)| index as u8);
let mut new_fragments = new_fragments.into_iter();
for name in &mut new_fragments {
let id = match find_free_index() {
Some(id) => id,
None => {
let more = (&mut new_fragments).count();
return Err(UserError::TooManyFragments(u8::MAX as usize + more + 1).into());
}
};
self.0.push(FragmentConfig { id, name });
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FragmentConfig {
pub id: u8,
pub name: String,
}
pub struct IndexEmbeddingConfigs {
main: Database<Unspecified, Unspecified>,
embedder_info: Database<Str, EmbedderInfoCodec>,
}
pub struct EmbedderInfo {
pub embedder_id: u8,
pub embedding_status: EmbeddingStatus,
}
impl EmbedderInfo {
pub fn to_bytes(&self) -> Result<Cow<'_, [u8]>, heed::BoxedError> {
EmbedderInfoCodec::bytes_encode(self)
}
}
/// Optimized struct to hold the list of documents that are `user_provided` and `must_regenerate`.
///
/// Because most documents have the same value for `user_provided` and `must_regenerate`, we store only
/// the `user_provided` and a list of the documents for which `must_regenerate` assumes the other value
/// than `user_provided`.
#[derive(Default)]
pub struct EmbeddingStatus {
user_provided: RoaringBitmap,
skip_regenerate_different_from_user_provided: RoaringBitmap,
}
impl EmbeddingStatus {
pub fn new() -> Self {
Default::default()
}
/// Whether the document contains user-provided vectors for that embedder.
pub fn is_user_provided(&self, docid: DocumentId) -> bool {
self.user_provided.contains(docid)
}
/// Whether vectors should be regenerated for that document and that embedder.
pub fn must_regenerate(&self, docid: DocumentId) -> bool {
let invert = self.skip_regenerate_different_from_user_provided.contains(docid);
let user_provided = self.user_provided.contains(docid);
!(user_provided ^ invert)
}
pub fn is_user_provided_must_regenerate(&self, docid: DocumentId) -> (bool, bool) {
let invert = self.skip_regenerate_different_from_user_provided.contains(docid);
let user_provided = self.user_provided.contains(docid);
(user_provided, !(user_provided ^ invert))
}
pub fn user_provided_docids(&self) -> &RoaringBitmap {
&self.user_provided
}
pub fn skip_regenerate_docids(&self) -> RoaringBitmap {
&self.user_provided ^ &self.skip_regenerate_different_from_user_provided
}
pub(crate) fn into_user_provided(self) -> RoaringBitmap {
self.user_provided
}
}
#[derive(Default)]
pub struct EmbeddingStatusDelta {
del_status: EmbeddingStatus,
add_status: EmbeddingStatus,
}
impl EmbeddingStatusDelta {
pub fn new() -> Self {
Self::default()
}
pub fn needs_change(
old_is_user_provided: bool,
old_must_regenerate: bool,
new_is_user_provided: bool,
new_must_regenerate: bool,
) -> bool {
let old_skip_regenerate_different_user_provided =
old_is_user_provided == old_must_regenerate;
let new_skip_regenerate_different_user_provided =
new_is_user_provided == new_must_regenerate;
old_is_user_provided != new_is_user_provided
|| old_skip_regenerate_different_user_provided
!= new_skip_regenerate_different_user_provided
}
pub fn needs_clear(is_user_provided: bool, must_regenerate: bool) -> bool {
Self::needs_change(is_user_provided, must_regenerate, false, true)
}
pub fn clear_docid(
&mut self,
docid: DocumentId,
is_user_provided: bool,
must_regenerate: bool,
) {
self.push_delta(docid, is_user_provided, must_regenerate, false, true);
}
pub fn push_delta(
&mut self,
docid: DocumentId,
old_is_user_provided: bool,
old_must_regenerate: bool,
new_is_user_provided: bool,
new_must_regenerate: bool,
) {
// must_regenerate == !skip_regenerate
let old_skip_regenerate_different_user_provided =
old_is_user_provided == old_must_regenerate;
let new_skip_regenerate_different_user_provided =
new_is_user_provided == new_must_regenerate;
match (old_is_user_provided, new_is_user_provided) {
(true, true) | (false, false) => { /* no change */ }
(true, false) => {
self.del_status.user_provided.insert(docid);
}
(false, true) => {
self.add_status.user_provided.insert(docid);
}
}
match (
old_skip_regenerate_different_user_provided,
new_skip_regenerate_different_user_provided,
) {
(true, true) | (false, false) => { /* no change */ }
(true, false) => {
self.del_status.skip_regenerate_different_from_user_provided.insert(docid);
}
(false, true) => {
self.add_status.skip_regenerate_different_from_user_provided.insert(docid);
}
}
}
pub fn push_new(&mut self, docid: DocumentId, is_user_provided: bool, must_regenerate: bool) {
self.push_delta(
docid,
!is_user_provided,
!must_regenerate,
is_user_provided,
must_regenerate,
);
}
pub fn apply_to(&self, status: &mut EmbeddingStatus) {
status.user_provided -= &self.del_status.user_provided;
status.user_provided |= &self.add_status.user_provided;
status.skip_regenerate_different_from_user_provided -=
&self.del_status.skip_regenerate_different_from_user_provided;
status.skip_regenerate_different_from_user_provided |=
&self.add_status.skip_regenerate_different_from_user_provided;
}
}
struct EmbedderInfoCodec;
impl<'a> heed::BytesDecode<'a> for EmbedderInfoCodec {
type DItem = EmbedderInfo;
fn bytes_decode(mut bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
let embedder_id = bytes.read_u8()?;
// Support all version that didn't store the embedding status
if bytes.is_empty() {
return Ok(EmbedderInfo { embedder_id, embedding_status: EmbeddingStatus::new() });
}
let first_bitmap_size = bytes.read_u32::<BigEndian>()?;
let first_bitmap_bytes = &bytes[..first_bitmap_size as usize];
let user_provided = CboRoaringBitmapCodec::bytes_decode(first_bitmap_bytes)?;
let skip_regenerate_different_from_user_provided =
CboRoaringBitmapCodec::bytes_decode(&bytes[first_bitmap_size as usize..])?;
Ok(EmbedderInfo {
embedder_id,
embedding_status: EmbeddingStatus {
user_provided,
skip_regenerate_different_from_user_provided,
},
})
}
}
impl<'a> heed::BytesEncode<'a> for EmbedderInfoCodec {
type EItem = EmbedderInfo;
fn bytes_encode(item: &'a Self::EItem) -> Result<Cow<'a, [u8]>, heed::BoxedError> {
let first_bitmap_size =
CboRoaringBitmapCodec::serialized_size(&item.embedding_status.user_provided);
let second_bitmap_size = CboRoaringBitmapCodec::serialized_size(
&item.embedding_status.skip_regenerate_different_from_user_provided,
);
let mut bytes = Vec::with_capacity(1 + 4 + first_bitmap_size + second_bitmap_size);
bytes.write_u8(item.embedder_id)?;
bytes.write_u32::<BigEndian>(first_bitmap_size.try_into()?)?;
CboRoaringBitmapCodec::serialize_into_writer(
&item.embedding_status.user_provided,
&mut bytes,
)?;
CboRoaringBitmapCodec::serialize_into_writer(
&item.embedding_status.skip_regenerate_different_from_user_provided,
&mut bytes,
)?;
Ok(bytes.into())
}
}
impl IndexEmbeddingConfigs {
pub(crate) fn new(
main: Database<Unspecified, Unspecified>,
embedder_info: Database<Unspecified, Unspecified>,
) -> Self {
Self { main, embedder_info: embedder_info.remap_types() }
}
pub(crate) fn put_embedding_configs(
&self,
wtxn: &mut RwTxn<'_>,
configs: Vec<IndexEmbeddingConfig>,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeJson<Vec<IndexEmbeddingConfig>>>().put(
wtxn,
crate::index::main_key::EMBEDDING_CONFIGS,
&configs,
)
}
pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
self.main.remap_key_type::<Str>().delete(wtxn, crate::index::main_key::EMBEDDING_CONFIGS)
}
pub fn embedding_configs(&self, rtxn: &RoTxn<'_>) -> heed::Result<Vec<IndexEmbeddingConfig>> {
Ok(self
.main
.remap_types::<Str, SerdeJson<Vec<IndexEmbeddingConfig>>>()
.get(rtxn, crate::index::main_key::EMBEDDING_CONFIGS)?
.unwrap_or_default())
}
pub fn embedder_id(&self, rtxn: &RoTxn<'_>, name: &str) -> heed::Result<Option<u8>> {
self.embedder_info.remap_data_type::<U8>().get(rtxn, name)
}
pub fn put_fresh_embedder_id(
&self,
wtxn: &mut RwTxn<'_>,
name: &str,
embedder_id: u8,
) -> heed::Result<()> {
let info = EmbedderInfo { embedder_id, embedding_status: EmbeddingStatus::new() };
self.put_embedder_info(wtxn, name, &info)
}
/// Iterate through the passed list of embedder names, associating a fresh embedder id to any new names.
///
/// Passing the name of a currently existing embedder is not an error, and will not modify its embedder id,
/// so it is not necessary to differentiate between new and existing embedders before calling this function.
pub fn add_new_embedders<'a>(
&self,
wtxn: &mut RwTxn<'_>,
embedder_names: impl IntoIterator<Item = &'a str>,
total_embedder_count: usize,
) -> crate::Result<()> {
let mut free_indices: [bool; u8::MAX as usize] = [true; u8::MAX as usize];
for res in self.embedder_info.iter(wtxn)? {
let (_name, EmbedderInfo { embedder_id, embedding_status: _ }) = res?;
free_indices[embedder_id as usize] = false;
}
let mut free_indices = free_indices.iter_mut().enumerate();
let mut find_free_index =
move || free_indices.find(|(_, free)| **free).map(|(index, _)| index as u8);
for embedder_name in embedder_names {
if self.embedder_id(wtxn, embedder_name)?.is_some() {
continue;
}
let embedder_id = find_free_index()
.ok_or(crate::UserError::TooManyEmbedders(total_embedder_count))?;
tracing::debug!(
embedder = embedder_name,
embedder_id,
"assigning free id to new embedder"
);
self.put_fresh_embedder_id(wtxn, embedder_name, embedder_id)?;
}
Ok(())
}
pub fn embedder_info(
&self,
rtxn: &RoTxn<'_>,
name: &str,
) -> heed::Result<Option<EmbedderInfo>> {
self.embedder_info.get(rtxn, name)
}
/// Clear the list of docids that are `user_provided` or `must_regenerate` across all embedders.
pub fn clear_embedder_info_docids(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<()> {
let mut it = self.embedder_info.iter_mut(wtxn)?;
while let Some(res) = it.next() {
let (embedder_name, info) = res?;
let embedder_name = embedder_name.to_owned();
// SAFETY: we copied the `embedder_name` so are not using the reference while using put
unsafe {
it.put_current(
&embedder_name,
&EmbedderInfo {
embedder_id: info.embedder_id,
embedding_status: EmbeddingStatus::new(),
},
)?;
}
}
Ok(())
}
pub fn iter_embedder_info<'a>(
&self,
rtxn: &'a RoTxn<'_>,
) -> heed::Result<impl Iterator<Item = heed::Result<(&'a str, EmbedderInfo)>>> {
self.embedder_info.iter(rtxn)
}
pub fn iter_embedder_id<'a>(
&self,
rtxn: &'a RoTxn<'_>,
) -> heed::Result<impl Iterator<Item = heed::Result<(&'a str, u8)>>> {
self.embedder_info.remap_data_type::<U8>().iter(rtxn)
}
pub fn remove_embedder(
&self,
wtxn: &mut RwTxn<'_>,
name: &str,
) -> heed::Result<Option<EmbedderInfo>> {
let info = self.embedder_info.get(wtxn, name)?;
self.embedder_info.delete(wtxn, name)?;
Ok(info)
}
pub fn put_embedder_info(
&self,
wtxn: &mut RwTxn<'_>,
name: &str,
info: &EmbedderInfo,
) -> heed::Result<()> {
self.embedder_info.put(wtxn, name, info)
}
}

View File

@@ -3,6 +3,7 @@ use std::path::PathBuf;
use bumpalo::Bump;
use hf_hub::api::sync::ApiError;
use itertools::Itertools as _;
use super::parsed_vectors::ParsedVectorsDiff;
use super::rest::ConfigurationSource;
@@ -101,6 +102,32 @@ pub enum EmbedErrorKind {
MissingEmbedding,
#[error(transparent)]
PanicInThreadPool(#[from] PanicCatched),
#[error("`media` requested but the configuration doesn't have source `rest`")]
RestMediaNotARest,
#[error("`media` requested, and the configuration has source `rest`, but the configuration doesn't have `searchFragments`.")]
RestMediaNotAFragment,
#[error("Query matches multiple search fragments.\n - Note: First matched fragment `{name}`.\n - Note: Second matched fragment `{second_name}`.\n - Note: {}",
{
serde_json::json!({
"q": q,
"media": media
})
})]
RestSearchMatchesMultipleFragments {
name: String,
second_name: String,
q: Option<String>,
media: Option<serde_json::Value>,
},
#[error("Query matches no search fragment.\n - Note: {}",
{
serde_json::json!({
"q": q,
"media": media
})
})]
RestSearchMatchesNoFragment { q: Option<String>, media: Option<serde_json::Value> },
}
fn option_info(info: Option<&str>, prefix: &str) -> String {
@@ -210,6 +237,44 @@ impl EmbedError {
pub(crate) fn rest_extraction_error(error: String) -> EmbedError {
Self { kind: EmbedErrorKind::RestExtractionError(error), fault: FaultSource::Runtime }
}
pub(crate) fn rest_media_not_a_rest() -> EmbedError {
Self { kind: EmbedErrorKind::RestMediaNotARest, fault: FaultSource::User }
}
pub(crate) fn rest_media_not_a_fragment() -> EmbedError {
Self { kind: EmbedErrorKind::RestMediaNotAFragment, fault: FaultSource::User }
}
pub(crate) fn rest_search_matches_multiple_fragments(
name: &str,
second_name: &str,
q: Option<&str>,
media: Option<&serde_json::Value>,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestSearchMatchesMultipleFragments {
name: name.to_string(),
second_name: second_name.to_string(),
q: q.map(String::from),
media: media.cloned(),
},
fault: FaultSource::User,
}
}
pub(crate) fn rest_search_matches_no_fragment(
q: Option<&str>,
media: Option<&serde_json::Value>,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestSearchMatchesNoFragment {
q: q.map(String::from),
media: media.cloned(),
},
fault: FaultSource::User,
}
}
}
#[derive(Debug, thiserror::Error)]
@@ -382,6 +447,49 @@ impl NewEmbedderError {
fault: FaultSource::User,
}
}
pub(crate) fn rest_cannot_infer_dimensions_for_fragment() -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::RestCannotInferDimensionsForFragment,
fault: FaultSource::User,
}
}
pub(crate) fn rest_inconsistent_fragments(
indexing_fragments_is_empty: bool,
indexing_fragments: BTreeMap<String, serde_json::Value>,
search_fragments: BTreeMap<String, serde_json::Value>,
) -> NewEmbedderError {
let message = if indexing_fragments_is_empty {
format!("`indexingFragments` is empty, but `searchFragments` declares {} fragments: {}{}\n - Hint: declare at least one fragment in `indexingFragments` or remove fragments from `searchFragments` by setting them to `null`",
search_fragments.len(),
search_fragments.keys().take(3).join(", "), if search_fragments.len() > 3 { ", ..." } else { "" }
)
} else {
format!("`searchFragments` is empty, but `indexingFragments` declares {} fragments: {}{}\n - Hint: declare at least one fragment in `searchFragments` or remove fragments from `indexingFragments` by setting them to `null`",
indexing_fragments.len(),
indexing_fragments.keys().take(3).join(", "), if indexing_fragments.len() > 3 { ", ..." } else { "" }
)
};
Self {
kind: NewEmbedderErrorKind::RestInconsistentFragments { message },
fault: FaultSource::User,
}
}
pub(crate) fn rest_document_template_and_fragments(
indexing_fragments_len: usize,
search_fragments_len: usize,
) -> Self {
Self {
kind: NewEmbedderErrorKind::RestDocumentTemplateAndFragments {
indexing_fragments_len,
search_fragments_len,
},
fault: FaultSource::User,
}
}
}
#[derive(Debug, Clone, Copy)]
@@ -499,6 +607,12 @@ pub enum NewEmbedderErrorKind {
CompositeEmbeddingCountMismatch { search_count: usize, index_count: usize },
#[error("error while generating test embeddings.\n - the embeddings produced at search time and indexing time are not similar enough.\n - angular distance {distance:.2}\n - Meilisearch requires a maximum distance of {MAX_COMPOSITE_DISTANCE}.\n - Note: check that both embedders produce similar embeddings.{hint}")]
CompositeEmbeddingValueMismatch { distance: f32, hint: CompositeEmbedderContainsHuggingFace },
#[error("cannot infer `dimensions` for an embedder using `indexingFragments`.\n - Note: Specify `dimensions` explicitly or don't use `indexingFragments`.")]
RestCannotInferDimensionsForFragment,
#[error("inconsistent fragments: {message}")]
RestInconsistentFragments { message: String },
#[error("cannot pass both fragments and a document template.\n - Note: {indexing_fragments_len} fragments declared in `indexingFragments` and {search_fragments_len} fragments declared in `search_fragments_len`.\n - Hint: remove the declared fragments or remove the `documentTemplate`")]
RestDocumentTemplateAndFragments { indexing_fragments_len: usize, search_fragments_len: usize },
}
pub struct PossibleEmbeddingMistakes {

View File

@@ -0,0 +1,244 @@
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::fmt::Debug;
use bumpalo::Bump;
use serde_json::Value;
use super::json_template::{self, JsonTemplate};
use crate::prompt::error::RenderPromptError;
use crate::prompt::Prompt;
use crate::update::new::document::Document;
use crate::vector::RuntimeFragment;
use crate::GlobalFieldsIdsMap;
/// Trait for types that extract embedder inputs from a document.
///
/// An embedder input can then be sent to an embedder by using an [`super::session::EmbedSession`].
pub trait Extractor<'doc> {
/// The embedder input that is extracted from documents by this extractor.
///
/// The inputs have to be comparable for equality so that diffing is possible.
type Input: PartialEq;
/// The error that can happen while extracting from a document.
type Error;
/// Metadata associated with a document.
type DocumentMetadata;
/// Extract the embedder input from a document and its metadata.
fn extract<'a, D: Document<'a> + Debug>(
&self,
doc: D,
meta: &Self::DocumentMetadata,
) -> Result<Option<Self::Input>, Self::Error>;
/// Unique `id` associated with this extractor.
///
/// This will serve to decide where to store the vectors in the vector store.
/// The id should be stable for a given extractor.
fn extractor_id(&self) -> u8;
/// The result of diffing the embedder inputs extracted from two versions of a document.
///
/// # Parameters
///
/// - `old`: old version of the document
/// - `new`: new version of the document
/// - `meta`: metadata associated to the document
fn diff_documents<'a, OD: Document<'a> + Debug, ND: Document<'a> + Debug>(
&self,
old: OD,
new: ND,
meta: &Self::DocumentMetadata,
) -> Result<ExtractorDiff<Self::Input>, Self::Error>
where
'doc: 'a,
{
let old_input = self.extract(old, meta);
let new_input = self.extract(new, meta);
to_diff(old_input, new_input)
}
/// The result of diffing the embedder inputs extracted from a document by two versions of this extractor.
///
/// # Parameters
///
/// - `doc`: the document from which to extract the embedder inputs
/// - `meta`: metadata associated to the document
/// - `old`: If `Some`, the old version of this extractor. If `None`, this is equivalent to calling `ExtractorDiff::Added(self.extract(_))`.
fn diff_settings<'a, D: Document<'a> + Debug>(
&self,
doc: D,
meta: &Self::DocumentMetadata,
old: Option<&Self>,
) -> Result<ExtractorDiff<Self::Input>, Self::Error> {
let old_input = if let Some(old) = old { old.extract(&doc, meta) } else { Ok(None) };
let new_input = self.extract(&doc, meta);
to_diff(old_input, new_input)
}
/// Returns an extractor wrapping `self` and set to ignore all errors arising from extracting with this extractor.
fn ignore_errors(self) -> IgnoreErrorExtractor<Self>
where
Self: Sized,
{
IgnoreErrorExtractor(self)
}
}
fn to_diff<I: PartialEq, E>(
old_input: Result<Option<I>, E>,
new_input: Result<Option<I>, E>,
) -> Result<ExtractorDiff<I>, E> {
let old_input = old_input.ok().unwrap_or(None);
let new_input = new_input?;
Ok(match (old_input, new_input) {
(Some(old), Some(new)) if old == new => ExtractorDiff::Unchanged,
(None, None) => ExtractorDiff::Unchanged,
(None, Some(input)) => ExtractorDiff::Added(input),
(Some(_), None) => ExtractorDiff::Removed,
(Some(_), Some(input)) => ExtractorDiff::Updated(input),
})
}
pub enum ExtractorDiff<Input> {
Removed,
Added(Input),
Updated(Input),
Unchanged,
}
impl<Input> ExtractorDiff<Input> {
pub fn into_input(self) -> Option<Input> {
match self {
ExtractorDiff::Removed => None,
ExtractorDiff::Added(input) => Some(input),
ExtractorDiff::Updated(input) => Some(input),
ExtractorDiff::Unchanged => None,
}
}
pub fn needs_change(&self) -> bool {
match self {
ExtractorDiff::Removed => true,
ExtractorDiff::Added(_) => true,
ExtractorDiff::Updated(_) => true,
ExtractorDiff::Unchanged => false,
}
}
pub fn into_list_of_changes(
named_diffs: impl IntoIterator<Item = (String, Self)>,
) -> BTreeMap<String, Option<Input>> {
named_diffs
.into_iter()
.filter(|(_, diff)| diff.needs_change())
.map(|(name, diff)| (name, diff.into_input()))
.collect()
}
}
pub struct DocumentTemplateExtractor<'a, 'b, 'c> {
doc_alloc: &'a Bump,
field_id_map: &'a RefCell<GlobalFieldsIdsMap<'b>>,
template: &'c Prompt,
}
impl<'a, 'b, 'c> DocumentTemplateExtractor<'a, 'b, 'c> {
pub fn new(
template: &'c Prompt,
doc_alloc: &'a Bump,
field_id_map: &'a RefCell<GlobalFieldsIdsMap<'b>>,
) -> Self {
Self { template, doc_alloc, field_id_map }
}
}
impl<'doc> Extractor<'doc> for DocumentTemplateExtractor<'doc, '_, '_> {
type DocumentMetadata = &'doc str;
type Input = &'doc str;
type Error = RenderPromptError;
fn extractor_id(&self) -> u8 {
0
}
fn extract<'a, D: Document<'a> + Debug>(
&self,
doc: D,
external_docid: &Self::DocumentMetadata,
) -> Result<Option<Self::Input>, Self::Error> {
Ok(Some(self.template.render_document(
external_docid,
doc,
self.field_id_map,
self.doc_alloc,
)?))
}
}
pub struct RequestFragmentExtractor<'a> {
fragment: &'a JsonTemplate,
extractor_id: u8,
doc_alloc: &'a Bump,
}
impl<'a> RequestFragmentExtractor<'a> {
pub fn new(fragment: &'a RuntimeFragment, doc_alloc: &'a Bump) -> Self {
Self { fragment: &fragment.template, extractor_id: fragment.id, doc_alloc }
}
}
impl<'doc> Extractor<'doc> for RequestFragmentExtractor<'doc> {
type DocumentMetadata = ();
type Input = Value;
type Error = json_template::Error;
fn extractor_id(&self) -> u8 {
self.extractor_id
}
fn extract<'a, D: Document<'a> + Debug>(
&self,
doc: D,
_meta: &Self::DocumentMetadata,
) -> Result<Option<Self::Input>, Self::Error> {
Ok(Some(self.fragment.render_document(doc, self.doc_alloc)?))
}
}
pub struct IgnoreErrorExtractor<E>(E);
impl<'doc, E> Extractor<'doc> for IgnoreErrorExtractor<E>
where
E: Extractor<'doc>,
{
type DocumentMetadata = E::DocumentMetadata;
type Input = E::Input;
type Error = Infallible;
fn extractor_id(&self) -> u8 {
self.0.extractor_id()
}
fn extract<'a, D: Document<'a> + Debug>(
&self,
doc: D,
meta: &Self::DocumentMetadata,
) -> Result<Option<Self::Input>, Self::Error> {
Ok(self.0.extract(doc, meta).ok().flatten())
}
}
#[derive(Debug)]
pub enum Infallible {}
impl From<Infallible> for crate::Error {
fn from(_: Infallible) -> Self {
unreachable!("Infallible values cannot be built")
}
}

View File

@@ -1,20 +1,17 @@
//! Module to manipulate JSON templates.
//! Module to manipulate JSON values containing placeholder strings.
//!
//! This module allows two main operations:
//! 1. Render JSON values from a template and a context value.
//! 2. Retrieve data from a template and JSON values.
#![warn(rustdoc::broken_intra_doc_links)]
#![warn(missing_docs)]
//! 1. Render JSON values from a template value containing placeholders and a value to inject.
//! 2. Extract data from a template value containing placeholders and a concrete JSON value that fits the template value.
use serde::Deserialize;
use serde_json::{Map, Value};
type ValuePath = Vec<PathComponent>;
use super::{format_value, inject_value, path_with_root, PathComponent, ValuePath};
/// Encapsulates a JSON template and allows injecting and extracting values from it.
#[derive(Debug)]
pub struct ValueTemplate {
pub struct InjectableValue {
template: Value,
value_kind: ValueKind,
}
@@ -32,34 +29,13 @@ struct ArrayPath {
value_path_in_array: ValuePath,
}
/// Component of a path to a Value
#[derive(Debug, Clone)]
pub enum PathComponent {
/// A key inside of an object
MapKey(String),
/// An index inside of an array
ArrayIndex(usize),
}
impl PartialEq for PathComponent {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::MapKey(l0), Self::MapKey(r0)) => l0 == r0,
(Self::ArrayIndex(l0), Self::ArrayIndex(r0)) => l0 == r0,
_ => false,
}
}
}
impl Eq for PathComponent {}
/// Error that occurs when no few value was provided to a template for injection.
/// Error that occurs when no value was provided to a template for injection.
#[derive(Debug)]
pub struct MissingValue;
/// Error that occurs when trying to parse a template in [`ValueTemplate::new`]
/// Error that occurs when trying to parse a template in [`InjectableValue::new`]
#[derive(Debug)]
pub enum TemplateParsingError {
pub enum InjectableParsingError {
/// A repeat string appears inside a repeated value
NestedRepeatString(ValuePath),
/// A repeat string appears outside of an array
@@ -85,42 +61,42 @@ pub enum TemplateParsingError {
},
}
impl TemplateParsingError {
impl InjectableParsingError {
/// Produce an error message from the error kind, the name of the root object, the placeholder string and the repeat string
pub fn error_message(&self, root: &str, placeholder: &str, repeat: &str) -> String {
match self {
TemplateParsingError::NestedRepeatString(path) => {
InjectableParsingError::NestedRepeatString(path) => {
format!(
r#"in {}: "{repeat}" appears nested inside of a value that is itself repeated"#,
path_with_root(root, path)
)
}
TemplateParsingError::RepeatStringNotInArray(path) => format!(
InjectableParsingError::RepeatStringNotInArray(path) => format!(
r#"in {}: "{repeat}" appears outside of an array"#,
path_with_root(root, path)
),
TemplateParsingError::BadIndexForRepeatString(path, index) => format!(
InjectableParsingError::BadIndexForRepeatString(path, index) => format!(
r#"in {}: "{repeat}" expected at position #1, but found at position #{index}"#,
path_with_root(root, path)
),
TemplateParsingError::MissingPlaceholderInRepeatedValue(path) => format!(
InjectableParsingError::MissingPlaceholderInRepeatedValue(path) => format!(
r#"in {}: Expected "{placeholder}" inside of the repeated value"#,
path_with_root(root, path)
),
TemplateParsingError::MultipleRepeatString(current, previous) => format!(
InjectableParsingError::MultipleRepeatString(current, previous) => format!(
r#"in {}: Found "{repeat}", but it was already present in {}"#,
path_with_root(root, current),
path_with_root(root, previous)
),
TemplateParsingError::MultiplePlaceholderString(current, previous) => format!(
InjectableParsingError::MultiplePlaceholderString(current, previous) => format!(
r#"in {}: Found "{placeholder}", but it was already present in {}"#,
path_with_root(root, current),
path_with_root(root, previous)
),
TemplateParsingError::MissingPlaceholderString => {
InjectableParsingError::MissingPlaceholderString => {
format!(r#"in `{root}`: "{placeholder}" not found"#)
}
TemplateParsingError::BothArrayAndSingle {
InjectableParsingError::BothArrayAndSingle {
single_path,
path_to_array,
array_to_placeholder,
@@ -140,41 +116,41 @@ impl TemplateParsingError {
fn prepend_path(self, mut prepended_path: ValuePath) -> Self {
match self {
TemplateParsingError::NestedRepeatString(mut path) => {
InjectableParsingError::NestedRepeatString(mut path) => {
prepended_path.append(&mut path);
TemplateParsingError::NestedRepeatString(prepended_path)
InjectableParsingError::NestedRepeatString(prepended_path)
}
TemplateParsingError::RepeatStringNotInArray(mut path) => {
InjectableParsingError::RepeatStringNotInArray(mut path) => {
prepended_path.append(&mut path);
TemplateParsingError::RepeatStringNotInArray(prepended_path)
InjectableParsingError::RepeatStringNotInArray(prepended_path)
}
TemplateParsingError::BadIndexForRepeatString(mut path, index) => {
InjectableParsingError::BadIndexForRepeatString(mut path, index) => {
prepended_path.append(&mut path);
TemplateParsingError::BadIndexForRepeatString(prepended_path, index)
InjectableParsingError::BadIndexForRepeatString(prepended_path, index)
}
TemplateParsingError::MissingPlaceholderInRepeatedValue(mut path) => {
InjectableParsingError::MissingPlaceholderInRepeatedValue(mut path) => {
prepended_path.append(&mut path);
TemplateParsingError::MissingPlaceholderInRepeatedValue(prepended_path)
InjectableParsingError::MissingPlaceholderInRepeatedValue(prepended_path)
}
TemplateParsingError::MultipleRepeatString(mut path, older_path) => {
InjectableParsingError::MultipleRepeatString(mut path, older_path) => {
let older_prepended_path =
prepended_path.iter().cloned().chain(older_path).collect();
prepended_path.append(&mut path);
TemplateParsingError::MultipleRepeatString(prepended_path, older_prepended_path)
InjectableParsingError::MultipleRepeatString(prepended_path, older_prepended_path)
}
TemplateParsingError::MultiplePlaceholderString(mut path, older_path) => {
InjectableParsingError::MultiplePlaceholderString(mut path, older_path) => {
let older_prepended_path =
prepended_path.iter().cloned().chain(older_path).collect();
prepended_path.append(&mut path);
TemplateParsingError::MultiplePlaceholderString(
InjectableParsingError::MultiplePlaceholderString(
prepended_path,
older_prepended_path,
)
}
TemplateParsingError::MissingPlaceholderString => {
TemplateParsingError::MissingPlaceholderString
InjectableParsingError::MissingPlaceholderString => {
InjectableParsingError::MissingPlaceholderString
}
TemplateParsingError::BothArrayAndSingle {
InjectableParsingError::BothArrayAndSingle {
single_path,
mut path_to_array,
array_to_placeholder,
@@ -184,7 +160,7 @@ impl TemplateParsingError {
prepended_path.iter().cloned().chain(single_path).collect();
prepended_path.append(&mut path_to_array);
// we don't prepend the array_to_placeholder path as it is the array path that is prepended
TemplateParsingError::BothArrayAndSingle {
InjectableParsingError::BothArrayAndSingle {
single_path: single_prepended_path,
path_to_array: prepended_path,
array_to_placeholder,
@@ -194,7 +170,7 @@ impl TemplateParsingError {
}
}
/// Error that occurs when [`ValueTemplate::extract`] fails.
/// Error that occurs when [`InjectableValue::extract`] fails.
#[derive(Debug)]
pub struct ExtractionError {
/// The cause of the failure
@@ -336,27 +312,6 @@ enum LastNamedObject<'a> {
NestedArrayInsideObject { object_name: &'a str, index: usize, nesting_level: usize },
}
/// Builds a string representation of a path, preprending the name of the root value.
pub fn path_with_root<'a>(
root: &str,
path: impl IntoIterator<Item = &'a PathComponent> + 'a,
) -> String {
use std::fmt::Write as _;
let mut res = format!("`{root}");
for component in path.into_iter() {
match component {
PathComponent::MapKey(key) => {
let _ = write!(&mut res, ".{key}");
}
PathComponent::ArrayIndex(index) => {
let _ = write!(&mut res, "[{index}]");
}
}
}
res.push('`');
res
}
/// Context where an extraction failure happened
///
/// The operation that failed
@@ -405,7 +360,7 @@ enum ArrayParsingContext<'a> {
NotNested(&'a mut Option<ArrayPath>),
}
impl ValueTemplate {
impl InjectableValue {
/// Prepare a template for injection or extraction.
///
/// # Parameters
@@ -419,12 +374,12 @@ impl ValueTemplate {
///
/// # Errors
///
/// - [`TemplateParsingError`]: refer to the documentation of this type
/// - [`InjectableParsingError`]: refer to the documentation of this type
pub fn new(
template: Value,
placeholder_string: &str,
repeat_string: &str,
) -> Result<Self, TemplateParsingError> {
) -> Result<Self, InjectableParsingError> {
let mut value_path = None;
let mut array_path = None;
let mut current_path = Vec::new();
@@ -438,11 +393,11 @@ impl ValueTemplate {
)?;
let value_kind = match (array_path, value_path) {
(None, None) => return Err(TemplateParsingError::MissingPlaceholderString),
(None, None) => return Err(InjectableParsingError::MissingPlaceholderString),
(None, Some(value_path)) => ValueKind::Single(value_path),
(Some(array_path), None) => ValueKind::Array(array_path),
(Some(array_path), Some(value_path)) => {
return Err(TemplateParsingError::BothArrayAndSingle {
return Err(InjectableParsingError::BothArrayAndSingle {
single_path: value_path,
path_to_array: array_path.path_to_array,
array_to_placeholder: array_path.value_path_in_array,
@@ -564,29 +519,29 @@ impl ValueTemplate {
value_path: &mut Option<ValuePath>,
mut array_path: &mut ArrayParsingContext,
current_path: &mut ValuePath,
) -> Result<(), TemplateParsingError> {
) -> Result<(), InjectableParsingError> {
// two modes for parsing array.
match array {
// 1. array contains a repeat string in second position
[first, second, rest @ ..] if second == repeat_string => {
let ArrayParsingContext::NotNested(array_path) = &mut array_path else {
return Err(TemplateParsingError::NestedRepeatString(current_path.clone()));
return Err(InjectableParsingError::NestedRepeatString(current_path.clone()));
};
if let Some(array_path) = array_path {
return Err(TemplateParsingError::MultipleRepeatString(
return Err(InjectableParsingError::MultipleRepeatString(
current_path.clone(),
array_path.path_to_array.clone(),
));
}
if first == repeat_string {
return Err(TemplateParsingError::BadIndexForRepeatString(
return Err(InjectableParsingError::BadIndexForRepeatString(
current_path.clone(),
0,
));
}
if let Some(position) = rest.iter().position(|value| value == repeat_string) {
let position = position + 2;
return Err(TemplateParsingError::BadIndexForRepeatString(
return Err(InjectableParsingError::BadIndexForRepeatString(
current_path.clone(),
position,
));
@@ -609,7 +564,9 @@ impl ValueTemplate {
value_path.ok_or_else(|| {
let mut repeated_value_path = current_path.clone();
repeated_value_path.push(PathComponent::ArrayIndex(0));
TemplateParsingError::MissingPlaceholderInRepeatedValue(repeated_value_path)
InjectableParsingError::MissingPlaceholderInRepeatedValue(
repeated_value_path,
)
})?
};
**array_path = Some(ArrayPath {
@@ -621,7 +578,7 @@ impl ValueTemplate {
// 2. array does not contain a repeat string
array => {
if let Some(position) = array.iter().position(|value| value == repeat_string) {
return Err(TemplateParsingError::BadIndexForRepeatString(
return Err(InjectableParsingError::BadIndexForRepeatString(
current_path.clone(),
position,
));
@@ -650,7 +607,7 @@ impl ValueTemplate {
value_path: &mut Option<ValuePath>,
array_path: &mut ArrayParsingContext,
current_path: &mut ValuePath,
) -> Result<(), TemplateParsingError> {
) -> Result<(), InjectableParsingError> {
for (key, value) in object.iter() {
current_path.push(PathComponent::MapKey(key.to_owned()));
Self::parse_value(
@@ -673,12 +630,12 @@ impl ValueTemplate {
value_path: &mut Option<ValuePath>,
array_path: &mut ArrayParsingContext,
current_path: &mut ValuePath,
) -> Result<(), TemplateParsingError> {
) -> Result<(), InjectableParsingError> {
match value {
Value::String(str) => {
if placeholder_string == str {
if let Some(value_path) = value_path {
return Err(TemplateParsingError::MultiplePlaceholderString(
return Err(InjectableParsingError::MultiplePlaceholderString(
current_path.clone(),
value_path.clone(),
));
@@ -687,7 +644,9 @@ impl ValueTemplate {
*value_path = Some(current_path.clone());
}
if repeat_string == str {
return Err(TemplateParsingError::RepeatStringNotInArray(current_path.clone()));
return Err(InjectableParsingError::RepeatStringNotInArray(
current_path.clone(),
));
}
}
Value::Null | Value::Bool(_) | Value::Number(_) => {}
@@ -712,27 +671,6 @@ impl ValueTemplate {
}
}
fn inject_value(rendered: &mut Value, injection_path: &Vec<PathComponent>, injected_value: Value) {
let mut current_value = rendered;
for injection_component in injection_path {
current_value = match injection_component {
PathComponent::MapKey(key) => current_value.get_mut(key).unwrap(),
PathComponent::ArrayIndex(index) => current_value.get_mut(index).unwrap(),
}
}
*current_value = injected_value;
}
fn format_value(value: &Value) -> String {
match value {
Value::Array(array) => format!("an array of size {}", array.len()),
Value::Object(object) => {
format!("an object with {} field(s)", object.len())
}
value => value.to_string(),
}
}
fn extract_value<T>(
extraction_path: &[PathComponent],
initial_value: &mut Value,
@@ -838,10 +776,10 @@ impl<T> ExtractionResultErrorContext<T> for Result<T, ExtractionErrorKind> {
mod test {
use serde_json::{json, Value};
use super::{PathComponent, TemplateParsingError, ValueTemplate};
use super::{InjectableParsingError, InjectableValue, PathComponent};
fn new_template(template: Value) -> Result<ValueTemplate, TemplateParsingError> {
ValueTemplate::new(template, "{{text}}", "{{..}}")
fn new_template(template: Value) -> Result<InjectableValue, InjectableParsingError> {
InjectableValue::new(template, "{{text}}", "{{..}}")
}
#[test]
@@ -853,7 +791,7 @@ mod test {
});
let error = new_template(template.clone()).unwrap_err();
assert!(matches!(error, TemplateParsingError::MissingPlaceholderString))
assert!(matches!(error, InjectableParsingError::MissingPlaceholderString))
}
#[test]
@@ -887,7 +825,7 @@ mod test {
});
match new_template(template.clone()) {
Err(TemplateParsingError::MultiplePlaceholderString(left, right)) => {
Err(InjectableParsingError::MultiplePlaceholderString(left, right)) => {
assert_eq!(
left,
vec![PathComponent::MapKey("titi".into()), PathComponent::ArrayIndex(3)]

View File

@@ -0,0 +1,282 @@
//! Exposes types to manipulate JSON values
//!
//! - [`JsonTemplate`]: renders JSON values by rendering its strings as [`Template`]s.
//! - [`InjectableValue`]: Describes a JSON value containing placeholders,
//! then allows to inject values instead of the placeholder to produce new concrete JSON values,
//! or extract sub-values at the placeholder location from concrete JSON values.
//!
//! The module also exposes foundational types to work with JSON paths:
//!
//! - [`ValuePath`] is made of [`PathComponent`]s to indicate the location of a sub-value inside of a JSON value.
//! - [`inject_value`] is a primitive that replaces the sub-value at the described location by an injected value.
#![warn(rustdoc::broken_intra_doc_links)]
#![warn(missing_docs)]
use bumpalo::Bump;
use liquid::{Parser, Template};
use serde_json::{Map, Value};
use crate::prompt::ParseableDocument;
use crate::update::new::document::Document;
mod injectable_value;
pub use injectable_value::InjectableValue;
/// Represents a JSON [`Value`] where each string is rendered as a [`Template`].
#[derive(Debug)]
pub struct JsonTemplate {
value: Value,
templates: Vec<TemplateAtPath>,
}
impl Clone for JsonTemplate {
fn clone(&self) -> Self {
Self::new(self.value.clone()).unwrap()
}
}
struct TemplateAtPath {
template: Template,
path: ValuePath,
}
impl std::fmt::Debug for TemplateAtPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TemplateAtPath")
.field("template", &&"template")
.field("path", &self.path)
.finish()
}
}
/// Error that can occur either when parsing the templates in the value, or when trying to render them.
#[derive(Debug)]
pub struct Error {
template_error: liquid::Error,
path: ValuePath,
}
impl Error {
/// Produces an error message when the error happened at rendering time.
pub fn rendering_error(&self, root: &str) -> String {
format!(
"in `{}`, error while rendering template: {}",
path_with_root(root, self.path.iter()),
&self.template_error
)
}
/// Produces an error message when the error happened at parsing time.
pub fn parsing(&self, root: &str) -> String {
format!(
"in `{}`, error while parsing template: {}",
path_with_root(root, self.path.iter()),
&self.template_error
)
}
}
impl JsonTemplate {
/// Creates a new `JsonTemplate` by parsing all strings inside the value as templates.
///
/// # Error
///
/// - If any of the strings contains a template that cannot be parsed.
pub fn new(value: Value) -> Result<Self, Error> {
let templates = build_templates(&value)?;
Ok(Self { value, templates })
}
/// Renders this value by replacing all its strings with the rendered version of the template they represent from the given context.
///
/// # Error
///
/// - If any of the strings contains a template that cannot be rendered with the given context.
pub fn render(&self, context: &dyn liquid::ObjectView) -> Result<Value, Error> {
let mut rendered = self.value.clone();
for TemplateAtPath { template, path } in &self.templates {
let injected_value =
template.render(context).map_err(|err| error_with_path(err, path.clone()))?;
inject_value(&mut rendered, path, Value::String(injected_value));
}
Ok(rendered)
}
/// Renders this value by replacing all its strings with the rendered version of the template they represent from the contents of the given document.
///
/// # Error
///
/// - If any of the strings contains a template that cannot be rendered with the given document.
pub fn render_document<'a, 'doc, D: Document<'a> + std::fmt::Debug>(
&self,
document: D,
doc_alloc: &'doc Bump,
) -> Result<Value, Error> {
let document = ParseableDocument::new(document, doc_alloc);
let context = crate::prompt::Context::without_fields(&document);
self.render(&context)
}
/// Renders this value by replacing all its strings with the rendered version of the template they represent from the contents of the search query.
///
/// # Error
///
/// - If any of the strings contains a template that cannot be rendered from the contents of the search query
pub fn render_search(&self, q: Option<&str>, media: Option<&Value>) -> Result<Value, Error> {
let search_data = match (q, media) {
(None, None) => liquid::object!({}),
(None, Some(media)) => liquid::object!({ "media": media }),
(Some(q), None) => liquid::object!({"q": q}),
(Some(q), Some(media)) => liquid::object!({"q": q, "media": media}),
};
self.render(&search_data)
}
/// The JSON value representing the underlying template
pub fn template(&self) -> &Value {
&self.value
}
}
fn build_templates(value: &Value) -> Result<Vec<TemplateAtPath>, Error> {
let mut current_path = ValuePath::new();
let mut templates = Vec::new();
let compiler = liquid::ParserBuilder::with_stdlib().build().unwrap();
parse_value(value, &mut current_path, &mut templates, &compiler)?;
Ok(templates)
}
fn error_with_path(template_error: liquid::Error, path: ValuePath) -> Error {
Error { template_error, path }
}
fn parse_value(
value: &Value,
current_path: &mut ValuePath,
templates: &mut Vec<TemplateAtPath>,
compiler: &Parser,
) -> Result<(), Error> {
match value {
Value::String(template) => {
let template = compiler
.parse(template)
.map_err(|err| error_with_path(err, current_path.clone()))?;
templates.push(TemplateAtPath { template, path: current_path.clone() });
}
Value::Array(values) => {
parse_array(values, current_path, templates, compiler)?;
}
Value::Object(map) => {
parse_object(map, current_path, templates, compiler)?;
}
_ => {}
}
Ok(())
}
fn parse_object(
map: &Map<String, Value>,
current_path: &mut ValuePath,
templates: &mut Vec<TemplateAtPath>,
compiler: &Parser,
) -> Result<(), Error> {
for (key, value) in map {
current_path.push(PathComponent::MapKey(key.clone()));
parse_value(value, current_path, templates, compiler)?;
current_path.pop();
}
Ok(())
}
fn parse_array(
values: &[Value],
current_path: &mut ValuePath,
templates: &mut Vec<TemplateAtPath>,
compiler: &Parser,
) -> Result<(), Error> {
for (index, value) in values.iter().enumerate() {
current_path.push(PathComponent::ArrayIndex(index));
parse_value(value, current_path, templates, compiler)?;
current_path.pop();
}
Ok(())
}
/// A list of [`PathComponent`]s describing a path to a value inside a JSON value.
///
/// The empty list refers to the root value.
pub type ValuePath = Vec<PathComponent>;
/// Component of a path to a Value
#[derive(Debug, Clone)]
pub enum PathComponent {
/// A key inside of an object
MapKey(String),
/// An index inside of an array
ArrayIndex(usize),
}
impl PartialEq for PathComponent {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::MapKey(l0), Self::MapKey(r0)) => l0 == r0,
(Self::ArrayIndex(l0), Self::ArrayIndex(r0)) => l0 == r0,
_ => false,
}
}
}
impl Eq for PathComponent {}
/// Builds a string representation of a path, preprending the name of the root value.
pub fn path_with_root<'a>(
root: &str,
path: impl IntoIterator<Item = &'a PathComponent> + 'a,
) -> String {
use std::fmt::Write as _;
let mut res = format!("`{root}");
for component in path.into_iter() {
match component {
PathComponent::MapKey(key) => {
let _ = write!(&mut res, ".{key}");
}
PathComponent::ArrayIndex(index) => {
let _ = write!(&mut res, "[{index}]");
}
}
}
res.push('`');
res
}
/// Modifies `rendered` to replace the sub-value at the `injection_path` location by the `injected_value`.
///
/// # Panics
///
/// - if the provided `injection_path` cannot be traversed in `rendered`.
pub fn inject_value(
rendered: &mut Value,
injection_path: &Vec<PathComponent>,
injected_value: Value,
) {
let mut current_value = rendered;
for injection_component in injection_path {
current_value = match injection_component {
PathComponent::MapKey(key) => current_value.get_mut(key).unwrap(),
PathComponent::ArrayIndex(index) => current_value.get_mut(index).unwrap(),
}
}
*current_value = injected_value;
}
fn format_value(value: &Value) -> String {
match value {
Value::Array(array) => format!("an array of size {}", array.len()),
Value::Object(object) => {
format!("an object with {} field(s)", object.len())
}
value => value.to_string(),
}
}

View File

@@ -15,15 +15,20 @@ use utoipa::ToSchema;
use self::error::{EmbedError, NewEmbedderError};
use crate::progress::{EmbedderStats, Progress};
use crate::prompt::{Prompt, PromptData};
use crate::vector::composite::SubEmbedderOptions;
use crate::vector::json_template::JsonTemplate;
use crate::ThreadPoolNoAbort;
pub mod composite;
pub mod db;
pub mod error;
pub mod extractor;
pub mod hf;
pub mod json_template;
pub mod manual;
pub mod openai;
pub mod parsed_vectors;
pub mod session;
pub mod settings;
pub mod ollama;
@@ -60,7 +65,7 @@ impl ArroyWrapper {
rtxn: &'a RoTxn<'a>,
db: arroy::Database<D>,
) -> impl Iterator<Item = Result<arroy::Reader<'a, D>, arroy::Error>> + 'a {
arroy_db_range_for_embedder(self.embedder_index).map_while(move |index| {
arroy_store_range_for_embedder(self.embedder_index).filter_map(move |index| {
match arroy::Reader::open(rtxn, index, db) {
Ok(reader) => match reader.is_empty(rtxn) {
Ok(false) => Some(Ok(reader)),
@@ -73,12 +78,57 @@ impl ArroyWrapper {
})
}
pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
let first_id = arroy_db_range_for_embedder(self.embedder_index).next().unwrap();
/// The item ids that are present in the store specified by its id.
///
/// The ids are accessed via a lambda to avoid lifetime shenanigans.
pub fn items_in_store<F, O>(
&self,
rtxn: &RoTxn,
store_id: u8,
with_items: F,
) -> Result<O, arroy::Error>
where
F: FnOnce(&RoaringBitmap) -> O,
{
if self.quantized {
Ok(arroy::Reader::open(rtxn, first_id, self.quantized_db())?.dimensions())
self._items_in_store(rtxn, self.quantized_db(), store_id, with_items)
} else {
Ok(arroy::Reader::open(rtxn, first_id, self.angular_db())?.dimensions())
self._items_in_store(rtxn, self.angular_db(), store_id, with_items)
}
}
fn _items_in_store<D: arroy::Distance, F, O>(
&self,
rtxn: &RoTxn,
db: arroy::Database<D>,
store_id: u8,
with_items: F,
) -> Result<O, arroy::Error>
where
F: FnOnce(&RoaringBitmap) -> O,
{
let index = arroy_store_for_embedder(self.embedder_index, store_id);
let reader = arroy::Reader::open(rtxn, index, db);
match reader {
Ok(reader) => Ok(with_items(reader.item_ids())),
Err(arroy::Error::MissingMetadata(_)) => Ok(with_items(&RoaringBitmap::new())),
Err(err) => Err(err),
}
}
pub fn dimensions(&self, rtxn: &RoTxn) -> Result<Option<usize>, arroy::Error> {
if self.quantized {
Ok(self
.readers(rtxn, self.quantized_db())
.next()
.transpose()?
.map(|reader| reader.dimensions()))
} else {
Ok(self
.readers(rtxn, self.angular_db())
.next()
.transpose()?
.map(|reader| reader.dimensions()))
}
}
@@ -93,13 +143,13 @@ impl ArroyWrapper {
arroy_memory: Option<usize>,
cancel: &(impl Fn() -> bool + Sync + Send),
) -> Result<(), arroy::Error> {
for index in arroy_db_range_for_embedder(self.embedder_index) {
for index in arroy_store_range_for_embedder(self.embedder_index) {
if self.quantized {
let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
if writer.need_build(wtxn)? {
writer.builder(rng).build(wtxn)?
} else if writer.is_empty(wtxn)? {
break;
continue;
}
} else {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
@@ -124,7 +174,7 @@ impl ArroyWrapper {
.cancel(cancel)
.build(wtxn)?;
} else if writer.is_empty(wtxn)? {
break;
continue;
}
}
}
@@ -143,7 +193,7 @@ impl ArroyWrapper {
) -> Result<(), arroy::Error> {
let dimension = embeddings.dimension();
for (index, vector) in
arroy_db_range_for_embedder(self.embedder_index).zip(embeddings.iter())
arroy_store_range_for_embedder(self.embedder_index).zip(embeddings.iter())
{
if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension)
@@ -179,7 +229,7 @@ impl ArroyWrapper {
) -> Result<(), arroy::Error> {
let dimension = vector.len();
for index in arroy_db_range_for_embedder(self.embedder_index) {
for index in arroy_store_range_for_embedder(self.embedder_index) {
let writer = arroy::Writer::new(db, index, dimension);
if !writer.contains_item(wtxn, item_id)? {
writer.add_item(wtxn, item_id, vector)?;
@@ -189,6 +239,38 @@ impl ArroyWrapper {
Ok(())
}
/// Add a vector associated with a document in store specified by its id.
///
/// Any existing vector associated with the document in the store will be replaced by the new vector.
pub fn add_item_in_store(
&self,
wtxn: &mut RwTxn,
item_id: arroy::ItemId,
store_id: u8,
vector: &[f32],
) -> Result<(), arroy::Error> {
if self.quantized {
self._add_item_in_store(wtxn, self.quantized_db(), item_id, store_id, vector)
} else {
self._add_item_in_store(wtxn, self.angular_db(), item_id, store_id, vector)
}
}
fn _add_item_in_store<D: arroy::Distance>(
&self,
wtxn: &mut RwTxn,
db: arroy::Database<D>,
item_id: arroy::ItemId,
store_id: u8,
vector: &[f32],
) -> Result<(), arroy::Error> {
let dimension = vector.len();
let index = arroy_store_for_embedder(self.embedder_index, store_id);
let writer = arroy::Writer::new(db, index, dimension);
writer.add_item(wtxn, item_id, vector)
}
/// Delete all embeddings from a specific `item_id`
pub fn del_items(
&self,
@@ -196,24 +278,84 @@ impl ArroyWrapper {
dimension: usize,
item_id: arroy::ItemId,
) -> Result<(), arroy::Error> {
for index in arroy_db_range_for_embedder(self.embedder_index) {
for index in arroy_store_range_for_embedder(self.embedder_index) {
if self.quantized {
let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
if !writer.del_item(wtxn, item_id)? {
break;
}
writer.del_item(wtxn, item_id)?;
} else {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
if !writer.del_item(wtxn, item_id)? {
break;
}
writer.del_item(wtxn, item_id)?;
}
}
Ok(())
}
/// Delete one item.
/// Removes the item specified by its id from the store specified by its id.
///
/// Returns whether the item was removed.
///
/// # Warning
///
/// - This function will silently fail to remove the item if used against an arroy database that was never built.
pub fn del_item_in_store(
&self,
wtxn: &mut RwTxn,
item_id: arroy::ItemId,
store_id: u8,
dimensions: usize,
) -> Result<bool, arroy::Error> {
if self.quantized {
self._del_item_in_store(wtxn, self.quantized_db(), item_id, store_id, dimensions)
} else {
self._del_item_in_store(wtxn, self.angular_db(), item_id, store_id, dimensions)
}
}
fn _del_item_in_store<D: arroy::Distance>(
&self,
wtxn: &mut RwTxn,
db: arroy::Database<D>,
item_id: arroy::ItemId,
store_id: u8,
dimensions: usize,
) -> Result<bool, arroy::Error> {
let index = arroy_store_for_embedder(self.embedder_index, store_id);
let writer = arroy::Writer::new(db, index, dimensions);
writer.del_item(wtxn, item_id)
}
/// Removes all items from the store specified by its id.
///
/// # Warning
///
/// - This function will silently fail to remove the items if used against an arroy database that was never built.
pub fn clear_store(
&self,
wtxn: &mut RwTxn,
store_id: u8,
dimensions: usize,
) -> Result<(), arroy::Error> {
if self.quantized {
self._clear_store(wtxn, self.quantized_db(), store_id, dimensions)
} else {
self._clear_store(wtxn, self.angular_db(), store_id, dimensions)
}
}
fn _clear_store<D: arroy::Distance>(
&self,
wtxn: &mut RwTxn,
db: arroy::Database<D>,
store_id: u8,
dimensions: usize,
) -> Result<(), arroy::Error> {
let index = arroy_store_for_embedder(self.embedder_index, store_id);
let writer = arroy::Writer::new(db, index, dimensions);
writer.clear(wtxn)
}
/// Delete one item from its value.
pub fn del_item(
&self,
wtxn: &mut RwTxn,
@@ -235,54 +377,31 @@ impl ArroyWrapper {
vector: &[f32],
) -> Result<bool, arroy::Error> {
let dimension = vector.len();
let mut deleted_index = None;
for index in arroy_db_range_for_embedder(self.embedder_index) {
for index in arroy_store_range_for_embedder(self.embedder_index) {
let writer = arroy::Writer::new(db, index, dimension);
let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
// uses invariant: vectors are packed in the first writers.
break;
continue;
};
if candidate == vector {
writer.del_item(wtxn, item_id)?;
deleted_index = Some(index);
return writer.del_item(wtxn, item_id);
}
}
// 🥲 enforce invariant: vectors are packed in the first writers.
if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
for index in
arroy_db_range_for_embedder(self.embedder_index).skip(deleted_index as usize)
{
let writer = arroy::Writer::new(db, index, dimension);
let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
break;
};
last_index_with_a_vector = Some((index, candidate));
}
if let Some((last_index, vector)) = last_index_with_a_vector {
let writer = arroy::Writer::new(db, last_index, dimension);
writer.del_item(wtxn, item_id)?;
let writer = arroy::Writer::new(db, deleted_index, dimension);
writer.add_item(wtxn, item_id, &vector)?;
}
}
Ok(deleted_index.is_some())
Ok(false)
}
pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
for index in arroy_db_range_for_embedder(self.embedder_index) {
for index in arroy_store_range_for_embedder(self.embedder_index) {
if self.quantized {
let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
if writer.is_empty(wtxn)? {
break;
continue;
}
writer.clear(wtxn)?;
} else {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
if writer.is_empty(wtxn)? {
break;
continue;
}
writer.clear(wtxn)?;
}
@@ -296,17 +415,17 @@ impl ArroyWrapper {
dimension: usize,
item: arroy::ItemId,
) -> Result<bool, arroy::Error> {
for index in arroy_db_range_for_embedder(self.embedder_index) {
for index in arroy_store_range_for_embedder(self.embedder_index) {
let contains = if self.quantized {
let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
if writer.is_empty(rtxn)? {
break;
continue;
}
writer.contains_item(rtxn, item)?
} else {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
if writer.is_empty(rtxn)? {
break;
continue;
}
writer.contains_item(rtxn, item)?
};
@@ -345,13 +464,14 @@ impl ArroyWrapper {
let reader = reader?;
let mut searcher = reader.nns(limit);
if let Some(filter) = filter {
if reader.item_ids().is_disjoint(filter) {
continue;
}
searcher.candidates(filter);
}
if let Some(mut ret) = searcher.by_item(rtxn, item)? {
results.append(&mut ret);
} else {
break;
}
}
results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
@@ -386,6 +506,9 @@ impl ArroyWrapper {
let reader = reader?;
let mut searcher = reader.nns(limit);
if let Some(filter) = filter {
if reader.item_ids().is_disjoint(filter) {
continue;
}
searcher.candidates(filter);
}
@@ -404,16 +527,12 @@ impl ArroyWrapper {
for reader in self.readers(rtxn, self.quantized_db()) {
if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
vectors.push(vec);
} else {
break;
}
}
} else {
for reader in self.readers(rtxn, self.angular_db()) {
if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
vectors.push(vec);
} else {
break;
}
}
}
@@ -465,6 +584,7 @@ pub struct ArroyStats {
pub documents: RoaringBitmap,
}
/// One or multiple embeddings stored consecutively in a flat vector.
#[derive(Debug, PartialEq)]
pub struct Embeddings<F> {
data: Vec<F>,
dimension: usize,
@@ -615,15 +735,43 @@ impl EmbeddingConfig {
}
}
/// Map of embedder configurations.
///
/// Each configuration is mapped to a name.
/// Map of runtime embedder data.
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
pub struct RuntimeEmbedders(HashMap<String, Arc<RuntimeEmbedder>>);
impl EmbeddingConfigs {
pub struct RuntimeEmbedder {
pub embedder: Arc<Embedder>,
pub document_template: Prompt,
fragments: Vec<RuntimeFragment>,
pub is_quantized: bool,
}
impl RuntimeEmbedder {
pub fn new(
embedder: Arc<Embedder>,
document_template: Prompt,
mut fragments: Vec<RuntimeFragment>,
is_quantized: bool,
) -> Self {
fragments.sort_unstable_by(|left, right| left.name.cmp(&right.name));
Self { embedder, document_template, fragments, is_quantized }
}
/// The runtime fragments sorted by name.
pub fn fragments(&self) -> &[RuntimeFragment] {
self.fragments.as_slice()
}
}
pub struct RuntimeFragment {
pub name: String,
pub id: u8,
pub template: JsonTemplate,
}
impl RuntimeEmbedders {
/// Create the map from its internal component.s
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
pub fn new(data: HashMap<String, Arc<RuntimeEmbedder>>) -> Self {
Self(data)
}
@@ -632,24 +780,31 @@ impl EmbeddingConfigs {
}
/// Get an embedder configuration and template from its name.
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
self.0.get(name).cloned()
pub fn get(&self, name: &str) -> Option<&Arc<RuntimeEmbedder>> {
self.0.get(name)
}
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
pub fn inner_as_ref(&self) -> &HashMap<String, Arc<RuntimeEmbedder>> {
&self.0
}
pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
pub fn into_inner(self) -> HashMap<String, Arc<RuntimeEmbedder>> {
self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
impl IntoIterator for RuntimeEmbedders {
type Item = (String, Arc<RuntimeEmbedder>);
type IntoIter =
std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
type IntoIter = std::collections::hash_map::IntoIter<String, Arc<RuntimeEmbedder>>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
@@ -667,6 +822,27 @@ pub enum EmbedderOptions {
Composite(composite::EmbedderOptions),
}
impl EmbedderOptions {
pub fn fragment(&self, name: &str) -> Option<&serde_json::Value> {
match &self {
EmbedderOptions::HuggingFace(_)
| EmbedderOptions::OpenAi(_)
| EmbedderOptions::Ollama(_)
| EmbedderOptions::UserProvided(_) => None,
EmbedderOptions::Rest(embedder_options) => {
embedder_options.indexing_fragments.get(name)
}
EmbedderOptions::Composite(embedder_options) => {
if let SubEmbedderOptions::Rest(embedder_options) = &embedder_options.index {
embedder_options.indexing_fragments.get(name)
} else {
None
}
}
}
}
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::HuggingFace(Default::default())
@@ -707,6 +883,17 @@ impl Embedder {
#[tracing::instrument(level = "debug", skip_all, target = "search")]
pub fn embed_search(
&self,
query: SearchQuery<'_>,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
match query {
SearchQuery::Text(text) => self.embed_search_text(text, deadline),
SearchQuery::Media { q, media } => self.embed_search_media(q, media, deadline),
}
}
pub fn embed_search_text(
&self,
text: &str,
deadline: Option<Instant>,
@@ -728,10 +915,7 @@ impl Embedder {
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::UserProvided(embedder) => embedder.embed_one(text),
Embedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline, None)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::Rest(embedder) => embedder.embed_one(SearchQuery::Text(text), deadline, None),
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline, None),
}?;
@@ -742,6 +926,18 @@ impl Embedder {
Ok(embedding)
}
pub fn embed_search_media(
&self,
q: Option<&str>,
media: Option<&serde_json::Value>,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
let Embedder::Rest(embedder) = self else {
return Err(EmbedError::rest_media_not_a_rest());
};
embedder.embed_one(SearchQuery::Media { q, media }, deadline, None)
}
/// Embed multiple chunks of texts.
///
/// Each chunk is composed of one or multiple texts.
@@ -786,6 +982,26 @@ impl Embedder {
}
}
pub fn embed_index_ref_fragments(
&self,
fragments: &[serde_json::Value],
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
if let Embedder::Rest(embedder) = self {
embedder.embed_index_ref(fragments, threads, embedder_stats)
} else {
let Embedder::Composite(embedder) = self else {
unimplemented!("embedding fragments is only available for rest embedders")
};
let crate::vector::composite::SubEmbedder::Rest(embedder) = &embedder.index else {
unimplemented!("embedding fragments is only available for rest embedders")
};
embedder.embed_index_ref(fragments, threads, embedder_stats)
}
}
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
pub fn chunk_count_hint(&self) -> usize {
match self {
@@ -857,6 +1073,12 @@ impl Embedder {
}
}
#[derive(Clone, Copy)]
pub enum SearchQuery<'a> {
Text(&'a str),
Media { q: Option<&'a str>, media: Option<&'a serde_json::Value> },
}
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.
@@ -986,8 +1208,11 @@ pub const fn is_cuda_enabled() -> bool {
cfg!(feature = "cuda")
}
pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
let embedder_id = (embedder_id as u16) << 8;
(0..=u8::MAX).map(move |k| embedder_id | (k as u16))
fn arroy_store_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
(0..=u8::MAX).map(move |store_id| arroy_store_for_embedder(embedder_id, store_id))
}
fn arroy_store_for_embedder(embedder_id: u8, store_id: u8) -> u16 {
let embedder_id = (embedder_id as u16) << 8;
embedder_id | (store_id as u16)
}

View File

@@ -71,6 +71,8 @@ impl EmbedderOptions {
request,
response,
headers: Default::default(),
indexing_fragments: Default::default(),
search_fragments: Default::default(),
})
}
}

View File

@@ -201,6 +201,8 @@ impl Embedder {
]
}),
headers: Default::default(),
indexing_fragments: Default::default(),
search_fragments: Default::default(),
},
cache_cap,
super::rest::ConfigurationSource::OpenAi,

View File

@@ -6,9 +6,8 @@ use serde_json::value::RawValue;
use serde_json::{from_slice, Value};
use super::Embedding;
use crate::index::IndexEmbeddingConfig;
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
use crate::{DocumentId, FieldId, InternalError, UserError};
use crate::{FieldId, InternalError, UserError};
#[derive(serde::Serialize, Debug)]
#[serde(untagged)]
@@ -151,7 +150,8 @@ impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
regenerate = Some(value);
}
Ok(Some("embeddings")) => {
let value: &RawValue = match map.next_value() {
let value: &RawValue = match map.next_value::<&RawValue>() {
Ok(value) if value.get() == RawValue::NULL.get() => continue,
Ok(value) => value,
Err(error) => {
return Ok(Err(RawVectorsError::DeserializeEmbeddings {
@@ -374,8 +374,7 @@ pub struct ParsedVectorsDiff {
impl ParsedVectorsDiff {
pub fn new(
docid: DocumentId,
embedders_configs: &[IndexEmbeddingConfig],
regenerate_for_embedders: impl Iterator<Item = String>,
documents_diff: &KvReader<FieldId>,
old_vectors_fid: Option<FieldId>,
new_vectors_fid: Option<FieldId>,
@@ -396,10 +395,8 @@ impl ParsedVectorsDiff {
}
}
.flatten().map_or(BTreeMap::default(), |del| del.into_iter().map(|(name, vec)| (name, VectorState::Inline(vec))).collect());
for embedding_config in embedders_configs {
if embedding_config.user_provided.contains(docid) {
old.entry(embedding_config.name.to_string()).or_insert(VectorState::Manual);
}
for name in regenerate_for_embedders {
old.entry(name).or_insert(VectorState::Generated);
}
let new = 'new: {

View File

@@ -6,11 +6,13 @@ use rand::Rng;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use rayon::slice::ParallelSlice as _;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate;
use super::json_template::{InjectableValue, JsonTemplate};
use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, SearchQuery,
REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
use crate::progress::EmbedderStats;
@@ -88,19 +90,61 @@ struct EmbedderData {
bearer: Option<String>,
headers: BTreeMap<String, String>,
url: String,
request: Request,
request: RequestData,
response: Response,
configuration_source: ConfigurationSource,
}
#[derive(Debug)]
pub enum RequestData {
Single(Request),
FromFragments(RequestFromFragments),
}
impl RequestData {
pub fn new(
request: Value,
indexing_fragments: BTreeMap<String, Value>,
search_fragments: BTreeMap<String, Value>,
) -> Result<Self, NewEmbedderError> {
Ok(if indexing_fragments.is_empty() && search_fragments.is_empty() {
RequestData::Single(Request::new(request)?)
} else {
for (name, value) in indexing_fragments {
JsonTemplate::new(value).map_err(|error| {
NewEmbedderError::rest_could_not_parse_template(
error.parsing(&format!(".indexingFragments.{name}")),
)
})?;
}
RequestData::FromFragments(RequestFromFragments::new(request, search_fragments)?)
})
}
fn input_type(&self) -> InputType {
match self {
RequestData::Single(request) => request.input_type(),
RequestData::FromFragments(request_from_fragments) => {
request_from_fragments.input_type()
}
}
}
fn has_fragments(&self) -> bool {
matches!(self, RequestData::FromFragments(_))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>,
pub url: String,
pub request: serde_json::Value,
pub response: serde_json::Value,
pub request: Value,
pub search_fragments: BTreeMap<String, Value>,
pub indexing_fragments: BTreeMap<String, Value>,
pub response: Value,
pub headers: BTreeMap<String, String>,
}
@@ -138,7 +182,12 @@ impl Embedder {
.timeout(std::time::Duration::from_secs(30))
.build();
let request = Request::new(options.request)?;
let request = RequestData::new(
options.request,
options.indexing_fragments,
options.search_fragments,
)?;
let response = Response::new(options.response, &request)?;
let data = EmbedderData {
@@ -188,7 +237,7 @@ impl Embedder {
embedder_stats: Option<&EmbedderStats>,
) -> Result<Vec<Embedding>, EmbedError>
where
S: AsRef<str> + Serialize,
S: Serialize,
{
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
}
@@ -231,9 +280,9 @@ impl Embedder {
}
}
pub(crate) fn embed_index_ref(
pub(crate) fn embed_index_ref<S: Serialize + Sync>(
&self,
texts: &[&str],
texts: &[S],
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> Result<Vec<Embedding>, EmbedError> {
@@ -287,9 +336,44 @@ impl Embedder {
pub(super) fn cache(&self) -> &EmbeddingCache {
&self.cache
}
pub(crate) fn embed_one(
&self,
query: SearchQuery,
deadline: Option<Instant>,
embedder_stats: Option<&EmbedderStats>,
) -> Result<Embedding, EmbedError> {
let mut embeddings = match (&self.data.request, query) {
(RequestData::Single(_), SearchQuery::Text(text)) => {
embed(&self.data, &[text], 1, Some(self.dimensions), deadline, embedder_stats)
}
(RequestData::Single(_), SearchQuery::Media { q: _, media: _ }) => {
return Err(EmbedError::rest_media_not_a_fragment())
}
(RequestData::FromFragments(request_from_fragments), SearchQuery::Text(q)) => {
let fragment = request_from_fragments.render_search_fragment(Some(q), None)?;
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
}
(
RequestData::FromFragments(request_from_fragments),
SearchQuery::Media { q, media },
) => {
let fragment = request_from_fragments.render_search_fragment(q, media)?;
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
}
}?;
// unwrap: checked by `expected_count`
Ok(embeddings.pop().unwrap())
}
}
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
if data.request.has_fragments() {
return Err(NewEmbedderError::rest_cannot_infer_dimensions_for_fragment());
}
let v = embed(data, ["test"].as_slice(), 1, None, None, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
@@ -307,6 +391,13 @@ fn embed<S>(
where
S: Serialize,
{
if inputs.is_empty() {
if expected_count != 0 {
return Err(EmbedError::rest_response_embedding_count(expected_count, 0));
}
return Ok(Vec::new());
}
let request = data.client.post(&data.url);
let request = if let Some(bearer) = &data.bearer {
request.set("Authorization", bearer)
@@ -318,7 +409,12 @@ where
request = request.set(header.as_str(), value.as_str());
}
let body = data.request.inject_texts(inputs);
let body = match &data.request {
RequestData::Single(request) => request.inject_texts(inputs),
RequestData::FromFragments(request_from_fragments) => {
request_from_fragments.request_from_fragments(inputs).expect("inputs was empty")
}
};
for attempt in 0..10 {
if let Some(embedder_stats) = &embedder_stats {
@@ -426,7 +522,7 @@ fn response_to_embedding(
expected_count: usize,
expected_dimensions: Option<usize>,
) -> Result<Vec<Embedding>, Retry> {
let response: serde_json::Value = response
let response: Value = response
.into_json()
.map_err(EmbedError::rest_response_deserialization)
.map_err(Retry::retry_later)?;
@@ -455,21 +551,24 @@ fn response_to_embedding(
}
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
pub(super) const REQUEST_FRAGMENT_PLACEHOLDER: &str = "{{fragment}}";
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
#[derive(Debug)]
pub struct Request {
template: ValueTemplate,
template: InjectableValue,
}
impl Request {
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
pub fn new(template: Value) -> Result<Self, NewEmbedderError> {
let template = match InjectableValue::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER)
{
Ok(template) => template,
Err(error) => {
let message =
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
let message = format!("{message}\n - Note: this template is using a document template, and so expects to contain the placeholder {REQUEST_PLACEHOLDER:?} rather than {REQUEST_FRAGMENT_PLACEHOLDER:?}");
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
@@ -485,42 +584,120 @@ impl Request {
}
}
pub fn inject_texts<S: Serialize>(
&self,
texts: impl IntoIterator<Item = S>,
) -> serde_json::Value {
pub fn inject_texts<S: Serialize>(&self, texts: impl IntoIterator<Item = S>) -> Value {
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
}
}
#[derive(Debug)]
pub struct Response {
template: ValueTemplate,
pub struct RequestFromFragments {
search_fragments: BTreeMap<String, JsonTemplate>,
request: InjectableValue,
}
impl Response {
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
{
impl RequestFromFragments {
pub fn new(
request: Value,
search_fragments: impl IntoIterator<Item = (String, Value)>,
) -> Result<Self, NewEmbedderError> {
let request = match InjectableValue::new(
request,
REQUEST_FRAGMENT_PLACEHOLDER,
REPEAT_PLACEHOLDER,
) {
Ok(template) => template,
Err(error) => {
let message =
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
let message = error.error_message(
"request",
REQUEST_FRAGMENT_PLACEHOLDER,
REPEAT_PLACEHOLDER,
);
let message = format!("{message}\n - Note: this template is using fragments, and so expects to contain the placeholder {REQUEST_FRAGMENT_PLACEHOLDER:?} rathern than {REQUEST_PLACEHOLDER:?}");
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
match (template.has_array_value(), request.template.has_array_value()) {
let search_fragments: Result<_, NewEmbedderError> = search_fragments
.into_iter()
.map(|(name, value)| {
let json_template = JsonTemplate::new(value).map_err(|error| {
NewEmbedderError::rest_could_not_parse_template(
error.parsing(&format!(".searchFragments.{name}")),
)
})?;
Ok((name, json_template))
})
.collect();
Ok(Self { request, search_fragments: search_fragments? })
}
fn input_type(&self) -> InputType {
if self.request.has_array_value() {
InputType::TextArray
} else {
InputType::Text
}
}
pub fn render_search_fragment(
&self,
q: Option<&str>,
media: Option<&Value>,
) -> Result<Value, EmbedError> {
let mut it = self.search_fragments.iter().filter_map(|(name, template)| {
let render = template.render_search(q, media).ok()?;
Some((name, render))
});
let Some((name, fragment)) = it.next() else {
return Err(EmbedError::rest_search_matches_no_fragment(q, media));
};
if let Some((second_name, _)) = it.next() {
return Err(EmbedError::rest_search_matches_multiple_fragments(
name,
second_name,
q,
media,
));
}
Ok(fragment)
}
pub fn request_from_fragments<'a, S: Serialize + 'a>(
&self,
fragments: impl IntoIterator<Item = &'a S>,
) -> Option<Value> {
self.request.inject(fragments.into_iter().map(|fragment| serde_json::json!(fragment))).ok()
}
}
#[derive(Debug)]
pub struct Response {
template: InjectableValue,
}
impl Response {
pub fn new(template: Value, request: &RequestData) -> Result<Self, NewEmbedderError> {
let template =
match InjectableValue::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) {
Ok(template) => template,
Err(error) => {
let message =
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
match (template.has_array_value(), request.input_type() == InputType::TextArray) {
(true, true) | (false, false) => Ok(Self {template}),
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
}
}
pub fn extract_embeddings(
&self,
response: serde_json::Value,
) -> Result<Vec<Embedding>, EmbedError> {
pub fn extract_embeddings(&self, response: Value) -> Result<Vec<Embedding>, EmbedError> {
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
Ok(extracted_values) => extracted_values,
Err(error) => {

View File

@@ -0,0 +1,177 @@
use bumpalo::collections::Vec as BVec;
use bumpalo::Bump;
use serde_json::Value;
use super::{EmbedError, Embedder, Embedding};
use crate::progress::EmbedderStats;
use crate::{DocumentId, Result, ThreadPoolNoAbort};
type ExtractorId = u8;
#[derive(Clone, Copy)]
pub struct Metadata<'doc> {
pub docid: DocumentId,
pub external_docid: &'doc str,
pub extractor_id: ExtractorId,
}
pub struct EmbeddingResponse<'doc> {
pub metadata: Metadata<'doc>,
pub embedding: Option<Embedding>,
}
pub trait OnEmbed<'doc> {
type ErrorMetadata;
fn process_embedding_response(&mut self, response: EmbeddingResponse<'doc>);
fn process_embedding_error(
&mut self,
error: EmbedError,
embedder_name: &'doc str,
unused_vectors_distribution: &Self::ErrorMetadata,
metadata: BVec<'doc, Metadata<'doc>>,
) -> crate::Error;
}
pub struct EmbedSession<'doc, C, I> {
// requests
inputs: BVec<'doc, I>,
metadata: BVec<'doc, Metadata<'doc>>,
threads: &'doc ThreadPoolNoAbort,
embedder: &'doc Embedder,
embedder_name: &'doc str,
embedder_stats: &'doc EmbedderStats,
on_embed: C,
}
pub trait Input: Sized {
fn embed_ref(
inputs: &[Self],
embedder: &Embedder,
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> std::result::Result<Vec<Embedding>, EmbedError>;
}
impl Input for &'_ str {
fn embed_ref(
inputs: &[Self],
embedder: &Embedder,
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
embedder.embed_index_ref(inputs, threads, embedder_stats)
}
}
impl Input for Value {
fn embed_ref(
inputs: &[Value],
embedder: &Embedder,
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
embedder.embed_index_ref_fragments(inputs, threads, embedder_stats)
}
}
impl<'doc, C: OnEmbed<'doc>, I: Input> EmbedSession<'doc, C, I> {
#[allow(clippy::too_many_arguments)]
pub fn new(
embedder: &'doc Embedder,
embedder_name: &'doc str,
threads: &'doc ThreadPoolNoAbort,
doc_alloc: &'doc Bump,
embedder_stats: &'doc EmbedderStats,
on_embed: C,
) -> Self {
let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint();
let texts = BVec::with_capacity_in(capacity, doc_alloc);
let ids = BVec::with_capacity_in(capacity, doc_alloc);
Self {
inputs: texts,
metadata: ids,
embedder,
threads,
embedder_name,
embedder_stats,
on_embed,
}
}
pub fn request_embedding(
&mut self,
metadata: Metadata<'doc>,
rendered: I,
unused_vectors_distribution: &C::ErrorMetadata,
) -> Result<()> {
if self.inputs.len() < self.inputs.capacity() {
self.inputs.push(rendered);
self.metadata.push(metadata);
return Ok(());
}
self.embed_chunks(unused_vectors_distribution)
}
pub fn drain(mut self, unused_vectors_distribution: &C::ErrorMetadata) -> Result<C> {
self.embed_chunks(unused_vectors_distribution)?;
Ok(self.on_embed)
}
#[allow(clippy::too_many_arguments)]
fn embed_chunks(&mut self, unused_vectors_distribution: &C::ErrorMetadata) -> Result<()> {
if self.inputs.is_empty() {
return Ok(());
}
let res = match I::embed_ref(
self.inputs.as_slice(),
self.embedder,
self.threads,
self.embedder_stats,
) {
Ok(embeddings) => {
for (metadata, embedding) in self.metadata.iter().copied().zip(embeddings) {
self.on_embed.process_embedding_response(EmbeddingResponse {
metadata,
embedding: Some(embedding),
});
}
Ok(())
}
Err(error) => {
// reset metadata and inputs, and send metadata to the error processing.
let doc_alloc = self.metadata.bump();
let metadata = std::mem::replace(
&mut self.metadata,
BVec::with_capacity_in(self.inputs.capacity(), doc_alloc),
);
self.inputs.clear();
return Err(self.on_embed.process_embedding_error(
error,
self.embedder_name,
unused_vectors_distribution,
metadata,
));
}
};
self.inputs.clear();
self.metadata.clear();
res
}
pub(crate) fn embedder_name(&self) -> &'doc str {
self.embedder_name
}
pub(crate) fn doc_alloc(&self) -> &'doc Bump {
self.inputs.bump()
}
pub(crate) fn on_embed_mut(&mut self) -> &mut C {
&mut self.on_embed
}
}

View File

@@ -2,6 +2,8 @@ use std::collections::BTreeMap;
use std::num::NonZeroUsize;
use deserr::Deserr;
use either::Either;
use itertools::Itertools;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
@@ -229,6 +231,35 @@ pub struct EmbeddingSettings {
/// - 🏗️ When modified for sources `ollama` and `rest`, embeddings are always regenerated
pub url: Setting<String>,
/// Template fragments that will be reassembled and sent to the remote embedder at indexing time.
///
/// # Availability
///
/// - This parameter is available for sources `rest`.
///
/// # 🔄 Reindexing
///
/// - 🏗️ When a fragment is deleted by passing `null` to its name, the corresponding embeddings are removed from documents.
/// - 🏗️ When a fragment is modified, the corresponding embeddings are regenerated if their rendered version changes.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<BTreeMap<String, serde_json::Value>>)]
pub indexing_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
/// Template fragments that will be reassembled and sent to the remote embedder at search time.
///
/// # Availability
///
/// - This parameter is available for sources `rest`.
///
/// # 🔄 Reindexing
///
/// - 🌱 Changing the value of this parameter never regenerates embeddings
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<BTreeMap<String, serde_json::Value>>)]
pub search_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<serde_json::Value>)]
@@ -483,6 +514,36 @@ pub struct SubEmbeddingSettings {
/// - 🌱 When modified for source `openAi`, embeddings are never regenerated
/// - 🏗️ When modified for sources `ollama` and `rest`, embeddings are always regenerated
pub url: Setting<String>,
/// Template fragments that will be reassembled and sent to the remote embedder at indexing time.
///
/// # Availability
///
/// - This parameter is available for sources `rest`.
///
/// # 🔄 Reindexing
///
/// - 🏗️ When a fragment is deleted by passing `null` to its name, the corresponding embeddings are removed from documents.
/// - 🏗️ When a fragment is modified, the corresponding embeddings are regenerated if their rendered version changes.
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<BTreeMap<String, serde_json::Value>>)]
pub indexing_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
/// Template fragments that will be reassembled and sent to the remote embedder at search time.
///
/// # Availability
///
/// - This parameter is available for sources `rest`.
///
/// # 🔄 Reindexing
///
/// - 🌱 Changing the value of this parameter never regenerates embeddings
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<BTreeMap<String, serde_json::Value>>)]
pub search_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<serde_json::Value>)]
@@ -554,17 +615,31 @@ pub struct SubEmbeddingSettings {
pub indexing_embedder: Setting<serde_json::Value>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum EmbeddingValidationContext {
FullSettings,
SettingsPartialUpdate,
}
/// Indicates what action should take place during a reindexing operation for an embedder
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ReindexAction {
/// An indexing operation should take place for this embedder, keeping existing vectors
/// and checking whether the document template changed or not
RegeneratePrompts,
RegenerateFragments(Vec<(String, RegenerateFragment)>),
/// An indexing operation should take place for all documents for this embedder, removing existing vectors
/// (except userProvided ones)
FullReindex,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RegenerateFragment {
Update,
Remove,
Add,
}
pub enum SettingsDiff {
Remove,
Reindex { action: ReindexAction, updated_settings: EmbeddingSettings, quantize: bool },
@@ -577,6 +652,12 @@ pub struct EmbedderAction {
pub is_being_quantized: bool,
pub write_back: Option<WriteBackToDocuments>,
pub reindex: Option<ReindexAction>,
pub remove_fragments: Option<RemoveFragments>,
}
#[derive(Debug)]
pub struct RemoveFragments {
pub fragment_ids: Vec<u8>,
}
impl EmbedderAction {
@@ -592,6 +673,10 @@ impl EmbedderAction {
self.reindex.as_ref()
}
pub fn remove_fragments(&self) -> Option<&RemoveFragments> {
self.remove_fragments.as_ref()
}
pub fn with_is_being_quantized(mut self, quantize: bool) -> Self {
self.is_being_quantized = quantize;
self
@@ -603,11 +688,23 @@ impl EmbedderAction {
is_being_quantized: false,
write_back: Some(write_back),
reindex: None,
remove_fragments: None,
}
}
pub fn with_reindex(reindex: ReindexAction, was_quantized: bool) -> Self {
Self { was_quantized, is_being_quantized: false, write_back: None, reindex: Some(reindex) }
Self {
was_quantized,
is_being_quantized: false,
write_back: None,
reindex: Some(reindex),
remove_fragments: None,
}
}
pub fn with_remove_fragments(mut self, remove_fragments: RemoveFragments) -> Self {
self.remove_fragments = Some(remove_fragments);
self
}
}
@@ -634,6 +731,8 @@ impl SettingsDiff {
mut dimensions,
mut document_template,
mut url,
mut indexing_fragments,
mut search_fragments,
mut request,
mut response,
mut search_embedder,
@@ -653,6 +752,8 @@ impl SettingsDiff {
dimensions: new_dimensions,
document_template: new_document_template,
url: new_url,
indexing_fragments: new_indexing_fragments,
search_fragments: new_search_fragments,
request: new_request,
response: new_response,
search_embedder: new_search_embedder,
@@ -684,6 +785,8 @@ impl SettingsDiff {
&mut document_template,
&mut document_template_max_bytes,
&mut url,
&mut indexing_fragments,
&mut search_fragments,
&mut request,
&mut response,
&mut headers,
@@ -696,6 +799,8 @@ impl SettingsDiff {
new_document_template,
new_document_template_max_bytes,
new_url,
new_indexing_fragments,
new_search_fragments,
new_request,
new_response,
new_headers,
@@ -722,6 +827,8 @@ impl SettingsDiff {
dimensions,
document_template,
url,
indexing_fragments,
search_fragments,
request,
response,
search_embedder,
@@ -769,6 +876,8 @@ impl SettingsDiff {
mut document_template,
mut document_template_max_bytes,
mut url,
mut indexing_fragments,
mut search_fragments,
mut request,
mut response,
mut headers,
@@ -794,6 +903,8 @@ impl SettingsDiff {
document_template: new_document_template,
document_template_max_bytes: new_document_template_max_bytes,
url: new_url,
indexing_fragments: new_indexing_fragments,
search_fragments: new_search_fragments,
request: new_request,
response: new_response,
headers: new_headers,
@@ -814,6 +925,8 @@ impl SettingsDiff {
&mut document_template,
&mut document_template_max_bytes,
&mut url,
&mut indexing_fragments,
&mut search_fragments,
&mut request,
&mut response,
&mut headers,
@@ -826,6 +939,8 @@ impl SettingsDiff {
new_document_template,
new_document_template_max_bytes,
new_url,
new_indexing_fragments,
new_search_fragments,
new_request,
new_response,
new_headers,
@@ -846,6 +961,8 @@ impl SettingsDiff {
dimensions,
document_template,
url,
indexing_fragments,
search_fragments,
request,
response,
headers,
@@ -875,6 +992,8 @@ impl SettingsDiff {
document_template: &mut Setting<String>,
document_template_max_bytes: &mut Setting<usize>,
url: &mut Setting<String>,
indexing_fragments: &mut Setting<BTreeMap<String, Option<Fragment>>>,
search_fragments: &mut Setting<BTreeMap<String, Option<Fragment>>>,
request: &mut Setting<serde_json::Value>,
response: &mut Setting<serde_json::Value>,
headers: &mut Setting<BTreeMap<String, String>>,
@@ -887,6 +1006,8 @@ impl SettingsDiff {
new_document_template: Setting<String>,
new_document_template_max_bytes: Setting<usize>,
new_url: Setting<String>,
new_indexing_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
new_search_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
new_request: Setting<serde_json::Value>,
new_response: Setting<serde_json::Value>,
new_headers: Setting<BTreeMap<String, String>>,
@@ -902,6 +1023,8 @@ impl SettingsDiff {
pooling,
dimensions,
url,
indexing_fragments,
search_fragments,
request,
response,
document_template,
@@ -941,6 +1064,105 @@ impl SettingsDiff {
}
}
}
*search_fragments = match (std::mem::take(search_fragments), new_search_fragments) {
(Setting::Set(search_fragments), Setting::Set(new_search_fragments)) => {
Setting::Set(
search_fragments
.into_iter()
.merge_join_by(new_search_fragments, |(left, _), (right, _)| {
left.cmp(right)
})
.map(|eob| {
match eob {
// merge fragments
itertools::EitherOrBoth::Both((name, _), (_, right)) => {
(name, right)
}
// unchanged fragment
itertools::EitherOrBoth::Left(left) => left,
// new fragment
itertools::EitherOrBoth::Right(right) => right,
}
})
.collect(),
)
}
(_, Setting::Reset) => Setting::Reset,
(left, Setting::NotSet) => left,
(Setting::NotSet | Setting::Reset, Setting::Set(new_search_fragments)) => {
Setting::Set(new_search_fragments)
}
};
let mut regenerate_fragments = Vec::new();
*indexing_fragments = match (std::mem::take(indexing_fragments), new_indexing_fragments) {
(Setting::Set(fragments), Setting::Set(new_fragments)) => {
Setting::Set(
fragments
.into_iter()
.merge_join_by(new_fragments, |(left, _), (right, _)| left.cmp(right))
.map(|eob| {
match eob {
// merge fragments
itertools::EitherOrBoth::Both(
(name, left),
(other_name, right),
) => {
if left == right {
(name, left)
} else {
match right {
Some(right) => {
regenerate_fragments
.push((other_name, RegenerateFragment::Update));
(name, Some(right))
}
None => {
regenerate_fragments
.push((other_name, RegenerateFragment::Remove));
(name, None)
}
}
}
}
// unchanged fragment
itertools::EitherOrBoth::Left(left) => left,
// new fragment
itertools::EitherOrBoth::Right((name, right)) => {
if right.is_some() {
regenerate_fragments
.push((name.clone(), RegenerateFragment::Add));
}
(name, right)
}
}
})
.collect(),
)
}
// remove all fragments => move to document template
(_, Setting::Reset) => {
ReindexAction::push_action(reindex_action, ReindexAction::FullReindex);
Setting::Reset
}
// add all fragments
(Setting::NotSet | Setting::Reset, Setting::Set(new_fragments)) => {
ReindexAction::push_action(reindex_action, ReindexAction::FullReindex);
Setting::Set(new_fragments)
}
// no change
(left, Setting::NotSet) => left,
};
if !regenerate_fragments.is_empty() {
regenerate_fragments.sort_unstable_by(|(left, _), (right, _)| left.cmp(right));
ReindexAction::push_action(
reindex_action,
ReindexAction::RegenerateFragments(regenerate_fragments),
);
}
if request.apply(new_request) {
ReindexAction::push_action(reindex_action, ReindexAction::FullReindex);
}
@@ -972,10 +1194,16 @@ impl SettingsDiff {
impl ReindexAction {
fn push_action(this: &mut Option<Self>, other: Self) {
*this = match (*this, other) {
(_, ReindexAction::FullReindex) => Some(ReindexAction::FullReindex),
(Some(ReindexAction::FullReindex), _) => Some(ReindexAction::FullReindex),
(_, ReindexAction::RegeneratePrompts) => Some(ReindexAction::RegeneratePrompts),
use ReindexAction::*;
*this = match (this.take(), other) {
(_, FullReindex) => Some(FullReindex),
(Some(FullReindex), _) => Some(FullReindex),
(_, RegenerateFragments(fragments)) => Some(RegenerateFragments(fragments)),
(Some(RegenerateFragments(fragments)), RegeneratePrompts) => {
Some(RegenerateFragments(fragments))
}
(Some(RegeneratePrompts), RegeneratePrompts) => Some(RegeneratePrompts),
(None, RegeneratePrompts) => Some(RegeneratePrompts),
}
}
}
@@ -988,6 +1216,8 @@ fn apply_default_for_source(
pooling: &mut Setting<OverridePooling>,
dimensions: &mut Setting<usize>,
url: &mut Setting<String>,
indexing_fragments: &mut Setting<BTreeMap<String, Option<Fragment>>>,
search_fragments: &mut Setting<BTreeMap<String, Option<Fragment>>>,
request: &mut Setting<serde_json::Value>,
response: &mut Setting<serde_json::Value>,
document_template: &mut Setting<String>,
@@ -1003,6 +1233,8 @@ fn apply_default_for_source(
*pooling = Setting::Reset;
*dimensions = Setting::NotSet;
*url = Setting::NotSet;
*indexing_fragments = Setting::NotSet;
*search_fragments = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*headers = Setting::NotSet;
@@ -1015,6 +1247,8 @@ fn apply_default_for_source(
*pooling = Setting::NotSet;
*dimensions = Setting::Reset;
*url = Setting::NotSet;
*indexing_fragments = Setting::NotSet;
*search_fragments = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*headers = Setting::NotSet;
@@ -1027,6 +1261,8 @@ fn apply_default_for_source(
*pooling = Setting::NotSet;
*dimensions = Setting::NotSet;
*url = Setting::Reset;
*indexing_fragments = Setting::NotSet;
*search_fragments = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*headers = Setting::NotSet;
@@ -1039,6 +1275,8 @@ fn apply_default_for_source(
*pooling = Setting::NotSet;
*dimensions = Setting::Reset;
*url = Setting::Reset;
*indexing_fragments = Setting::Reset;
*search_fragments = Setting::Reset;
*request = Setting::Reset;
*response = Setting::Reset;
*headers = Setting::Reset;
@@ -1051,6 +1289,8 @@ fn apply_default_for_source(
*pooling = Setting::NotSet;
*dimensions = Setting::Reset;
*url = Setting::NotSet;
*indexing_fragments = Setting::NotSet;
*search_fragments = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*document_template = Setting::NotSet;
@@ -1065,6 +1305,8 @@ fn apply_default_for_source(
*pooling = Setting::NotSet;
*dimensions = Setting::NotSet;
*url = Setting::NotSet;
*indexing_fragments = Setting::NotSet;
*search_fragments = Setting::NotSet;
*request = Setting::NotSet;
*response = Setting::NotSet;
*document_template = Setting::NotSet;
@@ -1131,6 +1373,8 @@ pub enum MetaEmbeddingSetting {
DocumentTemplate,
DocumentTemplateMaxBytes,
Url,
IndexingFragments,
SearchFragments,
Request,
Response,
Headers,
@@ -1153,6 +1397,8 @@ impl MetaEmbeddingSetting {
DocumentTemplate => "documentTemplate",
DocumentTemplateMaxBytes => "documentTemplateMaxBytes",
Url => "url",
IndexingFragments => "indexingFragments",
SearchFragments => "searchFragments",
Request => "request",
Response => "response",
Headers => "headers",
@@ -1176,6 +1422,8 @@ impl EmbeddingSettings {
dimensions: &Setting<usize>,
api_key: &Setting<String>,
url: &Setting<String>,
indexing_fragments: &Setting<BTreeMap<String, Option<Fragment>>>,
search_fragments: &Setting<BTreeMap<String, Option<Fragment>>>,
request: &Setting<serde_json::Value>,
response: &Setting<serde_json::Value>,
document_template: &Setting<String>,
@@ -1210,6 +1458,20 @@ impl EmbeddingSettings {
)?;
Self::check_setting(embedder_name, source, MetaEmbeddingSetting::ApiKey, context, api_key)?;
Self::check_setting(embedder_name, source, MetaEmbeddingSetting::Url, context, url)?;
Self::check_setting(
embedder_name,
source,
MetaEmbeddingSetting::IndexingFragments,
context,
indexing_fragments,
)?;
Self::check_setting(
embedder_name,
source,
MetaEmbeddingSetting::SearchFragments,
context,
search_fragments,
)?;
Self::check_setting(
embedder_name,
source,
@@ -1348,8 +1610,8 @@ impl EmbeddingSettings {
) => FieldStatus::Allowed,
(
OpenAi,
Revision | Pooling | Request | Response | Headers | SearchEmbedder
| IndexingEmbedder,
Revision | Pooling | IndexingFragments | SearchFragments | Request | Response
| Headers | SearchEmbedder | IndexingEmbedder,
_,
) => FieldStatus::Disallowed,
(
@@ -1359,8 +1621,8 @@ impl EmbeddingSettings {
) => FieldStatus::Allowed,
(
HuggingFace,
ApiKey | Dimensions | Url | Request | Response | Headers | SearchEmbedder
| IndexingEmbedder,
ApiKey | Dimensions | Url | IndexingFragments | SearchFragments | Request
| Response | Headers | SearchEmbedder | IndexingEmbedder,
_,
) => FieldStatus::Disallowed,
(Ollama, Model, _) => FieldStatus::Mandatory,
@@ -1371,8 +1633,8 @@ impl EmbeddingSettings {
) => FieldStatus::Allowed,
(
Ollama,
Revision | Pooling | Request | Response | Headers | SearchEmbedder
| IndexingEmbedder,
Revision | Pooling | IndexingFragments | SearchFragments | Request | Response
| Headers | SearchEmbedder | IndexingEmbedder,
_,
) => FieldStatus::Disallowed,
(UserProvided, Dimensions, _) => FieldStatus::Mandatory,
@@ -1386,6 +1648,8 @@ impl EmbeddingSettings {
| DocumentTemplate
| DocumentTemplateMaxBytes
| Url
| IndexingFragments
| SearchFragments
| Request
| Response
| Headers
@@ -1404,6 +1668,10 @@ impl EmbeddingSettings {
| Headers,
_,
) => FieldStatus::Allowed,
(Rest, IndexingFragments, NotNested | Indexing) => FieldStatus::Allowed,
(Rest, IndexingFragments, Search) => FieldStatus::Disallowed,
(Rest, SearchFragments, NotNested | Search) => FieldStatus::Allowed,
(Rest, SearchFragments, Indexing) => FieldStatus::Disallowed,
(Rest, Model | Revision | Pooling | SearchEmbedder | IndexingEmbedder, _) => {
FieldStatus::Disallowed
}
@@ -1419,6 +1687,8 @@ impl EmbeddingSettings {
| DocumentTemplate
| DocumentTemplateMaxBytes
| Url
| IndexingFragments
| SearchFragments
| Request
| Response
| Headers,
@@ -1512,6 +1782,11 @@ impl std::fmt::Display for EmbedderSource {
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)]
pub struct Fragment {
pub value: serde_json::Value,
}
impl EmbeddingSettings {
fn from_hugging_face(
super::hf::EmbedderOptions {
@@ -1534,6 +1809,8 @@ impl EmbeddingSettings {
document_template,
document_template_max_bytes,
url: Setting::NotSet,
indexing_fragments: Setting::NotSet,
search_fragments: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
@@ -1566,6 +1843,8 @@ impl EmbeddingSettings {
document_template,
document_template_max_bytes,
url: Setting::some_or_not_set(url),
indexing_fragments: Setting::NotSet,
search_fragments: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
@@ -1598,6 +1877,8 @@ impl EmbeddingSettings {
document_template,
document_template_max_bytes,
url: Setting::some_or_not_set(url),
indexing_fragments: Setting::NotSet,
search_fragments: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
@@ -1622,6 +1903,8 @@ impl EmbeddingSettings {
document_template: Setting::NotSet,
document_template_max_bytes: Setting::NotSet,
url: Setting::NotSet,
indexing_fragments: Setting::NotSet,
search_fragments: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
@@ -1638,6 +1921,8 @@ impl EmbeddingSettings {
dimensions,
url,
request,
indexing_fragments,
search_fragments,
response,
distribution,
headers,
@@ -1653,9 +1938,39 @@ impl EmbeddingSettings {
pooling: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions),
document_template,
document_template_max_bytes,
document_template: if indexing_fragments.is_empty() && search_fragments.is_empty() {
document_template
} else {
Setting::NotSet
},
document_template_max_bytes: if indexing_fragments.is_empty()
&& search_fragments.is_empty()
{
document_template_max_bytes
} else {
Setting::NotSet
},
url: Setting::Set(url),
indexing_fragments: if indexing_fragments.is_empty() {
Setting::NotSet
} else {
Setting::Set(
indexing_fragments
.into_iter()
.map(|(name, fragment)| (name, Some(Fragment { value: fragment })))
.collect(),
)
},
search_fragments: if search_fragments.is_empty() {
Setting::NotSet
} else {
Setting::Set(
search_fragments
.into_iter()
.map(|(name, fragment)| (name, Some(Fragment { value: fragment })))
.collect(),
)
},
request: Setting::Set(request),
response: Setting::Set(response),
distribution: Setting::some_or_not_set(distribution),
@@ -1714,6 +2029,8 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
document_template: Setting::NotSet,
document_template_max_bytes: Setting::NotSet,
url: Setting::NotSet,
indexing_fragments: Setting::NotSet,
search_fragments: Setting::NotSet,
request: Setting::NotSet,
response: Setting::NotSet,
headers: Setting::NotSet,
@@ -1786,6 +2103,8 @@ impl From<EmbeddingSettings> for SubEmbeddingSettings {
document_template,
document_template_max_bytes,
url,
indexing_fragments,
search_fragments,
request,
response,
headers,
@@ -1804,6 +2123,8 @@ impl From<EmbeddingSettings> for SubEmbeddingSettings {
document_template,
document_template_max_bytes,
url,
indexing_fragments,
search_fragments,
request,
response,
headers,
@@ -1828,6 +2149,8 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
document_template,
document_template_max_bytes,
url,
indexing_fragments,
search_fragments,
request,
response,
distribution,
@@ -1879,6 +2202,8 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
EmbedderSource::Rest => SubEmbedderOptions::rest(
url.set().unwrap(),
api_key,
indexing_fragments,
search_fragments,
request.set().unwrap(),
response.set().unwrap(),
headers,
@@ -1922,6 +2247,8 @@ impl SubEmbedderOptions {
document_template: _,
document_template_max_bytes: _,
url,
indexing_fragments,
search_fragments,
request,
response,
headers,
@@ -1944,6 +2271,8 @@ impl SubEmbedderOptions {
EmbedderSource::Rest => Self::rest(
url.set().unwrap(),
api_key,
indexing_fragments,
search_fragments,
request.set().unwrap(),
response.set().unwrap(),
headers,
@@ -2010,9 +2339,13 @@ impl SubEmbedderOptions {
distribution: distribution.set(),
})
}
#[allow(clippy::too_many_arguments)]
fn rest(
url: String,
api_key: Setting<String>,
indexing_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
search_fragments: Setting<BTreeMap<String, Option<Fragment>>>,
request: serde_json::Value,
response: serde_json::Value,
headers: Setting<BTreeMap<String, String>>,
@@ -2027,6 +2360,22 @@ impl SubEmbedderOptions {
response,
distribution: distribution.set(),
headers: headers.set().unwrap_or_default(),
search_fragments: search_fragments
.set()
.unwrap_or_default()
.into_iter()
.filter_map(|(name, fragment)| {
Some((name, fragment.map(|fragment| fragment.value)?))
})
.collect(),
indexing_fragments: indexing_fragments
.set()
.unwrap_or_default()
.into_iter()
.filter_map(|(name, fragment)| {
Some((name, fragment.map(|fragment| fragment.value)?))
})
.collect(),
})
}
fn ollama(
@@ -2066,3 +2415,29 @@ impl From<SubEmbedderOptions> for EmbedderOptions {
}
}
}
pub(crate) fn fragments_from_settings(
setting: &Setting<EmbeddingSettings>,
) -> impl Iterator<Item = String> + '_ {
let Some(setting) = setting.as_ref().set() else { return Either::Left(None.into_iter()) };
let filter_map = |(name, fragment): (&String, &Option<Fragment>)| {
if fragment.is_some() {
Some(name.clone())
} else {
None
}
};
if let Some(setting) = setting.indexing_fragments.as_ref().set() {
Either::Right(setting.iter().filter_map(filter_map))
} else {
let Some(setting) = setting.indexing_embedder.as_ref().set() else {
return Either::Left(None.into_iter());
};
let Some(setting) = setting.indexing_fragments.as_ref().set() else {
return Either::Left(None.into_iter());
};
Either::Right(setting.iter().filter_map(filter_map))
}
}

View File

@@ -5,7 +5,7 @@ use milli::documents::mmap_from_objects;
use milli::progress::Progress;
use milli::update::new::indexer;
use milli::update::{IndexerConfig, Settings};
use milli::vector::EmbeddingConfigs;
use milli::vector::RuntimeEmbedders;
use milli::{FacetDistribution, FilterableAttributesRule, Index, Object, OrderBy};
use serde_json::{from_value, json};
@@ -35,7 +35,7 @@ fn test_facet_distribution_with_no_facet_values() {
let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let doc1: Object = from_value(

View File

@@ -10,7 +10,7 @@ use maplit::{btreemap, hashset};
use milli::progress::Progress;
use milli::update::new::indexer;
use milli::update::{IndexerConfig, Settings};
use milli::vector::EmbeddingConfigs;
use milli::vector::RuntimeEmbedders;
use milli::{
AscDesc, Criterion, DocumentId, FilterableAttributesRule, Index, Member, TermsMatchingStrategy,
};
@@ -74,7 +74,7 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index {
let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let mut file = tempfile::tempfile().unwrap();

View File

@@ -8,7 +8,7 @@ use maplit::hashset;
use milli::progress::Progress;
use milli::update::new::indexer;
use milli::update::{IndexerConfig, Settings};
use milli::vector::EmbeddingConfigs;
use milli::vector::RuntimeEmbedders;
use milli::{AscDesc, Criterion, Index, Member, Search, SearchResult, TermsMatchingStrategy};
use rand::Rng;
use Criterion::*;
@@ -288,7 +288,7 @@ fn criteria_ascdesc() {
let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
let mut file = tempfile::tempfile().unwrap();

View File

@@ -6,7 +6,7 @@ use milli::documents::mmap_from_objects;
use milli::progress::Progress;
use milli::update::new::indexer;
use milli::update::{IndexerConfig, Settings};
use milli::vector::EmbeddingConfigs;
use milli::vector::RuntimeEmbedders;
use milli::{Criterion, Index, Object, Search, TermsMatchingStrategy};
use serde_json::from_value;
use tempfile::tempdir;
@@ -123,7 +123,7 @@ fn test_typo_disabled_on_word() {
let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
let mut new_fields_ids_map = db_fields_ids_map.clone();
let embedders = EmbeddingConfigs::default();
let embedders = RuntimeEmbedders::default();
let mut indexer = indexer::DocumentOperation::new();
indexer.replace_documents(&documents).unwrap();