mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-26 05:26:27 +00:00 
			
		
		
		
	Refactor the paths_of_cost algorithm
Support conditions that require certain nodes to be skipped
This commit is contained in:
		| @@ -9,141 +9,202 @@ use crate::search::new::query_graph::QueryNode; | ||||
| use crate::search::new::small_bitmap::SmallBitmap; | ||||
| use crate::Result; | ||||
|  | ||||
| impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> { | ||||
|     pub fn visit_paths_of_cost( | ||||
|         &mut self, | ||||
|         from: Interned<QueryNode>, | ||||
|         cost: u16, | ||||
|         all_distances: &MappedInterner<QueryNode, Vec<u16>>, | ||||
|         dead_ends_cache: &mut DeadEndsCache<G::Condition>, | ||||
|         mut visit: impl FnMut( | ||||
|             &[Interned<G::Condition>], | ||||
|             &mut Self, | ||||
|             &mut DeadEndsCache<G::Condition>, | ||||
|         ) -> Result<ControlFlow<()>>, | ||||
|     ) -> Result<()> { | ||||
|         let _ = self.visit_paths_of_cost_rec( | ||||
|             from, | ||||
|             cost, | ||||
|             all_distances, | ||||
|             dead_ends_cache, | ||||
|             &mut visit, | ||||
|             &mut vec![], | ||||
|             &mut SmallBitmap::for_interned_values_in(&self.conditions_interner), | ||||
|             dead_ends_cache.forbidden.clone(), | ||||
|         )?; | ||||
| type VisitFn<'f, G> = &'f mut dyn FnMut( | ||||
|     &[Interned<<G as RankingRuleGraphTrait>::Condition>], | ||||
|     &mut RankingRuleGraph<G>, | ||||
|     &mut DeadEndsCache<<G as RankingRuleGraphTrait>::Condition>, | ||||
| ) -> Result<ControlFlow<()>>; | ||||
|  | ||||
| struct VisitorContext<'a, G: RankingRuleGraphTrait> { | ||||
|     graph: &'a mut RankingRuleGraph<G>, | ||||
|     all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>, | ||||
|     dead_ends_cache: &'a mut DeadEndsCache<G::Condition>, | ||||
| } | ||||
|  | ||||
| struct VisitorState<G: RankingRuleGraphTrait> { | ||||
|     remaining_cost: u64, | ||||
|  | ||||
|     path: Vec<Interned<G::Condition>>, | ||||
|  | ||||
|     visited_conditions: SmallBitmap<G::Condition>, | ||||
|     visited_nodes: SmallBitmap<QueryNode>, | ||||
|  | ||||
|     forbidden_conditions: SmallBitmap<G::Condition>, | ||||
|     forbidden_conditions_to_nodes: SmallBitmap<QueryNode>, | ||||
| } | ||||
|  | ||||
| pub struct PathVisitor<'a, G: RankingRuleGraphTrait> { | ||||
|     state: VisitorState<G>, | ||||
|     ctx: VisitorContext<'a, G>, | ||||
| } | ||||
| impl<'a, G: RankingRuleGraphTrait> PathVisitor<'a, G> { | ||||
|     pub fn new( | ||||
|         cost: u64, | ||||
|         graph: &'a mut RankingRuleGraph<G>, | ||||
|         all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>, | ||||
|         dead_ends_cache: &'a mut DeadEndsCache<G::Condition>, | ||||
|     ) -> Self { | ||||
|         Self { | ||||
|             state: VisitorState { | ||||
|                 remaining_cost: cost, | ||||
|                 path: vec![], | ||||
|                 visited_conditions: SmallBitmap::for_interned_values_in(&graph.conditions_interner), | ||||
|                 visited_nodes: SmallBitmap::for_interned_values_in(&graph.query_graph.nodes), | ||||
|                 forbidden_conditions: SmallBitmap::for_interned_values_in( | ||||
|                     &graph.conditions_interner, | ||||
|                 ), | ||||
|                 forbidden_conditions_to_nodes: SmallBitmap::for_interned_values_in( | ||||
|                     &graph.query_graph.nodes, | ||||
|                 ), | ||||
|             }, | ||||
|             ctx: VisitorContext { graph, all_costs_from_node, dead_ends_cache }, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn visit_paths(mut self, visit: VisitFn<G>) -> Result<()> { | ||||
|         let _ = | ||||
|             self.state.visit_node(self.ctx.graph.query_graph.root_node, visit, &mut self.ctx)?; | ||||
|         Ok(()) | ||||
|     } | ||||
|     pub fn visit_paths_of_cost_rec( | ||||
| } | ||||
|  | ||||
| impl<G: RankingRuleGraphTrait> VisitorState<G> { | ||||
|     fn visit_node( | ||||
|         &mut self, | ||||
|         from: Interned<QueryNode>, | ||||
|         cost: u16, | ||||
|         all_distances: &MappedInterner<QueryNode, Vec<u16>>, | ||||
|         dead_ends_cache: &mut DeadEndsCache<G::Condition>, | ||||
|         visit: &mut impl FnMut( | ||||
|             &[Interned<G::Condition>], | ||||
|             &mut Self, | ||||
|             &mut DeadEndsCache<G::Condition>, | ||||
|         ) -> Result<ControlFlow<()>>, | ||||
|         prev_conditions: &mut Vec<Interned<G::Condition>>, | ||||
|         cur_path: &mut SmallBitmap<G::Condition>, | ||||
|         mut forbidden_conditions: SmallBitmap<G::Condition>, | ||||
|     ) -> Result<bool> { | ||||
|         from_node: Interned<QueryNode>, | ||||
|         visit: VisitFn<G>, | ||||
|         ctx: &mut VisitorContext<G>, | ||||
|     ) -> Result<ControlFlow<(), bool>> { | ||||
|         let mut any_valid = false; | ||||
|  | ||||
|         let edges = self.edges_of_node.get(from).clone(); | ||||
|         'edges_loop: for edge_idx in edges.iter() { | ||||
|             let Some(edge) = self.edges_store.get(edge_idx).as_ref() else { continue }; | ||||
|             if cost < edge.cost as u16 { | ||||
|         let edges = ctx.graph.edges_of_node.get(from_node).clone(); | ||||
|         for edge_idx in edges.iter() { | ||||
|             let Some(edge) = ctx.graph.edges_store.get(edge_idx).clone() else { continue }; | ||||
|  | ||||
|             if self.remaining_cost < edge.cost as u64 { | ||||
|                 continue; | ||||
|             } | ||||
|             let next_any_valid = match edge.condition { | ||||
|                 None => { | ||||
|                     if edge.dest_node == self.query_graph.end_node { | ||||
|                         any_valid = true; | ||||
|                         let control_flow = visit(prev_conditions, self, dead_ends_cache)?; | ||||
|                         match control_flow { | ||||
|                             ControlFlow::Continue(_) => {} | ||||
|                             ControlFlow::Break(_) => return Ok(true), | ||||
|                         } | ||||
|                         true | ||||
|                     } else { | ||||
|                         self.visit_paths_of_cost_rec( | ||||
|                             edge.dest_node, | ||||
|                             cost - edge.cost as u16, | ||||
|                             all_distances, | ||||
|                             dead_ends_cache, | ||||
|                             visit, | ||||
|                             prev_conditions, | ||||
|                             cur_path, | ||||
|                             forbidden_conditions.clone(), | ||||
|                         )? | ||||
|                     } | ||||
|                 } | ||||
|                 Some(condition) => { | ||||
|                     if forbidden_conditions.contains(condition) | ||||
|                         || all_distances | ||||
|                             .get(edge.dest_node) | ||||
|                             .iter() | ||||
|                             .all(|next_cost| *next_cost != cost - edge.cost as u16) | ||||
|                     { | ||||
|                         continue; | ||||
|                     } | ||||
|                     cur_path.insert(condition); | ||||
|                     prev_conditions.push(condition); | ||||
|                     let mut new_forbidden_conditions = forbidden_conditions.clone(); | ||||
|                     if let Some(next_forbidden) = | ||||
|                         dead_ends_cache.forbidden_conditions_after_prefix(prev_conditions) | ||||
|                     { | ||||
|                         new_forbidden_conditions.union(&next_forbidden); | ||||
|                     } | ||||
|  | ||||
|                     let next_any_valid = if edge.dest_node == self.query_graph.end_node { | ||||
|                         any_valid = true; | ||||
|                         let control_flow = visit(prev_conditions, self, dead_ends_cache)?; | ||||
|                         match control_flow { | ||||
|                             ControlFlow::Continue(_) => {} | ||||
|                             ControlFlow::Break(_) => return Ok(true), | ||||
|                         } | ||||
|                         true | ||||
|                     } else { | ||||
|                         self.visit_paths_of_cost_rec( | ||||
|                             edge.dest_node, | ||||
|                             cost - edge.cost as u16, | ||||
|                             all_distances, | ||||
|                             dead_ends_cache, | ||||
|                             visit, | ||||
|                             prev_conditions, | ||||
|                             cur_path, | ||||
|                             new_forbidden_conditions, | ||||
|                         )? | ||||
|                     }; | ||||
|                     cur_path.remove(condition); | ||||
|                     prev_conditions.pop(); | ||||
|                     next_any_valid | ||||
|                 } | ||||
|             self.remaining_cost -= edge.cost as u64; | ||||
|             let cf = match edge.condition { | ||||
|                 Some(condition) => self.visit_condition( | ||||
|                     condition, | ||||
|                     edge.dest_node, | ||||
|                     &edge.nodes_to_skip, | ||||
|                     visit, | ||||
|                     ctx, | ||||
|                 )?, | ||||
|                 None => self.visit_no_condition(edge.dest_node, &edge.nodes_to_skip, visit, ctx)?, | ||||
|             }; | ||||
|             any_valid |= next_any_valid; | ||||
|             self.remaining_cost += edge.cost as u64; | ||||
|  | ||||
|             let ControlFlow::Continue(next_any_valid) = cf else { | ||||
|                 return Ok(ControlFlow::Break(())); | ||||
|             }; | ||||
|             if next_any_valid { | ||||
|                 forbidden_conditions = | ||||
|                     dead_ends_cache.forbidden_conditions_for_all_prefixes_up_to(prev_conditions); | ||||
|                 if cur_path.intersects(&forbidden_conditions) { | ||||
|                     break 'edges_loop; | ||||
|                 self.forbidden_conditions = ctx | ||||
|                     .dead_ends_cache | ||||
|                     .forbidden_conditions_for_all_prefixes_up_to(self.path.iter().copied()); | ||||
|                 if self.visited_conditions.intersects(&self.forbidden_conditions) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             any_valid |= next_any_valid; | ||||
|         } | ||||
|  | ||||
|         Ok(any_valid) | ||||
|         Ok(ControlFlow::Continue(any_valid)) | ||||
|     } | ||||
|  | ||||
|     pub fn initialize_distances_with_necessary_edges(&self) -> MappedInterner<QueryNode, Vec<u16>> { | ||||
|         let mut distances_to_end = self.query_graph.nodes.map(|_| vec![]); | ||||
|     fn visit_no_condition( | ||||
|         &mut self, | ||||
|         dest_node: Interned<QueryNode>, | ||||
|         edge_forbidden_nodes: &SmallBitmap<QueryNode>, | ||||
|         visit: VisitFn<G>, | ||||
|         ctx: &mut VisitorContext<G>, | ||||
|     ) -> Result<ControlFlow<(), bool>> { | ||||
|         if ctx | ||||
|             .all_costs_from_node | ||||
|             .get(dest_node) | ||||
|             .iter() | ||||
|             .all(|next_cost| *next_cost != self.remaining_cost) | ||||
|         { | ||||
|             return Ok(ControlFlow::Continue(false)); | ||||
|         } | ||||
|         if dest_node == ctx.graph.query_graph.end_node { | ||||
|             let control_flow = visit(&self.path, ctx.graph, ctx.dead_ends_cache)?; | ||||
|             match control_flow { | ||||
|                 ControlFlow::Continue(_) => Ok(ControlFlow::Continue(true)), | ||||
|                 ControlFlow::Break(_) => Ok(ControlFlow::Break(())), | ||||
|             } | ||||
|         } else { | ||||
|             let old_fbct = self.forbidden_conditions_to_nodes.clone(); | ||||
|             self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes); | ||||
|             let cf = self.visit_node(dest_node, visit, ctx)?; | ||||
|             self.forbidden_conditions_to_nodes = old_fbct; | ||||
|             Ok(cf) | ||||
|         } | ||||
|     } | ||||
|     fn visit_condition( | ||||
|         &mut self, | ||||
|         condition: Interned<G::Condition>, | ||||
|         dest_node: Interned<QueryNode>, | ||||
|         edge_forbidden_nodes: &SmallBitmap<QueryNode>, | ||||
|         visit: VisitFn<G>, | ||||
|         ctx: &mut VisitorContext<G>, | ||||
|     ) -> Result<ControlFlow<(), bool>> { | ||||
|         assert!(dest_node != ctx.graph.query_graph.end_node); | ||||
|  | ||||
|         if self.forbidden_conditions_to_nodes.contains(dest_node) | ||||
|             || edge_forbidden_nodes.intersects(&self.visited_nodes) | ||||
|         { | ||||
|             return Ok(ControlFlow::Continue(false)); | ||||
|         } | ||||
|         if self.forbidden_conditions.contains(condition) { | ||||
|             return Ok(ControlFlow::Continue(false)); | ||||
|         } | ||||
|  | ||||
|         if ctx | ||||
|             .all_costs_from_node | ||||
|             .get(dest_node) | ||||
|             .iter() | ||||
|             .all(|next_cost| *next_cost != self.remaining_cost) | ||||
|         { | ||||
|             return Ok(ControlFlow::Continue(false)); | ||||
|         } | ||||
|  | ||||
|         self.path.push(condition); | ||||
|         self.visited_nodes.insert(dest_node); | ||||
|         self.visited_conditions.insert(condition); | ||||
|  | ||||
|         let old_fc = self.forbidden_conditions.clone(); | ||||
|         if let Some(next_forbidden) = | ||||
|             ctx.dead_ends_cache.forbidden_conditions_after_prefix(self.path.iter().copied()) | ||||
|         { | ||||
|             self.forbidden_conditions.union(&next_forbidden); | ||||
|         } | ||||
|         let old_fctn = self.forbidden_conditions_to_nodes.clone(); | ||||
|         self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes); | ||||
|  | ||||
|         let cf = self.visit_node(dest_node, visit, ctx)?; | ||||
|  | ||||
|         self.forbidden_conditions_to_nodes = old_fctn; | ||||
|         self.forbidden_conditions = old_fc; | ||||
|  | ||||
|         self.visited_conditions.remove(condition); | ||||
|         self.visited_nodes.remove(dest_node); | ||||
|         self.path.pop(); | ||||
|  | ||||
|         Ok(cf) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> { | ||||
|     pub fn find_all_costs_to_end(&self) -> MappedInterner<QueryNode, Vec<u64>> { | ||||
|         let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]); | ||||
|         let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len()); | ||||
|  | ||||
|         let mut node_stack = VecDeque::new(); | ||||
|  | ||||
|         *distances_to_end.get_mut(self.query_graph.end_node) = vec![0]; | ||||
|         *costs_to_end.get_mut(self.query_graph.end_node) = vec![0]; | ||||
|  | ||||
|         for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() { | ||||
|             node_stack.push_back(prev_node); | ||||
| @@ -151,22 +212,22 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> { | ||||
|         } | ||||
|  | ||||
|         while let Some(cur_node) = node_stack.pop_front() { | ||||
|             let mut self_distances = BTreeSet::<u16>::new(); | ||||
|             let mut self_costs = BTreeSet::<u64>::new(); | ||||
|  | ||||
|             let cur_node_edges = &self.edges_of_node.get(cur_node); | ||||
|             for edge_idx in cur_node_edges.iter() { | ||||
|                 let edge = self.edges_store.get(edge_idx).as_ref().unwrap(); | ||||
|                 let succ_node = edge.dest_node; | ||||
|                 let succ_distances = distances_to_end.get(succ_node); | ||||
|                 for succ_distance in succ_distances { | ||||
|                     self_distances.insert(edge.cost as u16 + succ_distance); | ||||
|                 let succ_costs = costs_to_end.get(succ_node); | ||||
|                 for succ_distance in succ_costs { | ||||
|                     self_costs.insert(edge.cost as u64 + succ_distance); | ||||
|                 } | ||||
|             } | ||||
|             let distances_to_end_cur_node = distances_to_end.get_mut(cur_node); | ||||
|             for cost in self_distances.iter() { | ||||
|                 distances_to_end_cur_node.push(*cost); | ||||
|             let costs_to_end_cur_node = costs_to_end.get_mut(cur_node); | ||||
|             for cost in self_costs.iter() { | ||||
|                 costs_to_end_cur_node.push(*cost); | ||||
|             } | ||||
|             *distances_to_end.get_mut(cur_node) = self_distances.into_iter().collect(); | ||||
|             *costs_to_end.get_mut(cur_node) = self_costs.into_iter().collect(); | ||||
|             for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { | ||||
|                 if !enqueued.contains(prev_node) { | ||||
|                     node_stack.push_back(prev_node); | ||||
| @@ -174,6 +235,6 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> { | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         distances_to_end | ||||
|         costs_to_end | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -36,12 +36,12 @@ impl<T> DeadEndsCache<T> { | ||||
|     } | ||||
|     pub fn forbidden_conditions_for_all_prefixes_up_to( | ||||
|         &mut self, | ||||
|         prefix: &[Interned<T>], | ||||
|         prefix: impl Iterator<Item = Interned<T>>, | ||||
|     ) -> SmallBitmap<T> { | ||||
|         let mut forbidden = self.forbidden.clone(); | ||||
|         let mut cursor = self; | ||||
|         for c in prefix.iter() { | ||||
|             if let Some(next) = cursor.advance(*c) { | ||||
|         for c in prefix { | ||||
|             if let Some(next) = cursor.advance(c) { | ||||
|                 cursor = next; | ||||
|                 forbidden.union(&cursor.forbidden); | ||||
|             } else { | ||||
| @@ -52,11 +52,11 @@ impl<T> DeadEndsCache<T> { | ||||
|     } | ||||
|     pub fn forbidden_conditions_after_prefix( | ||||
|         &mut self, | ||||
|         prefix: &[Interned<T>], | ||||
|         prefix: impl Iterator<Item = Interned<T>>, | ||||
|     ) -> Option<SmallBitmap<T>> { | ||||
|         let mut cursor = self; | ||||
|         for c in prefix.iter() { | ||||
|             if let Some(next) = cursor.advance(*c) { | ||||
|         for c in prefix { | ||||
|             if let Some(next) = cursor.advance(c) { | ||||
|                 cursor = next; | ||||
|             } else { | ||||
|                 return None; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user