Introduce a customized A* algorithm.

This custom algo lazily compute the intersections between words, to avoid too much set operations and database reads
This commit is contained in:
Kerollmops
2020-06-14 12:51:54 +02:00
parent 69285b22d3
commit a8cda248b4
5 changed files with 262 additions and 65 deletions

View File

@ -1,7 +1,7 @@
use std::cmp;
use std::time::Instant;
use pathfinding::directed::astar::astar_bag;
use crate::iter_shortest_paths::astar_bag;
const ONE_ATTRIBUTE: u32 = 1000;
const MAX_DISTANCE: u32 = 8;
@ -37,6 +37,8 @@ enum Node {
position: u32,
// The total accumulated proximity until this node, used for skipping nodes.
acc_proximity: u32,
// The parent position from the above layer.
parent_position: u32,
},
}
@ -44,35 +46,29 @@ impl Node {
// TODO we must skip the successors that have already been seen
// TODO we must skip the successors that doesn't return any documents
// this way we are able to skip entire paths
fn successors<F>(
&self,
positions: &[Vec<u32>],
best_proximity: u32,
mut contains_documents: F,
) -> Vec<(Node, u32)>
where F: FnMut((usize, u32), (usize, u32)) -> bool,
{
fn successors(&self, positions: &[Vec<u32>], best_proximity: u32) -> Vec<(Node, u32)> {
match self {
Node::Uninit => {
positions[0].iter().map(|p| {
(Node::Init { layer: 0, position: *p, acc_proximity: 0 }, 0)
(Node::Init { layer: 0, position: *p, acc_proximity: 0, parent_position: 0 }, 0)
}).collect()
},
// We reached the highest layer
n @ Node::Init { .. } if n.is_complete(positions) => vec![],
Node::Init { layer, position, acc_proximity } => {
Node::Init { layer, position, acc_proximity, .. } => {
positions[layer + 1].iter().filter_map(|p| {
let proximity = positions_proximity(*position, *p);
let node = Node::Init { layer: layer + 1, position: *p, acc_proximity: acc_proximity + proximity };
if (contains_documents)((*layer, *position), (layer + 1, *p)) {
// We do not produce the nodes we have already seen in previous iterations loops.
if node.is_complete(positions) && acc_proximity + proximity < best_proximity {
None
} else {
Some((node, proximity))
}
} else {
let node = Node::Init {
layer: layer + 1,
position: *p,
acc_proximity: acc_proximity + proximity,
parent_position: *position,
};
// We do not produce the nodes we have already seen in previous iterations loops.
if node.is_complete(positions) && acc_proximity + proximity < best_proximity {
None
} else {
Some((node, proximity))
}
}).collect()
}
@ -92,6 +88,35 @@ impl Node {
Node::Init { position, .. } => Some(*position),
}
}
fn proximity(&self) -> u32 {
match self {
Node::Uninit => 0,
Node::Init { layer, position, acc_proximity, parent_position } => {
if layer.checked_sub(1).is_some() {
acc_proximity + positions_proximity(*position, *parent_position)
} else {
0
}
},
}
}
fn is_reachable<F>(&self, mut contains_documents: F) -> bool
where F: FnMut((usize, u32), (usize, u32)) -> bool,
{
match self {
Node::Uninit => true,
Node::Init { layer, position, parent_position, .. } => {
match layer.checked_sub(1) {
Some(parent_layer) => {
(contains_documents)((parent_layer, *parent_position), (*layer, *position))
},
None => true,
}
},
}
}
}
pub struct BestProximity<F> {
@ -102,7 +127,7 @@ pub struct BestProximity<F> {
impl<F> BestProximity<F> {
pub fn new(positions: Vec<Vec<u32>>, contains_documents: F) -> BestProximity<F> {
let best_proximity = positions.len() as u32 - 1;
let best_proximity = (positions.len() as u32).saturating_sub(1);
BestProximity { positions, best_proximity, contains_documents }
}
}
@ -121,9 +146,12 @@ where F: FnMut((usize, u32), (usize, u32)) -> bool + Copy,
let result = astar_bag(
&Node::Uninit, // start
|n| n.successors(&self.positions, self.best_proximity, self.contains_documents),
|n| n.successors(&self.positions, self.best_proximity),
|_| 0, // heuristic
|n| n.is_complete(&self.positions), // success
|n| { // success
let c = n.is_complete(&self.positions) && n.proximity() >= self.best_proximity;
if n.is_reachable(self.contains_documents) { Some(c) } else { None }
},
);
eprintln!("BestProximity::next() took {:.02?}", before.elapsed());

204
src/iter_shortest_paths.rs Normal file
View File

@ -0,0 +1,204 @@
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::hash::Hash;
use std::usize;
use indexmap::map::Entry::{Occupied, Vacant};
use indexmap::IndexMap;
pub fn astar_bag<N, FN, IN, FH, FS>(
start: &N,
mut successors: FN,
mut heuristic: FH,
mut success: FS,
) -> Option<(AstarSolution<N>, u32)>
where
N: Eq + Hash + Clone,
FN: FnMut(&N) -> IN,
IN: IntoIterator<Item = (N, u32)>,
FH: FnMut(&N) -> u32,
FS: FnMut(&N) -> Option<bool>,
{
let mut to_see = BinaryHeap::new();
let mut min_cost = None;
let mut sinks = HashSet::new();
to_see.push(SmallestCostHolder {
estimated_cost: heuristic(start),
cost: 0,
index: 0,
});
let mut parents: IndexMap<N, (HashSet<usize>, u32)> = IndexMap::new();
parents.insert(start.clone(), (HashSet::new(), 0));
while let Some(SmallestCostHolder { cost, index, estimated_cost, .. }) = to_see.pop() {
if let Some(min_cost) = min_cost {
if estimated_cost > min_cost {
break;
}
}
let successors = {
let (node, &(_, c)) = parents.get_index(index).unwrap();
// We check that the node is even reachable and if so if it is an answer.
// If this node is unreachable we skip it.
match success(node) {
Some(success) => if success {
min_cost = Some(cost);
sinks.insert(index);
},
None => continue,
}
// We may have inserted a node several time into the binary heap if we found
// a better way to access it. Ensure that we are currently dealing with the
// best path and discard the others.
if cost > c {
continue;
}
successors(node)
};
for (successor, move_cost) in successors {
let new_cost = cost + move_cost;
let h; // heuristic(&successor)
let n; // index for successor
match parents.entry(successor) {
Vacant(e) => {
h = heuristic(e.key());
n = e.index();
let mut p = HashSet::new();
p.insert(index);
e.insert((p, new_cost));
}
Occupied(mut e) => {
if e.get().1 > new_cost {
h = heuristic(e.key());
n = e.index();
let s = e.get_mut();
s.0.clear();
s.0.insert(index);
s.1 = new_cost;
} else {
if e.get().1 == new_cost {
// New parent with an identical cost, this is not
// considered as an insertion.
e.get_mut().0.insert(index);
}
continue;
}
}
}
to_see.push(SmallestCostHolder {
estimated_cost: new_cost + h,
cost: new_cost,
index: n,
});
}
}
min_cost.map(|cost| {
let parents = parents
.into_iter()
.map(|(k, (ps, _))| (k, ps.into_iter().collect()))
.collect();
(
AstarSolution {
sinks: sinks.into_iter().collect(),
parents,
current: vec![],
terminated: false,
},
cost,
)
})
}
struct SmallestCostHolder<K> {
estimated_cost: K,
cost: K,
index: usize,
}
impl<K: PartialEq> PartialEq for SmallestCostHolder<K> {
fn eq(&self, other: &Self) -> bool {
self.estimated_cost.eq(&other.estimated_cost) && self.cost.eq(&other.cost)
}
}
impl<K: PartialEq> Eq for SmallestCostHolder<K> {}
impl<K: Ord> PartialOrd for SmallestCostHolder<K> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<K: Ord> Ord for SmallestCostHolder<K> {
fn cmp(&self, other: &Self) -> Ordering {
match other.estimated_cost.cmp(&self.estimated_cost) {
Ordering::Equal => self.cost.cmp(&other.cost),
s => s,
}
}
}
/// Iterator structure created by the `astar_bag` function.
#[derive(Clone)]
pub struct AstarSolution<N> {
sinks: Vec<usize>,
parents: Vec<(N, Vec<usize>)>,
current: Vec<Vec<usize>>,
terminated: bool,
}
impl<N: Clone + Eq + Hash> AstarSolution<N> {
fn complete(&mut self) {
loop {
let ps = match self.current.last() {
None => self.sinks.clone(),
Some(last) => {
let &top = last.last().unwrap();
self.parents(top).clone()
}
};
if ps.is_empty() {
break;
}
self.current.push(ps);
}
}
fn next_vec(&mut self) {
while self.current.last().map(Vec::len) == Some(1) {
self.current.pop();
}
self.current.last_mut().map(Vec::pop);
}
fn node(&self, i: usize) -> &N {
&self.parents[i].0
}
fn parents(&self, i: usize) -> &Vec<usize> {
&self.parents[i].1
}
}
impl<N: Clone + Eq + Hash> Iterator for AstarSolution<N> {
type Item = Vec<N>;
fn next(&mut self) -> Option<Self::Item> {
if self.terminated {
return None;
}
self.complete();
let path = self
.current
.iter()
.rev()
.map(|v| v.last().cloned().unwrap())
.map(|i| self.node(i).clone())
.collect::<Vec<_>>();
self.next_vec();
self.terminated = self.current.is_empty();
Some(path)
}
}

View File

@ -1,4 +1,5 @@
mod best_proximity;
mod iter_shortest_paths;
mod query_tokens;
use std::borrow::Cow;