mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-07-27 00:31:02 +00:00
321 lines
9.4 KiB
Rust
321 lines
9.4 KiB
Rust
mod context;
|
|
mod document;
|
|
pub(crate) mod error;
|
|
mod fields;
|
|
|
|
use std::cell::RefCell;
|
|
use std::convert::TryFrom;
|
|
use std::fmt::Debug;
|
|
use std::num::NonZeroUsize;
|
|
|
|
use bumpalo::Bump;
|
|
pub(crate) use document::{Document, ParseableDocument};
|
|
use error::{NewPromptError, RenderPromptError};
|
|
pub use fields::{BorrowedFields, OwnedFields};
|
|
use heed::RoTxn;
|
|
use liquid::model::Value as LiquidValue;
|
|
use liquid::ValueView;
|
|
|
|
pub use self::context::Context;
|
|
use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
|
|
use crate::prompt::document::JsonDocument;
|
|
use crate::update::del_add::DelAdd;
|
|
use crate::update::new::document::DocumentFromDb;
|
|
use crate::{GlobalFieldsIdsMap, Index, MetadataBuilder};
|
|
|
|
pub struct Prompt {
|
|
template: liquid::Template,
|
|
template_text: String,
|
|
max_bytes: Option<NonZeroUsize>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct PromptData {
|
|
pub template: String,
|
|
pub max_bytes: Option<NonZeroUsize>,
|
|
}
|
|
|
|
impl From<Prompt> for PromptData {
|
|
fn from(value: Prompt) -> Self {
|
|
Self { template: value.template_text, max_bytes: value.max_bytes }
|
|
}
|
|
}
|
|
|
|
impl TryFrom<PromptData> for Prompt {
|
|
type Error = NewPromptError;
|
|
|
|
fn try_from(value: PromptData) -> Result<Self, Self::Error> {
|
|
Prompt::new(value.template, value.max_bytes)
|
|
}
|
|
}
|
|
|
|
impl Clone for Prompt {
|
|
fn clone(&self) -> Self {
|
|
let template_text = self.template_text.clone();
|
|
Self {
|
|
template: new_template(&template_text).unwrap(),
|
|
template_text,
|
|
max_bytes: self.max_bytes,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> {
|
|
liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text)
|
|
}
|
|
|
|
fn default_template() -> liquid::Template {
|
|
new_template(default_template_text()).unwrap()
|
|
}
|
|
|
|
pub fn default_template_text() -> &'static str {
|
|
"{% for field in fields %}\
|
|
{% if field.is_searchable and field.value != nil %}\
|
|
{{ field.name }}: {{ field.value }}\n\
|
|
{% endif %}\
|
|
{% endfor %}"
|
|
}
|
|
|
|
pub fn default_max_bytes() -> NonZeroUsize {
|
|
NonZeroUsize::new(400).unwrap()
|
|
}
|
|
|
|
impl Default for Prompt {
|
|
fn default() -> Self {
|
|
Self {
|
|
template: default_template(),
|
|
template_text: default_template_text().into(),
|
|
max_bytes: Some(default_max_bytes()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for PromptData {
|
|
fn default() -> Self {
|
|
Self { template: default_template_text().into(), max_bytes: Some(default_max_bytes()) }
|
|
}
|
|
}
|
|
|
|
impl Prompt {
|
|
pub fn new(template: String, max_bytes: Option<NonZeroUsize>) -> Result<Self, NewPromptError> {
|
|
let this = Self {
|
|
template: liquid::ParserBuilder::with_stdlib()
|
|
.build()
|
|
.unwrap()
|
|
.parse(&template)
|
|
.map_err(NewPromptError::cannot_parse_template)?,
|
|
template_text: template,
|
|
max_bytes,
|
|
};
|
|
|
|
Ok(this)
|
|
}
|
|
|
|
pub fn render_document<
|
|
'a, // lifetime of the borrow of the document
|
|
'doc, // lifetime of the allocator, will live for an entire chunk of documents
|
|
>(
|
|
&self,
|
|
external_docid: &str,
|
|
document: impl crate::update::new::document::Document<'a> + Debug,
|
|
field_id_map: &RefCell<GlobalFieldsIdsMap>,
|
|
doc_alloc: &'doc Bump,
|
|
) -> Result<&'doc str, RenderPromptError> {
|
|
let document = ParseableDocument::new(document, doc_alloc);
|
|
let fields = BorrowedFields::new(&document, field_id_map, doc_alloc);
|
|
let context = Context::new(&document, &fields);
|
|
let mut rendered = bumpalo::collections::Vec::with_capacity_in(
|
|
self.max_bytes.unwrap_or_else(default_max_bytes).get(),
|
|
doc_alloc,
|
|
);
|
|
self.template.render_to(&mut rendered, &context).map_err(|liquid_error| {
|
|
RenderPromptError::missing_context_with_external_docid(
|
|
external_docid.to_owned(),
|
|
liquid_error,
|
|
)
|
|
})?;
|
|
Ok(std::str::from_utf8(rendered.into_bump_slice())
|
|
.expect("render can only write UTF-8 because all inputs and processing preserve utf-8"))
|
|
}
|
|
|
|
pub fn render_kvdeladd(
|
|
&self,
|
|
document: &obkv::KvReaderU16,
|
|
side: DelAdd,
|
|
field_id_map: &FieldIdMapWithMetadata,
|
|
) -> Result<String, RenderPromptError> {
|
|
let document = Document::new(document, side, field_id_map.as_fields_ids_map());
|
|
let fields = OwnedFields::new(&document, field_id_map);
|
|
let context = Context::new(&document, &fields);
|
|
|
|
let mut rendered =
|
|
self.template.render(&context).map_err(RenderPromptError::missing_context)?;
|
|
if let Some(max_bytes) = self.max_bytes {
|
|
truncate(&mut rendered, max_bytes.get());
|
|
}
|
|
Ok(rendered)
|
|
}
|
|
}
|
|
|
|
fn truncate(s: &mut String, max_bytes: usize) {
|
|
if max_bytes >= s.len() {
|
|
return;
|
|
}
|
|
for i in (0..=max_bytes).rev() {
|
|
if s.is_char_boundary(i) {
|
|
s.truncate(i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn get_inline_document_fields(
|
|
index: &Index,
|
|
rtxn: &RoTxn<'_>,
|
|
inline_doc: &serde_json::Value,
|
|
) -> Result<LiquidValue, crate::Error> {
|
|
let fid_map_with_meta = index.fields_ids_map_with_metadata(rtxn)?;
|
|
let inline_doc = JsonDocument::new(inline_doc);
|
|
let fields = OwnedFields::new(&inline_doc, &fid_map_with_meta);
|
|
|
|
Ok(fields.to_value())
|
|
}
|
|
|
|
pub fn get_document(
|
|
index: &Index,
|
|
rtxn: &RoTxn<'_>,
|
|
external_id: &str,
|
|
with_fields: bool,
|
|
) -> Result<Option<(LiquidValue, Option<LiquidValue>)>, crate::Error> {
|
|
let Some(internal_id) = index.external_documents_ids().get(rtxn, external_id)? else {
|
|
return Ok(None);
|
|
};
|
|
|
|
let fid_map = index.fields_ids_map(rtxn)?;
|
|
let Some(document_from_db) = DocumentFromDb::new(internal_id, rtxn, index, &fid_map)? else {
|
|
return Ok(None);
|
|
};
|
|
|
|
let doc_alloc = Bump::new();
|
|
let parseable_document = ParseableDocument::new(document_from_db, &doc_alloc);
|
|
|
|
if with_fields {
|
|
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
|
|
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
|
|
let fields = OwnedFields::new(&parseable_document, &fid_map_with_meta);
|
|
|
|
Ok(Some((parseable_document.to_value(), Some(fields.to_value()))))
|
|
} else {
|
|
Ok(Some((parseable_document.to_value(), None)))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::Prompt;
|
|
use crate::error::FaultSource;
|
|
use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
|
|
use crate::prompt::truncate;
|
|
|
|
#[test]
|
|
fn default_template() {
|
|
// does not panic
|
|
Prompt::default();
|
|
}
|
|
|
|
#[test]
|
|
fn empty_template() {
|
|
Prompt::new("".into(), None).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn template_ok() {
|
|
Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn template_syntax() {
|
|
assert!(matches!(
|
|
Prompt::new("{{doc.title: {{doc.overview}}".into(), None),
|
|
Err(NewPromptError {
|
|
kind: NewPromptErrorKind::CannotParseTemplate(_),
|
|
fault: FaultSource::User
|
|
})
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
#[ignore] // See <https://github.com/meilisearch/meilisearch/pull/5593> for explanation
|
|
fn template_missing_doc() {
|
|
assert!(matches!(
|
|
Prompt::new("{{title}}: {{overview}}".into(), None),
|
|
Err(NewPromptError {
|
|
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
|
|
fault: FaultSource::User
|
|
})
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn template_nested_doc() {
|
|
Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn template_fields() {
|
|
Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn template_fields_ok() {
|
|
Prompt::new(
|
|
"{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(),
|
|
None,
|
|
)
|
|
.unwrap();
|
|
}
|
|
|
|
#[test]
|
|
#[ignore] // See <https://github.com/meilisearch/meilisearch/pull/5593> for explanation
|
|
fn template_fields_invalid() {
|
|
assert!(matches!(
|
|
// intentionally garbled field
|
|
Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None),
|
|
Err(NewPromptError {
|
|
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
|
|
fault: FaultSource::User
|
|
})
|
|
));
|
|
}
|
|
|
|
// todo: test truncation
|
|
#[test]
|
|
fn template_truncation() {
|
|
let mut s = "インテル ザー ビーグル".to_string();
|
|
|
|
truncate(&mut s, 42);
|
|
assert_eq!(s, "インテル ザー ビーグル");
|
|
|
|
assert_eq!(s.len(), 32);
|
|
truncate(&mut s, 32);
|
|
assert_eq!(s, "インテル ザー ビーグル");
|
|
|
|
truncate(&mut s, 31);
|
|
assert_eq!(s, "インテル ザー ビーグ");
|
|
truncate(&mut s, 30);
|
|
assert_eq!(s, "インテル ザー ビーグ");
|
|
truncate(&mut s, 28);
|
|
assert_eq!(s, "インテル ザー ビー");
|
|
truncate(&mut s, 26);
|
|
assert_eq!(s, "インテル ザー ビー");
|
|
truncate(&mut s, 25);
|
|
assert_eq!(s, "インテル ザー ビ");
|
|
|
|
assert_eq!("イ".len(), 3);
|
|
truncate(&mut s, 3);
|
|
assert_eq!(s, "イ");
|
|
truncate(&mut s, 2);
|
|
assert_eq!(s, "");
|
|
}
|
|
}
|