Support modernbert architecture in hugging face embedder

This commit is contained in:
Hayato Sakaguchi
2025-11-08 20:53:47 +09:00
parent a9d6e86077
commit d6eca83cfa
2 changed files with 87 additions and 21 deletions

View File

@@ -1,6 +1,7 @@
use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
@@ -84,14 +85,21 @@ impl Default for EmbedderOptions {
}
}
enum ModelKind {
Bert(BertModel),
Modern(ModernBert),
}
/// Perform embedding of documents and queries
pub struct Embedder {
model: BertModel,
model: ModelKind,
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
pooling: Pooling,
cache: EmbeddingCache,
device: candle_core::Device,
max_len: usize,
}
impl std::fmt::Debug for Embedder {
@@ -220,16 +228,34 @@ impl Embedder {
(config, tokenizer, weights, source, pooling)
};
let config = std::fs::read_to_string(&config_filename)
let config_str = std::fs::read_to_string(&config_filename)
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
let config: Config = serde_json::from_str(&config).map_err(|inner| {
NewEmbedderError::deserialize_config(
let cfg_val: serde_json::Value = match serde_json::from_str(&config_str) {
Ok(v) => v,
Err(inner) => {
return Err(NewEmbedderError::deserialize_config(
options.model.clone(),
config,
config_filename,
config_str.clone(),
config_filename.clone(),
inner,
)
})?;
));
}
};
let model_type = cfg_val.get("model_type").and_then(|v| v.as_str()).unwrap_or_default();
let arch_arr = cfg_val.get("architectures").and_then(|v| v.as_array());
let has_arch = |needle: &str| {
model_type.eq_ignore_ascii_case(needle)
|| arch_arr.is_some_and(|arr| {
arr.iter().filter_map(|v| v.as_str()).any(|s| s.to_lowercase().contains(needle))
})
};
let is_modern = has_arch("modernbert");
tracing::debug!(is_modern, model_type, "detected HF architecture");
// default to BERT otherwise
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
@@ -244,7 +270,31 @@ impl Embedder {
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
// max length from config, fallback to 512
let max_len =
cfg_val.get("max_position_embeddings").and_then(|v| v.as_u64()).unwrap_or(512) as usize;
let model = if is_modern {
let config: ModernConfig = serde_json::from_str(&config_str).map_err(|inner| {
NewEmbedderError::deserialize_config(
options.model.clone(),
config_str.clone(),
config_filename.clone(),
inner,
)
})?;
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
} else {
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
NewEmbedderError::deserialize_config(
options.model.clone(),
config_str.clone(),
config_filename.clone(),
inner,
)
})?;
ModelKind::Bert(BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?)
};
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
@@ -263,6 +313,8 @@ impl Embedder {
dimensions: 0,
pooling,
cache: EmbeddingCache::new(cache_cap),
device,
max_len,
};
let embeddings = this
@@ -321,15 +373,29 @@ impl Embedder {
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
let token_ids = tokens.get_ids();
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
let token_ids =
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
if token_ids.len() > self.max_len { &token_ids[..self.max_len] } else { token_ids };
let token_ids = Tensor::new(token_ids, &self.device).map_err(EmbedError::tensor_shape)?;
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
let embeddings = match &self.model {
ModelKind::Bert(model) => {
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings = self
.model
model
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?;
.map_err(EmbedError::model_forward)?
}
ModelKind::Modern(model) => {
let mut mask_vec = tokens.get_attention_mask().to_vec();
if mask_vec.len() > self.max_len {
mask_vec.truncate(self.max_len);
}
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
.map_err(EmbedError::tensor_shape)?;
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
}
};
let embedding = Self::pooling(embeddings, self.pooling)?;

View File

@@ -550,9 +550,9 @@ pub struct DeserializePoolingConfig {
#[derive(Debug, thiserror::Error)]
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
if architectures.is_empty() {
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
"\n - Note: only models with architecture \"BertModel\" or \"ModernBert\" are supported.".to_string()
} else {
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` or `\"ModernBert\"` are supported.")
})]
pub struct UnsupportedModel {
pub model_name: String,