Add embedding cache

This commit is contained in:
Louis Dureuil
2025-03-13 11:13:14 +01:00
parent d9111fe8ce
commit b08544e86d
8 changed files with 159 additions and 19 deletions

View File

@@ -9,7 +9,10 @@ use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate;
use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM};
use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, CAP_PER_THREAD,
REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
@@ -75,6 +78,7 @@ pub struct Embedder {
data: EmbedderData,
dimensions: usize,
distribution: Option<DistributionShift>,
cache: EmbeddingCache,
}
/// All data needed to perform requests and parse responses
@@ -152,7 +156,12 @@ impl Embedder {
infer_dimensions(&data)?
};
Ok(Self { data, dimensions, distribution: options.distribution })
Ok(Self {
data,
dimensions,
distribution: options.distribution,
cache: EmbeddingCache::new(CAP_PER_THREAD),
})
}
pub fn embed(
@@ -256,6 +265,10 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution
}
pub(super) fn cache(&self) -> &EmbeddingCache {
&self.cache
}
}
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {