mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-12-24 13:26:57 +00:00
Compare commits
6 Commits
prototype-
...
prototype-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
295840d07a | ||
|
|
c0c3bddda8 | ||
|
|
10b5fcd4ba | ||
|
|
8113d4a52e | ||
|
|
5964289284 | ||
|
|
6b81854d48 |
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -470,8 +470,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "async-openai"
|
||||
version = "0.28.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14d76e2f5af19477d6254415acc95ba97c6cc6f3b1e3cb4676b7f0fab8194298"
|
||||
source = "git+https://github.com/meilisearch/async-openai?branch=optional-type-function#603f1d17bb4530c45fb9a6e93294ab715a7af869"
|
||||
dependencies = [
|
||||
"async-openai-macros",
|
||||
"backoff",
|
||||
@@ -496,8 +495,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "async-openai-macros"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
|
||||
source = "git+https://github.com/meilisearch/async-openai?branch=optional-type-function#603f1d17bb4530c45fb9a6e93294ab715a7af869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
||||
@@ -212,7 +212,7 @@ impl IndexScheduler {
|
||||
#[cfg(test)]
|
||||
run_loop_iteration: self.run_loop_iteration.clone(),
|
||||
features: self.features.clone(),
|
||||
chat_settings: self.chat_settings.clone(),
|
||||
chat_settings: self.chat_settings,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -870,12 +870,12 @@ impl IndexScheduler {
|
||||
|
||||
pub fn chat_settings(&self) -> Result<Option<serde_json::Value>> {
|
||||
let rtxn = self.env.read_txn().map_err(Error::HeedTransaction)?;
|
||||
self.chat_settings.get(&rtxn, &"main").map_err(Into::into)
|
||||
self.chat_settings.get(&rtxn, "main").map_err(Into::into)
|
||||
}
|
||||
|
||||
pub fn put_chat_settings(&self, settings: &serde_json::Value) -> Result<()> {
|
||||
let mut wtxn = self.env.write_txn().map_err(Error::HeedTransaction)?;
|
||||
self.chat_settings.put(&mut wtxn, &"main", &settings)?;
|
||||
self.chat_settings.put(&mut wtxn, "main", settings)?;
|
||||
wtxn.commit().map_err(Error::HeedTransaction)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ utoipa = { version = "5.3.1", features = [
|
||||
"openapi_extensions",
|
||||
] }
|
||||
utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] }
|
||||
async-openai = "0.28.1"
|
||||
async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "optional-type-function" }
|
||||
actix-web-lab = { version = "0.24.1", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -42,7 +42,7 @@ use crate::search_queue::SearchQueue;
|
||||
const EMBEDDER_NAME: &str = "openai";
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(web::resource("").route(web::post().to(chat)));
|
||||
cfg.service(web::resource("/completions").route(web::post().to(chat)));
|
||||
}
|
||||
|
||||
/// Get a chat completion
|
||||
@@ -86,7 +86,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
|
||||
"description": prompts.search_index_uid_param,
|
||||
},
|
||||
"q": {
|
||||
"type": ["string", "null"],
|
||||
// Unfortunately, Mistral does not support an array of types, here.
|
||||
// "type": ["string", "null"],
|
||||
"type": "string",
|
||||
"description": prompts.search_q_param,
|
||||
}
|
||||
},
|
||||
@@ -186,10 +188,9 @@ async fn non_streamed_chat(
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
// We cannot change the endpoint
|
||||
// if let Some(endpoint) = chat_settings.endpoint.as_ref() {
|
||||
// config.with_api_base(&endpoint);
|
||||
// }
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
let client = Client::with_config(config);
|
||||
|
||||
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
|
||||
@@ -257,10 +258,9 @@ async fn streamed_chat(
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
// We cannot change the endpoint
|
||||
// if let Some(endpoint) = chat_settings.endpoint.as_ref() {
|
||||
// config.with_api_base(&endpoint);
|
||||
// }
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
|
||||
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
|
||||
|
||||
@@ -268,30 +268,27 @@ async fn streamed_chat(
|
||||
let _join_handle = Handle::current().spawn(async move {
|
||||
let client = Client::with_config(config.clone());
|
||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||
let mut finish_reason = None;
|
||||
|
||||
'main: loop {
|
||||
'main: while finish_reason.map_or(true, |fr| fr == FinishReason::ToolCalls) {
|
||||
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||
|
||||
while let Some(result) = response.next().await {
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
let delta = &resp.choices[0].delta;
|
||||
let choice = &resp.choices[0];
|
||||
finish_reason = choice.finish_reason;
|
||||
|
||||
#[allow(deprecated)]
|
||||
let ChatCompletionStreamResponseDelta {
|
||||
content,
|
||||
// Using deprecated field but keeping for compatibility
|
||||
#[allow(deprecated)]
|
||||
function_call: _,
|
||||
function_call: _,
|
||||
ref tool_calls,
|
||||
role: _,
|
||||
refusal: _,
|
||||
} = delta;
|
||||
} = &choice.delta;
|
||||
|
||||
if content.is_none() && tool_calls.is_none() && global_tool_calls.is_empty()
|
||||
{
|
||||
break 'main;
|
||||
}
|
||||
|
||||
if let Some(_) = content {
|
||||
if content.is_some() {
|
||||
tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap()
|
||||
}
|
||||
|
||||
@@ -308,12 +305,14 @@ async fn streamed_chat(
|
||||
function.as_ref().unwrap();
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
.and_modify(|call| {
|
||||
call.append(arguments.as_ref().unwrap());
|
||||
})
|
||||
.or_insert_with(|| Call {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
})
|
||||
.append(arguments.as_ref().unwrap());
|
||||
});
|
||||
}
|
||||
}
|
||||
None if !global_tool_calls.is_empty() => {
|
||||
@@ -321,10 +320,10 @@ async fn streamed_chat(
|
||||
|
||||
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(&mut global_tool_calls)
|
||||
.into_iter()
|
||||
.map(|(_, call)| ChatCompletionMessageToolCall {
|
||||
.into_values()
|
||||
.map(|call| ChatCompletionMessageToolCall {
|
||||
id: call.id,
|
||||
r#type: ChatCompletionToolType::Function,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall {
|
||||
name: call.function_name,
|
||||
arguments: call.arguments,
|
||||
@@ -342,7 +341,7 @@ async fn streamed_chat(
|
||||
|
||||
for call in meili_calls {
|
||||
tx.send(Event::Data(
|
||||
sse::Data::new_json(&json!({
|
||||
sse::Data::new_json(json!({
|
||||
"object": "chat.completion.tool.call",
|
||||
"tool": call,
|
||||
}))
|
||||
@@ -380,7 +379,7 @@ async fn streamed_chat(
|
||||
);
|
||||
|
||||
tx.send(Event::Data(
|
||||
sse::Data::new_json(&json!({
|
||||
sse::Data::new_json(json!({
|
||||
"object": "chat.completion.tool.output",
|
||||
"tool": ChatCompletionRequestMessage::Tool(
|
||||
ChatCompletionRequestToolMessage {
|
||||
@@ -402,8 +401,10 @@ async fn streamed_chat(
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
Err(_err) => {
|
||||
Err(err) => {
|
||||
// writeln!(lock, "error: {err}").unwrap();
|
||||
tracing::error!("{err:?}");
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ async fn patch_settings(
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct ChatSettings {
|
||||
pub source: String,
|
||||
pub endpoint: Option<String>,
|
||||
pub base_api: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub prompts: ChatPrompts,
|
||||
pub indexes: BTreeMap<String, ChatIndexSettings>,
|
||||
@@ -95,7 +95,7 @@ impl Default for ChatSettings {
|
||||
fn default() -> Self {
|
||||
ChatSettings {
|
||||
source: "openai".to_string(),
|
||||
endpoint: None,
|
||||
base_api: None,
|
||||
api_key: None,
|
||||
prompts: ChatPrompts {
|
||||
system: DEFAULT_SYSTEM_MESSAGE.to_string(),
|
||||
|
||||
Reference in New Issue
Block a user