This commit is contained in:
Louis Dureuil
2025-11-10 16:56:13 +01:00
parent 9f7172f6ab
commit 33fa564a9c

View File

@@ -5,8 +5,8 @@ use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
use safetensors::SafeTensors;
use tokenizers::{PaddingParams, Tokenizer};
use super::EmbeddingCache;
use crate::vector::error::{EmbedError, NewEmbedderError};
@@ -115,7 +115,9 @@ impl std::fmt::Debug for Embedder {
}
// some models do not have the "model." prefix in their safetensors weights
fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::PathBuf, NewEmbedderError> {
fn change_tensor_names(
weights_path: &std::path::Path,
) -> Result<std::path::PathBuf, NewEmbedderError> {
let data = std::fs::read(weights_path)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Io(e)))?;
@@ -137,8 +139,9 @@ fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::Path
let mut new_tensors = vec![];
for name in names {
let tensor_view = tensors.tensor(name)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
let tensor_view = tensors.tensor(name).map_err(|e| {
NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string()))
})?;
let new_name = format!("model.{}", name);
let data_offset = tensor_view.data();
@@ -149,9 +152,12 @@ fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::Path
}
use safetensors::tensor::TensorView;
let views: Vec<(&str, TensorView)> = new_tensors.iter().map(|(name, shape, dtype, data)| {
(name.as_str(), TensorView::new(*dtype, shape.clone(), *data).unwrap())
}).collect();
let views: Vec<(&str, TensorView)> = new_tensors
.iter()
.map(|(name, shape, dtype, data)| {
(name.as_str(), TensorView::new(*dtype, shape.clone(), *data).unwrap())
})
.collect();
safetensors::serialize_to_file(views, None, &fixed_path)
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;