Correctly support document templates on the chat API

This commit is contained in:
Clément Renault 2025-05-21 15:32:34 +02:00 committed by Kerollmops
parent 39ecea5e9e
commit 97d74bd2b9
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
8 changed files with 72 additions and 52 deletions

1
Cargo.lock generated
View File

@ -3592,6 +3592,7 @@ dependencies = [
"brotli 6.0.0", "brotli 6.0.0",
"bstr", "bstr",
"build-info", "build-info",
"bumpalo",
"byte-unit", "byte-unit",
"bytes", "bytes",
"cargo_toml", "cargo_toml",

View File

@ -32,6 +32,7 @@ async-trait = "0.1.85"
bstr = "1.11.3" bstr = "1.11.3"
byte-unit = { version = "5.1.6", features = ["serde"] } byte-unit = { version = "5.1.6", features = ["serde"] }
bytes = "1.9.0" bytes = "1.9.0"
bumpalo = "3.16.0"
clap = { version = "4.5.24", features = ["derive", "env"] } clap = { version = "4.5.24", features = ["derive", "env"] }
crossbeam-channel = "0.5.15" crossbeam-channel = "0.5.15"
deserr = { version = "0.6.3", features = ["actix-web"] } deserr = { version = "0.6.3", features = ["actix-web"] }

View File

@ -1,5 +1,7 @@
use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::mem; use std::mem;
use std::sync::RwLock;
use std::time::Duration; use std::time::Duration;
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
@ -16,27 +18,33 @@ use async_openai::types::{
FunctionObjectArgs, FunctionObjectArgs,
}; };
use async_openai::Client; use async_openai::Client;
use bumpalo::Bump;
use futures::StreamExt; use futures::StreamExt;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController; use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::IndexEmbeddingConfig; use meilisearch_types::milli::index::ChatConfig;
use meilisearch_types::milli::prompt::PromptData; use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::vector::EmbeddingConfig; use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::{Document, Index}; use meilisearch_types::milli::{
use serde::{Deserialize, Serialize}; DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
};
use meilisearch_types::Index;
use serde::Deserialize;
use serde_json::json; use serde_json::json;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, GlobalChatSettings}; use super::settings::chat::{ChatPrompts, GlobalChatSettings};
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _}; use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search::search_kind; use crate::routes::indexes::search::search_kind;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, HybridQuery, RetrieveVectors, SearchQuery, SemanticRatio, add_search_rules, prepare_search, search_from_kind, HybridQuery, SearchQuery, SemanticRatio,
}; };
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
@ -175,15 +183,22 @@ async fn process_search_request(
let permit = search_queue.try_get_search_permit().await?; let permit = search_queue.try_get_search_permit().await?;
let features = index_scheduler.features(); let features = index_scheduler.features();
let index_cloned = index.clone(); let index_cloned = index.clone();
let search_result = tokio::task::spawn_blocking(move || { let search_result = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> {
perform_search( let rtxn = index_cloned.read_txn()?;
index_uid.to_string(), let time_budget = match index_cloned
&index_cloned, .search_cutoff(&rtxn)
query, .map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))?
search_kind, {
RetrieveVectors::new(false), Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)),
features, None => TimeBudget::default(),
) };
let (search, _is_finite_pagination, _max_total_hits, _offset) =
prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?;
search_from_kind(index_uid, search_kind, search)
.map(|(search_results, _)| search_results)
.map_err(ResponseError::from)
}) })
.await; .await;
permit.drop().await; permit.drop().await;
@ -198,9 +213,11 @@ async fn process_search_request(
// analytics.publish(aggregate, &req); // analytics.publish(aggregate, &req);
let search_result = search_result?; let search_result = search_result?;
let formatted = let rtxn = index.read_txn()?;
format_documents(&index, search_result.hits.into_iter().map(|doc| doc.document)); let render_alloc = Bump::new();
let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?;
let text = formatted.join("\n"); let text = formatted.join("\n");
drop(rtxn);
Ok((index, text)) Ok((index, text))
} }
@ -506,31 +523,36 @@ struct SearchInIndexParameters {
q: Option<String>, q: Option<String>,
} }
fn format_documents(index: &Index, documents: impl Iterator<Item = Document>) -> Vec<String> { fn format_documents<'t, 'doc>(
let rtxn = index.read_txn().unwrap(); rtxn: &RoTxn<'t>,
let IndexEmbeddingConfig { name: _, config, user_provided: _ } = index index: &Index,
.embedding_configs(&rtxn) doc_alloc: &'doc Bump,
.unwrap() internal_docids: Vec<DocumentId>,
) -> Result<Vec<&'doc str>, ResponseError> {
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
let prompt = Prompt::new(template, max_bytes).unwrap();
let fid_map = index.fields_ids_map(rtxn)?;
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
let global = RwLock::new(fid_map_with_meta);
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
let external_ids: Vec<String> = index
.external_id_of(rtxn, internal_docids.iter().copied())?
.into_iter() .into_iter()
.find(|conf| conf.name == EMBEDDER_NAME) .collect::<Result<_, _>>()?;
.unwrap();
let EmbeddingConfig { let mut renders = Vec::new();
embedder_options: _, for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
prompt: PromptData { template, max_bytes: _ }, let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
quantized: _, Some(doc) => doc,
} = config; None => continue,
};
#[derive(Serialize)] let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
struct Doc<T: Serialize> { renders.push(text);
doc: T,
} }
let template = liquid::ParserBuilder::with_stdlib().build().unwrap().parse(&template).unwrap(); Ok(renders)
documents
.map(|doc| {
let object = liquid::to_object(&Doc { doc }).unwrap();
template.render(&object).unwrap()
})
.collect()
} }

View File

@ -1,5 +1,3 @@
use std::collections::BTreeMap;
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
use actix_web::HttpResponse; use actix_web::HttpResponse;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
@ -51,7 +49,6 @@ pub struct GlobalChatSettings {
pub base_api: Option<String>, pub base_api: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub prompts: ChatPrompts, pub prompts: ChatPrompts,
pub indexes: BTreeMap<String, ChatIndexSettings>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -105,7 +102,6 @@ impl Default for GlobalChatSettings {
.to_string(), .to_string(),
pre_query: "".to_string(), pre_query: "".to_string(),
}, },
indexes: BTreeMap::new(),
} }
} }
} }

View File

@ -882,7 +882,7 @@ pub fn add_search_rules(filter: &mut Option<Value>, rules: IndexSearchRules) {
} }
} }
fn prepare_search<'t>( pub fn prepare_search<'t>(
index: &'t Index, index: &'t Index,
rtxn: &'t RoTxn, rtxn: &'t RoTxn,
query: &'t SearchQuery, query: &'t SearchQuery,

View File

@ -32,13 +32,13 @@ impl ExternalDocumentsIds {
&self, &self,
rtxn: &RoTxn<'_>, rtxn: &RoTxn<'_>,
external_id: A, external_id: A,
) -> heed::Result<Option<u32>> { ) -> heed::Result<Option<DocumentId>> {
self.0.get(rtxn, external_id.as_ref()) self.0.get(rtxn, external_id.as_ref())
} }
/// An helper function to debug this type, returns an `HashMap` of both, /// An helper function to debug this type, returns an `HashMap` of both,
/// soft and hard fst maps, combined. /// soft and hard fst maps, combined.
pub fn to_hash_map(&self, rtxn: &RoTxn<'_>) -> heed::Result<HashMap<String, u32>> { pub fn to_hash_map(&self, rtxn: &RoTxn<'_>) -> heed::Result<HashMap<String, DocumentId>> {
let mut map = HashMap::default(); let mut map = HashMap::default();
for result in self.0.iter(rtxn)? { for result in self.0.iter(rtxn)? {
let (external, internal) = result?; let (external, internal) = result?;

View File

@ -7,6 +7,7 @@ use crate::FieldId;
mod global; mod global;
pub mod metadata; pub mod metadata;
pub use global::GlobalFieldsIdsMap; pub use global::GlobalFieldsIdsMap;
pub use metadata::{FieldIdMapWithMetadata, MetadataBuilder};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldsIdsMap { pub struct FieldsIdsMap {

View File

@ -52,18 +52,19 @@ pub use search::new::{
}; };
use serde_json::Value; use serde_json::Value;
pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
pub use {charabia as tokenizer, heed, rhai}; pub use {arroy, charabia as tokenizer, heed, rhai};
pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError}; pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError};
pub use self::attribute_patterns::AttributePatterns; pub use self::attribute_patterns::{AttributePatterns, PatternMatch};
pub use self::attribute_patterns::PatternMatch;
pub use self::criterion::{default_criteria, Criterion, CriterionError}; pub use self::criterion::{default_criteria, Criterion, CriterionError};
pub use self::error::{ pub use self::error::{
Error, FieldIdMapMissingEntry, InternalError, SerializationError, UserError, Error, FieldIdMapMissingEntry, InternalError, SerializationError, UserError,
}; };
pub use self::external_documents_ids::ExternalDocumentsIds; pub use self::external_documents_ids::ExternalDocumentsIds;
pub use self::fieldids_weights_map::FieldidsWeightsMap; pub use self::fieldids_weights_map::FieldidsWeightsMap;
pub use self::fields_ids_map::{FieldsIdsMap, GlobalFieldsIdsMap}; pub use self::fields_ids_map::{
FieldIdMapWithMetadata, FieldsIdsMap, GlobalFieldsIdsMap, MetadataBuilder,
};
pub use self::filterable_attributes_rules::{ pub use self::filterable_attributes_rules::{
FilterFeatures, FilterableAttributesFeatures, FilterableAttributesPatterns, FilterFeatures, FilterableAttributesFeatures, FilterableAttributesPatterns,
FilterableAttributesRule, FilterableAttributesRule,
@ -84,8 +85,6 @@ pub use self::search::{
}; };
pub use self::update::ChannelCongestion; pub use self::update::ChannelCongestion;
pub use arroy;
pub type Result<T> = std::result::Result<T, error::Error>; pub type Result<T> = std::result::Result<T, error::Error>;
pub type Attribute = u32; pub type Attribute = u32;