mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-12-14 16:36:57 +00:00
add support of models XLMRoberta
This commit is contained in:
committed by
Louis Dureuil
parent
57b94b411f
commit
fe46af7ded
@@ -2,6 +2,7 @@ use candle_core::Tensor;
|
|||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
|
||||||
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
|
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
|
||||||
|
use candle_transformers::models::xlm_roberta::{Config as XlmRobertaConfig, XLMRobertaModel};
|
||||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Repo, RepoType};
|
||||||
@@ -89,6 +90,7 @@ impl Default for EmbedderOptions {
|
|||||||
enum ModelKind {
|
enum ModelKind {
|
||||||
Bert(BertModel),
|
Bert(BertModel),
|
||||||
Modern(ModernBert),
|
Modern(ModernBert),
|
||||||
|
XlmRoberta(XLMRobertaModel),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform embedding of documents and queries
|
/// Perform embedding of documents and queries
|
||||||
@@ -304,7 +306,8 @@ impl Embedder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let is_modern = has_arch("modernbert");
|
let is_modern = has_arch("modernbert");
|
||||||
tracing::debug!(is_modern, model_type, "detected HF architecture");
|
let is_xlm_roberta = has_arch("xlm-roberta") || has_arch("xlm_roberta");
|
||||||
|
tracing::debug!(is_modern, is_xlm_roberta, model_type, "detected HF architecture");
|
||||||
|
|
||||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||||
@@ -340,6 +343,19 @@ impl Embedder {
|
|||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
||||||
|
} else if is_xlm_roberta {
|
||||||
|
let config: XlmRobertaConfig =
|
||||||
|
serde_json::from_str(&config_str).map_err(|inner| {
|
||||||
|
NewEmbedderError::deserialize_config(
|
||||||
|
options.model.clone(),
|
||||||
|
config_str.clone(),
|
||||||
|
config_filename.clone(),
|
||||||
|
inner,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
ModelKind::XlmRoberta(
|
||||||
|
XLMRobertaModel::new(&config, vb).map_err(NewEmbedderError::load_model)?,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
||||||
NewEmbedderError::deserialize_config(
|
NewEmbedderError::deserialize_config(
|
||||||
@@ -451,6 +467,19 @@ impl Embedder {
|
|||||||
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
||||||
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
|
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
|
||||||
}
|
}
|
||||||
|
ModelKind::XlmRoberta(model) => {
|
||||||
|
let mut mask_vec = tokens.get_attention_mask().to_vec();
|
||||||
|
if mask_vec.len() > self.max_len {
|
||||||
|
mask_vec.truncate(self.max_len);
|
||||||
|
}
|
||||||
|
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
|
||||||
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
||||||
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
|
model
|
||||||
|
.forward(&token_ids, &mask, &token_type_ids, None, None, None)
|
||||||
|
.map_err(EmbedError::model_forward)?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let embedding = Self::pooling(embeddings, self.pooling)?;
|
let embedding = Self::pooling(embeddings, self.pooling)?;
|
||||||
|
|||||||
Reference in New Issue
Block a user