add support of models XLMRoberta

This commit is contained in:
Quentin de Quelen
2025-11-29 13:58:24 +01:00
committed by Louis Dureuil
parent 57b94b411f
commit fe46af7ded

View File

@@ -2,6 +2,7 @@ use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
use candle_transformers::models::xlm_roberta::{Config as XlmRobertaConfig, XLMRobertaModel};
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
@@ -89,6 +90,7 @@ impl Default for EmbedderOptions {
enum ModelKind {
Bert(BertModel),
Modern(ModernBert),
XlmRoberta(XLMRobertaModel),
}
/// Perform embedding of documents and queries
@@ -304,7 +306,8 @@ impl Embedder {
};
let is_modern = has_arch("modernbert");
tracing::debug!(is_modern, model_type, "detected HF architecture");
let is_xlm_roberta = has_arch("xlm-roberta") || has_arch("xlm_roberta");
tracing::debug!(is_modern, is_xlm_roberta, model_type, "detected HF architecture");
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
@@ -340,6 +343,19 @@ impl Embedder {
)
})?;
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
} else if is_xlm_roberta {
let config: XlmRobertaConfig =
serde_json::from_str(&config_str).map_err(|inner| {
NewEmbedderError::deserialize_config(
options.model.clone(),
config_str.clone(),
config_filename.clone(),
inner,
)
})?;
ModelKind::XlmRoberta(
XLMRobertaModel::new(&config, vb).map_err(NewEmbedderError::load_model)?,
)
} else {
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
NewEmbedderError::deserialize_config(
@@ -451,6 +467,19 @@ impl Embedder {
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
}
ModelKind::XlmRoberta(model) => {
let mut mask_vec = tokens.get_attention_mask().to_vec();
if mask_vec.len() > self.max_len {
mask_vec.truncate(self.max_len);
}
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
.map_err(EmbedError::tensor_shape)?;
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
model
.forward(&token_ids, &mask, &token_type_ids, None, None, None)
.map_err(EmbedError::model_forward)?
}
};
let embedding = Self::pooling(embeddings, self.pooling)?;