diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 2715cff72..f794ba19c 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -76,12 +76,23 @@ async fn chat( } /// Setup search tool in chat completion request -fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: &ChatPrompts) { +fn setup_search_tool( + index_scheduler: &Data, + filters: &meilisearch_auth::AuthFilter, + chat_completion: &mut CreateChatCompletionRequest, + prompts: &ChatPrompts, +) -> Result<(), ResponseError> { let tools = chat_completion.tools.get_or_insert_default(); if tools.iter().find(|t| t.function.name == SEARCH_IN_INDEX_FUNCTION_NAME).is_some() { panic!("{SEARCH_IN_INDEX_FUNCTION_NAME} function already set"); } + let index_uids: Vec<_> = index_scheduler + .index_names()? + .into_iter() + .filter(|index_uid| filters.is_index_authorized(&index_uid)) + .collect(); + let tool = ChatCompletionToolArgs::default() .r#type(ChatCompletionToolType::Function) .function( @@ -93,7 +104,7 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: "properties": { "index_uid": { "type": "string", - "enum": ["main"], + "enum": index_uids, "description": prompts.search_index_uid_param, }, "q": { @@ -120,6 +131,8 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: name: None, }), ); + + Ok(()) } /// Process search request and return formatted results @@ -199,6 +212,8 @@ async fn non_streamed_chat( search_queue: web::Data, mut chat_completion: CreateChatCompletionRequest, ) -> Result { + let filters = index_scheduler.filters(); + let chat_settings = match index_scheduler.chat_settings().unwrap() { Some(value) => serde_json::from_value(value).unwrap(), None => ChatSettings::default(), @@ -214,7 +229,7 @@ async fn non_streamed_chat( let client = Client::with_config(config); let auth_token = extract_token_from_request(&req)?.unwrap(); - setup_search_tool(&mut chat_completion, &chat_settings.prompts); + setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; let mut response; loop { @@ -279,6 +294,8 @@ async fn streamed_chat( search_queue: web::Data, mut chat_completion: CreateChatCompletionRequest, ) -> Result { + let filters = index_scheduler.filters(); + let chat_settings = match index_scheduler.chat_settings().unwrap() { Some(value) => serde_json::from_value(value).unwrap(), None => ChatSettings::default(), @@ -293,7 +310,7 @@ async fn streamed_chat( } let auth_token = extract_token_from_request(&req)?.unwrap().to_string(); - setup_search_tool(&mut chat_completion, &chat_settings.prompts); + setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?; let (tx, rx) = tokio::sync::mpsc::channel(10); let _join_handle = Handle::current().spawn(async move {