diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 6c6c97761..733b8ff65 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -27,7 +27,7 @@ use meilisearch_types::{Document, Index}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::runtime::Handle; -use tracing::error; +use tokio::sync::mpsc::error::SendError; use super::settings::chat::{ChatPrompts, ChatSettings}; use crate::extractors::authentication::policies::ActionPolicy; @@ -289,7 +289,9 @@ async fn streamed_chat( } = &choice.delta; if content.is_some() { - tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap() + if let Err(SendError(_)) = tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await { + return; + } } match tool_calls { @@ -305,9 +307,7 @@ async fn streamed_chat( function.as_ref().unwrap(); global_tool_calls .entry(*index) - .and_modify(|call| { - call.append(arguments.as_ref().unwrap()); - }) + .and_modify(|call| call.append(arguments.as_ref().unwrap())) .or_insert_with(|| Call { id: id.as_ref().unwrap().clone(), function_name: name.as_ref().unwrap().clone(), @@ -316,8 +316,6 @@ async fn streamed_chat( } } None if !global_tool_calls.is_empty() => { - // dbg!(&global_tool_calls); - let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = mem::take(&mut global_tool_calls) .into_values() @@ -340,15 +338,16 @@ async fn streamed_chat( ); for call in meili_calls { - tx.send(Event::Data( + if let Err(SendError(_)) = tx.send(Event::Data( sse::Data::new_json(json!({ "object": "chat.completion.tool.call", "tool": call, })) .unwrap(), )) - .await - .unwrap(); + .await { + return; + } let SearchInIndexParameters { index_uid, q } = serde_json::from_str(&call.function.arguments).unwrap(); @@ -361,41 +360,40 @@ async fn streamed_chat( ) .await; + let is_error = result.is_err(); let text = match result { Ok((_, text)) => text, - Err(err) => { - error!("Error processing search request: {err:?}"); - continue; - } + Err(err) => err.to_string(), }; - let tool = ChatCompletionRequestMessage::Tool( - ChatCompletionRequestToolMessage { - tool_call_id: call.id.clone(), - content: ChatCompletionRequestToolMessageContent::Text( - format!("{}\n\n{text}", chat_settings.prompts.pre_query), - ), - }, - ); + let tool = ChatCompletionRequestToolMessage { + tool_call_id: call.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text( + format!("{}\n\n{text}", chat_settings.prompts.pre_query), + ), + }; - tx.send(Event::Data( + if let Err(SendError(_)) = tx.send(Event::Data( sse::Data::new_json(json!({ - "object": "chat.completion.tool.output", - "tool": ChatCompletionRequestMessage::Tool( - ChatCompletionRequestToolMessage { - tool_call_id: call.id, - content: ChatCompletionRequestToolMessageContent::Text( - text, - ), - }, - ), + "object": if is_error { + "chat.completion.tool.error" + } else { + "chat.completion.tool.output" + }, + "tool": ChatCompletionRequestToolMessage { + tool_call_id: call.id, + content: ChatCompletionRequestToolMessageContent::Text( + text, + ), + }, })) .unwrap(), )) - .await - .unwrap(); + .await { + return; + } - chat_completion.messages.push(tool); + chat_completion.messages.push(ChatCompletionRequestMessage::Tool(tool)); } } None => (),