mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-30 23:46:28 +00:00 
			
		
		
		
	Fix multiple embeddings in hf
This commit is contained in:
		| @@ -255,34 +255,8 @@ impl Embedder { | ||||
|         Ok(this) | ||||
|     } | ||||
|  | ||||
|     pub fn embed(&self, mut texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> { | ||||
|         let tokens = match texts.len() { | ||||
|             1 => vec![self | ||||
|                 .tokenizer | ||||
|                 .encode(texts.pop().unwrap(), true) | ||||
|                 .map_err(EmbedError::tokenize)?], | ||||
|             _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, | ||||
|         }; | ||||
|         let token_ids = tokens | ||||
|             .iter() | ||||
|             .map(|tokens| { | ||||
|                 let mut tokens = tokens.get_ids().to_vec(); | ||||
|                 tokens.truncate(512); | ||||
|                 Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) | ||||
|             }) | ||||
|             .collect::<Result<Vec<_>, EmbedError>>()?; | ||||
|  | ||||
|         let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; | ||||
|         let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; | ||||
|         let embeddings = self | ||||
|             .model | ||||
|             .forward(&token_ids, &token_type_ids, None) | ||||
|             .map_err(EmbedError::model_forward)?; | ||||
|  | ||||
|         let embeddings = Self::pooling(embeddings, self.pooling)?; | ||||
|  | ||||
|         let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; | ||||
|         Ok(embeddings) | ||||
|     pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> { | ||||
|         texts.into_iter().map(|text| self.embed_one(&text)).collect() | ||||
|     } | ||||
|  | ||||
|     fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user