From 5e8a004b1f2cebb4b06430771228bc64ca321fd5 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 23 Jun 2025 15:38:28 +0530 Subject: [PATCH] feat: Added ndarray_15 feature --- Cargo.lock | 16 +++++++++++++++- Cargo.toml | 6 +++++- flake.nix | 1 + src/cosine.rs | 3 +++ src/lib.rs | 4 ++++ src/ndarray_15_extra.rs | 15 +++++++++++++++ 6 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 src/ndarray_15_extra.rs diff --git a/Cargo.lock b/Cargo.lock index 45485d0..1545c18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,6 +18,19 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -37,7 +50,8 @@ dependencies = [ name = "ndarray-math" version = "0.1.0" dependencies = [ - "ndarray", + "ndarray 0.15.6", + "ndarray 0.16.1", "num", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index 820494b..a5de811 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,10 @@ license = "MIT" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ndarray = "0.16" +ndarray = { version = "0.16" } +ndarray_15 = { package = "ndarray", version = "0.15.6", optional = true } num = "0.4.3" thiserror = "2.0.12" + +[features] +ndarray_15 = ["dep:ndarray_15"] diff --git a/flake.nix b/flake.nix index 156695f..3fa7bbc 100644 --- a/flake.nix +++ b/flake.nix @@ -132,6 +132,7 @@ stableToolchainWithRustAnalyzer cargo-nextest cargo-deny + cargo-hack ] ++ (lib.optionals pkgs.stdenv.isDarwin [ apple-sdk_13 diff --git a/src/cosine.rs b/src/cosine.rs index 9799a97..6c5d3a1 100644 --- a/src/cosine.rs +++ b/src/cosine.rs @@ -1,4 +1,7 @@ +#[cfg(feature = "ndarray_15")] +use crate::ndarray_15_extra::Pow; use ndarray::{ArrayBase, Ix1}; + #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] pub enum CosineSimilarityError { #[error( diff --git a/src/lib.rs b/src/lib.rs index 1cbdf13..cce7204 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,6 @@ +#[cfg(feature = "ndarray_15")] +extern crate ndarray_15 as ndarray; +pub mod ndarray_15_extra; + mod cosine; pub use cosine::{CosineSimilarity, CosineSimilarityError}; diff --git a/src/ndarray_15_extra.rs b/src/ndarray_15_extra.rs new file mode 100644 index 0000000..5a1f4d3 --- /dev/null +++ b/src/ndarray_15_extra.rs @@ -0,0 +1,15 @@ +pub trait Pow { + type Output; + fn powi(&self, rhs: i32) -> Self::Output; +} +impl Pow for ndarray::ArrayBase +where + S: ndarray::Data, + T: num::Float, + D: ndarray::Dimension, +{ + type Output = ndarray::Array; + fn powi(&self, rhs: i32) -> Self::Output { + self.mapv(|x| x.powi(rhs)) + } +}