From cfed5051c55ded5ca0b84c13d0e9501d36de1cd5 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Sat, 28 Jun 2025 17:13:26 +0530 Subject: [PATCH] feat: Added euclidean_distance --- src/cosine.rs | 2 +- src/euclidean.rs | 95 ++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 src/euclidean.rs diff --git a/src/cosine.rs b/src/cosine.rs index 6c5d3a1..be5a99f 100644 --- a/src/cosine.rs +++ b/src/cosine.rs @@ -44,7 +44,7 @@ where } #[cfg(test)] -mod cosine_tests { +mod tests { use super::*; use ndarray::*; diff --git a/src/euclidean.rs b/src/euclidean.rs new file mode 100644 index 0000000..3413bbd --- /dev/null +++ b/src/euclidean.rs @@ -0,0 +1,95 @@ +#[cfg(feature = "ndarray_15")] +use crate::ndarray_15_extra::Pow; +use ndarray::{ArrayBase, Ix1}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum EuclideanDistanceError { + #[error( + "Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}" + )] + InvalidVectors { lhs: usize, rhs: usize }, +} +pub trait EuclideanDistance { + /// Computes the euclidean distance between two vectors. + /// + /// A `Result` containing the euclidean distance as a `f64`, or an error if the vectors are invalid. + fn euclidean_distance(&self, rhs: Rhs) -> Result; +} + +impl EuclideanDistance> for ArrayBase +where + S1: ndarray::Data, + S2: ndarray::Data, + T: num::traits::Float + core::iter::Sum + 'static, +{ + fn euclidean_distance(&self, rhs: ArrayBase) -> Result { + if self.len() != rhs.len() { + return Err(EuclideanDistanceError::InvalidVectors { + lhs: self.len(), + rhs: rhs.len(), + }); + } + debug_assert!( + self.iter().all(|&x| x.is_finite()), + "LHS vector contains non-finite values" + ); + debug_assert!( + rhs.iter().all(|&x| x.is_finite()), + "RHS vector contains non-finite values" + ); + // let numerator = self.dot(&rhs); + // let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt(); + // Ok(numerator / denominator) + Ok(self + .iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| (*lhs - *rhs).powi(2)) + .sum::() + .sqrt()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::*; + + #[test] + fn test_same_vectors() { + let a = array![1.0, 2.0, 3.0]; + let b = array![1.0, 2.0, 3.0]; + assert_eq!(a.euclidean_distance(b).unwrap(), 0.0); + } + + #[test] + fn test_orthogonal_vectors() { + let a = array![1.0, 0.0, 0.0]; + let b = array![0.0, 1.0, 0.0]; + assert_eq!(a.euclidean_distance(b).unwrap(), 2.0_f64.sqrt()); + } + + // #[test] + // fn test_invalid_vectors() { + // let a = array![1.0, 2.0]; + // let b = array![1.0, 2.0, 3.0]; + // assert!(matches!( + // a.euclidean_distance(b), + // Err(EuclideanDistanceError::InvalidVectors { lhs: 2, rhs: 3 }) + // )); + // } + // + // #[test] + // fn test_zero_vector() { + // let a = array![0.0, 0.0, 0.0]; + // let b = array![1.0, 2.0, 3.0]; + // let similarity = a.euclidean_distance(b); + // assert!(similarity.is_ok_and(|item: f64| item.is_nan())); + // } + // + // #[test] + // fn test_different_ndarray_types() { + // let a = array![1.0, 2.0, 3.0]; + // let b = array![1.0, 2.0, 3.0]; + // assert_eq!(a.euclidean_distance(b.view()).unwrap(), 1.0); + // } +} diff --git a/src/lib.rs b/src/lib.rs index cce7204..29cdb15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,3 +4,5 @@ pub mod ndarray_15_extra; mod cosine; pub use cosine::{CosineSimilarity, CosineSimilarityError}; +mod euclidean; +pub use euclidean::{EuclideanDistance, EuclideanDistanceError};