mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-31 07:56:28 +00:00 
			
		
		
		
	Nearly support tools on the streaming route
This commit is contained in:
		| @@ -6,14 +6,16 @@ use actix_web::{Either, HttpResponse, Responder}; | ||||
| use actix_web_lab::sse::{self, Event}; | ||||
| use async_openai::config::OpenAIConfig; | ||||
| use async_openai::types::{ | ||||
|     ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs, | ||||
|     ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, | ||||
|     ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, | ||||
|     ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason, | ||||
|     FunctionCallStream, FunctionObjectArgs, | ||||
|     ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, | ||||
|     ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, | ||||
|     ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, | ||||
|     ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, | ||||
|     CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream, | ||||
|     FunctionObjectArgs, | ||||
| }; | ||||
| use async_openai::Client; | ||||
| use futures::StreamExt; | ||||
| use futures_util::stream; | ||||
| use index_scheduler::IndexScheduler; | ||||
| use meilisearch_types::error::ResponseError; | ||||
| use meilisearch_types::keys::actions; | ||||
| @@ -23,6 +25,7 @@ use meilisearch_types::milli::vector::EmbeddingConfig; | ||||
| use meilisearch_types::{Document, Index}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use serde_json::json; | ||||
| use tokio::runtime::Handle; | ||||
|  | ||||
| use crate::extractors::authentication::policies::ActionPolicy; | ||||
| use crate::extractors::authentication::GuardedData; | ||||
| @@ -297,26 +300,25 @@ async fn streamed_chat( | ||||
|  | ||||
|     let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base | ||||
|     let client = Client::with_config(config); | ||||
|     let response = client.chat().create_stream(chat_completion).await.unwrap(); | ||||
|     let response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); | ||||
|     let mut global_tool_calls = HashMap::<u32, Call>::new(); | ||||
|     actix_web_lab::sse::Sse::from_stream(response.map(move |response| { | ||||
|         response.map(|mut r| { | ||||
|             let delta = &r.choices[0].delta; | ||||
|     actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response { | ||||
|         Ok(resp) => { | ||||
|             let delta = &resp.choices[0].delta; | ||||
|             let ChatCompletionStreamResponseDelta { | ||||
|                 ref content, | ||||
|                 ref function_call, | ||||
|                 content: _, | ||||
|                 function_call: _, | ||||
|                 ref tool_calls, | ||||
|                 ref role, | ||||
|                 ref refusal, | ||||
|                 role: _, | ||||
|                 refusal: _, | ||||
|             } = delta; | ||||
|  | ||||
|             match tool_calls { | ||||
|                 Some(tool_calls) => { | ||||
|                     for chunk in tool_calls { | ||||
|                         let ChatCompletionMessageToolCallChunk { index, id, r#type, function } = | ||||
|                         let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } = | ||||
|                             chunk; | ||||
|                         let FunctionCallStream { ref name, ref arguments } = | ||||
|                             function.as_ref().unwrap(); | ||||
|                         let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); | ||||
|                         global_tool_calls | ||||
|                             .entry(*index) | ||||
|                             .or_insert_with(|| Call { | ||||
| @@ -326,15 +328,120 @@ async fn streamed_chat( | ||||
|                             }) | ||||
|                             .append(arguments.as_ref().unwrap()); | ||||
|                     } | ||||
|                     stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]) | ||||
|                 } | ||||
|                 None if !global_tool_calls.is_empty() => { | ||||
|                     dbg!(&global_tool_calls); | ||||
|                 } | ||||
|                 None => (), | ||||
|             } | ||||
|  | ||||
|             Event::Data(sse::Data::new_json(r).unwrap()) | ||||
|         }) | ||||
|                     let (meili_calls, other_calls): (Vec<_>, Vec<_>) = | ||||
|                         mem::take(&mut global_tool_calls) | ||||
|                             .into_iter() | ||||
|                             .map(|(_, call)| ChatCompletionMessageToolCall { | ||||
|                                 id: call.id, | ||||
|                                 r#type: ChatCompletionToolType::Function, | ||||
|                                 function: FunctionCall { | ||||
|                                     name: call.function_name, | ||||
|                                     arguments: call.arguments, | ||||
|                                 }, | ||||
|                             }) | ||||
|                             .partition(|call| call.function.name == "searchInIndex"); | ||||
|  | ||||
|                     chat_completion.messages.push( | ||||
|                         ChatCompletionRequestAssistantMessageArgs::default() | ||||
|                             .tool_calls(meili_calls.clone()) | ||||
|                             .build() | ||||
|                             .unwrap() | ||||
|                             .into(), | ||||
|                     ); | ||||
|  | ||||
|                     for call in meili_calls { | ||||
|                         let SearchInIndexParameters { index_uid, q } = | ||||
|                             serde_json::from_str(&call.function.arguments).unwrap(); | ||||
|  | ||||
|                         let mut query = SearchQuery { | ||||
|                             q, | ||||
|                             hybrid: Some(HybridQuery { | ||||
|                                 semantic_ratio: SemanticRatio::default(), | ||||
|                                 embedder: EMBEDDER_NAME.to_string(), | ||||
|                             }), | ||||
|                             limit: 20, | ||||
|                             ..Default::default() | ||||
|                         }; | ||||
|  | ||||
|                         // Tenant token search_rules. | ||||
|                         if let Some(search_rules) = | ||||
|                             index_scheduler.filters().get_index_search_rules(&index_uid) | ||||
|                         { | ||||
|                             add_search_rules(&mut query.filter, search_rules); | ||||
|                         } | ||||
|  | ||||
|                         // TBD | ||||
|                         // let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query); | ||||
|  | ||||
|                         let index = index_scheduler.index(&index_uid).unwrap(); | ||||
|                         let search_kind = search_kind( | ||||
|                             &query, | ||||
|                             index_scheduler.get_ref(), | ||||
|                             index_uid.to_string(), | ||||
|                             &index, | ||||
|                         ) | ||||
|                         .unwrap(); | ||||
|  | ||||
|                         // let permit = search_queue.try_get_search_permit().await?; | ||||
|                         let features = index_scheduler.features(); | ||||
|                         let index_cloned = index.clone(); | ||||
|                         // let search_result = tokio::task::spawn_blocking(move || { | ||||
|                         let search_result = perform_search( | ||||
|                             index_uid.to_string(), | ||||
|                             &index_cloned, | ||||
|                             query, | ||||
|                             search_kind, | ||||
|                             RetrieveVectors::new(false), | ||||
|                             features, | ||||
|                         ); | ||||
|                         // }) | ||||
|                         // .await; | ||||
|                         // permit.drop().await; | ||||
|  | ||||
|                         // let search_result = search_result.unwrap(); | ||||
|                         if let Ok(ref search_result) = search_result { | ||||
|                             // aggregate.succeed(search_result); | ||||
|                             if search_result.degraded { | ||||
|                                 MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); | ||||
|                             } | ||||
|                         } | ||||
|                         // analytics.publish(aggregate, &req); | ||||
|  | ||||
|                         let search_result = search_result.unwrap(); | ||||
|                         let formatted = format_documents( | ||||
|                             &index, | ||||
|                             search_result.hits.into_iter().map(|doc| doc.document), | ||||
|                         ); | ||||
|                         let text = formatted.join("\n"); | ||||
|                         chat_completion.messages.push(ChatCompletionRequestMessage::Tool( | ||||
|                             ChatCompletionRequestToolMessage { | ||||
|                                 tool_call_id: call.id, | ||||
|                                 content: ChatCompletionRequestToolMessageContent::Text(text), | ||||
|                             }, | ||||
|                         )); | ||||
|                     } | ||||
|  | ||||
|                     let response = Handle::current().block_on(async { | ||||
|                         client.chat().create_stream(chat_completion.clone()).await.unwrap() | ||||
|                     }); | ||||
|  | ||||
|                     // stream::iter(vec![ | ||||
|                     //     Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())), | ||||
|                     //     Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())), | ||||
|                     //     Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())), | ||||
|                     // ]) | ||||
|  | ||||
|                     response | ||||
|                 } | ||||
|                 None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]), | ||||
|             } | ||||
|         } | ||||
|         Err(err) => stream::iter(vec![Err(err)]), | ||||
|     })) | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user