mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 13:06:27 +00:00 
			
		
		
		
	Introduce the ThreadPoolNoAbort wrapper
This commit is contained in:
		| @@ -6,7 +6,6 @@ use std::num::ParseIntError; | ||||
| use std::ops::Deref; | ||||
| use std::path::PathBuf; | ||||
| use std::str::FromStr; | ||||
| use std::sync::atomic::{AtomicBool, Ordering}; | ||||
| use std::sync::Arc; | ||||
| use std::{env, fmt, fs}; | ||||
|  | ||||
| @@ -14,6 +13,7 @@ use byte_unit::{Byte, ByteError}; | ||||
| use clap::Parser; | ||||
| use meilisearch_types::features::InstanceTogglableFeatures; | ||||
| use meilisearch_types::milli::update::IndexerConfig; | ||||
| use meilisearch_types::milli::ThreadPoolNoAbortBuilder; | ||||
| use rustls::server::{ | ||||
|     AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache, | ||||
| }; | ||||
| @@ -667,23 +667,15 @@ impl TryFrom<&IndexerOpts> for IndexerConfig { | ||||
|     type Error = anyhow::Error; | ||||
|  | ||||
|     fn try_from(other: &IndexerOpts) -> Result<Self, Self::Error> { | ||||
|         let pool_panic_catched = Arc::new(AtomicBool::new(false)); | ||||
|         let thread_pool = rayon::ThreadPoolBuilder::new() | ||||
|         let thread_pool = ThreadPoolNoAbortBuilder::new() | ||||
|             .thread_name(|index| format!("indexing-thread:{index}")) | ||||
|             .num_threads(*other.max_indexing_threads) | ||||
|             .panic_handler({ | ||||
|                 // TODO What should we do with this Box<dyn Any + Send>. | ||||
|                 //      So, let's just set a value to true to cancel the task with a message for now. | ||||
|                 let panic_cathed = pool_panic_catched.clone(); | ||||
|                 move |_result| panic_cathed.store(true, Ordering::SeqCst) | ||||
|             }) | ||||
|             .build()?; | ||||
|  | ||||
|         Ok(Self { | ||||
|             log_every_n: Some(DEFAULT_LOG_EVERY_N), | ||||
|             max_memory: other.max_indexing_memory.map(|b| b.get_bytes() as usize), | ||||
|             thread_pool: Some(thread_pool), | ||||
|             pool_panic_catched, | ||||
|             max_positions_per_attributes: None, | ||||
|             skip_index_budget: other.skip_index_budget, | ||||
|             ..Default::default() | ||||
|   | ||||
| @@ -9,6 +9,7 @@ use serde_json::Value; | ||||
| use thiserror::Error; | ||||
|  | ||||
| use crate::documents::{self, DocumentsBatchCursorError}; | ||||
| use crate::thread_pool_no_abort::PanicCatched; | ||||
| use crate::{CriterionError, DocumentId, FieldId, Object, SortError}; | ||||
|  | ||||
| pub fn is_reserved_keyword(keyword: &str) -> bool { | ||||
| @@ -49,8 +50,8 @@ pub enum InternalError { | ||||
|     InvalidDatabaseTyping, | ||||
|     #[error(transparent)] | ||||
|     RayonThreadPool(#[from] ThreadPoolBuildError), | ||||
|     #[error("A panic occured. Read the logs to find more information about it")] | ||||
|     PanicInThreadPool, | ||||
|     #[error(transparent)] | ||||
|     PanicInThreadPool(#[from] PanicCatched), | ||||
|     #[error(transparent)] | ||||
|     SerdeJson(#[from] serde_json::Error), | ||||
|     #[error(transparent)] | ||||
|   | ||||
| @@ -21,6 +21,7 @@ pub mod prompt; | ||||
| pub mod proximity; | ||||
| pub mod score_details; | ||||
| mod search; | ||||
| mod thread_pool_no_abort; | ||||
| pub mod update; | ||||
| pub mod vector; | ||||
|  | ||||
| @@ -42,6 +43,7 @@ pub use search::new::{ | ||||
|     SearchLogger, VisualSearchLogger, | ||||
| }; | ||||
| use serde_json::Value; | ||||
| pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; | ||||
| pub use {charabia as tokenizer, heed}; | ||||
|  | ||||
| pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError}; | ||||
|   | ||||
							
								
								
									
										69
									
								
								milli/src/thread_pool_no_abort.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								milli/src/thread_pool_no_abort.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| use std::sync::atomic::{AtomicBool, Ordering}; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use rayon::{ThreadPool, ThreadPoolBuilder}; | ||||
| use thiserror::Error; | ||||
|  | ||||
| /// A rayon ThreadPool wrapper that can catch panics in the pool | ||||
| /// and modifies the install function accordingly. | ||||
| #[derive(Debug)] | ||||
| pub struct ThreadPoolNoAbort { | ||||
|     thread_pool: ThreadPool, | ||||
|     /// Set to true if the thread pool catched a panic. | ||||
|     pool_catched_panic: Arc<AtomicBool>, | ||||
| } | ||||
|  | ||||
| impl ThreadPoolNoAbort { | ||||
|     pub fn install<OP, R>(&self, op: OP) -> Result<R, PanicCatched> | ||||
|     where | ||||
|         OP: FnOnce() -> R + Send, | ||||
|         R: Send, | ||||
|     { | ||||
|         let output = self.thread_pool.install(op); | ||||
|         // While reseting the pool panic catcher we return an error if we catched one. | ||||
|         if self.pool_catched_panic.swap(false, Ordering::SeqCst) { | ||||
|             Err(PanicCatched) | ||||
|         } else { | ||||
|             Ok(output) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn current_num_threads(&self) -> usize { | ||||
|         self.thread_pool.current_num_threads() | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Error, Debug)] | ||||
| #[error("A panic occured. Read the logs to find more information about it")] | ||||
| pub struct PanicCatched; | ||||
|  | ||||
| #[derive(Default)] | ||||
| pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder); | ||||
|  | ||||
| impl ThreadPoolNoAbortBuilder { | ||||
|     pub fn new() -> ThreadPoolNoAbortBuilder { | ||||
|         ThreadPoolNoAbortBuilder::default() | ||||
|     } | ||||
|  | ||||
|     pub fn thread_name<F>(mut self, closure: F) -> Self | ||||
|     where | ||||
|         F: FnMut(usize) -> String + 'static, | ||||
|     { | ||||
|         self.0 = self.0.thread_name(closure); | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder { | ||||
|         self.0 = self.0.num_threads(num_threads); | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     pub fn build(mut self) -> Result<ThreadPoolNoAbort, rayon::ThreadPoolBuildError> { | ||||
|         let pool_catched_panic = Arc::new(AtomicBool::new(false)); | ||||
|         self.0 = self.0.panic_handler({ | ||||
|             let catched_panic = pool_catched_panic.clone(); | ||||
|             move |_result| catched_panic.store(true, Ordering::SeqCst) | ||||
|         }); | ||||
|         Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic }) | ||||
|     } | ||||
| } | ||||
| @@ -19,7 +19,7 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; | ||||
| use crate::update::index_documents::helpers::try_split_at; | ||||
| use crate::update::settings::InnerIndexSettingsDiff; | ||||
| use crate::vector::Embedder; | ||||
| use crate::{DocumentId, InternalError, Result, VectorOrArrayOfVectors}; | ||||
| use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, VectorOrArrayOfVectors}; | ||||
|  | ||||
| /// The length of the elements that are always in the buffer when inserting new values. | ||||
| const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); | ||||
| @@ -347,7 +347,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|     prompt_reader: grenad::Reader<R>, | ||||
|     indexer: GrenadParameters, | ||||
|     embedder: Arc<Embedder>, | ||||
|     request_threads: &rayon::ThreadPool, | ||||
|     request_threads: &ThreadPoolNoAbort, | ||||
| ) -> Result<grenad::Reader<BufReader<File>>> { | ||||
|     puffin::profile_function!(); | ||||
|     let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism | ||||
|   | ||||
| @@ -31,7 +31,7 @@ use self::extract_word_position_docids::extract_word_position_docids; | ||||
| use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters}; | ||||
| use super::{helpers, TypedChunk}; | ||||
| use crate::update::settings::InnerIndexSettingsDiff; | ||||
| use crate::{FieldId, Result}; | ||||
| use crate::{FieldId, Result, ThreadPoolNoAbortBuilder}; | ||||
|  | ||||
| /// Extract data for each databases from obkv documents in parallel. | ||||
| /// Send data in grenad file over provided Sender. | ||||
| @@ -229,7 +229,7 @@ fn send_original_documents_data( | ||||
|     let documents_chunk_cloned = original_documents_chunk.clone(); | ||||
|     let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); | ||||
|  | ||||
|     let request_threads = rayon::ThreadPoolBuilder::new() | ||||
|     let request_threads = ThreadPoolNoAbortBuilder::new() | ||||
|         .num_threads(crate::vector::REQUEST_PARALLELISM) | ||||
|         .thread_name(|index| format!("embedding-request-{index}")) | ||||
|         .build()?; | ||||
|   | ||||
| @@ -8,7 +8,6 @@ use std::collections::{HashMap, HashSet}; | ||||
| use std::io::{Read, Seek}; | ||||
| use std::num::NonZeroU32; | ||||
| use std::result::Result as StdResult; | ||||
| use std::sync::atomic::Ordering; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use crossbeam_channel::{Receiver, Sender}; | ||||
| @@ -34,6 +33,7 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; | ||||
| pub use self::transform::{Transform, TransformOutput}; | ||||
| use crate::documents::{obkv_to_object, DocumentsBatchReader}; | ||||
| use crate::error::{Error, InternalError, UserError}; | ||||
| use crate::thread_pool_no_abort::ThreadPoolNoAbortBuilder; | ||||
| pub use crate::update::index_documents::helpers::CursorClonableMmap; | ||||
| use crate::update::{ | ||||
|     IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, | ||||
| @@ -297,17 +297,13 @@ where | ||||
|         let settings_diff = Arc::new(settings_diff); | ||||
|  | ||||
|         let backup_pool; | ||||
|         let pool_catched_panic = self.indexer_config.pool_panic_catched.clone(); | ||||
|         let pool = match self.indexer_config.thread_pool { | ||||
|             Some(ref pool) => pool, | ||||
|             None => { | ||||
|                 // We initialize a backup pool with the default | ||||
|                 // settings if none have already been set. | ||||
|                 let mut pool_builder = rayon::ThreadPoolBuilder::new(); | ||||
|                 pool_builder = pool_builder.panic_handler({ | ||||
|                     let catched_panic = pool_catched_panic.clone(); | ||||
|                     move |_result| catched_panic.store(true, Ordering::SeqCst) | ||||
|                 }); | ||||
|                 #[allow(unused_mut)] | ||||
|                 let mut pool_builder = ThreadPoolNoAbortBuilder::new(); | ||||
|  | ||||
|                 #[cfg(test)] | ||||
|                 { | ||||
| @@ -538,12 +534,7 @@ where | ||||
|             } | ||||
|  | ||||
|             Ok(()) | ||||
|         })?; | ||||
|  | ||||
|         // While reseting the pool panic catcher we return an error if we catched one. | ||||
|         if pool_catched_panic.swap(false, Ordering::SeqCst) { | ||||
|             return Err(InternalError::PanicInThreadPool.into()); | ||||
|         } | ||||
|         }).map_err(InternalError::from)??; | ||||
|  | ||||
|         // We write the field distribution into the main database | ||||
|         self.index.put_field_distribution(self.wtxn, &field_distribution)?; | ||||
| @@ -572,12 +563,8 @@ where | ||||
|                     writer.build(wtxn, &mut rng, None)?; | ||||
|                 } | ||||
|                 Result::Ok(()) | ||||
|             })?; | ||||
|  | ||||
|             // While reseting the pool panic catcher we return an error if we catched one. | ||||
|             if pool_catched_panic.swap(false, Ordering::SeqCst) { | ||||
|                 return Err(InternalError::PanicInThreadPool.into()); | ||||
|             } | ||||
|             }) | ||||
|             .map_err(InternalError::from)??; | ||||
|         } | ||||
|  | ||||
|         self.execute_prefix_databases( | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| use std::sync::atomic::AtomicBool; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use grenad::CompressionType; | ||||
| use rayon::ThreadPool; | ||||
|  | ||||
| use crate::thread_pool_no_abort::ThreadPoolNoAbort; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct IndexerConfig { | ||||
| @@ -12,10 +10,7 @@ pub struct IndexerConfig { | ||||
|     pub max_memory: Option<usize>, | ||||
|     pub chunk_compression_type: CompressionType, | ||||
|     pub chunk_compression_level: Option<u32>, | ||||
|     pub thread_pool: Option<ThreadPool>, | ||||
|     /// Set to true if the thread pool catched a panic | ||||
|     /// and we must abort the task | ||||
|     pub pool_panic_catched: Arc<AtomicBool>, | ||||
|     pub thread_pool: Option<ThreadPoolNoAbort>, | ||||
|     pub max_positions_per_attributes: Option<u32>, | ||||
|     pub skip_index_budget: bool, | ||||
| } | ||||
| @@ -30,7 +25,6 @@ impl Default for IndexerConfig { | ||||
|             chunk_compression_type: CompressionType::None, | ||||
|             chunk_compression_level: None, | ||||
|             thread_pool: None, | ||||
|             pool_panic_catched: Arc::default(), | ||||
|             max_positions_per_attributes: None, | ||||
|             skip_index_budget: false, | ||||
|         } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ use std::path::PathBuf; | ||||
| use hf_hub::api::sync::ApiError; | ||||
|  | ||||
| use crate::error::FaultSource; | ||||
| use crate::PanicCatched; | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("Error while generating embeddings: {inner}")] | ||||
| @@ -80,6 +81,8 @@ pub enum EmbedErrorKind { | ||||
|     OpenAiUnexpectedDimension(usize, usize), | ||||
|     #[error("no embedding was produced")] | ||||
|     MissingEmbedding, | ||||
|     #[error(transparent)] | ||||
|     PanicInThreadPool(#[from] PanicCatched), | ||||
| } | ||||
|  | ||||
| impl EmbedError { | ||||
|   | ||||
| @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use self::error::{EmbedError, NewEmbedderError}; | ||||
| use crate::prompt::{Prompt, PromptData}; | ||||
| use crate::ThreadPoolNoAbort; | ||||
|  | ||||
| pub mod error; | ||||
| pub mod hf; | ||||
| @@ -254,7 +255,7 @@ impl Embedder { | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|         threads: &ThreadPoolNoAbort, | ||||
|     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), | ||||
|   | ||||
| @@ -3,6 +3,8 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
| use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; | ||||
| use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; | ||||
| use super::{DistributionShift, Embeddings}; | ||||
| use crate::error::FaultSource; | ||||
| use crate::ThreadPoolNoAbort; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
| @@ -71,11 +73,16 @@ impl Embedder { | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|         threads: &ThreadPoolNoAbort, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         threads.install(move || { | ||||
|         threads | ||||
|             .install(move || { | ||||
|                 text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() | ||||
|             }) | ||||
|             .map_err(|error| EmbedError { | ||||
|                 kind: EmbedErrorKind::PanicInThreadPool(error), | ||||
|                 fault: FaultSource::Bug, | ||||
|             })? | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|   | ||||
| @@ -4,7 +4,9 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; | ||||
| use super::error::{EmbedError, NewEmbedderError}; | ||||
| use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; | ||||
| use super::{DistributionShift, Embeddings}; | ||||
| use crate::error::FaultSource; | ||||
| use crate::vector::error::EmbedErrorKind; | ||||
| use crate::ThreadPoolNoAbort; | ||||
|  | ||||
| #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||
| pub struct EmbedderOptions { | ||||
| @@ -241,11 +243,16 @@ impl Embedder { | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|         threads: &ThreadPoolNoAbort, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         threads.install(move || { | ||||
|         threads | ||||
|             .install(move || { | ||||
|                 text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() | ||||
|             }) | ||||
|             .map_err(|error| EmbedError { | ||||
|                 kind: EmbedErrorKind::PanicInThreadPool(error), | ||||
|                 fault: FaultSource::Bug, | ||||
|             })? | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|   | ||||
| @@ -2,9 +2,12 @@ use deserr::Deserr; | ||||
| use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use super::error::EmbedErrorKind; | ||||
| use super::{ | ||||
|     DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, | ||||
| }; | ||||
| use crate::error::FaultSource; | ||||
| use crate::ThreadPoolNoAbort; | ||||
|  | ||||
| // retrying in case of failure | ||||
|  | ||||
| @@ -158,11 +161,16 @@ impl Embedder { | ||||
|     pub fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|         threads: &rayon::ThreadPool, | ||||
|         threads: &ThreadPoolNoAbort, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         threads.install(move || { | ||||
|         threads | ||||
|             .install(move || { | ||||
|                 text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() | ||||
|             }) | ||||
|             .map_err(|error| EmbedError { | ||||
|                 kind: EmbedErrorKind::PanicInThreadPool(error), | ||||
|                 fault: FaultSource::Bug, | ||||
|             })? | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|   | ||||
| @@ -217,9 +217,7 @@ fn add_memory_samples( | ||||
|     memory_counters: &mut Option<MemoryCounterHandles>, | ||||
|     last_memory: &mut MemoryStats, | ||||
| ) -> Option<MemoryStats> { | ||||
|     let Some(stats) = memory else { | ||||
|         return None; | ||||
|     }; | ||||
|     let stats = memory?; | ||||
|  | ||||
|     let memory_counters = | ||||
|         memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main)); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user