From 20525376815e6c0d6962744701f954bf9f03f7cb Mon Sep 17 00:00:00 2001 From: Mubelotix Date: Mon, 7 Jul 2025 15:28:35 +0200 Subject: [PATCH] Implement core filter logic --- crates/milli/src/index.rs | 2 +- crates/milli/src/search/facet/filter.rs | 15 ++- .../milli/src/search/facet/filter_vector.rs | 123 ++++++++++++++++++ crates/milli/src/search/facet/mod.rs | 1 + .../src/update/index_documents/transform.rs | 2 +- 5 files changed, 138 insertions(+), 5 deletions(-) create mode 100644 crates/milli/src/search/facet/filter_vector.rs diff --git a/crates/milli/src/index.rs b/crates/milli/src/index.rs index b2ec992ba..2751498bf 100644 --- a/crates/milli/src/index.rs +++ b/crates/milli/src/index.rs @@ -1776,7 +1776,7 @@ impl Index { embedder_info.embedder_id, config.config.quantized(), ); - let embeddings = reader.item_vectors(rtxn, docid)?; + let embeddings = reader.item_vectors(rtxn, docid)?; // MARKER res.insert( config.name.to_owned(), (embeddings, embedder_info.embedding_status.must_regenerate(docid)), diff --git a/crates/milli/src/search/facet/filter.rs b/crates/milli/src/search/facet/filter.rs index c3eba8031..f80d1681f 100644 --- a/crates/milli/src/search/facet/filter.rs +++ b/crates/milli/src/search/facet/filter.rs @@ -10,7 +10,7 @@ use memchr::memmem::Finder; use roaring::{MultiOps, RoaringBitmap}; use serde_json::Value; -use super::facet_range_search; +use super::{facet_range_search, filter_vector::VectorFilter}; use crate::constants::RESERVED_GEO_FIELD_NAME; use crate::error::{Error, UserError}; use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features}; @@ -234,8 +234,11 @@ impl<'a> Filter<'a> { pub fn evaluate(&self, rtxn: &heed::RoTxn<'_>, index: &Index) -> Result { // to avoid doing this for each recursive call we're going to do it ONCE ahead of time let fields_ids_map = index.fields_ids_map(rtxn)?; - let filterable_attributes_rules = index.filterable_attributes_rules(rtxn)?; + let filterable_attributes_rules = dbg!(index.filterable_attributes_rules(rtxn)?); + for fid in self.condition.fids(MAX_FILTER_DEPTH) { + println!("{fid:?}"); + let attribute = fid.value(); if matching_features(attribute, &filterable_attributes_rules) .is_some_and(|(_, features)| features.is_filterable()) @@ -542,7 +545,13 @@ impl<'a> Filter<'a> { .union() } FilterCondition::Condition { fid, op } => { - let Some(field_id) = field_ids_map.id(fid.value()) else { + let value = fid.value(); + if VectorFilter::matches(value, op) { + let vector_filter = VectorFilter::parse(value)?; + return vector_filter.evaluate(rtxn, index, universe); + } + + let Some(field_id) = field_ids_map.id(value) else { return Ok(RoaringBitmap::new()); }; let Some((rule_index, features)) = diff --git a/crates/milli/src/search/facet/filter_vector.rs b/crates/milli/src/search/facet/filter_vector.rs new file mode 100644 index 000000000..701ab561c --- /dev/null +++ b/crates/milli/src/search/facet/filter_vector.rs @@ -0,0 +1,123 @@ +use filter_parser::Condition; +use roaring::RoaringBitmap; + +use crate::error::{Error, UserError}; +use crate::vector::{ArroyStats, ArroyWrapper}; +use crate::{Index, Result}; + +pub(super) struct VectorFilter<'a> { + embedder_name: &'a str, + fragment_name: Option<&'a str>, + user_provided: bool, + // TODO: not_user_provided: bool, +} + +impl<'a> VectorFilter<'a> { + pub(super) fn matches(value: &str, op: &Condition) -> bool { + matches!(op, Condition::Exists) && value.starts_with("_vectors.") + } + + /// Parses a vector filter string. + /// + /// Valid formats: + /// - `_vectors.{embedder_name}` + /// - `_vectors.{embedder_name}.userProvided` + /// - `_vectors.{embedder_name}.fragments.{fragment_name}` + /// - `_vectors.{embedder_name}.fragments.{fragment_name}.userProvided` + pub(super) fn parse(s: &'a str) -> Result { + let mut split = s.split('.').peekable(); + + if split.next() != Some("_vectors") { + return Err(Error::UserError(UserError::InvalidFilter(String::from( + "Vector filter must start with '_vectors'", + )))); + } + + let embedder_name = split.next().ok_or_else(|| { + Error::UserError(UserError::InvalidFilter(String::from( + "Vector filter must contain an embedder name", + ))) + })?; + + let mut fragment_name = None; + if split.peek() == Some(&"fragments") { + split.next(); + + fragment_name = Some(split.next().ok_or_else(|| { + Error::UserError(UserError::InvalidFilter( + String::from("Vector filter is inconsistent: either specify a fragment name or remove the 'fragments' part"), + )) + })?); + } + + let mut user_provided = false; + if split.peek() == Some(&"userProvided") || split.peek() == Some(&"user_provided") { + split.next(); + user_provided = true; + } + + if let Some(next) = split.next() { + return Err(Error::UserError(UserError::InvalidFilter(format!( + "Unexpected part in vector filter: '{next}'" + )))); + } + + Ok(Self { embedder_name, fragment_name, user_provided }) + } + + pub(super) fn evaluate( + &self, + rtxn: &heed::RoTxn<'_>, + index: &Index, + universe: Option<&RoaringBitmap>, + ) -> Result { + let index_embedding_configs = index.embedding_configs(); + let embedding_configs = index_embedding_configs.embedding_configs(rtxn)?; + + let Some(embedder_config) = + embedding_configs.iter().find(|config| config.name == self.embedder_name) + else { + return Ok(RoaringBitmap::new()); + }; + let Some(embedder_info) = + index_embedding_configs.embedder_info(rtxn, self.embedder_name)? + else { + return Ok(RoaringBitmap::new()); + }; + + let arroy_wrapper = ArroyWrapper::new( + index.vector_arroy, + embedder_info.embedder_id, + embedder_config.config.quantized(), + ); + + let mut docids = if let Some(fragment_name) = self.fragment_name { + let Some(fragment_config) = embedder_config + .fragments + .as_slice() + .iter() + .find(|fragment| fragment.name == fragment_name) + else { + return Ok(RoaringBitmap::new()); + }; + + arroy_wrapper.items_in_store(rtxn, fragment_config.id, |bitmap| bitmap.clone())? + } else { + let mut stats = ArroyStats::default(); + arroy_wrapper.aggregate_stats(rtxn, &mut stats)?; + stats.documents + }; + + // FIXME: performance + if self.user_provided { + let user_provided_docsids = embedder_info.embedding_status.user_provided_docids(); + docids &= user_provided_docsids; + } + + if let Some(universe) = universe { + docids &= universe; + } + + Ok(docids) + } +} diff --git a/crates/milli/src/search/facet/mod.rs b/crates/milli/src/search/facet/mod.rs index a5e65c95d..fac85df59 100644 --- a/crates/milli/src/search/facet/mod.rs +++ b/crates/milli/src/search/facet/mod.rs @@ -17,6 +17,7 @@ mod facet_range_search; mod facet_sort_ascending; mod facet_sort_descending; mod filter; +mod filter_vector; mod search; fn facet_extreme_value<'t>( diff --git a/crates/milli/src/update/index_documents/transform.rs b/crates/milli/src/update/index_documents/transform.rs index e07483aff..d69768d4b 100644 --- a/crates/milli/src/update/index_documents/transform.rs +++ b/crates/milli/src/update/index_documents/transform.rs @@ -966,7 +966,7 @@ impl<'a, 'i> Transform<'a, 'i> { // some user provided, remove only the ids that are not user provided let to_delete = arroy.items_in_store(wtxn, *fragment_id, |items| { items - infos.embedding_status.user_provided_docids() - })?; + })?; // MARKER for to_delete in to_delete { arroy.del_item_in_store(wtxn, to_delete, *fragment_id, dimensions)?;