diff --git a/crates/milli/src/vector/embedder/hf.rs b/crates/milli/src/vector/embedder/hf.rs index 13ec46cca..b5a3f5c45 100644 --- a/crates/milli/src/vector/embedder/hf.rs +++ b/crates/milli/src/vector/embedder/hf.rs @@ -1,6 +1,7 @@ use candle_core::Tensor; use candle_nn::VarBuilder; -use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE}; +use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert}; // FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself use hf_hub::api::sync::Api; use hf_hub::{Repo, RepoType}; @@ -84,14 +85,21 @@ impl Default for EmbedderOptions { } } +enum ModelKind { + Bert(BertModel), + Modern(ModernBert), +} + /// Perform embedding of documents and queries pub struct Embedder { - model: BertModel, + model: ModelKind, tokenizer: Tokenizer, options: EmbedderOptions, dimensions: usize, pooling: Pooling, cache: EmbeddingCache, + device: candle_core::Device, + max_len: usize, } impl std::fmt::Debug for Embedder { @@ -220,16 +228,34 @@ impl Embedder { (config, tokenizer, weights, source, pooling) }; - let config = std::fs::read_to_string(&config_filename) + let config_str = std::fs::read_to_string(&config_filename) .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; - let config: Config = serde_json::from_str(&config).map_err(|inner| { - NewEmbedderError::deserialize_config( - options.model.clone(), - config, - config_filename, - inner, - ) - })?; + + let cfg_val: serde_json::Value = match serde_json::from_str(&config_str) { + Ok(v) => v, + Err(inner) => { + return Err(NewEmbedderError::deserialize_config( + options.model.clone(), + config_str.clone(), + config_filename.clone(), + inner, + )); + } + }; + + let model_type = cfg_val.get("model_type").and_then(|v| v.as_str()).unwrap_or_default(); + let arch_arr = cfg_val.get("architectures").and_then(|v| v.as_array()); + let has_arch = |needle: &str| { + model_type.eq_ignore_ascii_case(needle) + || arch_arr.is_some_and(|arr| { + arr.iter().filter_map(|v| v.as_str()).any(|s| s.to_lowercase().contains(needle)) + }) + }; + + let is_modern = has_arch("modernbert"); + tracing::debug!(is_modern, model_type, "detected HF architecture"); + // default to BERT otherwise + let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; @@ -244,7 +270,31 @@ impl Embedder { tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config"); - let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; + // max length from config, fallback to 512 + let max_len = + cfg_val.get("max_position_embeddings").and_then(|v| v.as_u64()).unwrap_or(512) as usize; + + let model = if is_modern { + let config: ModernConfig = serde_json::from_str(&config_str).map_err(|inner| { + NewEmbedderError::deserialize_config( + options.model.clone(), + config_str.clone(), + config_filename.clone(), + inner, + ) + })?; + ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?) + } else { + let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| { + NewEmbedderError::deserialize_config( + options.model.clone(), + config_str.clone(), + config_filename.clone(), + inner, + ) + })?; + ModelKind::Bert(BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?) + }; if let Some(pp) = tokenizer.get_padding_mut() { pp.strategy = tokenizers::PaddingStrategy::BatchLongest @@ -263,6 +313,8 @@ impl Embedder { dimensions: 0, pooling, cache: EmbeddingCache::new(cache_cap), + device, + max_len, }; let embeddings = this @@ -321,15 +373,29 @@ impl Embedder { pub fn embed_one(&self, text: &str) -> std::result::Result { let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; let token_ids = tokens.get_ids(); - let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids }; let token_ids = - Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; + if token_ids.len() > self.max_len { &token_ids[..self.max_len] } else { token_ids }; + let token_ids = Tensor::new(token_ids, &self.device).map_err(EmbedError::tensor_shape)?; let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; - let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; - let embeddings = self - .model - .forward(&token_ids, &token_type_ids, None) - .map_err(EmbedError::model_forward)?; + + let embeddings = match &self.model { + ModelKind::Bert(model) => { + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + model + .forward(&token_ids, &token_type_ids, None) + .map_err(EmbedError::model_forward)? + } + ModelKind::Modern(model) => { + let mut mask_vec = tokens.get_attention_mask().to_vec(); + if mask_vec.len() > self.max_len { + mask_vec.truncate(self.max_len); + } + let mask = Tensor::new(mask_vec.as_slice(), &self.device) + .map_err(EmbedError::tensor_shape)?; + let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?; + model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)? + } + }; let embedding = Self::pooling(embeddings, self.pooling)?; diff --git a/crates/milli/src/vector/error.rs b/crates/milli/src/vector/error.rs index b4b90b24b..f65d69057 100644 --- a/crates/milli/src/vector/error.rs +++ b/crates/milli/src/vector/error.rs @@ -550,9 +550,9 @@ pub struct DeserializePoolingConfig { #[derive(Debug, thiserror::Error)] #[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}", if architectures.is_empty() { - "\n - Note: only models with architecture \"BertModel\" are supported.".to_string() + "\n - Note: only models with architecture \"BertModel\" or \"ModernBert\" are supported.".to_string() } else { - format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.") + format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` or `\"ModernBert\"` are supported.") })] pub struct UnsupportedModel { pub model_name: String,