mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-24 20:46:27 +00:00 
			
		
		
		
	Expose REST embedder to the API
This commit is contained in:
		| @@ -2646,6 +2646,12 @@ mod tests { | ||||
|                         api_key: Setting::NotSet, | ||||
|                         dimensions: Setting::Set(3), | ||||
|                         document_template: Setting::NotSet, | ||||
|                         url: Setting::NotSet, | ||||
|                         query: Setting::NotSet, | ||||
|                         input_field: Setting::NotSet, | ||||
|                         path_to_embeddings: Setting::NotSet, | ||||
|                         embedding_object: Setting::NotSet, | ||||
|                         input_type: Setting::NotSet, | ||||
|                     }), | ||||
|                 ); | ||||
|                 settings.set_embedder_settings(embedders); | ||||
|   | ||||
| @@ -1140,6 +1140,12 @@ fn validate_prompt( | ||||
|             api_key, | ||||
|             dimensions, | ||||
|             document_template: Setting::Set(template), | ||||
|             url, | ||||
|             query, | ||||
|             input_field, | ||||
|             path_to_embeddings, | ||||
|             embedding_object, | ||||
|             input_type, | ||||
|         }) => { | ||||
|             // validate | ||||
|             let template = crate::prompt::Prompt::new(template) | ||||
| @@ -1153,6 +1159,12 @@ fn validate_prompt( | ||||
|                 api_key, | ||||
|                 dimensions, | ||||
|                 document_template: Setting::Set(template), | ||||
|                 url, | ||||
|                 query, | ||||
|                 input_field, | ||||
|                 path_to_embeddings, | ||||
|                 embedding_object, | ||||
|                 input_type, | ||||
|             })) | ||||
|         } | ||||
|         new => Ok(new), | ||||
| @@ -1165,8 +1177,20 @@ pub fn validate_embedding_settings( | ||||
| ) -> Result<Setting<EmbeddingSettings>> { | ||||
|     let settings = validate_prompt(name, settings)?; | ||||
|     let Setting::Set(settings) = settings else { return Ok(settings) }; | ||||
|     let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = | ||||
|         settings; | ||||
|     let EmbeddingSettings { | ||||
|         source, | ||||
|         model, | ||||
|         revision, | ||||
|         api_key, | ||||
|         dimensions, | ||||
|         document_template, | ||||
|         url, | ||||
|         query, | ||||
|         input_field, | ||||
|         path_to_embeddings, | ||||
|         embedding_object, | ||||
|         input_type, | ||||
|     } = settings; | ||||
|  | ||||
|     if let Some(0) = dimensions.set() { | ||||
|         return Err(crate::error::UserError::InvalidSettingsDimensions { | ||||
| @@ -1183,11 +1207,25 @@ pub fn validate_embedding_settings( | ||||
|             api_key, | ||||
|             dimensions, | ||||
|             document_template, | ||||
|             url, | ||||
|             query, | ||||
|             input_field, | ||||
|             path_to_embeddings, | ||||
|             embedding_object, | ||||
|             input_type, | ||||
|         })); | ||||
|     }; | ||||
|     match inferred_source { | ||||
|         EmbedderSource::OpenAi => { | ||||
|             check_unset(&revision, "revision", inferred_source, name)?; | ||||
|  | ||||
|             check_unset(&url, "url", inferred_source, name)?; | ||||
|             check_unset(&query, "query", inferred_source, name)?; | ||||
|             check_unset(&input_field, "inputField", inferred_source, name)?; | ||||
|             check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; | ||||
|             check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; | ||||
|             check_unset(&input_type, "inputType", inferred_source, name)?; | ||||
|  | ||||
|             if let Setting::Set(model) = &model { | ||||
|                 let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str()) | ||||
|                     .ok_or(crate::error::UserError::InvalidOpenAiModel { | ||||
| @@ -1224,10 +1262,24 @@ pub fn validate_embedding_settings( | ||||
|             check_set(&model, "model", inferred_source, name)?; | ||||
|             check_unset(&api_key, "apiKey", inferred_source, name)?; | ||||
|             check_unset(&revision, "revision", inferred_source, name)?; | ||||
|  | ||||
|             check_unset(&url, "url", inferred_source, name)?; | ||||
|             check_unset(&query, "query", inferred_source, name)?; | ||||
|             check_unset(&input_field, "inputField", inferred_source, name)?; | ||||
|             check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; | ||||
|             check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; | ||||
|             check_unset(&input_type, "inputType", inferred_source, name)?; | ||||
|         } | ||||
|         EmbedderSource::HuggingFace => { | ||||
|             check_unset(&api_key, "apiKey", inferred_source, name)?; | ||||
|             check_unset(&dimensions, "dimensions", inferred_source, name)?; | ||||
|  | ||||
|             check_unset(&url, "url", inferred_source, name)?; | ||||
|             check_unset(&query, "query", inferred_source, name)?; | ||||
|             check_unset(&input_field, "inputField", inferred_source, name)?; | ||||
|             check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; | ||||
|             check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; | ||||
|             check_unset(&input_type, "inputType", inferred_source, name)?; | ||||
|         } | ||||
|         EmbedderSource::UserProvided => { | ||||
|             check_unset(&model, "model", inferred_source, name)?; | ||||
| @@ -1235,6 +1287,18 @@ pub fn validate_embedding_settings( | ||||
|             check_unset(&api_key, "apiKey", inferred_source, name)?; | ||||
|             check_unset(&document_template, "documentTemplate", inferred_source, name)?; | ||||
|             check_set(&dimensions, "dimensions", inferred_source, name)?; | ||||
|  | ||||
|             check_unset(&url, "url", inferred_source, name)?; | ||||
|             check_unset(&query, "query", inferred_source, name)?; | ||||
|             check_unset(&input_field, "inputField", inferred_source, name)?; | ||||
|             check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; | ||||
|             check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; | ||||
|             check_unset(&input_type, "inputType", inferred_source, name)?; | ||||
|         } | ||||
|         EmbedderSource::Rest => { | ||||
|             check_unset(&model, "model", inferred_source, name)?; | ||||
|             check_unset(&revision, "revision", inferred_source, name)?; | ||||
|             check_set(&url, "url", inferred_source, name)?; | ||||
|         } | ||||
|     } | ||||
|     Ok(Setting::Set(EmbeddingSettings { | ||||
| @@ -1244,6 +1308,12 @@ pub fn validate_embedding_settings( | ||||
|         api_key, | ||||
|         dimensions, | ||||
|         document_template, | ||||
|         url, | ||||
|         query, | ||||
|         input_field, | ||||
|         path_to_embeddings, | ||||
|         embedding_object, | ||||
|         input_type, | ||||
|     })) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -194,7 +194,10 @@ impl Embedder { | ||||
|  | ||||
|     pub fn distribution(&self) -> Option<DistributionShift> { | ||||
|         if self.options.model == "BAAI/bge-base-en-v1.5" { | ||||
|             Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) | ||||
|             Some(DistributionShift { | ||||
|                 current_mean: ordered_float::OrderedFloat(0.85), | ||||
|                 current_sigma: ordered_float::OrderedFloat(0.1), | ||||
|             }) | ||||
|         } else { | ||||
|             None | ||||
|         } | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| use std::collections::HashMap; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use ordered_float::OrderedFloat; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use self::error::{EmbedError, NewEmbedderError}; | ||||
| use crate::prompt::{Prompt, PromptData}; | ||||
|  | ||||
| @@ -104,7 +107,10 @@ pub enum Embedder { | ||||
|     OpenAi(openai::Embedder), | ||||
|     /// An embedder based on the user providing the embeddings in the documents and queries. | ||||
|     UserProvided(manual::Embedder), | ||||
|     /// An embedder based on making embedding queries against an <https://ollama.com> embedding server. | ||||
|     Ollama(ollama::Embedder), | ||||
|     /// An embedder based on making embedding queries against a generic JSON/REST embedding server. | ||||
|     Rest(rest::Embedder), | ||||
| } | ||||
|  | ||||
| /// Configuration for an embedder. | ||||
| @@ -175,6 +181,7 @@ pub enum EmbedderOptions { | ||||
|     OpenAi(openai::EmbedderOptions), | ||||
|     Ollama(ollama::EmbedderOptions), | ||||
|     UserProvided(manual::EmbedderOptions), | ||||
|     Rest(rest::EmbedderOptions), | ||||
| } | ||||
|  | ||||
| impl Default for EmbedderOptions { | ||||
| @@ -209,6 +216,7 @@ impl Embedder { | ||||
|             EmbedderOptions::UserProvided(options) => { | ||||
|                 Self::UserProvided(manual::Embedder::new(options)) | ||||
|             } | ||||
|             EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
| @@ -224,6 +232,7 @@ impl Embedder { | ||||
|             Embedder::OpenAi(embedder) => embedder.embed(texts), | ||||
|             Embedder::Ollama(embedder) => embedder.embed(texts), | ||||
|             Embedder::UserProvided(embedder) => embedder.embed(texts), | ||||
|             Embedder::Rest(embedder) => embedder.embed(texts), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -240,6 +249,7 @@ impl Embedder { | ||||
|             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), | ||||
|             Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads), | ||||
|             Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), | ||||
|             Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -250,6 +260,7 @@ impl Embedder { | ||||
|             Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), | ||||
|             Embedder::Ollama(embedder) => embedder.chunk_count_hint(), | ||||
|             Embedder::UserProvided(_) => 1, | ||||
|             Embedder::Rest(embedder) => embedder.chunk_count_hint(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -260,6 +271,7 @@ impl Embedder { | ||||
|             Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||
|             Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||
|             Embedder::UserProvided(_) => 1, | ||||
|             Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -270,6 +282,7 @@ impl Embedder { | ||||
|             Embedder::OpenAi(embedder) => embedder.dimensions(), | ||||
|             Embedder::Ollama(embedder) => embedder.dimensions(), | ||||
|             Embedder::UserProvided(embedder) => embedder.dimensions(), | ||||
|             Embedder::Rest(embedder) => embedder.dimensions(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -280,6 +293,7 @@ impl Embedder { | ||||
|             Embedder::OpenAi(embedder) => embedder.distribution(), | ||||
|             Embedder::Ollama(embedder) => embedder.distribution(), | ||||
|             Embedder::UserProvided(_embedder) => None, | ||||
|             Embedder::Rest(embedder) => embedder.distribution(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -288,17 +302,47 @@ impl Embedder { | ||||
| /// | ||||
| /// The intended use is to make the similarity score more comparable to the regular ranking score. | ||||
| /// This allows to correct effects where results are too "packed" around a certain value. | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)] | ||||
| #[serde(from = "DistributionShiftSerializable")] | ||||
| #[serde(into = "DistributionShiftSerializable")] | ||||
| pub struct DistributionShift { | ||||
|     /// Value where the results are "packed". | ||||
|     /// | ||||
|     /// Similarity scores are translated so that they are packed around 0.5 instead | ||||
|     pub current_mean: f32, | ||||
|     pub current_mean: OrderedFloat<f32>, | ||||
|  | ||||
|     /// standard deviation of a similarity score. | ||||
|     /// | ||||
|     /// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed. | ||||
|     pub current_sigma: f32, | ||||
|     pub current_sigma: OrderedFloat<f32>, | ||||
| } | ||||
|  | ||||
| #[derive(Serialize, Deserialize)] | ||||
| struct DistributionShiftSerializable { | ||||
|     current_mean: f32, | ||||
|     current_sigma: f32, | ||||
| } | ||||
|  | ||||
| impl From<DistributionShift> for DistributionShiftSerializable { | ||||
|     fn from( | ||||
|         DistributionShift { | ||||
|             current_mean: OrderedFloat(current_mean), | ||||
|             current_sigma: OrderedFloat(current_sigma), | ||||
|         }: DistributionShift, | ||||
|     ) -> Self { | ||||
|         Self { current_mean, current_sigma } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<DistributionShiftSerializable> for DistributionShift { | ||||
|     fn from( | ||||
|         DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable, | ||||
|     ) -> Self { | ||||
|         Self { | ||||
|             current_mean: OrderedFloat(current_mean), | ||||
|             current_sigma: OrderedFloat(current_sigma), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl DistributionShift { | ||||
| @@ -307,11 +351,13 @@ impl DistributionShift { | ||||
|         if sigma <= 0.0 { | ||||
|             None | ||||
|         } else { | ||||
|             Some(Self { current_mean: mean, current_sigma: sigma }) | ||||
|             Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn shift(&self, score: f32) -> f32 { | ||||
|         let current_mean = self.current_mean.0; | ||||
|         let current_sigma = self.current_sigma.0; | ||||
|         // <https://math.stackexchange.com/a/2894689> | ||||
|         // We're somewhat abusively mapping the distribution of distances to a gaussian. | ||||
|         // The parameters we're given is the mean and sigma of the native result distribution. | ||||
| @@ -321,9 +367,9 @@ impl DistributionShift { | ||||
|         let target_sigma = 0.4; | ||||
|  | ||||
|         // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. | ||||
|         let factor = target_sigma / self.current_sigma; | ||||
|         let factor = target_sigma / current_sigma; | ||||
|         // a*mu1 + b = mu2 => b = mu2 - a*mu1 | ||||
|         let offset = target_mean - (factor * self.current_mean); | ||||
|         let offset = target_mean - (factor * current_mean); | ||||
|  | ||||
|         let mut score = factor * score + offset; | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| use ordered_float::OrderedFloat; | ||||
| use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; | ||||
|  | ||||
| use super::error::{EmbedError, NewEmbedderError}; | ||||
| @@ -110,15 +111,18 @@ impl EmbeddingModel { | ||||
|  | ||||
|     fn distribution(&self) -> Option<DistributionShift> { | ||||
|         match self { | ||||
|             EmbeddingModel::TextEmbeddingAda002 => { | ||||
|                 Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) | ||||
|             } | ||||
|             EmbeddingModel::TextEmbedding3Large => { | ||||
|                 Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 }) | ||||
|             } | ||||
|             EmbeddingModel::TextEmbedding3Small => { | ||||
|                 Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 }) | ||||
|             } | ||||
|             EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift { | ||||
|                 current_mean: OrderedFloat(0.90), | ||||
|                 current_sigma: OrderedFloat(0.08), | ||||
|             }), | ||||
|             EmbeddingModel::TextEmbedding3Large => Some(DistributionShift { | ||||
|                 current_mean: OrderedFloat(0.70), | ||||
|                 current_sigma: OrderedFloat(0.1), | ||||
|             }), | ||||
|             EmbeddingModel::TextEmbedding3Small => Some(DistributionShift { | ||||
|                 current_mean: OrderedFloat(0.75), | ||||
|                 current_sigma: OrderedFloat(0.1), | ||||
|             }), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| use deserr::Deserr; | ||||
| use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
| use serde::Serialize; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use super::{ | ||||
|     DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, | ||||
| @@ -64,7 +65,7 @@ pub struct Embedder { | ||||
|     dimensions: usize, | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] | ||||
| pub struct EmbedderOptions { | ||||
|     pub api_key: Option<String>, | ||||
|     pub distribution: Option<DistributionShift>, | ||||
| @@ -79,7 +80,41 @@ pub struct EmbedderOptions { | ||||
|     pub input_type: InputType, | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| impl Default for EmbedderOptions { | ||||
|     fn default() -> Self { | ||||
|         Self { | ||||
|             url: Default::default(), | ||||
|             query: Default::default(), | ||||
|             input_field: vec!["input".into()], | ||||
|             path_to_embeddings: vec!["data".into()], | ||||
|             embedding_object: vec!["embedding".into()], | ||||
|             input_type: InputType::Text, | ||||
|             api_key: None, | ||||
|             distribution: None, | ||||
|             dimensions: None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::hash::Hash for EmbedderOptions { | ||||
|     fn hash<H: std::hash::Hasher>(&self, state: &mut H) { | ||||
|         self.api_key.hash(state); | ||||
|         self.distribution.hash(state); | ||||
|         self.dimensions.hash(state); | ||||
|         self.url.hash(state); | ||||
|         // skip hashing the query | ||||
|         // collisions in regular usage should be minimal, | ||||
|         // and the list is limited to 256 values anyway | ||||
|         self.input_field.hash(state); | ||||
|         self.path_to_embeddings.hash(state); | ||||
|         self.embedding_object.hash(state); | ||||
|         self.input_type.hash(state); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)] | ||||
| #[serde(rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub enum InputType { | ||||
|     Text, | ||||
|     TextArray, | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| use deserr::Deserr; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use super::rest::InputType; | ||||
| use super::{ollama, openai}; | ||||
| use crate::prompt::PromptData; | ||||
| use crate::update::Setting; | ||||
| @@ -29,6 +30,24 @@ pub struct EmbeddingSettings { | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub document_template: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub url: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub query: Setting<serde_json::Value>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub input_field: Setting<Vec<String>>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub path_to_embeddings: Setting<Vec<String>>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub embedding_object: Setting<Vec<String>>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub input_type: Setting<InputType>, | ||||
| } | ||||
|  | ||||
| pub fn check_unset<T>( | ||||
| @@ -75,20 +94,42 @@ impl EmbeddingSettings { | ||||
|     pub const DIMENSIONS: &'static str = "dimensions"; | ||||
|     pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; | ||||
|  | ||||
|     pub const URL: &'static str = "url"; | ||||
|     pub const QUERY: &'static str = "query"; | ||||
|     pub const INPUT_FIELD: &'static str = "inputField"; | ||||
|     pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings"; | ||||
|     pub const EMBEDDING_OBJECT: &'static str = "embeddingObject"; | ||||
|     pub const INPUT_TYPE: &'static str = "inputType"; | ||||
|  | ||||
|     pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { | ||||
|         match field { | ||||
|             Self::SOURCE => { | ||||
|                 &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] | ||||
|             } | ||||
|             Self::SOURCE => &[ | ||||
|                 EmbedderSource::HuggingFace, | ||||
|                 EmbedderSource::OpenAi, | ||||
|                 EmbedderSource::UserProvided, | ||||
|                 EmbedderSource::Rest, | ||||
|                 EmbedderSource::Ollama, | ||||
|             ], | ||||
|             Self::MODEL => { | ||||
|                 &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] | ||||
|             } | ||||
|             Self::REVISION => &[EmbedderSource::HuggingFace], | ||||
|             Self::API_KEY => &[EmbedderSource::OpenAi], | ||||
|             Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], | ||||
|             Self::DOCUMENT_TEMPLATE => { | ||||
|                 &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] | ||||
|             Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest], | ||||
|             Self::DIMENSIONS => { | ||||
|                 &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest] | ||||
|             } | ||||
|             Self::DOCUMENT_TEMPLATE => &[ | ||||
|                 EmbedderSource::HuggingFace, | ||||
|                 EmbedderSource::OpenAi, | ||||
|                 EmbedderSource::Ollama, | ||||
|                 EmbedderSource::Rest, | ||||
|             ], | ||||
|             Self::URL => &[EmbedderSource::Rest], | ||||
|             Self::QUERY => &[EmbedderSource::Rest], | ||||
|             Self::INPUT_FIELD => &[EmbedderSource::Rest], | ||||
|             Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], | ||||
|             Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest], | ||||
|             Self::INPUT_TYPE => &[EmbedderSource::Rest], | ||||
|             _other => unreachable!("unknown field"), | ||||
|         } | ||||
|     } | ||||
| @@ -107,6 +148,18 @@ impl EmbeddingSettings { | ||||
|             } | ||||
|             EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], | ||||
|             EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], | ||||
|             EmbedderSource::Rest => &[ | ||||
|                 Self::SOURCE, | ||||
|                 Self::API_KEY, | ||||
|                 Self::DIMENSIONS, | ||||
|                 Self::DOCUMENT_TEMPLATE, | ||||
|                 Self::URL, | ||||
|                 Self::QUERY, | ||||
|                 Self::INPUT_FIELD, | ||||
|                 Self::PATH_TO_EMBEDDINGS, | ||||
|                 Self::EMBEDDING_OBJECT, | ||||
|                 Self::INPUT_TYPE, | ||||
|             ], | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -141,6 +194,7 @@ pub enum EmbedderSource { | ||||
|     HuggingFace, | ||||
|     Ollama, | ||||
|     UserProvided, | ||||
|     Rest, | ||||
| } | ||||
|  | ||||
| impl std::fmt::Display for EmbedderSource { | ||||
| @@ -150,6 +204,7 @@ impl std::fmt::Display for EmbedderSource { | ||||
|             EmbedderSource::HuggingFace => "huggingFace", | ||||
|             EmbedderSource::UserProvided => "userProvided", | ||||
|             EmbedderSource::Ollama => "ollama", | ||||
|             EmbedderSource::Rest => "rest", | ||||
|         }; | ||||
|         f.write_str(s) | ||||
|     } | ||||
| @@ -157,8 +212,20 @@ impl std::fmt::Display for EmbedderSource { | ||||
|  | ||||
| impl EmbeddingSettings { | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = | ||||
|             new; | ||||
|         let EmbeddingSettings { | ||||
|             source, | ||||
|             model, | ||||
|             revision, | ||||
|             api_key, | ||||
|             dimensions, | ||||
|             document_template, | ||||
|             url, | ||||
|             query, | ||||
|             input_field, | ||||
|             path_to_embeddings, | ||||
|             embedding_object, | ||||
|             input_type, | ||||
|         } = new; | ||||
|         let old_source = self.source; | ||||
|         self.source.apply(source); | ||||
|         // Reinitialize the whole setting object on a source change | ||||
| @@ -170,6 +237,12 @@ impl EmbeddingSettings { | ||||
|                 api_key, | ||||
|                 dimensions, | ||||
|                 document_template, | ||||
|                 url, | ||||
|                 query, | ||||
|                 input_field, | ||||
|                 path_to_embeddings, | ||||
|                 embedding_object, | ||||
|                 input_type, | ||||
|             }; | ||||
|             return; | ||||
|         } | ||||
| @@ -179,6 +252,13 @@ impl EmbeddingSettings { | ||||
|         self.api_key.apply(api_key); | ||||
|         self.dimensions.apply(dimensions); | ||||
|         self.document_template.apply(document_template); | ||||
|  | ||||
|         self.url.apply(url); | ||||
|         self.query.apply(query); | ||||
|         self.input_field.apply(input_field); | ||||
|         self.path_to_embeddings.apply(path_to_embeddings); | ||||
|         self.embedding_object.apply(embedding_object); | ||||
|         self.input_type.apply(input_type); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -193,6 +273,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 api_key: Setting::NotSet, | ||||
|                 dimensions: Setting::NotSet, | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
|                 url: Setting::NotSet, | ||||
|                 query: Setting::NotSet, | ||||
|                 input_field: Setting::NotSet, | ||||
|                 path_to_embeddings: Setting::NotSet, | ||||
|                 embedding_object: Setting::NotSet, | ||||
|                 input_type: Setting::NotSet, | ||||
|             }, | ||||
|             super::EmbedderOptions::OpenAi(options) => Self { | ||||
|                 source: Setting::Set(EmbedderSource::OpenAi), | ||||
| @@ -201,6 +287,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 api_key: options.api_key.map(Setting::Set).unwrap_or_default(), | ||||
|                 dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(), | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
|                 url: Setting::NotSet, | ||||
|                 query: Setting::NotSet, | ||||
|                 input_field: Setting::NotSet, | ||||
|                 path_to_embeddings: Setting::NotSet, | ||||
|                 embedding_object: Setting::NotSet, | ||||
|                 input_type: Setting::NotSet, | ||||
|             }, | ||||
|             super::EmbedderOptions::Ollama(options) => Self { | ||||
|                 source: Setting::Set(EmbedderSource::Ollama), | ||||
| @@ -209,6 +301,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 api_key: Setting::NotSet, | ||||
|                 dimensions: Setting::NotSet, | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
|                 url: Setting::NotSet, | ||||
|                 query: Setting::NotSet, | ||||
|                 input_field: Setting::NotSet, | ||||
|                 path_to_embeddings: Setting::NotSet, | ||||
|                 embedding_object: Setting::NotSet, | ||||
|                 input_type: Setting::NotSet, | ||||
|             }, | ||||
|             super::EmbedderOptions::UserProvided(options) => Self { | ||||
|                 source: Setting::Set(EmbedderSource::UserProvided), | ||||
| @@ -217,6 +315,37 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 api_key: Setting::NotSet, | ||||
|                 dimensions: Setting::Set(options.dimensions), | ||||
|                 document_template: Setting::NotSet, | ||||
|                 url: Setting::NotSet, | ||||
|                 query: Setting::NotSet, | ||||
|                 input_field: Setting::NotSet, | ||||
|                 path_to_embeddings: Setting::NotSet, | ||||
|                 embedding_object: Setting::NotSet, | ||||
|                 input_type: Setting::NotSet, | ||||
|             }, | ||||
|             super::EmbedderOptions::Rest(super::rest::EmbedderOptions { | ||||
|                 api_key, | ||||
|                 // TODO: support distribution | ||||
|                 distribution: _, | ||||
|                 dimensions, | ||||
|                 url, | ||||
|                 query, | ||||
|                 input_field, | ||||
|                 path_to_embeddings, | ||||
|                 embedding_object, | ||||
|                 input_type, | ||||
|             }) => Self { | ||||
|                 source: Setting::Set(EmbedderSource::Rest), | ||||
|                 model: Setting::NotSet, | ||||
|                 revision: Setting::NotSet, | ||||
|                 api_key: api_key.map(Setting::Set).unwrap_or_default(), | ||||
|                 dimensions: dimensions.map(Setting::Set).unwrap_or_default(), | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
|                 url: Setting::Set(url), | ||||
|                 query: Setting::Set(query), | ||||
|                 input_field: Setting::Set(input_field), | ||||
|                 path_to_embeddings: Setting::Set(path_to_embeddings), | ||||
|                 embedding_object: Setting::Set(embedding_object), | ||||
|                 input_type: Setting::Set(input_type), | ||||
|             }, | ||||
|         } | ||||
|     } | ||||
| @@ -225,8 +354,20 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
| impl From<EmbeddingSettings> for EmbeddingConfig { | ||||
|     fn from(value: EmbeddingSettings) -> Self { | ||||
|         let mut this = Self::default(); | ||||
|         let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = | ||||
|             value; | ||||
|         let EmbeddingSettings { | ||||
|             source, | ||||
|             model, | ||||
|             revision, | ||||
|             api_key, | ||||
|             dimensions, | ||||
|             document_template, | ||||
|             url, | ||||
|             query, | ||||
|             input_field, | ||||
|             path_to_embeddings, | ||||
|             embedding_object, | ||||
|             input_type, | ||||
|         } = value; | ||||
|         if let Some(source) = source.set() { | ||||
|             match source { | ||||
|                 EmbedderSource::OpenAi => { | ||||
| @@ -274,6 +415,26 @@ impl From<EmbeddingSettings> for EmbeddingConfig { | ||||
|                             dimensions: dimensions.set().unwrap(), | ||||
|                         }); | ||||
|                 } | ||||
|                 EmbedderSource::Rest => { | ||||
|                     let embedder_options = super::rest::EmbedderOptions::default(); | ||||
|  | ||||
|                     this.embedder_options = | ||||
|                         super::EmbedderOptions::Rest(super::rest::EmbedderOptions { | ||||
|                             api_key: api_key.set(), | ||||
|                             distribution: None, | ||||
|                             dimensions: dimensions.set(), | ||||
|                             url: url.set().unwrap(), | ||||
|                             query: query.set().unwrap_or(embedder_options.query), | ||||
|                             input_field: input_field.set().unwrap_or(embedder_options.input_field), | ||||
|                             path_to_embeddings: path_to_embeddings | ||||
|                                 .set() | ||||
|                                 .unwrap_or(embedder_options.path_to_embeddings), | ||||
|                             embedding_object: embedding_object | ||||
|                                 .set() | ||||
|                                 .unwrap_or(embedder_options.embedding_object), | ||||
|                             input_type: input_type.set().unwrap_or(embedder_options.input_type), | ||||
|                         }) | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user