#[cfg(feature = "ndarray_15")] use crate::ndarray_15_extra::Pow; use ndarray::{ArrayBase, Ix1}; #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] pub enum CosineSimilarityError { #[error( "Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}" )] InvalidVectors { lhs: usize, rhs: usize }, } pub trait CosineSimilarity { /// Computes the cosine similarity between two vectors. /// /// A `Result` containing the cosine similarity as a `f64`, or an error if the vectors are invalid. fn cosine_similarity(&self, rhs: Rhs) -> Result; } impl CosineSimilarity> for ArrayBase where S1: ndarray::Data, S2: ndarray::Data, T: num::traits::Float + 'static, { fn cosine_similarity(&self, rhs: ArrayBase) -> Result { if self.len() != rhs.len() { return Err(CosineSimilarityError::InvalidVectors { lhs: self.len(), rhs: rhs.len(), }); } debug_assert!( self.iter().all(|&x| x.is_finite()), "LHS vector contains non-finite values" ); debug_assert!( rhs.iter().all(|&x| x.is_finite()), "RHS vector contains non-finite values" ); let numerator = self.dot(&rhs); let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt(); Ok(numerator / denominator) } } #[cfg(test)] mod cosine_tests { use super::*; use ndarray::*; #[test] fn test_same_vectors() { let a = array![1.0, 2.0, 3.0]; let b = array![1.0, 2.0, 3.0]; assert_eq!(a.cosine_similarity(b).unwrap(), 1.0); } #[test] fn test_orthogonal_vectors() { let a = array![1.0, 0.0, 0.0]; let b = array![0.0, 1.0, 0.0]; assert_eq!(a.cosine_similarity(b).unwrap(), 0.0); } #[test] fn test_opposite_vectors() { let a = array![1.0, 2.0, 3.0]; let b = array![-1.0, -2.0, -3.0]; assert_eq!(a.cosine_similarity(b).unwrap(), -1.0); } #[test] fn test_invalid_vectors() { let a = array![1.0, 2.0]; let b = array![1.0, 2.0, 3.0]; assert!(matches!( a.cosine_similarity(b), Err(CosineSimilarityError::InvalidVectors { lhs: 2, rhs: 3 }) )); } #[test] fn test_zero_vector() { let a = array![0.0, 0.0, 0.0]; let b = array![1.0, 2.0, 3.0]; let similarity = a.cosine_similarity(b); assert!(similarity.is_ok_and(|item: f64| item.is_nan())); } #[test] fn test_different_ndarray_types() { let a = array![1.0, 2.0, 3.0]; let b = array![1.0, 2.0, 3.0]; assert_eq!(a.cosine_similarity(b.view()).unwrap(), 1.0); } }