mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-31 07:56:28 +00:00 
			
		
		
		
	feat: add new models and ability to override dimensions
This commit is contained in:
		| @@ -17,6 +17,7 @@ pub struct Embedder { | |||||||
| pub struct EmbedderOptions { | pub struct EmbedderOptions { | ||||||
|     pub api_key: Option<String>, |     pub api_key: Option<String>, | ||||||
|     pub embedding_model: EmbeddingModel, |     pub embedding_model: EmbeddingModel, | ||||||
|  |     pub dimensions: Option<usize>, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive( | #[derive( | ||||||
| @@ -41,34 +42,54 @@ pub enum EmbeddingModel { | |||||||
|     #[serde(rename = "text-embedding-ada-002")] |     #[serde(rename = "text-embedding-ada-002")] | ||||||
|     #[deserr(rename = "text-embedding-ada-002")] |     #[deserr(rename = "text-embedding-ada-002")] | ||||||
|     TextEmbeddingAda002, |     TextEmbeddingAda002, | ||||||
|  |  | ||||||
|  |     #[serde(rename = "text-embedding-3-small")] | ||||||
|  |     #[deserr(rename = "text-embedding-3-small")] | ||||||
|  |     TextEmbedding3Small, | ||||||
|  |  | ||||||
|  |     #[serde(rename = "text-embedding-3-large")] | ||||||
|  |     #[deserr(rename = "text-embedding-3-large")] | ||||||
|  |     TextEmbedding3Large, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl EmbeddingModel { | impl EmbeddingModel { | ||||||
|     pub fn supported_models() -> &'static [&'static str] { |     pub fn supported_models() -> &'static [&'static str] { | ||||||
|         &["text-embedding-ada-002"] |         &["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"] | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn max_token(&self) -> usize { |     pub fn max_token(&self) -> usize { | ||||||
|         match self { |         match self { | ||||||
|             EmbeddingModel::TextEmbeddingAda002 => 8191, |             EmbeddingModel::TextEmbeddingAda002 => 8191, | ||||||
|  |             EmbeddingModel::TextEmbedding3Large => 8191, | ||||||
|  |             EmbeddingModel::TextEmbedding3Small => 8191, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn dimensions(&self) -> usize { |     pub fn dimensions(&self) -> usize { | ||||||
|         match self { |         match self { | ||||||
|             EmbeddingModel::TextEmbeddingAda002 => 1536, |             EmbeddingModel::TextEmbeddingAda002 => 1536, | ||||||
|  |  | ||||||
|  |             //Default value for the model | ||||||
|  |             EmbeddingModel::TextEmbedding3Large => 1536, | ||||||
|  |  | ||||||
|  |             //Default value for the model | ||||||
|  |             EmbeddingModel::TextEmbedding3Small => 3072, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn name(&self) -> &'static str { |     pub fn name(&self) -> &'static str { | ||||||
|         match self { |         match self { | ||||||
|             EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", |             EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", | ||||||
|  |             EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large", | ||||||
|  |             EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small", | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn from_name(name: &str) -> Option<Self> { |     pub fn from_name(name: &str) -> Option<Self> { | ||||||
|         match name { |         match name { | ||||||
|             "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), |             "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), | ||||||
|  |             "text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large), | ||||||
|  |             "text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small), | ||||||
|             _ => None, |             _ => None, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -78,6 +99,20 @@ impl EmbeddingModel { | |||||||
|             EmbeddingModel::TextEmbeddingAda002 => { |             EmbeddingModel::TextEmbeddingAda002 => { | ||||||
|                 Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) |                 Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) | ||||||
|             } |             } | ||||||
|  |             EmbeddingModel::TextEmbedding3Large => { | ||||||
|  |                 Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) | ||||||
|  |             } | ||||||
|  |             EmbeddingModel::TextEmbedding3Small => { | ||||||
|  |                 Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn is_optional_dimensions_supported(&self) -> bool { | ||||||
|  |         match self { | ||||||
|  |             EmbeddingModel::TextEmbeddingAda002 => false, | ||||||
|  |             EmbeddingModel::TextEmbedding3Large => true, | ||||||
|  |             EmbeddingModel::TextEmbedding3Small => true, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -86,11 +121,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; | |||||||
|  |  | ||||||
| impl EmbedderOptions { | impl EmbedderOptions { | ||||||
|     pub fn with_default_model(api_key: Option<String>) -> Self { |     pub fn with_default_model(api_key: Option<String>) -> Self { | ||||||
|         Self { api_key, embedding_model: Default::default() } |         Self { api_key, embedding_model: Default::default(), dimensions: None } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { |     pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { | ||||||
|         Self { api_key, embedding_model } |         Self { api_key, embedding_model, dimensions: None } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -237,7 +272,15 @@ impl Embedder { | |||||||
|         for text in texts { |         for text in texts { | ||||||
|             log::trace!("Received prompt: {}", text.as_ref()) |             log::trace!("Received prompt: {}", text.as_ref()) | ||||||
|         } |         } | ||||||
|         let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; |         let request = OpenAiRequest { | ||||||
|  |             model: self.options.embedding_model.name(), | ||||||
|  |             input: texts, | ||||||
|  |             dimension: if self.options.embedding_model.is_optional_dimensions_supported() { | ||||||
|  |                 self.options.dimensions.as_ref() | ||||||
|  |             } else { | ||||||
|  |                 None | ||||||
|  |             }, | ||||||
|  |         }; | ||||||
|         let response = client |         let response = client | ||||||
|             .post(OPENAI_EMBEDDINGS_URL) |             .post(OPENAI_EMBEDDINGS_URL) | ||||||
|             .json(&request) |             .json(&request) | ||||||
| @@ -366,7 +409,7 @@ impl Embedder { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn dimensions(&self) -> usize { |     pub fn dimensions(&self) -> usize { | ||||||
|         self.options.embedding_model.dimensions() |         self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn distribution(&self) -> Option<DistributionShift> { |     pub fn distribution(&self) -> Option<DistributionShift> { | ||||||
| @@ -431,6 +474,7 @@ impl Retry { | |||||||
| struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { | struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { | ||||||
|     model: &'a str, |     model: &'a str, | ||||||
|     input: &'a [S], |     input: &'a [S], | ||||||
|  |     dimension: Option<&'a usize>, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, Serialize)] | #[derive(Debug, Serialize)] | ||||||
|   | |||||||
| @@ -208,6 +208,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig { | |||||||
|                     if let Some(api_key) = api_key.set() { |                     if let Some(api_key) = api_key.set() { | ||||||
|                         options.api_key = Some(api_key); |                         options.api_key = Some(api_key); | ||||||
|                     } |                     } | ||||||
|  |                     if let Some(dimensions) = dimensions.set() { | ||||||
|  |                         options.dimensions = Some(dimensions); | ||||||
|  |                     } | ||||||
|                     this.embedder_options = super::EmbedderOptions::OpenAi(options); |                     this.embedder_options = super::EmbedderOptions::OpenAi(options); | ||||||
|                 } |                 } | ||||||
|                 EmbedderSource::HuggingFace => { |                 EmbedderSource::HuggingFace => { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user