mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-06-06 20:25:40 +00:00
Factorise a bit the code
This commit is contained in:
parent
d94f16b1d2
commit
e603e221d5
@ -55,6 +55,14 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||||||
cfg.service(web::resource("").route(web::post().to(chat)));
|
cfg.service(web::resource("").route(web::post().to(chat)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates OpenAI client with API key
|
||||||
|
fn create_openai_client() -> Client<OpenAIConfig> {
|
||||||
|
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
|
||||||
|
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
|
||||||
|
let config = OpenAIConfig::default().with_api_key(&api_key);
|
||||||
|
Client::with_config(config)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a chat completion
|
/// Get a chat completion
|
||||||
async fn chat(
|
async fn chat(
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
||||||
@ -77,16 +85,112 @@ async fn chat(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn non_streamed_chat(
|
/// Setup search tool in chat completion request
|
||||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
fn setup_search_tool(
|
||||||
search_queue: web::Data<SearchQueue>,
|
chat_completion: &mut CreateChatCompletionRequest,
|
||||||
mut chat_completion: CreateChatCompletionRequest,
|
search_in_index_description: &str,
|
||||||
) -> Result<HttpResponse, ResponseError> {
|
search_in_index_q_param_description: &str,
|
||||||
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
|
search_in_index_index_description: &str,
|
||||||
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
|
) {
|
||||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
let tools = chat_completion.tools.get_or_insert_default();
|
||||||
let client = Client::with_config(config);
|
tools.push(
|
||||||
|
ChatCompletionToolArgs::default()
|
||||||
|
.r#type(ChatCompletionToolType::Function)
|
||||||
|
.function(
|
||||||
|
FunctionObjectArgs::default()
|
||||||
|
.name("searchInIndex")
|
||||||
|
.description(search_in_index_description)
|
||||||
|
.parameters(json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index_uid": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["main"],
|
||||||
|
"description": search_in_index_index_description,
|
||||||
|
},
|
||||||
|
"q": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": search_in_index_q_param_description,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["index_uid", "q"],
|
||||||
|
"additionalProperties": false,
|
||||||
|
}))
|
||||||
|
.strict(true)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process search request and return formatted results
|
||||||
|
async fn process_search_request(
|
||||||
|
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
||||||
|
search_queue: &web::Data<SearchQueue>,
|
||||||
|
index_uid: String,
|
||||||
|
q: Option<String>,
|
||||||
|
) -> Result<(Index, String), ResponseError> {
|
||||||
|
let mut query = SearchQuery {
|
||||||
|
q,
|
||||||
|
hybrid: Some(HybridQuery {
|
||||||
|
semantic_ratio: SemanticRatio::default(),
|
||||||
|
embedder: EMBEDDER_NAME.to_string(),
|
||||||
|
}),
|
||||||
|
limit: 20,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tenant token search_rules.
|
||||||
|
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
|
||||||
|
add_search_rules(&mut query.filter, search_rules);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TBD
|
||||||
|
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
||||||
|
|
||||||
|
let index = index_scheduler.index(&index_uid)?;
|
||||||
|
let search_kind =
|
||||||
|
search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
|
||||||
|
|
||||||
|
let permit = search_queue.try_get_search_permit().await?;
|
||||||
|
let features = index_scheduler.features();
|
||||||
|
let index_cloned = index.clone();
|
||||||
|
let search_result = tokio::task::spawn_blocking(move || {
|
||||||
|
perform_search(
|
||||||
|
index_uid.to_string(),
|
||||||
|
&index_cloned,
|
||||||
|
query,
|
||||||
|
search_kind,
|
||||||
|
RetrieveVectors::new(false),
|
||||||
|
features,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
permit.drop().await;
|
||||||
|
|
||||||
|
let search_result = search_result?;
|
||||||
|
if let Ok(ref search_result) = search_result {
|
||||||
|
// aggregate.succeed(search_result);
|
||||||
|
if search_result.degraded {
|
||||||
|
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// analytics.publish(aggregate, &req);
|
||||||
|
|
||||||
|
let search_result = search_result?;
|
||||||
|
let formatted =
|
||||||
|
format_documents(&index, search_result.hits.into_iter().map(|doc| doc.document));
|
||||||
|
let text = formatted.join("\n");
|
||||||
|
|
||||||
|
Ok((index, text))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get prompt descriptions from index scheduler
|
||||||
|
fn get_prompt_descriptions(
|
||||||
|
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
||||||
|
) -> (String, String, String) {
|
||||||
let rtxn = index_scheduler.read_txn().unwrap();
|
let rtxn = index_scheduler.read_txn().unwrap();
|
||||||
let search_in_index_description = index_scheduler
|
let search_in_index_description = index_scheduler
|
||||||
.chat_prompts(&rtxn, "searchInIndex-description")
|
.chat_prompts(&rtxn, "searchInIndex-description")
|
||||||
@ -105,39 +209,35 @@ async fn non_streamed_chat(
|
|||||||
.to_string();
|
.to_string();
|
||||||
drop(rtxn);
|
drop(rtxn);
|
||||||
|
|
||||||
|
(
|
||||||
|
search_in_index_description,
|
||||||
|
search_in_index_q_param_description,
|
||||||
|
search_in_index_index_description,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn non_streamed_chat(
|
||||||
|
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
|
||||||
|
search_queue: web::Data<SearchQueue>,
|
||||||
|
mut chat_completion: CreateChatCompletionRequest,
|
||||||
|
) -> Result<HttpResponse, ResponseError> {
|
||||||
|
let client = create_openai_client();
|
||||||
|
|
||||||
|
let (
|
||||||
|
search_in_index_description,
|
||||||
|
search_in_index_q_param_description,
|
||||||
|
search_in_index_index_description,
|
||||||
|
) = get_prompt_descriptions(&index_scheduler);
|
||||||
|
|
||||||
let mut response;
|
let mut response;
|
||||||
loop {
|
loop {
|
||||||
let tools = chat_completion.tools.get_or_insert_default();
|
setup_search_tool(
|
||||||
tools.push(
|
&mut chat_completion,
|
||||||
ChatCompletionToolArgs::default()
|
&search_in_index_description,
|
||||||
.r#type(ChatCompletionToolType::Function)
|
&search_in_index_q_param_description,
|
||||||
.function(
|
&search_in_index_index_description,
|
||||||
FunctionObjectArgs::default()
|
|
||||||
.name("searchInIndex")
|
|
||||||
.description(&search_in_index_description)
|
|
||||||
.parameters(json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_uid": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["main"],
|
|
||||||
"description": search_in_index_index_description,
|
|
||||||
},
|
|
||||||
"q": {
|
|
||||||
"type": ["string", "null"],
|
|
||||||
"description": search_in_index_q_param_description,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["index_uid", "q"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
}))
|
|
||||||
.strict(true)
|
|
||||||
.build()
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
response = client.chat().create(chat_completion.clone()).await.unwrap();
|
response = client.chat().create(chat_completion.clone()).await.unwrap();
|
||||||
|
|
||||||
let choice = &mut response.choices[0];
|
let choice = &mut response.choices[0];
|
||||||
@ -160,65 +260,10 @@ async fn non_streamed_chat(
|
|||||||
let SearchInIndexParameters { index_uid, q } =
|
let SearchInIndexParameters { index_uid, q } =
|
||||||
serde_json::from_str(&call.function.arguments).unwrap();
|
serde_json::from_str(&call.function.arguments).unwrap();
|
||||||
|
|
||||||
let mut query = SearchQuery {
|
let (_, text) =
|
||||||
q,
|
process_search_request(&index_scheduler, &search_queue, index_uid, q)
|
||||||
hybrid: Some(HybridQuery {
|
.await?;
|
||||||
semantic_ratio: SemanticRatio::default(),
|
|
||||||
embedder: EMBEDDER_NAME.to_string(),
|
|
||||||
}),
|
|
||||||
limit: 20,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tenant token search_rules.
|
|
||||||
if let Some(search_rules) =
|
|
||||||
index_scheduler.filters().get_index_search_rules(&index_uid)
|
|
||||||
{
|
|
||||||
add_search_rules(&mut query.filter, search_rules);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TBD
|
|
||||||
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
|
||||||
|
|
||||||
let index = index_scheduler.index(&index_uid)?;
|
|
||||||
let search_kind = search_kind(
|
|
||||||
&query,
|
|
||||||
index_scheduler.get_ref(),
|
|
||||||
index_uid.to_string(),
|
|
||||||
&index,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let permit = search_queue.try_get_search_permit().await?;
|
|
||||||
let features = index_scheduler.features();
|
|
||||||
let index_cloned = index.clone();
|
|
||||||
let search_result = tokio::task::spawn_blocking(move || {
|
|
||||||
perform_search(
|
|
||||||
index_uid.to_string(),
|
|
||||||
&index_cloned,
|
|
||||||
query,
|
|
||||||
search_kind,
|
|
||||||
RetrieveVectors::new(false),
|
|
||||||
features,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
permit.drop().await;
|
|
||||||
|
|
||||||
let search_result = search_result?;
|
|
||||||
if let Ok(ref search_result) = search_result {
|
|
||||||
// aggregate.succeed(search_result);
|
|
||||||
if search_result.degraded {
|
|
||||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// analytics.publish(aggregate, &req);
|
|
||||||
|
|
||||||
let search_result = search_result?;
|
|
||||||
let formatted = format_documents(
|
|
||||||
&index,
|
|
||||||
search_result.hits.into_iter().map(|doc| doc.document),
|
|
||||||
);
|
|
||||||
let text = formatted.join("\n");
|
|
||||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
||||||
ChatCompletionRequestToolMessage {
|
ChatCompletionRequestToolMessage {
|
||||||
tool_call_id: call.id,
|
tool_call_id: call.id,
|
||||||
@ -245,63 +290,22 @@ async fn streamed_chat(
|
|||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
mut chat_completion: CreateChatCompletionRequest,
|
mut chat_completion: CreateChatCompletionRequest,
|
||||||
) -> impl Responder {
|
) -> impl Responder {
|
||||||
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
|
let (
|
||||||
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
|
search_in_index_description,
|
||||||
|
search_in_index_q_param_description,
|
||||||
|
search_in_index_index_description,
|
||||||
|
) = get_prompt_descriptions(&index_scheduler);
|
||||||
|
|
||||||
let rtxn = index_scheduler.read_txn().unwrap();
|
setup_search_tool(
|
||||||
let search_in_index_description = index_scheduler
|
&mut chat_completion,
|
||||||
.chat_prompts(&rtxn, "searchInIndex-description")
|
&search_in_index_description,
|
||||||
.unwrap()
|
&search_in_index_q_param_description,
|
||||||
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION)
|
&search_in_index_index_description,
|
||||||
.to_string();
|
|
||||||
let search_in_index_q_param_description = index_scheduler
|
|
||||||
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
|
|
||||||
.unwrap()
|
|
||||||
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
|
|
||||||
.to_string();
|
|
||||||
let search_in_index_index_description = index_scheduler
|
|
||||||
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
|
|
||||||
.unwrap()
|
|
||||||
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
|
|
||||||
.to_string();
|
|
||||||
drop(rtxn);
|
|
||||||
|
|
||||||
let tools = chat_completion.tools.get_or_insert_default();
|
|
||||||
tools.push(
|
|
||||||
ChatCompletionToolArgs::default()
|
|
||||||
.r#type(ChatCompletionToolType::Function)
|
|
||||||
.function(
|
|
||||||
FunctionObjectArgs::default()
|
|
||||||
.name("searchInIndex")
|
|
||||||
.description(&search_in_index_description)
|
|
||||||
.parameters(json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_uid": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["main"],
|
|
||||||
"description": search_in_index_index_description,
|
|
||||||
},
|
|
||||||
"q": {
|
|
||||||
"type": ["string", "null"],
|
|
||||||
"description": search_in_index_q_param_description,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["index_uid", "q"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
}))
|
|
||||||
.strict(true)
|
|
||||||
.build()
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
let _join_handle = Handle::current().spawn(async move {
|
let _join_handle = Handle::current().spawn(async move {
|
||||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
let client = create_openai_client();
|
||||||
let client = Client::with_config(config);
|
|
||||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||||
|
|
||||||
'main: loop {
|
'main: loop {
|
||||||
@ -313,7 +317,9 @@ async fn streamed_chat(
|
|||||||
let delta = &resp.choices[0].delta;
|
let delta = &resp.choices[0].delta;
|
||||||
let ChatCompletionStreamResponseDelta {
|
let ChatCompletionStreamResponseDelta {
|
||||||
content,
|
content,
|
||||||
function_call: _,
|
// Using deprecated field but keeping for compatibility
|
||||||
|
#[allow(deprecated)]
|
||||||
|
function_call: _,
|
||||||
ref tool_calls,
|
ref tool_calls,
|
||||||
role: _,
|
role: _,
|
||||||
refusal: _,
|
refusal: _,
|
||||||
@ -352,7 +358,7 @@ async fn streamed_chat(
|
|||||||
None if !global_tool_calls.is_empty() => {
|
None if !global_tool_calls.is_empty() => {
|
||||||
// dbg!(&global_tool_calls);
|
// dbg!(&global_tool_calls);
|
||||||
|
|
||||||
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
|
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
||||||
mem::take(&mut global_tool_calls)
|
mem::take(&mut global_tool_calls)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(_, call)| ChatCompletionMessageToolCall {
|
.map(|(_, call)| ChatCompletionMessageToolCall {
|
||||||
@ -387,67 +393,23 @@ async fn streamed_chat(
|
|||||||
let SearchInIndexParameters { index_uid, q } =
|
let SearchInIndexParameters { index_uid, q } =
|
||||||
serde_json::from_str(&call.function.arguments).unwrap();
|
serde_json::from_str(&call.function.arguments).unwrap();
|
||||||
|
|
||||||
let mut query = SearchQuery {
|
let result = process_search_request(
|
||||||
|
&index_scheduler,
|
||||||
|
&search_queue,
|
||||||
|
index_uid,
|
||||||
q,
|
q,
|
||||||
hybrid: Some(HybridQuery {
|
|
||||||
semantic_ratio: SemanticRatio::default(),
|
|
||||||
embedder: EMBEDDER_NAME.to_string(),
|
|
||||||
}),
|
|
||||||
limit: 20,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tenant token search_rules.
|
|
||||||
if let Some(search_rules) =
|
|
||||||
index_scheduler.filters().get_index_search_rules(&index_uid)
|
|
||||||
{
|
|
||||||
add_search_rules(&mut query.filter, search_rules);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TBD
|
|
||||||
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
|
|
||||||
|
|
||||||
let index = index_scheduler.index(&index_uid).unwrap();
|
|
||||||
let search_kind = search_kind(
|
|
||||||
&query,
|
|
||||||
index_scheduler.get_ref(),
|
|
||||||
index_uid.to_string(),
|
|
||||||
&index,
|
|
||||||
)
|
)
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let permit =
|
|
||||||
search_queue.try_get_search_permit().await.unwrap();
|
|
||||||
let features = index_scheduler.features();
|
|
||||||
let index_cloned = index.clone();
|
|
||||||
let search_result = tokio::task::spawn_blocking(move || {
|
|
||||||
perform_search(
|
|
||||||
index_uid.to_string(),
|
|
||||||
&index_cloned,
|
|
||||||
query,
|
|
||||||
search_kind,
|
|
||||||
RetrieveVectors::new(false),
|
|
||||||
features,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
permit.drop().await;
|
|
||||||
|
|
||||||
let search_result = search_result.unwrap();
|
// Handle potential errors more explicitly
|
||||||
if let Ok(ref search_result) = search_result {
|
if let Err(err) = &result {
|
||||||
// aggregate.succeed(search_result);
|
// Log the error or handle it as needed
|
||||||
if search_result.degraded {
|
eprintln!("Error processing search request: {:?}", err);
|
||||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
continue;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// analytics.publish(aggregate, &req);
|
|
||||||
|
|
||||||
let search_result = search_result.unwrap();
|
let (_, text) = result.unwrap();
|
||||||
let formatted = format_documents(
|
|
||||||
&index,
|
|
||||||
search_result.hits.into_iter().map(|doc| doc.document),
|
|
||||||
);
|
|
||||||
let text = formatted.join("\n");
|
|
||||||
let tool = ChatCompletionRequestMessage::Tool(
|
let tool = ChatCompletionRequestMessage::Tool(
|
||||||
ChatCompletionRequestToolMessage {
|
ChatCompletionRequestToolMessage {
|
||||||
tool_call_id: call.id,
|
tool_call_id: call.id,
|
||||||
@ -515,7 +477,7 @@ fn format_documents(index: &Index, documents: impl Iterator<Item = Document>) ->
|
|||||||
|
|
||||||
let EmbeddingConfig {
|
let EmbeddingConfig {
|
||||||
embedder_options: _,
|
embedder_options: _,
|
||||||
prompt: PromptData { template, max_bytes },
|
prompt: PromptData { template, max_bytes: _ },
|
||||||
quantized: _,
|
quantized: _,
|
||||||
} = config;
|
} = config;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user