mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-31 07:56:28 +00:00 
			
		
		
		
	Don't use a runtime in extract_embedder, use it only for OpenAI
This commit is contained in:
		| @@ -339,9 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|     indexer: GrenadParameters, | ||||
|     embedder: Arc<Embedder>, | ||||
| ) -> Result<grenad::Reader<BufReader<File>>> { | ||||
|     let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; | ||||
|  | ||||
|     let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism | ||||
|     let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism | ||||
|     let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk | ||||
|  | ||||
|     // docid, state with embedding | ||||
| @@ -375,11 +373,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|         current_chunk_ids.push(docid); | ||||
|  | ||||
|         if chunks.len() == chunks.capacity() { | ||||
|             let chunked_embeds = rt | ||||
|                 .block_on( | ||||
|                     embedder | ||||
|                         .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), | ||||
|                 ) | ||||
|             let chunked_embeds = embedder | ||||
|                 .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) | ||||
|                 .map_err(crate::vector::Error::from) | ||||
|                 .map_err(crate::Error::from)?; | ||||
|  | ||||
| @@ -396,8 +391,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|  | ||||
|     // send last chunk | ||||
|     if !chunks.is_empty() { | ||||
|         let chunked_embeds = rt | ||||
|             .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) | ||||
|         let chunked_embeds = embedder | ||||
|             .embed_chunks(std::mem::take(&mut chunks)) | ||||
|             .map_err(crate::vector::Error::from) | ||||
|             .map_err(crate::Error::from)?; | ||||
|         for (docid, embeddings) in chunks_ids | ||||
| @@ -410,13 +405,15 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|     } | ||||
|  | ||||
|     if !current_chunk.is_empty() { | ||||
|         let embeds = rt | ||||
|             .block_on(embedder.embed(std::mem::take(&mut current_chunk))) | ||||
|         let embeds = embedder | ||||
|             .embed_chunks(vec![std::mem::take(&mut current_chunk)]) | ||||
|             .map_err(crate::vector::Error::from) | ||||
|             .map_err(crate::Error::from)?; | ||||
|  | ||||
|         for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { | ||||
|             state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||
|         if let Some(embeds) = embeds.first() { | ||||
|             for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { | ||||
|                 state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -67,6 +67,10 @@ pub enum EmbedErrorKind { | ||||
|     OpenAiUnhandledStatusCode(u16), | ||||
|     #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] | ||||
|     ManualEmbed(String), | ||||
|     #[error("could not initialize asynchronous runtime: {0}")] | ||||
|     OpenAiRuntimeInit(std::io::Error), | ||||
|     #[error("initializing web client for sending embedding requests failed: {0}")] | ||||
|     InitWebClient(reqwest::Error), | ||||
| } | ||||
|  | ||||
| impl EmbedError { | ||||
| @@ -117,6 +121,14 @@ impl EmbedError { | ||||
|     pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { | ||||
|         Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| @@ -183,10 +195,6 @@ impl NewEmbedderError { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } | ||||
|     } | ||||
| @@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind { | ||||
|     #[error("loading model failed: {0}")] | ||||
|     LoadModel(candle_core::Error), | ||||
|     // openai | ||||
|     #[error("initializing web client for sending embedding requests failed: {0}")] | ||||
|     InitWebClient(reqwest::Error), | ||||
|     #[error("The API key passed to Authorization error was in an invalid format: {0}")] | ||||
|     InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), | ||||
| } | ||||
|   | ||||
| @@ -163,18 +163,24 @@ impl Embedder { | ||||
|     ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed(texts), | ||||
|             Embedder::OpenAi(embedder) => embedder.embed(texts).await, | ||||
|             Embedder::OpenAi(embedder) => { | ||||
|                 let client = embedder.new_client()?; | ||||
|                 embedder.embed(texts, &client).await | ||||
|             } | ||||
|             Embedder::UserProvided(embedder) => embedder.embed(texts), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub async fn embed_chunks( | ||||
|     /// # Panics | ||||
|     /// | ||||
|     /// - if called from an asynchronous context | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), | ||||
|             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, | ||||
|             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), | ||||
|             Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     client: reqwest::Client, | ||||
|     headers: reqwest::header::HeaderMap, | ||||
|     tokenizer: tiktoken_rs::CoreBPE, | ||||
|     options: EmbedderOptions, | ||||
| } | ||||
| @@ -95,6 +95,13 @@ impl EmbedderOptions { | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> { | ||||
|         reqwest::ClientBuilder::new() | ||||
|             .default_headers(self.headers.clone()) | ||||
|             .build() | ||||
|             .map_err(EmbedError::openai_initialize_web_client) | ||||
|     } | ||||
|  | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|         let mut headers = reqwest::header::HeaderMap::new(); | ||||
|         let mut inferred_api_key = Default::default(); | ||||
| @@ -111,25 +118,25 @@ impl Embedder { | ||||
|             reqwest::header::CONTENT_TYPE, | ||||
|             reqwest::header::HeaderValue::from_static("application/json"), | ||||
|         ); | ||||
|         let client = reqwest::ClientBuilder::new() | ||||
|             .default_headers(headers) | ||||
|             .build() | ||||
|             .map_err(NewEmbedderError::openai_initialize_web_client)?; | ||||
|  | ||||
|         // looking at the code it is very unclear that this can actually fail. | ||||
|         let tokenizer = tiktoken_rs::cl100k_base().unwrap(); | ||||
|  | ||||
|         Ok(Self { options, client, tokenizer }) | ||||
|         Ok(Self { options, headers, tokenizer }) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|     pub async fn embed( | ||||
|         &self, | ||||
|         texts: Vec<String>, | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         let mut tokenized = false; | ||||
|  | ||||
|         for attempt in 0..7 { | ||||
|             let result = if tokenized { | ||||
|                 self.try_embed_tokenized(&texts).await | ||||
|                 self.try_embed_tokenized(&texts, client).await | ||||
|             } else { | ||||
|                 self.try_embed(&texts).await | ||||
|                 self.try_embed(&texts, client).await | ||||
|             }; | ||||
|  | ||||
|             let retry_duration = match result { | ||||
| @@ -145,9 +152,9 @@ impl Embedder { | ||||
|         } | ||||
|  | ||||
|         let result = if tokenized { | ||||
|             self.try_embed_tokenized(&texts).await | ||||
|             self.try_embed_tokenized(&texts, client).await | ||||
|         } else { | ||||
|             self.try_embed(&texts).await | ||||
|             self.try_embed(&texts, client).await | ||||
|         }; | ||||
|  | ||||
|         result.map_err(Retry::into_error) | ||||
| @@ -225,13 +232,13 @@ impl Embedder { | ||||
|     async fn try_embed<S: AsRef<str> + serde::Serialize>( | ||||
|         &self, | ||||
|         texts: &[S], | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|         for text in texts { | ||||
|             log::trace!("Received prompt: {}", text.as_ref()) | ||||
|         } | ||||
|         let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; | ||||
|         let response = self | ||||
|             .client | ||||
|         let response = client | ||||
|             .post(OPENAI_EMBEDDINGS_URL) | ||||
|             .json(&request) | ||||
|             .send() | ||||
| @@ -256,7 +263,11 @@ impl Embedder { | ||||
|             .collect()) | ||||
|     } | ||||
|  | ||||
|     async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|     async fn try_embed_tokenized( | ||||
|         &self, | ||||
|         text: &[String], | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|         pub const OVERLAP_SIZE: usize = 200; | ||||
|         let mut all_embeddings = Vec::with_capacity(text.len()); | ||||
|         for text in text { | ||||
| @@ -264,7 +275,7 @@ impl Embedder { | ||||
|             let encoded = self.tokenizer.encode_ordinary(text.as_str()); | ||||
|             let len = encoded.len(); | ||||
|             if len < max_token_count { | ||||
|                 all_embeddings.append(&mut self.try_embed(&[text]).await?); | ||||
|                 all_embeddings.append(&mut self.try_embed(&[text], client).await?); | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
| @@ -273,22 +284,26 @@ impl Embedder { | ||||
|                 Embeddings::new(self.options.embedding_model.dimensions()); | ||||
|             while tokens.len() > max_token_count { | ||||
|                 let window = &tokens[..max_token_count]; | ||||
|                 embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); | ||||
|                 embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); | ||||
|  | ||||
|                 tokens = &tokens[max_token_count - OVERLAP_SIZE..]; | ||||
|             } | ||||
|  | ||||
|             // end of text | ||||
|             embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); | ||||
|             embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap(); | ||||
|  | ||||
|             all_embeddings.push(embeddings_for_prompt); | ||||
|         } | ||||
|         Ok(all_embeddings) | ||||
|     } | ||||
|  | ||||
|     async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { | ||||
|     async fn embed_tokens( | ||||
|         &self, | ||||
|         tokens: &[usize], | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Embedding, Retry> { | ||||
|         for attempt in 0..9 { | ||||
|             let duration = match self.try_embed_tokens(tokens).await { | ||||
|             let duration = match self.try_embed_tokens(tokens, client).await { | ||||
|                 Ok(embedding) => return Ok(embedding), | ||||
|                 Err(retry) => retry.into_duration(attempt), | ||||
|             } | ||||
| @@ -297,14 +312,19 @@ impl Embedder { | ||||
|             tokio::time::sleep(duration).await; | ||||
|         } | ||||
|  | ||||
|         self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) | ||||
|         self.try_embed_tokens(tokens, client) | ||||
|             .await | ||||
|             .map_err(|retry| Retry::give_up(retry.into_error())) | ||||
|     } | ||||
|  | ||||
|     async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { | ||||
|     async fn try_embed_tokens( | ||||
|         &self, | ||||
|         tokens: &[usize], | ||||
|         client: &reqwest::Client, | ||||
|     ) -> Result<Embedding, Retry> { | ||||
|         let request = | ||||
|             OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; | ||||
|         let response = self | ||||
|             .client | ||||
|         let response = client | ||||
|             .post(OPENAI_EMBEDDINGS_URL) | ||||
|             .json(&request) | ||||
|             .send() | ||||
| @@ -322,12 +342,19 @@ impl Embedder { | ||||
|         Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed_chunks( | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) | ||||
|             .await | ||||
|         let rt = tokio::runtime::Builder::new_current_thread() | ||||
|             .enable_io() | ||||
|             .enable_time() | ||||
|             .build() | ||||
|             .map_err(EmbedError::openai_runtime_init)?; | ||||
|         let client = self.new_client()?; | ||||
|         rt.block_on(futures::future::try_join_all( | ||||
|             text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), | ||||
|         )) | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user