mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-04 19:25:32 +00:00
Correctly support tenant tokens and filters
This commit is contained in:
parent
0fcc7e1377
commit
1159af1219
@ -4,6 +4,7 @@ use std::marker::PhantomData;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
|
||||
use actix_web::http::header::AUTHORIZATION;
|
||||
use actix_web::web::Data;
|
||||
use actix_web::FromRequest;
|
||||
pub use error::AuthenticationError;
|
||||
@ -94,36 +95,44 @@ impl<P: Policy + 'static, D: 'static + Clone> FromRequest for GuardedData<P, D>
|
||||
_payload: &mut actix_web::dev::Payload,
|
||||
) -> Self::Future {
|
||||
match req.app_data::<Data<AuthController>>().cloned() {
|
||||
Some(auth) => match req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
|
||||
{
|
||||
Some(mut type_token) => match type_token.next() {
|
||||
Some("Bearer") => {
|
||||
// TODO: find a less hardcoded way?
|
||||
let index = req.match_info().get("index_uid");
|
||||
match type_token.next() {
|
||||
Some(token) => Box::pin(Self::auth_bearer(
|
||||
auth,
|
||||
token.to_string(),
|
||||
index.map(String::from),
|
||||
req.app_data::<D>().cloned(),
|
||||
)),
|
||||
None => Box::pin(err(AuthenticationError::InvalidToken.into())),
|
||||
}
|
||||
}
|
||||
_otherwise => {
|
||||
Box::pin(err(AuthenticationError::MissingAuthorizationHeader.into()))
|
||||
}
|
||||
},
|
||||
None => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
|
||||
Some(auth) => match extract_token_from_request(req) {
|
||||
Ok(Some(token)) => {
|
||||
// TODO: find a less hardcoded way?
|
||||
let index = req.match_info().get("index_uid");
|
||||
Box::pin(Self::auth_bearer(
|
||||
auth,
|
||||
token.to_string(),
|
||||
index.map(String::from),
|
||||
req.app_data::<D>().cloned(),
|
||||
))
|
||||
}
|
||||
Ok(None) => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
|
||||
Err(e) => Box::pin(err(e.into())),
|
||||
},
|
||||
None => Box::pin(err(AuthenticationError::IrretrievableState.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_token_from_request(
|
||||
req: &actix_web::HttpRequest,
|
||||
) -> Result<Option<&str>, AuthenticationError> {
|
||||
match req
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
|
||||
{
|
||||
Some(mut type_token) => match type_token.next() {
|
||||
Some("Bearer") => match type_token.next() {
|
||||
Some(token) => Ok(Some(token)),
|
||||
None => Err(AuthenticationError::InvalidToken),
|
||||
},
|
||||
_otherwise => Err(AuthenticationError::MissingAuthorizationHeader),
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Policy {
|
||||
fn authenticate(
|
||||
auth: Data<AuthController>,
|
||||
|
@ -3,7 +3,7 @@ use std::mem;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpResponse, Responder};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event, Sse};
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::types::{
|
||||
@ -18,6 +18,7 @@ use async_openai::types::{
|
||||
use async_openai::Client;
|
||||
use futures::StreamExt;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_auth::AuthController;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::index::IndexEmbeddingConfig;
|
||||
@ -31,7 +32,7 @@ use tokio::sync::mpsc::error::SendError;
|
||||
|
||||
use super::settings::chat::{ChatPrompts, ChatSettings};
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
|
||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||
use crate::routes::indexes::search::search_kind;
|
||||
use crate::search::{
|
||||
@ -48,6 +49,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
/// Get a chat completion
|
||||
async fn chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
|
||||
) -> impl Responder {
|
||||
@ -61,9 +64,13 @@ async fn chat(
|
||||
);
|
||||
|
||||
if chat_completion.stream.unwrap_or(false) {
|
||||
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
||||
Either::Right(
|
||||
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
||||
)
|
||||
} else {
|
||||
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
||||
Either::Left(
|
||||
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,7 +122,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
|
||||
/// Process search request and return formatted results
|
||||
async fn process_search_request(
|
||||
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
index_uid: String,
|
||||
q: Option<String>,
|
||||
) -> Result<(Index, String), ResponseError> {
|
||||
@ -129,8 +138,14 @@ async fn process_search_request(
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
|
||||
auth_ctrl,
|
||||
auth_token,
|
||||
Some(index_uid.as_str()),
|
||||
)?;
|
||||
|
||||
// Tenant token search_rules.
|
||||
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
|
||||
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
|
||||
add_search_rules(&mut query.filter, search_rules);
|
||||
}
|
||||
|
||||
@ -176,6 +191,8 @@ async fn process_search_request(
|
||||
|
||||
async fn non_streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
@ -193,6 +210,7 @@ async fn non_streamed_chat(
|
||||
}
|
||||
let client = Client::with_config(config);
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
|
||||
|
||||
let mut response;
|
||||
@ -219,9 +237,15 @@ async fn non_streamed_chat(
|
||||
let SearchInIndexParameters { index_uid, q } =
|
||||
serde_json::from_str(&call.function.arguments).unwrap();
|
||||
|
||||
let (_, text) =
|
||||
process_search_request(&index_scheduler, &search_queue, index_uid, q)
|
||||
.await?;
|
||||
let (_, text) = process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
.await?;
|
||||
|
||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
||||
ChatCompletionRequestToolMessage {
|
||||
@ -246,9 +270,11 @@ async fn non_streamed_chat(
|
||||
|
||||
async fn streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> impl Responder {
|
||||
) -> Result<impl Responder, ResponseError> {
|
||||
let chat_settings = match index_scheduler.chat_settings().unwrap() {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => ChatSettings::default(),
|
||||
@ -262,6 +288,7 @@ async fn streamed_chat(
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
@ -354,7 +381,9 @@ async fn streamed_chat(
|
||||
|
||||
let result = process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
@ -417,7 +446,7 @@ async fn streamed_chat(
|
||||
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
|
||||
});
|
||||
|
||||
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
|
||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||
}
|
||||
|
||||
/// The structure used to aggregate the function calls to make.
|
||||
|
Loading…
x
Reference in New Issue
Block a user