From ae89825b37bb6fcb823c9aa0dbc186eb1dd0abfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 14 May 2025 11:18:21 +0200 Subject: [PATCH] Implement a first version of a streamed chat API --- Cargo.lock | 281 +++++++++++++++----------- crates/meilisearch/Cargo.toml | 1 + crates/meilisearch/src/routes/chat.rs | 38 +++- 3 files changed, 200 insertions(+), 120 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb23b4f5d..aea49a0a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,7 +27,7 @@ checksum = "f9e772b3bcafe335042b5db010ab7c09013dad6eac4915c91d8d50902769f331" dependencies = [ "actix-utils", "actix-web", - "derive_more", + "derive_more 0.99.17", "futures-util", "log", "once_cell", @@ -36,24 +36,24 @@ dependencies = [ [[package]] name = "actix-http" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48f96fc3003717aeb9856ca3d02a8c7de502667ad76eeacd830b48d2e91fac4" +checksum = "44dfe5c9e0004c623edc65391dfd51daa201e7e30ebd9c9bedf873048ec32bc2" dependencies = [ "actix-codec", "actix-rt", "actix-service", "actix-tls", "actix-utils", - "ahash 0.8.11", "base64 0.22.1", "bitflags 2.9.0", - "brotli", + "brotli 8.0.1", "bytes", "bytestring", - "derive_more", + "derive_more 2.0.1", "encoding_rs", "flate2", + "foldhash", "futures-core", "h2 0.3.26", "http 0.2.11", @@ -65,7 +65,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "sha1", "smallvec", "tokio", @@ -92,6 +92,7 @@ dependencies = [ "bytestring", "cfg-if", "http 0.2.11", + "regex", "regex-lite", "serde", "tracing", @@ -187,7 +188,7 @@ dependencies = [ "bytestring", "cfg-if", "cookie", - "derive_more", + "derive_more 0.99.17", "encoding_rs", "futures-core", "futures-util", @@ -220,6 +221,43 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "actix-web-lab" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33034dd88446a5deb20e42156dbfe43d07e0499345db3ae65b3f51854190531" +dependencies = [ + "actix-http", + "actix-router", + "actix-service", + "actix-utils", + "actix-web", + "ahash 0.8.11", + "arc-swap", + "bytes", + "bytestring", + "csv", + "derive_more 2.0.1", + "form_urlencoded", + "futures-core", + "futures-util", + "http 0.2.11", + "impl-more", + "itertools 0.14.0", + "local-channel", + "mime", + "pin-project-lite", + "regex", + "serde", + "serde_html_form", + "serde_json", + "serde_path_to_error", + "tokio", + "tokio-stream", + "tracing", + "url", +] + [[package]] name = "addr2line" version = "0.20.0" @@ -241,12 +279,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "adler32" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" - [[package]] name = "aes" version = "0.8.4" @@ -391,6 +423,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayvec" version = "0.7.4" @@ -556,7 +594,7 @@ dependencies = [ "milli", "mimalloc", "rand 0.8.5", - "rand_chacha", + "rand_chacha 0.3.1", "reqwest", "roaring", "serde_json", @@ -708,7 +746,18 @@ checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor", + "brotli-decompressor 4.0.1", +] + +[[package]] +name = "brotli" +version = "8.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor 5.0.0", ] [[package]] @@ -721,6 +770,16 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "brotli-decompressor" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bstr" version = "1.11.3" @@ -1261,15 +1320,6 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" -[[package]] -name = "core2" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" -dependencies = [ - "memchr", -] - [[package]] name = "cpufeatures" version = "0.2.12" @@ -1499,12 +1549,6 @@ dependencies = [ "syn 2.0.87", ] -[[package]] -name = "dary_heap" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" - [[package]] name = "deadpool" version = "0.10.0" @@ -1634,6 +1678,27 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "unicode-xid", +] + [[package]] name = "deserr" version = "0.6.3" @@ -2842,32 +2907,9 @@ dependencies = [ [[package]] name = "impl-more" -version = "0.1.6" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d" - -[[package]] -name = "include-flate" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df49c16750695486c1f34de05da5b7438096156466e7f76c38fcdf285cf0113e" -dependencies = [ - "include-flate-codegen", - "lazy_static", - "libflate", -] - -[[package]] -name = "include-flate-codegen" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c5b246c6261be723b85c61ecf87804e8ea4a35cb68be0ff282ed84b95ffe7d7" -dependencies = [ - "libflate", - "proc-macro2", - "quote", - "syn 2.0.87", -] +checksum = "e8a5a9a0ff0086c7a148acb942baaabeadf9504d10400b5a05645853729b9cd2" [[package]] name = "index-scheduler" @@ -3044,27 +3086,18 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" -[[package]] -name = "jieba-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c676b32a471d3cfae8dac2ad2f8334cd52e53377733cca8c1fb0a5062fec192" -dependencies = [ - "phf_codegen", -] - [[package]] name = "jieba-rs" -version = "0.7.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d1bcad6332969e4d48ee568d430e14ee6dea70740c2549d005d87677ebefb0c" +checksum = "c1e2b0210dc78b49337af9e49d7ae41a39dceac6e5985613f1cf7763e2f76a25" dependencies = [ "cedarwood", + "derive_builder 0.20.2", "fxhash", - "include-flate", - "jieba-macros", "lazy_static", "phf", + "phf_codegen", "regex", ] @@ -3156,30 +3189,6 @@ version = "0.2.171" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" -[[package]] -name = "libflate" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" -dependencies = [ - "adler32", - "core2", - "crc32fast", - "dary_heap", - "libflate_lz77", -] - -[[package]] -name = "libflate_lz77" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" -dependencies = [ - "core2", - "hashbrown 0.14.3", - "rle-decode-fast", -] - [[package]] name = "libgit2-sys" version = "0.17.0+1.8.1" @@ -3243,9 +3252,9 @@ dependencies = [ [[package]] name = "lindera" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fa3936dbcfc54b90a53da68ec8fe209656cfa691147f951944f48c61dcde317" +checksum = "73b6ee48fa4ffaff0b34a0f56e8fe9e3a9f38ff097d7ffe11a189acac242efbf" dependencies = [ "anyhow", "bincode", @@ -3273,9 +3282,9 @@ dependencies = [ [[package]] name = "lindera-cc-cedict" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4720c69e32b278614eefb8181e0ef78907fa115d947edaeaedb1150785b902" +checksum = "88fb51b5730fd63b1baf677fb19ce3f3f00616a3fbaf430f923b676dce5fab39" dependencies = [ "bincode", "byteorder", @@ -3286,9 +3295,9 @@ dependencies = [ [[package]] name = "lindera-dictionary" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b123ac54a74c9418616c96d0d7cf5eb8fbf372211c07032d1e174c94e40ff030" +checksum = "d5dafa44610860d21f66dbfee1ad387fd127824b204137b540ada4c1a744b19c" dependencies = [ "anyhow", "bincode", @@ -3314,9 +3323,9 @@ dependencies = [ [[package]] name = "lindera-ipadic" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c3786e6cf65dd1e8537c3c35637f887289bf83687f6fbcac3a6679bfa33265" +checksum = "d273907fdf1c14a8244a370afd7ac79126337ad450d25888b1613aee17b1262a" dependencies = [ "bincode", "byteorder", @@ -3327,9 +3336,9 @@ dependencies = [ [[package]] name = "lindera-ipadic-neologd" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42646cc30bf8ceabf3db1154358329e1031f2af25ca1721ddba8ee3666881a08" +checksum = "9d4371fbd6dc3ac5cc76990ed41061c553635f67953771159e4061d7f568d14f" dependencies = [ "bincode", "byteorder", @@ -3340,9 +3349,9 @@ dependencies = [ [[package]] name = "lindera-ko-dic" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10f94a00fc5931636c10d2e6af4cfa43fbf95f8a529caa45d10600f3cb2853c9" +checksum = "03f35d8e54e6d5f73e9f76da0fedfa336fa60a6d2ac7f7dcc8bcd15e338db291" dependencies = [ "bincode", "byteorder", @@ -3353,9 +3362,9 @@ dependencies = [ [[package]] name = "lindera-unidic" -version = "0.42.3" +version = "0.42.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5933014ca145351d59bb50a6e509a53af1f89ceda687fe9efd6d534e6b59a27" +checksum = "661aa828cf6af7ccd1c0c1142c087fd048af5f83776ccec6af9f9c56448bc626" dependencies = [ "bincode", "byteorder", @@ -3578,10 +3587,11 @@ dependencies = [ "actix-rt", "actix-utils", "actix-web", + "actix-web-lab", "anyhow", "async-openai", "async-trait", - "brotli", + "brotli 6.0.0", "bstr", "build-info", "byte-unit", @@ -4673,7 +4683,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", + "rand_chacha 0.3.1", "rand_core 0.6.4", ] @@ -4683,6 +4693,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ + "rand_chacha 0.9.0", "rand_core 0.9.3", ] @@ -4696,6 +4707,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", +] + [[package]] name = "rand_core" version = "0.6.4" @@ -4710,6 +4731,9 @@ name = "rand_core" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", +] [[package]] name = "rand_distr" @@ -4983,12 +5007,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "rle-decode-fast" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" - [[package]] name = "roaring" version = "0.10.10" @@ -5290,6 +5308,19 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.140" @@ -5303,6 +5334,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -5930,9 +5971,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.45.0" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2513ca694ef9ede0fb23fe71a4ee4107cb102b9dc1930f6d0fd77aae068ae165" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", "bytes", @@ -6288,6 +6329,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unicode_categories" version = "0.1.1" diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index a0ce49193..f7469e7ac 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -113,6 +113,7 @@ utoipa = { version = "5.3.1", features = [ ] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } async-openai = "0.28.1" +actix-web-lab = { version = "0.24.1", default-features = false } [dev-dependencies] actix-rt = "2.10.0" diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 8f0552561..ad46d91c8 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -1,7 +1,8 @@ use std::mem; use actix_web::web::{self, Data}; -use actix_web::HttpResponse; +use actix_web::{Either, HttpResponse, Responder}; +use actix_web_lab::sse::{self, Event}; use async_openai::config::OpenAIConfig; use async_openai::types::{ ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, @@ -10,6 +11,7 @@ use async_openai::types::{ FunctionObjectArgs, }; use async_openai::Client; +use futures::StreamExt; use index_scheduler::IndexScheduler; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; @@ -53,10 +55,22 @@ async fn chat( index_scheduler: GuardedData, Data>, search_queue: web::Data, web::Json(mut chat_completion): web::Json, -) -> Result { +) -> impl Responder { // To enable later on, when the feature will be experimental // index_scheduler.features().check_chat("Using the /chat route")?; + if chat_completion.stream.unwrap_or(false) { + Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await) + } else { + Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await) + } +} + +async fn non_streamed_chat( + index_scheduler: GuardedData, Data>, + search_queue: web::Data, + mut chat_completion: CreateChatCompletionRequest, +) -> Result { let api_key = std::env::var("MEILI_OPENAI_API_KEY") .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base @@ -119,7 +133,7 @@ async fn chat( .build() .unwrap(), ); - response = dbg!(client.chat().create(chat_completion.clone()).await.unwrap()); + response = client.chat().create(chat_completion.clone()).await.unwrap(); let choice = &mut response.choices[0]; match choice.finish_reason { @@ -221,6 +235,24 @@ async fn chat( Ok(HttpResponse::Ok().json(response)) } +async fn streamed_chat( + index_scheduler: GuardedData, Data>, + search_queue: web::Data, + mut chat_completion: CreateChatCompletionRequest, +) -> impl Responder { + assert!(chat_completion.stream.unwrap_or(false)); + + let api_key = std::env::var("MEILI_OPENAI_API_KEY") + .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); + let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base + let client = Client::with_config(config); + let response = client.chat().create_stream(chat_completion).await.unwrap(); + actix_web_lab::sse::Sse::from_stream(response.map(|response| { + response + .map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap())) + })) +} + #[derive(Deserialize)] struct SearchInIndexParameters { /// The index uid to search in.