mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-06 04:05:37 +00:00
Factorize the code a bit more and support reporting errors
This commit is contained in:
parent
fa139ee601
commit
5d8cdb075b
@ -2,13 +2,15 @@ use std::cell::RefCell;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Write as _;
|
use std::fmt::Write as _;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
use std::ops::ControlFlow;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use actix_web::web::{self, Data};
|
use actix_web::web::{self, Data};
|
||||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||||
use actix_web_lab::sse::{self, Event, Sse};
|
use actix_web_lab::sse::{self, Event, Sse};
|
||||||
use async_openai::config::OpenAIConfig;
|
use async_openai::config::{Config, OpenAIConfig};
|
||||||
|
use async_openai::error::OpenAIError;
|
||||||
use async_openai::types::{
|
use async_openai::types::{
|
||||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
||||||
@ -40,6 +42,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::runtime::Handle;
|
use tokio::runtime::Handle;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
|
||||||
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
|
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
|
||||||
use crate::error::MeilisearchHttpError;
|
use crate::error::MeilisearchHttpError;
|
||||||
@ -83,11 +86,11 @@ async fn chat(
|
|||||||
|
|
||||||
if chat_completion.stream.unwrap_or(false) {
|
if chat_completion.stream.unwrap_or(false) {
|
||||||
Either::Right(
|
Either::Right(
|
||||||
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
Either::Left(
|
Either::Left(
|
||||||
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
non_streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -327,8 +330,8 @@ async fn process_search_request(
|
|||||||
async fn non_streamed_chat(
|
async fn non_streamed_chat(
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||||
auth_ctrl: web::Data<AuthController>,
|
auth_ctrl: web::Data<AuthController>,
|
||||||
req: HttpRequest,
|
|
||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
|
req: HttpRequest,
|
||||||
mut chat_completion: CreateChatCompletionRequest,
|
mut chat_completion: CreateChatCompletionRequest,
|
||||||
) -> Result<HttpResponse, ResponseError> {
|
) -> Result<HttpResponse, ResponseError> {
|
||||||
let filters = index_scheduler.filters();
|
let filters = index_scheduler.filters();
|
||||||
@ -420,8 +423,8 @@ async fn non_streamed_chat(
|
|||||||
async fn streamed_chat(
|
async fn streamed_chat(
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||||
auth_ctrl: web::Data<AuthController>,
|
auth_ctrl: web::Data<AuthController>,
|
||||||
req: HttpRequest,
|
|
||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
|
req: HttpRequest,
|
||||||
mut chat_completion: CreateChatCompletionRequest,
|
mut chat_completion: CreateChatCompletionRequest,
|
||||||
) -> Result<impl Responder, ResponseError> {
|
) -> Result<impl Responder, ResponseError> {
|
||||||
let filters = index_scheduler.filters();
|
let filters = index_scheduler.filters();
|
||||||
@ -441,354 +444,285 @@ async fn streamed_chat(
|
|||||||
|
|
||||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||||
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
|
let function_support =
|
||||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
|
let tx = SseEventSender(tx);
|
||||||
let _join_handle = Handle::current().spawn(async move {
|
let _join_handle = Handle::current().spawn(async move {
|
||||||
let client = Client::with_config(config.clone());
|
let client = Client::with_config(config.clone());
|
||||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||||
let mut finish_reason = None;
|
|
||||||
|
|
||||||
// Limit the number of internal calls to satisfy the search requests of the LLM
|
// Limit the number of internal calls to satisfy the search requests of the LLM
|
||||||
'main: for _ in 0..20 {
|
for _ in 0..20 {
|
||||||
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
let output = run_conversation(
|
||||||
while let Some(result) = response.next().await {
|
&index_scheduler,
|
||||||
match result {
|
&auth_ctrl,
|
||||||
Ok(resp) => {
|
&search_queue,
|
||||||
let choice = &resp.choices[0];
|
&auth_token,
|
||||||
finish_reason = choice.finish_reason;
|
&client,
|
||||||
|
&chat_settings,
|
||||||
|
&mut chat_completion,
|
||||||
|
&tx,
|
||||||
|
&mut global_tool_calls,
|
||||||
|
function_support,
|
||||||
|
);
|
||||||
|
|
||||||
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } =
|
match output.await {
|
||||||
&choice.delta;
|
Ok(ControlFlow::Continue(())) => (),
|
||||||
|
Ok(ControlFlow::Break(_finish_reason)) => break,
|
||||||
match tool_calls {
|
// If the connection is closed we must stop
|
||||||
Some(tool_calls) => {
|
Err(SendError(_)) => return,
|
||||||
for chunk in tool_calls {
|
|
||||||
let ChatCompletionMessageToolCallChunk {
|
|
||||||
index,
|
|
||||||
id,
|
|
||||||
r#type: _,
|
|
||||||
function,
|
|
||||||
} = chunk;
|
|
||||||
let FunctionCallStream { name, arguments } =
|
|
||||||
function.as_ref().unwrap();
|
|
||||||
|
|
||||||
global_tool_calls
|
|
||||||
.entry(*index)
|
|
||||||
.and_modify(|call| {
|
|
||||||
if call.is_internal() {
|
|
||||||
call.append(arguments.as_ref().unwrap())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.or_insert_with(|| {
|
|
||||||
if name.as_ref().map_or(false, |n| {
|
|
||||||
n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
|
||||||
}) {
|
|
||||||
Call::Internal {
|
|
||||||
id: id.as_ref().unwrap().clone(),
|
|
||||||
function_name: name.as_ref().unwrap().clone(),
|
|
||||||
arguments: arguments.as_ref().unwrap().clone(),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Call::External { _id: id.as_ref().unwrap().clone() }
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if global_tool_calls.get(index).map_or(false, Call::is_external)
|
|
||||||
{
|
|
||||||
todo!("Support forwarding external tool calls");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
if !global_tool_calls.is_empty() {
|
|
||||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
|
||||||
mem::take(&mut global_tool_calls)
|
|
||||||
.into_values()
|
|
||||||
.flat_map(|call| match call {
|
|
||||||
Call::Internal {
|
|
||||||
id,
|
|
||||||
function_name: name,
|
|
||||||
arguments,
|
|
||||||
} => Some(ChatCompletionMessageToolCall {
|
|
||||||
id,
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: FunctionCall { name, arguments },
|
|
||||||
}),
|
|
||||||
Call::External { _id: _ } => None,
|
|
||||||
})
|
|
||||||
.partition(|call| {
|
|
||||||
call.function.name
|
|
||||||
== MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
|
||||||
});
|
|
||||||
|
|
||||||
chat_completion.messages.push(
|
|
||||||
ChatCompletionRequestAssistantMessageArgs::default()
|
|
||||||
.tool_calls(meili_calls.clone())
|
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(
|
|
||||||
other_calls.is_empty(),
|
|
||||||
"We do not support external tool forwarding for now"
|
|
||||||
);
|
|
||||||
|
|
||||||
for call in meili_calls {
|
|
||||||
if report_progress {
|
|
||||||
let call = MeiliSearchProgress {
|
|
||||||
call_id: call.id.to_string(),
|
|
||||||
function_name: call.function.name.clone(),
|
|
||||||
function_arguments: call
|
|
||||||
.function
|
|
||||||
.arguments
|
|
||||||
.clone(),
|
|
||||||
};
|
|
||||||
let resp = call.create_response(resp.clone());
|
|
||||||
// Send the event of "we are doing a search"
|
|
||||||
if let Err(SendError(_)) = tx
|
|
||||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if append_to_conversation {
|
|
||||||
// Ask the front-end user to append this tool *call* to the conversation
|
|
||||||
let call = MeiliAppendConversationMessage(ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: None,
|
|
||||||
refusal: None,
|
|
||||||
name: None,
|
|
||||||
audio: None,
|
|
||||||
tool_calls: Some(vec![
|
|
||||||
ChatCompletionMessageToolCall {
|
|
||||||
id: call.id.clone(),
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: FunctionCall {
|
|
||||||
name: call.function.name.clone(),
|
|
||||||
arguments: call.function.arguments.clone(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
));
|
|
||||||
let resp = call.create_response(resp.clone());
|
|
||||||
if let Err(SendError(_)) = tx
|
|
||||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let result =
|
|
||||||
match serde_json::from_str(&call.function.arguments) {
|
|
||||||
Ok(SearchInIndexParameters { index_uid, q }) => {
|
|
||||||
process_search_request(
|
|
||||||
&index_scheduler,
|
|
||||||
auth_ctrl.clone(),
|
|
||||||
&search_queue,
|
|
||||||
&auth_token,
|
|
||||||
index_uid,
|
|
||||||
q,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(|e| e.to_string())
|
|
||||||
}
|
|
||||||
Err(err) => Err(err.to_string()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let text = match result {
|
|
||||||
Ok((_index, documents, text)) => {
|
|
||||||
if report_sources {
|
|
||||||
let call = MeiliSearchSources {
|
|
||||||
call_id: call.id.to_string(),
|
|
||||||
sources: documents,
|
|
||||||
};
|
|
||||||
let resp = call.create_response(resp.clone());
|
|
||||||
// Send the event of "we are doing a search"
|
|
||||||
if let Err(SendError(_)) = tx
|
|
||||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
text
|
|
||||||
},
|
|
||||||
Err(err) => err,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
|
||||||
tool_call_id: call.id.clone(),
|
|
||||||
content: ChatCompletionRequestToolMessageContent::Text(
|
|
||||||
format!(
|
|
||||||
"{}\n\n{text}",
|
|
||||||
chat_settings
|
|
||||||
.prompts
|
|
||||||
.as_ref()
|
|
||||||
.unwrap()
|
|
||||||
.pre_query
|
|
||||||
.as_ref()
|
|
||||||
.unwrap()
|
|
||||||
),
|
|
||||||
),
|
|
||||||
});
|
|
||||||
|
|
||||||
if append_to_conversation {
|
|
||||||
// Ask the front-end user to append this tool *output* to the conversation
|
|
||||||
let tool = MeiliAppendConversationMessage(tool.clone());
|
|
||||||
let resp = tool.create_response(resp.clone());
|
|
||||||
if let Err(SendError(_)) = tx
|
|
||||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
chat_completion.messages.push(tool);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if let Err(SendError(_)) = tx
|
|
||||||
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
// tracing::error!("{err:?}");
|
|
||||||
// if let Err(SendError(_)) = tx
|
|
||||||
// .send(Event::Data(
|
|
||||||
// sse::Data::new_json(&json!({
|
|
||||||
// "object": "chat.completion.error",
|
|
||||||
// "tool": err.to_string(),
|
|
||||||
// }))
|
|
||||||
// .unwrap(),
|
|
||||||
// ))
|
|
||||||
// .await
|
|
||||||
// {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
break 'main;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We must stop if the finish reason is not something we can solve with Meilisearch
|
|
||||||
if finish_reason.map_or(true, |fr| fr != FinishReason::ToolCalls) {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
|
let _ = tx.stop().await;
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
/// Updates the chat completion with the new messages, streams the LLM tokens,
|
||||||
/// Provides information about the current Meilisearch search operation.
|
/// and report progress and errors.
|
||||||
struct MeiliSearchProgress {
|
async fn run_conversation<C: Config>(
|
||||||
/// The call ID to track the sources of the search.
|
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||||
pub call_id: String,
|
auth_ctrl: &web::Data<AuthController>,
|
||||||
/// The name of the function we are executing.
|
search_queue: &web::Data<SearchQueue>,
|
||||||
pub function_name: String,
|
auth_token: &str,
|
||||||
/// The arguments of the function we are executing, encoded in JSON.
|
client: &Client<C>,
|
||||||
pub function_arguments: String,
|
chat_settings: &GlobalChatSettings,
|
||||||
}
|
chat_completion: &mut CreateChatCompletionRequest,
|
||||||
|
tx: &SseEventSender,
|
||||||
|
global_tool_calls: &mut HashMap<u32, Call>,
|
||||||
|
function_support: FunctionSupport,
|
||||||
|
) -> Result<ControlFlow<Option<FinishReason>, ()>, SendError<Event>> {
|
||||||
|
let mut finish_reason = None;
|
||||||
|
|
||||||
impl MeiliSearchProgress {
|
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||||
fn create_response(
|
while let Some(result) = response.next().await {
|
||||||
&self,
|
match result {
|
||||||
mut resp: CreateChatCompletionStreamResponse,
|
Ok(resp) => {
|
||||||
) -> CreateChatCompletionStreamResponse {
|
let choice = &resp.choices[0];
|
||||||
let call_text = serde_json::to_string(self).unwrap();
|
finish_reason = choice.finish_reason;
|
||||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
|
||||||
index: 0,
|
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta;
|
||||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
match tool_calls {
|
||||||
function: Some(FunctionCallStream {
|
Some(tool_calls) => {
|
||||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
for chunk in tool_calls {
|
||||||
arguments: Some(call_text),
|
let ChatCompletionMessageToolCallChunk {
|
||||||
}),
|
index,
|
||||||
};
|
id,
|
||||||
resp.choices[0] = ChatChoiceStream {
|
r#type: _,
|
||||||
index: 0,
|
function,
|
||||||
delta: ChatCompletionStreamResponseDelta {
|
} = chunk;
|
||||||
content: None,
|
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
||||||
function_call: None,
|
|
||||||
tool_calls: Some(vec![tool_call]),
|
global_tool_calls
|
||||||
role: Some(Role::Assistant),
|
.entry(*index)
|
||||||
refusal: None,
|
.and_modify(|call| {
|
||||||
},
|
if call.is_internal() {
|
||||||
finish_reason: None,
|
call.append(arguments.as_ref().unwrap())
|
||||||
logprobs: None,
|
}
|
||||||
};
|
})
|
||||||
resp
|
.or_insert_with(|| {
|
||||||
|
if name
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||||
|
{
|
||||||
|
Call::Internal {
|
||||||
|
id: id.as_ref().unwrap().clone(),
|
||||||
|
function_name: name.as_ref().unwrap().clone(),
|
||||||
|
arguments: arguments.as_ref().unwrap().clone(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Call::External
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if global_tool_calls.get(index).map_or(false, Call::is_external) {
|
||||||
|
todo!("Support forwarding external tool calls");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if !global_tool_calls.is_empty() {
|
||||||
|
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
||||||
|
mem::take(global_tool_calls)
|
||||||
|
.into_values()
|
||||||
|
.flat_map(|call| match call {
|
||||||
|
Call::Internal { id, function_name: name, arguments } => {
|
||||||
|
Some(ChatCompletionMessageToolCall {
|
||||||
|
id,
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: FunctionCall { name, arguments },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Call::External => None,
|
||||||
|
})
|
||||||
|
.partition(|call| {
|
||||||
|
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||||
|
});
|
||||||
|
|
||||||
|
chat_completion.messages.push(
|
||||||
|
ChatCompletionRequestAssistantMessageArgs::default()
|
||||||
|
.tool_calls(meili_calls.clone())
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
other_calls.is_empty(),
|
||||||
|
"We do not support external tool forwarding for now"
|
||||||
|
);
|
||||||
|
|
||||||
|
handle_meili_tools(
|
||||||
|
&index_scheduler,
|
||||||
|
&auth_ctrl,
|
||||||
|
&search_queue,
|
||||||
|
&auth_token,
|
||||||
|
chat_settings,
|
||||||
|
tx,
|
||||||
|
meili_calls,
|
||||||
|
chat_completion,
|
||||||
|
&resp,
|
||||||
|
function_support,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
} else {
|
||||||
|
tx.forward_response(&resp).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
if function_support.report_errors {
|
||||||
|
tx.report_error(err).await?;
|
||||||
|
}
|
||||||
|
return Ok(ControlFlow::Break(None));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We must stop if the finish reason is not something we can solve with Meilisearch
|
||||||
|
match finish_reason {
|
||||||
|
Some(FinishReason::ToolCalls) => Ok(ControlFlow::Continue(())),
|
||||||
|
otherwise => Ok(ControlFlow::Break(otherwise)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
async fn handle_meili_tools(
|
||||||
/// Provides sources of the search.
|
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||||
struct MeiliSearchSources {
|
auth_ctrl: &web::Data<AuthController>,
|
||||||
/// The call ID to track the original search associated to those sources.
|
search_queue: &web::Data<SearchQueue>,
|
||||||
pub call_id: String,
|
auth_token: &str,
|
||||||
/// The documents associated with the search (call_id).
|
chat_settings: &GlobalChatSettings,
|
||||||
/// Only the displayed attributes of the documents are returned.
|
tx: &SseEventSender,
|
||||||
pub sources: Vec<Document>,
|
meili_calls: Vec<ChatCompletionMessageToolCall>,
|
||||||
}
|
chat_completion: &mut CreateChatCompletionRequest,
|
||||||
|
resp: &CreateChatCompletionStreamResponse,
|
||||||
|
FunctionSupport { report_progress, report_sources, append_to_conversation, .. }: FunctionSupport,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
for call in meili_calls {
|
||||||
|
if report_progress {
|
||||||
|
tx.report_search_progress(
|
||||||
|
resp.clone(),
|
||||||
|
&call.id,
|
||||||
|
&call.function.name,
|
||||||
|
&call.function.arguments,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
impl MeiliSearchSources {
|
if append_to_conversation {
|
||||||
fn create_response(
|
tx.append_tool_call_conversation_message(
|
||||||
&self,
|
resp.clone(),
|
||||||
mut resp: CreateChatCompletionStreamResponse,
|
call.id.clone(),
|
||||||
) -> CreateChatCompletionStreamResponse {
|
call.function.name.clone(),
|
||||||
let call_text = serde_json::to_string(self).unwrap();
|
call.function.arguments.clone(),
|
||||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
)
|
||||||
index: 0,
|
.await?;
|
||||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
}
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: Some(FunctionCallStream {
|
let result = match serde_json::from_str(&call.function.arguments) {
|
||||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
Ok(SearchInIndexParameters { index_uid, q }) => process_search_request(
|
||||||
arguments: Some(call_text),
|
&index_scheduler,
|
||||||
}),
|
auth_ctrl.clone(),
|
||||||
|
&search_queue,
|
||||||
|
&auth_token,
|
||||||
|
index_uid,
|
||||||
|
q,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
Err(err) => Err(err.to_string()),
|
||||||
};
|
};
|
||||||
resp.choices[0] = ChatChoiceStream {
|
|
||||||
index: 0,
|
let text = match result {
|
||||||
delta: ChatCompletionStreamResponseDelta {
|
Ok((_index, documents, text)) => {
|
||||||
content: None,
|
if report_sources {
|
||||||
function_call: None,
|
tx.report_sources(resp.clone(), &call.id, &documents).await?;
|
||||||
tool_calls: Some(vec![tool_call]),
|
}
|
||||||
role: Some(Role::Assistant),
|
|
||||||
refusal: None,
|
text
|
||||||
},
|
}
|
||||||
finish_reason: None,
|
Err(err) => err,
|
||||||
logprobs: None,
|
|
||||||
};
|
};
|
||||||
resp
|
|
||||||
|
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||||
|
tool_call_id: call.id.clone(),
|
||||||
|
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
||||||
|
"{}\n\n{text}",
|
||||||
|
chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
|
||||||
|
if append_to_conversation {
|
||||||
|
tx.append_conversation_message(resp.clone(), &tool).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
chat_completion.messages.push(tool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
|
pub struct SseEventSender(Sender<Event>);
|
||||||
|
|
||||||
impl MeiliAppendConversationMessage {
|
impl SseEventSender {
|
||||||
fn create_response(
|
/// Ask the front-end user to append this tool *call* to the conversation
|
||||||
|
pub async fn append_tool_call_conversation_message(
|
||||||
|
&self,
|
||||||
|
resp: CreateChatCompletionStreamResponse,
|
||||||
|
call_id: String,
|
||||||
|
function_name: String,
|
||||||
|
function_arguments: String,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
let message =
|
||||||
|
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||||
|
content: None,
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
audio: None,
|
||||||
|
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
||||||
|
id: call_id,
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: FunctionCall { name: function_name, arguments: function_arguments },
|
||||||
|
}]),
|
||||||
|
function_call: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.append_conversation_message(resp, &message).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ask the front-end user to append this tool to the conversation
|
||||||
|
pub async fn append_conversation_message(
|
||||||
&self,
|
&self,
|
||||||
mut resp: CreateChatCompletionStreamResponse,
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
) -> CreateChatCompletionStreamResponse {
|
message: &ChatCompletionRequestMessage,
|
||||||
let call_text = serde_json::to_string(&self.0).unwrap();
|
) -> Result<(), SendError<Event>> {
|
||||||
|
let call_text = serde_json::to_string(message).unwrap();
|
||||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
index: 0,
|
index: 0,
|
||||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
@ -798,6 +732,7 @@ impl MeiliAppendConversationMessage {
|
|||||||
arguments: Some(call_text),
|
arguments: Some(call_text),
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
resp.choices[0] = ChatChoiceStream {
|
resp.choices[0] = ChatChoiceStream {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: ChatCompletionStreamResponseDelta {
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
@ -810,7 +745,132 @@ impl MeiliAppendConversationMessage {
|
|||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
logprobs: None,
|
logprobs: None,
|
||||||
};
|
};
|
||||||
resp
|
|
||||||
|
self.send_json(&resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn report_search_progress(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
call_id: &str,
|
||||||
|
function_name: &str,
|
||||||
|
function_arguments: &str,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// Provides information about the current Meilisearch search operation.
|
||||||
|
struct MeiliSearchProgress<'a> {
|
||||||
|
/// The call ID to track the sources of the search.
|
||||||
|
call_id: &'a str,
|
||||||
|
/// The name of the function we are executing.
|
||||||
|
function_name: &'a str,
|
||||||
|
/// The arguments of the function we are executing, encoded in JSON.
|
||||||
|
function_arguments: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||||
|
let call_text = serde_json::to_string(&progress).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_json(&resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn report_sources(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
call_id: &str,
|
||||||
|
documents: &[Document],
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// Provides sources of the search.
|
||||||
|
struct MeiliSearchSources<'a> {
|
||||||
|
/// The call ID to track the original search associated to those sources.
|
||||||
|
call_id: &'a str,
|
||||||
|
/// The documents associated with the search (call_id).
|
||||||
|
/// Only the displayed attributes of the documents are returned.
|
||||||
|
sources: &'a [Document],
|
||||||
|
}
|
||||||
|
|
||||||
|
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||||
|
let call_text = serde_json::to_string(&sources).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_json(&resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn report_error(&self, error: OpenAIError) -> Result<(), SendError<Event>> {
|
||||||
|
tracing::error!("OpenAI Error: {}", error);
|
||||||
|
|
||||||
|
let (error_code, message) = match error {
|
||||||
|
OpenAIError::Reqwest(e) => ("internal_reqwest_error", e.to_string()),
|
||||||
|
OpenAIError::ApiError(api_error) => ("llm_api_issue", api_error.to_string()),
|
||||||
|
OpenAIError::JSONDeserialize(error) => ("internal_json_deserialize", error.to_string()),
|
||||||
|
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
||||||
|
OpenAIError::StreamError(error) => ("llm_api_stream_error", error.to_string()),
|
||||||
|
OpenAIError::InvalidArgument(error) => ("internal_invalid_argument", error.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_json(&json!({
|
||||||
|
"error_code": error_code,
|
||||||
|
"message": message,
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn forward_response(
|
||||||
|
&self,
|
||||||
|
resp: &CreateChatCompletionStreamResponse,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
self.send_json(resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
||||||
|
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
||||||
|
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -822,7 +882,7 @@ enum Call {
|
|||||||
Internal { id: String, function_name: String, arguments: String },
|
Internal { id: String, function_name: String, arguments: String },
|
||||||
/// Tool calls that we track but only to know that its not our functions.
|
/// Tool calls that we track but only to know that its not our functions.
|
||||||
/// We return the function calls as-is to the end-user.
|
/// We return the function calls as-is to the end-user.
|
||||||
External { _id: String },
|
External,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Call {
|
impl Call {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user