Take sort criteria from the request

This commit is contained in:
Mubelotix
2025-06-25 16:41:08 +02:00
parent 6e0526090a
commit b05cb80803
3 changed files with 79 additions and 29 deletions

View File

@ -1,6 +1,7 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::io::{ErrorKind, Seek as _}; use std::io::{ErrorKind, Seek as _};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::str::FromStr;
use actix_web::http::header::CONTENT_TYPE; use actix_web::http::header::CONTENT_TYPE;
use actix_web::web::Data; use actix_web::web::Data;
@ -18,12 +19,9 @@ use meilisearch_types::error::{Code, ResponseError};
use meilisearch_types::heed::RoTxn; use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::facet::facet_sort_recursive::recursive_facet_sort; use meilisearch_types::milli::facet::facet_sort_recursive::recursive_facet_sort;
use meilisearch_types::milli::facet::{ascending_facet_sort, descending_facet_sort};
use meilisearch_types::milli::heed_codec::facet::FacetGroupKeyCodec;
use meilisearch_types::milli::heed_codec::BytesRefCodec;
use meilisearch_types::milli::update::IndexDocumentsMethod; use meilisearch_types::milli::update::IndexDocumentsMethod;
use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors; use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors;
use meilisearch_types::milli::DocumentId; use meilisearch_types::milli::{AscDesc, DocumentId};
use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::serde_cs::vec::CS;
use meilisearch_types::star_or::OptionStarOrList; use meilisearch_types::star_or::OptionStarOrList;
use meilisearch_types::tasks::KindWithContent; use meilisearch_types::tasks::KindWithContent;
@ -46,6 +44,7 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::extractors::payload::Payload; use crate::extractors::payload::Payload;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::routes::indexes::search::fix_sort_query_parameters;
use crate::routes::{ use crate::routes::{
get_task_id, is_dry_run, PaginationView, SummarizedTaskView, PAGINATION_DEFAULT_LIMIT, get_task_id, is_dry_run, PaginationView, SummarizedTaskView, PAGINATION_DEFAULT_LIMIT,
}; };
@ -410,6 +409,8 @@ pub struct BrowseQueryGet {
#[param(default, value_type = Option<String>, example = "popularity > 1000")] #[param(default, value_type = Option<String>, example = "popularity > 1000")]
#[deserr(default, error = DeserrQueryParamError<InvalidDocumentFilter>)] #[deserr(default, error = DeserrQueryParamError<InvalidDocumentFilter>)]
filter: Option<String>, filter: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchSort>)]
sort: Option<String>, // TODO: change deser error
} }
#[derive(Debug, Deserr, ToSchema)] #[derive(Debug, Deserr, ToSchema)]
@ -434,6 +435,9 @@ pub struct BrowseQuery {
#[schema(default, value_type = Option<Value>, example = "popularity > 1000")] #[schema(default, value_type = Option<Value>, example = "popularity > 1000")]
#[deserr(default, error = DeserrJsonError<InvalidDocumentFilter>)] #[deserr(default, error = DeserrJsonError<InvalidDocumentFilter>)]
filter: Option<Value>, filter: Option<Value>,
#[schema(default, value_type = Option<Vec<String>>, example = json!(["title:asc", "rating:desc"]))]
#[deserr(default, error = DeserrJsonError<InvalidSearchSort>)] // TODO: Change error
pub sort: Option<Vec<String>>,
} }
/// Get documents with POST /// Get documents with POST
@ -575,7 +579,7 @@ pub async fn get_documents(
) -> Result<HttpResponse, ResponseError> { ) -> Result<HttpResponse, ResponseError> {
debug!(parameters = ?params, "Get documents GET"); debug!(parameters = ?params, "Get documents GET");
let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids } = let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids, sort } =
params.into_inner(); params.into_inner();
let filter = match filter { let filter = match filter {
@ -586,15 +590,14 @@ pub async fn get_documents(
None => None, None => None,
}; };
let ids = ids.map(|ids| ids.into_iter().map(Into::into).collect());
let query = BrowseQuery { let query = BrowseQuery {
offset: offset.0, offset: offset.0,
limit: limit.0, limit: limit.0,
fields: fields.merge_star_and_none(), fields: fields.merge_star_and_none(),
retrieve_vectors: retrieve_vectors.0, retrieve_vectors: retrieve_vectors.0,
filter, filter,
ids, ids: ids.map(|ids| ids.into_iter().map(Into::into).collect()),
sort: sort.map(|attr| fix_sort_query_parameters(&attr)),
}; };
analytics.publish( analytics.publish(
@ -619,7 +622,7 @@ fn documents_by_query(
query: BrowseQuery, query: BrowseQuery,
) -> Result<HttpResponse, ResponseError> { ) -> Result<HttpResponse, ResponseError> {
let index_uid = IndexUid::try_from(index_uid.into_inner())?; let index_uid = IndexUid::try_from(index_uid.into_inner())?;
let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids } = query; let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids, sort } = query;
let retrieve_vectors = RetrieveVectors::new(retrieve_vectors); let retrieve_vectors = RetrieveVectors::new(retrieve_vectors);
@ -637,6 +640,22 @@ fn documents_by_query(
None None
}; };
let sort_criteria = if let Some(sort) = &sort {
let sorts: Vec<_> =
match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() {
Ok(sorts) => sorts,
Err(asc_desc_error) => {
return Err(milli::Error::from(milli::SortError::from(
asc_desc_error,
))
.into())
}
};
Some(sorts)
} else {
None
};
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let (total, documents) = retrieve_documents( let (total, documents) = retrieve_documents(
&index, &index,
@ -647,6 +666,7 @@ fn documents_by_query(
fields, fields,
retrieve_vectors, retrieve_vectors,
index_scheduler.features(), index_scheduler.features(),
sort_criteria,
)?; )?;
let ret = PaginationView::new(offset, limit, total as usize, documents); let ret = PaginationView::new(offset, limit, total as usize, documents);
@ -1505,6 +1525,7 @@ fn retrieve_documents<S: AsRef<str>>(
attributes_to_retrieve: Option<Vec<S>>, attributes_to_retrieve: Option<Vec<S>>,
retrieve_vectors: RetrieveVectors, retrieve_vectors: RetrieveVectors,
features: RoFeatures, features: RoFeatures,
sort_criteria: Option<Vec<AscDesc>>,
) -> Result<(u64, Vec<Document>), ResponseError> { ) -> Result<(u64, Vec<Document>), ResponseError> {
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
let filter = &filter; let filter = &filter;
@ -1537,14 +1558,9 @@ fn retrieve_documents<S: AsRef<str>>(
})? })?
} }
let fields = vec![(0, true)]; if let Some(sort) = sort_criteria {
let number_db = index candidates = recursive_facet_sort(index, &rtxn, &sort, candidates)?;
.facet_id_f64_docids }
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let string_db = index
.facet_id_string_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
candidates = recursive_facet_sort(&rtxn, number_db, string_db, &fields, candidates)?;
let (it, number_of_documents) = { let (it, number_of_documents) = {
let number_of_documents = candidates.len(); let number_of_documents = candidates.len();

View File

@ -1,8 +1,8 @@
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use heed::Database; use heed::Database;
use crate::{facet::{ascending_facet_sort, descending_facet_sort}, heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}}; use crate::{heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}, search::{facet::{ascending_facet_sort, descending_facet_sort}, new::check_sort_criteria}, AscDesc, Member};
pub fn recursive_facet_sort<'t>( fn recursive_facet_sort_inner<'t>(
rtxn: &'t heed::RoTxn<'t>, rtxn: &'t heed::RoTxn<'t>,
number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>, string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
@ -53,7 +53,7 @@ pub fn recursive_facet_sort<'t>(
if inner_candidates.len() <= 1 || fields.len() <= 1 { if inner_candidates.len() <= 1 || fields.len() <= 1 {
result |= inner_candidates; result |= inner_candidates;
} else { } else {
let inner_candidates = recursive_facet_sort( let inner_candidates = recursive_facet_sort_inner(
rtxn, rtxn,
number_db, number_db,
string_db, string_db,
@ -66,3 +66,36 @@ pub fn recursive_facet_sort<'t>(
Ok(result) Ok(result)
} }
pub fn recursive_facet_sort<'t>(
index: &crate::Index,
rtxn: &'t heed::RoTxn<'t>,
sort: &[AscDesc],
candidates: RoaringBitmap,
) -> crate::Result<RoaringBitmap> {
check_sort_criteria(index, rtxn, Some(sort))?;
let mut fields = Vec::new();
let fields_ids_map = index.fields_ids_map(rtxn)?;
for sort in sort {
let (field_id, ascending) = match sort {
AscDesc::Asc(Member::Field(field)) => (fields_ids_map.id(field), true),
AscDesc::Desc(Member::Field(field)) => (fields_ids_map.id(field), false),
AscDesc::Asc(Member::Geo(_)) => todo!(),
AscDesc::Desc(Member::Geo(_)) => todo!(),
};
if let Some(field_id) = field_id {
fields.push((field_id, ascending)); // FIXME: Should this return an error if the field is not found?
}
}
let number_db = index
.facet_id_f64_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let string_db = index
.facet_id_string_docids
.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let candidates = recursive_facet_sort_inner(rtxn, number_db, string_db, &fields, candidates)?;
Ok(candidates)
}

View File

@ -638,7 +638,7 @@ pub fn execute_vector_search(
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>, ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx.index, ctx.txn, sort_criteria.as_deref())?;
// FIXME: input universe = universe & documents_with_vectors // FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe // for now if we're computing embeddings for ALL documents, we can assume that this is just universe
@ -702,7 +702,7 @@ pub fn execute_search(
ranking_score_threshold: Option<f64>, ranking_score_threshold: Option<f64>,
locales: Option<&Vec<Language>>, locales: Option<&Vec<Language>>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx.index, ctx.txn, sort_criteria.as_deref())?;
let mut used_negative_operator = false; let mut used_negative_operator = false;
let mut located_query_terms = None; let mut located_query_terms = None;
@ -872,9 +872,10 @@ pub fn execute_search(
}) })
} }
fn check_sort_criteria( pub(crate) fn check_sort_criteria(
ctx: &SearchContext<'_>, index: &Index,
sort_criteria: Option<&Vec<AscDesc>>, rtxn: &RoTxn<'_>,
sort_criteria: Option<&[AscDesc]>,
) -> Result<()> { ) -> Result<()> {
let sort_criteria = if let Some(sort_criteria) = sort_criteria { let sort_criteria = if let Some(sort_criteria) = sort_criteria {
sort_criteria sort_criteria
@ -888,19 +889,19 @@ fn check_sort_criteria(
// We check that the sort ranking rule exists and throw an // We check that the sort ranking rule exists and throw an
// error if we try to use it and that it doesn't. // error if we try to use it and that it doesn't.
let sort_ranking_rule_missing = !ctx.index.criteria(ctx.txn)?.contains(&crate::Criterion::Sort); let sort_ranking_rule_missing = !index.criteria(rtxn)?.contains(&crate::Criterion::Sort);
if sort_ranking_rule_missing { if sort_ranking_rule_missing {
return Err(UserError::SortRankingRuleMissing.into()); return Err(UserError::SortRankingRuleMissing.into());
} }
// We check that we are allowed to use the sort criteria, we check // We check that we are allowed to use the sort criteria, we check
// that they are declared in the sortable fields. // that they are declared in the sortable fields.
let sortable_fields = ctx.index.sortable_fields(ctx.txn)?; let sortable_fields = index.sortable_fields(rtxn)?;
for asc_desc in sort_criteria { for asc_desc in sort_criteria {
match asc_desc.member() { match asc_desc.member() {
Member::Field(ref field) if !crate::is_faceted(field, &sortable_fields) => { Member::Field(ref field) if !crate::is_faceted(field, &sortable_fields) => {
let (valid_fields, hidden_fields) = let (valid_fields, hidden_fields) =
ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?; index.remove_hidden_fields(rtxn, sortable_fields)?;
return Err(UserError::InvalidSortableAttribute { return Err(UserError::InvalidSortableAttribute {
field: field.to_string(), field: field.to_string(),
@ -911,7 +912,7 @@ fn check_sort_criteria(
} }
Member::Geo(_) if !sortable_fields.contains(RESERVED_GEO_FIELD_NAME) => { Member::Geo(_) if !sortable_fields.contains(RESERVED_GEO_FIELD_NAME) => {
let (valid_fields, hidden_fields) = let (valid_fields, hidden_fields) =
ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?; index.remove_hidden_fields(rtxn, sortable_fields)?;
return Err(UserError::InvalidSortableAttribute { return Err(UserError::InvalidSortableAttribute {
field: RESERVED_GEO_FIELD_NAME.to_string(), field: RESERVED_GEO_FIELD_NAME.to_string(),