Correctly support tenant tokens and filters

This commit is contained in:
Clément Renault 2025-05-20 16:15:49 +02:00 committed by Kerollmops
parent 0fcc7e1377
commit 1159af1219
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
2 changed files with 72 additions and 34 deletions

View File

@ -4,6 +4,7 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use actix_web::http::header::AUTHORIZATION;
use actix_web::web::Data;
use actix_web::FromRequest;
pub use error::AuthenticationError;
@ -94,36 +95,44 @@ impl<P: Policy + 'static, D: 'static + Clone> FromRequest for GuardedData<P, D>
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
match req.app_data::<Data<AuthController>>().cloned() {
Some(auth) => match req
.headers()
.get("Authorization")
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
{
Some(mut type_token) => match type_token.next() {
Some("Bearer") => {
// TODO: find a less hardcoded way?
let index = req.match_info().get("index_uid");
match type_token.next() {
Some(token) => Box::pin(Self::auth_bearer(
auth,
token.to_string(),
index.map(String::from),
req.app_data::<D>().cloned(),
)),
None => Box::pin(err(AuthenticationError::InvalidToken.into())),
}
}
_otherwise => {
Box::pin(err(AuthenticationError::MissingAuthorizationHeader.into()))
}
},
None => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
Some(auth) => match extract_token_from_request(req) {
Ok(Some(token)) => {
// TODO: find a less hardcoded way?
let index = req.match_info().get("index_uid");
Box::pin(Self::auth_bearer(
auth,
token.to_string(),
index.map(String::from),
req.app_data::<D>().cloned(),
))
}
Ok(None) => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
Err(e) => Box::pin(err(e.into())),
},
None => Box::pin(err(AuthenticationError::IrretrievableState.into())),
}
}
}
pub fn extract_token_from_request(
req: &actix_web::HttpRequest,
) -> Result<Option<&str>, AuthenticationError> {
match req
.headers()
.get(AUTHORIZATION)
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
{
Some(mut type_token) => match type_token.next() {
Some("Bearer") => match type_token.next() {
Some(token) => Ok(Some(token)),
None => Err(AuthenticationError::InvalidToken),
},
_otherwise => Err(AuthenticationError::MissingAuthorizationHeader),
},
None => Ok(None),
}
}
pub trait Policy {
fn authenticate(
auth: Data<AuthController>,

View File

@ -3,7 +3,7 @@ use std::mem;
use std::time::Duration;
use actix_web::web::{self, Data};
use actix_web::{Either, HttpResponse, Responder};
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig;
use async_openai::types::{
@ -18,6 +18,7 @@ use async_openai::types::{
use async_openai::Client;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::IndexEmbeddingConfig;
@ -31,7 +32,7 @@ use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, ChatSettings};
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::GuardedData;
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search::search_kind;
use crate::search::{
@ -48,6 +49,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
/// Get a chat completion
async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
) -> impl Responder {
@ -61,9 +64,13 @@ async fn chat(
);
if chat_completion.stream.unwrap_or(false) {
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
Either::Right(
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
} else {
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await)
Either::Left(
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
}
}
@ -115,7 +122,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
/// Process search request and return formatted results
async fn process_search_request(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
index_uid: String,
q: Option<String>,
) -> Result<(Index, String), ResponseError> {
@ -129,8 +138,14 @@ async fn process_search_request(
..Default::default()
};
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
auth_ctrl,
auth_token,
Some(index_uid.as_str()),
)?;
// Tenant token search_rules.
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
add_search_rules(&mut query.filter, search_rules);
}
@ -176,6 +191,8 @@ async fn process_search_request(
async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
@ -193,6 +210,7 @@ async fn non_streamed_chat(
}
let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap();
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let mut response;
@ -219,9 +237,15 @@ async fn non_streamed_chat(
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
let (_, text) =
process_search_request(&index_scheduler, &search_queue, index_uid, q)
.await?;
let (_, text) = process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
auth_token,
index_uid,
q,
)
.await?;
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
@ -246,9 +270,11 @@ async fn non_streamed_chat(
async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder {
) -> Result<impl Responder, ResponseError> {
let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(),
@ -262,6 +288,7 @@ async fn streamed_chat(
config = config.with_api_base(base_api);
}
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let (tx, rx) = tokio::sync::mpsc::channel(10);
@ -354,7 +381,9 @@ async fn streamed_chat(
let result = process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
&auth_token,
index_uid,
q,
)
@ -417,7 +446,7 @@ async fn streamed_chat(
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
});
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
}
/// The structure used to aggregate the function calls to make.