mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-11-22 04:36:32 +00:00
rustfmt
This commit is contained in:
@@ -5,8 +5,8 @@ use candle_transformers::models::modernbert::{Config as ModernConfig, ModernBert
|
||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
use safetensors::SafeTensors;
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
use super::EmbeddingCache;
|
||||
use crate::vector::error::{EmbedError, NewEmbedderError};
|
||||
@@ -115,7 +115,9 @@ impl std::fmt::Debug for Embedder {
|
||||
}
|
||||
|
||||
// some models do not have the "model." prefix in their safetensors weights
|
||||
fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::PathBuf, NewEmbedderError> {
|
||||
fn change_tensor_names(
|
||||
weights_path: &std::path::Path,
|
||||
) -> Result<std::path::PathBuf, NewEmbedderError> {
|
||||
let data = std::fs::read(weights_path)
|
||||
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Io(e)))?;
|
||||
|
||||
@@ -137,8 +139,9 @@ fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::Path
|
||||
|
||||
let mut new_tensors = vec![];
|
||||
for name in names {
|
||||
let tensor_view = tensors.tensor(name)
|
||||
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
|
||||
let tensor_view = tensors.tensor(name).map_err(|e| {
|
||||
NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string()))
|
||||
})?;
|
||||
|
||||
let new_name = format!("model.{}", name);
|
||||
let data_offset = tensor_view.data();
|
||||
@@ -149,9 +152,12 @@ fn change_tensor_names(weights_path: &std::path::Path) -> Result<std::path::Path
|
||||
}
|
||||
|
||||
use safetensors::tensor::TensorView;
|
||||
let views: Vec<(&str, TensorView)> = new_tensors.iter().map(|(name, shape, dtype, data)| {
|
||||
(name.as_str(), TensorView::new(*dtype, shape.clone(), *data).unwrap())
|
||||
}).collect();
|
||||
let views: Vec<(&str, TensorView)> = new_tensors
|
||||
.iter()
|
||||
.map(|(name, shape, dtype, data)| {
|
||||
(name.as_str(), TensorView::new(*dtype, shape.clone(), *data).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
safetensors::serialize_to_file(views, None, &fixed_path)
|
||||
.map_err(|e| NewEmbedderError::safetensor_weight(candle_core::Error::Msg(e.to_string())))?;
|
||||
|
||||
Reference in New Issue
Block a user