mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-12-08 21:55:42 +00:00
add support of models XLMRoberta
This commit is contained in:
committed by
Louis Dureuil
parent
57b94b411f
commit
fe46af7ded
@@ -2,6 +2,7 @@ use candle_core::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
|
||||
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
|
||||
use candle_transformers::models::xlm_roberta::{Config as XlmRobertaConfig, XLMRobertaModel};
|
||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
@@ -89,6 +90,7 @@ impl Default for EmbedderOptions {
|
||||
enum ModelKind {
|
||||
Bert(BertModel),
|
||||
Modern(ModernBert),
|
||||
XlmRoberta(XLMRobertaModel),
|
||||
}
|
||||
|
||||
/// Perform embedding of documents and queries
|
||||
@@ -304,7 +306,8 @@ impl Embedder {
|
||||
};
|
||||
|
||||
let is_modern = has_arch("modernbert");
|
||||
tracing::debug!(is_modern, model_type, "detected HF architecture");
|
||||
let is_xlm_roberta = has_arch("xlm-roberta") || has_arch("xlm_roberta");
|
||||
tracing::debug!(is_modern, is_xlm_roberta, model_type, "detected HF architecture");
|
||||
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
@@ -340,6 +343,19 @@ impl Embedder {
|
||||
)
|
||||
})?;
|
||||
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
||||
} else if is_xlm_roberta {
|
||||
let config: XlmRobertaConfig =
|
||||
serde_json::from_str(&config_str).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(
|
||||
options.model.clone(),
|
||||
config_str.clone(),
|
||||
config_filename.clone(),
|
||||
inner,
|
||||
)
|
||||
})?;
|
||||
ModelKind::XlmRoberta(
|
||||
XLMRobertaModel::new(&config, vb).map_err(NewEmbedderError::load_model)?,
|
||||
)
|
||||
} else {
|
||||
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(
|
||||
@@ -451,6 +467,19 @@ impl Embedder {
|
||||
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
||||
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
|
||||
}
|
||||
ModelKind::XlmRoberta(model) => {
|
||||
let mut mask_vec = tokens.get_attention_mask().to_vec();
|
||||
if mask_vec.len() > self.max_len {
|
||||
mask_vec.truncate(self.max_len);
|
||||
}
|
||||
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
model
|
||||
.forward(&token_ids, &mask, &token_type_ids, None, None, None)
|
||||
.map_err(EmbedError::model_forward)?
|
||||
}
|
||||
};
|
||||
|
||||
let embedding = Self::pooling(embeddings, self.pooling)?;
|
||||
|
||||
Reference in New Issue
Block a user