Merge pull request #5980 from hayatosc/feat/hugging-face-modernbert

Support ModernBERT architecture on `huggingface` embedder
This commit is contained in:
Louis Dureuil
2025-11-10 18:03:35 +00:00
committed by GitHub
4 changed files with 216 additions and 132 deletions

181
Cargo.lock generated
View File

@@ -310,6 +310,7 @@ dependencies = [
"const-random", "const-random",
"getrandom 0.3.3", "getrandom 0.3.3",
"once_cell", "once_cell",
"serde",
"version_check", "version_check",
"zerocopy", "zerocopy",
] ]
@@ -492,7 +493,7 @@ dependencies = [
"backoff", "backoff",
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
"derive_builder 0.20.2", "derive_builder",
"eventsource-stream", "eventsource-stream",
"futures", "futures",
"rand 0.8.5", "rand 0.8.5",
@@ -945,7 +946,7 @@ dependencies = [
"rand 0.9.2", "rand 0.9.2",
"rand_distr", "rand_distr",
"rayon", "rayon",
"safetensors", "safetensors 0.4.5",
"thiserror 1.0.69", "thiserror 1.0.69",
"ug", "ug",
"ug-cuda", "ug-cuda",
@@ -972,7 +973,7 @@ dependencies = [
"half", "half",
"num-traits", "num-traits",
"rayon", "rayon",
"safetensors", "safetensors 0.4.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
@@ -1052,6 +1053,15 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
dependencies = [
"rustversion",
]
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.37" version = "1.2.37"
@@ -1214,7 +1224,7 @@ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
"clap_lex", "clap_lex",
"strsim 0.11.1", "strsim",
] ]
[[package]] [[package]]
@@ -1253,6 +1263,21 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "compact_str"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
dependencies = [
"castaway",
"cfg-if",
"itoa",
"rustversion",
"ryu",
"serde",
"static_assertions",
]
[[package]] [[package]]
name = "concat-arrays" name = "concat-arrays"
version = "0.1.2" version = "0.1.2"
@@ -1511,38 +1536,14 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "darling"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
dependencies = [
"darling_core 0.14.4",
"darling_macro 0.14.4",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.20.11" version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [ dependencies = [
"darling_core 0.20.11", "darling_core",
"darling_macro 0.20.11", "darling_macro",
]
[[package]]
name = "darling_core"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn 1.0.109",
] ]
[[package]] [[package]]
@@ -1555,28 +1556,17 @@ dependencies = [
"ident_case", "ident_case",
"proc-macro2", "proc-macro2",
"quote", "quote",
"strsim 0.11.1", "strsim",
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "darling_macro"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
dependencies = [
"darling_core 0.14.4",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "darling_macro" name = "darling_macro"
version = "0.20.11" version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [ dependencies = [
"darling_core 0.20.11", "darling_core",
"quote", "quote",
"syn 2.0.106", "syn 2.0.106",
] ]
@@ -1586,6 +1576,9 @@ name = "dary_heap"
version = "0.3.7" version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "deadpool" name = "deadpool"
@@ -1641,34 +1634,13 @@ dependencies = [
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "derive_builder"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
dependencies = [
"derive_builder_macro 0.12.0",
]
[[package]] [[package]]
name = "derive_builder" name = "derive_builder"
version = "0.20.2" version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [ dependencies = [
"derive_builder_macro 0.20.2", "derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
dependencies = [
"darling 0.14.4",
"proc-macro2",
"quote",
"syn 1.0.109",
] ]
[[package]] [[package]]
@@ -1677,29 +1649,19 @@ version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [ dependencies = [
"darling 0.20.11", "darling",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "derive_builder_macro"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
dependencies = [
"derive_builder_core 0.12.0",
"syn 1.0.109",
]
[[package]] [[package]]
name = "derive_builder_macro" name = "derive_builder_macro"
version = "0.20.2" version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [ dependencies = [
"derive_builder_core 0.20.2", "derive_builder_core",
"syn 2.0.106", "syn 2.0.106",
] ]
@@ -1738,7 +1700,7 @@ dependencies = [
"serde-cs", "serde-cs",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"strsim 0.11.1", "strsim",
] ]
[[package]] [[package]]
@@ -3245,7 +3207,7 @@ dependencies = [
"convert_case 0.8.0", "convert_case 0.8.0",
"crossbeam-channel", "crossbeam-channel",
"csv", "csv",
"derive_builder 0.20.2", "derive_builder",
"dump", "dump",
"enum-iterator", "enum-iterator",
"file-store", "file-store",
@@ -3412,15 +3374,6 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.13.0" version = "0.13.0"
@@ -3765,7 +3718,7 @@ dependencies = [
"bincode 2.0.1", "bincode 2.0.1",
"byteorder", "byteorder",
"csv", "csv",
"derive_builder 0.20.2", "derive_builder",
"encoding", "encoding",
"encoding_rs", "encoding_rs",
"encoding_rs_io", "encoding_rs_io",
@@ -4289,6 +4242,7 @@ dependencies = [
"roaring 0.10.12", "roaring 0.10.12",
"rstar", "rstar",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"safetensors 0.6.2",
"serde", "serde",
"serde_json", "serde_json",
"slice-group-by", "slice-group-by",
@@ -5399,12 +5353,12 @@ dependencies = [
[[package]] [[package]]
name = "rayon-cond" name = "rayon-cond"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f"
dependencies = [ dependencies = [
"either", "either",
"itertools 0.11.0", "itertools 0.14.0",
"rayon", "rayon",
] ]
@@ -5825,6 +5779,16 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "safetensors"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@@ -6306,12 +6270,6 @@ dependencies = [
"indexmap", "indexmap",
] ]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.11.1" version = "0.11.1"
@@ -6637,21 +6595,24 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.15.2" version = "0.22.1"
source = "git+https://github.com/huggingface/tokenizers.git?tag=v0.15.2#701a73b869602b5639589d197e805349cdba3223" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6475a27088c98ea96d00b39a9ddfb63780d1ad4cceb6f48374349a96ab2b7842"
dependencies = [ dependencies = [
"ahash 0.8.12",
"aho-corasick", "aho-corasick",
"derive_builder 0.12.0", "compact_str",
"dary_heap",
"derive_builder",
"esaxx-rs", "esaxx-rs",
"getrandom 0.2.16", "getrandom 0.3.3",
"itertools 0.12.1", "itertools 0.14.0",
"lazy_static",
"log", "log",
"macro_rules_attribute", "macro_rules_attribute",
"monostate", "monostate",
"onig", "onig",
"paste", "paste",
"rand 0.8.5", "rand 0.9.2",
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
@@ -6659,7 +6620,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
"thiserror 1.0.69", "thiserror 2.0.16",
"unicode-normalization-alignments", "unicode-normalization-alignments",
"unicode-segmentation", "unicode-segmentation",
"unicode_categories", "unicode_categories",
@@ -7021,7 +6982,7 @@ dependencies = [
"num-traits", "num-traits",
"num_cpus", "num_cpus",
"rayon", "rayon",
"safetensors", "safetensors 0.4.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
"tracing", "tracing",
@@ -7251,7 +7212,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b2bf58be11fc9414104c6d3a2e464163db5ef74b12296bda593cac37b6e4777" checksum = "6b2bf58be11fc9414104c6d3a2e464163db5ef74b12296bda593cac37b6e4777"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derive_builder 0.20.2", "derive_builder",
"rustversion", "rustversion",
"vergen-lib", "vergen-lib",
] ]
@@ -7263,7 +7224,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f6ee511ec45098eabade8a0750e76eec671e7fb2d9360c563911336bea9cac1" checksum = "4f6ee511ec45098eabade8a0750e76eec671e7fb2d9360c563911336bea9cac1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derive_builder 0.20.2", "derive_builder",
"git2", "git2",
"rustversion", "rustversion",
"time", "time",
@@ -7278,7 +7239,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b07e6010c0f3e59fcb164e0163834597da68d1f864e2b8ca49f74de01e9c166" checksum = "9b07e6010c0f3e59fcb164e0163834597da68d1f864e2b8ca49f74de01e9c166"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derive_builder 0.20.2", "derive_builder",
"rustversion", "rustversion",
] ]

View File

@@ -74,12 +74,13 @@ csv = "1.3.1"
candle-core = { version = "0.9.1" } candle-core = { version = "0.9.1" }
candle-transformers = { version = "0.9.1" } candle-transformers = { version = "0.9.1" }
candle-nn = { version = "0.9.1" } candle-nn = { version = "0.9.1" }
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.15.2", version = "0.15.2", default-features = false, features = [ tokenizers = { version = "0.22.1", default-features = false, features = [
"onig", "onig",
] } ] }
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default-features = false, features = [ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default-features = false, features = [
"online", "online",
] } ] }
safetensors = "0.6.2"
tiktoken-rs = "0.7.0" tiktoken-rs = "0.7.0"
liquid = "0.26.11" liquid = "0.26.11"
rhai = { version = "1.22.2", features = [ rhai = { version = "1.22.2", features = [

View File

@@ -1,9 +1,11 @@
use candle_core::Tensor; use candle_core::Tensor;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert};
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself // FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType}; use hf_hub::{Repo, RepoType};
use safetensors::SafeTensors;
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
use super::EmbeddingCache; use super::EmbeddingCache;
@@ -84,14 +86,21 @@ impl Default for EmbedderOptions {
} }
} }
enum ModelKind {
Bert(BertModel),
Modern(ModernBert),
}
/// Perform embedding of documents and queries /// Perform embedding of documents and queries
pub struct Embedder { pub struct Embedder {
model: BertModel, model: ModelKind,
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: EmbedderOptions, options: EmbedderOptions,
dimensions: usize, dimensions: usize,
pooling: Pooling, pooling: Pooling,
cache: EmbeddingCache, cache: EmbeddingCache,
device: candle_core::Device,
max_len: usize,
} }
impl std::fmt::Debug for Embedder { impl std::fmt::Debug for Embedder {
@@ -101,10 +110,60 @@ impl std::fmt::Debug for Embedder {
.field("tokenizer", &self.tokenizer) .field("tokenizer", &self.tokenizer)
.field("options", &self.options) .field("options", &self.options)
.field("pooling", &self.pooling) .field("pooling", &self.pooling)
.field("device", &self.device)
.field("max_len", &self.max_len)
.finish() .finish()
} }
} }
// some models do not have the "model." prefix in their safetensors weights
fn change_tensor_names(
weights_path: &std::path::Path,
) -> Result<std::path::PathBuf, NewEmbedderError> {
let data = std::fs::read(weights_path)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Io(e)))?;
let tensors = SafeTensors::deserialize(&data)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
let names = tensors.names();
let has_model_prefix = names.iter().any(|n| n.starts_with("model."));
if has_model_prefix {
return Ok(weights_path.to_path_buf());
}
let fixed_path = weights_path.with_extension("fixed.safetensors");
if fixed_path.exists() {
return Ok(fixed_path);
}
let mut new_tensors = vec![];
for name in names {
let tensor_view = tensors.tensor(name).map_err(|e| {
NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string()))
})?;
let new_name = format!("model.{}", name);
let data_offset = tensor_view.data();
let shape = tensor_view.shape();
let dtype = tensor_view.dtype();
new_tensors.push((new_name, shape.to_vec(), dtype, data_offset));
}
use safetensors::tensor::TensorView;
let views = new_tensors.iter().map(|(name, shape, dtype, data)| {
(name.as_str(), TensorView::new(*dtype, shape.clone(), data).unwrap())
});
safetensors::serialize_to_file(views, None, &fixed_path)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
Ok(fixed_path)
}
#[derive(Clone, Copy, serde::Deserialize)] #[derive(Clone, Copy, serde::Deserialize)]
struct PoolingConfig { struct PoolingConfig {
#[serde(default)] #[serde(default)]
@@ -220,19 +279,42 @@ impl Embedder {
(config, tokenizer, weights, source, pooling) (config, tokenizer, weights, source, pooling)
}; };
let config = std::fs::read_to_string(&config_filename) let config_str = std::fs::read_to_string(&config_filename)
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
let config: Config = serde_json::from_str(&config).map_err(|inner| {
NewEmbedderError::deserialize_config( let cfg_val: serde_json::Value = match serde_json::from_str(&config_str) {
options.model.clone(), Ok(v) => v,
config, Err(inner) => {
config_filename, return Err(NewEmbedderError::deserialize_config(
inner, options.model.clone(),
) config_str.clone(),
})?; config_filename.clone(),
inner,
));
}
};
let model_type = cfg_val.get("model_type").and_then(|v| v.as_str()).unwrap_or_default();
let arch_arr = cfg_val.get("architectures").and_then(|v| v.as_array());
let has_arch = |needle: &str| {
model_type.eq_ignore_ascii_case(needle)
|| arch_arr.is_some_and(|arr| {
arr.iter().filter_map(|v| v.as_str()).any(|s| s.to_lowercase().contains(needle))
})
};
let is_modern = has_arch("modernbert");
tracing::debug!(is_modern, model_type, "detected HF architecture");
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let weights_filename = if is_modern && weight_source == WeightSource::Safetensors {
change_tensor_names(&weights_filename)?
} else {
weights_filename
};
let vb = match weight_source { let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?, .map_err(NewEmbedderError::pytorch_weight)?,
@@ -244,7 +326,31 @@ impl Embedder {
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config"); tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; // max length from config, fallback to 512
let max_len =
cfg_val.get("max_position_embeddings").and_then(|v| v.as_u64()).unwrap_or(512) as usize;
let model = if is_modern {
let config: ModernConfig = serde_json::from_str(&config_str).map_err(|inner| {
NewEmbedderError::deserialize_config(
options.model.clone(),
config_str.clone(),
config_filename.clone(),
inner,
)
})?;
ModelKind::Modern(ModernBert::load(vb, &config).map_err(NewEmbedderError::load_model)?)
} else {
let config: BertConfig = serde_json::from_str(&config_str).map_err(|inner| {
NewEmbedderError::deserialize_config(
options.model.clone(),
config_str.clone(),
config_filename.clone(),
inner,
)
})?;
ModelKind::Bert(BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?)
};
if let Some(pp) = tokenizer.get_padding_mut() { if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest pp.strategy = tokenizers::PaddingStrategy::BatchLongest
@@ -263,6 +369,8 @@ impl Embedder {
dimensions: 0, dimensions: 0,
pooling, pooling,
cache: EmbeddingCache::new(cache_cap), cache: EmbeddingCache::new(cache_cap),
device,
max_len,
}; };
let embeddings = this let embeddings = this
@@ -321,15 +429,29 @@ impl Embedder {
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> { pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
let token_ids = tokens.get_ids(); let token_ids = tokens.get_ids();
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
let token_ids = let token_ids =
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; if token_ids.len() > self.max_len { &token_ids[..self.max_len] } else { token_ids };
let token_ids = Tensor::new(token_ids, &self.device).map_err(EmbedError::tensor_shape)?;
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings = self let embeddings = match &self.model {
.model ModelKind::Bert(model) => {
.forward(&token_ids, &token_type_ids, None) let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
.map_err(EmbedError::model_forward)?; model
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?
}
ModelKind::Modern(model) => {
let mut mask_vec = tokens.get_attention_mask().to_vec();
if mask_vec.len() > self.max_len {
mask_vec.truncate(self.max_len);
}
let mask = Tensor::new(mask_vec.as_slice(), &self.device)
.map_err(EmbedError::tensor_shape)?;
let mask = Tensor::stack(&[mask], 0).map_err(EmbedError::tensor_shape)?;
model.forward(&token_ids, &mask).map_err(EmbedError::model_forward)?
}
};
let embedding = Self::pooling(embeddings, self.pooling)?; let embedding = Self::pooling(embeddings, self.pooling)?;

View File

@@ -550,9 +550,9 @@ pub struct DeserializePoolingConfig {
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}", #[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
if architectures.is_empty() { if architectures.is_empty() {
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string() "\n - Note: only models with architecture \"BertModel\" or \"ModernBert\" are supported.".to_string()
} else { } else {
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.") format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` or `\"ModernBert\"` are supported.")
})] })]
pub struct UnsupportedModel { pub struct UnsupportedModel {
pub model_name: String, pub model_name: String,