mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-26 05:26:27 +00:00 
			
		
		
		
	Update ollama and openai impls to use the rest embedder internally
This commit is contained in:
		| @@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|     prompt_reader: grenad::Reader<R>, | ||||
|     indexer: GrenadParameters, | ||||
|     embedder: Arc<Embedder>, | ||||
|     request_threads: &rayon::ThreadPool, | ||||
| ) -> Result<grenad::Reader<BufReader<File>>> { | ||||
|     puffin::profile_function!(); | ||||
|     let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism | ||||
| @@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|  | ||||
|         if chunks.len() == chunks.capacity() { | ||||
|             let chunked_embeds = embedder | ||||
|                 .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) | ||||
|                 .embed_chunks( | ||||
|                     std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)), | ||||
|                     request_threads, | ||||
|                 ) | ||||
|                 .map_err(crate::vector::Error::from) | ||||
|                 .map_err(crate::Error::from)?; | ||||
|  | ||||
| @@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|     // send last chunk | ||||
|     if !chunks.is_empty() { | ||||
|         let chunked_embeds = embedder | ||||
|             .embed_chunks(std::mem::take(&mut chunks)) | ||||
|             .embed_chunks(std::mem::take(&mut chunks), request_threads) | ||||
|             .map_err(crate::vector::Error::from) | ||||
|             .map_err(crate::Error::from)?; | ||||
|         for (docid, embeddings) in chunks_ids | ||||
| @@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|  | ||||
|     if !current_chunk.is_empty() { | ||||
|         let embeds = embedder | ||||
|             .embed_chunks(vec![std::mem::take(&mut current_chunk)]) | ||||
|             .embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads) | ||||
|             .map_err(crate::vector::Error::from) | ||||
|             .map_err(crate::Error::from)?; | ||||
|  | ||||
|   | ||||
| @@ -238,7 +238,15 @@ fn send_original_documents_data( | ||||
|  | ||||
|     let documents_chunk_cloned = original_documents_chunk.clone(); | ||||
|     let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); | ||||
|  | ||||
|     let request_threads = rayon::ThreadPoolBuilder::new() | ||||
|         .num_threads(crate::vector::REQUEST_PARALLELISM) | ||||
|         .thread_name(|index| format!("embedding-request-{index}")) | ||||
|         .build() | ||||
|         .unwrap(); | ||||
|  | ||||
|     rayon::spawn(move || { | ||||
|         /// FIXME: unwrap | ||||
|         for (name, (embedder, prompt)) in embedders { | ||||
|             let result = extract_vector_points( | ||||
|                 documents_chunk_cloned.clone(), | ||||
| @@ -249,7 +257,12 @@ fn send_original_documents_data( | ||||
|             ); | ||||
|             match result { | ||||
|                 Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { | ||||
|                     let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { | ||||
|                     let embeddings = match extract_embeddings( | ||||
|                         prompts, | ||||
|                         indexer, | ||||
|                         embedder.clone(), | ||||
|                         &request_threads, | ||||
|                     ) { | ||||
|                         Ok(results) => Some(results), | ||||
|                         Err(error) => { | ||||
|                             let _ = lmdb_writer_sx_cloned.send(Err(error)); | ||||
|   | ||||
| @@ -2,9 +2,7 @@ use std::path::PathBuf; | ||||
|  | ||||
| use hf_hub::api::sync::ApiError; | ||||
|  | ||||
| use super::ollama::OllamaError; | ||||
| use crate::error::FaultSource; | ||||
| use crate::vector::openai::OpenAiError; | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("Error while generating embeddings: {inner}")] | ||||
| @@ -52,43 +50,12 @@ pub enum EmbedErrorKind { | ||||
|     TensorValue(candle_core::Error), | ||||
|     #[error("could not run model: {0}")] | ||||
|     ModelForward(candle_core::Error), | ||||
|     #[error("could not reach OpenAI: {0}")] | ||||
|     OpenAiNetwork(ureq::Transport), | ||||
|     #[error("unexpected response from OpenAI: {0}")] | ||||
|     OpenAiUnexpected(ureq::Error), | ||||
|     #[error("could not authenticate against OpenAI: {0:?}")] | ||||
|     OpenAiAuth(Option<OpenAiError>), | ||||
|     #[error("sent too many requests to OpenAI: {0:?}")] | ||||
|     OpenAiTooManyRequests(Option<OpenAiError>), | ||||
|     #[error("received internal error from OpenAI: {0:?}")] | ||||
|     OpenAiInternalServerError(Option<OpenAiError>), | ||||
|     #[error("sent too many tokens in a request to OpenAI: {0:?}")] | ||||
|     OpenAiTooManyTokens(Option<OpenAiError>), | ||||
|     #[error("received unhandled HTTP status code {0} from OpenAI")] | ||||
|     OpenAiUnhandledStatusCode(u16), | ||||
|     #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] | ||||
|     ManualEmbed(String), | ||||
|     #[error("could not initialize asynchronous runtime: {0}")] | ||||
|     OpenAiRuntimeInit(std::io::Error), | ||||
|     #[error("initializing web client for sending embedding requests failed: {0}")] | ||||
|     InitWebClient(reqwest::Error), | ||||
|     // Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends. | ||||
|     #[error("unexpected response from Ollama: {0}")] | ||||
|     OllamaUnexpected(reqwest::Error), | ||||
|     #[error("sent too many requests to Ollama: {0}")] | ||||
|     OllamaTooManyRequests(OllamaError), | ||||
|     #[error("received internal error from Ollama: {0}")] | ||||
|     OllamaInternalServerError(OllamaError), | ||||
|     #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] | ||||
|     OllamaModelNotFoundError(OllamaError), | ||||
|     #[error("received unhandled HTTP status code {0} from Ollama")] | ||||
|     OllamaUnhandledStatusCode(u16), | ||||
|     #[error("error serializing template context: {0}")] | ||||
|     RestTemplateContextSerialization(liquid::Error), | ||||
|     #[error( | ||||
|         "error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'" | ||||
|     )] | ||||
|     RestTemplateError(liquid::Error), | ||||
|     #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")] | ||||
|     OllamaModelNotFoundError(Option<String>), | ||||
|     #[error("error deserialization the response body as JSON: {0}")] | ||||
|     RestResponseDeserialization(std::io::Error), | ||||
|     #[error("component `{0}` not found in path `{1}` in response: `{2}`")] | ||||
| @@ -128,77 +95,14 @@ impl EmbedError { | ||||
|         Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_network(inner: ureq::Transport) -> Self { | ||||
|         Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_unexpected(inner: ureq::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_auth_error(inner: Option<OpenAiError>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_too_many_requests(inner: Option<OpenAiError>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_internal_server_error(inner: Option<OpenAiError>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_too_many_tokens(inner: Option<OpenAiError>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { | ||||
|         Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError { | ||||
|     pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestTemplateContextSerialization(error), | ||||
|             fault: FaultSource::Bug, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestResponseDeserialization(error), | ||||
| @@ -335,17 +239,6 @@ impl NewEmbedderError { | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { | ||||
|         Self { | ||||
|             kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), | ||||
|             fault: FaultSource::User, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| @@ -392,7 +285,4 @@ pub enum NewEmbedderErrorKind { | ||||
|     CouldNotDetermineDimension(EmbedError), | ||||
|     #[error("loading model failed: {0}")] | ||||
|     LoadModel(candle_core::Error), | ||||
|     // openai | ||||
|     #[error("The API key passed to Authorization error was in an invalid format: {0}")] | ||||
|     InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), | ||||
| } | ||||
|   | ||||
| @@ -17,6 +17,8 @@ pub use self::error::Error; | ||||
|  | ||||
| pub type Embedding = Vec<f32>; | ||||
|  | ||||
| pub const REQUEST_PARALLELISM: usize = 40; | ||||
|  | ||||
| /// One or multiple embeddings stored consecutively in a flat vector. | ||||
| pub struct Embeddings<F> { | ||||
|     data: Vec<F>, | ||||
| @@ -99,7 +101,7 @@ pub enum Embedder { | ||||
|     /// An embedder based on running local models, fetched from the Hugging Face Hub. | ||||
|     HuggingFace(hf::Embedder), | ||||
|     /// An embedder based on making embedding queries against the OpenAI API. | ||||
|     OpenAi(openai::sync::Embedder), | ||||
|     OpenAi(openai::Embedder), | ||||
|     /// An embedder based on the user providing the embeddings in the documents and queries. | ||||
|     UserProvided(manual::Embedder), | ||||
|     Ollama(ollama::Embedder), | ||||
| @@ -202,7 +204,7 @@ impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { | ||||
|         Ok(match options { | ||||
|             EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), | ||||
|             EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?), | ||||
|             EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), | ||||
|             EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?), | ||||
|             EmbedderOptions::UserProvided(options) => { | ||||
|                 Self::UserProvided(manual::Embedder::new(options)) | ||||
| @@ -213,17 +215,14 @@ impl Embedder { | ||||
|     /// Embed one or multiple texts. | ||||
|     /// | ||||
|     /// Each text can be embedded as one or multiple embeddings. | ||||
|     pub async fn embed( | ||||
|     pub fn embed( | ||||
|         &self, | ||||
|         texts: Vec<String>, | ||||
|     ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed(texts), | ||||
|             Embedder::OpenAi(embedder) => embedder.embed(texts), | ||||
|             Embedder::Ollama(embedder) => { | ||||
|                 let client = embedder.new_client()?; | ||||
|                 embedder.embed(texts, &client).await | ||||
|             } | ||||
|             Embedder::Ollama(embedder) => embedder.embed(texts), | ||||
|             Embedder::UserProvided(embedder) => embedder.embed(texts), | ||||
|         } | ||||
|     } | ||||
| @@ -231,18 +230,15 @@ impl Embedder { | ||||
|     /// Embed multiple chunks of texts. | ||||
|     /// | ||||
|     /// Each chunk is composed of one or multiple texts. | ||||
|     /// | ||||
|     /// # Panics | ||||
|     /// | ||||
|     /// - if called from an asynchronous context | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), | ||||
|             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), | ||||
|             Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks), | ||||
|             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), | ||||
|             Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads), | ||||
|             Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -1,293 +1,94 @@ | ||||
| // Copied from "openai.rs" with the sections I actually understand changed for Ollama. | ||||
| // The common components of the Ollama and OpenAI interfaces might need to be extracted. | ||||
| use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
|  | ||||
| use std::fmt::Display; | ||||
|  | ||||
| use reqwest::StatusCode; | ||||
|  | ||||
| use super::error::{EmbedError, NewEmbedderError}; | ||||
| use super::openai::Retry; | ||||
| use super::{DistributionShift, Embedding, Embeddings}; | ||||
| use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; | ||||
| use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; | ||||
| use super::{DistributionShift, Embeddings}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     headers: reqwest::header::HeaderMap, | ||||
|     options: EmbedderOptions, | ||||
|     rest_embedder: RestEmbedder, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||
| pub struct EmbedderOptions { | ||||
|     pub embedding_model: EmbeddingModel, | ||||
| } | ||||
|  | ||||
| #[derive( | ||||
|     Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr, | ||||
| )] | ||||
| #[deserr(deny_unknown_fields)] | ||||
| pub struct EmbeddingModel { | ||||
|     name: String, | ||||
|     dimensions: usize, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, serde::Serialize)] | ||||
| struct OllamaRequest<'a> { | ||||
|     model: &'a str, | ||||
|     prompt: &'a str, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, serde::Deserialize)] | ||||
| struct OllamaResponse { | ||||
|     embedding: Embedding, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, serde::Deserialize)] | ||||
| pub struct OllamaError { | ||||
|     error: String, | ||||
| } | ||||
|  | ||||
| impl EmbeddingModel { | ||||
|     pub fn max_token(&self) -> usize { | ||||
|         // this might not be the same for all models | ||||
|         8192 | ||||
|     } | ||||
|  | ||||
|     pub fn default_dimensions(&self) -> usize { | ||||
|         // Dimensions for nomic-embed-text | ||||
|         768 | ||||
|     } | ||||
|  | ||||
|     pub fn name(&self) -> String { | ||||
|         self.name.clone() | ||||
|     } | ||||
|  | ||||
|     pub fn from_name(name: &str) -> Self { | ||||
|         Self { name: name.to_string(), dimensions: 0 } | ||||
|     } | ||||
|  | ||||
|     pub fn supports_overriding_dimensions(&self) -> bool { | ||||
|         false | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for EmbeddingModel { | ||||
|     fn default() -> Self { | ||||
|         Self { name: "nomic-embed-text".to_string(), dimensions: 0 } | ||||
|     } | ||||
|     pub embedding_model: String, | ||||
| } | ||||
|  | ||||
| impl EmbedderOptions { | ||||
|     pub fn with_default_model() -> Self { | ||||
|         Self { embedding_model: Default::default() } | ||||
|         Self { embedding_model: "nomic-embed-text".into() } | ||||
|     } | ||||
|  | ||||
|     pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { | ||||
|     pub fn with_embedding_model(embedding_model: String) -> Self { | ||||
|         Self { embedding_model } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> { | ||||
|         reqwest::ClientBuilder::new() | ||||
|             .default_headers(self.headers.clone()) | ||||
|             .build() | ||||
|             .map_err(EmbedError::openai_initialize_web_client) | ||||
|     } | ||||
|  | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|         let mut headers = reqwest::header::HeaderMap::new(); | ||||
|         headers.insert( | ||||
|             reqwest::header::CONTENT_TYPE, | ||||
|             reqwest::header::HeaderValue::from_static("application/json"), | ||||
|         ); | ||||
|         let model = options.embedding_model.as_str(); | ||||
|         let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { | ||||
|             api_key: None, | ||||
|             distribution: None, | ||||
|             dimensions: None, | ||||
|             url: get_ollama_path(), | ||||
|             query: serde_json::json!({ | ||||
|                 "model": model, | ||||
|             }), | ||||
|             input_field: vec!["prompt".to_owned()], | ||||
|             path_to_embeddings: Default::default(), | ||||
|             embedding_object: vec!["embedding".to_owned()], | ||||
|             input_type: super::rest::InputType::Text, | ||||
|         }) { | ||||
|             Ok(embedder) => embedder, | ||||
|             Err(NewEmbedderError { | ||||
|                 kind: | ||||
|                     NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError { | ||||
|                         kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error), | ||||
|                         fault: _, | ||||
|                     }), | ||||
|                 fault: _, | ||||
|             }) => { | ||||
|                 return Err(NewEmbedderError::could_not_determine_dimension( | ||||
|                     EmbedError::ollama_model_not_found(error), | ||||
|                 )) | ||||
|             } | ||||
|             Err(error) => return Err(error), | ||||
|         }; | ||||
|  | ||||
|         let mut embedder = Self { options, headers }; | ||||
|  | ||||
|         let rt = tokio::runtime::Builder::new_current_thread() | ||||
|             .enable_io() | ||||
|             .enable_time() | ||||
|             .build() | ||||
|             .map_err(EmbedError::openai_runtime_init) | ||||
|             .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; | ||||
|  | ||||
|         // Get dimensions from Ollama | ||||
|         let request = | ||||
|             OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; | ||||
|         // TODO: Refactor into shared error type | ||||
|         let client = embedder | ||||
|             .new_client() | ||||
|             .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; | ||||
|  | ||||
|         rt.block_on(async move { | ||||
|             let response = client | ||||
|                 .post(get_ollama_path()) | ||||
|                 .json(&request) | ||||
|                 .send() | ||||
|                 .await | ||||
|                 .map_err(EmbedError::ollama_unexpected) | ||||
|                 .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; | ||||
|  | ||||
|             // Process error in case model not found | ||||
|             let response = Self::check_response(response).await.map_err(|_err| { | ||||
|                 let e = EmbedError::ollama_model_not_found(OllamaError { | ||||
|                     error: format!("model: {}", embedder.options.embedding_model.name()), | ||||
|                 }); | ||||
|                 NewEmbedderError::ollama_could_not_determine_dimension(e) | ||||
|             })?; | ||||
|  | ||||
|             let response: OllamaResponse = response | ||||
|                 .json() | ||||
|                 .await | ||||
|                 .map_err(EmbedError::ollama_unexpected) | ||||
|                 .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; | ||||
|  | ||||
|             let embedding = Embeddings::from_single_embedding(response.embedding); | ||||
|  | ||||
|             embedder.options.embedding_model.dimensions = embedding.dimension(); | ||||
|  | ||||
|             tracing::info!( | ||||
|                 "ollama model {} with dimensionality {} added", | ||||
|                 embedder.options.embedding_model.name(), | ||||
|                 embedding.dimension() | ||||
|             ); | ||||
|  | ||||
|             Ok(embedder) | ||||
|         }) | ||||
|         Ok(Self { rest_embedder }) | ||||
|     } | ||||
|  | ||||
|     async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> { | ||||
|         if !response.status().is_success() { | ||||
|             // Not the same number of possible error cases covered as with OpenAI. | ||||
|             match response.status() { | ||||
|                 StatusCode::TOO_MANY_REQUESTS => { | ||||
|                     let error_response: OllamaError = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::ollama_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|  | ||||
|                     return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( | ||||
|                         OllamaError { error: error_response.error }, | ||||
|                     ))); | ||||
|     pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         match self.rest_embedder.embed(texts) { | ||||
|             Ok(embeddings) => Ok(embeddings), | ||||
|             Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { | ||||
|                 Err(EmbedError::ollama_model_not_found(error)) | ||||
|             } | ||||
|                 StatusCode::SERVICE_UNAVAILABLE => { | ||||
|                     let error_response: OllamaError = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::ollama_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|                     return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( | ||||
|                         OllamaError { error: error_response.error }, | ||||
|                     ))); | ||||
|             Err(error) => Err(error), | ||||
|         } | ||||
|                 StatusCode::NOT_FOUND => { | ||||
|                     let error_response: OllamaError = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::ollama_unexpected) | ||||
|                         .map_err(Retry::give_up)?; | ||||
|  | ||||
|                     return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { | ||||
|                         error: error_response.error, | ||||
|                     }))); | ||||
|                 } | ||||
|                 code => { | ||||
|                     return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( | ||||
|                         code.as_u16(), | ||||
|                     ))); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         Ok(response) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed( | ||||
|         &self, | ||||
|         texts: Vec<String>, | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         // Ollama only embedds one document at a time. | ||||
|         let mut results = Vec::with_capacity(texts.len()); | ||||
|  | ||||
|         // The retry loop is inside the texts loop, might have to switch that around | ||||
|         for text in texts { | ||||
|             // Retries copied from openai.rs | ||||
|             for attempt in 0..7 { | ||||
|                 let retry_duration = match self.try_embed(&text, client).await { | ||||
|                     Ok(result) => { | ||||
|                         results.push(result); | ||||
|                         break; | ||||
|                     } | ||||
|                     Err(retry) => { | ||||
|                         tracing::warn!("Failed: {}", retry.error); | ||||
|                         retry.into_duration(attempt) | ||||
|                     } | ||||
|                 }?; | ||||
|                 tracing::warn!( | ||||
|                     "Attempt #{}, retrying after {}ms.", | ||||
|                     attempt, | ||||
|                     retry_duration.as_millis() | ||||
|                 ); | ||||
|                 tokio::time::sleep(retry_duration).await; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         Ok(results) | ||||
|     } | ||||
|  | ||||
|     async fn try_embed( | ||||
|         &self, | ||||
|         text: &str, | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Embeddings<f32>, Retry> { | ||||
|         let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text }; | ||||
|         let response = client | ||||
|             .post(get_ollama_path()) | ||||
|             .json(&request) | ||||
|             .send() | ||||
|             .await | ||||
|             .map_err(EmbedError::openai_network) | ||||
|             .map_err(Retry::retry_later)?; | ||||
|  | ||||
|         let response = Self::check_response(response).await?; | ||||
|  | ||||
|         let response: OllamaResponse = response | ||||
|             .json() | ||||
|             .await | ||||
|             .map_err(EmbedError::openai_unexpected) | ||||
|             .map_err(Retry::retry_later)?; | ||||
|  | ||||
|         tracing::trace!("response: {:?}", response.embedding); | ||||
|  | ||||
|         let embedding = Embeddings::from_single_embedding(response.embedding); | ||||
|         Ok(embedding) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         let rt = tokio::runtime::Builder::new_current_thread() | ||||
|             .enable_io() | ||||
|             .enable_time() | ||||
|             .build() | ||||
|             .map_err(EmbedError::openai_runtime_init)?; | ||||
|         let client = self.new_client()?; | ||||
|         rt.block_on(futures::future::try_join_all( | ||||
|             text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), | ||||
|         )) | ||||
|         threads.install(move || { | ||||
|             text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     // Defaults copied from openai.rs | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|         10 | ||||
|         self.rest_embedder.chunk_count_hint() | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         10 | ||||
|         self.rest_embedder.prompt_count_in_chunk_hint() | ||||
|     } | ||||
|  | ||||
|     pub fn dimensions(&self) -> usize { | ||||
|         self.options.embedding_model.dimensions | ||||
|         self.rest_embedder.dimensions() | ||||
|     } | ||||
|  | ||||
|     pub fn distribution(&self) -> Option<DistributionShift> { | ||||
| @@ -295,12 +96,6 @@ impl Embedder { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Display for OllamaError { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         write!(f, "{}", self.error) | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn get_ollama_path() -> String { | ||||
|     // Important: Hostname not enough, has to be entire path to embeddings endpoint | ||||
|     std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| use std::fmt::Display; | ||||
|  | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; | ||||
|  | ||||
| use super::error::{EmbedError, NewEmbedderError}; | ||||
| use super::{DistributionShift, Embedding, Embeddings}; | ||||
| use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; | ||||
| use super::{DistributionShift, Embeddings}; | ||||
| use crate::vector::error::EmbedErrorKind; | ||||
|  | ||||
| #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||
| pub struct EmbedderOptions { | ||||
| @@ -12,6 +12,32 @@ pub struct EmbedderOptions { | ||||
|     pub dimensions: Option<usize>, | ||||
| } | ||||
|  | ||||
| impl EmbedderOptions { | ||||
|     pub fn dimensions(&self) -> usize { | ||||
|         if self.embedding_model.supports_overriding_dimensions() { | ||||
|             self.dimensions.unwrap_or(self.embedding_model.default_dimensions()) | ||||
|         } else { | ||||
|             self.embedding_model.default_dimensions() | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn query(&self) -> serde_json::Value { | ||||
|         let model = self.embedding_model.name(); | ||||
|  | ||||
|         let mut query = serde_json::json!({ | ||||
|             "model": model, | ||||
|         }); | ||||
|  | ||||
|         if self.embedding_model.supports_overriding_dimensions() { | ||||
|             if let Some(dimensions) = self.dimensions { | ||||
|                 query["dimensions"] = dimensions.into(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         query | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive( | ||||
|     Debug, | ||||
|     Clone, | ||||
| @@ -117,261 +143,57 @@ impl EmbedderOptions { | ||||
|     } | ||||
| } | ||||
|  | ||||
| // retrying in case of failure | ||||
|  | ||||
| pub struct Retry { | ||||
|     pub error: EmbedError, | ||||
|     strategy: RetryStrategy, | ||||
| } | ||||
|  | ||||
| pub enum RetryStrategy { | ||||
|     GiveUp, | ||||
|     Retry, | ||||
|     RetryTokenized, | ||||
|     RetryAfterRateLimit, | ||||
| } | ||||
|  | ||||
| impl Retry { | ||||
|     pub fn give_up(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::GiveUp } | ||||
|     } | ||||
|  | ||||
|     pub fn retry_later(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::Retry } | ||||
|     } | ||||
|  | ||||
|     pub fn retry_tokenized(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::RetryTokenized } | ||||
|     } | ||||
|  | ||||
|     pub fn rate_limited(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::RetryAfterRateLimit } | ||||
|     } | ||||
|  | ||||
|     pub fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> { | ||||
|         match self.strategy { | ||||
|             RetryStrategy::GiveUp => Err(self.error), | ||||
|             RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), | ||||
|             RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), | ||||
|             RetryStrategy::RetryAfterRateLimit => { | ||||
|                 Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn must_tokenize(&self) -> bool { | ||||
|         matches!(self.strategy, RetryStrategy::RetryTokenized) | ||||
|     } | ||||
|  | ||||
|     pub fn into_error(self) -> EmbedError { | ||||
|         self.error | ||||
|     } | ||||
| } | ||||
|  | ||||
| // openai api structs | ||||
|  | ||||
| #[derive(Debug, Serialize)] | ||||
| struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { | ||||
|     model: &'a str, | ||||
|     input: &'a [S], | ||||
|     #[serde(skip_serializing_if = "Option::is_none")] | ||||
|     dimensions: Option<usize>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Serialize)] | ||||
| struct OpenAiTokensRequest<'a> { | ||||
|     model: &'a str, | ||||
|     input: &'a [usize], | ||||
|     #[serde(skip_serializing_if = "Option::is_none")] | ||||
|     dimensions: Option<usize>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| struct OpenAiResponse { | ||||
|     data: Vec<OpenAiEmbedding>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| struct OpenAiErrorResponse { | ||||
|     error: OpenAiError, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| pub struct OpenAiError { | ||||
|     message: String, | ||||
|     // type: String, | ||||
|     code: Option<String>, | ||||
| } | ||||
|  | ||||
| impl Display for OpenAiError { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match &self.code { | ||||
|             Some(code) => write!(f, "{} ({})", self.message, code), | ||||
|             None => write!(f, "{}", self.message), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| struct OpenAiEmbedding { | ||||
|     embedding: Embedding, | ||||
|     // object: String, | ||||
|     // index: usize, | ||||
| } | ||||
|  | ||||
| fn infer_api_key() -> String { | ||||
|     std::env::var("MEILI_OPENAI_API_KEY") | ||||
|         .or_else(|_| std::env::var("OPENAI_API_KEY")) | ||||
|         .unwrap_or_default() | ||||
| } | ||||
|  | ||||
| pub mod sync { | ||||
|     use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; | ||||
|  | ||||
|     use super::{ | ||||
|         EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest, | ||||
|         OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL, | ||||
|     }; | ||||
|     use crate::vector::DistributionShift; | ||||
|  | ||||
|     const REQUEST_PARALLELISM: usize = 10; | ||||
|  | ||||
|     #[derive(Debug)] | ||||
|     pub struct Embedder { | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     tokenizer: tiktoken_rs::CoreBPE, | ||||
|         options: super::EmbedderOptions, | ||||
|         bearer: String, | ||||
|         threads: rayon::ThreadPool, | ||||
|     } | ||||
|     rest_embedder: RestEmbedder, | ||||
|     options: EmbedderOptions, | ||||
| } | ||||
|  | ||||
|     impl Embedder { | ||||
|         pub fn new(options: super::EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|         let mut inferred_api_key = Default::default(); | ||||
|         let api_key = options.api_key.as_ref().unwrap_or_else(|| { | ||||
|                 inferred_api_key = super::infer_api_key(); | ||||
|             inferred_api_key = infer_api_key(); | ||||
|             &inferred_api_key | ||||
|         }); | ||||
|             let bearer = format!("Bearer {api_key}"); | ||||
|  | ||||
|         let rest_embedder = RestEmbedder::new(RestEmbedderOptions { | ||||
|             api_key: Some(api_key.clone()), | ||||
|             distribution: options.embedding_model.distribution(), | ||||
|             dimensions: Some(options.dimensions()), | ||||
|             url: OPENAI_EMBEDDINGS_URL.to_owned(), | ||||
|             query: options.query(), | ||||
|             input_field: vec!["input".to_owned()], | ||||
|             input_type: crate::vector::rest::InputType::TextArray, | ||||
|             path_to_embeddings: vec!["data".to_owned()], | ||||
|             embedding_object: vec!["embedding".to_owned()], | ||||
|         })?; | ||||
|  | ||||
|         // looking at the code it is very unclear that this can actually fail. | ||||
|         let tokenizer = tiktoken_rs::cl100k_base().unwrap(); | ||||
|  | ||||
|             // FIXME: unwrap | ||||
|             let threads = rayon::ThreadPoolBuilder::new() | ||||
|                 .num_threads(REQUEST_PARALLELISM) | ||||
|                 .thread_name(|index| format!("embedder-chunk-{index}")) | ||||
|                 .build() | ||||
|                 .unwrap(); | ||||
|  | ||||
|             Ok(Self { options, bearer, tokenizer, threads }) | ||||
|         Ok(Self { options, rest_embedder, tokenizer }) | ||||
|     } | ||||
|  | ||||
|     pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|             let mut tokenized = false; | ||||
|  | ||||
|             let client = ureq::agent(); | ||||
|  | ||||
|             for attempt in 0..7 { | ||||
|                 let result = if tokenized { | ||||
|                     self.try_embed_tokenized(&texts, &client) | ||||
|                 } else { | ||||
|                     self.try_embed(&texts, &client) | ||||
|                 }; | ||||
|  | ||||
|                 let retry_duration = match result { | ||||
|                     Ok(embeddings) => return Ok(embeddings), | ||||
|                     Err(retry) => { | ||||
|                         tracing::warn!("Failed: {}", retry.error); | ||||
|                         tokenized |= retry.must_tokenize(); | ||||
|                         retry.into_duration(attempt) | ||||
|                     } | ||||
|                 }?; | ||||
|  | ||||
|                 let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute | ||||
|                 tracing::warn!( | ||||
|                     "Attempt #{}, retrying after {}ms.", | ||||
|                     attempt, | ||||
|                     retry_duration.as_millis() | ||||
|                 ); | ||||
|                 std::thread::sleep(retry_duration); | ||||
|             } | ||||
|  | ||||
|             let result = if tokenized { | ||||
|                 self.try_embed_tokenized(&texts, &client) | ||||
|             } else { | ||||
|                 self.try_embed(&texts, &client) | ||||
|             }; | ||||
|  | ||||
|             result.map_err(Retry::into_error) | ||||
|         } | ||||
|  | ||||
|         fn check_response( | ||||
|             response: Result<ureq::Response, ureq::Error>, | ||||
|         ) -> Result<ureq::Response, Retry> { | ||||
|             match response { | ||||
|                 Ok(response) => Ok(response), | ||||
|                 Err(ureq::Error::Status(code, response)) => { | ||||
|                     let error_response: Option<OpenAiErrorResponse> = response.into_json().ok(); | ||||
|                     let error = error_response.map(|response| response.error); | ||||
|                     Err(match code { | ||||
|                         401 => Retry::give_up(EmbedError::openai_auth_error(error)), | ||||
|                         429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)), | ||||
|                         400 => { | ||||
|                             tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); | ||||
|  | ||||
|                             Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error)) | ||||
|                         } | ||||
|                         500..=599 => { | ||||
|                             Retry::retry_later(EmbedError::openai_internal_server_error(error)) | ||||
|                         } | ||||
|                         x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)), | ||||
|                     }) | ||||
|                 } | ||||
|                 Err(ureq::Error::Transport(transport)) => { | ||||
|                     Err(Retry::retry_later(EmbedError::openai_network(transport))) | ||||
|         match self.rest_embedder.embed_ref(&texts) { | ||||
|             Ok(embeddings) => Ok(embeddings), | ||||
|             Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => { | ||||
|                 tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); | ||||
|                 self.try_embed_tokenized(&texts) | ||||
|             } | ||||
|             Err(error) => Err(error), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|         fn try_embed<S: AsRef<str> + serde::Serialize>( | ||||
|             &self, | ||||
|             texts: &[S], | ||||
|             client: &ureq::Agent, | ||||
|         ) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|             for text in texts { | ||||
|                 tracing::trace!("Received prompt: {}", text.as_ref()) | ||||
|             } | ||||
|             let request = OpenAiRequest { | ||||
|                 model: self.options.embedding_model.name(), | ||||
|                 input: texts, | ||||
|                 dimensions: self.overriden_dimensions(), | ||||
|             }; | ||||
|             let response = client | ||||
|                 .post(OPENAI_EMBEDDINGS_URL) | ||||
|                 .set("Authorization", &self.bearer) | ||||
|                 .send_json(&request); | ||||
|  | ||||
|             let response = Self::check_response(response)?; | ||||
|  | ||||
|             let response: OpenAiResponse = response | ||||
|                 .into_json() | ||||
|                 .map_err(EmbedError::openai_unexpected) | ||||
|                 .map_err(Retry::retry_later)?; | ||||
|  | ||||
|             tracing::trace!("response: {:?}", response.data); | ||||
|  | ||||
|             Ok(response | ||||
|                 .data | ||||
|                 .into_iter() | ||||
|                 .map(|data| Embeddings::from_single_embedding(data.embedding)) | ||||
|                 .collect()) | ||||
|         } | ||||
|  | ||||
|         fn try_embed_tokenized( | ||||
|             &self, | ||||
|             text: &[String], | ||||
|             client: &ureq::Agent, | ||||
|         ) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|     fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         pub const OVERLAP_SIZE: usize = 200; | ||||
|         let mut all_embeddings = Vec::with_capacity(text.len()); | ||||
|         for text in text { | ||||
| @@ -379,7 +201,7 @@ pub mod sync { | ||||
|             let encoded = self.tokenizer.encode_ordinary(text.as_str()); | ||||
|             let len = encoded.len(); | ||||
|             if len < max_token_count { | ||||
|                     all_embeddings.append(&mut self.try_embed(&[text], client)?); | ||||
|                 all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
| @@ -387,94 +209,46 @@ pub mod sync { | ||||
|             let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); | ||||
|             while tokens.len() > max_token_count { | ||||
|                 let window = &tokens[..max_token_count]; | ||||
|                     embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap(); | ||||
|                 let embedding = self.rest_embedder.embed_tokens(window)?; | ||||
|                 /// FIXME: unwrap | ||||
|                 embeddings_for_prompt.append(embedding.into_inner()).unwrap(); | ||||
|  | ||||
|                 tokens = &tokens[max_token_count - OVERLAP_SIZE..]; | ||||
|             } | ||||
|  | ||||
|             // end of text | ||||
|                 embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap(); | ||||
|             let embedding = self.rest_embedder.embed_tokens(tokens)?; | ||||
|             /// FIXME: unwrap | ||||
|             embeddings_for_prompt.append(embedding.into_inner()).unwrap(); | ||||
|  | ||||
|             all_embeddings.push(embeddings_for_prompt); | ||||
|         } | ||||
|         Ok(all_embeddings) | ||||
|     } | ||||
|  | ||||
|         fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result<Embedding, Retry> { | ||||
|             for attempt in 0..9 { | ||||
|                 let duration = match self.try_embed_tokens(tokens, client) { | ||||
|                     Ok(embedding) => return Ok(embedding), | ||||
|                     Err(retry) => retry.into_duration(attempt), | ||||
|                 } | ||||
|                 .map_err(Retry::retry_later)?; | ||||
|  | ||||
|                 std::thread::sleep(duration); | ||||
|             } | ||||
|  | ||||
|             self.try_embed_tokens(tokens, client) | ||||
|                 .map_err(|retry| Retry::give_up(retry.into_error())) | ||||
|         } | ||||
|  | ||||
|         fn try_embed_tokens( | ||||
|             &self, | ||||
|             tokens: &[usize], | ||||
|             client: &ureq::Agent, | ||||
|         ) -> Result<Embedding, Retry> { | ||||
|             let request = OpenAiTokensRequest { | ||||
|                 model: self.options.embedding_model.name(), | ||||
|                 input: tokens, | ||||
|                 dimensions: self.overriden_dimensions(), | ||||
|             }; | ||||
|             let response = client | ||||
|                 .post(OPENAI_EMBEDDINGS_URL) | ||||
|                 .set("Authorization", &self.bearer) | ||||
|                 .send_json(&request); | ||||
|  | ||||
|             let response = Self::check_response(response)?; | ||||
|  | ||||
|             let mut response: OpenAiResponse = response | ||||
|                 .into_json() | ||||
|                 .map_err(EmbedError::openai_unexpected) | ||||
|                 .map_err(Retry::retry_later)?; | ||||
|  | ||||
|             Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) | ||||
|         } | ||||
|  | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|             self.threads | ||||
|                 .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) | ||||
|                 .collect() | ||||
|         threads.install(move || { | ||||
|             text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|             10 | ||||
|         self.rest_embedder.chunk_count_hint() | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|             10 | ||||
|         self.rest_embedder.prompt_count_in_chunk_hint() | ||||
|     } | ||||
|  | ||||
|     pub fn dimensions(&self) -> usize { | ||||
|             if self.options.embedding_model.supports_overriding_dimensions() { | ||||
|                 self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) | ||||
|             } else { | ||||
|                 self.options.embedding_model.default_dimensions() | ||||
|             } | ||||
|         self.options.dimensions() | ||||
|     } | ||||
|  | ||||
|     pub fn distribution(&self) -> Option<DistributionShift> { | ||||
|         self.options.embedding_model.distribution() | ||||
|     } | ||||
|  | ||||
|         fn overriden_dimensions(&self) -> Option<usize> { | ||||
|             if self.options.embedding_model.supports_overriding_dimensions() { | ||||
|                 self.options.dimensions | ||||
|             } else { | ||||
|                 None | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,9 +1,62 @@ | ||||
| use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
| use serde::Serialize; | ||||
|  | ||||
| use super::openai::Retry; | ||||
| use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError}; | ||||
| use crate::VectorOrArrayOfVectors; | ||||
| use super::{ | ||||
|     DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, | ||||
| }; | ||||
|  | ||||
| // retrying in case of failure | ||||
|  | ||||
| pub struct Retry { | ||||
|     pub error: EmbedError, | ||||
|     strategy: RetryStrategy, | ||||
| } | ||||
|  | ||||
| pub enum RetryStrategy { | ||||
|     GiveUp, | ||||
|     Retry, | ||||
|     RetryTokenized, | ||||
|     RetryAfterRateLimit, | ||||
| } | ||||
|  | ||||
| impl Retry { | ||||
|     pub fn give_up(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::GiveUp } | ||||
|     } | ||||
|  | ||||
|     pub fn retry_later(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::Retry } | ||||
|     } | ||||
|  | ||||
|     pub fn retry_tokenized(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::RetryTokenized } | ||||
|     } | ||||
|  | ||||
|     pub fn rate_limited(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::RetryAfterRateLimit } | ||||
|     } | ||||
|  | ||||
|     pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> { | ||||
|         match self.strategy { | ||||
|             RetryStrategy::GiveUp => Err(self.error), | ||||
|             RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))), | ||||
|             RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)), | ||||
|             RetryStrategy::RetryAfterRateLimit => { | ||||
|                 Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt))) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn must_tokenize(&self) -> bool { | ||||
|         matches!(self.strategy, RetryStrategy::RetryTokenized) | ||||
|     } | ||||
|  | ||||
|     pub fn into_error(self) -> EmbedError { | ||||
|         self.error | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     client: ureq::Agent, | ||||
|     options: EmbedderOptions, | ||||
| @@ -11,20 +64,35 @@ pub struct Embedder { | ||||
|     dimensions: usize, | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct EmbedderOptions { | ||||
|     api_key: Option<String>, | ||||
|     distribution: Option<DistributionShift>, | ||||
|     dimensions: Option<usize>, | ||||
|     url: String, | ||||
|     query: liquid::Template, | ||||
|     response_field: Vec<String>, | ||||
|     pub api_key: Option<String>, | ||||
|     pub distribution: Option<DistributionShift>, | ||||
|     pub dimensions: Option<usize>, | ||||
|     pub url: String, | ||||
|     pub query: serde_json::Value, | ||||
|     pub input_field: Vec<String>, | ||||
|     // path to the array of embeddings | ||||
|     pub path_to_embeddings: Vec<String>, | ||||
|     // shape of a single embedding | ||||
|     pub embedding_object: Vec<String>, | ||||
|     pub input_type: InputType, | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub enum InputType { | ||||
|     Text, | ||||
|     TextArray, | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|         let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}")); | ||||
|         let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); | ||||
|  | ||||
|         let client = ureq::agent(); | ||||
|         let client = ureq::AgentBuilder::new() | ||||
|             .max_idle_connections(REQUEST_PARALLELISM * 2) | ||||
|             .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) | ||||
|             .build(); | ||||
|  | ||||
|         let dimensions = if let Some(dimensions) = options.dimensions { | ||||
|             dimensions | ||||
| @@ -36,7 +104,20 @@ impl Embedder { | ||||
|     } | ||||
|  | ||||
|     pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice()) | ||||
|         embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError> | ||||
|     where | ||||
|         S: AsRef<str> + Serialize, | ||||
|     { | ||||
|         embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> { | ||||
|         let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; | ||||
|         // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error | ||||
|         Ok(embeddings.pop().unwrap()) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_chunks( | ||||
| @@ -44,17 +125,20 @@ impl Embedder { | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         threads | ||||
|             .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) | ||||
|             .collect() | ||||
|         threads.install(move || { | ||||
|             text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|         10 | ||||
|         super::REQUEST_PARALLELISM | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         10 | ||||
|         match self.options.input_type { | ||||
|             InputType::Text => 1, | ||||
|             InputType::TextArray => 10, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn dimensions(&self) -> usize { | ||||
| @@ -71,9 +155,9 @@ fn infer_dimensions( | ||||
|     options: &EmbedderOptions, | ||||
|     bearer: Option<&str>, | ||||
| ) -> Result<usize, NewEmbedderError> { | ||||
|     let v = embed(client, options, bearer, ["test"].as_slice()) | ||||
|     let v = embed(client, options, bearer, ["test"].as_slice(), 1) | ||||
|         .map_err(NewEmbedderError::could_not_determine_dimension)?; | ||||
|     // unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error | ||||
|     // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error | ||||
|     Ok(v.first().unwrap().dimension()) | ||||
| } | ||||
|  | ||||
| @@ -82,33 +166,57 @@ fn embed<S>( | ||||
|     options: &EmbedderOptions, | ||||
|     bearer: Option<&str>, | ||||
|     inputs: &[S], | ||||
|     expected_count: usize, | ||||
| ) -> Result<Vec<Embeddings<f32>>, EmbedError> | ||||
| where | ||||
|     S: serde::Serialize, | ||||
|     S: Serialize, | ||||
| { | ||||
|     let request = client.post(&options.url); | ||||
|     let request = | ||||
|         if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; | ||||
|     let request = request.set("Content-Type", "application/json"); | ||||
|  | ||||
|     let body = options | ||||
|         .query | ||||
|         .render( | ||||
|             &liquid::to_object(&serde_json::json!({ | ||||
|                 "input": inputs, | ||||
|             })) | ||||
|             .map_err(EmbedError::rest_template_context_serialization)?, | ||||
|         ) | ||||
|         .map_err(EmbedError::rest_template_render)?; | ||||
|     let input_value = match options.input_type { | ||||
|         InputType::Text => serde_json::json!(inputs.first()), | ||||
|         InputType::TextArray => serde_json::json!(inputs), | ||||
|     }; | ||||
|  | ||||
|     let body = match options.input_field.as_slice() { | ||||
|         [] => { | ||||
|             // inject input in body | ||||
|             input_value | ||||
|         } | ||||
|         [input] => { | ||||
|             let mut body = options.query.clone(); | ||||
|  | ||||
|             /// FIXME unwrap | ||||
|             body.as_object_mut().unwrap().insert(input.clone(), input_value); | ||||
|             body | ||||
|         } | ||||
|         [path @ .., input] => { | ||||
|             let mut body = options.query.clone(); | ||||
|  | ||||
|             /// FIXME unwrap | ||||
|             let mut current_value = &mut body; | ||||
|             for component in path { | ||||
|                 current_value = current_value | ||||
|                     .as_object_mut() | ||||
|                     .unwrap() | ||||
|                     .entry(component.clone()) | ||||
|                     .or_insert(serde_json::json!({})); | ||||
|             } | ||||
|  | ||||
|             current_value.as_object_mut().unwrap().insert(input.clone(), input_value); | ||||
|             body | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     for attempt in 0..7 { | ||||
|         let response = request.send_string(&body); | ||||
|         let response = request.clone().send_json(&body); | ||||
|         let result = check_response(response); | ||||
|  | ||||
|         let retry_duration = match result { | ||||
|             Ok(response) => { | ||||
|                 return response_to_embedding(response, &options.response_field, inputs.len()) | ||||
|             } | ||||
|             Ok(response) => return response_to_embedding(response, options, expected_count), | ||||
|             Err(retry) => { | ||||
|                 tracing::warn!("Failed: {}", retry.error); | ||||
|                 retry.into_duration(attempt) | ||||
| @@ -120,11 +228,11 @@ where | ||||
|         std::thread::sleep(retry_duration); | ||||
|     } | ||||
|  | ||||
|     let response = request.send_string(&body); | ||||
|     let response = request.send_json(&body); | ||||
|     let result = check_response(response); | ||||
|     result | ||||
|         .map_err(Retry::into_error) | ||||
|         .and_then(|response| response_to_embedding(response, &options.response_field, inputs.len())) | ||||
|         .and_then(|response| response_to_embedding(response, options, expected_count)) | ||||
| } | ||||
|  | ||||
| fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> { | ||||
| @@ -139,7 +247,10 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq: | ||||
|                 500..=599 => { | ||||
|                     Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) | ||||
|                 } | ||||
|                 x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), | ||||
|                 402..=499 => { | ||||
|                     Retry::give_up(EmbedError::rest_other_status_code(code, error_response)) | ||||
|                 } | ||||
|                 _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), | ||||
|             }) | ||||
|         } | ||||
|         Err(ureq::Error::Transport(transport)) => { | ||||
| @@ -148,34 +259,66 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq: | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn response_to_embedding<S: AsRef<str>>( | ||||
| fn response_to_embedding( | ||||
|     response: ureq::Response, | ||||
|     response_field: &[S], | ||||
|     options: &EmbedderOptions, | ||||
|     expected_count: usize, | ||||
| ) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|     let response: serde_json::Value = | ||||
|         response.into_json().map_err(EmbedError::rest_response_deserialization)?; | ||||
|  | ||||
|     let mut current_value = &response; | ||||
|     for component in response_field { | ||||
|     for component in &options.path_to_embeddings { | ||||
|         let component = component.as_ref(); | ||||
|         let current_value = current_value.get(component).ok_or_else(|| { | ||||
|             EmbedError::rest_response_missing_embeddings(response, component, response_field) | ||||
|         current_value = current_value.get(component).ok_or_else(|| { | ||||
|             EmbedError::rest_response_missing_embeddings( | ||||
|                 response.clone(), | ||||
|                 component, | ||||
|                 &options.path_to_embeddings, | ||||
|             ) | ||||
|         })?; | ||||
|     } | ||||
|  | ||||
|     let embeddings = match options.input_type { | ||||
|         InputType::Text => { | ||||
|             for component in &options.embedding_object { | ||||
|                 current_value = current_value.get(component).ok_or_else(|| { | ||||
|                     EmbedError::rest_response_missing_embeddings( | ||||
|                         response.clone(), | ||||
|                         component, | ||||
|                         &options.embedding_object, | ||||
|                     ) | ||||
|                 })?; | ||||
|             } | ||||
|             let embeddings = current_value.to_owned(); | ||||
|  | ||||
|     let embeddings: VectorOrArrayOfVectors = | ||||
|             let embeddings: Embedding = | ||||
|                 serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; | ||||
|  | ||||
|     let embeddings = embeddings.into_array_of_vectors(); | ||||
|  | ||||
|     let embeddings: Vec<Embeddings<f32>> = embeddings | ||||
|         .into_iter() | ||||
|         .flatten() | ||||
|         .map(|embedding| Embeddings::from_single_embedding(embedding)) | ||||
|         .collect(); | ||||
|             vec![Embeddings::from_single_embedding(embeddings)] | ||||
|         } | ||||
|         InputType::TextArray => { | ||||
|             let empty = vec![]; | ||||
|             let values = current_value.as_array().unwrap_or(&empty); | ||||
|             let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count); | ||||
|             for value in values { | ||||
|                 let mut current_value = value; | ||||
|                 for component in &options.embedding_object { | ||||
|                     current_value = current_value.get(component).ok_or_else(|| { | ||||
|                         EmbedError::rest_response_missing_embeddings( | ||||
|                             response.clone(), | ||||
|                             component, | ||||
|                             &options.embedding_object, | ||||
|                         ) | ||||
|                     })?; | ||||
|                 } | ||||
|                 let embedding = current_value.to_owned(); | ||||
|                 let embedding: Embedding = | ||||
|                     serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?; | ||||
|                 embeddings.push(Embeddings::from_single_embedding(embedding)); | ||||
|             } | ||||
|             embeddings | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     if embeddings.len() != expected_count { | ||||
|         return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); | ||||
|   | ||||
| @@ -204,7 +204,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|             }, | ||||
|             super::EmbedderOptions::Ollama(options) => Self { | ||||
|                 source: Setting::Set(EmbedderSource::Ollama), | ||||
|                 model: Setting::Set(options.embedding_model.name().to_owned()), | ||||
|                 model: Setting::Set(options.embedding_model.to_owned()), | ||||
|                 revision: Setting::NotSet, | ||||
|                 api_key: Setting::NotSet, | ||||
|                 dimensions: Setting::NotSet, | ||||
| @@ -248,7 +248,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig { | ||||
|                     let mut options: ollama::EmbedderOptions = | ||||
|                         super::ollama::EmbedderOptions::with_default_model(); | ||||
|                     if let Some(model) = model.set() { | ||||
|                         options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); | ||||
|                         options.embedding_model = model; | ||||
|                     } | ||||
|                     this.embedder_options = super::EmbedderOptions::Ollama(options); | ||||
|                 } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user