From 7266aed77093a69b4e7aaeff0f55a46e889c8215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 20 May 2025 16:15:49 +0200 Subject: [PATCH] Correctly support tenant tokens and filters --- .../src/extractors/authentication/mod.rs | 57 +++++++++++-------- crates/meilisearch/src/routes/chat.rs | 49 ++++++++++++---- 2 files changed, 72 insertions(+), 34 deletions(-) diff --git a/crates/meilisearch/src/extractors/authentication/mod.rs b/crates/meilisearch/src/extractors/authentication/mod.rs index 7c9f5892e..eb250190d 100644 --- a/crates/meilisearch/src/extractors/authentication/mod.rs +++ b/crates/meilisearch/src/extractors/authentication/mod.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use std::ops::Deref; use std::pin::Pin; +use actix_web::http::header::AUTHORIZATION; use actix_web::web::Data; use actix_web::FromRequest; pub use error::AuthenticationError; @@ -94,36 +95,44 @@ impl FromRequest for GuardedData _payload: &mut actix_web::dev::Payload, ) -> Self::Future { match req.app_data::>().cloned() { - Some(auth) => match req - .headers() - .get("Authorization") - .map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' ')) - { - Some(mut type_token) => match type_token.next() { - Some("Bearer") => { - // TODO: find a less hardcoded way? - let index = req.match_info().get("index_uid"); - match type_token.next() { - Some(token) => Box::pin(Self::auth_bearer( - auth, - token.to_string(), - index.map(String::from), - req.app_data::().cloned(), - )), - None => Box::pin(err(AuthenticationError::InvalidToken.into())), - } - } - _otherwise => { - Box::pin(err(AuthenticationError::MissingAuthorizationHeader.into())) - } - }, - None => Box::pin(Self::auth_token(auth, req.app_data::().cloned())), + Some(auth) => match extract_token_from_request(req) { + Ok(Some(token)) => { + // TODO: find a less hardcoded way? + let index = req.match_info().get("index_uid"); + Box::pin(Self::auth_bearer( + auth, + token.to_string(), + index.map(String::from), + req.app_data::().cloned(), + )) + } + Ok(None) => Box::pin(Self::auth_token(auth, req.app_data::().cloned())), + Err(e) => Box::pin(err(e.into())), }, None => Box::pin(err(AuthenticationError::IrretrievableState.into())), } } } +pub fn extract_token_from_request( + req: &actix_web::HttpRequest, +) -> Result, AuthenticationError> { + match req + .headers() + .get(AUTHORIZATION) + .map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' ')) + { + Some(mut type_token) => match type_token.next() { + Some("Bearer") => match type_token.next() { + Some(token) => Ok(Some(token)), + None => Err(AuthenticationError::InvalidToken), + }, + _otherwise => Err(AuthenticationError::MissingAuthorizationHeader), + }, + None => Ok(None), + } +} + pub trait Policy { fn authenticate( auth: Data, diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 5ddcb6088..31e089231 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -3,7 +3,7 @@ use std::mem; use std::time::Duration; use actix_web::web::{self, Data}; -use actix_web::{Either, HttpResponse, Responder}; +use actix_web::{Either, HttpRequest, HttpResponse, Responder}; use actix_web_lab::sse::{self, Event, Sse}; use async_openai::config::OpenAIConfig; use async_openai::types::{ @@ -18,6 +18,7 @@ use async_openai::types::{ use async_openai::Client; use futures::StreamExt; use index_scheduler::IndexScheduler; +use meilisearch_auth::AuthController; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; use meilisearch_types::milli::index::IndexEmbeddingConfig; @@ -31,7 +32,7 @@ use tokio::sync::mpsc::error::SendError; use super::settings::chat::{ChatPrompts, ChatSettings}; use crate::extractors::authentication::policies::ActionPolicy; -use crate::extractors::authentication::GuardedData; +use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _}; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::routes::indexes::search::search_kind; use crate::search::{ @@ -48,6 +49,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) { /// Get a chat completion async fn chat( index_scheduler: GuardedData, Data>, + auth_ctrl: web::Data, + req: HttpRequest, search_queue: web::Data, web::Json(chat_completion): web::Json, ) -> impl Responder { @@ -61,9 +64,13 @@ async fn chat( ); if chat_completion.stream.unwrap_or(false) { - Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await) + Either::Right( + streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await, + ) } else { - Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await) + Either::Left( + non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await, + ) } } @@ -115,7 +122,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: /// Process search request and return formatted results async fn process_search_request( index_scheduler: &GuardedData, Data>, + auth_ctrl: web::Data, search_queue: &web::Data, + auth_token: &str, index_uid: String, q: Option, ) -> Result<(Index, String), ResponseError> { @@ -129,8 +138,14 @@ async fn process_search_request( ..Default::default() }; + let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate( + auth_ctrl, + auth_token, + Some(index_uid.as_str()), + )?; + // Tenant token search_rules. - if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { + if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) { add_search_rules(&mut query.filter, search_rules); } @@ -176,6 +191,8 @@ async fn process_search_request( async fn non_streamed_chat( index_scheduler: GuardedData, Data>, + auth_ctrl: web::Data, + req: HttpRequest, search_queue: web::Data, mut chat_completion: CreateChatCompletionRequest, ) -> Result { @@ -193,6 +210,7 @@ async fn non_streamed_chat( } let client = Client::with_config(config); + let auth_token = extract_token_from_request(&req)?.unwrap(); setup_search_tool(&mut chat_completion, &chat_settings.prompts); let mut response; @@ -219,9 +237,15 @@ async fn non_streamed_chat( let SearchInIndexParameters { index_uid, q } = serde_json::from_str(&call.function.arguments).unwrap(); - let (_, text) = - process_search_request(&index_scheduler, &search_queue, index_uid, q) - .await?; + let (_, text) = process_search_request( + &index_scheduler, + auth_ctrl.clone(), + &search_queue, + auth_token, + index_uid, + q, + ) + .await?; chat_completion.messages.push(ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { @@ -246,9 +270,11 @@ async fn non_streamed_chat( async fn streamed_chat( index_scheduler: GuardedData, Data>, + auth_ctrl: web::Data, + req: HttpRequest, search_queue: web::Data, mut chat_completion: CreateChatCompletionRequest, -) -> impl Responder { +) -> Result { let chat_settings = match index_scheduler.chat_settings().unwrap() { Some(value) => serde_json::from_value(value).unwrap(), None => ChatSettings::default(), @@ -262,6 +288,7 @@ async fn streamed_chat( config = config.with_api_base(base_api); } + let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); setup_search_tool(&mut chat_completion, &chat_settings.prompts); let (tx, rx) = tokio::sync::mpsc::channel(10); @@ -354,7 +381,9 @@ async fn streamed_chat( let result = process_search_request( &index_scheduler, + auth_ctrl.clone(), &search_queue, + &auth_token, index_uid, q, ) @@ -417,7 +446,7 @@ async fn streamed_chat( let _ = tx.send(Event::Data(sse::Data::new("[DONE]"))); }); - Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)) + Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))) } /// The structure used to aggregate the function calls to make.