mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-30 23:46:28 +00:00 
			
		
		
		
	Support pooling
This commit is contained in:
		| @@ -262,6 +262,31 @@ impl NewEmbedderError { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn open_pooling_config( | ||||
|         pooling_config_filename: PathBuf, | ||||
|         inner: std::io::Error, | ||||
|     ) -> NewEmbedderError { | ||||
|         let open_config = OpenPoolingConfig { filename: pooling_config_filename, inner }; | ||||
|  | ||||
|         Self { | ||||
|             kind: NewEmbedderErrorKind::OpenPoolingConfig(open_config), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn deserialize_pooling_config( | ||||
|         model_name: String, | ||||
|         pooling_config_filename: PathBuf, | ||||
|         inner: serde_json::Error, | ||||
|     ) -> NewEmbedderError { | ||||
|         let deserialize_pooling_config = | ||||
|             DeserializePoolingConfig { model_name, filename: pooling_config_filename, inner }; | ||||
|         Self { | ||||
|             kind: NewEmbedderErrorKind::DeserializePoolingConfig(deserialize_pooling_config), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn open_tokenizer( | ||||
|         tokenizer_filename: PathBuf, | ||||
|         inner: Box<dyn std::error::Error + Send + Sync>, | ||||
| @@ -319,6 +344,13 @@ pub struct OpenConfig { | ||||
|     pub inner: std::io::Error, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("could not open pooling config at {filename}: {inner}")] | ||||
| pub struct OpenPoolingConfig { | ||||
|     pub filename: PathBuf, | ||||
|     pub inner: std::io::Error, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")] | ||||
| pub struct DeserializeConfig { | ||||
| @@ -327,6 +359,14 @@ pub struct DeserializeConfig { | ||||
|     pub inner: serde_json::Error, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("for model '{model_name}', could not deserialize file at `{filename}` as a pooling config: {inner}")] | ||||
| pub struct DeserializePoolingConfig { | ||||
|     pub model_name: String, | ||||
|     pub filename: PathBuf, | ||||
|     pub inner: serde_json::Error, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("model `{model_name}` appears to be unsupported{}\n  - inner error: {inner}", | ||||
| if architectures.is_empty() { | ||||
| @@ -354,8 +394,12 @@ pub enum NewEmbedderErrorKind { | ||||
|     #[error(transparent)] | ||||
|     OpenConfig(OpenConfig), | ||||
|     #[error(transparent)] | ||||
|     OpenPoolingConfig(OpenPoolingConfig), | ||||
|     #[error(transparent)] | ||||
|     DeserializeConfig(DeserializeConfig), | ||||
|     #[error(transparent)] | ||||
|     DeserializePoolingConfig(DeserializePoolingConfig), | ||||
|     #[error(transparent)] | ||||
|     UnsupportedModel(UnsupportedModel), | ||||
|     #[error(transparent)] | ||||
|     OpenTokenizer(OpenTokenizer), | ||||
|   | ||||
| @@ -58,6 +58,7 @@ pub struct Embedder { | ||||
|     tokenizer: Tokenizer, | ||||
|     options: EmbedderOptions, | ||||
|     dimensions: usize, | ||||
|     pooling: Pooling, | ||||
| } | ||||
|  | ||||
| impl std::fmt::Debug for Embedder { | ||||
| @@ -66,10 +67,53 @@ impl std::fmt::Debug for Embedder { | ||||
|             .field("model", &self.options.model) | ||||
|             .field("tokenizer", &self.tokenizer) | ||||
|             .field("options", &self.options) | ||||
|             .field("pooling", &self.pooling) | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Clone, Copy, serde::Deserialize)] | ||||
| struct PoolingConfig { | ||||
|     #[serde(default)] | ||||
|     pub pooling_mode_cls_token: bool, | ||||
|     #[serde(default)] | ||||
|     pub pooling_mode_mean_tokens: bool, | ||||
|     #[serde(default)] | ||||
|     pub pooling_mode_max_tokens: bool, | ||||
|     #[serde(default)] | ||||
|     pub pooling_mode_mean_sqrt_len_tokens: bool, | ||||
|     #[serde(default)] | ||||
|     pub pooling_mode_lasttoken: bool, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, Default)] | ||||
| pub enum Pooling { | ||||
|     #[default] | ||||
|     Mean, | ||||
|     Cls, | ||||
|     Max, | ||||
|     MeanSqrtLen, | ||||
|     LastToken, | ||||
| } | ||||
|  | ||||
| impl From<PoolingConfig> for Pooling { | ||||
|     fn from(value: PoolingConfig) -> Self { | ||||
|         if value.pooling_mode_cls_token { | ||||
|             Self::Cls | ||||
|         } else if value.pooling_mode_mean_tokens { | ||||
|             Self::Mean | ||||
|         } else if value.pooling_mode_lasttoken { | ||||
|             Self::LastToken | ||||
|         } else if value.pooling_mode_mean_sqrt_len_tokens { | ||||
|             Self::MeanSqrtLen | ||||
|         } else if value.pooling_mode_max_tokens { | ||||
|             Self::Max | ||||
|         } else { | ||||
|             Self::default() | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { | ||||
|         let device = match candle_core::Device::cuda_if_available(0) { | ||||
| @@ -83,7 +127,7 @@ impl Embedder { | ||||
|             Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), | ||||
|             None => Repo::model(options.model.clone()), | ||||
|         }; | ||||
|         let (config_filename, tokenizer_filename, weights_filename, weight_source) = { | ||||
|         let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = { | ||||
|             let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; | ||||
|             let api = api.repo(repo); | ||||
|             let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; | ||||
| @@ -97,7 +141,36 @@ impl Embedder { | ||||
|                     }) | ||||
|                     .map_err(NewEmbedderError::api_get)? | ||||
|             }; | ||||
|             (config, tokenizer, weights, source) | ||||
|             let pooling = match api.get("1_Pooling/config.json") { | ||||
|                 Ok(pooling) => Some(pooling), | ||||
|                 Err(hf_hub::api::sync::ApiError::RequestError(error)) | ||||
|                     if matches!(*error, ureq::Error::Status(404, _,)) => | ||||
|                 { | ||||
|                     // ignore the error if the file simply doesn't exist | ||||
|                     None | ||||
|                 } | ||||
|                 Err(error) => return Err(NewEmbedderError::api_get(error)), | ||||
|             }; | ||||
|             let pooling: Pooling = match pooling { | ||||
|                 Some(pooling_filename) => { | ||||
|                     let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| { | ||||
|                         NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner) | ||||
|                     })?; | ||||
|  | ||||
|                     let pooling: PoolingConfig = | ||||
|                         serde_json::from_str(&pooling).map_err(|inner| { | ||||
|                             NewEmbedderError::deserialize_pooling_config( | ||||
|                                 options.model.clone(), | ||||
|                                 pooling_filename, | ||||
|                                 inner, | ||||
|                             ) | ||||
|                         })?; | ||||
|                     pooling.into() | ||||
|                 } | ||||
|                 None => Pooling::default(), | ||||
|             }; | ||||
|  | ||||
|             (config, tokenizer, weights, source, pooling) | ||||
|         }; | ||||
|  | ||||
|         let config = std::fs::read_to_string(&config_filename) | ||||
| @@ -122,6 +195,8 @@ impl Embedder { | ||||
|             }, | ||||
|         }; | ||||
|  | ||||
|         tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config"); | ||||
|  | ||||
|         let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; | ||||
|  | ||||
|         if let Some(pp) = tokenizer.get_padding_mut() { | ||||
| @@ -134,7 +209,7 @@ impl Embedder { | ||||
|             tokenizer.with_padding(Some(pp)); | ||||
|         } | ||||
|  | ||||
|         let mut this = Self { model, tokenizer, options, dimensions: 0 }; | ||||
|         let mut this = Self { model, tokenizer, options, dimensions: 0, pooling }; | ||||
|  | ||||
|         let embeddings = this | ||||
|             .embed(vec!["test".into()]) | ||||
| @@ -168,17 +243,53 @@ impl Embedder { | ||||
|             .forward(&token_ids, &token_type_ids, None) | ||||
|             .map_err(EmbedError::model_forward)?; | ||||
|  | ||||
|         // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) | ||||
|             .map_err(EmbedError::tensor_shape)?; | ||||
|         let embeddings = Self::pooling(embeddings, self.pooling)?; | ||||
|  | ||||
|         let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; | ||||
|         Ok(embeddings) | ||||
|     } | ||||
|  | ||||
|     fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> { | ||||
|         match pooling { | ||||
|             Pooling::Mean => Self::mean_pooling(embeddings), | ||||
|             Pooling::Cls => Self::cls_pooling(embeddings), | ||||
|             Pooling::Max => Self::max_pooling(embeddings), | ||||
|             Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings), | ||||
|             Pooling::LastToken => Self::last_token_pooling(embeddings), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn cls_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> { | ||||
|         embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value) | ||||
|     } | ||||
|  | ||||
|     fn mean_sqrt_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> { | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt()) | ||||
|             .map_err(EmbedError::tensor_shape) | ||||
|     } | ||||
|  | ||||
|     fn mean_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> { | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) | ||||
|             .map_err(EmbedError::tensor_shape) | ||||
|     } | ||||
|  | ||||
|     fn max_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> { | ||||
|         embeddings.max(1).map_err(EmbedError::tensor_shape) | ||||
|     } | ||||
|  | ||||
|     fn last_token_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> { | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value) | ||||
|     } | ||||
|  | ||||
|     pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> { | ||||
|         let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; | ||||
|         let token_ids = tokens.get_ids(); | ||||
| @@ -192,11 +303,8 @@ impl Embedder { | ||||
|             .forward(&token_ids, &token_type_ids, None) | ||||
|             .map_err(EmbedError::model_forward)?; | ||||
|  | ||||
|         // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||
|         let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) | ||||
|             .map_err(EmbedError::tensor_shape)?; | ||||
|         let embedding = Self::pooling(embeddings, self.pooling)?; | ||||
|  | ||||
|         let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?; | ||||
|         let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; | ||||
|         Ok(embedding) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user