mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-11-04 01:46:28 +00:00 
			
		
		
		
	Fix multiple embeddings in hf
This commit is contained in:
		@@ -255,34 +255,8 @@ impl Embedder {
 | 
			
		||||
        Ok(this)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn embed(&self, mut texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
 | 
			
		||||
        let tokens = match texts.len() {
 | 
			
		||||
            1 => vec![self
 | 
			
		||||
                .tokenizer
 | 
			
		||||
                .encode(texts.pop().unwrap(), true)
 | 
			
		||||
                .map_err(EmbedError::tokenize)?],
 | 
			
		||||
            _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
 | 
			
		||||
        };
 | 
			
		||||
        let token_ids = tokens
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|tokens| {
 | 
			
		||||
                let mut tokens = tokens.get_ids().to_vec();
 | 
			
		||||
                tokens.truncate(512);
 | 
			
		||||
                Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
 | 
			
		||||
            })
 | 
			
		||||
            .collect::<Result<Vec<_>, EmbedError>>()?;
 | 
			
		||||
 | 
			
		||||
        let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
 | 
			
		||||
        let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
 | 
			
		||||
        let embeddings = self
 | 
			
		||||
            .model
 | 
			
		||||
            .forward(&token_ids, &token_type_ids, None)
 | 
			
		||||
            .map_err(EmbedError::model_forward)?;
 | 
			
		||||
 | 
			
		||||
        let embeddings = Self::pooling(embeddings, self.pooling)?;
 | 
			
		||||
 | 
			
		||||
        let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
 | 
			
		||||
        Ok(embeddings)
 | 
			
		||||
    pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
 | 
			
		||||
        texts.into_iter().map(|text| self.embed_one(&text)).collect()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user