mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-31 16:06:31 +00:00 
			
		
		
		
	Remove some settings
This commit is contained in:
		| @@ -23,7 +23,7 @@ use super::{Embedding, Embeddings}; | ||||
| )] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub enum WeightSource { | ||||
| enum WeightSource { | ||||
|     #[default] | ||||
|     Safetensors, | ||||
|     Pytorch, | ||||
| @@ -33,20 +33,13 @@ pub enum WeightSource { | ||||
| pub struct EmbedderOptions { | ||||
|     pub model: String, | ||||
|     pub revision: Option<String>, | ||||
|     pub weight_source: WeightSource, | ||||
|     pub normalize_embeddings: bool, | ||||
| } | ||||
|  | ||||
| impl EmbedderOptions { | ||||
|     pub fn new() -> Self { | ||||
|         Self { | ||||
|             //model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), | ||||
|             model: "BAAI/bge-base-en-v1.5".to_string(), | ||||
|             //revision: Some("refs/pr/21".to_string()), | ||||
|             revision: None, | ||||
|             //weight_source: Default::default(), | ||||
|             weight_source: WeightSource::Pytorch, | ||||
|             normalize_embeddings: true, | ||||
|             revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -82,20 +75,21 @@ impl Embedder { | ||||
|             Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), | ||||
|             None => Repo::model(options.model.clone()), | ||||
|         }; | ||||
|         let (config_filename, tokenizer_filename, weights_filename) = { | ||||
|         let (config_filename, tokenizer_filename, weights_filename, weight_source) = { | ||||
|             let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; | ||||
|             let api = api.repo(repo); | ||||
|             let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; | ||||
|             let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; | ||||
|             let weights = match options.weight_source { | ||||
|                 WeightSource::Pytorch => { | ||||
|                     api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? | ||||
|                 } | ||||
|                 WeightSource::Safetensors => { | ||||
|                     api.get("model.safetensors").map_err(NewEmbedderError::api_get)? | ||||
|                 } | ||||
|             let (weights, source) = { | ||||
|                 api.get("pytorch_model.bin") | ||||
|                     .map(|filename| (filename, WeightSource::Pytorch)) | ||||
|                     .or_else(|_| { | ||||
|                         api.get("model.safetensors") | ||||
|                             .map(|filename| (filename, WeightSource::Safetensors)) | ||||
|                     }) | ||||
|                     .map_err(NewEmbedderError::api_get)? | ||||
|             }; | ||||
|             (config, tokenizer, weights) | ||||
|             (config, tokenizer, weights, source) | ||||
|         }; | ||||
|  | ||||
|         let config = std::fs::read_to_string(&config_filename) | ||||
| @@ -106,7 +100,7 @@ impl Embedder { | ||||
|         let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) | ||||
|             .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; | ||||
|  | ||||
|         let vb = match options.weight_source { | ||||
|         let vb = match weight_source { | ||||
|             WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) | ||||
|                 .map_err(NewEmbedderError::pytorch_weight)?, | ||||
|             WeightSource::Safetensors => unsafe { | ||||
| @@ -168,12 +162,6 @@ impl Embedder { | ||||
|         let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) | ||||
|             .map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         let embeddings: Tensor = if self.options.normalize_embeddings { | ||||
|             normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? | ||||
|         } else { | ||||
|             embeddings | ||||
|         }; | ||||
|  | ||||
|         let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; | ||||
|         Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) | ||||
|     } | ||||
| @@ -197,7 +185,3 @@ impl Embedder { | ||||
|         self.dimensions | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> { | ||||
|     v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) | ||||
| } | ||||
|   | ||||
| @@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use crate::prompt::PromptData; | ||||
| use crate::update::Setting; | ||||
| use crate::vector::hf::WeightSource; | ||||
| use crate::vector::EmbeddingConfig; | ||||
|  | ||||
| #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||
| @@ -204,26 +203,13 @@ pub struct HfEmbedderSettings { | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub revision: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub weight_source: Setting<WeightSource>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub normalize_embeddings: Setting<bool>, | ||||
| } | ||||
|  | ||||
| impl HfEmbedderSettings { | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         let HfEmbedderSettings { | ||||
|             model, | ||||
|             revision, | ||||
|             weight_source, | ||||
|             normalize_embeddings: normalize_embedding, | ||||
|         } = new; | ||||
|         let HfEmbedderSettings { model, revision } = new; | ||||
|         self.model.apply(model); | ||||
|         self.revision.apply(revision); | ||||
|         self.weight_source.apply(weight_source); | ||||
|         self.normalize_embeddings.apply(normalize_embedding); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -232,15 +218,13 @@ impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings { | ||||
|         Self { | ||||
|             model: Setting::Set(value.model), | ||||
|             revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), | ||||
|             weight_source: Setting::Set(value.weight_source), | ||||
|             normalize_embeddings: Setting::Set(value.normalize_embeddings), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions { | ||||
|     fn from(value: HfEmbedderSettings) -> Self { | ||||
|         let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; | ||||
|         let HfEmbedderSettings { model, revision } = value; | ||||
|         let mut this = Self::default(); | ||||
|         if let Some(model) = model.set() { | ||||
|             this.model = model; | ||||
| @@ -248,12 +232,6 @@ impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions { | ||||
|         if let Some(revision) = revision.set() { | ||||
|             this.revision = Some(revision); | ||||
|         } | ||||
|         if let Some(weight_source) = weight_source.set() { | ||||
|             this.weight_source = weight_source; | ||||
|         } | ||||
|         if let Some(normalize_embeddings) = normalize_embeddings.set() { | ||||
|             this.normalize_embeddings = normalize_embeddings; | ||||
|         } | ||||
|         this | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user