Implement localized attributes settings

This commit is contained in:
ManyTheFish
2024-07-23 14:51:36 +02:00
committed by Louis Dureuil
parent 90c0a6db7d
commit 04fa44e7eb
18 changed files with 405 additions and 209 deletions

View File

@ -3,7 +3,7 @@ use std::fs::File;
use std::io::BufReader;
use std::{io, mem, str};
use charabia::{Language, SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder};
use charabia::{SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder};
use obkv::{KvReader, KvWriterU16};
use roaring::RoaringBitmap;
use serde_json::Value;
@ -11,7 +11,7 @@ use serde_json::Value;
use super::helpers::{create_sorter, keep_latest_obkv, sorter_into_reader, GrenadParameters};
use crate::error::{InternalError, SerializationError};
use crate::update::del_add::{del_add_from_two_obkvs, DelAdd, KvReaderDelAdd};
use crate::update::settings::InnerIndexSettingsDiff;
use crate::update::settings::{InnerIndexSettings, InnerIndexSettingsDiff};
use crate::{FieldId, Result, MAX_POSITION_PER_ATTRIBUTE, MAX_WORD_LENGTH};
/// Extracts the word and positions where this word appear and
@ -57,13 +57,9 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
.map(|s| s.iter().map(String::as_str).collect());
let old_dictionary: Option<Vec<_>> =
settings_diff.old.dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect());
let mut del_builder = tokenizer_builder(
old_stop_words,
old_separators.as_deref(),
old_dictionary.as_deref(),
None,
);
let del_tokenizer = del_builder.build();
let del_builder =
tokenizer_builder(old_stop_words, old_separators.as_deref(), old_dictionary.as_deref());
let del_tokenizer = del_builder.into_tokenizer();
let new_stop_words = settings_diff.new.stop_words.as_ref();
let new_separators: Option<Vec<_>> = settings_diff
@ -73,13 +69,9 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
.map(|s| s.iter().map(String::as_str).collect());
let new_dictionary: Option<Vec<_>> =
settings_diff.new.dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect());
let mut add_builder = tokenizer_builder(
new_stop_words,
new_separators.as_deref(),
new_dictionary.as_deref(),
None,
);
let add_tokenizer = add_builder.build();
let add_builder =
tokenizer_builder(new_stop_words, new_separators.as_deref(), new_dictionary.as_deref());
let add_tokenizer = add_builder.into_tokenizer();
// iterate over documents.
let mut cursor = obkv_documents.into_cursor()?;
@ -107,7 +99,7 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
// deletions
tokens_from_document(
&obkv,
&settings_diff.old.searchable_fields_ids,
&settings_diff.old,
&del_tokenizer,
max_positions_per_attributes,
DelAdd::Deletion,
@ -118,7 +110,7 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
// additions
tokens_from_document(
&obkv,
&settings_diff.new.searchable_fields_ids,
&settings_diff.new,
&add_tokenizer,
max_positions_per_attributes,
DelAdd::Addition,
@ -180,7 +172,6 @@ fn tokenizer_builder<'a>(
stop_words: Option<&'a fst::Set<Vec<u8>>>,
allowed_separators: Option<&'a [&str]>,
dictionary: Option<&'a [&str]>,
languages: Option<&'a Vec<Language>>,
) -> TokenizerBuilder<'a, Vec<u8>> {
let mut tokenizer_builder = TokenizerBuilder::new();
if let Some(stop_words) = stop_words {
@ -193,17 +184,13 @@ fn tokenizer_builder<'a>(
tokenizer_builder.separators(separators);
}
if let Some(languages) = languages {
tokenizer_builder.allow_list(languages);
}
tokenizer_builder
}
/// Extract words mapped with their positions of a document.
fn tokens_from_document<'a>(
obkv: &KvReader<'a, FieldId>,
searchable_fields: &[FieldId],
settings: &InnerIndexSettings,
tokenizer: &Tokenizer<'_>,
max_positions_per_attributes: u32,
del_add: DelAdd,
@ -213,7 +200,7 @@ fn tokens_from_document<'a>(
let mut document_writer = KvWriterU16::new(&mut buffers.obkv_buffer);
for (field_id, field_bytes) in obkv.iter() {
// if field is searchable.
if searchable_fields.as_ref().contains(&field_id) {
if settings.searchable_fields_ids.contains(&field_id) {
// extract deletion or addition only.
if let Some(field_bytes) = KvReaderDelAdd::new(field_bytes).get(del_add) {
// parse json.
@ -228,7 +215,8 @@ fn tokens_from_document<'a>(
buffers.field_buffer.clear();
if let Some(field) = json_to_string(&value, &mut buffers.field_buffer) {
// create an iterator of token with their positions.
let tokens = process_tokens(tokenizer.tokenize(field))
let locales = settings.localized_searchable_fields_ids.locales(field_id);
let tokens = process_tokens(tokenizer.tokenize_with_allow_list(field, locales))
.take_while(|(p, _)| (*p as u32) < max_positions_per_attributes);
for (index, token) in tokens {

View File

@ -5,6 +5,7 @@ use std::iter::FromIterator;
use std::{io, str};
use charabia::normalizer::{Normalize, NormalizerOption};
use charabia::{Language, StrDetection, Token};
use heed::types::SerdeJson;
use heed::BytesEncode;
@ -26,10 +27,9 @@ use crate::{FieldId, Result, MAX_FACET_VALUE_LENGTH};
pub fn extract_facet_string_docids<R: io::Read + io::Seek>(
docid_fid_facet_string: grenad::Reader<R>,
indexer: GrenadParameters,
_settings_diff: &InnerIndexSettingsDiff,
settings_diff: &InnerIndexSettingsDiff,
) -> Result<(grenad::Reader<BufReader<File>>, grenad::Reader<BufReader<File>>)> {
let max_memory = indexer.max_memory_by_thread();
let options = NormalizerOption { lossy: true, ..Default::default() };
let mut facet_string_docids_sorter = create_sorter(
grenad::SortAlgorithm::Stable,
@ -54,12 +54,8 @@ pub fn extract_facet_string_docids<R: io::Read + io::Seek>(
while let Some((key, deladd_original_value_bytes)) = cursor.move_on_next()? {
let deladd_reader = KvReaderDelAdd::new(deladd_original_value_bytes);
// nothing to do if we delete and re-add the value.
if deladd_reader.get(DelAdd::Deletion).is_some()
&& deladd_reader.get(DelAdd::Addition).is_some()
{
continue;
}
let is_same_value = deladd_reader.get(DelAdd::Deletion).is_some()
&& deladd_reader.get(DelAdd::Addition).is_some();
let (field_id_bytes, bytes) = try_split_array_at(key).unwrap();
let field_id = FieldId::from_be_bytes(field_id_bytes);
@ -72,29 +68,66 @@ pub fn extract_facet_string_docids<R: io::Read + io::Seek>(
// Facet search normalization
{
let mut hyper_normalized_value = normalized_value.normalize(&options);
let normalized_truncated_facet: String;
if hyper_normalized_value.len() > MAX_FACET_VALUE_LENGTH {
normalized_truncated_facet = hyper_normalized_value
.char_indices()
.take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH)
.map(|(_, c)| c)
.collect();
hyper_normalized_value = normalized_truncated_facet.into();
}
let locales = settings_diff.old.localized_faceted_fields_ids.locales(field_id);
let old_hyper_normalized_value = normalize_facet_string(normalized_value, locales);
let locales = settings_diff.new.localized_faceted_fields_ids.locales(field_id);
let new_hyper_normalized_value = normalize_facet_string(normalized_value, locales);
let set = BTreeSet::from_iter(std::iter::once(normalized_value));
buffer.clear();
let mut obkv = KvWriterDelAdd::new(&mut buffer);
for (deladd_key, _) in deladd_reader.iter() {
let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?;
obkv.insert(deladd_key, val)?;
}
obkv.finish()?;
// if the facet string is the same, we can put the deletion and addition in the same obkv.
if old_hyper_normalized_value == new_hyper_normalized_value {
// nothing to do if we delete and re-add the value.
if is_same_value {
continue;
}
let key = (field_id, hyper_normalized_value.as_ref());
let key_bytes = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?;
normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?;
buffer.clear();
let mut obkv = KvWriterDelAdd::new(&mut buffer);
for (deladd_key, _) in deladd_reader.iter() {
let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?;
obkv.insert(deladd_key, val)?;
}
obkv.finish()?;
let key: (u16, &str) = (field_id, new_hyper_normalized_value.as_ref());
let key_bytes = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?;
normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?;
} else {
// if the facet string is different, we need to insert the deletion and addition in different obkv because the related key is different.
// deletion
if deladd_reader.get(DelAdd::Deletion).is_some() {
// insert old value
let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?;
buffer.clear();
let mut obkv = KvWriterDelAdd::new(&mut buffer);
obkv.insert(DelAdd::Deletion, val)?;
obkv.finish()?;
let key: (u16, &str) = (field_id, old_hyper_normalized_value.as_ref());
let key_bytes =
BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?;
normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?;
}
// addition
if deladd_reader.get(DelAdd::Addition).is_some() {
// insert new value
let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?;
buffer.clear();
let mut obkv = KvWriterDelAdd::new(&mut buffer);
obkv.insert(DelAdd::Addition, val)?;
obkv.finish()?;
let key: (u16, &str) = (field_id, new_hyper_normalized_value.as_ref());
let key_bytes =
BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?;
normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?;
}
}
}
// nothing to do if we delete and re-add the value.
if is_same_value {
continue;
}
let key = FacetGroupKey { field_id, level: 0, left_bound: normalized_value };
@ -112,3 +145,24 @@ pub fn extract_facet_string_docids<R: io::Read + io::Seek>(
let normalized = sorter_into_reader(normalized_facet_string_docids_sorter, indexer)?;
sorter_into_reader(facet_string_docids_sorter, indexer).map(|s| (s, normalized))
}
/// Normalizes the facet string and truncates it to the max length.
fn normalize_facet_string(facet_string: &str, locales: Option<&[Language]>) -> String {
let options = NormalizerOption { lossy: true, ..Default::default() };
let mut detection = StrDetection::new(facet_string, locales);
let token = Token {
lemma: std::borrow::Cow::Borrowed(facet_string),
script: detection.script(),
language: detection.language(),
..Default::default()
};
// truncate the facet string to the max length
token
.normalize(&options)
.lemma
.char_indices()
.take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH)
.map(|(_, c)| c)
.collect()
}

View File

@ -3388,44 +3388,6 @@ mod tests {
wtxn.commit().unwrap();
}
#[test]
#[cfg(feature = "all-tokenizations")]
fn stored_detected_script_and_language_should_not_return_deleted_documents() {
use charabia::{Language, Script};
let index = TempIndex::new();
let mut wtxn = index.write_txn().unwrap();
index
.add_documents_using_wtxn(
&mut wtxn,
documents!([
{ "id": "0", "title": "The quick (\"brown\") fox can't jump 32.3 feet, right? Brr, it's 29.3°F!" },
{ "id": "1", "title": "人人生而自由﹐在尊嚴和權利上一律平等。他們賦有理性和良心﹐並應以兄弟關係的精神互相對待。" },
{ "id": "2", "title": "הַשּׁוּעָל הַמָּהִיר (״הַחוּם״) לֹא יָכוֹל לִקְפֹּץ 9.94 מֶטְרִים, נָכוֹן? ברר, 1.5°C- בַּחוּץ!" },
{ "id": "3", "title": "関西国際空港限定トートバッグ すもももももももものうち" },
{ "id": "4", "title": "ภาษาไทยง่ายนิดเดียว" },
{ "id": "5", "title": "The quick 在尊嚴和權利上一律平等。" },
]))
.unwrap();
let key_cmn = (Script::Cj, Language::Cmn);
let cj_cmn_docs =
index.script_language_documents_ids(&wtxn, &key_cmn).unwrap().unwrap_or_default();
let mut expected_cj_cmn_docids = RoaringBitmap::new();
expected_cj_cmn_docids.push(1);
expected_cj_cmn_docids.push(5);
assert_eq!(cj_cmn_docs, expected_cj_cmn_docids);
delete_documents(&mut wtxn, &index, &["1"]);
wtxn.commit().unwrap();
let rtxn = index.read_txn().unwrap();
let cj_cmn_docs =
index.script_language_documents_ids(&rtxn, &key_cmn).unwrap().unwrap_or_default();
let mut expected_cj_cmn_docids = RoaringBitmap::new();
expected_cj_cmn_docids.push(5);
assert_eq!(cj_cmn_docs, expected_cj_cmn_docids);
}
#[test]
fn delete_words_exact_attributes() {
let index = TempIndex::new();