Support multiple indexes and not only main

This commit is contained in:
Clément Renault 2025-05-20 17:15:23 +02:00 committed by Kerollmops
parent 14c8e5cb56
commit d6f2bd9b57
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -76,12 +76,23 @@ async fn chat(
} }
/// Setup search tool in chat completion request /// Setup search tool in chat completion request
fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts) { fn setup_search_tool(
index_scheduler: &Data<IndexScheduler>,
filters: &meilisearch_auth::AuthFilter,
chat_completion: &mut CreateChatCompletionRequest,
prompts: &ChatPrompts,
) -> Result<(), ResponseError> {
let tools = chat_completion.tools.get_or_insert_default(); let tools = chat_completion.tools.get_or_insert_default();
if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() { if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set"); panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set");
} }
let index_uids: Vec<_> = index_scheduler
.index_names()?
.into_iter()
.filter(|index_uid| filters.is_index_authorized(&index_uid))
.collect();
let tool = ChatCompletionToolArgs::default() let tool = ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function) .r#type(ChatCompletionToolType::Function)
.function( .function(
@ -93,7 +104,7 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
"properties": { "properties": {
"index_uid": { "index_uid": {
"type": "string", "type": "string",
"enum": ["main"], "enum": index_uids,
"description": prompts.search_index_uid_param, "description": prompts.search_index_uid_param,
}, },
"q": { "q": {
@ -120,6 +131,8 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
name: None, name: None,
}), }),
); );
Ok(())
} }
/// Process search request and return formatted results /// Process search request and return formatted results
@ -199,6 +212,8 @@ async fn non_streamed_chat(
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest, mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> { ) -> Result<HttpResponse, ResponseError> {
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() { let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(), Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(), None => ChatSettings::default(),
@ -214,7 +229,7 @@ async fn non_streamed_chat(
let client = Client::with_config(config); let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap(); let auth_token = extract_token_from_request(&req)?.unwrap();
setup_search_tool(&mut chat_completion, &chat_settings.prompts); setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
let mut response; let mut response;
loop { loop {
@ -279,6 +294,8 @@ async fn streamed_chat(
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest, mut chat_completion: CreateChatCompletionRequest,
) -> Result<impl Responder, ResponseError> { ) -> Result<impl Responder, ResponseError> {
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() { let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(), Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(), None => ChatSettings::default(),
@ -293,7 +310,7 @@ async fn streamed_chat(
} }
let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
setup_search_tool(&mut chat_completion, &chat_settings.prompts); setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10); let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move { let _join_handle = Handle::current().spawn(async move {