diff --git a/src/cosine.rs b/src/cosine.rs index c272929..94d6b3a 100644 --- a/src/cosine.rs +++ b/src/cosine.rs @@ -52,15 +52,15 @@ where } } -impl CosineSimilarity> for &ArrayBase +impl CosineSimilarity> for ArrayBase where S1: ndarray::Data, S2: ndarray::Data, T: num::traits::Float + 'static, { type Output = T; - fn cosine_similarity(&self, rhs: ArrayBase) -> Result { - (*self).cosine_similarity(rhs) + fn cosine_similarity(&self, rhs: &ArrayBase) -> Result { + self.cosine_similarity(rhs.view()) } }