From 0f05c0eb6fbae67a2555a9f04c90b3e5fe5c2b83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 14 May 2025 11:18:21 +0200 Subject: [PATCH] Implement a first version of a streamed chat API --- Cargo.lock | 158 +++++++++++++++++++++++--- crates/meilisearch/Cargo.toml | 1 + crates/meilisearch/src/routes/chat.rs | 38 ++++++- 3 files changed, 180 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c205bbb3d..6ce6b5186 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,7 +27,7 @@ checksum = "f9e772b3bcafe335042b5db010ab7c09013dad6eac4915c91d8d50902769f331" dependencies = [ "actix-utils", "actix-web", - "derive_more", + "derive_more 0.99.17", "futures-util", "log", "once_cell", @@ -36,24 +36,24 @@ dependencies = [ [[package]] name = "actix-http" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48f96fc3003717aeb9856ca3d02a8c7de502667ad76eeacd830b48d2e91fac4" +checksum = "44dfe5c9e0004c623edc65391dfd51daa201e7e30ebd9c9bedf873048ec32bc2" dependencies = [ "actix-codec", "actix-rt", "actix-service", "actix-tls", "actix-utils", - "ahash 0.8.11", "base64 0.22.1", "bitflags 2.9.0", - "brotli", + "brotli 8.0.1", "bytes", "bytestring", - "derive_more", + "derive_more 2.0.1", "encoding_rs", "flate2", + "foldhash", "futures-core", "h2 0.3.26", "http 0.2.11", @@ -65,7 +65,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "sha1", "smallvec", "tokio", @@ -92,6 +92,7 @@ dependencies = [ "bytestring", "cfg-if", "http 0.2.11", + "regex", "regex-lite", "serde", "tracing", @@ -187,7 +188,7 @@ dependencies = [ "bytestring", "cfg-if", "cookie", - "derive_more", + "derive_more 0.99.17", "encoding_rs", "futures-core", "futures-util", @@ -220,6 +221,43 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "actix-web-lab" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33034dd88446a5deb20e42156dbfe43d07e0499345db3ae65b3f51854190531" +dependencies = [ + "actix-http", + "actix-router", + "actix-service", + "actix-utils", + "actix-web", + "ahash 0.8.11", + "arc-swap", + "bytes", + "bytestring", + "csv", + "derive_more 2.0.1", + "form_urlencoded", + "futures-core", + "futures-util", + "http 0.2.11", + "impl-more", + "itertools 0.14.0", + "local-channel", + "mime", + "pin-project-lite", + "regex", + "serde", + "serde_html_form", + "serde_json", + "serde_path_to_error", + "tokio", + "tokio-stream", + "tracing", + "url", +] + [[package]] name = "addr2line" version = "0.20.0" @@ -391,6 +429,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayvec" version = "0.7.4" @@ -556,7 +600,7 @@ dependencies = [ "milli", "mimalloc", "rand 0.8.5", - "rand_chacha", + "rand_chacha 0.3.1", "reqwest", "roaring", "serde_json", @@ -708,7 +752,18 @@ checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor", + "brotli-decompressor 4.0.1", +] + +[[package]] +name = "brotli" +version = "8.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor 5.0.0", ] [[package]] @@ -721,6 +776,16 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "brotli-decompressor" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bstr" version = "1.11.3" @@ -1634,6 +1699,27 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "unicode-xid", +] + [[package]] name = "deserr" version = "0.6.3" @@ -2847,9 +2933,9 @@ dependencies = [ [[package]] name = "impl-more" -version = "0.1.6" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d" +checksum = "e8a5a9a0ff0086c7a148acb942baaabeadf9504d10400b5a05645853729b9cd2" [[package]] name = "include-flate" @@ -3593,10 +3679,11 @@ dependencies = [ "actix-rt", "actix-utils", "actix-web", + "actix-web-lab", "anyhow", "async-openai", "async-trait", - "brotli", + "brotli 6.0.0", "bstr", "build-info", "byte-unit", @@ -4688,7 +4775,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", + "rand_chacha 0.3.1", "rand_core 0.6.4", ] @@ -4698,6 +4785,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ + "rand_chacha 0.9.0", "rand_core 0.9.3", ] @@ -4711,6 +4799,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", +] + [[package]] name = "rand_core" version = "0.6.4" @@ -4725,6 +4823,9 @@ name = "rand_core" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", +] [[package]] name = "rand_distr" @@ -5304,6 +5405,19 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.140" @@ -5317,6 +5431,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -6320,6 +6444,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unicode_categories" version = "0.1.1" diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index a0ce49193..f7469e7ac 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -113,6 +113,7 @@ utoipa = { version = "5.3.1", features = [ ] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } async-openai = "0.28.1" +actix-web-lab = { version = "0.24.1", default-features = false } [dev-dependencies] actix-rt = "2.10.0" diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 8f0552561..ad46d91c8 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -1,7 +1,8 @@ use std::mem; use actix_web::web::{self, Data}; -use actix_web::HttpResponse; +use actix_web::{Either, HttpResponse, Responder}; +use actix_web_lab::sse::{self, Event}; use async_openai::config::OpenAIConfig; use async_openai::types::{ ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, @@ -10,6 +11,7 @@ use async_openai::types::{ FunctionObjectArgs, }; use async_openai::Client; +use futures::StreamExt; use index_scheduler::IndexScheduler; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; @@ -53,10 +55,22 @@ async fn chat( index_scheduler: GuardedData, Data>, search_queue: web::Data, web::Json(mut chat_completion): web::Json, -) -> Result { +) -> impl Responder { // To enable later on, when the feature will be experimental // index_scheduler.features().check_chat("Using the /chat route")?; + if chat_completion.stream.unwrap_or(false) { + Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await) + } else { + Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await) + } +} + +async fn non_streamed_chat( + index_scheduler: GuardedData, Data>, + search_queue: web::Data, + mut chat_completion: CreateChatCompletionRequest, +) -> Result { let api_key = std::env::var("MEILI_OPENAI_API_KEY") .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base @@ -119,7 +133,7 @@ async fn chat( .build() .unwrap(), ); - response = dbg!(client.chat().create(chat_completion.clone()).await.unwrap()); + response = client.chat().create(chat_completion.clone()).await.unwrap(); let choice = &mut response.choices[0]; match choice.finish_reason { @@ -221,6 +235,24 @@ async fn chat( Ok(HttpResponse::Ok().json(response)) } +async fn streamed_chat( + index_scheduler: GuardedData, Data>, + search_queue: web::Data, + mut chat_completion: CreateChatCompletionRequest, +) -> impl Responder { + assert!(chat_completion.stream.unwrap_or(false)); + + let api_key = std::env::var("MEILI_OPENAI_API_KEY") + .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); + let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base + let client = Client::with_config(config); + let response = client.chat().create_stream(chat_completion).await.unwrap(); + actix_web_lab::sse::Sse::from_stream(response.map(|response| { + response + .map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap())) + })) +} + #[derive(Deserialize)] struct SearchInIndexParameters { /// The index uid to search in.