chage tensor names

This commit is contained in:
Hayato Sakaguchi
2025-11-09 01:06:04 +09:00
parent d6eca83cfa
commit 9f7172f6ab
3 changed files with 125 additions and 112 deletions

181
Cargo.lock generated
View File

@@ -310,6 +310,7 @@ dependencies = [
"const-random", "const-random",
"getrandom 0.3.3", "getrandom 0.3.3",
"once_cell", "once_cell",
"serde",
"version_check", "version_check",
"zerocopy", "zerocopy",
] ]
@@ -492,7 +493,7 @@ dependencies = [
"backoff", "backoff",
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
"derive_builder 0.20.2", "derive_builder",
"eventsource-stream", "eventsource-stream",
"futures", "futures",
"rand 0.8.5", "rand 0.8.5",
@@ -945,7 +946,7 @@ dependencies = [
"rand 0.9.2", "rand 0.9.2",
"rand_distr", "rand_distr",
"rayon", "rayon",
"safetensors", "safetensors 0.4.5",
"thiserror 1.0.69", "thiserror 1.0.69",
"ug", "ug",
"ug-cuda", "ug-cuda",
@@ -972,7 +973,7 @@ dependencies = [
"half", "half",
"num-traits", "num-traits",
"rayon", "rayon",
"safetensors", "safetensors 0.4.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
@@ -1052,6 +1053,15 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
dependencies = [
"rustversion",
]
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.37" version = "1.2.37"
@@ -1215,7 +1225,7 @@ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
"clap_lex", "clap_lex",
"strsim 0.11.1", "strsim",
] ]
[[package]] [[package]]
@@ -1254,6 +1264,21 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "compact_str"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
dependencies = [
"castaway",
"cfg-if",
"itoa",
"rustversion",
"ryu",
"serde",
"static_assertions",
]
[[package]] [[package]]
name = "concat-arrays" name = "concat-arrays"
version = "0.1.2" version = "0.1.2"
@@ -1512,38 +1537,14 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "darling"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
dependencies = [
"darling_core 0.14.4",
"darling_macro 0.14.4",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.20.11" version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [ dependencies = [
"darling_core 0.20.11", "darling_core",
"darling_macro 0.20.11", "darling_macro",
]
[[package]]
name = "darling_core"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn 1.0.109",
] ]
[[package]] [[package]]
@@ -1556,28 +1557,17 @@ dependencies = [
"ident_case", "ident_case",
"proc-macro2", "proc-macro2",
"quote", "quote",
"strsim 0.11.1", "strsim",
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "darling_macro"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
dependencies = [
"darling_core 0.14.4",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "darling_macro" name = "darling_macro"
version = "0.20.11" version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [ dependencies = [
"darling_core 0.20.11", "darling_core",
"quote", "quote",
"syn 2.0.106", "syn 2.0.106",
] ]
@@ -1587,6 +1577,9 @@ name = "dary_heap"
version = "0.3.7" version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "deadpool" name = "deadpool"
@@ -1642,34 +1635,13 @@ dependencies = [
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "derive_builder"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
dependencies = [
"derive_builder_macro 0.12.0",
]
[[package]] [[package]]
name = "derive_builder" name = "derive_builder"
version = "0.20.2" version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [ dependencies = [
"derive_builder_macro 0.20.2", "derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
dependencies = [
"darling 0.14.4",
"proc-macro2",
"quote",
"syn 1.0.109",
] ]
[[package]] [[package]]
@@ -1678,29 +1650,19 @@ version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [ dependencies = [
"darling 0.20.11", "darling",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "derive_builder_macro"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
dependencies = [
"derive_builder_core 0.12.0",
"syn 1.0.109",
]
[[package]] [[package]]
name = "derive_builder_macro" name = "derive_builder_macro"
version = "0.20.2" version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [ dependencies = [
"derive_builder_core 0.20.2", "derive_builder_core",
"syn 2.0.106", "syn 2.0.106",
] ]
@@ -1739,7 +1701,7 @@ dependencies = [
"serde-cs", "serde-cs",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"strsim 0.11.1", "strsim",
] ]
[[package]] [[package]]
@@ -3246,7 +3208,7 @@ dependencies = [
"convert_case 0.8.0", "convert_case 0.8.0",
"crossbeam-channel", "crossbeam-channel",
"csv", "csv",
"derive_builder 0.20.2", "derive_builder",
"dump", "dump",
"enum-iterator", "enum-iterator",
"file-store", "file-store",
@@ -3413,15 +3375,6 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.13.0" version = "0.13.0"
@@ -3767,7 +3720,7 @@ dependencies = [
"bincode 2.0.1", "bincode 2.0.1",
"byteorder", "byteorder",
"csv", "csv",
"derive_builder 0.20.2", "derive_builder",
"encoding", "encoding",
"encoding_rs", "encoding_rs",
"encoding_rs_io", "encoding_rs_io",
@@ -4291,6 +4244,7 @@ dependencies = [
"roaring 0.10.12", "roaring 0.10.12",
"rstar", "rstar",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"safetensors 0.6.2",
"serde", "serde",
"serde_json", "serde_json",
"slice-group-by", "slice-group-by",
@@ -5372,12 +5326,12 @@ dependencies = [
[[package]] [[package]]
name = "rayon-cond" name = "rayon-cond"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f"
dependencies = [ dependencies = [
"either", "either",
"itertools 0.11.0", "itertools 0.14.0",
"rayon", "rayon",
] ]
@@ -5798,6 +5752,16 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "safetensors"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@@ -6279,12 +6243,6 @@ dependencies = [
"indexmap", "indexmap",
] ]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.11.1" version = "0.11.1"
@@ -6610,21 +6568,24 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.15.2" version = "0.22.1"
source = "git+https://github.com/huggingface/tokenizers.git?tag=v0.15.2#701a73b869602b5639589d197e805349cdba3223" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6475a27088c98ea96d00b39a9ddfb63780d1ad4cceb6f48374349a96ab2b7842"
dependencies = [ dependencies = [
"ahash 0.8.12",
"aho-corasick", "aho-corasick",
"derive_builder 0.12.0", "compact_str",
"dary_heap",
"derive_builder",
"esaxx-rs", "esaxx-rs",
"getrandom 0.2.16", "getrandom 0.3.3",
"itertools 0.12.1", "itertools 0.14.0",
"lazy_static",
"log", "log",
"macro_rules_attribute", "macro_rules_attribute",
"monostate", "monostate",
"onig", "onig",
"paste", "paste",
"rand 0.8.5", "rand 0.9.2",
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
@@ -6632,7 +6593,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
"thiserror 1.0.69", "thiserror 2.0.16",
"unicode-normalization-alignments", "unicode-normalization-alignments",
"unicode-segmentation", "unicode-segmentation",
"unicode_categories", "unicode_categories",
@@ -6994,7 +6955,7 @@ dependencies = [
"num-traits", "num-traits",
"num_cpus", "num_cpus",
"rayon", "rayon",
"safetensors", "safetensors 0.4.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
"tracing", "tracing",
@@ -7224,7 +7185,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b2bf58be11fc9414104c6d3a2e464163db5ef74b12296bda593cac37b6e4777" checksum = "6b2bf58be11fc9414104c6d3a2e464163db5ef74b12296bda593cac37b6e4777"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derive_builder 0.20.2", "derive_builder",
"rustversion", "rustversion",
"vergen-lib", "vergen-lib",
] ]
@@ -7236,7 +7197,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f6ee511ec45098eabade8a0750e76eec671e7fb2d9360c563911336bea9cac1" checksum = "4f6ee511ec45098eabade8a0750e76eec671e7fb2d9360c563911336bea9cac1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derive_builder 0.20.2", "derive_builder",
"git2", "git2",
"rustversion", "rustversion",
"time", "time",
@@ -7251,7 +7212,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b07e6010c0f3e59fcb164e0163834597da68d1f864e2b8ca49f74de01e9c166" checksum = "9b07e6010c0f3e59fcb164e0163834597da68d1f864e2b8ca49f74de01e9c166"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derive_builder 0.20.2", "derive_builder",
"rustversion", "rustversion",
] ]

View File

@@ -74,12 +74,13 @@ csv = "1.3.1"
candle-core = { version = "0.9.1" } candle-core = { version = "0.9.1" }
candle-transformers = { version = "0.9.1" } candle-transformers = { version = "0.9.1" }
candle-nn = { version = "0.9.1" } candle-nn = { version = "0.9.1" }
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.15.2", version = "0.15.2", default-features = false, features = [ tokenizers = { version = "0.22.1", default-features = false, features = [
"onig", "onig",
] } ] }
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default-features = false, features = [ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default-features = false, features = [
"online", "online",
] } ] }
safetensors = "0.6.2"
tiktoken-rs = "0.7.0" tiktoken-rs = "0.7.0"
liquid = "0.26.11" liquid = "0.26.11"
rhai = { version = "1.22.2", features = [ rhai = { version = "1.22.2", features = [

View File

@@ -6,6 +6,7 @@ use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType}; use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
use safetensors::SafeTensors;
use super::EmbeddingCache; use super::EmbeddingCache;
use crate::vector::error::{EmbedError, NewEmbedderError}; use crate::vector::error::{EmbedError, NewEmbedderError};
@@ -113,6 +114,51 @@ impl std::fmt::Debug for Embedder {
} }
} }
// some models do not have the "model." prefix in their safetensors weights
fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::PathBuf, NewEmbedderError> {
let data = std::fs::read(weights_path)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Io(e)))?;
let tensors = SafeTensors::deserialize(&data)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
let names = tensors.names();
let has_model_prefix = names.iter().any(|n| n.starts_with("model."));
if has_model_prefix {
return Ok(weights_path.to_path_buf());
}
let fixed_path = weights_path.with_extension("fixed.safetensors");
if fixed_path.exists() {
return Ok(fixed_path);
}
let mut new_tensors = vec![];
for name in names {
let tensor_view = tensors.tensor(name)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
let new_name = format!("model.{}", name);
let data_offset = tensor_view.data();
let shape = tensor_view.shape();
let dtype = tensor_view.dtype();
new_tensors.push((new_name, shape.to_vec(), dtype, data_offset));
}
use safetensors::tensor::TensorView;
let views: Vec<(&str, TensorView)> = new_tensors.iter().map(|(name, shape, dtype, data)| {
(name.as_str(), TensorView::new(*dtype, shape.clone(), *data).unwrap())
}).collect();
safetensors::serialize_to_file(views, None, &fixed_path)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
Ok(fixed_path)
}
#[derive(Clone, Copy, serde::Deserialize)] #[derive(Clone, Copy, serde::Deserialize)]
struct PoolingConfig { struct PoolingConfig {
#[serde(default)] #[serde(default)]
@@ -254,11 +300,16 @@ impl Embedder {
let is_modern = has_arch("modernbert"); let is_modern = has_arch("modernbert");
tracing::debug!(is_modern, model_type, "detected HF architecture"); tracing::debug!(is_modern, model_type, "detected HF architecture");
// default to BERT otherwise
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let weights_filename = if is_modern && weight_source == WeightSource::Safetensors {
change_tensor_names(&weights_filename)?
} else {
weights_filename
};
let vb = match weight_source { let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?, .map_err(NewEmbedderError::pytorch_weight)?,