Report the sources

This commit is contained in:
Clément Renault 2025-05-27 11:48:12 +02:00 committed by Kerollmops
parent f635a8f9f3
commit fa139ee601
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -32,9 +32,10 @@ use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap,
MetadataBuilder, TimeBudget,
};
use meilisearch_types::Index;
use meilisearch_types::{Document, Index};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::runtime::Handle;
@ -55,6 +56,8 @@ use crate::search_queue::SearchQueue;
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources";
const MEILI_REPORT_ERRORS_NAME: &str = "_meiliReportErrors";
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
pub fn configure(cfg: &mut web::ServiceConfig) {
@ -93,10 +96,16 @@ async fn chat(
pub struct FunctionSupport {
/// Defines if we can call the _meiliSearchProgress function
/// to inform the front-end about what we are searching for.
progress: bool,
report_progress: bool,
/// Defines if we can call the _meiliSearchSources function
/// to inform the front-end about the sources of the search.
report_sources: bool,
/// Defines if we can call the _meiliAppendConversationMessage
/// function to provide the messages to append into the conversation.
append_to_conversation: bool,
/// Defines if we can call the _meiliReportErrors function
/// to inform the front-end about potential errors.
report_errors: bool,
}
/// Setup search tool in chat completion request
@ -112,18 +121,28 @@ fn setup_search_tool(
}
// Remove internal tools used for front-end notifications as they should be hidden from the LLM.
let mut progress = false;
let mut report_progress = false;
let mut report_sources = false;
let mut append_to_conversation = false;
let mut report_errors = false;
tools.retain(|tool| {
match tool.function.name.as_str() {
MEILI_SEARCH_PROGRESS_NAME => {
progress = true;
report_progress = true;
false
}
MEILI_SEARCH_SOURCES_NAME => {
report_sources = true;
false
}
MEILI_APPEND_CONVERSATION_MESSAGE_NAME => {
append_to_conversation = true;
false
}
MEILI_REPORT_ERRORS_NAME => {
report_errors = true;
false
}
_ => true, // keep other tools
}
});
@ -188,7 +207,7 @@ fn setup_search_tool(
}),
);
Ok(FunctionSupport { progress, append_to_conversation })
Ok(FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors })
}
/// Process search request and return formatted results
@ -199,7 +218,7 @@ async fn process_search_request(
auth_token: &str,
index_uid: String,
q: Option<String>,
) -> Result<(Index, String), ResponseError> {
) -> Result<(Index, Vec<Document>, String), ResponseError> {
// TBD
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
@ -276,22 +295,33 @@ async fn process_search_request(
permit.drop().await;
let output = output?;
if let Ok((_, ref search_result)) = output {
let mut documents = Vec::new();
if let Ok((ref rtxn, ref search_result)) = output {
// aggregate.succeed(search_result);
if search_result.degraded {
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
}
let fields_ids_map = index.fields_ids_map(rtxn)?;
let displayed_fields = index.displayed_fields_ids(rtxn)?;
for &document_id in &search_result.documents_ids {
let obkv = index.document(rtxn, document_id)?;
let document = match displayed_fields {
Some(ref fields) => obkv_to_json(fields, &fields_ids_map, obkv)?,
None => all_obkv_to_json(obkv, &fields_ids_map)?,
};
documents.push(document);
}
}
// analytics.publish(aggregate, &req);
let (rtxn, search_result) = output?;
// let rtxn = index.read_txn()?;
let render_alloc = Bump::new();
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
let text = formatted.join("\n");
drop(rtxn);
Ok((index, text))
Ok((index, documents, text))
}
async fn non_streamed_chat(
@ -319,7 +349,7 @@ async fn non_streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { progress, append_to_conversation } =
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let mut response;
@ -359,7 +389,7 @@ async fn non_streamed_chat(
};
let text = match result {
Ok((_, text)) => text,
Ok((_, documents, text)) => text,
Err(err) => err,
};
@ -411,7 +441,7 @@ async fn streamed_chat(
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { progress, append_to_conversation } =
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10);
@ -507,8 +537,9 @@ async fn streamed_chat(
);
for call in meili_calls {
if progress {
if report_progress {
let call = MeiliSearchProgress {
call_id: call.id.to_string(),
function_name: call.function.name.clone(),
function_arguments: call
.function
@ -573,7 +604,24 @@ async fn streamed_chat(
};
let text = match result {
Ok((_, text)) => text,
Ok((_index, documents, text)) => {
if report_sources {
let call = MeiliSearchSources {
call_id: call.id.to_string(),
sources: documents,
};
let resp = call.create_response(resp.clone());
// Send the event of "we are doing a search"
if let Err(SendError(_)) = tx
.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
{
return;
}
}
text
},
Err(err) => err,
};
@ -651,8 +699,10 @@ async fn streamed_chat(
}
#[derive(Debug, Clone, Serialize)]
/// Give context about what Meilisearch is doing.
/// Provides information about the current Meilisearch search operation.
struct MeiliSearchProgress {
/// The call ID to track the sources of the search.
pub call_id: String,
/// The name of the function we are executing.
pub function_name: String,
/// The arguments of the function we are executing, encoded in JSON.
@ -690,6 +740,47 @@ impl MeiliSearchProgress {
}
}
#[derive(Debug, Clone, Serialize)]
/// Provides sources of the search.
struct MeiliSearchSources {
/// The call ID to track the original search associated to those sources.
pub call_id: String,
/// The documents associated with the search (call_id).
/// Only the displayed attributes of the documents are returned.
pub sources: Vec<Document>,
}
impl MeiliSearchSources {
fn create_response(
&self,
mut resp: CreateChatCompletionStreamResponse,
) -> CreateChatCompletionStreamResponse {
let call_text = serde_json::to_string(self).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
arguments: Some(call_text),
}),
};
resp.choices[0] = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
tool_calls: Some(vec![tool_call]),
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: None,
logprobs: None,
};
resp
}
}
struct MeiliAppendConversationMessage(pub ChatCompletionRequestMessage);
impl MeiliAppendConversationMessage {