Stop the stream when the connexion stops and chnage the events

This commit is contained in:
Clément Renault 2025-05-20 12:05:51 +02:00
parent 4f919db344
commit 42c95cf3c4
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -27,7 +27,7 @@ use meilisearch_types::{Document, Index};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tracing::error; use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, ChatSettings}; use super::settings::chat::{ChatPrompts, ChatSettings};
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
@ -289,7 +289,9 @@ async fn streamed_chat(
} = &choice.delta; } = &choice.delta;
if content.is_some() { if content.is_some() {
tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap() if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await {
return;
}
} }
match tool_calls { match tool_calls {
@ -305,9 +307,7 @@ async fn streamed_chat(
function.as_ref().unwrap(); function.as_ref().unwrap();
global_tool_calls global_tool_calls
.entry(*index) .entry(*index)
.and_modify(|call| { .and_modify(|call| call.append(arguments.as_ref().unwrap()))
call.append(arguments.as_ref().unwrap());
})
.or_insert_with(|| Call { .or_insert_with(|| Call {
id: id.as_ref().unwrap().clone(), id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(), function_name: name.as_ref().unwrap().clone(),
@ -316,8 +316,6 @@ async fn streamed_chat(
} }
} }
None if !global_tool_calls.is_empty() => { None if !global_tool_calls.is_empty() => {
// dbg!(&global_tool_calls);
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls) mem::take(&mut global_tool_calls)
.into_values() .into_values()
@ -340,15 +338,16 @@ async fn streamed_chat(
); );
for call in meili_calls { for call in meili_calls {
tx.send(Event::Data( if let Err(SendError(_)) = tx.send(Event::Data(
sse::Data::new_json(json!({ sse::Data::new_json(json!({
"object": "chat.completion.tool.call", "object": "chat.completion.tool.call",
"tool": call, "tool": call,
})) }))
.unwrap(), .unwrap(),
)) ))
.await .await {
.unwrap(); return;
}
let SearchInIndexParameters { index_uid, q } = let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap(); serde_json::from_str(&call.function.arguments).unwrap();
@ -361,41 +360,40 @@ async fn streamed_chat(
) )
.await; .await;
let is_error = result.is_err();
let text = match result { let text = match result {
Ok((_, text)) => text, Ok((_, text)) => text,
Err(err) => { Err(err) => err.to_string(),
error!("Error processing search request: {err:?}");
continue;
}
}; };
let tool = ChatCompletionRequestMessage::Tool( let tool = ChatCompletionRequestToolMessage {
ChatCompletionRequestToolMessage { tool_call_id: call.id.clone(),
tool_call_id: call.id.clone(), content: ChatCompletionRequestToolMessageContent::Text(
content: ChatCompletionRequestToolMessageContent::Text( format!("{}\n\n{text}", chat_settings.prompts.pre_query),
format!("{}\n\n{text}", chat_settings.prompts.pre_query), ),
), };
},
);
tx.send(Event::Data( if let Err(SendError(_)) = tx.send(Event::Data(
sse::Data::new_json(json!({ sse::Data::new_json(json!({
"object": "chat.completion.tool.output", "object": if is_error {
"tool": ChatCompletionRequestMessage::Tool( "chat.completion.tool.error"
ChatCompletionRequestToolMessage { } else {
tool_call_id: call.id, "chat.completion.tool.output"
content: ChatCompletionRequestToolMessageContent::Text( },
text, "tool": ChatCompletionRequestToolMessage {
), tool_call_id: call.id,
}, content: ChatCompletionRequestToolMessageContent::Text(
), text,
),
},
})) }))
.unwrap(), .unwrap(),
)) ))
.await .await {
.unwrap(); return;
}
chat_completion.messages.push(tool); chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool));
} }
} }
None => (), None => (),