From 091a75ac9e3cec688f5b322d2bb37c2aaf7f69a4 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 30 Jun 2025 12:23:40 +0530 Subject: [PATCH] feat: Add sqrt for ndarray_15_extra and add some preliminary support for cosine similarity for matrices --- src/cosine.rs | 80 +++++++++++++++++++++++++++++++++++++++-- src/euclidean.rs | 4 +-- src/ndarray_15_extra.rs | 17 +++++++++ 3 files changed, 97 insertions(+), 4 deletions(-) diff --git a/src/cosine.rs b/src/cosine.rs index be5a99f..c272929 100644 --- a/src/cosine.rs +++ b/src/cosine.rs @@ -1,5 +1,5 @@ #[cfg(feature = "ndarray_15")] -use crate::ndarray_15_extra::Pow; +use crate::ndarray_15_extra::*; use ndarray::{ArrayBase, Ix1}; #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] @@ -8,12 +8,20 @@ pub enum CosineSimilarityError { "Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}" )] InvalidVectors { lhs: usize, rhs: usize }, + // #[error( + // "Invalid matrices: Matrices must have the same shape for similarity calculation. LHS: {}x{}, RHS: {}x{}", lhs.0, lhs.1, rhs.0, rhs.1 + // )] + // InvalidMatrices { + // lhs: (usize, usize), + // rhs: (usize, usize), + // }, } pub trait CosineSimilarity { /// Computes the cosine similarity between two vectors. /// /// A `Result` containing the cosine similarity as a `f64`, or an error if the vectors are invalid. - fn cosine_similarity(&self, rhs: Rhs) -> Result; + type Output; + fn cosine_similarity(&self, rhs: Rhs) -> Result; } impl CosineSimilarity> for ArrayBase @@ -22,6 +30,7 @@ where S2: ndarray::Data, T: num::traits::Float + 'static, { + type Output = T; fn cosine_similarity(&self, rhs: ArrayBase) -> Result { if self.len() != rhs.len() { return Err(CosineSimilarityError::InvalidVectors { @@ -43,6 +52,54 @@ where } } +impl CosineSimilarity> for &ArrayBase +where + S1: ndarray::Data, + S2: ndarray::Data, + T: num::traits::Float + 'static, +{ + type Output = T; + fn cosine_similarity(&self, rhs: ArrayBase) -> Result { + (*self).cosine_similarity(rhs) + } +} + +// impl CosineSimilarity> for ArrayBase +// where +// S1: ndarray::Data, +// S2: ndarray::Data, +// T: num::traits::Float + 'static, +// T: core::fmt::Debug, +// { +// type Output = Array; +// fn cosine_similarity( +// &self, +// rhs: ArrayBase, +// ) -> Result { +// if self.dim() != rhs.dim() { +// return Err(CosineSimilarityError::InvalidMatrices { +// lhs: self.dim(), +// rhs: rhs.dim(), +// }); +// } +// debug_assert!( +// self.iter().all(|&x| x.is_finite()), +// "LHS matrix contains non-finite values" +// ); +// debug_assert!( +// rhs.iter().all(|&x| x.is_finite()), +// "RHS matrix contains non-finite values" +// ); +// let numerator = self.dot(&rhs.t()); +// let lhs_norm = self.powi(2).sum().sqrt(); +// let rhs_norm = rhs.powi(2).sum().sqrt(); +// dbg!(&lhs_norm, &rhs_norm); +// +// let denominator = lhs_norm * rhs_norm.t(); +// Ok(numerator / denominator) +// } +// } + #[cfg(test)] mod tests { use super::*; @@ -93,4 +150,23 @@ mod tests { let b = array![1.0, 2.0, 3.0]; assert_eq!(a.cosine_similarity(b.view()).unwrap(), 1.0); } + + // #[test] + // fn test_similarity_with_same_matrices() { + // let a = array![[1.0, 2.0], [3.0, 4.0]]; + // let b = array![[1.0, 2.0], [3.0, 4.0]]; + // assert_eq!( + // a.cosine_similarity(b).unwrap(), + // array![[1.0, 1.0], [1.0, 1.0]] + // ); + // } + // #[test] + // fn test_similarity_with_matrices() { + // let a = array![[1.0, 2.0], [3.0, 4.0]]; + // let b = array![[5.0, 6.0], [7.0, 8.0]]; + // assert_eq!( + // a.cosine_similarity(b).unwrap(), + // array![[0.2358, 0.3191], [0.5410, 0.7353]] + // ); + // } } diff --git a/src/euclidean.rs b/src/euclidean.rs index 3413bbd..9c5bab6 100644 --- a/src/euclidean.rs +++ b/src/euclidean.rs @@ -1,5 +1,5 @@ -#[cfg(feature = "ndarray_15")] -use crate::ndarray_15_extra::Pow; +// #[cfg(feature = "ndarray_15")] +// use crate::ndarray_15_extra::*; use ndarray::{ArrayBase, Ix1}; #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] diff --git a/src/ndarray_15_extra.rs b/src/ndarray_15_extra.rs index 5a1f4d3..41b654c 100644 --- a/src/ndarray_15_extra.rs +++ b/src/ndarray_15_extra.rs @@ -13,3 +13,20 @@ where self.mapv(|x| x.powi(rhs)) } } + +pub trait Sqrt { + type Output; + fn sqrt(&self) -> Self::Output; +} + +impl Sqrt for ndarray::ArrayBase +where + S: ndarray::Data, + T: num::Float, + D: ndarray::Dimension, +{ + type Output = ndarray::Array; + fn sqrt(&self) -> Self::Output { + self.mapv(|x| x.sqrt()) + } +}