feat: Added euclidean_distance

This commit is contained in:
uttarayan21
2025-06-28 17:13:26 +05:30
parent 5e8a004b1f
commit cfed5051c5
3 changed files with 98 additions and 1 deletions

View File

@@ -44,7 +44,7 @@ where
} }
#[cfg(test)] #[cfg(test)]
mod cosine_tests { mod tests {
use super::*; use super::*;
use ndarray::*; use ndarray::*;

95
src/euclidean.rs Normal file
View File

@@ -0,0 +1,95 @@
#[cfg(feature = "ndarray_15")]
use crate::ndarray_15_extra::Pow;
use ndarray::{ArrayBase, Ix1};
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum EuclideanDistanceError {
#[error(
"Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}"
)]
InvalidVectors { lhs: usize, rhs: usize },
}
pub trait EuclideanDistance<T, Rhs = Self> {
/// Computes the euclidean distance between two vectors.
///
/// A `Result` containing the euclidean distance as a `f64`, or an error if the vectors are invalid.
fn euclidean_distance(&self, rhs: Rhs) -> Result<T, EuclideanDistanceError>;
}
impl<S1, S2, T> EuclideanDistance<T, ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix1>
where
S1: ndarray::Data<Elem = T>,
S2: ndarray::Data<Elem = T>,
T: num::traits::Float + core::iter::Sum + 'static,
{
fn euclidean_distance(&self, rhs: ArrayBase<S2, Ix1>) -> Result<T, EuclideanDistanceError> {
if self.len() != rhs.len() {
return Err(EuclideanDistanceError::InvalidVectors {
lhs: self.len(),
rhs: rhs.len(),
});
}
debug_assert!(
self.iter().all(|&x| x.is_finite()),
"LHS vector contains non-finite values"
);
debug_assert!(
rhs.iter().all(|&x| x.is_finite()),
"RHS vector contains non-finite values"
);
// let numerator = self.dot(&rhs);
// let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt();
// Ok(numerator / denominator)
Ok(self
.iter()
.zip(rhs.iter())
.map(|(lhs, rhs)| (*lhs - *rhs).powi(2))
.sum::<T>()
.sqrt())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::*;
#[test]
fn test_same_vectors() {
let a = array![1.0, 2.0, 3.0];
let b = array![1.0, 2.0, 3.0];
assert_eq!(a.euclidean_distance(b).unwrap(), 0.0);
}
#[test]
fn test_orthogonal_vectors() {
let a = array![1.0, 0.0, 0.0];
let b = array![0.0, 1.0, 0.0];
assert_eq!(a.euclidean_distance(b).unwrap(), 2.0_f64.sqrt());
}
// #[test]
// fn test_invalid_vectors() {
// let a = array![1.0, 2.0];
// let b = array![1.0, 2.0, 3.0];
// assert!(matches!(
// a.euclidean_distance(b),
// Err(EuclideanDistanceError::InvalidVectors { lhs: 2, rhs: 3 })
// ));
// }
//
// #[test]
// fn test_zero_vector() {
// let a = array![0.0, 0.0, 0.0];
// let b = array![1.0, 2.0, 3.0];
// let similarity = a.euclidean_distance(b);
// assert!(similarity.is_ok_and(|item: f64| item.is_nan()));
// }
//
// #[test]
// fn test_different_ndarray_types() {
// let a = array![1.0, 2.0, 3.0];
// let b = array![1.0, 2.0, 3.0];
// assert_eq!(a.euclidean_distance(b.view()).unwrap(), 1.0);
// }
}

View File

@@ -4,3 +4,5 @@ pub mod ndarray_15_extra;
mod cosine; mod cosine;
pub use cosine::{CosineSimilarity, CosineSimilarityError}; pub use cosine::{CosineSimilarity, CosineSimilarityError};
mod euclidean;
pub use euclidean::{EuclideanDistance, EuclideanDistanceError};