Introduce the support of Azure, Gemini, vLLM

This commit is contained in:
Clément Renault 2025-06-06 12:08:37 +02:00
parent 4dfb89168b
commit 70670c3be4
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
8 changed files with 261 additions and 18 deletions

1
Cargo.lock generated
View File

@ -3788,6 +3788,7 @@ dependencies = [
"rustls", "rustls",
"rustls-pemfile", "rustls-pemfile",
"rustls-pki-types", "rustls-pki-types",
"secrecy",
"segment", "segment",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -391,6 +391,10 @@ EditDocumentsByFunctionError , InvalidRequest , BAD_REQU
InvalidSettingsIndexChat , InvalidRequest , BAD_REQUEST ; InvalidSettingsIndexChat , InvalidRequest , BAD_REQUEST ;
// Experimental features - Chat Completions // Experimental features - Chat Completions
ChatWorkspaceNotFound , InvalidRequest , NOT_FOUND ; ChatWorkspaceNotFound , InvalidRequest , NOT_FOUND ;
InvalidChatCompletionOrgId , InvalidRequest , BAD_REQUEST ;
InvalidChatCompletionProjectId , InvalidRequest , BAD_REQUEST ;
InvalidChatCompletionApiVersion , InvalidRequest , BAD_REQUEST ;
InvalidChatCompletionDeploymentId , InvalidRequest , BAD_REQUEST ;
InvalidChatCompletionSource , InvalidRequest , BAD_REQUEST ; InvalidChatCompletionSource , InvalidRequest , BAD_REQUEST ;
InvalidChatCompletionBaseApi , InvalidRequest , BAD_REQUEST ; InvalidChatCompletionBaseApi , InvalidRequest , BAD_REQUEST ;
InvalidChatCompletionApiKey , InvalidRequest , BAD_REQUEST ; InvalidChatCompletionApiKey , InvalidRequest , BAD_REQUEST ;

View File

@ -51,6 +51,14 @@ pub struct Network {
pub struct ChatCompletionSettings { pub struct ChatCompletionSettings {
pub source: ChatCompletionSource, pub source: ChatCompletionSource,
#[serde(default)] #[serde(default)]
pub org_id: Option<String>,
#[serde(default)]
pub project_id: Option<String>,
#[serde(default)]
pub api_version: Option<String>,
#[serde(default)]
pub deployment_id: Option<String>,
#[serde(default)]
pub base_api: Option<String>, pub base_api: Option<String>,
#[serde(default)] #[serde(default)]
pub api_key: Option<String>, pub api_key: Option<String>,
@ -88,6 +96,43 @@ impl ChatCompletionSettings {
pub enum ChatCompletionSource { pub enum ChatCompletionSource {
#[default] #[default]
OpenAi, OpenAi,
AzureOpenAi,
Mistral,
Gemini,
VLlm,
}
impl ChatCompletionSource {
pub fn system_role(&self, model: &str) -> &'static str {
match self {
ChatCompletionSource::OpenAi if Self::old_openai_model(model) => "system",
ChatCompletionSource::OpenAi => "developer",
ChatCompletionSource::AzureOpenAi if Self::old_openai_model(model) => "system",
ChatCompletionSource::AzureOpenAi => "developer",
ChatCompletionSource::Mistral => "system",
ChatCompletionSource::Gemini => "system",
ChatCompletionSource::VLlm => "system",
}
}
/// Returns true if the model is an old OpenAI model.
///
/// Old OpenAI models use the system role while new ones use the developer role.
fn old_openai_model(model: &str) -> bool {
["gpt-3.5", "gpt-4", "gpt-4.1", "gpt-4.5", "gpt-4o", "chatgpt-4o"].iter().any(|old| {
model.starts_with(old) && model.chars().nth(old.len()).is_none_or(|last| last == '-')
})
}
pub fn base_url(&self) -> Option<&'static str> {
use ChatCompletionSource::*;
match self {
OpenAi => Some("https://api.openai.com/v1/"),
Mistral => Some("https://api.mistral.ai/v1/"),
Gemini => Some("https://generativelanguage.googleapis.com/v1beta/openai/"),
AzureOpenAi | VLlm => None,
}
}
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
@ -111,3 +156,85 @@ impl Default for ChatCompletionPrompts {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
const ALL_OPENAI_MODELS_OLDINESS: &[(&str, bool)] = &[
("gpt-4-0613", true),
("gpt-4", true),
("gpt-3.5-turbo", true),
("gpt-4o-audio-preview-2025-06-03", true),
("gpt-4.1-nano", true),
("gpt-4o-realtime-preview-2025-06-03", true),
("gpt-3.5-turbo-instruct", true),
("gpt-3.5-turbo-instruct-0914", true),
("gpt-4-1106-preview", true),
("gpt-3.5-turbo-1106", true),
("gpt-4-0125-preview", true),
("gpt-4-turbo-preview", true),
("gpt-3.5-turbo-0125", true),
("gpt-4-turbo", true),
("gpt-4-turbo-2024-04-09", true),
("gpt-4o", true),
("gpt-4o-2024-05-13", true),
("gpt-4o-mini-2024-07-18", true),
("gpt-4o-mini", true),
("gpt-4o-2024-08-06", true),
("chatgpt-4o-latest", true),
("gpt-4o-realtime-preview-2024-10-01", true),
("gpt-4o-audio-preview-2024-10-01", true),
("gpt-4o-audio-preview", true),
("gpt-4o-realtime-preview", true),
("gpt-4o-realtime-preview-2024-12-17", true),
("gpt-4o-audio-preview-2024-12-17", true),
("gpt-4o-mini-realtime-preview-2024-12-17", true),
("gpt-4o-mini-audio-preview-2024-12-17", true),
("gpt-4o-mini-realtime-preview", true),
("gpt-4o-mini-audio-preview", true),
("gpt-4o-2024-11-20", true),
("gpt-4.5-preview", true),
("gpt-4.5-preview-2025-02-27", true),
("gpt-4o-search-preview-2025-03-11", true),
("gpt-4o-search-preview", true),
("gpt-4o-mini-search-preview-2025-03-11", true),
("gpt-4o-mini-search-preview", true),
("gpt-4o-transcribe", true),
("gpt-4o-mini-transcribe", true),
("gpt-4o-mini-tts", true),
("gpt-4.1-2025-04-14", true),
("gpt-4.1", true),
("gpt-4.1-mini-2025-04-14", true),
("gpt-4.1-mini", true),
("gpt-4.1-nano-2025-04-14", true),
("gpt-3.5-turbo-16k", true),
//
// new models
("o1-preview-2024-09-12", false),
("o1-preview", false),
("o1-mini-2024-09-12", false),
("o1-mini", false),
("o1-2024-12-17", false),
("o1", false),
("o3-mini", false),
("o3-mini-2025-01-31", false),
("o1-pro-2025-03-19", false),
("o1-pro", false),
("o3-2025-04-16", false),
("o4-mini-2025-04-16", false),
("o3", false),
("o4-mini", false),
];
#[test]
fn old_openai_models() {
for (name, is_old) in ALL_OPENAI_MODELS_OLDINESS.iter().copied() {
assert_eq!(
ChatCompletionSource::old_openai_model(name),
is_old,
"Model {name} is not considered old"
);
}
}
}

View File

@ -114,6 +114,7 @@ utoipa = { version = "5.3.1", features = [
] } ] }
utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] }
async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling" } async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling" }
secrecy = "0.10.3"
actix-web-lab = { version = "0.24.1", default-features = false } actix-web-lab = { version = "0.24.1", default-features = false }
[dev-dependencies] [dev-dependencies]

View File

@ -7,7 +7,6 @@ use std::time::Duration;
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
use actix_web::{Either, HttpRequest, HttpResponse, Responder}; use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{Event, Sse}; use actix_web_lab::sse::{Event, Sse};
use async_openai::config::{Config, OpenAIConfig};
use async_openai::types::{ use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
@ -35,6 +34,7 @@ use serde_json::json;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use super::config::Config;
use super::errors::StreamErrorEvent; use super::errors::StreamErrorEvent;
use super::utils::format_documents; use super::utils::format_documents;
use super::{ use super::{
@ -312,15 +312,8 @@ async fn non_streamed_chat(
} }
}; };
let mut config = OpenAIConfig::default(); let config = Config::new(&chat_settings);
if let Some(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let client = Client::with_config(config); let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap(); let auth_token = extract_token_from_request(&req)?.unwrap();
// TODO do function support later // TODO do function support later
let _function_support = let _function_support =
@ -413,14 +406,7 @@ async fn streamed_chat(
}; };
drop(rtxn); drop(rtxn);
let mut config = OpenAIConfig::default(); let config = Config::new(&chat_settings);
if let Some(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let function_support = let function_support =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
@ -465,7 +451,7 @@ async fn streamed_chat(
/// Updates the chat completion with the new messages, streams the LLM tokens, /// Updates the chat completion with the new messages, streams the LLM tokens,
/// and report progress and errors. /// and report progress and errors.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn run_conversation<C: Config>( async fn run_conversation<C: async_openai::config::Config>(
index_scheduler: &GuardedData< index_scheduler: &GuardedData<
ActionPolicy<{ actions::CHAT_COMPLETIONS }>, ActionPolicy<{ actions::CHAT_COMPLETIONS }>,
Data<IndexScheduler>, Data<IndexScheduler>,

View File

@ -0,0 +1,87 @@
use async_openai::config::{AzureConfig, OpenAIConfig};
use meilisearch_types::features::ChatCompletionSettings as DbChatSettings;
use reqwest::header::HeaderMap;
use secrecy::SecretString;
#[derive(Debug, Clone)]
pub enum Config {
OpenAiCompatible(OpenAIConfig),
AzureOpenAiCompatible(AzureConfig),
}
impl Config {
pub fn new(chat_settings: &DbChatSettings) -> Self {
use meilisearch_types::features::ChatCompletionSource::*;
match chat_settings.source {
OpenAi | Mistral | Gemini | VLlm => {
let mut config = OpenAIConfig::default();
if let Some(org_id) = chat_settings.org_id.as_ref() {
config = config.with_org_id(org_id);
}
if let Some(project_id) = chat_settings.project_id.as_ref() {
config = config.with_project_id(project_id);
}
if let Some(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
Self::OpenAiCompatible(config)
}
AzureOpenAi => {
let mut config = AzureConfig::default();
if let Some(version) = chat_settings.api_version.as_ref() {
config = config.with_api_version(version);
}
if let Some(deployment_id) = chat_settings.deployment_id.as_ref() {
config = config.with_deployment_id(deployment_id);
}
if let Some(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
Self::AzureOpenAiCompatible(config)
}
}
}
}
impl async_openai::config::Config for Config {
fn headers(&self) -> HeaderMap {
match self {
Config::OpenAiCompatible(config) => config.headers(),
Config::AzureOpenAiCompatible(config) => config.headers(),
}
}
fn url(&self, path: &str) -> String {
match self {
Config::OpenAiCompatible(config) => config.url(path),
Config::AzureOpenAiCompatible(config) => config.url(path),
}
}
fn query(&self) -> Vec<(&str, &str)> {
match self {
Config::OpenAiCompatible(config) => config.query(),
Config::AzureOpenAiCompatible(config) => config.query(),
}
}
fn api_base(&self) -> &str {
match self {
Config::OpenAiCompatible(config) => config.api_base(),
Config::AzureOpenAiCompatible(config) => config.api_base(),
}
}
fn api_key(&self) -> &SecretString {
match self {
Config::OpenAiCompatible(config) => config.api_key(),
Config::AzureOpenAiCompatible(config) => config.api_key(),
}
}
}

View File

@ -18,6 +18,7 @@ use crate::extractors::authentication::GuardedData;
use crate::routes::PAGINATION_DEFAULT_LIMIT; use crate::routes::PAGINATION_DEFAULT_LIMIT;
pub mod chat_completions; pub mod chat_completions;
mod config;
mod errors; mod errors;
pub mod settings; pub mod settings;
mod utils; mod utils;

View File

@ -109,6 +109,26 @@ async fn patch_settings(
Setting::Reset => DbChatCompletionSource::default(), Setting::Reset => DbChatCompletionSource::default(),
Setting::NotSet => old_settings.source, Setting::NotSet => old_settings.source,
}, },
org_id: match new.org_id {
Setting::Set(new_org_id) => Some(new_org_id),
Setting::Reset => None,
Setting::NotSet => old_settings.org_id,
},
project_id: match new.project_id {
Setting::Set(new_project_id) => Some(new_project_id),
Setting::Reset => None,
Setting::NotSet => old_settings.project_id,
},
api_version: match new.api_version {
Setting::Set(new_api_version) => Some(new_api_version),
Setting::Reset => None,
Setting::NotSet => old_settings.api_version,
},
deployment_id: match new.deployment_id {
Setting::Set(new_deployment_id) => Some(new_deployment_id),
Setting::Reset => None,
Setting::NotSet => old_settings.deployment_id,
},
base_api: match new.base_api { base_api: match new.base_api {
Setting::Set(new_base_api) => Some(new_base_api), Setting::Set(new_base_api) => Some(new_base_api),
Setting::Reset => None, Setting::Reset => None,
@ -171,6 +191,22 @@ pub struct GlobalChatSettings {
#[schema(value_type = Option<ChatCompletionSource>)] #[schema(value_type = Option<ChatCompletionSource>)]
pub source: Setting<ChatCompletionSource>, pub source: Setting<ChatCompletionSource>,
#[serde(default)] #[serde(default)]
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionOrgId>)]
#[schema(value_type = Option<String>, example = json!("dcba4321..."))]
pub org_id: Setting<String>,
#[serde(default)]
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionProjectId>)]
#[schema(value_type = Option<String>, example = json!("4321dcba..."))]
pub project_id: Setting<String>,
#[serde(default)]
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionApiVersion>)]
#[schema(value_type = Option<String>, example = json!("2024-02-01"))]
pub api_version: Setting<String>,
#[serde(default)]
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionDeploymentId>)]
#[schema(value_type = Option<String>, example = json!("1234abcd..."))]
pub deployment_id: Setting<String>,
#[serde(default)]
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionBaseApi>)] #[deserr(default, error = DeserrJsonError<InvalidChatCompletionBaseApi>)]
#[schema(value_type = Option<String>, example = json!("https://api.mistral.ai/v1"))] #[schema(value_type = Option<String>, example = json!("https://api.mistral.ai/v1"))]
pub base_api: Setting<String>, pub base_api: Setting<String>,