diff --git a/crates/meilisearch/src/routes/indexes/documents.rs b/crates/meilisearch/src/routes/indexes/documents.rs index 99ca2b7df..d91f43d21 100644 --- a/crates/meilisearch/src/routes/indexes/documents.rs +++ b/crates/meilisearch/src/routes/indexes/documents.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; use std::io::{ErrorKind, Seek as _}; use std::marker::PhantomData; +use std::str::FromStr; use actix_web::http::header::CONTENT_TYPE; use actix_web::web::Data; @@ -18,12 +19,9 @@ use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::facet::facet_sort_recursive::recursive_facet_sort; -use meilisearch_types::milli::facet::{ascending_facet_sort, descending_facet_sort}; -use meilisearch_types::milli::heed_codec::facet::FacetGroupKeyCodec; -use meilisearch_types::milli::heed_codec::BytesRefCodec; use meilisearch_types::milli::update::IndexDocumentsMethod; use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors; -use meilisearch_types::milli::DocumentId; +use meilisearch_types::milli::{AscDesc, DocumentId}; use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::star_or::OptionStarOrList; use meilisearch_types::tasks::KindWithContent; @@ -46,6 +44,7 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::extractors::payload::Payload; use crate::extractors::sequential_extractor::SeqHandler; +use crate::routes::indexes::search::fix_sort_query_parameters; use crate::routes::{ get_task_id, is_dry_run, PaginationView, SummarizedTaskView, PAGINATION_DEFAULT_LIMIT, }; @@ -410,6 +409,8 @@ pub struct BrowseQueryGet { #[param(default, value_type = Option, example = "popularity > 1000")] #[deserr(default, error = DeserrQueryParamError)] filter: Option, + #[deserr(default, error = DeserrQueryParamError)] + sort: Option, // TODO: change deser error } #[derive(Debug, Deserr, ToSchema)] @@ -434,6 +435,9 @@ pub struct BrowseQuery { #[schema(default, value_type = Option, example = "popularity > 1000")] #[deserr(default, error = DeserrJsonError)] filter: Option, + #[schema(default, value_type = Option>, example = json!(["title:asc", "rating:desc"]))] + #[deserr(default, error = DeserrJsonError)] // TODO: Change error + pub sort: Option>, } /// Get documents with POST @@ -575,7 +579,7 @@ pub async fn get_documents( ) -> Result { debug!(parameters = ?params, "Get documents GET"); - let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids } = + let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter, ids, sort } = params.into_inner(); let filter = match filter { @@ -586,15 +590,14 @@ pub async fn get_documents( None => None, }; - let ids = ids.map(|ids| ids.into_iter().map(Into::into).collect()); - let query = BrowseQuery { offset: offset.0, limit: limit.0, fields: fields.merge_star_and_none(), retrieve_vectors: retrieve_vectors.0, filter, - ids, + ids: ids.map(|ids| ids.into_iter().map(Into::into).collect()), + sort: sort.map(|attr| fix_sort_query_parameters(&attr)), }; analytics.publish( @@ -619,7 +622,7 @@ fn documents_by_query( query: BrowseQuery, ) -> Result { let index_uid = IndexUid::try_from(index_uid.into_inner())?; - let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids } = query; + let BrowseQuery { offset, limit, fields, retrieve_vectors, filter, ids, sort } = query; let retrieve_vectors = RetrieveVectors::new(retrieve_vectors); @@ -637,6 +640,22 @@ fn documents_by_query( None }; + let sort_criteria = if let Some(sort) = &sort { + let sorts: Vec<_> = + match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() { + Ok(sorts) => sorts, + Err(asc_desc_error) => { + return Err(milli::Error::from(milli::SortError::from( + asc_desc_error, + )) + .into()) + } + }; + Some(sorts) + } else { + None + }; + let index = index_scheduler.index(&index_uid)?; let (total, documents) = retrieve_documents( &index, @@ -647,6 +666,7 @@ fn documents_by_query( fields, retrieve_vectors, index_scheduler.features(), + sort_criteria, )?; let ret = PaginationView::new(offset, limit, total as usize, documents); @@ -1505,6 +1525,7 @@ fn retrieve_documents>( attributes_to_retrieve: Option>, retrieve_vectors: RetrieveVectors, features: RoFeatures, + sort_criteria: Option>, ) -> Result<(u64, Vec), ResponseError> { let rtxn = index.read_txn()?; let filter = &filter; @@ -1537,14 +1558,9 @@ fn retrieve_documents>( })? } - let fields = vec![(0, true)]; - let number_db = index - .facet_id_f64_docids - .remap_key_type::>(); - let string_db = index - .facet_id_string_docids - .remap_key_type::>(); - candidates = recursive_facet_sort(&rtxn, number_db, string_db, &fields, candidates)?; + if let Some(sort) = sort_criteria { + candidates = recursive_facet_sort(index, &rtxn, &sort, candidates)?; + } let (it, number_of_documents) = { let number_of_documents = candidates.len(); diff --git a/crates/milli/src/facet/facet_sort_recursive.rs b/crates/milli/src/facet/facet_sort_recursive.rs index a6bbad906..c0fd6ca6f 100644 --- a/crates/milli/src/facet/facet_sort_recursive.rs +++ b/crates/milli/src/facet/facet_sort_recursive.rs @@ -1,8 +1,8 @@ use roaring::RoaringBitmap; use heed::Database; -use crate::{facet::{ascending_facet_sort, descending_facet_sort}, heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}}; +use crate::{heed_codec::{facet::{FacetGroupKeyCodec, FacetGroupValueCodec}, BytesRefCodec}, search::{facet::{ascending_facet_sort, descending_facet_sort}, new::check_sort_criteria}, AscDesc, Member}; -pub fn recursive_facet_sort<'t>( +fn recursive_facet_sort_inner<'t>( rtxn: &'t heed::RoTxn<'t>, number_db: Database, FacetGroupValueCodec>, string_db: Database, FacetGroupValueCodec>, @@ -53,7 +53,7 @@ pub fn recursive_facet_sort<'t>( if inner_candidates.len() <= 1 || fields.len() <= 1 { result |= inner_candidates; } else { - let inner_candidates = recursive_facet_sort( + let inner_candidates = recursive_facet_sort_inner( rtxn, number_db, string_db, @@ -66,3 +66,36 @@ pub fn recursive_facet_sort<'t>( Ok(result) } + +pub fn recursive_facet_sort<'t>( + index: &crate::Index, + rtxn: &'t heed::RoTxn<'t>, + sort: &[AscDesc], + candidates: RoaringBitmap, +) -> crate::Result { + check_sort_criteria(index, rtxn, Some(sort))?; + + let mut fields = Vec::new(); + let fields_ids_map = index.fields_ids_map(rtxn)?; + for sort in sort { + let (field_id, ascending) = match sort { + AscDesc::Asc(Member::Field(field)) => (fields_ids_map.id(field), true), + AscDesc::Desc(Member::Field(field)) => (fields_ids_map.id(field), false), + AscDesc::Asc(Member::Geo(_)) => todo!(), + AscDesc::Desc(Member::Geo(_)) => todo!(), + }; + if let Some(field_id) = field_id { + fields.push((field_id, ascending)); // FIXME: Should this return an error if the field is not found? + } + } + + let number_db = index + .facet_id_f64_docids + .remap_key_type::>(); + let string_db = index + .facet_id_string_docids + .remap_key_type::>(); + + let candidates = recursive_facet_sort_inner(rtxn, number_db, string_db, &fields, candidates)?; + Ok(candidates) +} diff --git a/crates/milli/src/search/new/mod.rs b/crates/milli/src/search/new/mod.rs index a65b4076b..5cb4c9fd5 100644 --- a/crates/milli/src/search/new/mod.rs +++ b/crates/milli/src/search/new/mod.rs @@ -638,7 +638,7 @@ pub fn execute_vector_search( time_budget: TimeBudget, ranking_score_threshold: Option, ) -> Result { - check_sort_criteria(ctx, sort_criteria.as_ref())?; + check_sort_criteria(ctx.index, ctx.txn, sort_criteria.as_deref())?; // FIXME: input universe = universe & documents_with_vectors // for now if we're computing embeddings for ALL documents, we can assume that this is just universe @@ -702,7 +702,7 @@ pub fn execute_search( ranking_score_threshold: Option, locales: Option<&Vec>, ) -> Result { - check_sort_criteria(ctx, sort_criteria.as_ref())?; + check_sort_criteria(ctx.index, ctx.txn, sort_criteria.as_deref())?; let mut used_negative_operator = false; let mut located_query_terms = None; @@ -872,9 +872,10 @@ pub fn execute_search( }) } -fn check_sort_criteria( - ctx: &SearchContext<'_>, - sort_criteria: Option<&Vec>, +pub(crate) fn check_sort_criteria( + index: &Index, + rtxn: &RoTxn<'_>, + sort_criteria: Option<&[AscDesc]>, ) -> Result<()> { let sort_criteria = if let Some(sort_criteria) = sort_criteria { sort_criteria @@ -888,19 +889,19 @@ fn check_sort_criteria( // We check that the sort ranking rule exists and throw an // error if we try to use it and that it doesn't. - let sort_ranking_rule_missing = !ctx.index.criteria(ctx.txn)?.contains(&crate::Criterion::Sort); + let sort_ranking_rule_missing = !index.criteria(rtxn)?.contains(&crate::Criterion::Sort); if sort_ranking_rule_missing { return Err(UserError::SortRankingRuleMissing.into()); } // We check that we are allowed to use the sort criteria, we check // that they are declared in the sortable fields. - let sortable_fields = ctx.index.sortable_fields(ctx.txn)?; + let sortable_fields = index.sortable_fields(rtxn)?; for asc_desc in sort_criteria { match asc_desc.member() { Member::Field(ref field) if !crate::is_faceted(field, &sortable_fields) => { let (valid_fields, hidden_fields) = - ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?; + index.remove_hidden_fields(rtxn, sortable_fields)?; return Err(UserError::InvalidSortableAttribute { field: field.to_string(), @@ -911,7 +912,7 @@ fn check_sort_criteria( } Member::Geo(_) if !sortable_fields.contains(RESERVED_GEO_FIELD_NAME) => { let (valid_fields, hidden_fields) = - ctx.index.remove_hidden_fields(ctx.txn, sortable_fields)?; + index.remove_hidden_fields(rtxn, sortable_fields)?; return Err(UserError::InvalidSortableAttribute { field: RESERVED_GEO_FIELD_NAME.to_string(),