mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-11-22 04:36:32 +00:00
Support modernbert architecture in hugging face embedder
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
use candle_core::Tensor;
|
use candle_core::Tensor;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
|
||||||
|
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
|
||||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Repo, RepoType};
|
||||||
@@ -84,14 +85,21 @@ impl Default for EmbedderOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ModelKind {
|
||||||
|
Bert(BertModel),
|
||||||
|
Modern(ModernBert),
|
||||||
|
}
|
||||||
|
|
||||||
/// Perform embedding of documents and queries
|
/// Perform embedding of documents and queries
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
model: BertModel,
|
model: ModelKind,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: EmbedderOptions,
|
options: EmbedderOptions,
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
pooling: Pooling,
|
pooling: Pooling,
|
||||||
cache: EmbeddingCache,
|
cache: EmbeddingCache,
|
||||||
|
device: candle_core::Device,
|
||||||
|
max_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Embedder {
|
impl std::fmt::Debug for Embedder {
|
||||||
@@ -220,16 +228,34 @@ impl Embedder {
|
|||||||
(config, tokenizer, weights, source, pooling)
|
(config, tokenizer, weights, source, pooling)
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = std::fs::read_to_string(&config_filename)
|
let config_str = std::fs::read_to_string(&config_filename)
|
||||||
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||||
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
|
||||||
NewEmbedderError::deserialize_config(
|
let cfg_val: serde_json::Value = match serde_json::from_str(&config_str) {
|
||||||
options.model.clone(),
|
Ok(v) => v,
|
||||||
config,
|
Err(inner) => {
|
||||||
config_filename,
|
return Err(NewEmbedderError::deserialize_config(
|
||||||
inner,
|
options.model.clone(),
|
||||||
)
|
config_str.clone(),
|
||||||
})?;
|
config_filename.clone(),
|
||||||
|
inner,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_type = cfg_val.get("model_type").and_then(|v| v.as_str()).unwrap_or_default();
|
||||||
|
let arch_arr = cfg_val.get("architectures").and_then(|v| v.as_array());
|
||||||
|
let has_arch = |needle: &str| {
|
||||||
|
model_type.eq_ignore_ascii_case(needle)
|
||||||
|
|| arch_arr.is_some_and(|arr| {
|
||||||
|
arr.iter().filter_map(|v| v.as_str()).any(|s| s.to_lowercase().contains(needle))
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let is_modern = has_arch("modernbert");
|
||||||
|
tracing::debug!(is_modern, model_type, "detected HF architecture");
|
||||||
|
// default to BERT otherwise
|
||||||
|
|
||||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||||
|
|
||||||
@@ -244,7 +270,31 @@ impl Embedder {
|
|||||||
|
|
||||||
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
|
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
|
||||||
|
|
||||||
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
// max length from config, fallback to 512
|
||||||
|
let max_len =
|
||||||
|
cfg_val.get("max_position_embeddings").and_then(|v| v.as_u64()).unwrap_or(512) as usize;
|
||||||
|
|
||||||
|
let model = if is_modern {
|
||||||
|
let config: ModernConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
||||||
|
NewEmbedderError::deserialize_config(
|
||||||
|
options.model.clone(),
|
||||||
|
config_str.clone(),
|
||||||
|
config_filename.clone(),
|
||||||
|
inner,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
||||||
|
} else {
|
||||||
|
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
|
||||||
|
NewEmbedderError::deserialize_config(
|
||||||
|
options.model.clone(),
|
||||||
|
config_str.clone(),
|
||||||
|
config_filename.clone(),
|
||||||
|
inner,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
ModelKind::Bert(BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?)
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
@@ -263,6 +313,8 @@ impl Embedder {
|
|||||||
dimensions: 0,
|
dimensions: 0,
|
||||||
pooling,
|
pooling,
|
||||||
cache: EmbeddingCache::new(cache_cap),
|
cache: EmbeddingCache::new(cache_cap),
|
||||||
|
device,
|
||||||
|
max_len,
|
||||||
};
|
};
|
||||||
|
|
||||||
let embeddings = this
|
let embeddings = this
|
||||||
@@ -321,15 +373,29 @@ impl Embedder {
|
|||||||
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
|
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
|
||||||
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
|
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
|
||||||
let token_ids = tokens.get_ids();
|
let token_ids = tokens.get_ids();
|
||||||
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
|
|
||||||
let token_ids =
|
let token_ids =
|
||||||
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
if token_ids.len() > self.max_len { &token_ids[..self.max_len] } else { token_ids };
|
||||||
|
let token_ids = Tensor::new(token_ids, &self.device).map_err(EmbedError::tensor_shape)?;
|
||||||
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
||||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
|
||||||
let embeddings = self
|
let embeddings = match &self.model {
|
||||||
.model
|
ModelKind::Bert(model) => {
|
||||||
.forward(&token_ids, &token_type_ids, None)
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
.map_err(EmbedError::model_forward)?;
|
model
|
||||||
|
.forward(&token_ids, &token_type_ids, None)
|
||||||
|
.map_err(EmbedError::model_forward)?
|
||||||
|
}
|
||||||
|
ModelKind::Modern(model) => {
|
||||||
|
let mut mask_vec = tokens.get_attention_mask().to_vec();
|
||||||
|
if mask_vec.len() > self.max_len {
|
||||||
|
mask_vec.truncate(self.max_len);
|
||||||
|
}
|
||||||
|
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
|
||||||
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
|
||||||
|
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let embedding = Self::pooling(embeddings, self.pooling)?;
|
let embedding = Self::pooling(embeddings, self.pooling)?;
|
||||||
|
|
||||||
|
|||||||
@@ -550,9 +550,9 @@ pub struct DeserializePoolingConfig {
|
|||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
||||||
if architectures.is_empty() {
|
if architectures.is_empty() {
|
||||||
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
|
"\n - Note: only models with architecture \"BertModel\" or \"ModernBert\" are supported.".to_string()
|
||||||
} else {
|
} else {
|
||||||
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
|
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` or `\"ModernBert\"` are supported.")
|
||||||
})]
|
})]
|
||||||
pub struct UnsupportedModel {
|
pub struct UnsupportedModel {
|
||||||
pub model_name: String,
|
pub model_name: String,
|
||||||
|
|||||||
Reference in New Issue
Block a user