mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-11-22 04:36:32 +00:00
Support modernbert architecture in hugging face embedder
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use candle_core::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
|
||||
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
|
||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
@@ -84,14 +85,21 @@ impl Default for EmbedderOptions {
|
||||
}
|
||||
}
|
||||
|
||||
enum ModelKind {
|
||||
Bert(BertModel),
|
||||
Modern(ModernBert),
|
||||
}
|
||||
|
||||
/// Perform embedding of documents and queries
|
||||
pub struct Embedder {
|
||||
model: BertModel,
|
||||
model: ModelKind,
|
||||
tokenizer: Tokenizer,
|
||||
options: EmbedderOptions,
|
||||
dimensions: usize,
|
||||
pooling: Pooling,
|
||||
cache: EmbeddingCache,
|
||||
device: candle_core::Device,
|
||||
max_len: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Embedder {
|
||||
@@ -220,16 +228,34 @@ impl Embedder {
|
||||
(config, tokenizer, weights, source, pooling)
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(&config_filename)
|
||||
let config_str = std::fs::read_to_string(&config_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(
|
||||
|
||||
let cfg_val: serde_json::Value = match serde_json::from_str(&config_str) {
|
||||
Ok(v) => v,
|
||||
Err(inner) => {
|
||||
return Err(NewEmbedderError::deserialize_config(
|
||||
options.model.clone(),
|
||||
config,
|
||||
config_filename,
|
||||
config_str.clone(),
|
||||
config_filename.clone(),
|
||||
inner,
|
||||
)
|
||||
})?;
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let model_type = cfg_val.get("model_type").and_then(|v| v.as_str()).unwrap_or_default();
|
||||
let arch_arr = cfg_val.get("architectures").and_then(|v| v.as_array());
|
||||
let has_arch = |needle: &str| {
|
||||
model_type.eq_ignore_ascii_case(needle)
|
||||
|| arch_arr.is_some_and(|arr| {
|
||||
arr.iter().filter_map(|v| v.as_str()).any(|s| s.to_lowercase().contains(needle))
|
||||
})
|
||||
};
|
||||
|
||||
let is_modern = has_arch("modernbert");
|
||||
tracing::debug!(is_modern, model_type, "detected HF architecture");
|
||||
// default to BERT otherwise
|
||||
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
|
||||
@@ -244,7 +270,31 @@ impl Embedder {
|
||||
|
||||
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
|
||||
|
||||
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
||||
// max length from config, fallback to 512
|
||||
let max_len =
|
||||
cfg_val.get("max_position_embeddings").and_then(|v| v.as_u64()).unwrap_or(512) as usize;
|
||||
|
||||
let model = if is_modern {
|
||||
let config: ModernConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(
|
||||
options.model.clone(),
|
||||
config_str.clone(),
|
||||
config_filename.clone(),
|
||||
inner,
|
||||
)
|
||||
})?;
|
||||
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
||||
} else {
|
||||
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(
|
||||
options.model.clone(),
|
||||
config_str.clone(),
|
||||
config_filename.clone(),
|
||||
inner,
|
||||
)
|
||||
})?;
|
||||
ModelKind::Bert(BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
||||
};
|
||||
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
@@ -263,6 +313,8 @@ impl Embedder {
|
||||
dimensions: 0,
|
||||
pooling,
|
||||
cache: EmbeddingCache::new(cache_cap),
|
||||
device,
|
||||
max_len,
|
||||
};
|
||||
|
||||
let embeddings = this
|
||||
@@ -321,15 +373,29 @@ impl Embedder {
|
||||
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
|
||||
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
|
||||
let token_ids = tokens.get_ids();
|
||||
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
|
||||
let token_ids =
|
||||
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
||||
if token_ids.len() > self.max_len { &token_ids[..self.max_len] } else { token_ids };
|
||||
let token_ids = Tensor::new(token_ids, &self.device).map_err(EmbedError::tensor_shape)?;
|
||||
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings = match &self.model {
|
||||
ModelKind::Bert(model) => {
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings = self
|
||||
.model
|
||||
model
|
||||
.forward(&token_ids, &token_type_ids, None)
|
||||
.map_err(EmbedError::model_forward)?;
|
||||
.map_err(EmbedError::model_forward)?
|
||||
}
|
||||
ModelKind::Modern(model) => {
|
||||
let mut mask_vec = tokens.get_attention_mask().to_vec();
|
||||
if mask_vec.len() > self.max_len {
|
||||
mask_vec.truncate(self.max_len);
|
||||
}
|
||||
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
||||
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
|
||||
}
|
||||
};
|
||||
|
||||
let embedding = Self::pooling(embeddings, self.pooling)?;
|
||||
|
||||
|
||||
@@ -550,9 +550,9 @@ pub struct DeserializePoolingConfig {
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
||||
if architectures.is_empty() {
|
||||
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
|
||||
"\n - Note: only models with architecture \"BertModel\" or \"ModernBert\" are supported.".to_string()
|
||||
} else {
|
||||
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
|
||||
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` or `\"ModernBert\"` are supported.")
|
||||
})]
|
||||
pub struct UnsupportedModel {
|
||||
pub model_name: String,
|
||||
|
||||
Reference in New Issue
Block a user