mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 13:06:27 +00:00 
			
		
		
		
	Add RestEmbedder
This commit is contained in:
		| @@ -83,6 +83,32 @@ pub enum EmbedErrorKind { | ||||
|     OllamaModelNotFoundError(OllamaError), | ||||
|     #[error("received unhandled HTTP status code {0} from Ollama")] | ||||
|     OllamaUnhandledStatusCode(u16), | ||||
|     #[error("error serializing template context: {0}")] | ||||
|     RestTemplateContextSerialization(liquid::Error), | ||||
|     #[error( | ||||
|         "error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'" | ||||
|     )] | ||||
|     RestTemplateError(liquid::Error), | ||||
|     #[error("error deserialization the response body as JSON: {0}")] | ||||
|     RestResponseDeserialization(std::io::Error), | ||||
|     #[error("component `{0}` not found in path `{1}` in response: `{2}`")] | ||||
|     RestResponseMissingEmbeddings(String, String, String), | ||||
|     #[error("expected a response parseable as a vector or an array of vectors: {0}")] | ||||
|     RestResponseFormat(serde_json::Error), | ||||
|     #[error("expected a response containing {0} embeddings, got only {1}")] | ||||
|     RestResponseEmbeddingCount(usize, usize), | ||||
|     #[error("could not authenticate against embedding server: {0:?}")] | ||||
|     RestUnauthorized(Option<String>), | ||||
|     #[error("sent too many requests to embedding server: {0:?}")] | ||||
|     RestTooManyRequests(Option<String>), | ||||
|     #[error("sent a bad request to embedding server: {0:?}")] | ||||
|     RestBadRequest(Option<String>), | ||||
|     #[error("received internal error from embedding server: {0:?}")] | ||||
|     RestInternalServerError(u16, Option<String>), | ||||
|     #[error("received HTTP {0} from embedding server: {0:?}")] | ||||
|     RestOtherStatusCode(u16, Option<String>), | ||||
|     #[error("could not reach embedding server: {0}")] | ||||
|     RestNetwork(ureq::Transport), | ||||
| } | ||||
|  | ||||
| impl EmbedError { | ||||
| @@ -161,6 +187,89 @@ impl EmbedError { | ||||
|     pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestTemplateContextSerialization(error), | ||||
|             fault: FaultSource::Bug, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestResponseDeserialization(error), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_response_missing_embeddings<S: AsRef<str>>( | ||||
|         response: serde_json::Value, | ||||
|         component: &str, | ||||
|         response_field: &[S], | ||||
|     ) -> EmbedError { | ||||
|         let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect(); | ||||
|         let response_field = response_field.join("."); | ||||
|  | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestResponseMissingEmbeddings( | ||||
|                 component.to_owned(), | ||||
|                 response_field, | ||||
|                 serde_json::to_string_pretty(&response).unwrap_or_default(), | ||||
|             ), | ||||
|             fault: FaultSource::Undecided, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_unauthorized(error_response: Option<String>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestTooManyRequests(error_response), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_bad_request(error_response: Option<String>) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_internal_server_error( | ||||
|         code: u16, | ||||
|         error_response: Option<String>, | ||||
|     ) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestInternalServerError(code, error_response), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError { | ||||
|         Self { | ||||
|             kind: EmbedErrorKind::RestOtherStatusCode(code, error_response), | ||||
|             fault: FaultSource::Undecided, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
|   | ||||
| @@ -11,6 +11,7 @@ pub mod openai; | ||||
| pub mod settings; | ||||
|  | ||||
| pub mod ollama; | ||||
| pub mod rest; | ||||
|  | ||||
| pub use self::error::Error; | ||||
|  | ||||
|   | ||||
							
								
								
									
										185
									
								
								milli/src/vector/rest.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								milli/src/vector/rest.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,185 @@ | ||||
| use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
|  | ||||
| use super::openai::Retry; | ||||
| use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError}; | ||||
| use crate::VectorOrArrayOfVectors; | ||||
|  | ||||
| pub struct Embedder { | ||||
|     client: ureq::Agent, | ||||
|     options: EmbedderOptions, | ||||
|     bearer: Option<String>, | ||||
|     dimensions: usize, | ||||
| } | ||||
|  | ||||
| pub struct EmbedderOptions { | ||||
|     api_key: Option<String>, | ||||
|     distribution: Option<DistributionShift>, | ||||
|     dimensions: Option<usize>, | ||||
|     url: String, | ||||
|     query: liquid::Template, | ||||
|     response_field: Vec<String>, | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|         let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}")); | ||||
|  | ||||
|         let client = ureq::agent(); | ||||
|  | ||||
|         let dimensions = if let Some(dimensions) = options.dimensions { | ||||
|             dimensions | ||||
|         } else { | ||||
|             infer_dimensions(&client, &options, bearer.as_deref())? | ||||
|         }; | ||||
|  | ||||
|         Ok(Self { client, dimensions, options, bearer }) | ||||
|     } | ||||
|  | ||||
|     pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice()) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         threads | ||||
|             .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) | ||||
|             .collect() | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|         10 | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         10 | ||||
|     } | ||||
|  | ||||
|     pub fn dimensions(&self) -> usize { | ||||
|         self.dimensions | ||||
|     } | ||||
|  | ||||
|     pub fn distribution(&self) -> Option<DistributionShift> { | ||||
|         self.options.distribution | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn infer_dimensions( | ||||
|     client: &ureq::Agent, | ||||
|     options: &EmbedderOptions, | ||||
|     bearer: Option<&str>, | ||||
| ) -> Result<usize, NewEmbedderError> { | ||||
|     let v = embed(client, options, bearer, ["test"].as_slice()) | ||||
|         .map_err(NewEmbedderError::could_not_determine_dimension)?; | ||||
|     // unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error | ||||
|     Ok(v.first().unwrap().dimension()) | ||||
| } | ||||
|  | ||||
| fn embed<S>( | ||||
|     client: &ureq::Agent, | ||||
|     options: &EmbedderOptions, | ||||
|     bearer: Option<&str>, | ||||
|     inputs: &[S], | ||||
| ) -> Result<Vec<Embeddings<f32>>, EmbedError> | ||||
| where | ||||
|     S: serde::Serialize, | ||||
| { | ||||
|     let request = client.post(&options.url); | ||||
|     let request = | ||||
|         if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; | ||||
|     let request = request.set("Content-Type", "application/json"); | ||||
|  | ||||
|     let body = options | ||||
|         .query | ||||
|         .render( | ||||
|             &liquid::to_object(&serde_json::json!({ | ||||
|                 "input": inputs, | ||||
|             })) | ||||
|             .map_err(EmbedError::rest_template_context_serialization)?, | ||||
|         ) | ||||
|         .map_err(EmbedError::rest_template_render)?; | ||||
|  | ||||
|     for attempt in 0..7 { | ||||
|         let response = request.send_string(&body); | ||||
|         let result = check_response(response); | ||||
|  | ||||
|         let retry_duration = match result { | ||||
|             Ok(response) => { | ||||
|                 return response_to_embedding(response, &options.response_field, inputs.len()) | ||||
|             } | ||||
|             Err(retry) => { | ||||
|                 tracing::warn!("Failed: {}", retry.error); | ||||
|                 retry.into_duration(attempt) | ||||
|             } | ||||
|         }?; | ||||
|  | ||||
|         let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute | ||||
|         tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); | ||||
|         std::thread::sleep(retry_duration); | ||||
|     } | ||||
|  | ||||
|     let response = request.send_string(&body); | ||||
|     let result = check_response(response); | ||||
|     result | ||||
|         .map_err(Retry::into_error) | ||||
|         .and_then(|response| response_to_embedding(response, &options.response_field, inputs.len())) | ||||
| } | ||||
|  | ||||
| fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> { | ||||
|     match response { | ||||
|         Ok(response) => Ok(response), | ||||
|         Err(ureq::Error::Status(code, response)) => { | ||||
|             let error_response: Option<String> = response.into_string().ok(); | ||||
|             Err(match code { | ||||
|                 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), | ||||
|                 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), | ||||
|                 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), | ||||
|                 500..=599 => { | ||||
|                     Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) | ||||
|                 } | ||||
|                 x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), | ||||
|             }) | ||||
|         } | ||||
|         Err(ureq::Error::Transport(transport)) => { | ||||
|             Err(Retry::retry_later(EmbedError::rest_network(transport))) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn response_to_embedding<S: AsRef<str>>( | ||||
|     response: ureq::Response, | ||||
|     response_field: &[S], | ||||
|     expected_count: usize, | ||||
| ) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|     let response: serde_json::Value = | ||||
|         response.into_json().map_err(EmbedError::rest_response_deserialization)?; | ||||
|  | ||||
|     let mut current_value = &response; | ||||
|     for component in response_field { | ||||
|         let component = component.as_ref(); | ||||
|         let current_value = current_value.get(component).ok_or_else(|| { | ||||
|             EmbedError::rest_response_missing_embeddings(response, component, response_field) | ||||
|         })?; | ||||
|     } | ||||
|  | ||||
|     let embeddings = current_value.to_owned(); | ||||
|  | ||||
|     let embeddings: VectorOrArrayOfVectors = | ||||
|         serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; | ||||
|  | ||||
|     let embeddings = embeddings.into_array_of_vectors(); | ||||
|  | ||||
|     let embeddings: Vec<Embeddings<f32>> = embeddings | ||||
|         .into_iter() | ||||
|         .flatten() | ||||
|         .map(|embedding| Embeddings::from_single_embedding(embedding)) | ||||
|         .collect(); | ||||
|  | ||||
|     if embeddings.len() != expected_count { | ||||
|         return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); | ||||
|     } | ||||
|  | ||||
|     Ok(embeddings) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user