feat: initial commit with cosine similarity

This commit is contained in:
uttarayan21
2025-06-23 15:18:36 +05:30
commit 669d1bf568
11 changed files with 920 additions and 0 deletions

93
src/cosine.rs Normal file
View File

@@ -0,0 +1,93 @@
use ndarray::{ArrayBase, Ix1};
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum CosineSimilarityError {
#[error(
"Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}"
)]
InvalidVectors { lhs: usize, rhs: usize },
}
pub trait CosineSimilarity<T, Rhs = Self> {
/// Computes the cosine similarity between two vectors.
///
/// A `Result` containing the cosine similarity as a `f64`, or an error if the vectors are invalid.
fn cosine_similarity(&self, rhs: Rhs) -> Result<T, CosineSimilarityError>;
}
impl<S1, S2, T> CosineSimilarity<T, ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix1>
where
S1: ndarray::Data<Elem = T>,
S2: ndarray::Data<Elem = T>,
T: num::traits::Float + 'static,
{
fn cosine_similarity(&self, rhs: ArrayBase<S2, Ix1>) -> Result<T, CosineSimilarityError> {
if self.len() != rhs.len() {
return Err(CosineSimilarityError::InvalidVectors {
lhs: self.len(),
rhs: rhs.len(),
});
}
debug_assert!(
self.iter().all(|&x| x.is_finite()),
"LHS vector contains non-finite values"
);
debug_assert!(
rhs.iter().all(|&x| x.is_finite()),
"RHS vector contains non-finite values"
);
let numerator = self.dot(&rhs);
let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt();
Ok(numerator / denominator)
}
}
#[cfg(test)]
mod cosine_tests {
use super::*;
use ndarray::*;
#[test]
fn test_same_vectors() {
let a = array![1.0, 2.0, 3.0];
let b = array![1.0, 2.0, 3.0];
assert_eq!(a.cosine_similarity(b).unwrap(), 1.0);
}
#[test]
fn test_orthogonal_vectors() {
let a = array![1.0, 0.0, 0.0];
let b = array![0.0, 1.0, 0.0];
assert_eq!(a.cosine_similarity(b).unwrap(), 0.0);
}
#[test]
fn test_opposite_vectors() {
let a = array![1.0, 2.0, 3.0];
let b = array![-1.0, -2.0, -3.0];
assert_eq!(a.cosine_similarity(b).unwrap(), -1.0);
}
#[test]
fn test_invalid_vectors() {
let a = array![1.0, 2.0];
let b = array![1.0, 2.0, 3.0];
assert!(matches!(
a.cosine_similarity(b),
Err(CosineSimilarityError::InvalidVectors { lhs: 2, rhs: 3 })
));
}
#[test]
fn test_zero_vector() {
let a = array![0.0, 0.0, 0.0];
let b = array![1.0, 2.0, 3.0];
let similarity = a.cosine_similarity(b);
assert!(similarity.is_ok_and(|item: f64| item.is_nan()));
}
#[test]
fn test_different_ndarray_types() {
let a = array![1.0, 2.0, 3.0];
let b = array![1.0, 2.0, 3.0];
assert_eq!(a.cosine_similarity(b.view()).unwrap(), 1.0);
}
}