mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-31 07:56:28 +00:00 
			
		
		
		
	Allow overriding pooling method
This commit is contained in:
		| @@ -34,6 +34,30 @@ pub struct EmbedderOptions { | ||||
|     pub model: String, | ||||
|     pub revision: Option<String>, | ||||
|     pub distribution: Option<DistributionShift>, | ||||
|     #[serde(default)] | ||||
|     pub pooling: OverridePooling, | ||||
| } | ||||
|  | ||||
| #[derive( | ||||
|     Debug, | ||||
|     Clone, | ||||
|     Copy, | ||||
|     Default, | ||||
|     Hash, | ||||
|     PartialEq, | ||||
|     Eq, | ||||
|     serde::Deserialize, | ||||
|     serde::Serialize, | ||||
|     utoipa::ToSchema, | ||||
|     deserr::Deserr, | ||||
| )] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| #[serde(rename_all = "camelCase")] | ||||
| pub enum OverridePooling { | ||||
|     UseModel, | ||||
|     ForceCls, | ||||
|     #[default] | ||||
|     ForceMean, | ||||
| } | ||||
|  | ||||
| impl EmbedderOptions { | ||||
| @@ -42,6 +66,7 @@ impl EmbedderOptions { | ||||
|             model: "BAAI/bge-base-en-v1.5".to_string(), | ||||
|             revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), | ||||
|             distribution: None, | ||||
|             pooling: OverridePooling::UseModel, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -95,6 +120,15 @@ pub enum Pooling { | ||||
|     MeanSqrtLen, | ||||
|     LastToken, | ||||
| } | ||||
| impl Pooling { | ||||
|     fn override_with(&mut self, pooling: OverridePooling) { | ||||
|         match pooling { | ||||
|             OverridePooling::UseModel => {} | ||||
|             OverridePooling::ForceCls => *self = Pooling::Cls, | ||||
|             OverridePooling::ForceMean => *self = Pooling::Mean, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<PoolingConfig> for Pooling { | ||||
|     fn from(value: PoolingConfig) -> Self { | ||||
| @@ -151,7 +185,7 @@ impl Embedder { | ||||
|                 } | ||||
|                 Err(error) => return Err(NewEmbedderError::api_get(error)), | ||||
|             }; | ||||
|             let pooling: Pooling = match pooling { | ||||
|             let mut pooling: Pooling = match pooling { | ||||
|                 Some(pooling_filename) => { | ||||
|                     let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| { | ||||
|                         NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner) | ||||
| @@ -170,6 +204,8 @@ impl Embedder { | ||||
|                 None => Pooling::default(), | ||||
|             }; | ||||
|  | ||||
|             pooling.override_with(options.pooling); | ||||
|  | ||||
|             (config, tokenizer, weights, source, pooling) | ||||
|         }; | ||||
|  | ||||
|   | ||||
| @@ -6,6 +6,7 @@ use roaring::RoaringBitmap; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use utoipa::ToSchema; | ||||
|  | ||||
| use super::hf::OverridePooling; | ||||
| use super::{ollama, openai, DistributionShift}; | ||||
| use crate::prompt::{default_max_bytes, PromptData}; | ||||
| use crate::update::Setting; | ||||
| @@ -30,6 +31,10 @@ pub struct EmbeddingSettings { | ||||
|     pub revision: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     #[schema(value_type = Option<OverridePooling>)] | ||||
|     pub pooling: Setting<OverridePooling>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     #[schema(value_type = Option<String>)] | ||||
|     pub api_key: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
| @@ -164,6 +169,7 @@ impl SettingsDiff { | ||||
|                     mut source, | ||||
|                     mut model, | ||||
|                     mut revision, | ||||
|                     mut pooling, | ||||
|                     mut api_key, | ||||
|                     mut dimensions, | ||||
|                     mut document_template, | ||||
| @@ -180,6 +186,7 @@ impl SettingsDiff { | ||||
|                     source: new_source, | ||||
|                     model: new_model, | ||||
|                     revision: new_revision, | ||||
|                     pooling: new_pooling, | ||||
|                     api_key: new_api_key, | ||||
|                     dimensions: new_dimensions, | ||||
|                     document_template: new_document_template, | ||||
| @@ -210,6 +217,7 @@ impl SettingsDiff { | ||||
|                         &source, | ||||
|                         &mut model, | ||||
|                         &mut revision, | ||||
|                         &mut pooling, | ||||
|                         &mut dimensions, | ||||
|                         &mut url, | ||||
|                         &mut request, | ||||
| @@ -225,6 +233,9 @@ impl SettingsDiff { | ||||
|                 if revision.apply(new_revision) { | ||||
|                     ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); | ||||
|                 } | ||||
|                 if pooling.apply(new_pooling) { | ||||
|                     ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); | ||||
|                 } | ||||
|                 if dimensions.apply(new_dimensions) { | ||||
|                     match source { | ||||
|                         // regenerate on dimensions change in OpenAI since truncation is supported | ||||
| @@ -290,6 +301,7 @@ impl SettingsDiff { | ||||
|                     source, | ||||
|                     model, | ||||
|                     revision, | ||||
|                     pooling, | ||||
|                     api_key, | ||||
|                     dimensions, | ||||
|                     document_template, | ||||
| @@ -338,6 +350,7 @@ fn apply_default_for_source( | ||||
|     source: &Setting<EmbedderSource>, | ||||
|     model: &mut Setting<String>, | ||||
|     revision: &mut Setting<String>, | ||||
|     pooling: &mut Setting<OverridePooling>, | ||||
|     dimensions: &mut Setting<usize>, | ||||
|     url: &mut Setting<String>, | ||||
|     request: &mut Setting<serde_json::Value>, | ||||
| @@ -350,6 +363,7 @@ fn apply_default_for_source( | ||||
|         Setting::Set(EmbedderSource::HuggingFace) => { | ||||
|             *model = Setting::Reset; | ||||
|             *revision = Setting::Reset; | ||||
|             *pooling = Setting::Reset; | ||||
|             *dimensions = Setting::NotSet; | ||||
|             *url = Setting::NotSet; | ||||
|             *request = Setting::NotSet; | ||||
| @@ -359,6 +373,7 @@ fn apply_default_for_source( | ||||
|         Setting::Set(EmbedderSource::Ollama) => { | ||||
|             *model = Setting::Reset; | ||||
|             *revision = Setting::NotSet; | ||||
|             *pooling = Setting::NotSet; | ||||
|             *dimensions = Setting::Reset; | ||||
|             *url = Setting::NotSet; | ||||
|             *request = Setting::NotSet; | ||||
| @@ -368,6 +383,7 @@ fn apply_default_for_source( | ||||
|         Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => { | ||||
|             *model = Setting::Reset; | ||||
|             *revision = Setting::NotSet; | ||||
|             *pooling = Setting::NotSet; | ||||
|             *dimensions = Setting::NotSet; | ||||
|             *url = Setting::Reset; | ||||
|             *request = Setting::NotSet; | ||||
| @@ -377,6 +393,7 @@ fn apply_default_for_source( | ||||
|         Setting::Set(EmbedderSource::Rest) => { | ||||
|             *model = Setting::NotSet; | ||||
|             *revision = Setting::NotSet; | ||||
|             *pooling = Setting::NotSet; | ||||
|             *dimensions = Setting::Reset; | ||||
|             *url = Setting::Reset; | ||||
|             *request = Setting::Reset; | ||||
| @@ -386,6 +403,7 @@ fn apply_default_for_source( | ||||
|         Setting::Set(EmbedderSource::UserProvided) => { | ||||
|             *model = Setting::NotSet; | ||||
|             *revision = Setting::NotSet; | ||||
|             *pooling = Setting::NotSet; | ||||
|             *dimensions = Setting::Reset; | ||||
|             *url = Setting::NotSet; | ||||
|             *request = Setting::NotSet; | ||||
| @@ -419,6 +437,7 @@ impl EmbeddingSettings { | ||||
|     pub const SOURCE: &'static str = "source"; | ||||
|     pub const MODEL: &'static str = "model"; | ||||
|     pub const REVISION: &'static str = "revision"; | ||||
|     pub const POOLING: &'static str = "pooling"; | ||||
|     pub const API_KEY: &'static str = "apiKey"; | ||||
|     pub const DIMENSIONS: &'static str = "dimensions"; | ||||
|     pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; | ||||
| @@ -446,6 +465,7 @@ impl EmbeddingSettings { | ||||
|                 &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] | ||||
|             } | ||||
|             Self::REVISION => &[EmbedderSource::HuggingFace], | ||||
|             Self::POOLING => &[EmbedderSource::HuggingFace], | ||||
|             Self::API_KEY => { | ||||
|                 &[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest] | ||||
|             } | ||||
| @@ -500,6 +520,7 @@ impl EmbeddingSettings { | ||||
|                 Self::SOURCE, | ||||
|                 Self::MODEL, | ||||
|                 Self::REVISION, | ||||
|                 Self::POOLING, | ||||
|                 Self::DOCUMENT_TEMPLATE, | ||||
|                 Self::DOCUMENT_TEMPLATE_MAX_BYTES, | ||||
|                 Self::DISTRIBUTION, | ||||
| @@ -592,10 +613,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 model, | ||||
|                 revision, | ||||
|                 distribution, | ||||
|                 pooling, | ||||
|             }) => Self { | ||||
|                 source: Setting::Set(EmbedderSource::HuggingFace), | ||||
|                 model: Setting::Set(model), | ||||
|                 revision: Setting::some_or_not_set(revision), | ||||
|                 pooling: Setting::Set(pooling), | ||||
|                 api_key: Setting::NotSet, | ||||
|                 dimensions: Setting::NotSet, | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
| @@ -617,6 +640,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 source: Setting::Set(EmbedderSource::OpenAi), | ||||
|                 model: Setting::Set(embedding_model.name().to_owned()), | ||||
|                 revision: Setting::NotSet, | ||||
|                 pooling: Setting::NotSet, | ||||
|                 api_key: Setting::some_or_not_set(api_key), | ||||
|                 dimensions: Setting::some_or_not_set(dimensions), | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
| @@ -638,6 +662,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 source: Setting::Set(EmbedderSource::Ollama), | ||||
|                 model: Setting::Set(embedding_model), | ||||
|                 revision: Setting::NotSet, | ||||
|                 pooling: Setting::NotSet, | ||||
|                 api_key: Setting::some_or_not_set(api_key), | ||||
|                 dimensions: Setting::some_or_not_set(dimensions), | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
| @@ -656,6 +681,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 source: Setting::Set(EmbedderSource::UserProvided), | ||||
|                 model: Setting::NotSet, | ||||
|                 revision: Setting::NotSet, | ||||
|                 pooling: Setting::NotSet, | ||||
|                 api_key: Setting::NotSet, | ||||
|                 dimensions: Setting::Set(dimensions), | ||||
|                 document_template: Setting::NotSet, | ||||
| @@ -679,6 +705,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|                 source: Setting::Set(EmbedderSource::Rest), | ||||
|                 model: Setting::NotSet, | ||||
|                 revision: Setting::NotSet, | ||||
|                 pooling: Setting::NotSet, | ||||
|                 api_key: Setting::some_or_not_set(api_key), | ||||
|                 dimensions: Setting::some_or_not_set(dimensions), | ||||
|                 document_template: Setting::Set(prompt.template), | ||||
| @@ -701,6 +728,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig { | ||||
|             source, | ||||
|             model, | ||||
|             revision, | ||||
|             pooling, | ||||
|             api_key, | ||||
|             dimensions, | ||||
|             document_template, | ||||
| @@ -764,6 +792,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig { | ||||
|                     if let Some(revision) = revision.set() { | ||||
|                         options.revision = Some(revision); | ||||
|                     } | ||||
|                     if let Some(pooling) = pooling.set() { | ||||
|                         options.pooling = pooling; | ||||
|                     } | ||||
|                     options.distribution = distribution.set(); | ||||
|                     this.embedder_options = super::EmbedderOptions::HuggingFace(options); | ||||
|                 } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user