mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-11-04 01:46:28 +00:00 
			
		
		
		
	Introduce an optimized version of the euclidean distance function
This commit is contained in:
		
				
					committed by
					
						
						Clément Renault
					
				
			
			
				
	
			
			
			
						parent
						
							268a9ef416
						
					
				
				
					commit
					5816008139
				
			@@ -1,6 +1,13 @@
 | 
				
			|||||||
use serde::{Deserialize, Serialize};
 | 
					use serde::{Deserialize, Serialize};
 | 
				
			||||||
use space::Metric;
 | 
					use space::Metric;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[cfg(any(
 | 
				
			||||||
 | 
					    target_arch = "x86",
 | 
				
			||||||
 | 
					    target_arch = "x86_64",
 | 
				
			||||||
 | 
					    all(target_arch = "aarch64", target_feature = "neon")
 | 
				
			||||||
 | 
					))]
 | 
				
			||||||
 | 
					const MIN_DIM_SIZE_SIMD: usize = 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
 | 
					#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
 | 
				
			||||||
pub struct DotProduct;
 | 
					pub struct DotProduct;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -26,9 +33,58 @@ impl Metric<Vec<f32>> for Euclidean {
 | 
				
			|||||||
    type Unit = u32;
 | 
					    type Unit = u32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
 | 
					    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
 | 
				
			||||||
 | 
					        #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
 | 
				
			||||||
 | 
					                let squared = unsafe { squared_euclid_neon(&a, &b) };
 | 
				
			||||||
 | 
					                let dist = squared.sqrt();
 | 
				
			||||||
 | 
					                debug_assert!(!dist.is_nan());
 | 
				
			||||||
 | 
					                return dist.to_bits();
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum();
 | 
					        let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum();
 | 
				
			||||||
        let dist = squared.sqrt();
 | 
					        let dist = squared.sqrt();
 | 
				
			||||||
        debug_assert!(!dist.is_nan());
 | 
					        debug_assert!(!dist.is_nan());
 | 
				
			||||||
        dist.to_bits()
 | 
					        dist.to_bits()
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[cfg(target_feature = "neon")]
 | 
				
			||||||
 | 
					use std::arch::aarch64::*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[cfg(target_feature = "neon")]
 | 
				
			||||||
 | 
					pub(crate) unsafe fn squared_euclid_neon(v1: &[f32], v2: &[f32]) -> f32 {
 | 
				
			||||||
 | 
					    let n = v1.len();
 | 
				
			||||||
 | 
					    let m = n - (n % 16);
 | 
				
			||||||
 | 
					    let mut ptr1: *const f32 = v1.as_ptr();
 | 
				
			||||||
 | 
					    let mut ptr2: *const f32 = v2.as_ptr();
 | 
				
			||||||
 | 
					    let mut sum1 = vdupq_n_f32(0.);
 | 
				
			||||||
 | 
					    let mut sum2 = vdupq_n_f32(0.);
 | 
				
			||||||
 | 
					    let mut sum3 = vdupq_n_f32(0.);
 | 
				
			||||||
 | 
					    let mut sum4 = vdupq_n_f32(0.);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let mut i: usize = 0;
 | 
				
			||||||
 | 
					    while i < m {
 | 
				
			||||||
 | 
					        let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2));
 | 
				
			||||||
 | 
					        sum1 = vfmaq_f32(sum1, sub1, sub1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
 | 
				
			||||||
 | 
					        sum2 = vfmaq_f32(sum2, sub2, sub2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
 | 
				
			||||||
 | 
					        sum3 = vfmaq_f32(sum3, sub3, sub3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
 | 
				
			||||||
 | 
					        sum4 = vfmaq_f32(sum4, sub4, sub4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ptr1 = ptr1.add(16);
 | 
				
			||||||
 | 
					        ptr2 = ptr2.add(16);
 | 
				
			||||||
 | 
					        i += 16;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
 | 
				
			||||||
 | 
					    for i in 0..n - m {
 | 
				
			||||||
 | 
					        result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user