Compare commits

...

4 Commits

Author SHA1 Message Date
680868ef77 Use weak ptr in panic handler 2024-06-19 12:22:46 +02:00
c668043c4f Merge #4617
4617: Destructure `EmbedderOptions` so we don't miss some options r=dureuill a=dureuill

# Pull Request

## Related issue
#4595 was caused by the code not destructuring the embedder options.


## What does this PR do?
This PR adds the missing `url` parameter for ollama, and makes sure similar issue cannot happen in the future



Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2024-05-02 14:55:32 +00:00
5a305bfdea Remove unused struct 2024-05-02 16:14:37 +02:00
f4dd73ec8c Destructure EmbedderOptions so we don't miss some options 2024-05-02 15:39:36 +02:00
3 changed files with 56 additions and 26 deletions

View File

@ -367,12 +367,6 @@ async fn get_version(
})
}
#[derive(Serialize)]
struct KeysResponse {
private: Option<String>,
public: Option<String>,
}
pub async fn get_health(
index_scheduler: Data<IndexScheduler>,
auth_controller: Data<AuthController>,

View File

@ -34,7 +34,7 @@ impl ThreadPoolNoAbort {
}
#[derive(Error, Debug)]
#[error("A panic occured. Read the logs to find more information about it")]
#[error("A panic occurred. Read the logs to find more information about it")]
pub struct PanicCatched;
#[derive(Default)]
@ -61,9 +61,28 @@ impl ThreadPoolNoAbortBuilder {
pub fn build(mut self) -> Result<ThreadPoolNoAbort, rayon::ThreadPoolBuildError> {
let pool_catched_panic = Arc::new(AtomicBool::new(false));
self.0 = self.0.panic_handler({
let catched_panic = pool_catched_panic.clone();
move |_result| catched_panic.store(true, Ordering::SeqCst)
let catched_panic = Arc::downgrade(&pool_catched_panic);
move |_result| {
if let Some(catched_panic) = catched_panic.upgrade() {
catched_panic.store(true, Ordering::SeqCst)
}
}
});
Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic })
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use crate::ThreadPoolNoAbortBuilder;
#[test]
fn drop_pool() {
let pool = ThreadPoolNoAbortBuilder::new().num_threads(10).build().unwrap();
let caught_panic = Arc::downgrade(&pool.pool_catched_panic);
drop(pool);
assert_eq!(caught_panic.strong_count(), 0);
}
}

View File

@ -301,10 +301,14 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self {
let EmbeddingConfig { embedder_options, prompt } = value;
match embedder_options {
super::EmbedderOptions::HuggingFace(options) => Self {
super::EmbedderOptions::HuggingFace(super::hf::EmbedderOptions {
model,
revision,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::HuggingFace),
model: Setting::Set(options.model),
revision: options.revision.map(Setting::Set).unwrap_or_default(),
model: Setting::Set(model),
revision: revision.map(Setting::Set).unwrap_or_default(),
api_key: Setting::NotSet,
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
@ -314,14 +318,19 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::OpenAi(options) => Self {
super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions {
api_key,
embedding_model,
dimensions,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
model: Setting::Set(options.embedding_model.name().to_owned()),
model: Setting::Set(embedding_model.name().to_owned()),
revision: Setting::NotSet,
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
api_key: api_key.map(Setting::Set).unwrap_or_default(),
dimensions: dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
@ -329,29 +338,37 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::Ollama(options) => Self {
super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions {
embedding_model,
url,
api_key,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::Ollama),
model: Setting::Set(options.embedding_model.to_owned()),
model: Setting::Set(embedding_model),
revision: Setting::NotSet,
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
api_key: api_key.map(Setting::Set).unwrap_or_default(),
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
url: url.map(Setting::Set).unwrap_or_default(),
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::UserProvided(options) => Self {
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(options.dimensions),
dimensions: Setting::Set(dimensions),
document_template: Setting::NotSet,
url: Setting::NotSet,
query: Setting::NotSet,
@ -359,7 +376,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key,