mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-11-04 09:56:28 +00:00 
			
		
		
		
	support vectors or array of vectors
This commit is contained in:
		@@ -6,6 +6,7 @@ use super::vector_document::{
 | 
			
		||||
    MergedVectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions,
 | 
			
		||||
};
 | 
			
		||||
use crate::documents::FieldIdMapper;
 | 
			
		||||
use crate::vector::EmbeddingConfigs;
 | 
			
		||||
use crate::{DocumentId, Index, Result};
 | 
			
		||||
 | 
			
		||||
pub enum DocumentChange<'doc> {
 | 
			
		||||
@@ -94,8 +95,9 @@ impl<'doc> Insertion<'doc> {
 | 
			
		||||
    pub fn inserted_vectors(
 | 
			
		||||
        &self,
 | 
			
		||||
        doc_alloc: &'doc Bump,
 | 
			
		||||
        embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
    ) -> Result<Option<VectorDocumentFromVersions<'doc>>> {
 | 
			
		||||
        VectorDocumentFromVersions::new(&self.new, doc_alloc)
 | 
			
		||||
        VectorDocumentFromVersions::new(&self.new, doc_alloc, embedders)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -165,8 +167,9 @@ impl<'doc> Update<'doc> {
 | 
			
		||||
    pub fn updated_vectors(
 | 
			
		||||
        &self,
 | 
			
		||||
        doc_alloc: &'doc Bump,
 | 
			
		||||
        embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
    ) -> Result<Option<VectorDocumentFromVersions<'doc>>> {
 | 
			
		||||
        VectorDocumentFromVersions::new(&self.new, doc_alloc)
 | 
			
		||||
        VectorDocumentFromVersions::new(&self.new, doc_alloc, embedders)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn merged_vectors<Mapper: FieldIdMapper>(
 | 
			
		||||
@@ -175,11 +178,14 @@ impl<'doc> Update<'doc> {
 | 
			
		||||
        index: &'doc Index,
 | 
			
		||||
        mapper: &'doc Mapper,
 | 
			
		||||
        doc_alloc: &'doc Bump,
 | 
			
		||||
        embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
    ) -> Result<Option<MergedVectorDocument<'doc>>> {
 | 
			
		||||
        if self.has_deletion {
 | 
			
		||||
            MergedVectorDocument::without_db(&self.new, doc_alloc)
 | 
			
		||||
            MergedVectorDocument::without_db(&self.new, doc_alloc, embedders)
 | 
			
		||||
        } else {
 | 
			
		||||
            MergedVectorDocument::with_db(self.docid, index, rtxn, mapper, &self.new, doc_alloc)
 | 
			
		||||
            MergedVectorDocument::with_db(
 | 
			
		||||
                self.docid, index, rtxn, mapper, &self.new, doc_alloc, embedders,
 | 
			
		||||
            )
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -93,7 +93,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
 | 
			
		||||
                        context.db_fields_ids_map,
 | 
			
		||||
                        &context.doc_alloc,
 | 
			
		||||
                    )?;
 | 
			
		||||
                    let new_vectors = update.updated_vectors(&context.doc_alloc)?;
 | 
			
		||||
                    let new_vectors = update.updated_vectors(&context.doc_alloc, self.embedders)?;
 | 
			
		||||
 | 
			
		||||
                    if let Some(new_vectors) = &new_vectors {
 | 
			
		||||
                        unused_vectors_distribution.append(new_vectors);
 | 
			
		||||
@@ -118,7 +118,12 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
 | 
			
		||||
                            if let Some(embeddings) = new_vectors.embeddings {
 | 
			
		||||
                                chunks.set_vectors(
 | 
			
		||||
                                    update.docid(),
 | 
			
		||||
                                    embeddings.into_vec().map_err(UserError::SerdeJson)?,
 | 
			
		||||
                                    embeddings
 | 
			
		||||
                                        .into_vec(&context.doc_alloc, embedder_name)
 | 
			
		||||
                                        .map_err(|error| UserError::InvalidVectorsEmbedderConf {
 | 
			
		||||
                                            document_id: update.external_document_id().to_string(),
 | 
			
		||||
                                            error,
 | 
			
		||||
                                        })?,
 | 
			
		||||
                                );
 | 
			
		||||
                            } else if new_vectors.regenerate {
 | 
			
		||||
                                let new_rendered = prompt.render_document(
 | 
			
		||||
@@ -177,7 +182,8 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                DocumentChange::Insertion(insertion) => {
 | 
			
		||||
                    let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?;
 | 
			
		||||
                    let new_vectors =
 | 
			
		||||
                        insertion.inserted_vectors(&context.doc_alloc, self.embedders)?;
 | 
			
		||||
                    if let Some(new_vectors) = &new_vectors {
 | 
			
		||||
                        unused_vectors_distribution.append(new_vectors);
 | 
			
		||||
                    }
 | 
			
		||||
@@ -194,7 +200,14 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
 | 
			
		||||
                            if let Some(embeddings) = new_vectors.embeddings {
 | 
			
		||||
                                chunks.set_vectors(
 | 
			
		||||
                                    insertion.docid(),
 | 
			
		||||
                                    embeddings.into_vec().map_err(UserError::SerdeJson)?,
 | 
			
		||||
                                    embeddings
 | 
			
		||||
                                        .into_vec(&context.doc_alloc, embedder_name)
 | 
			
		||||
                                        .map_err(|error| UserError::InvalidVectorsEmbedderConf {
 | 
			
		||||
                                            document_id: insertion
 | 
			
		||||
                                                .external_document_id()
 | 
			
		||||
                                                .to_string(),
 | 
			
		||||
                                            error,
 | 
			
		||||
                                        })?,
 | 
			
		||||
                                );
 | 
			
		||||
                            } else if new_vectors.regenerate {
 | 
			
		||||
                                let rendered = prompt.render_document(
 | 
			
		||||
 
 | 
			
		||||
@@ -326,3 +326,294 @@ pub fn match_component<'de, 'indexer: 'de>(
 | 
			
		||||
    }
 | 
			
		||||
    ControlFlow::Continue(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrRawValue<'a> {
 | 
			
		||||
    value: &'a RawValue,
 | 
			
		||||
    alloc: &'a Bump,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> DeserrRawValue<'a> {
 | 
			
		||||
    pub fn new_in(value: &'a RawValue, alloc: &'a Bump) -> Self {
 | 
			
		||||
        Self { value, alloc }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrRawVec<'a> {
 | 
			
		||||
    vec: raw_collections::RawVec<'a>,
 | 
			
		||||
    alloc: &'a Bump,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> deserr::Sequence for DeserrRawVec<'a> {
 | 
			
		||||
    type Value = DeserrRawValue<'a>;
 | 
			
		||||
 | 
			
		||||
    type Iter = DeserrRawVecIter<'a>;
 | 
			
		||||
 | 
			
		||||
    fn len(&self) -> usize {
 | 
			
		||||
        self.vec.len()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn into_iter(self) -> Self::Iter {
 | 
			
		||||
        DeserrRawVecIter { it: self.vec.into_iter(), alloc: self.alloc }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrRawVecIter<'a> {
 | 
			
		||||
    it: raw_collections::vec::iter::IntoIter<'a>,
 | 
			
		||||
    alloc: &'a Bump,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> Iterator for DeserrRawVecIter<'a> {
 | 
			
		||||
    type Item = DeserrRawValue<'a>;
 | 
			
		||||
 | 
			
		||||
    fn next(&mut self) -> Option<Self::Item> {
 | 
			
		||||
        let next = self.it.next()?;
 | 
			
		||||
        Some(DeserrRawValue { value: next, alloc: self.alloc })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrRawMap<'a> {
 | 
			
		||||
    map: raw_collections::RawMap<'a>,
 | 
			
		||||
    alloc: &'a Bump,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> deserr::Map for DeserrRawMap<'a> {
 | 
			
		||||
    type Value = DeserrRawValue<'a>;
 | 
			
		||||
 | 
			
		||||
    type Iter = DeserrRawMapIter<'a>;
 | 
			
		||||
 | 
			
		||||
    fn len(&self) -> usize {
 | 
			
		||||
        self.map.len()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn remove(&mut self, _key: &str) -> Option<Self::Value> {
 | 
			
		||||
        unimplemented!()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn into_iter(self) -> Self::Iter {
 | 
			
		||||
        DeserrRawMapIter { it: self.map.into_iter(), alloc: self.alloc }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrRawMapIter<'a> {
 | 
			
		||||
    it: raw_collections::map::iter::IntoIter<'a>,
 | 
			
		||||
    alloc: &'a Bump,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> Iterator for DeserrRawMapIter<'a> {
 | 
			
		||||
    type Item = (String, DeserrRawValue<'a>);
 | 
			
		||||
 | 
			
		||||
    fn next(&mut self) -> Option<Self::Item> {
 | 
			
		||||
        let (name, value) = self.it.next()?;
 | 
			
		||||
        Some((name.to_string(), DeserrRawValue { value, alloc: self.alloc }))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> deserr::IntoValue for DeserrRawValue<'a> {
 | 
			
		||||
    type Sequence = DeserrRawVec<'a>;
 | 
			
		||||
 | 
			
		||||
    type Map = DeserrRawMap<'a>;
 | 
			
		||||
 | 
			
		||||
    fn kind(&self) -> deserr::ValueKind {
 | 
			
		||||
        self.value.deserialize_any(DeserrKindVisitor).unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn into_value(self) -> deserr::Value<Self> {
 | 
			
		||||
        self.value.deserialize_any(DeserrRawValueVisitor { alloc: self.alloc }).unwrap()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrKindVisitor;
 | 
			
		||||
 | 
			
		||||
impl<'de> Visitor<'de> for DeserrKindVisitor {
 | 
			
		||||
    type Value = deserr::ValueKind;
 | 
			
		||||
 | 
			
		||||
    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
 | 
			
		||||
        write!(formatter, "any value")
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_bool<E>(self, _v: bool) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Boolean)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_i64<E>(self, _v: i64) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::NegativeInteger)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_u64<E>(self, _v: u64) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Integer)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_f64<E>(self, _v: f64) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Float)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_str<E>(self, _v: &str) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::String)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_none<E>(self) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Null)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
 | 
			
		||||
    where
 | 
			
		||||
        D: serde::Deserializer<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        deserializer.deserialize_any(self)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_unit<E>(self) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Null)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
 | 
			
		||||
    where
 | 
			
		||||
        D: serde::Deserializer<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        deserializer.deserialize_any(self)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_seq<A>(self, _seq: A) -> Result<Self::Value, A::Error>
 | 
			
		||||
    where
 | 
			
		||||
        A: serde::de::SeqAccess<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Sequence)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_map<A>(self, _map: A) -> Result<Self::Value, A::Error>
 | 
			
		||||
    where
 | 
			
		||||
        A: serde::de::MapAccess<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::ValueKind::Map)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct DeserrRawValueVisitor<'a> {
 | 
			
		||||
    alloc: &'a Bump,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'de> Visitor<'de> for DeserrRawValueVisitor<'de> {
 | 
			
		||||
    type Value = deserr::Value<DeserrRawValue<'de>>;
 | 
			
		||||
 | 
			
		||||
    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
 | 
			
		||||
        write!(formatter, "any value")
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::Boolean(v))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::NegativeInteger(v))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::Integer(v))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::Float(v))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::String(v.to_string()))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::String(v))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_none<E>(self) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::Null)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
 | 
			
		||||
    where
 | 
			
		||||
        D: serde::Deserializer<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        deserializer.deserialize_any(self)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_unit<E>(self) -> Result<Self::Value, E>
 | 
			
		||||
    where
 | 
			
		||||
        E: serde::de::Error,
 | 
			
		||||
    {
 | 
			
		||||
        Ok(deserr::Value::Null)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
 | 
			
		||||
    where
 | 
			
		||||
        D: serde::Deserializer<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        deserializer.deserialize_any(self)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
 | 
			
		||||
    where
 | 
			
		||||
        A: serde::de::SeqAccess<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        let mut raw_vec = raw_collections::RawVec::new_in(&self.alloc);
 | 
			
		||||
        while let Some(next) = seq.next_element()? {
 | 
			
		||||
            raw_vec.push(next);
 | 
			
		||||
        }
 | 
			
		||||
        Ok(deserr::Value::Sequence(DeserrRawVec { vec: raw_vec, alloc: self.alloc }))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
 | 
			
		||||
    where
 | 
			
		||||
        A: serde::de::MapAccess<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        let _ = map;
 | 
			
		||||
        Err(serde::de::Error::invalid_type(serde::de::Unexpected::Map, &self))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
 | 
			
		||||
    where
 | 
			
		||||
        A: serde::de::EnumAccess<'de>,
 | 
			
		||||
    {
 | 
			
		||||
        let _ = data;
 | 
			
		||||
        Err(serde::de::Error::invalid_type(serde::de::Unexpected::Enum, &self))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,29 +1,67 @@
 | 
			
		||||
use std::collections::BTreeSet;
 | 
			
		||||
 | 
			
		||||
use bumpalo::Bump;
 | 
			
		||||
use deserr::{Deserr, IntoValue};
 | 
			
		||||
use heed::RoTxn;
 | 
			
		||||
use raw_collections::RawMap;
 | 
			
		||||
use serde::Serialize;
 | 
			
		||||
use serde_json::value::RawValue;
 | 
			
		||||
 | 
			
		||||
use super::document::{Document, DocumentFromDb, DocumentFromVersions, Versions};
 | 
			
		||||
use super::indexer::de::DeserrRawValue;
 | 
			
		||||
use crate::documents::FieldIdMapper;
 | 
			
		||||
use crate::index::IndexEmbeddingConfig;
 | 
			
		||||
use crate::vector::parsed_vectors::RawVectors;
 | 
			
		||||
use crate::vector::Embedding;
 | 
			
		||||
use crate::vector::parsed_vectors::{
 | 
			
		||||
    RawVectors, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME,
 | 
			
		||||
};
 | 
			
		||||
use crate::vector::{Embedding, EmbeddingConfigs};
 | 
			
		||||
use crate::{DocumentId, Index, InternalError, Result, UserError};
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize)]
 | 
			
		||||
#[serde(untagged)]
 | 
			
		||||
pub enum Embeddings<'doc> {
 | 
			
		||||
    FromJson(&'doc RawValue),
 | 
			
		||||
    FromJsonExplicit(&'doc RawValue),
 | 
			
		||||
    FromJsonImplicityUserProvided(&'doc RawValue),
 | 
			
		||||
    FromDb(Vec<Embedding>),
 | 
			
		||||
}
 | 
			
		||||
impl<'doc> Embeddings<'doc> {
 | 
			
		||||
    pub fn into_vec(self) -> std::result::Result<Vec<Embedding>, serde_json::Error> {
 | 
			
		||||
    pub fn into_vec(
 | 
			
		||||
        self,
 | 
			
		||||
        doc_alloc: &'doc Bump,
 | 
			
		||||
        embedder_name: &str,
 | 
			
		||||
    ) -> std::result::Result<Vec<Embedding>, deserr::errors::JsonError> {
 | 
			
		||||
        match self {
 | 
			
		||||
            /// FIXME: this should be a VecOrArrayOfVec
 | 
			
		||||
            Embeddings::FromJson(value) => serde_json::from_str(value.get()),
 | 
			
		||||
            Embeddings::FromJsonExplicit(value) => {
 | 
			
		||||
                let vectors_ref = deserr::ValuePointerRef::Key {
 | 
			
		||||
                    key: RESERVED_VECTORS_FIELD_NAME,
 | 
			
		||||
                    prev: &deserr::ValuePointerRef::Origin,
 | 
			
		||||
                };
 | 
			
		||||
                let embedders_ref =
 | 
			
		||||
                    deserr::ValuePointerRef::Key { key: embedder_name, prev: &vectors_ref };
 | 
			
		||||
 | 
			
		||||
                let embeddings_ref =
 | 
			
		||||
                    deserr::ValuePointerRef::Key { key: "embeddings", prev: &embedders_ref };
 | 
			
		||||
 | 
			
		||||
                let v: VectorOrArrayOfVectors = VectorOrArrayOfVectors::deserialize_from_value(
 | 
			
		||||
                    DeserrRawValue::new_in(value, doc_alloc).into_value(),
 | 
			
		||||
                    embeddings_ref,
 | 
			
		||||
                )?;
 | 
			
		||||
                Ok(v.into_array_of_vectors().unwrap_or_default())
 | 
			
		||||
            }
 | 
			
		||||
            Embeddings::FromJsonImplicityUserProvided(value) => {
 | 
			
		||||
                let vectors_ref = deserr::ValuePointerRef::Key {
 | 
			
		||||
                    key: RESERVED_VECTORS_FIELD_NAME,
 | 
			
		||||
                    prev: &deserr::ValuePointerRef::Origin,
 | 
			
		||||
                };
 | 
			
		||||
                let embedders_ref =
 | 
			
		||||
                    deserr::ValuePointerRef::Key { key: embedder_name, prev: &vectors_ref };
 | 
			
		||||
 | 
			
		||||
                let v: VectorOrArrayOfVectors = VectorOrArrayOfVectors::deserialize_from_value(
 | 
			
		||||
                    DeserrRawValue::new_in(value, doc_alloc).into_value(),
 | 
			
		||||
                    embedders_ref,
 | 
			
		||||
                )?;
 | 
			
		||||
                Ok(v.into_array_of_vectors().unwrap_or_default())
 | 
			
		||||
            }
 | 
			
		||||
            Embeddings::FromDb(vec) => Ok(vec),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@@ -109,7 +147,7 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> {
 | 
			
		||||
                Ok((&*config_name, entry))
 | 
			
		||||
            })
 | 
			
		||||
            .chain(self.vectors_field.iter().flat_map(|map| map.iter()).map(|(name, value)| {
 | 
			
		||||
                Ok((name, entry_from_raw_value(value).map_err(InternalError::SerdeJson)?))
 | 
			
		||||
                Ok((name, entry_from_raw_value(value, false).map_err(InternalError::SerdeJson)?))
 | 
			
		||||
            }))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -122,7 +160,8 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> {
 | 
			
		||||
            }
 | 
			
		||||
            None => match self.vectors_field.as_ref().and_then(|obkv| obkv.get(key)) {
 | 
			
		||||
                Some(embedding_from_doc) => Some(
 | 
			
		||||
                    entry_from_raw_value(embedding_from_doc).map_err(InternalError::SerdeJson)?,
 | 
			
		||||
                    entry_from_raw_value(embedding_from_doc, false)
 | 
			
		||||
                        .map_err(InternalError::SerdeJson)?,
 | 
			
		||||
                ),
 | 
			
		||||
                None => None,
 | 
			
		||||
            },
 | 
			
		||||
@@ -132,26 +171,40 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> {
 | 
			
		||||
 | 
			
		||||
fn entry_from_raw_value(
 | 
			
		||||
    value: &RawValue,
 | 
			
		||||
    has_configured_embedder: bool,
 | 
			
		||||
) -> std::result::Result<VectorEntry<'_>, serde_json::Error> {
 | 
			
		||||
    let value: RawVectors = serde_json::from_str(value.get())?;
 | 
			
		||||
    Ok(VectorEntry {
 | 
			
		||||
        has_configured_embedder: false,
 | 
			
		||||
        embeddings: value.embeddings().map(Embeddings::FromJson),
 | 
			
		||||
        regenerate: value.must_regenerate(),
 | 
			
		||||
 | 
			
		||||
    Ok(match value {
 | 
			
		||||
        RawVectors::Explicit(raw_explicit_vectors) => VectorEntry {
 | 
			
		||||
            has_configured_embedder,
 | 
			
		||||
            embeddings: raw_explicit_vectors.embeddings.map(Embeddings::FromJsonExplicit),
 | 
			
		||||
            regenerate: raw_explicit_vectors.regenerate,
 | 
			
		||||
        },
 | 
			
		||||
        RawVectors::ImplicitlyUserProvided(value) => VectorEntry {
 | 
			
		||||
            has_configured_embedder,
 | 
			
		||||
            embeddings: Some(Embeddings::FromJsonImplicityUserProvided(value)),
 | 
			
		||||
            regenerate: false,
 | 
			
		||||
        },
 | 
			
		||||
    })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct VectorDocumentFromVersions<'doc> {
 | 
			
		||||
    vectors: RawMap<'doc>,
 | 
			
		||||
    embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'doc> VectorDocumentFromVersions<'doc> {
 | 
			
		||||
    pub fn new(versions: &Versions<'doc>, bump: &'doc Bump) -> Result<Option<Self>> {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        versions: &Versions<'doc>,
 | 
			
		||||
        bump: &'doc Bump,
 | 
			
		||||
        embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
    ) -> Result<Option<Self>> {
 | 
			
		||||
        let document = DocumentFromVersions::new(versions);
 | 
			
		||||
        if let Some(vectors_field) = document.vectors_field()? {
 | 
			
		||||
            let vectors =
 | 
			
		||||
                RawMap::from_raw_value(vectors_field, bump).map_err(UserError::SerdeJson)?;
 | 
			
		||||
            Ok(Some(Self { vectors }))
 | 
			
		||||
            Ok(Some(Self { vectors, embedders }))
 | 
			
		||||
        } else {
 | 
			
		||||
            Ok(None)
 | 
			
		||||
        }
 | 
			
		||||
@@ -161,14 +214,16 @@ impl<'doc> VectorDocumentFromVersions<'doc> {
 | 
			
		||||
impl<'doc> VectorDocument<'doc> for VectorDocumentFromVersions<'doc> {
 | 
			
		||||
    fn iter_vectors(&self) -> impl Iterator<Item = Result<(&'doc str, VectorEntry<'doc>)>> {
 | 
			
		||||
        self.vectors.iter().map(|(embedder, vectors)| {
 | 
			
		||||
            let vectors = entry_from_raw_value(vectors).map_err(UserError::SerdeJson)?;
 | 
			
		||||
            let vectors = entry_from_raw_value(vectors, self.embedders.contains(embedder))
 | 
			
		||||
                .map_err(UserError::SerdeJson)?;
 | 
			
		||||
            Ok((embedder, vectors))
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn vectors_for_key(&self, key: &str) -> Result<Option<VectorEntry<'doc>>> {
 | 
			
		||||
        let Some(vectors) = self.vectors.get(key) else { return Ok(None) };
 | 
			
		||||
        let vectors = entry_from_raw_value(vectors).map_err(UserError::SerdeJson)?;
 | 
			
		||||
        let vectors = entry_from_raw_value(vectors, self.embedders.contains(key))
 | 
			
		||||
            .map_err(UserError::SerdeJson)?;
 | 
			
		||||
        Ok(Some(vectors))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -186,14 +241,19 @@ impl<'doc> MergedVectorDocument<'doc> {
 | 
			
		||||
        db_fields_ids_map: &'doc Mapper,
 | 
			
		||||
        versions: &Versions<'doc>,
 | 
			
		||||
        doc_alloc: &'doc Bump,
 | 
			
		||||
        embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
    ) -> Result<Option<Self>> {
 | 
			
		||||
        let db = VectorDocumentFromDb::new(docid, index, rtxn, db_fields_ids_map, doc_alloc)?;
 | 
			
		||||
        let new_doc = VectorDocumentFromVersions::new(versions, doc_alloc)?;
 | 
			
		||||
        let new_doc = VectorDocumentFromVersions::new(versions, doc_alloc, embedders)?;
 | 
			
		||||
        Ok(if db.is_none() && new_doc.is_none() { None } else { Some(Self { new_doc, db }) })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn without_db(versions: &Versions<'doc>, doc_alloc: &'doc Bump) -> Result<Option<Self>> {
 | 
			
		||||
        let Some(new_doc) = VectorDocumentFromVersions::new(versions, doc_alloc)? else {
 | 
			
		||||
    pub fn without_db(
 | 
			
		||||
        versions: &Versions<'doc>,
 | 
			
		||||
        doc_alloc: &'doc Bump,
 | 
			
		||||
        embedders: &'doc EmbeddingConfigs,
 | 
			
		||||
    ) -> Result<Option<Self>> {
 | 
			
		||||
        let Some(new_doc) = VectorDocumentFromVersions::new(versions, doc_alloc, embedders)? else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        Ok(Some(Self { new_doc: Some(new_doc), db: None }))
 | 
			
		||||
 
 | 
			
		||||
@@ -316,6 +316,10 @@ impl EmbeddingConfigs {
 | 
			
		||||
        Self(data)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn contains(&self, name: &str) -> bool {
 | 
			
		||||
        self.0.contains_key(name)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Get an embedder configuration and template from its name.
 | 
			
		||||
    pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
 | 
			
		||||
        self.0.get(name).cloned()
 | 
			
		||||
 
 | 
			
		||||
@@ -84,7 +84,6 @@ impl<'doc> RawVectors<'doc> {
 | 
			
		||||
            RawVectors::Explicit(RawExplicitVectors { regenerate, .. }) => *regenerate,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn embeddings(&self) -> Option<&'doc RawValue> {
 | 
			
		||||
        match self {
 | 
			
		||||
            RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings),
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user