Factorise a bit the code

This commit is contained in:
Clément Renault 2025-05-15 15:39:38 +02:00 committed by Kerollmops
parent d94f16b1d2
commit e603e221d5
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -55,6 +55,14 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("").route(web::post().to(chat))); cfg.service(web::resource("").route(web::post().to(chat)));
} }
/// Creates OpenAI client with API key
fn create_openai_client() -> Client<OpenAIConfig> {
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
let config = OpenAIConfig::default().with_api_key(&api_key);
Client::with_config(config)
}
/// Get a chat completion /// Get a chat completion
async fn chat( async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
@ -77,36 +85,13 @@ async fn chat(
} }
} }
async fn non_streamed_chat( /// Setup search tool in chat completion request
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>, fn setup_search_tool(
search_queue: web::Data<SearchQueue>, chat_completion: &mut CreateChatCompletionRequest,
mut chat_completion: CreateChatCompletionRequest, search_in_index_description: &str,
) -> Result<HttpResponse, ResponseError> { search_in_index_q_param_description: &str,
let api_key = std::env::var("MEILI_OPENAI_API_KEY") search_in_index_index_description: &str,
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); ) {
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
let client = Client::with_config(config);
let rtxn = index_scheduler.read_txn().unwrap();
let search_in_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION)
.to_string();
let search_in_index_q_param_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
.to_string();
let search_in_index_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
.to_string();
drop(rtxn);
let mut response;
loop {
let tools = chat_completion.tools.get_or_insert_default(); let tools = chat_completion.tools.get_or_insert_default();
tools.push( tools.push(
ChatCompletionToolArgs::default() ChatCompletionToolArgs::default()
@ -114,7 +99,7 @@ async fn non_streamed_chat(
.function( .function(
FunctionObjectArgs::default() FunctionObjectArgs::default()
.name("searchInIndex") .name("searchInIndex")
.description(&search_in_index_description) .description(search_in_index_description)
.parameters(json!({ .parameters(json!({
"type": "object", "type": "object",
"properties": { "properties": {
@ -138,28 +123,15 @@ async fn non_streamed_chat(
.build() .build()
.unwrap(), .unwrap(),
); );
response = client.chat().create(chat_completion.clone()).await.unwrap(); }
let choice = &mut response.choices[0];
match choice.finish_reason {
Some(FinishReason::ToolCalls) => {
let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default();
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
tool_calls.into_iter().partition(|call| call.function.name == "searchInIndex");
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
for call in meili_calls {
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
/// Process search request and return formatted results
async fn process_search_request(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
search_queue: &web::Data<SearchQueue>,
index_uid: String,
q: Option<String>,
) -> Result<(Index, String), ResponseError> {
let mut query = SearchQuery { let mut query = SearchQuery {
q, q,
hybrid: Some(HybridQuery { hybrid: Some(HybridQuery {
@ -171,9 +143,7 @@ async fn non_streamed_chat(
}; };
// Tenant token search_rules. // Tenant token search_rules.
if let Some(search_rules) = if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
index_scheduler.filters().get_index_search_rules(&index_uid)
{
add_search_rules(&mut query.filter, search_rules); add_search_rules(&mut query.filter, search_rules);
} }
@ -181,12 +151,8 @@ async fn non_streamed_chat(
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query); // let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let search_kind = search_kind( let search_kind =
&query, search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?;
index_scheduler.get_ref(),
index_uid.to_string(),
&index,
)?;
let permit = search_queue.try_get_search_permit().await?; let permit = search_queue.try_get_search_permit().await?;
let features = index_scheduler.features(); let features = index_scheduler.features();
@ -214,11 +180,90 @@ async fn non_streamed_chat(
// analytics.publish(aggregate, &req); // analytics.publish(aggregate, &req);
let search_result = search_result?; let search_result = search_result?;
let formatted = format_documents( let formatted =
&index, format_documents(&index, search_result.hits.into_iter().map(|doc| doc.document));
search_result.hits.into_iter().map(|doc| doc.document),
);
let text = formatted.join("\n"); let text = formatted.join("\n");
Ok((index, text))
}
/// Get prompt descriptions from index scheduler
fn get_prompt_descriptions(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
) -> (String, String, String) {
let rtxn = index_scheduler.read_txn().unwrap();
let search_in_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION)
.to_string();
let search_in_index_q_param_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
.to_string();
let search_in_index_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
.to_string();
drop(rtxn);
(
search_in_index_description,
search_in_index_q_param_description,
search_in_index_index_description,
)
}
async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
let client = create_openai_client();
let (
search_in_index_description,
search_in_index_q_param_description,
search_in_index_index_description,
) = get_prompt_descriptions(&index_scheduler);
let mut response;
loop {
setup_search_tool(
&mut chat_completion,
&search_in_index_description,
&search_in_index_q_param_description,
&search_in_index_index_description,
);
response = client.chat().create(chat_completion.clone()).await.unwrap();
let choice = &mut response.choices[0];
match choice.finish_reason {
Some(FinishReason::ToolCalls) => {
let tool_calls = mem::take(&mut choice.message.tool_calls).unwrap_or_default();
let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
tool_calls.into_iter().partition(|call| call.function.name == "searchInIndex");
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
for call in meili_calls {
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
let (_, text) =
process_search_request(&index_scheduler, &search_queue, index_uid, q)
.await?;
chat_completion.messages.push(ChatCompletionRequestMessage::Tool( chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage { ChatCompletionRequestToolMessage {
tool_call_id: call.id, tool_call_id: call.id,
@ -245,63 +290,22 @@ async fn streamed_chat(
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest, mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder { ) -> impl Responder {
let api_key = std::env::var("MEILI_OPENAI_API_KEY") let (
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); search_in_index_description,
search_in_index_q_param_description,
search_in_index_index_description,
) = get_prompt_descriptions(&index_scheduler);
let rtxn = index_scheduler.read_txn().unwrap(); setup_search_tool(
let search_in_index_description = index_scheduler &mut chat_completion,
.chat_prompts(&rtxn, "searchInIndex-description") &search_in_index_description,
.unwrap() &search_in_index_q_param_description,
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION) &search_in_index_index_description,
.to_string();
let search_in_index_q_param_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
.to_string();
let search_in_index_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
.to_string();
drop(rtxn);
let tools = chat_completion.tools.get_or_insert_default();
tools.push(
ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(
FunctionObjectArgs::default()
.name("searchInIndex")
.description(&search_in_index_description)
.parameters(json!({
"type": "object",
"properties": {
"index_uid": {
"type": "string",
"enum": ["main"],
"description": search_in_index_index_description,
},
"q": {
"type": ["string", "null"],
"description": search_in_index_q_param_description,
}
},
"required": ["index_uid", "q"],
"additionalProperties": false,
}))
.strict(true)
.build()
.unwrap(),
)
.build()
.unwrap(),
); );
let (tx, rx) = tokio::sync::mpsc::channel(10); let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move { let _join_handle = Handle::current().spawn(async move {
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let client = create_openai_client();
let client = Client::with_config(config);
let mut global_tool_calls = HashMap::<u32, Call>::new(); let mut global_tool_calls = HashMap::<u32, Call>::new();
'main: loop { 'main: loop {
@ -313,6 +317,8 @@ async fn streamed_chat(
let delta = &resp.choices[0].delta; let delta = &resp.choices[0].delta;
let ChatCompletionStreamResponseDelta { let ChatCompletionStreamResponseDelta {
content, content,
// Using deprecated field but keeping for compatibility
#[allow(deprecated)]
function_call: _, function_call: _,
ref tool_calls, ref tool_calls,
role: _, role: _,
@ -352,7 +358,7 @@ async fn streamed_chat(
None if !global_tool_calls.is_empty() => { None if !global_tool_calls.is_empty() => {
// dbg!(&global_tool_calls); // dbg!(&global_tool_calls);
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls) mem::take(&mut global_tool_calls)
.into_iter() .into_iter()
.map(|(_, call)| ChatCompletionMessageToolCall { .map(|(_, call)| ChatCompletionMessageToolCall {
@ -387,67 +393,23 @@ async fn streamed_chat(
let SearchInIndexParameters { index_uid, q } = let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap(); serde_json::from_str(&call.function.arguments).unwrap();
let mut query = SearchQuery { let result = process_search_request(
&index_scheduler,
&search_queue,
index_uid,
q, q,
hybrid: Some(HybridQuery {
semantic_ratio: SemanticRatio::default(),
embedder: EMBEDDER_NAME.to_string(),
}),
limit: 20,
..Default::default()
};
// Tenant token search_rules.
if let Some(search_rules) =
index_scheduler.filters().get_index_search_rules(&index_uid)
{
add_search_rules(&mut query.filter, search_rules);
}
// TBD
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
let index = index_scheduler.index(&index_uid).unwrap();
let search_kind = search_kind(
&query,
index_scheduler.get_ref(),
index_uid.to_string(),
&index,
) )
.unwrap();
let permit =
search_queue.try_get_search_permit().await.unwrap();
let features = index_scheduler.features();
let index_cloned = index.clone();
let search_result = tokio::task::spawn_blocking(move || {
perform_search(
index_uid.to_string(),
&index_cloned,
query,
search_kind,
RetrieveVectors::new(false),
features,
)
})
.await; .await;
permit.drop().await;
let search_result = search_result.unwrap(); // Handle potential errors more explicitly
if let Ok(ref search_result) = search_result { if let Err(err) = &result {
// aggregate.succeed(search_result); // Log the error or handle it as needed
if search_result.degraded { eprintln!("Error processing search request: {:?}", err);
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); continue;
} }
}
// analytics.publish(aggregate, &req);
let search_result = search_result.unwrap(); let (_, text) = result.unwrap();
let formatted = format_documents(
&index,
search_result.hits.into_iter().map(|doc| doc.document),
);
let text = formatted.join("\n");
let tool = ChatCompletionRequestMessage::Tool( let tool = ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage { ChatCompletionRequestToolMessage {
tool_call_id: call.id, tool_call_id: call.id,
@ -515,7 +477,7 @@ fn format_documents(index: &Index, documents: impl Iterator<Item = Document>) ->
let EmbeddingConfig { let EmbeddingConfig {
embedder_options: _, embedder_options: _,
prompt: PromptData { template, max_bytes }, prompt: PromptData { template, max_bytes: _ },
quantized: _, quantized: _,
} = config; } = config;