mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-30 15:36:28 +00:00 
			
		
		
		
	rest embedder: use json_template
This commit is contained in:
		| @@ -228,7 +228,9 @@ impl Embedder { | ||||
|             EmbedderOptions::UserProvided(options) => { | ||||
|                 Self::UserProvided(manual::Embedder::new(options)) | ||||
|             } | ||||
|             EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?), | ||||
|             EmbedderOptions::Rest(options) => { | ||||
|                 Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?) | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -4,6 +4,7 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use super::error::EmbedErrorKind; | ||||
| use super::json_template::ValueTemplate; | ||||
| use super::{ | ||||
|     DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, | ||||
| }; | ||||
| @@ -11,12 +12,18 @@ use crate::error::FaultSource; | ||||
| use crate::ThreadPoolNoAbort; | ||||
|  | ||||
| // retrying in case of failure | ||||
|  | ||||
| pub struct Retry { | ||||
|     pub error: EmbedError, | ||||
|     strategy: RetryStrategy, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||||
| pub enum ConfigurationSource { | ||||
|     OpenAi, | ||||
|     Ollama, | ||||
|     User, | ||||
| } | ||||
|  | ||||
| pub enum RetryStrategy { | ||||
|     GiveUp, | ||||
|     Retry, | ||||
| @@ -63,10 +70,20 @@ impl Retry { | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     client: ureq::Agent, | ||||
|     options: EmbedderOptions, | ||||
|     bearer: Option<String>, | ||||
|     data: EmbedderData, | ||||
|     dimensions: usize, | ||||
|     distribution: Option<DistributionShift>, | ||||
| } | ||||
|  | ||||
| /// All data needed to perform requests and parse responses | ||||
| #[derive(Debug)] | ||||
| struct EmbedderData { | ||||
|     client: ureq::Agent, | ||||
|     bearer: Option<String>, | ||||
|     url: String, | ||||
|     request: Request, | ||||
|     response: Response, | ||||
|     configuration_source: ConfigurationSource, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] | ||||
| @@ -75,29 +92,8 @@ pub struct EmbedderOptions { | ||||
|     pub distribution: Option<DistributionShift>, | ||||
|     pub dimensions: Option<usize>, | ||||
|     pub url: String, | ||||
|     pub query: serde_json::Value, | ||||
|     pub input_field: Vec<String>, | ||||
|     // path to the array of embeddings | ||||
|     pub path_to_embeddings: Vec<String>, | ||||
|     // shape of a single embedding | ||||
|     pub embedding_object: Vec<String>, | ||||
|     pub input_type: InputType, | ||||
| } | ||||
|  | ||||
| impl Default for EmbedderOptions { | ||||
|     fn default() -> Self { | ||||
|         Self { | ||||
|             url: Default::default(), | ||||
|             query: Default::default(), | ||||
|             input_field: vec!["input".into()], | ||||
|             path_to_embeddings: vec!["data".into()], | ||||
|             embedding_object: vec!["embedding".into()], | ||||
|             input_type: InputType::Text, | ||||
|             api_key: None, | ||||
|             distribution: None, | ||||
|             dimensions: None, | ||||
|         } | ||||
|     } | ||||
|     pub request: serde_json::Value, | ||||
|     pub response: serde_json::Value, | ||||
| } | ||||
|  | ||||
| impl std::hash::Hash for EmbedderOptions { | ||||
| @@ -106,26 +102,25 @@ impl std::hash::Hash for EmbedderOptions { | ||||
|         self.distribution.hash(state); | ||||
|         self.dimensions.hash(state); | ||||
|         self.url.hash(state); | ||||
|         // skip hashing the query | ||||
|         // skip hashing the request and response | ||||
|         // collisions in regular usage should be minimal, | ||||
|         // and the list is limited to 256 values anyway | ||||
|         self.input_field.hash(state); | ||||
|         self.path_to_embeddings.hash(state); | ||||
|         self.embedding_object.hash(state); | ||||
|         self.input_type.hash(state); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)] | ||||
| #[serde(rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub enum InputType { | ||||
| enum InputType { | ||||
|     Text, | ||||
|     TextArray, | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|     pub fn new( | ||||
|         options: EmbedderOptions, | ||||
|         configuration_source: ConfigurationSource, | ||||
|     ) -> Result<Self, NewEmbedderError> { | ||||
|         let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); | ||||
|  | ||||
|         let client = ureq::AgentBuilder::new() | ||||
| @@ -133,28 +128,40 @@ impl Embedder { | ||||
|             .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) | ||||
|             .build(); | ||||
|  | ||||
|         let request = Request::new(options.request)?; | ||||
|         let response = Response::new(options.response, &request)?; | ||||
|  | ||||
|         let data = EmbedderData { | ||||
|             client, | ||||
|             bearer, | ||||
|             url: options.url, | ||||
|             request, | ||||
|             response, | ||||
|             configuration_source, | ||||
|         }; | ||||
|  | ||||
|         let dimensions = if let Some(dimensions) = options.dimensions { | ||||
|             dimensions | ||||
|         } else { | ||||
|             infer_dimensions(&client, &options, bearer.as_deref())? | ||||
|             infer_dimensions(&data)? | ||||
|         }; | ||||
|  | ||||
|         Ok(Self { client, dimensions, options, bearer }) | ||||
|         Ok(Self { data, dimensions, distribution: options.distribution }) | ||||
|     } | ||||
|  | ||||
|     pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) | ||||
|         embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError> | ||||
|     where | ||||
|         S: AsRef<str> + Serialize, | ||||
|     { | ||||
|         embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) | ||||
|         embed(&self.data, texts, texts.len(), Some(self.dimensions)) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> { | ||||
|         let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; | ||||
|         let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; | ||||
|         // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error | ||||
|         Ok(embeddings.pop().unwrap()) | ||||
|     } | ||||
| @@ -179,7 +186,7 @@ impl Embedder { | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         match self.options.input_type { | ||||
|         match self.data.request.input_type() { | ||||
|             InputType::Text => 1, | ||||
|             InputType::TextArray => 10, | ||||
|         } | ||||
| @@ -190,87 +197,44 @@ impl Embedder { | ||||
|     } | ||||
|  | ||||
|     pub fn distribution(&self) -> Option<DistributionShift> { | ||||
|         self.options.distribution | ||||
|         self.distribution | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn infer_dimensions( | ||||
|     client: &ureq::Agent, | ||||
|     options: &EmbedderOptions, | ||||
|     bearer: Option<&str>, | ||||
| ) -> Result<usize, NewEmbedderError> { | ||||
|     let v = embed(client, options, bearer, ["test"].as_slice(), 1) | ||||
| fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> { | ||||
|     let v = embed(data, ["test"].as_slice(), 1, None) | ||||
|         .map_err(NewEmbedderError::could_not_determine_dimension)?; | ||||
|     // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error | ||||
|     Ok(v.first().unwrap().dimension()) | ||||
| } | ||||
|  | ||||
| fn embed<S>( | ||||
|     client: &ureq::Agent, | ||||
|     options: &EmbedderOptions, | ||||
|     bearer: Option<&str>, | ||||
|     data: &EmbedderData, | ||||
|     inputs: &[S], | ||||
|     expected_count: usize, | ||||
|     expected_dimension: Option<usize>, | ||||
| ) -> Result<Vec<Embeddings<f32>>, EmbedError> | ||||
| where | ||||
|     S: Serialize, | ||||
| { | ||||
|     let request = client.post(&options.url); | ||||
|     let request = | ||||
|         if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; | ||||
|     let request = data.client.post(&data.url); | ||||
|     let request = if let Some(bearer) = &data.bearer { | ||||
|         request.set("Authorization", bearer) | ||||
|     } else { | ||||
|         request | ||||
|     }; | ||||
|     let request = request.set("Content-Type", "application/json"); | ||||
|  | ||||
|     let input_value = match options.input_type { | ||||
|         InputType::Text => serde_json::json!(inputs.first()), | ||||
|         InputType::TextArray => serde_json::json!(inputs), | ||||
|     }; | ||||
|  | ||||
|     let body = match options.input_field.as_slice() { | ||||
|         [] => { | ||||
|             // inject input in body | ||||
|             input_value | ||||
|         } | ||||
|         [input] => { | ||||
|             let mut body = options.query.clone(); | ||||
|  | ||||
|             body.as_object_mut() | ||||
|                 .ok_or_else(|| { | ||||
|                     EmbedError::rest_not_an_object( | ||||
|                         options.query.clone(), | ||||
|                         options.input_field.clone(), | ||||
|                     ) | ||||
|                 })? | ||||
|                 .insert(input.clone(), input_value); | ||||
|             body | ||||
|         } | ||||
|         [path @ .., input] => { | ||||
|             let mut body = options.query.clone(); | ||||
|  | ||||
|             let mut current_value = &mut body; | ||||
|             for component in path { | ||||
|                 current_value = current_value | ||||
|                     .as_object_mut() | ||||
|                     .ok_or_else(|| { | ||||
|                         EmbedError::rest_not_an_object( | ||||
|                             options.query.clone(), | ||||
|                             options.input_field.clone(), | ||||
|                         ) | ||||
|                     })? | ||||
|                     .entry(component.clone()) | ||||
|                     .or_insert(serde_json::json!({})); | ||||
|             } | ||||
|  | ||||
|             current_value.as_object_mut().unwrap().insert(input.clone(), input_value); | ||||
|             body | ||||
|         } | ||||
|     }; | ||||
|     let body = data.request.inject_texts(inputs); | ||||
|  | ||||
|     for attempt in 0..10 { | ||||
|         let response = request.clone().send_json(&body); | ||||
|         let result = check_response(response); | ||||
|         let result = check_response(response, data.configuration_source); | ||||
|  | ||||
|         let retry_duration = match result { | ||||
|             Ok(response) => return response_to_embedding(response, options, expected_count), | ||||
|             Ok(response) => { | ||||
|                 return response_to_embedding(response, data, expected_count, expected_dimension) | ||||
|             } | ||||
|             Err(retry) => { | ||||
|                 tracing::warn!("Failed: {}", retry.error); | ||||
|                 retry.into_duration(attempt) | ||||
| @@ -288,13 +252,16 @@ where | ||||
|     } | ||||
|  | ||||
|     let response = request.send_json(&body); | ||||
|     let result = check_response(response); | ||||
|     result | ||||
|         .map_err(Retry::into_error) | ||||
|         .and_then(|response| response_to_embedding(response, options, expected_count)) | ||||
|     let result = check_response(response, data.configuration_source); | ||||
|     result.map_err(Retry::into_error).and_then(|response| { | ||||
|         response_to_embedding(response, data, expected_count, expected_dimension) | ||||
|     }) | ||||
| } | ||||
|  | ||||
| fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> { | ||||
| fn check_response( | ||||
|     response: Result<ureq::Response, ureq::Error>, | ||||
|     configuration_source: ConfigurationSource, | ||||
| ) -> Result<ureq::Response, Retry> { | ||||
|     match response { | ||||
|         Ok(response) => Ok(response), | ||||
|         Err(ureq::Error::Status(code, response)) => { | ||||
| @@ -302,7 +269,10 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq: | ||||
|             Err(match code { | ||||
|                 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), | ||||
|                 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), | ||||
|                 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), | ||||
|                 400 => Retry::give_up(EmbedError::rest_bad_request( | ||||
|                     error_response, | ||||
|                     configuration_source, | ||||
|                 )), | ||||
|                 500..=599 => { | ||||
|                     Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) | ||||
|                 } | ||||
| @@ -320,68 +290,111 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq: | ||||
|  | ||||
| fn response_to_embedding( | ||||
|     response: ureq::Response, | ||||
|     options: &EmbedderOptions, | ||||
|     data: &EmbedderData, | ||||
|     expected_count: usize, | ||||
|     expected_dimensions: Option<usize>, | ||||
| ) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|     let response: serde_json::Value = | ||||
|         response.into_json().map_err(EmbedError::rest_response_deserialization)?; | ||||
|  | ||||
|     let mut current_value = &response; | ||||
|     for component in &options.path_to_embeddings { | ||||
|         let component = component.as_ref(); | ||||
|         current_value = current_value.get(component).ok_or_else(|| { | ||||
|             EmbedError::rest_response_missing_embeddings( | ||||
|                 response.clone(), | ||||
|                 component, | ||||
|                 &options.path_to_embeddings, | ||||
|             ) | ||||
|         })?; | ||||
|     } | ||||
|  | ||||
|     let embeddings = match options.input_type { | ||||
|         InputType::Text => { | ||||
|             for component in &options.embedding_object { | ||||
|                 current_value = current_value.get(component).ok_or_else(|| { | ||||
|                     EmbedError::rest_response_missing_embeddings( | ||||
|                         response.clone(), | ||||
|                         component, | ||||
|                         &options.embedding_object, | ||||
|                     ) | ||||
|                 })?; | ||||
|             } | ||||
|             let embeddings = current_value.to_owned(); | ||||
|             let embeddings: Embedding = | ||||
|                 serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; | ||||
|  | ||||
|             vec![Embeddings::from_single_embedding(embeddings)] | ||||
|         } | ||||
|         InputType::TextArray => { | ||||
|             let empty = vec![]; | ||||
|             let values = current_value.as_array().unwrap_or(&empty); | ||||
|             let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count); | ||||
|             for value in values { | ||||
|                 let mut current_value = value; | ||||
|                 for component in &options.embedding_object { | ||||
|                     current_value = current_value.get(component).ok_or_else(|| { | ||||
|                         EmbedError::rest_response_missing_embeddings( | ||||
|                             response.clone(), | ||||
|                             component, | ||||
|                             &options.embedding_object, | ||||
|                         ) | ||||
|                     })?; | ||||
|                 } | ||||
|                 let embedding = current_value.to_owned(); | ||||
|                 let embedding: Embedding = | ||||
|                     serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?; | ||||
|                 embeddings.push(Embeddings::from_single_embedding(embedding)); | ||||
|             } | ||||
|             embeddings | ||||
|         } | ||||
|     }; | ||||
|     let embeddings = data.response.extract_embeddings(response)?; | ||||
|  | ||||
|     if embeddings.len() != expected_count { | ||||
|         return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); | ||||
|     } | ||||
|  | ||||
|     if let Some(dimensions) = expected_dimensions { | ||||
|         for embedding in &embeddings { | ||||
|             if embedding.dimension() != dimensions { | ||||
|                 return Err(EmbedError::rest_unexpected_dimension( | ||||
|                     dimensions, | ||||
|                     embedding.dimension(), | ||||
|                 )); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(embeddings) | ||||
| } | ||||
|  | ||||
| pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}"; | ||||
| pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}"; | ||||
| pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}"; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Request { | ||||
|     template: ValueTemplate, | ||||
| } | ||||
|  | ||||
| impl Request { | ||||
|     pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> { | ||||
|         let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) { | ||||
|             Ok(template) => template, | ||||
|             Err(error) => { | ||||
|                 let message = | ||||
|                     error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER); | ||||
|                 return Err(NewEmbedderError::rest_could_not_parse_template(message)); | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         Ok(Self { template }) | ||||
|     } | ||||
|  | ||||
|     fn input_type(&self) -> InputType { | ||||
|         if self.template.has_array_value() { | ||||
|             InputType::TextArray | ||||
|         } else { | ||||
|             InputType::Text | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn inject_texts<S: Serialize>( | ||||
|         &self, | ||||
|         texts: impl IntoIterator<Item = S>, | ||||
|     ) -> serde_json::Value { | ||||
|         self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap() | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Response { | ||||
|     template: ValueTemplate, | ||||
| } | ||||
|  | ||||
| impl Response { | ||||
|     pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> { | ||||
|         let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) | ||||
|         { | ||||
|             Ok(template) => template, | ||||
|             Err(error) => { | ||||
|                 let message = | ||||
|                     error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER); | ||||
|                 return Err(NewEmbedderError::rest_could_not_parse_template(message)); | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         match (template.has_array_value(), request.template.has_array_value()) { | ||||
|             (true, true) | (false, false) => Ok(Self {template}), | ||||
|             (true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())), | ||||
|             (false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn extract_embeddings( | ||||
|         &self, | ||||
|         response: serde_json::Value, | ||||
|     ) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         let extracted_values: Vec<Embedding> = match self.template.extract(response) { | ||||
|             Ok(extracted_values) => extracted_values, | ||||
|             Err(error) => { | ||||
|                 let error_message = | ||||
|                     error.error_message("response", "{{embedding}}", "an array of numbers"); | ||||
|                 return Err(EmbedError::rest_extraction_error(error_message)); | ||||
|             } | ||||
|         }; | ||||
|         let embeddings: Vec<Embeddings<f32>> = | ||||
|             extracted_values.into_iter().map(Embeddings::from_single_embedding).collect(); | ||||
|  | ||||
|         Ok(embeddings) | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user