diff --git a/Cargo.lock b/Cargo.lock index 4e897e580..8a3942d5b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3788,6 +3788,7 @@ dependencies = [ "rustls", "rustls-pemfile", "rustls-pki-types", + "secrecy", "segment", "serde", "serde_json", diff --git a/crates/meilisearch-types/src/error.rs b/crates/meilisearch-types/src/error.rs index 11bad977d..27f7580e2 100644 --- a/crates/meilisearch-types/src/error.rs +++ b/crates/meilisearch-types/src/error.rs @@ -391,6 +391,10 @@ EditDocumentsByFunctionError , InvalidRequest , BAD_REQU InvalidSettingsIndexChat , InvalidRequest , BAD_REQUEST ; // Experimental features - Chat Completions ChatWorkspaceNotFound , InvalidRequest , NOT_FOUND ; +InvalidChatCompletionOrgId , InvalidRequest , BAD_REQUEST ; +InvalidChatCompletionProjectId , InvalidRequest , BAD_REQUEST ; +InvalidChatCompletionApiVersion , InvalidRequest , BAD_REQUEST ; +InvalidChatCompletionDeploymentId , InvalidRequest , BAD_REQUEST ; InvalidChatCompletionSource , InvalidRequest , BAD_REQUEST ; InvalidChatCompletionBaseApi , InvalidRequest , BAD_REQUEST ; InvalidChatCompletionApiKey , InvalidRequest , BAD_REQUEST ; diff --git a/crates/meilisearch-types/src/features.rs b/crates/meilisearch-types/src/features.rs index 95706fb46..9bcd58347 100644 --- a/crates/meilisearch-types/src/features.rs +++ b/crates/meilisearch-types/src/features.rs @@ -51,6 +51,14 @@ pub struct Network { pub struct ChatCompletionSettings { pub source: ChatCompletionSource, #[serde(default)] + pub org_id: Option, + #[serde(default)] + pub project_id: Option, + #[serde(default)] + pub api_version: Option, + #[serde(default)] + pub deployment_id: Option, + #[serde(default)] pub base_api: Option, #[serde(default)] pub api_key: Option, @@ -88,6 +96,43 @@ impl ChatCompletionSettings { pub enum ChatCompletionSource { #[default] OpenAi, + AzureOpenAi, + Mistral, + Gemini, + VLlm, +} + +impl ChatCompletionSource { + pub fn system_role(&self, model: &str) -> &'static str { + match self { + ChatCompletionSource::OpenAi if Self::old_openai_model(model) => "system", + ChatCompletionSource::OpenAi => "developer", + ChatCompletionSource::AzureOpenAi if Self::old_openai_model(model) => "system", + ChatCompletionSource::AzureOpenAi => "developer", + ChatCompletionSource::Mistral => "system", + ChatCompletionSource::Gemini => "system", + ChatCompletionSource::VLlm => "system", + } + } + + /// Returns true if the model is an old OpenAI model. + /// + /// Old OpenAI models use the system role while new ones use the developer role. + fn old_openai_model(model: &str) -> bool { + ["gpt-3.5", "gpt-4", "gpt-4.1", "gpt-4.5", "gpt-4o", "chatgpt-4o"].iter().any(|old| { + model.starts_with(old) && model.chars().nth(old.len()).is_none_or(|last| last == '-') + }) + } + + pub fn base_url(&self) -> Option<&'static str> { + use ChatCompletionSource::*; + match self { + OpenAi => Some("https://api.openai.com/v1/"), + Mistral => Some("https://api.mistral.ai/v1/"), + Gemini => Some("https://generativelanguage.googleapis.com/v1beta/openai/"), + AzureOpenAi | VLlm => None, + } + } } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] @@ -111,3 +156,85 @@ impl Default for ChatCompletionPrompts { } } } + +#[cfg(test)] +mod tests { + use super::*; + + const ALL_OPENAI_MODELS_OLDINESS: &[(&str, bool)] = &[ + ("gpt-4-0613", true), + ("gpt-4", true), + ("gpt-3.5-turbo", true), + ("gpt-4o-audio-preview-2025-06-03", true), + ("gpt-4.1-nano", true), + ("gpt-4o-realtime-preview-2025-06-03", true), + ("gpt-3.5-turbo-instruct", true), + ("gpt-3.5-turbo-instruct-0914", true), + ("gpt-4-1106-preview", true), + ("gpt-3.5-turbo-1106", true), + ("gpt-4-0125-preview", true), + ("gpt-4-turbo-preview", true), + ("gpt-3.5-turbo-0125", true), + ("gpt-4-turbo", true), + ("gpt-4-turbo-2024-04-09", true), + ("gpt-4o", true), + ("gpt-4o-2024-05-13", true), + ("gpt-4o-mini-2024-07-18", true), + ("gpt-4o-mini", true), + ("gpt-4o-2024-08-06", true), + ("chatgpt-4o-latest", true), + ("gpt-4o-realtime-preview-2024-10-01", true), + ("gpt-4o-audio-preview-2024-10-01", true), + ("gpt-4o-audio-preview", true), + ("gpt-4o-realtime-preview", true), + ("gpt-4o-realtime-preview-2024-12-17", true), + ("gpt-4o-audio-preview-2024-12-17", true), + ("gpt-4o-mini-realtime-preview-2024-12-17", true), + ("gpt-4o-mini-audio-preview-2024-12-17", true), + ("gpt-4o-mini-realtime-preview", true), + ("gpt-4o-mini-audio-preview", true), + ("gpt-4o-2024-11-20", true), + ("gpt-4.5-preview", true), + ("gpt-4.5-preview-2025-02-27", true), + ("gpt-4o-search-preview-2025-03-11", true), + ("gpt-4o-search-preview", true), + ("gpt-4o-mini-search-preview-2025-03-11", true), + ("gpt-4o-mini-search-preview", true), + ("gpt-4o-transcribe", true), + ("gpt-4o-mini-transcribe", true), + ("gpt-4o-mini-tts", true), + ("gpt-4.1-2025-04-14", true), + ("gpt-4.1", true), + ("gpt-4.1-mini-2025-04-14", true), + ("gpt-4.1-mini", true), + ("gpt-4.1-nano-2025-04-14", true), + ("gpt-3.5-turbo-16k", true), + // + // new models + ("o1-preview-2024-09-12", false), + ("o1-preview", false), + ("o1-mini-2024-09-12", false), + ("o1-mini", false), + ("o1-2024-12-17", false), + ("o1", false), + ("o3-mini", false), + ("o3-mini-2025-01-31", false), + ("o1-pro-2025-03-19", false), + ("o1-pro", false), + ("o3-2025-04-16", false), + ("o4-mini-2025-04-16", false), + ("o3", false), + ("o4-mini", false), + ]; + + #[test] + fn old_openai_models() { + for (name, is_old) in ALL_OPENAI_MODELS_OLDINESS.iter().copied() { + assert_eq!( + ChatCompletionSource::old_openai_model(name), + is_old, + "Model {name} is not considered old" + ); + } + } +} diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index deea9f803..a40b63a24 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -114,6 +114,7 @@ utoipa = { version = "5.3.1", features = [ ] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling" } +secrecy = "0.10.3" actix-web-lab = { version = "0.24.1", default-features = false } [dev-dependencies] diff --git a/crates/meilisearch/src/routes/chats/chat_completions.rs b/crates/meilisearch/src/routes/chats/chat_completions.rs index 6f6b50d1c..f588541fa 100644 --- a/crates/meilisearch/src/routes/chats/chat_completions.rs +++ b/crates/meilisearch/src/routes/chats/chat_completions.rs @@ -7,7 +7,6 @@ use std::time::Duration; use actix_web::web::{self, Data}; use actix_web::{Either, HttpRequest, HttpResponse, Responder}; use actix_web_lab::sse::{Event, Sse}; -use async_openai::config::{Config, OpenAIConfig}; use async_openai::types::{ ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, @@ -35,6 +34,7 @@ use serde_json::json; use tokio::runtime::Handle; use tokio::sync::mpsc::error::SendError; +use super::config::Config; use super::errors::StreamErrorEvent; use super::utils::format_documents; use super::{ @@ -312,15 +312,8 @@ async fn non_streamed_chat( } }; - let mut config = OpenAIConfig::default(); - if let Some(api_key) = chat_settings.api_key.as_ref() { - config = config.with_api_key(api_key); - } - if let Some(base_api) = chat_settings.base_api.as_ref() { - config = config.with_api_base(base_api); - } + let config = Config::new(&chat_settings); let client = Client::with_config(config); - let auth_token = extract_token_from_request(&req)?.unwrap(); // TODO do function support later let _function_support = @@ -413,14 +406,7 @@ async fn streamed_chat( }; drop(rtxn); - let mut config = OpenAIConfig::default(); - if let Some(api_key) = chat_settings.api_key.as_ref() { - config = config.with_api_key(api_key); - } - if let Some(base_api) = chat_settings.base_api.as_ref() { - config = config.with_api_base(base_api); - } - + let config = Config::new(&chat_settings); let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let function_support = setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; @@ -465,7 +451,7 @@ async fn streamed_chat( /// Updates the chat completion with the new messages, streams the LLM tokens, /// and report progress and errors. #[allow(clippy::too_many_arguments)] -async fn run_conversation( +async fn run_conversation( index_scheduler: &GuardedData< ActionPolicy<{ actions::CHAT_COMPLETIONS }>, Data, diff --git a/crates/meilisearch/src/routes/chats/config.rs b/crates/meilisearch/src/routes/chats/config.rs new file mode 100644 index 000000000..9babbd8c9 --- /dev/null +++ b/crates/meilisearch/src/routes/chats/config.rs @@ -0,0 +1,87 @@ +use async_openai::config::{AzureConfig, OpenAIConfig}; +use meilisearch_types::features::ChatCompletionSettings as DbChatSettings; +use reqwest::header::HeaderMap; +use secrecy::SecretString; + +#[derive(Debug, Clone)] +pub enum Config { + OpenAiCompatible(OpenAIConfig), + AzureOpenAiCompatible(AzureConfig), +} + +impl Config { + pub fn new(chat_settings: &DbChatSettings) -> Self { + use meilisearch_types::features::ChatCompletionSource::*; + match chat_settings.source { + OpenAi | Mistral | Gemini | VLlm => { + let mut config = OpenAIConfig::default(); + if let Some(org_id) = chat_settings.org_id.as_ref() { + config = config.with_org_id(org_id); + } + if let Some(project_id) = chat_settings.project_id.as_ref() { + config = config.with_project_id(project_id); + } + if let Some(api_key) = chat_settings.api_key.as_ref() { + config = config.with_api_key(api_key); + } + if let Some(base_api) = chat_settings.base_api.as_ref() { + config = config.with_api_base(base_api); + } + Self::OpenAiCompatible(config) + } + AzureOpenAi => { + let mut config = AzureConfig::default(); + if let Some(version) = chat_settings.api_version.as_ref() { + config = config.with_api_version(version); + } + if let Some(deployment_id) = chat_settings.deployment_id.as_ref() { + config = config.with_deployment_id(deployment_id); + } + if let Some(api_key) = chat_settings.api_key.as_ref() { + config = config.with_api_key(api_key); + } + if let Some(base_api) = chat_settings.base_api.as_ref() { + config = config.with_api_base(base_api); + } + Self::AzureOpenAiCompatible(config) + } + } + } +} + +impl async_openai::config::Config for Config { + fn headers(&self) -> HeaderMap { + match self { + Config::OpenAiCompatible(config) => config.headers(), + Config::AzureOpenAiCompatible(config) => config.headers(), + } + } + + fn url(&self, path: &str) -> String { + match self { + Config::OpenAiCompatible(config) => config.url(path), + Config::AzureOpenAiCompatible(config) => config.url(path), + } + } + + fn query(&self) -> Vec<(&str, &str)> { + match self { + Config::OpenAiCompatible(config) => config.query(), + Config::AzureOpenAiCompatible(config) => config.query(), + } + } + + fn api_base(&self) -> &str { + match self { + Config::OpenAiCompatible(config) => config.api_base(), + Config::AzureOpenAiCompatible(config) => config.api_base(), + } + } + + fn api_key(&self) -> &SecretString { + match self { + Config::OpenAiCompatible(config) => config.api_key(), + Config::AzureOpenAiCompatible(config) => config.api_key(), + } + } +} diff --git a/crates/meilisearch/src/routes/chats/mod.rs b/crates/meilisearch/src/routes/chats/mod.rs index 35afd69c0..ddaf4d80d 100644 --- a/crates/meilisearch/src/routes/chats/mod.rs +++ b/crates/meilisearch/src/routes/chats/mod.rs @@ -18,6 +18,7 @@ use crate::extractors::authentication::GuardedData; use crate::routes::PAGINATION_DEFAULT_LIMIT; pub mod chat_completions; +mod config; mod errors; pub mod settings; mod utils; diff --git a/crates/meilisearch/src/routes/chats/settings.rs b/crates/meilisearch/src/routes/chats/settings.rs index 0bb25f30d..c7b89d5bb 100644 --- a/crates/meilisearch/src/routes/chats/settings.rs +++ b/crates/meilisearch/src/routes/chats/settings.rs @@ -109,6 +109,26 @@ async fn patch_settings( Setting::Reset => DbChatCompletionSource::default(), Setting::NotSet => old_settings.source, }, + org_id: match new.org_id { + Setting::Set(new_org_id) => Some(new_org_id), + Setting::Reset => None, + Setting::NotSet => old_settings.org_id, + }, + project_id: match new.project_id { + Setting::Set(new_project_id) => Some(new_project_id), + Setting::Reset => None, + Setting::NotSet => old_settings.project_id, + }, + api_version: match new.api_version { + Setting::Set(new_api_version) => Some(new_api_version), + Setting::Reset => None, + Setting::NotSet => old_settings.api_version, + }, + deployment_id: match new.deployment_id { + Setting::Set(new_deployment_id) => Some(new_deployment_id), + Setting::Reset => None, + Setting::NotSet => old_settings.deployment_id, + }, base_api: match new.base_api { Setting::Set(new_base_api) => Some(new_base_api), Setting::Reset => None, @@ -171,6 +191,22 @@ pub struct GlobalChatSettings { #[schema(value_type = Option)] pub source: Setting, #[serde(default)] + #[deserr(default, error = DeserrJsonError)] + #[schema(value_type = Option, example = json!("dcba4321..."))] + pub org_id: Setting, + #[serde(default)] + #[deserr(default, error = DeserrJsonError)] + #[schema(value_type = Option, example = json!("4321dcba..."))] + pub project_id: Setting, + #[serde(default)] + #[deserr(default, error = DeserrJsonError)] + #[schema(value_type = Option, example = json!("2024-02-01"))] + pub api_version: Setting, + #[serde(default)] + #[deserr(default, error = DeserrJsonError)] + #[schema(value_type = Option, example = json!("1234abcd..."))] + pub deployment_id: Setting, + #[serde(default)] #[deserr(default, error = DeserrJsonError)] #[schema(value_type = Option, example = json!("https://api.mistral.ai/v1"))] pub base_api: Setting,