mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-30 23:46:28 +00:00 
			
		
		
		
	Fix after upgrading candle
This commit is contained in:
		| @@ -163,8 +163,10 @@ impl Embedder { | ||||
|  | ||||
|         let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; | ||||
|         let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; | ||||
|         let embeddings = | ||||
|             self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; | ||||
|         let embeddings = self | ||||
|             .model | ||||
|             .forward(&token_ids, &token_type_ids, None) | ||||
|             .map_err(EmbedError::model_forward)?; | ||||
|  | ||||
|         // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
| @@ -185,8 +187,10 @@ impl Embedder { | ||||
|             Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; | ||||
|         let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; | ||||
|         let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; | ||||
|         let embeddings = | ||||
|             self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; | ||||
|         let embeddings = self | ||||
|             .model | ||||
|             .forward(&token_ids, &token_type_ids, None) | ||||
|             .map_err(EmbedError::model_forward)?; | ||||
|  | ||||
|         // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| use std::fmt; | ||||
| use std::time::Instant; | ||||
|  | ||||
| use ordered_float::OrderedFloat; | ||||
| @@ -168,7 +169,6 @@ fn infer_api_key() -> String { | ||||
|         .unwrap_or_default() | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     tokenizer: tiktoken_rs::CoreBPE, | ||||
|     rest_embedder: RestEmbedder, | ||||
| @@ -302,3 +302,13 @@ impl Embedder { | ||||
|         self.options.distribution() | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for Embedder { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         f.debug_struct("Embedder") | ||||
|             .field("tokenizer", &"CoreBPE") | ||||
|             .field("rest_embedder", &self.rest_embedder) | ||||
|             .field("options", &self.options) | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -175,7 +175,7 @@ impl Embedder { | ||||
|  | ||||
|     pub fn embed_tokens( | ||||
|         &self, | ||||
|         tokens: &[usize], | ||||
|         tokens: &[u32], | ||||
|         deadline: Option<Instant>, | ||||
|     ) -> Result<Embedding, EmbedError> { | ||||
|         let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user