diff --git a/Cargo.lock b/Cargo.lock index aea49a0a4..41947aa5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -470,8 +470,7 @@ dependencies = [ [[package]] name = "async-openai" version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14d76e2f5af19477d6254415acc95ba97c6cc6f3b1e3cb4676b7f0fab8194298" +source = "git+https://github.com/meilisearch/async-openai?branch=optional-type-function#dd328d4c35ca24c30284c8aff616541ac82eb47a" dependencies = [ "async-openai-macros", "backoff", @@ -496,8 +495,7 @@ dependencies = [ [[package]] name = "async-openai-macros" version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398" +source = "git+https://github.com/meilisearch/async-openai?branch=optional-type-function#dd328d4c35ca24c30284c8aff616541ac82eb47a" dependencies = [ "proc-macro2", "quote", diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index f7469e7ac..398b62dad 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -112,7 +112,7 @@ utoipa = { version = "5.3.1", features = [ "openapi_extensions", ] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } -async-openai = "0.28.1" +async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "optional-type-function" } actix-web-lab = { version = "0.24.1", default-features = false } [dev-dependencies] diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 8d07342c8..7869b677d 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -86,7 +86,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts: "description": prompts.search_index_uid_param, }, "q": { - "type": ["string", "null"], + // Unfortunately, Mistral does not support an array of types, here. + // "type": ["string", "null"], + "type": "string", "description": prompts.search_q_param, } }, @@ -269,7 +271,6 @@ async fn streamed_chat( 'main: loop { let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); - while let Some(result) = response.next().await { match result { Ok(resp) => { @@ -306,12 +307,14 @@ async fn streamed_chat( function.as_ref().unwrap(); global_tool_calls .entry(*index) + .and_modify(|call| { + call.append(arguments.as_ref().unwrap()); + }) .or_insert_with(|| Call { id: id.as_ref().unwrap().clone(), function_name: name.as_ref().unwrap().clone(), arguments: arguments.as_ref().unwrap().clone(), - }) - .append(arguments.as_ref().unwrap()); + }); } } None if !global_tool_calls.is_empty() => { @@ -322,7 +325,7 @@ async fn streamed_chat( .into_values() .map(|call| ChatCompletionMessageToolCall { id: call.id, - r#type: ChatCompletionToolType::Function, + r#type: Some(ChatCompletionToolType::Function), function: FunctionCall { name: call.function_name, arguments: call.arguments, @@ -400,8 +403,9 @@ async fn streamed_chat( None => (), } } - Err(_err) => { + Err(err) => { // writeln!(lock, "error: {err}").unwrap(); + tracing::error!("{err:?}"); } } }