feat: initial commit with cosine similarity
This commit is contained in:
93
src/cosine.rs
Normal file
93
src/cosine.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use ndarray::{ArrayBase, Ix1};
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum CosineSimilarityError {
|
||||
#[error(
|
||||
"Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}"
|
||||
)]
|
||||
InvalidVectors { lhs: usize, rhs: usize },
|
||||
}
|
||||
pub trait CosineSimilarity<T, Rhs = Self> {
|
||||
/// Computes the cosine similarity between two vectors.
|
||||
///
|
||||
/// A `Result` containing the cosine similarity as a `f64`, or an error if the vectors are invalid.
|
||||
fn cosine_similarity(&self, rhs: Rhs) -> Result<T, CosineSimilarityError>;
|
||||
}
|
||||
|
||||
impl<S1, S2, T> CosineSimilarity<T, ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix1>
|
||||
where
|
||||
S1: ndarray::Data<Elem = T>,
|
||||
S2: ndarray::Data<Elem = T>,
|
||||
T: num::traits::Float + 'static,
|
||||
{
|
||||
fn cosine_similarity(&self, rhs: ArrayBase<S2, Ix1>) -> Result<T, CosineSimilarityError> {
|
||||
if self.len() != rhs.len() {
|
||||
return Err(CosineSimilarityError::InvalidVectors {
|
||||
lhs: self.len(),
|
||||
rhs: rhs.len(),
|
||||
});
|
||||
}
|
||||
debug_assert!(
|
||||
self.iter().all(|&x| x.is_finite()),
|
||||
"LHS vector contains non-finite values"
|
||||
);
|
||||
debug_assert!(
|
||||
rhs.iter().all(|&x| x.is_finite()),
|
||||
"RHS vector contains non-finite values"
|
||||
);
|
||||
let numerator = self.dot(&rhs);
|
||||
let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt();
|
||||
Ok(numerator / denominator)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod cosine_tests {
|
||||
use super::*;
|
||||
use ndarray::*;
|
||||
|
||||
#[test]
|
||||
fn test_same_vectors() {
|
||||
let a = array![1.0, 2.0, 3.0];
|
||||
let b = array![1.0, 2.0, 3.0];
|
||||
assert_eq!(a.cosine_similarity(b).unwrap(), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_orthogonal_vectors() {
|
||||
let a = array![1.0, 0.0, 0.0];
|
||||
let b = array![0.0, 1.0, 0.0];
|
||||
assert_eq!(a.cosine_similarity(b).unwrap(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_opposite_vectors() {
|
||||
let a = array![1.0, 2.0, 3.0];
|
||||
let b = array![-1.0, -2.0, -3.0];
|
||||
assert_eq!(a.cosine_similarity(b).unwrap(), -1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_vectors() {
|
||||
let a = array![1.0, 2.0];
|
||||
let b = array![1.0, 2.0, 3.0];
|
||||
assert!(matches!(
|
||||
a.cosine_similarity(b),
|
||||
Err(CosineSimilarityError::InvalidVectors { lhs: 2, rhs: 3 })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_vector() {
|
||||
let a = array![0.0, 0.0, 0.0];
|
||||
let b = array![1.0, 2.0, 3.0];
|
||||
let similarity = a.cosine_similarity(b);
|
||||
assert!(similarity.is_ok_and(|item: f64| item.is_nan()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_ndarray_types() {
|
||||
let a = array![1.0, 2.0, 3.0];
|
||||
let b = array![1.0, 2.0, 3.0];
|
||||
assert_eq!(a.cosine_similarity(b.view()).unwrap(), 1.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user