Small commit to add hybrid search and autoembedding

This commit is contained in:
Louis Dureuil
2023-11-15 15:46:37 +01:00
parent 21bcf32109
commit 13c2c6c16b
42 changed files with 4045 additions and 246 deletions

View File

@ -1,3 +1,6 @@
use std::cmp::Ordering;
use itertools::Itertools;
use serde::Serialize;
use crate::distance_between_two_points;
@ -12,9 +15,24 @@ pub enum ScoreDetails {
ExactAttribute(ExactAttribute),
ExactWords(ExactWords),
Sort(Sort),
Vector(Vector),
GeoSort(GeoSort),
}
#[derive(Clone, Copy)]
pub enum ScoreValue<'a> {
Score(f64),
Sort(&'a Sort),
GeoSort(&'a GeoSort),
}
enum RankOrValue<'a> {
Rank(Rank),
Sort(&'a Sort),
GeoSort(&'a GeoSort),
Score(f64),
}
impl ScoreDetails {
pub fn local_score(&self) -> Option<f64> {
self.rank().map(Rank::local_score)
@ -31,11 +49,55 @@ impl ScoreDetails {
ScoreDetails::ExactWords(details) => Some(details.rank()),
ScoreDetails::Sort(_) => None,
ScoreDetails::GeoSort(_) => None,
ScoreDetails::Vector(_) => None,
}
}
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 {
Rank::global_score(details.filter_map(Self::rank))
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self> + 'a) -> f64 {
Self::score_values(details)
.find_map(|x| {
let ScoreValue::Score(score) = x else {
return None;
};
Some(score)
})
.unwrap_or(1.0f64)
}
pub fn score_values<'a>(
details: impl Iterator<Item = &'a Self> + 'a,
) -> impl Iterator<Item = ScoreValue<'a>> + 'a {
details
.map(ScoreDetails::rank_or_value)
.coalesce(|left, right| match (left, right) {
(RankOrValue::Rank(left), RankOrValue::Rank(right)) => {
Ok(RankOrValue::Rank(Rank::merge(left, right)))
}
(left, right) => Err((left, right)),
})
.map(|rank_or_value| match rank_or_value {
RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()),
RankOrValue::Sort(s) => ScoreValue::Sort(s),
RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g),
RankOrValue::Score(s) => ScoreValue::Score(s),
})
}
fn rank_or_value(&self) -> RankOrValue<'_> {
match self {
ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()),
ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()),
ScoreDetails::Proximity(p) => RankOrValue::Rank(*p),
ScoreDetails::Fid(f) => RankOrValue::Rank(*f),
ScoreDetails::Position(p) => RankOrValue::Rank(*p),
ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()),
ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()),
ScoreDetails::Sort(sort) => RankOrValue::Sort(sort),
ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort),
ScoreDetails::Vector(vector) => RankOrValue::Score(
vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64),
),
}
}
/// Panics
@ -181,6 +243,19 @@ impl ScoreDetails {
details_map.insert(sort, sort_details);
order += 1;
}
ScoreDetails::Vector(s) => {
let vector = format!("vectorSort({:?})", s.target_vector);
let value = s.value_similarity.as_ref().map(|(v, _)| v);
let similarity = s.value_similarity.as_ref().map(|(_, s)| s);
let details = serde_json::json!({
"order": order,
"value": value,
"similarity": similarity,
});
details_map.insert(vector, details);
order += 1;
}
}
}
details_map
@ -297,15 +372,21 @@ impl Rank {
pub fn global_score(details: impl Iterator<Item = Self>) -> f64 {
let mut rank = Rank { rank: 1, max_rank: 1 };
for inner_rank in details {
rank.rank -= 1;
rank.rank *= inner_rank.max_rank;
rank.max_rank *= inner_rank.max_rank;
rank.rank += inner_rank.rank;
rank = Rank::merge(rank, inner_rank);
}
rank.local_score()
}
pub fn merge(mut outer: Rank, inner: Rank) -> Rank {
outer.rank = outer.rank.saturating_sub(1);
outer.rank *= inner.max_rank;
outer.max_rank *= inner.max_rank;
outer.rank += inner.rank;
outer
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
@ -335,13 +416,78 @@ pub struct Sort {
pub value: serde_json::Value,
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
impl PartialOrd for Sort {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.field_name != other.field_name {
return None;
}
if self.ascending != other.ascending {
return None;
}
match (&self.value, &other.value) {
(serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal),
(serde_json::Value::Null, _) => Some(Ordering::Less),
(_, serde_json::Value::Null) => Some(Ordering::Greater),
// numbers are always before strings
(serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater),
(serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less),
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
// FIXME: unwrap permitted here?
let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?;
// 12 < 42, and when ascending, we want to see 12 first, so the smallest.
// Hence, when ascending, smaller is better
Some(if self.ascending { order.reverse() } else { order })
}
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
let order = left.cmp(right);
// Taking e.g. "a" and "z"
// "a" < "z", and when ascending, we want to see "a" first, so the smallest.
// Hence, when ascending, smaller is better
Some(if self.ascending { order.reverse() } else { order })
}
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct GeoSort {
pub target_point: [f64; 2],
pub ascending: bool,
pub value: Option<[f64; 2]>,
}
impl PartialOrd for GeoSort {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.target_point != other.target_point {
return None;
}
if self.ascending != other.ascending {
return None;
}
Some(match (self.distance(), other.distance()) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(left), Some(right)) => {
let order = left.partial_cmp(&right)?;
if self.ascending {
// when ascending, the one with the smallest distance has the best score
order.reverse()
} else {
order
}
}
})
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct Vector {
pub target_vector: Vec<f32>,
pub value_similarity: Option<(Vec<f32>, f32)>,
}
impl GeoSort {
pub fn distance(&self) -> Option<f64> {
self.value.map(|value| distance_between_two_points(&self.target_point, &value))