Make it compatible with the Mistral API

This commit is contained in:
Clément Renault 2025-05-16 14:33:53 +02:00 committed by Kerollmops
parent 64fe283abc
commit 116fbb4c24
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
3 changed files with 13 additions and 11 deletions

6
Cargo.lock generated
View File

@ -470,8 +470,7 @@ dependencies = [
[[package]] [[package]]
name = "async-openai" name = "async-openai"
version = "0.28.1" version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/meilisearch/async-openai?branch=optional-type-function#dd328d4c35ca24c30284c8aff616541ac82eb47a"
checksum = "14d76e2f5af19477d6254415acc95ba97c6cc6f3b1e3cb4676b7f0fab8194298"
dependencies = [ dependencies = [
"async-openai-macros", "async-openai-macros",
"backoff", "backoff",
@ -496,8 +495,7 @@ dependencies = [
[[package]] [[package]]
name = "async-openai-macros" name = "async-openai-macros"
version = "0.1.0" version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/meilisearch/async-openai?branch=optional-type-function#dd328d4c35ca24c30284c8aff616541ac82eb47a"
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

View File

@ -112,7 +112,7 @@ utoipa = { version = "5.3.1", features = [
"openapi_extensions", "openapi_extensions",
] } ] }
utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] }
async-openai = "0.28.1" async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "optional-type-function" }
actix-web-lab = { version = "0.24.1", default-features = false } actix-web-lab = { version = "0.24.1", default-features = false }
[dev-dependencies] [dev-dependencies]

View File

@ -86,7 +86,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
"description": prompts.search_index_uid_param, "description": prompts.search_index_uid_param,
}, },
"q": { "q": {
"type": ["string", "null"], // Unfortunately, Mistral does not support an array of types, here.
// "type": ["string", "null"],
"type": "string",
"description": prompts.search_q_param, "description": prompts.search_q_param,
} }
}, },
@ -269,7 +271,6 @@ async fn streamed_chat(
'main: loop { 'main: loop {
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
while let Some(result) = response.next().await { while let Some(result) = response.next().await {
match result { match result {
Ok(resp) => { Ok(resp) => {
@ -306,12 +307,14 @@ async fn streamed_chat(
function.as_ref().unwrap(); function.as_ref().unwrap();
global_tool_calls global_tool_calls
.entry(*index) .entry(*index)
.and_modify(|call| {
call.append(arguments.as_ref().unwrap());
})
.or_insert_with(|| Call { .or_insert_with(|| Call {
id: id.as_ref().unwrap().clone(), id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(), function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(), arguments: arguments.as_ref().unwrap().clone(),
}) });
.append(arguments.as_ref().unwrap());
} }
} }
None if !global_tool_calls.is_empty() => { None if !global_tool_calls.is_empty() => {
@ -322,7 +325,7 @@ async fn streamed_chat(
.into_values() .into_values()
.map(|call| ChatCompletionMessageToolCall { .map(|call| ChatCompletionMessageToolCall {
id: call.id, id: call.id,
r#type: ChatCompletionToolType::Function, r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { function: FunctionCall {
name: call.function_name, name: call.function_name,
arguments: call.arguments, arguments: call.arguments,
@ -400,8 +403,9 @@ async fn streamed_chat(
None => (), None => (),
} }
} }
Err(_err) => { Err(err) => {
// writeln!(lock, "error: {err}").unwrap(); // writeln!(lock, "error: {err}").unwrap();
tracing::error!("{err:?}");
} }
} }
} }