feat: Added euclidean_distance
This commit is contained in:
@@ -44,7 +44,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod cosine_tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ndarray::*;
|
use ndarray::*;
|
||||||
|
|
||||||
|
|||||||
95
src/euclidean.rs
Normal file
95
src/euclidean.rs
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
#[cfg(feature = "ndarray_15")]
|
||||||
|
use crate::ndarray_15_extra::Pow;
|
||||||
|
use ndarray::{ArrayBase, Ix1};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
|
||||||
|
pub enum EuclideanDistanceError {
|
||||||
|
#[error(
|
||||||
|
"Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}"
|
||||||
|
)]
|
||||||
|
InvalidVectors { lhs: usize, rhs: usize },
|
||||||
|
}
|
||||||
|
pub trait EuclideanDistance<T, Rhs = Self> {
|
||||||
|
/// Computes the euclidean distance between two vectors.
|
||||||
|
///
|
||||||
|
/// A `Result` containing the euclidean distance as a `f64`, or an error if the vectors are invalid.
|
||||||
|
fn euclidean_distance(&self, rhs: Rhs) -> Result<T, EuclideanDistanceError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S1, S2, T> EuclideanDistance<T, ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix1>
|
||||||
|
where
|
||||||
|
S1: ndarray::Data<Elem = T>,
|
||||||
|
S2: ndarray::Data<Elem = T>,
|
||||||
|
T: num::traits::Float + core::iter::Sum + 'static,
|
||||||
|
{
|
||||||
|
fn euclidean_distance(&self, rhs: ArrayBase<S2, Ix1>) -> Result<T, EuclideanDistanceError> {
|
||||||
|
if self.len() != rhs.len() {
|
||||||
|
return Err(EuclideanDistanceError::InvalidVectors {
|
||||||
|
lhs: self.len(),
|
||||||
|
rhs: rhs.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
debug_assert!(
|
||||||
|
self.iter().all(|&x| x.is_finite()),
|
||||||
|
"LHS vector contains non-finite values"
|
||||||
|
);
|
||||||
|
debug_assert!(
|
||||||
|
rhs.iter().all(|&x| x.is_finite()),
|
||||||
|
"RHS vector contains non-finite values"
|
||||||
|
);
|
||||||
|
// let numerator = self.dot(&rhs);
|
||||||
|
// let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt();
|
||||||
|
// Ok(numerator / denominator)
|
||||||
|
Ok(self
|
||||||
|
.iter()
|
||||||
|
.zip(rhs.iter())
|
||||||
|
.map(|(lhs, rhs)| (*lhs - *rhs).powi(2))
|
||||||
|
.sum::<T>()
|
||||||
|
.sqrt())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use ndarray::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_same_vectors() {
|
||||||
|
let a = array![1.0, 2.0, 3.0];
|
||||||
|
let b = array![1.0, 2.0, 3.0];
|
||||||
|
assert_eq!(a.euclidean_distance(b).unwrap(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_orthogonal_vectors() {
|
||||||
|
let a = array![1.0, 0.0, 0.0];
|
||||||
|
let b = array![0.0, 1.0, 0.0];
|
||||||
|
assert_eq!(a.euclidean_distance(b).unwrap(), 2.0_f64.sqrt());
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn test_invalid_vectors() {
|
||||||
|
// let a = array![1.0, 2.0];
|
||||||
|
// let b = array![1.0, 2.0, 3.0];
|
||||||
|
// assert!(matches!(
|
||||||
|
// a.euclidean_distance(b),
|
||||||
|
// Err(EuclideanDistanceError::InvalidVectors { lhs: 2, rhs: 3 })
|
||||||
|
// ));
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// #[test]
|
||||||
|
// fn test_zero_vector() {
|
||||||
|
// let a = array![0.0, 0.0, 0.0];
|
||||||
|
// let b = array![1.0, 2.0, 3.0];
|
||||||
|
// let similarity = a.euclidean_distance(b);
|
||||||
|
// assert!(similarity.is_ok_and(|item: f64| item.is_nan()));
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// #[test]
|
||||||
|
// fn test_different_ndarray_types() {
|
||||||
|
// let a = array![1.0, 2.0, 3.0];
|
||||||
|
// let b = array![1.0, 2.0, 3.0];
|
||||||
|
// assert_eq!(a.euclidean_distance(b.view()).unwrap(), 1.0);
|
||||||
|
// }
|
||||||
|
}
|
||||||
@@ -4,3 +4,5 @@ pub mod ndarray_15_extra;
|
|||||||
|
|
||||||
mod cosine;
|
mod cosine;
|
||||||
pub use cosine::{CosineSimilarity, CosineSimilarityError};
|
pub use cosine::{CosineSimilarity, CosineSimilarityError};
|
||||||
|
mod euclidean;
|
||||||
|
pub use euclidean::{EuclideanDistance, EuclideanDistanceError};
|
||||||
|
|||||||
Reference in New Issue
Block a user