mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-11-04 01:46:28 +00:00 
			
		
		
		
	Merge #5418
5418: Cache embeddings in search r=Kerollmops a=dureuill # Pull Request ## Related issue TBD ## What does this PR do? - Adds a cache for embeddings produced in search - The cache is disabled by default, and can be enabled following the instructions [here](https://github.com/orgs/meilisearch/discussions/818). - Had to accommodate the `timeout` test for openai that uses a mock that simulates a timeout on subsequent responses: since the test was reusing the same query, the cache would kick-in and no request would be made to the mock, meaning no timeout any longer and so a failing test 😅 - `Embedder::embed_search` now accepts a reference instead of an owned `String`. ## Manual testing - I created 4 indexes on a fresh DB with the same settings (one embedder from openai) - I sent 1/4 of movies.json to each index - I sent a federated search request against all 4 indexes, with the same query for each index, using the embedder of each index. Results: - The first call took 400ms to 1s. Before this change, it took in the 3s range. - Any repeated call with the same query took in the range of 25ms. - Looking at the details at trace log level, I can see that the first index that needs the embedding is taking most of the 400ms in `embed_one`. The other indexes report that the query text is found in the cache and they each take a few µs. Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
		
							
								
								
									
										10
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										10
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							@@ -3498,6 +3498,15 @@ version = "0.4.26"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "lru"
 | 
			
		||||
version = "0.13.0"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "227748d55f2f0ab4735d87fd623798cb6b664512fe979705f829c9f81c934465"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "hashbrown 0.15.2",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "lzma-rs"
 | 
			
		||||
version = "0.3.0"
 | 
			
		||||
@@ -3778,6 +3787,7 @@ dependencies = [
 | 
			
		||||
 "json-depth-checker",
 | 
			
		||||
 "levenshtein_automata",
 | 
			
		||||
 "liquid",
 | 
			
		||||
 "lru",
 | 
			
		||||
 "maplit",
 | 
			
		||||
 "md5",
 | 
			
		||||
 "meili-snap",
 | 
			
		||||
 
 | 
			
		||||
@@ -125,6 +125,10 @@ pub struct IndexSchedulerOptions {
 | 
			
		||||
    pub instance_features: InstanceTogglableFeatures,
 | 
			
		||||
    /// The experimental features enabled for this instance.
 | 
			
		||||
    pub auto_upgrade: bool,
 | 
			
		||||
    /// The maximal number of entries in the search query cache of an embedder.
 | 
			
		||||
    ///
 | 
			
		||||
    /// 0 disables the cache.
 | 
			
		||||
    pub embedding_cache_cap: usize,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Structure which holds meilisearch's indexes and schedules the tasks
 | 
			
		||||
@@ -156,6 +160,11 @@ pub struct IndexScheduler {
 | 
			
		||||
    /// The Authorization header to send to the webhook URL.
 | 
			
		||||
    pub(crate) webhook_authorization_header: Option<String>,
 | 
			
		||||
 | 
			
		||||
    /// A map to retrieve the runtime representation of an embedder depending on its configuration.
 | 
			
		||||
    ///
 | 
			
		||||
    /// This map may return the same embedder object for two different indexes or embedder settings,
 | 
			
		||||
    /// but it will only do this if the embedder configuration options are the same, leading
 | 
			
		||||
    /// to the same embeddings for the same input text.
 | 
			
		||||
    embedders: Arc<RwLock<HashMap<EmbedderOptions, Arc<Embedder>>>>,
 | 
			
		||||
 | 
			
		||||
    // ================= test
 | 
			
		||||
@@ -818,7 +827,7 @@ impl IndexScheduler {
 | 
			
		||||
 | 
			
		||||
                    // add missing embedder
 | 
			
		||||
                    let embedder = Arc::new(
 | 
			
		||||
                        Embedder::new(embedder_options.clone())
 | 
			
		||||
                        Embedder::new(embedder_options.clone(), self.scheduler.embedding_cache_cap)
 | 
			
		||||
                            .map_err(meilisearch_types::milli::vector::Error::from)
 | 
			
		||||
                            .map_err(|err| {
 | 
			
		||||
                                Error::from_milli(err.into(), Some(index_uid.clone()))
 | 
			
		||||
 
 | 
			
		||||
@@ -76,6 +76,11 @@ pub struct Scheduler {
 | 
			
		||||
 | 
			
		||||
    /// The path to the version file of Meilisearch.
 | 
			
		||||
    pub(crate) version_file_path: PathBuf,
 | 
			
		||||
 | 
			
		||||
    /// The maximal number of entries in the search query cache of an embedder.
 | 
			
		||||
    ///
 | 
			
		||||
    /// 0 disables the cache.
 | 
			
		||||
    pub(crate) embedding_cache_cap: usize,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Scheduler {
 | 
			
		||||
@@ -90,6 +95,7 @@ impl Scheduler {
 | 
			
		||||
            snapshots_path: self.snapshots_path.clone(),
 | 
			
		||||
            auth_env: self.auth_env.clone(),
 | 
			
		||||
            version_file_path: self.version_file_path.clone(),
 | 
			
		||||
            embedding_cache_cap: self.embedding_cache_cap,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -105,6 +111,7 @@ impl Scheduler {
 | 
			
		||||
            snapshots_path: options.snapshots_path.clone(),
 | 
			
		||||
            auth_env,
 | 
			
		||||
            version_file_path: options.version_file_path.clone(),
 | 
			
		||||
            embedding_cache_cap: options.embedding_cache_cap,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -104,10 +104,9 @@ fn import_vectors() {
 | 
			
		||||
 | 
			
		||||
        let configs = index_scheduler.embedders("doggos".to_string(), configs).unwrap();
 | 
			
		||||
        let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap();
 | 
			
		||||
        let beagle_embed =
 | 
			
		||||
            hf_embedder.embed_search(S("Intel the beagle best doggo"), None).unwrap();
 | 
			
		||||
        let lab_embed = hf_embedder.embed_search(S("Max the lab best doggo"), None).unwrap();
 | 
			
		||||
        let patou_embed = hf_embedder.embed_search(S("kefir the patou best doggo"), None).unwrap();
 | 
			
		||||
        let beagle_embed = hf_embedder.embed_search("Intel the beagle best doggo", None).unwrap();
 | 
			
		||||
        let lab_embed = hf_embedder.embed_search("Max the lab best doggo", None).unwrap();
 | 
			
		||||
        let patou_embed = hf_embedder.embed_search("kefir the patou best doggo", None).unwrap();
 | 
			
		||||
        (fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed)
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -112,6 +112,7 @@ impl IndexScheduler {
 | 
			
		||||
            batched_tasks_size_limit: u64::MAX,
 | 
			
		||||
            instance_features: Default::default(),
 | 
			
		||||
            auto_upgrade: true, // Don't cost much and will ensure the happy path works
 | 
			
		||||
            embedding_cache_cap: 10,
 | 
			
		||||
        };
 | 
			
		||||
        let version = configuration(&mut options).unwrap_or_else(|| {
 | 
			
		||||
            (
 | 
			
		||||
 
 | 
			
		||||
@@ -199,6 +199,7 @@ struct Infos {
 | 
			
		||||
    experimental_network: bool,
 | 
			
		||||
    experimental_get_task_documents_route: bool,
 | 
			
		||||
    experimental_composite_embedders: bool,
 | 
			
		||||
    experimental_embedding_cache_entries: usize,
 | 
			
		||||
    gpu_enabled: bool,
 | 
			
		||||
    db_path: bool,
 | 
			
		||||
    import_dump: bool,
 | 
			
		||||
@@ -246,6 +247,7 @@ impl Infos {
 | 
			
		||||
            experimental_reduce_indexing_memory_usage,
 | 
			
		||||
            experimental_max_number_of_batched_tasks,
 | 
			
		||||
            experimental_limit_batched_tasks_total_size,
 | 
			
		||||
            experimental_embedding_cache_entries,
 | 
			
		||||
            http_addr,
 | 
			
		||||
            master_key: _,
 | 
			
		||||
            env,
 | 
			
		||||
@@ -312,6 +314,7 @@ impl Infos {
 | 
			
		||||
            experimental_network: network,
 | 
			
		||||
            experimental_get_task_documents_route: get_task_documents_route,
 | 
			
		||||
            experimental_composite_embedders: composite_embedders,
 | 
			
		||||
            experimental_embedding_cache_entries,
 | 
			
		||||
            gpu_enabled: meilisearch_types::milli::vector::is_cuda_enabled(),
 | 
			
		||||
            db_path: db_path != PathBuf::from("./data.ms"),
 | 
			
		||||
            import_dump: import_dump.is_some(),
 | 
			
		||||
 
 | 
			
		||||
@@ -233,6 +233,7 @@ pub fn setup_meilisearch(opt: &Opt) -> anyhow::Result<(Arc<IndexScheduler>, Arc<
 | 
			
		||||
        index_count: DEFAULT_INDEX_COUNT,
 | 
			
		||||
        instance_features: opt.to_instance_features(),
 | 
			
		||||
        auto_upgrade: opt.experimental_dumpless_upgrade,
 | 
			
		||||
        embedding_cache_cap: opt.experimental_embedding_cache_entries,
 | 
			
		||||
    };
 | 
			
		||||
    let bin_major: u32 = VERSION_MAJOR.parse().unwrap();
 | 
			
		||||
    let bin_minor: u32 = VERSION_MINOR.parse().unwrap();
 | 
			
		||||
 
 | 
			
		||||
@@ -63,7 +63,8 @@ const MEILI_EXPERIMENTAL_MAX_NUMBER_OF_BATCHED_TASKS: &str =
 | 
			
		||||
    "MEILI_EXPERIMENTAL_MAX_NUMBER_OF_BATCHED_TASKS";
 | 
			
		||||
const MEILI_EXPERIMENTAL_LIMIT_BATCHED_TASKS_TOTAL_SIZE: &str =
 | 
			
		||||
    "MEILI_EXPERIMENTAL_LIMIT_BATCHED_TASKS_SIZE";
 | 
			
		||||
 | 
			
		||||
const MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES: &str =
 | 
			
		||||
    "MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES";
 | 
			
		||||
const DEFAULT_CONFIG_FILE_PATH: &str = "./config.toml";
 | 
			
		||||
const DEFAULT_DB_PATH: &str = "./data.ms";
 | 
			
		||||
const DEFAULT_HTTP_ADDR: &str = "localhost:7700";
 | 
			
		||||
@@ -446,6 +447,14 @@ pub struct Opt {
 | 
			
		||||
    #[serde(default = "default_limit_batched_tasks_total_size")]
 | 
			
		||||
    pub experimental_limit_batched_tasks_total_size: u64,
 | 
			
		||||
 | 
			
		||||
    /// Enables experimental caching of search query embeddings. The value represents the maximal number of entries in the cache of each
 | 
			
		||||
    /// distinct embedder.
 | 
			
		||||
    ///
 | 
			
		||||
    /// For more information, see <https://github.com/orgs/meilisearch/discussions/818>.
 | 
			
		||||
    #[clap(long, env = MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES, default_value_t = default_embedding_cache_entries())]
 | 
			
		||||
    #[serde(default = "default_embedding_cache_entries")]
 | 
			
		||||
    pub experimental_embedding_cache_entries: usize,
 | 
			
		||||
 | 
			
		||||
    #[serde(flatten)]
 | 
			
		||||
    #[clap(flatten)]
 | 
			
		||||
    pub indexer_options: IndexerOpts,
 | 
			
		||||
@@ -549,6 +558,7 @@ impl Opt {
 | 
			
		||||
            experimental_reduce_indexing_memory_usage,
 | 
			
		||||
            experimental_max_number_of_batched_tasks,
 | 
			
		||||
            experimental_limit_batched_tasks_total_size,
 | 
			
		||||
            experimental_embedding_cache_entries,
 | 
			
		||||
        } = self;
 | 
			
		||||
        export_to_env_if_not_present(MEILI_DB_PATH, db_path);
 | 
			
		||||
        export_to_env_if_not_present(MEILI_HTTP_ADDR, http_addr);
 | 
			
		||||
@@ -641,6 +651,10 @@ impl Opt {
 | 
			
		||||
            MEILI_EXPERIMENTAL_LIMIT_BATCHED_TASKS_TOTAL_SIZE,
 | 
			
		||||
            experimental_limit_batched_tasks_total_size.to_string(),
 | 
			
		||||
        );
 | 
			
		||||
        export_to_env_if_not_present(
 | 
			
		||||
            MEILI_EXPERIMENTAL_EMBEDDING_CACHE_ENTRIES,
 | 
			
		||||
            experimental_embedding_cache_entries.to_string(),
 | 
			
		||||
        );
 | 
			
		||||
        indexer_options.export_to_env();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -948,6 +962,10 @@ fn default_limit_batched_tasks_total_size() -> u64 {
 | 
			
		||||
    u64::MAX
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn default_embedding_cache_entries() -> usize {
 | 
			
		||||
    0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn default_snapshot_dir() -> PathBuf {
 | 
			
		||||
    PathBuf::from(DEFAULT_SNAPSHOT_DIR)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -916,7 +916,7 @@ fn prepare_search<'t>(
 | 
			
		||||
                    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
 | 
			
		||||
 | 
			
		||||
                    embedder
 | 
			
		||||
                        .embed_search(query.q.clone().unwrap(), Some(deadline))
 | 
			
		||||
                        .embed_search(query.q.as_ref().unwrap(), Some(deadline))
 | 
			
		||||
                        .map_err(milli::vector::Error::from)
 | 
			
		||||
                        .map_err(milli::Error::from)?
 | 
			
		||||
                }
 | 
			
		||||
 
 | 
			
		||||
@@ -1995,7 +1995,7 @@ async fn timeout() {
 | 
			
		||||
 | 
			
		||||
    let (response, code) = index
 | 
			
		||||
        .search_post(json!({
 | 
			
		||||
            "q": "grand chien de berger des montagnes",
 | 
			
		||||
            "q": "grand chien de berger des montagnes foil the cache",
 | 
			
		||||
            "hybrid": {"semanticRatio": 0.99, "embedder": "default"}
 | 
			
		||||
        }))
 | 
			
		||||
        .await;
 | 
			
		||||
 
 | 
			
		||||
@@ -110,6 +110,7 @@ utoipa = { version = "5.3.1", features = [
 | 
			
		||||
    "time",
 | 
			
		||||
    "openapi_extensions",
 | 
			
		||||
] }
 | 
			
		||||
lru = "0.13.0"
 | 
			
		||||
 | 
			
		||||
[dev-dependencies]
 | 
			
		||||
mimalloc = { version = "0.1.43", default-features = false }
 | 
			
		||||
 
 | 
			
		||||
@@ -203,7 +203,7 @@ impl<'a> Search<'a> {
 | 
			
		||||
 | 
			
		||||
                let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
 | 
			
		||||
 | 
			
		||||
                match embedder.embed_search(query, Some(deadline)) {
 | 
			
		||||
                match embedder.embed_search(&query, Some(deadline)) {
 | 
			
		||||
                    Ok(embedding) => embedding,
 | 
			
		||||
                    Err(error) => {
 | 
			
		||||
                        tracing::error!(error=%error, "Embedding failed");
 | 
			
		||||
 
 | 
			
		||||
@@ -2806,8 +2806,9 @@ mod tests {
 | 
			
		||||
            embedding_configs.pop().unwrap();
 | 
			
		||||
        insta::assert_snapshot!(embedder_name, @"manual");
 | 
			
		||||
        insta::assert_debug_snapshot!(user_provided, @"RoaringBitmap<[0, 1, 2]>");
 | 
			
		||||
        let embedder =
 | 
			
		||||
            std::sync::Arc::new(crate::vector::Embedder::new(embedder.embedder_options).unwrap());
 | 
			
		||||
        let embedder = std::sync::Arc::new(
 | 
			
		||||
            crate::vector::Embedder::new(embedder.embedder_options, 0).unwrap(),
 | 
			
		||||
        );
 | 
			
		||||
        let res = index
 | 
			
		||||
            .search(&rtxn)
 | 
			
		||||
            .semantic(embedder_name, embedder, false, Some([0.0, 1.0, 2.0].to_vec()))
 | 
			
		||||
 
 | 
			
		||||
@@ -1628,7 +1628,8 @@ fn embedders(embedding_configs: Vec<IndexEmbeddingConfig>) -> Result<EmbeddingCo
 | 
			
		||||
                let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
 | 
			
		||||
 | 
			
		||||
                let embedder = Arc::new(
 | 
			
		||||
                    Embedder::new(embedder_options.clone())
 | 
			
		||||
                    // cache_cap: no cache needed for indexing purposes
 | 
			
		||||
                    Embedder::new(embedder_options.clone(), 0)
 | 
			
		||||
                        .map_err(crate::vector::Error::from)
 | 
			
		||||
                        .map_err(crate::Error::from)?,
 | 
			
		||||
                );
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,8 @@ use arroy::Distance;
 | 
			
		||||
 | 
			
		||||
use super::error::CompositeEmbedderContainsHuggingFace;
 | 
			
		||||
use super::{
 | 
			
		||||
    hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, NewEmbedderError,
 | 
			
		||||
    hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
 | 
			
		||||
    NewEmbedderError,
 | 
			
		||||
};
 | 
			
		||||
use crate::ThreadPoolNoAbort;
 | 
			
		||||
 | 
			
		||||
@@ -58,9 +59,11 @@ pub struct EmbedderOptions {
 | 
			
		||||
impl Embedder {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        EmbedderOptions { search, index }: EmbedderOptions,
 | 
			
		||||
        cache_cap: usize,
 | 
			
		||||
    ) -> Result<Self, NewEmbedderError> {
 | 
			
		||||
        let search = SubEmbedder::new(search)?;
 | 
			
		||||
        let index = SubEmbedder::new(index)?;
 | 
			
		||||
        let search = SubEmbedder::new(search, cache_cap)?;
 | 
			
		||||
        // cache is only used at search
 | 
			
		||||
        let index = SubEmbedder::new(index, 0)?;
 | 
			
		||||
 | 
			
		||||
        // check dimensions
 | 
			
		||||
        if search.dimensions() != index.dimensions() {
 | 
			
		||||
@@ -118,19 +121,28 @@ impl Embedder {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl SubEmbedder {
 | 
			
		||||
    pub fn new(options: SubEmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        options: SubEmbedderOptions,
 | 
			
		||||
        cache_cap: usize,
 | 
			
		||||
    ) -> std::result::Result<Self, NewEmbedderError> {
 | 
			
		||||
        Ok(match options {
 | 
			
		||||
            SubEmbedderOptions::HuggingFace(options) => {
 | 
			
		||||
                Self::HuggingFace(hf::Embedder::new(options)?)
 | 
			
		||||
                Self::HuggingFace(hf::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedderOptions::OpenAi(options) => {
 | 
			
		||||
                Self::OpenAi(openai::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedderOptions::Ollama(options) => {
 | 
			
		||||
                Self::Ollama(ollama::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
 | 
			
		||||
            SubEmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
 | 
			
		||||
            SubEmbedderOptions::UserProvided(options) => {
 | 
			
		||||
                Self::UserProvided(manual::Embedder::new(options))
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedderOptions::Rest(options) => {
 | 
			
		||||
                Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(
 | 
			
		||||
                options,
 | 
			
		||||
                cache_cap,
 | 
			
		||||
                rest::ConfigurationSource::User,
 | 
			
		||||
            )?),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -148,6 +160,27 @@ impl SubEmbedder {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn embed_one(
 | 
			
		||||
        &self,
 | 
			
		||||
        text: &str,
 | 
			
		||||
        deadline: Option<Instant>,
 | 
			
		||||
    ) -> std::result::Result<Embedding, EmbedError> {
 | 
			
		||||
        match self {
 | 
			
		||||
            SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
 | 
			
		||||
            SubEmbedder::OpenAi(embedder) => {
 | 
			
		||||
                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedder::Ollama(embedder) => {
 | 
			
		||||
                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
 | 
			
		||||
            }
 | 
			
		||||
            SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
 | 
			
		||||
            SubEmbedder::Rest(embedder) => embedder
 | 
			
		||||
                .embed_ref(&[text], deadline)?
 | 
			
		||||
                .pop()
 | 
			
		||||
                .ok_or_else(EmbedError::missing_embedding),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Embed multiple chunks of texts.
 | 
			
		||||
    ///
 | 
			
		||||
    /// Each chunk is composed of one or multiple texts.
 | 
			
		||||
@@ -233,6 +266,16 @@ impl SubEmbedder {
 | 
			
		||||
            SubEmbedder::Rest(embedder) => embedder.distribution(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(super) fn cache(&self) -> Option<&EmbeddingCache> {
 | 
			
		||||
        match self {
 | 
			
		||||
            SubEmbedder::HuggingFace(embedder) => Some(embedder.cache()),
 | 
			
		||||
            SubEmbedder::OpenAi(embedder) => Some(embedder.cache()),
 | 
			
		||||
            SubEmbedder::UserProvided(_) => None,
 | 
			
		||||
            SubEmbedder::Ollama(embedder) => Some(embedder.cache()),
 | 
			
		||||
            SubEmbedder::Rest(embedder) => Some(embedder.cache()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn check_similarity(
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType};
 | 
			
		||||
use tokenizers::{PaddingParams, Tokenizer};
 | 
			
		||||
 | 
			
		||||
pub use super::error::{EmbedError, Error, NewEmbedderError};
 | 
			
		||||
use super::{DistributionShift, Embedding};
 | 
			
		||||
use super::{DistributionShift, Embedding, EmbeddingCache};
 | 
			
		||||
 | 
			
		||||
#[derive(
 | 
			
		||||
    Debug,
 | 
			
		||||
@@ -84,6 +84,7 @@ pub struct Embedder {
 | 
			
		||||
    options: EmbedderOptions,
 | 
			
		||||
    dimensions: usize,
 | 
			
		||||
    pooling: Pooling,
 | 
			
		||||
    cache: EmbeddingCache,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl std::fmt::Debug for Embedder {
 | 
			
		||||
@@ -149,7 +150,10 @@ impl From<PoolingConfig> for Pooling {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Embedder {
 | 
			
		||||
    pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        options: EmbedderOptions,
 | 
			
		||||
        cache_cap: usize,
 | 
			
		||||
    ) -> std::result::Result<Self, NewEmbedderError> {
 | 
			
		||||
        let device = match candle_core::Device::cuda_if_available(0) {
 | 
			
		||||
            Ok(device) => device,
 | 
			
		||||
            Err(error) => {
 | 
			
		||||
@@ -245,7 +249,14 @@ impl Embedder {
 | 
			
		||||
            tokenizer.with_padding(Some(pp));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let mut this = Self { model, tokenizer, options, dimensions: 0, pooling };
 | 
			
		||||
        let mut this = Self {
 | 
			
		||||
            model,
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            options,
 | 
			
		||||
            dimensions: 0,
 | 
			
		||||
            pooling,
 | 
			
		||||
            cache: EmbeddingCache::new(cache_cap),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let embeddings = this
 | 
			
		||||
            .embed(vec!["test".into()])
 | 
			
		||||
@@ -355,4 +366,8 @@ impl Embedder {
 | 
			
		||||
    pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
 | 
			
		||||
        texts.iter().map(|text| self.embed_one(text)).collect()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(super) fn cache(&self) -> &EmbeddingCache {
 | 
			
		||||
        &self.cache
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use std::num::NonZeroUsize;
 | 
			
		||||
use std::sync::{Arc, Mutex};
 | 
			
		||||
use std::time::Instant;
 | 
			
		||||
 | 
			
		||||
use arroy::distances::{BinaryQuantizedCosine, Cosine};
 | 
			
		||||
@@ -551,6 +552,46 @@ pub enum Embedder {
 | 
			
		||||
    Composite(composite::Embedder),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct EmbeddingCache {
 | 
			
		||||
    data: Option<Mutex<lru::LruCache<String, Embedding>>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl EmbeddingCache {
 | 
			
		||||
    const MAX_TEXT_LEN: usize = 2000;
 | 
			
		||||
 | 
			
		||||
    pub fn new(cap: usize) -> Self {
 | 
			
		||||
        let data = NonZeroUsize::new(cap).map(lru::LruCache::new).map(Mutex::new);
 | 
			
		||||
        Self { data }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Get the embedding corresponding to `text`, if any is present in the cache.
 | 
			
		||||
    pub fn get(&self, text: &str) -> Option<Embedding> {
 | 
			
		||||
        let data = self.data.as_ref()?;
 | 
			
		||||
        if text.len() > Self::MAX_TEXT_LEN {
 | 
			
		||||
            return None;
 | 
			
		||||
        }
 | 
			
		||||
        let mut cache = data.lock().unwrap();
 | 
			
		||||
 | 
			
		||||
        cache.get(text).cloned()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Puts a new embedding for the specified `text`
 | 
			
		||||
    pub fn put(&self, text: String, embedding: Embedding) {
 | 
			
		||||
        let Some(data) = self.data.as_ref() else {
 | 
			
		||||
            return;
 | 
			
		||||
        };
 | 
			
		||||
        if text.len() > Self::MAX_TEXT_LEN {
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
        tracing::trace!(text, "embedding added to cache");
 | 
			
		||||
 | 
			
		||||
        let mut cache = data.lock().unwrap();
 | 
			
		||||
 | 
			
		||||
        cache.put(text, embedding);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Configuration for an embedder.
 | 
			
		||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
 | 
			
		||||
pub struct EmbeddingConfig {
 | 
			
		||||
@@ -629,19 +670,30 @@ impl Default for EmbedderOptions {
 | 
			
		||||
 | 
			
		||||
impl Embedder {
 | 
			
		||||
    /// Spawns a new embedder built from its options.
 | 
			
		||||
    pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        options: EmbedderOptions,
 | 
			
		||||
        cache_cap: usize,
 | 
			
		||||
    ) -> std::result::Result<Self, NewEmbedderError> {
 | 
			
		||||
        Ok(match options {
 | 
			
		||||
            EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
 | 
			
		||||
            EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
 | 
			
		||||
            EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
 | 
			
		||||
            EmbedderOptions::HuggingFace(options) => {
 | 
			
		||||
                Self::HuggingFace(hf::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
            EmbedderOptions::OpenAi(options) => {
 | 
			
		||||
                Self::OpenAi(openai::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
            EmbedderOptions::Ollama(options) => {
 | 
			
		||||
                Self::Ollama(ollama::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
            EmbedderOptions::UserProvided(options) => {
 | 
			
		||||
                Self::UserProvided(manual::Embedder::new(options))
 | 
			
		||||
            }
 | 
			
		||||
            EmbedderOptions::Rest(options) => {
 | 
			
		||||
                Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
 | 
			
		||||
            }
 | 
			
		||||
            EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(
 | 
			
		||||
                options,
 | 
			
		||||
                cache_cap,
 | 
			
		||||
                rest::ConfigurationSource::User,
 | 
			
		||||
            )?),
 | 
			
		||||
            EmbedderOptions::Composite(options) => {
 | 
			
		||||
                Self::Composite(composite::Embedder::new(options)?)
 | 
			
		||||
                Self::Composite(composite::Embedder::new(options, cache_cap)?)
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
@@ -651,19 +703,35 @@ impl Embedder {
 | 
			
		||||
    #[tracing::instrument(level = "debug", skip_all, target = "search")]
 | 
			
		||||
    pub fn embed_search(
 | 
			
		||||
        &self,
 | 
			
		||||
        text: String,
 | 
			
		||||
        text: &str,
 | 
			
		||||
        deadline: Option<Instant>,
 | 
			
		||||
    ) -> std::result::Result<Embedding, EmbedError> {
 | 
			
		||||
        let texts = vec![text];
 | 
			
		||||
        let mut embedding = match self {
 | 
			
		||||
            Embedder::HuggingFace(embedder) => embedder.embed(texts),
 | 
			
		||||
            Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline),
 | 
			
		||||
            Embedder::Ollama(embedder) => embedder.embed(&texts, deadline),
 | 
			
		||||
            Embedder::UserProvided(embedder) => embedder.embed(&texts),
 | 
			
		||||
            Embedder::Rest(embedder) => embedder.embed(texts, deadline),
 | 
			
		||||
            Embedder::Composite(embedder) => embedder.search.embed(texts, deadline),
 | 
			
		||||
        if let Some(cache) = self.cache() {
 | 
			
		||||
            if let Some(embedding) = cache.get(text) {
 | 
			
		||||
                tracing::trace!(text, "embedding found in cache");
 | 
			
		||||
                return Ok(embedding);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        let embedding = match self {
 | 
			
		||||
            Embedder::HuggingFace(embedder) => embedder.embed_one(text),
 | 
			
		||||
            Embedder::OpenAi(embedder) => {
 | 
			
		||||
                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
 | 
			
		||||
            }
 | 
			
		||||
            Embedder::Ollama(embedder) => {
 | 
			
		||||
                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
 | 
			
		||||
            }
 | 
			
		||||
            Embedder::UserProvided(embedder) => embedder.embed_one(text),
 | 
			
		||||
            Embedder::Rest(embedder) => embedder
 | 
			
		||||
                .embed_ref(&[text], deadline)?
 | 
			
		||||
                .pop()
 | 
			
		||||
                .ok_or_else(EmbedError::missing_embedding),
 | 
			
		||||
            Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline),
 | 
			
		||||
        }?;
 | 
			
		||||
        let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?;
 | 
			
		||||
 | 
			
		||||
        if let Some(cache) = self.cache() {
 | 
			
		||||
            cache.put(text.to_owned(), embedding.clone());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(embedding)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -759,6 +827,17 @@ impl Embedder {
 | 
			
		||||
            Embedder::Composite(embedder) => embedder.index.uses_document_template(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn cache(&self) -> Option<&EmbeddingCache> {
 | 
			
		||||
        match self {
 | 
			
		||||
            Embedder::HuggingFace(embedder) => Some(embedder.cache()),
 | 
			
		||||
            Embedder::OpenAi(embedder) => Some(embedder.cache()),
 | 
			
		||||
            Embedder::UserProvided(_) => None,
 | 
			
		||||
            Embedder::Ollama(embedder) => Some(embedder.cache()),
 | 
			
		||||
            Embedder::Rest(embedder) => Some(embedder.cache()),
 | 
			
		||||
            Embedder::Composite(embedder) => embedder.search.cache(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _;
 | 
			
		||||
 | 
			
		||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
 | 
			
		||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
 | 
			
		||||
use super::{DistributionShift, REQUEST_PARALLELISM};
 | 
			
		||||
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
 | 
			
		||||
use crate::error::FaultSource;
 | 
			
		||||
use crate::vector::Embedding;
 | 
			
		||||
use crate::ThreadPoolNoAbort;
 | 
			
		||||
@@ -75,9 +75,10 @@ impl EmbedderOptions {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Embedder {
 | 
			
		||||
    pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
 | 
			
		||||
    pub fn new(options: EmbedderOptions, cache_cap: usize) -> Result<Self, NewEmbedderError> {
 | 
			
		||||
        let rest_embedder = match RestEmbedder::new(
 | 
			
		||||
            options.into_rest_embedder_config()?,
 | 
			
		||||
            cache_cap,
 | 
			
		||||
            super::rest::ConfigurationSource::Ollama,
 | 
			
		||||
        ) {
 | 
			
		||||
            Ok(embedder) => embedder,
 | 
			
		||||
@@ -182,6 +183,10 @@ impl Embedder {
 | 
			
		||||
    pub fn distribution(&self) -> Option<DistributionShift> {
 | 
			
		||||
        self.rest_embedder.distribution()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(super) fn cache(&self) -> &EmbeddingCache {
 | 
			
		||||
        self.rest_embedder.cache()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn get_ollama_path() -> String {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ use rayon::slice::ParallelSlice as _;
 | 
			
		||||
 | 
			
		||||
use super::error::{EmbedError, NewEmbedderError};
 | 
			
		||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
 | 
			
		||||
use super::{DistributionShift, REQUEST_PARALLELISM};
 | 
			
		||||
use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
 | 
			
		||||
use crate::error::FaultSource;
 | 
			
		||||
use crate::vector::error::EmbedErrorKind;
 | 
			
		||||
use crate::vector::Embedding;
 | 
			
		||||
@@ -176,7 +176,7 @@ pub struct Embedder {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Embedder {
 | 
			
		||||
    pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
 | 
			
		||||
    pub fn new(options: EmbedderOptions, cache_cap: usize) -> Result<Self, NewEmbedderError> {
 | 
			
		||||
        let mut inferred_api_key = Default::default();
 | 
			
		||||
        let api_key = options.api_key.as_ref().unwrap_or_else(|| {
 | 
			
		||||
            inferred_api_key = infer_api_key();
 | 
			
		||||
@@ -201,6 +201,7 @@ impl Embedder {
 | 
			
		||||
                }),
 | 
			
		||||
                headers: Default::default(),
 | 
			
		||||
            },
 | 
			
		||||
            cache_cap,
 | 
			
		||||
            super::rest::ConfigurationSource::OpenAi,
 | 
			
		||||
        )?;
 | 
			
		||||
 | 
			
		||||
@@ -318,6 +319,10 @@ impl Embedder {
 | 
			
		||||
    pub fn distribution(&self) -> Option<DistributionShift> {
 | 
			
		||||
        self.options.distribution()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(super) fn cache(&self) -> &EmbeddingCache {
 | 
			
		||||
        self.rest_embedder.cache()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl fmt::Debug for Embedder {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,9 @@ use serde::{Deserialize, Serialize};
 | 
			
		||||
 | 
			
		||||
use super::error::EmbedErrorKind;
 | 
			
		||||
use super::json_template::ValueTemplate;
 | 
			
		||||
use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM};
 | 
			
		||||
use super::{
 | 
			
		||||
    DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
 | 
			
		||||
};
 | 
			
		||||
use crate::error::FaultSource;
 | 
			
		||||
use crate::ThreadPoolNoAbort;
 | 
			
		||||
 | 
			
		||||
@@ -75,6 +77,7 @@ pub struct Embedder {
 | 
			
		||||
    data: EmbedderData,
 | 
			
		||||
    dimensions: usize,
 | 
			
		||||
    distribution: Option<DistributionShift>,
 | 
			
		||||
    cache: EmbeddingCache,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// All data needed to perform requests and parse responses
 | 
			
		||||
@@ -123,6 +126,7 @@ enum InputType {
 | 
			
		||||
impl Embedder {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        options: EmbedderOptions,
 | 
			
		||||
        cache_cap: usize,
 | 
			
		||||
        configuration_source: ConfigurationSource,
 | 
			
		||||
    ) -> Result<Self, NewEmbedderError> {
 | 
			
		||||
        let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
 | 
			
		||||
@@ -152,7 +156,12 @@ impl Embedder {
 | 
			
		||||
            infer_dimensions(&data)?
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        Ok(Self { data, dimensions, distribution: options.distribution })
 | 
			
		||||
        Ok(Self {
 | 
			
		||||
            data,
 | 
			
		||||
            dimensions,
 | 
			
		||||
            distribution: options.distribution,
 | 
			
		||||
            cache: EmbeddingCache::new(cache_cap),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn embed(
 | 
			
		||||
@@ -256,6 +265,10 @@ impl Embedder {
 | 
			
		||||
    pub fn distribution(&self) -> Option<DistributionShift> {
 | 
			
		||||
        self.distribution
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(super) fn cache(&self) -> &EmbeddingCache {
 | 
			
		||||
        &self.cache
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user