diff --git a/crates/milli/src/vector/embedder/hf.rs b/crates/milli/src/vector/embedder/hf.rs index 3c51da5bb..d4d4a7bad 100644 --- a/crates/milli/src/vector/embedder/hf.rs +++ b/crates/milli/src/vector/embedder/hf.rs @@ -2,6 +2,7 @@ use candle_core::Tensor; use candle_nn::VarBuilder; use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE}; use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert}; +use candle_transformers::models::xlm_roberta::{Config as XlmRobertaConfig, XLMRobertaModel}; // FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself use hf_hub::api::sync::Api; use hf_hub::{Repo, RepoType}; @@ -89,6 +90,7 @@ impl Default for EmbedderOptions { enum ModelKind { Bert(BertModel), Modern(ModernBert), + XlmRoberta(XLMRobertaModel), } /// Perform embedding of documents and queries @@ -304,7 +306,8 @@ impl Embedder { }; let is_modern = has_arch("modernbert"); - tracing::debug!(is_modern, model_type, "detected HF architecture"); + let is_xlm_roberta = has_arch("xlm-roberta") || has_arch("xlm_roberta"); + tracing::debug!(is_modern, is_xlm_roberta, model_type, "detected HF architecture"); let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; @@ -340,6 +343,19 @@ impl Embedder { ) })?; ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?) + } else if is_xlm_roberta { + let config: XlmRobertaConfig = + serde_json::from_str(&config_str).map_err(|inner| { + NewEmbedderError::deserialize_config( + options.model.clone(), + config_str.clone(), + config_filename.clone(), + inner, + ) + })?; + ModelKind::XlmRoberta( + XLMRobertaModel::new(&config, vb).map_err(NewEmbedderError::load_model)?, + ) } else { let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| { NewEmbedderError::deserialize_config( @@ -451,6 +467,19 @@ impl Embedder { let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?; model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)? } + ModelKind::XlmRoberta(model) => { + let mut mask_vec = tokens.get_attention_mask().to_vec(); + if mask_vec.len() > self.max_len { + mask_vec.truncate(self.max_len); + } + let mask = Tensor::new(mask_vec.as_slice(), &self.device) + .map_err(EmbedError::tensor_shape)?; + let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + model + .forward(&token_ids, &mask, &token_type_ids, None, None, None) + .map_err(EmbedError::model_forward)? + } }; let embedding = Self::pooling(embeddings, self.pooling)?;