From 669d1bf5681811462987813f8717e930b5a0a39b Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 23 Jun 2025 15:18:36 +0530 Subject: [PATCH] feat: initial commit with cosine similarity --- .envrc | 1 + .github/workflows/build.yaml | 59 +++++++++ .github/workflows/docs.yaml | 38 ++++++ .gitignore | 3 + Cargo.lock | 192 ++++++++++++++++++++++++++++ Cargo.toml | 12 ++ deny.toml | 236 +++++++++++++++++++++++++++++++++++ flake.lock | 136 ++++++++++++++++++++ flake.nix | 148 ++++++++++++++++++++++ src/cosine.rs | 93 ++++++++++++++ src/lib.rs | 2 + 11 files changed, 920 insertions(+) create mode 100644 .envrc create mode 100644 .github/workflows/build.yaml create mode 100644 .github/workflows/docs.yaml create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 deny.toml create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 src/cosine.rs create mode 100644 src/lib.rs diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..4f98b75 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,59 @@ +name: build + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +env: + CARGO_TERM_COLOR: always + +jobs: + checks-matrix: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + - id: set-matrix + name: Generate Nix Matrix + run: | + set -Eeu + matrix="$(nix eval --json '.#githubActions.matrix')" + echo "matrix=$matrix" >> "$GITHUB_OUTPUT" + + checks-build: + needs: checks-matrix + runs-on: ${{ matrix.os }} + strategy: + matrix: ${{fromJSON(needs.checks-matrix.outputs.matrix)}} + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + - run: nix build -L '.#${{ matrix.attr }}' + + codecov: + runs-on: ubuntu-latest + permissions: + id-token: "write" + contents: "read" + + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + + - name: Run codecov + run: nix build .#checks.x86_64-linux.ndarray-math-llvm-cov + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + flags: unittests + name: codecov-ndarray-math + fail_ci_if_error: true + token: ${{ secrets.CODECOV_TOKEN }} + files: ./result + verbose: true + diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 0000000..cce0ccb --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,38 @@ +name: docs + +on: + push: + branches: [ master ] + +env: + CARGO_TERM_COLOR: always + +jobs: + docs: + runs-on: ubuntu-latest + permissions: + id-token: "write" + contents: "read" + pages: "write" + + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: DeterminateSystems/flake-checker-action@main + + - name: Generate docs + run: nix build .#checks.x86_64-linux.ndarray-math-docs + + - name: Setup Pages + uses: actions/configure-pages@v5 + + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: result/share/doc + + - name: Deploy to gh-pages + id: deployment + uses: actions/deploy-pages@v4 + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6210698 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/result +/target +.direnv diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..45485d0 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,192 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray-math" +version = "0.1.0" +dependencies = [ + "ndarray", + "num", + "thiserror", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "proc-macro2" +version = "1.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "syn" +version = "2.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..820494b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ndarray-math" +version = "0.1.0" +edition = "2024" +license = "MIT" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ndarray = "0.16" +num = "0.4.3" +thiserror = "2.0.12" diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..5578e66 --- /dev/null +++ b/deny.toml @@ -0,0 +1,236 @@ +# This template contains all of the possible sections and their default values + +# Note that all fields that take a lint level have these possible values: +# * deny - An error will be produced and the check will fail +# * warn - A warning will be produced, but the check will not fail +# * allow - No warning or error will be produced, though in some cases a note +# will be + +# The values provided in this template are the default values that will be used +# when any section or field is not specified in your own configuration + +# Root options + +# The graph table configures how the dependency graph is constructed and thus +# which crates the checks are performed against +[graph] +# If 1 or more target triples (and optionally, target_features) are specified, +# only the specified targets will be checked when running `cargo deny check`. +# This means, if a particular package is only ever used as a target specific +# dependency, such as, for example, the `nix` crate only being used via the +# `target_family = "unix"` configuration, that only having windows targets in +# this list would mean the nix crate, as well as any of its exclusive +# dependencies not shared by any other crates, would be ignored, as the target +# list here is effectively saying which targets you are building for. +targets = [ + # The triple can be any string, but only the target triples built in to + # rustc (as of 1.40) can be checked against actual config expressions + #"x86_64-unknown-linux-musl", + # You can also specify which target_features you promise are enabled for a + # particular target. target_features are currently not validated against + # the actual valid features supported by the target architecture. + #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, +] +# When creating the dependency graph used as the source of truth when checks are +# executed, this field can be used to prune crates from the graph, removing them +# from the view of cargo-deny. This is an extremely heavy hammer, as if a crate +# is pruned from the graph, all of its dependencies will also be pruned unless +# they are connected to another crate in the graph that hasn't been pruned, +# so it should be used with care. The identifiers are [Package ID Specifications] +# (https://doc.rust-lang.org/cargo/reference/pkgid-spec.html) +#exclude = [] +# If true, metadata will be collected with `--all-features`. Note that this can't +# be toggled off if true, if you want to conditionally enable `--all-features` it +# is recommended to pass `--all-features` on the cmd line instead +all-features = false +# If true, metadata will be collected with `--no-default-features`. The same +# caveat with `all-features` applies +no-default-features = false +# If set, these feature will be enabled when collecting metadata. If `--features` +# is specified on the cmd line they will take precedence over this option. +#features = [] + +# The output table provides options for how/if diagnostics are outputted +[output] +# When outputting inclusion graphs in diagnostics that include features, this +# option can be used to specify the depth at which feature edges will be added. +# This option is included since the graphs can be quite large and the addition +# of features from the crate(s) to all of the graph roots can be far too verbose. +# This option can be overridden via `--feature-depth` on the cmd line +feature-depth = 1 + +# This section is considered when running `cargo deny check advisories` +# More documentation for the advisories section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html +[advisories] +# The path where the advisory databases are cloned/fetched into +#db-path = "$CARGO_HOME/advisory-dbs" +# The url(s) of the advisory databases to use +#db-urls = ["https://github.com/rustsec/advisory-db"] +# A list of advisory IDs to ignore. Note that ignored advisories will still +# output a note when they are encountered. +ignore = [ + #"RUSTSEC-0000-0000", + #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, + #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish + #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, +] +# If this is true, then cargo deny will use the git executable to fetch advisory database. +# If this is false, then it uses a built-in git library. +# Setting this to true can be helpful if you have special authentication requirements that cargo-deny does not support. +# See Git Authentication for more information about setting up git authentication. +#git-fetch-with-cli = true + +# This section is considered when running `cargo deny check licenses` +# More documentation for the licenses section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html +[licenses] +# List of explicitly allowed licenses +# See https://spdx.org/licenses/ for list of possible licenses +# [possible values: any SPDX 3.11 short identifier (+ optional exception)]. +allow = [ + "MIT", + "Apache-2.0", + "Unicode-3.0", + #"Apache-2.0 WITH LLVM-exception", +] +# The confidence threshold for detecting a license from license text. +# The higher the value, the more closely the license text must be to the +# canonical license text of a valid SPDX license file. +# [possible values: any between 0.0 and 1.0]. +confidence-threshold = 0.8 +# Allow 1 or more licenses on a per-crate basis, so that particular licenses +# aren't accepted for every possible crate as with the normal allow list +exceptions = [ + # Each entry is the crate and version constraint, and its specific allow + # list + #{ allow = ["Zlib"], crate = "adler32" }, +] + +# Some crates don't have (easily) machine readable licensing information, +# adding a clarification entry for it allows you to manually specify the +# licensing information +#[[licenses.clarify]] +# The package spec the clarification applies to +#crate = "ring" +# The SPDX expression for the license requirements of the crate +#expression = "MIT AND ISC AND OpenSSL" +# One or more files in the crate's source used as the "source of truth" for +# the license expression. If the contents match, the clarification will be used +# when running the license check, otherwise the clarification will be ignored +# and the crate will be checked normally, which may produce warnings or errors +# depending on the rest of your configuration +#license-files = [ +# Each entry is a crate relative path, and the (opaque) hash of its contents +#{ path = "LICENSE", hash = 0xbd0eed23 } +#] + +[licenses.private] +# If true, ignores workspace crates that aren't published, or are only +# published to private registries. +# To see how to mark a crate as unpublished (to the official registry), +# visit https://doc.rust-lang.org/cargo/reference/manifest.html#the-publish-field. +ignore = false +# One or more private registries that you might publish crates to, if a crate +# is only published to private registries, and ignore is true, the crate will +# not have its license(s) checked +registries = [ + #"https://sekretz.com/registry +] + +# This section is considered when running `cargo deny check bans`. +# More documentation about the 'bans' section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html +[bans] +# Lint level for when multiple versions of the same crate are detected +multiple-versions = "warn" +# Lint level for when a crate version requirement is `*` +wildcards = "allow" +# The graph highlighting used when creating dotgraphs for crates +# with multiple versions +# * lowest-version - The path to the lowest versioned duplicate is highlighted +# * simplest-path - The path to the version with the fewest edges is highlighted +# * all - Both lowest-version and simplest-path are used +highlight = "all" +# The default lint level for `default` features for crates that are members of +# the workspace that is being checked. This can be overridden by allowing/denying +# `default` on a crate-by-crate basis if desired. +workspace-default-features = "allow" +# The default lint level for `default` features for external crates that are not +# members of the workspace. This can be overridden by allowing/denying `default` +# on a crate-by-crate basis if desired. +external-default-features = "allow" +# List of crates that are allowed. Use with care! +allow = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, +] +# List of crates to deny +deny = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is banned" }, + # Wrapper crates can optionally be specified to allow the crate when it + # is a direct dependency of the otherwise banned crate + #{ crate = "ansi_term@0.11.0", wrappers = ["this-crate-directly-depends-on-ansi_term"] }, +] + +# List of features to allow/deny +# Each entry the name of a crate and a version range. If version is +# not specified, all versions will be matched. +#[[bans.features]] +#crate = "reqwest" +# Features to not allow +#deny = ["json"] +# Features to allow +#allow = [ +# "rustls", +# "__rustls", +# "__tls", +# "hyper-rustls", +# "rustls", +# "rustls-pemfile", +# "rustls-tls-webpki-roots", +# "tokio-rustls", +# "webpki-roots", +#] +# If true, the allowed features must exactly match the enabled feature set. If +# this is set there is no point setting `deny` +#exact = true + +# Certain crates/versions that will be skipped when doing duplicate detection. +skip = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason why it can't be updated/removed" }, +] +# Similarly to `skip` allows you to skip certain crates during duplicate +# detection. Unlike skip, it also includes the entire tree of transitive +# dependencies starting at the specified crate, up to a certain depth, which is +# by default infinite. +skip-tree = [ + #"ansi_term@0.11.0", # will be skipped along with _all_ of its direct and transitive dependencies + #{ crate = "ansi_term@0.11.0", depth = 20 }, +] + +# This section is considered when running `cargo deny check sources`. +# More documentation about the 'sources' section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html +[sources] +# Lint level for what to happen when a crate from a crate registry that is not +# in the allow list is encountered +unknown-registry = "warn" +# Lint level for what to happen when a crate from a git repository that is not +# in the allow list is encountered +unknown-git = "warn" +# List of URLs for allowed crate registries. Defaults to the crates.io index +# if not specified. If it is specified but empty, no registries are allowed. +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +# List of URLs for allowed Git repositories +allow-git = [] + +[sources.allow-org] +# github.com organizations to allow git sources for +github = [] +# gitlab.com organizations to allow git sources for +gitlab = [] +# bitbucket.org organizations to allow git sources for +bitbucket = [] diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..cbf8dd4 --- /dev/null +++ b/flake.lock @@ -0,0 +1,136 @@ +{ + "nodes": { + "advisory-db": { + "flake": false, + "locked": { + "lastModified": 1750151065, + "narHash": "sha256-il+CAqChFIB82xP6bO43dWlUVs+NlG7a4g8liIP5HcI=", + "owner": "rustsec", + "repo": "advisory-db", + "rev": "7573f55ba337263f61167dbb0ea926cdc7c8eb5d", + "type": "github" + }, + "original": { + "owner": "rustsec", + "repo": "advisory-db", + "type": "github" + } + }, + "crane": { + "locked": { + "lastModified": 1750266157, + "narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=", + "owner": "ipetkov", + "repo": "crane", + "rev": "e37c943371b73ed87faf33f7583860f81f1d5a48", + "type": "github" + }, + "original": { + "owner": "ipetkov", + "repo": "crane", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-github-actions": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1737420293, + "narHash": "sha256-F1G5ifvqTpJq7fdkT34e/Jy9VCyzd5XfJ9TO8fHhJWE=", + "owner": "nix-community", + "repo": "nix-github-actions", + "rev": "f4158fa080ef4503c8f4c820967d946c2af31ec9", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nix-github-actions", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1750134718, + "narHash": "sha256-v263g4GbxXv87hMXMCpjkIxd/viIF7p3JpJrwgKdNiI=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "9e83b64f727c88a7711a2c463a7b16eedb69a84c", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "advisory-db": "advisory-db", + "crane": "crane", + "flake-utils": "flake-utils", + "nix-github-actions": "nix-github-actions", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1750214276, + "narHash": "sha256-1kniuhH70q4TAC/xIvjFYH46aHiLrbIlcr6fdrRwO1A=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "f9b2b2b1327ff6beab4662b8ea41689e0a57b8d4", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..156695f --- /dev/null +++ b/flake.nix @@ -0,0 +1,148 @@ +{ + description = "A simple rust flake using rust-overlay and craneLib"; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + crane.url = "github:ipetkov/crane"; + nix-github-actions = { + url = "github:nix-community/nix-github-actions"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + advisory-db = { + url = "github:rustsec/advisory-db"; + flake = false; + }; + }; + + outputs = { + self, + crane, + flake-utils, + nixpkgs, + rust-overlay, + advisory-db, + nix-github-actions, + ... + }: + flake-utils.lib.eachDefaultSystem ( + system: let + pkgs = import nixpkgs { + inherit system; + overlays = [ + rust-overlay.overlays.default + ]; + }; + inherit (pkgs) lib; + cargoToml = builtins.fromTOML (builtins.readFile ./Cargo.toml); + name = cargoToml.package.name; + + stableToolchain = pkgs.rust-bin.stable.latest.default; + stableToolchainWithLLvmTools = stableToolchain.override { + extensions = ["rust-src" "llvm-tools"]; + }; + stableToolchainWithRustAnalyzer = stableToolchain.override { + extensions = ["rust-src" "rust-analyzer"]; + }; + craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain; + craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools; + + src = let + filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts; + sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"]; + in + lib.cleanSourceWith { + filter = sourceFilters; + src = ./.; + }; + commonArgs = + { + inherit src; + pname = name; + stdenv = pkgs.clangStdenv; + doCheck = false; + # LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; + # nativeBuildInputs = with pkgs; [ + # cmake + # llvmPackages.libclang.lib + # ]; + buildInputs = with pkgs; + [] + ++ (lib.optionals pkgs.stdenv.isDarwin [ + libiconv + apple-sdk_13 + ]); + } + // (lib.optionalAttrs pkgs.stdenv.isLinux { + # BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include"; + }); + cargoArtifacts = craneLib.buildPackage commonArgs; + in { + checks = + { + "${name}-clippy" = craneLib.cargoClippy (commonArgs + // { + inherit cargoArtifacts; + cargoClippyExtraArgs = "--all-targets -- --deny warnings"; + }); + "${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;}); + "${name}-fmt" = craneLib.cargoFmt {inherit src;}; + "${name}-toml-fmt" = craneLib.taploFmt { + src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"]; + }; + # Audit dependencies + "${name}-audit" = craneLib.cargoAudit { + inherit src advisory-db; + }; + + # Audit licenses + "${name}-deny" = craneLib.cargoDeny { + inherit src; + }; + "${name}-nextest" = craneLib.cargoNextest (commonArgs + // { + inherit cargoArtifacts; + partitions = 1; + partitionType = "count"; + }); + } + // lib.optionalAttrs (!pkgs.stdenv.isDarwin) { + "${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;}); + }; + + packages = let + pkg = craneLib.buildPackage ( + commonArgs + // {inherit cargoArtifacts;} + ); + in { + "${name}" = pkg; + default = pkg; + }; + + devShells = { + default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs + // { + packages = with pkgs; + [ + stableToolchainWithRustAnalyzer + cargo-nextest + cargo-deny + ] + ++ (lib.optionals pkgs.stdenv.isDarwin [ + apple-sdk_13 + ]); + }); + }; + } + ) + // { + githubActions = nix-github-actions.lib.mkGithubMatrix { + checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks; + }; + }; +} diff --git a/src/cosine.rs b/src/cosine.rs new file mode 100644 index 0000000..9799a97 --- /dev/null +++ b/src/cosine.rs @@ -0,0 +1,93 @@ +use ndarray::{ArrayBase, Ix1}; +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum CosineSimilarityError { + #[error( + "Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}" + )] + InvalidVectors { lhs: usize, rhs: usize }, +} +pub trait CosineSimilarity { + /// Computes the cosine similarity between two vectors. + /// + /// A `Result` containing the cosine similarity as a `f64`, or an error if the vectors are invalid. + fn cosine_similarity(&self, rhs: Rhs) -> Result; +} + +impl CosineSimilarity> for ArrayBase +where + S1: ndarray::Data, + S2: ndarray::Data, + T: num::traits::Float + 'static, +{ + fn cosine_similarity(&self, rhs: ArrayBase) -> Result { + if self.len() != rhs.len() { + return Err(CosineSimilarityError::InvalidVectors { + lhs: self.len(), + rhs: rhs.len(), + }); + } + debug_assert!( + self.iter().all(|&x| x.is_finite()), + "LHS vector contains non-finite values" + ); + debug_assert!( + rhs.iter().all(|&x| x.is_finite()), + "RHS vector contains non-finite values" + ); + let numerator = self.dot(&rhs); + let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt(); + Ok(numerator / denominator) + } +} + +#[cfg(test)] +mod cosine_tests { + use super::*; + use ndarray::*; + + #[test] + fn test_same_vectors() { + let a = array![1.0, 2.0, 3.0]; + let b = array![1.0, 2.0, 3.0]; + assert_eq!(a.cosine_similarity(b).unwrap(), 1.0); + } + + #[test] + fn test_orthogonal_vectors() { + let a = array![1.0, 0.0, 0.0]; + let b = array![0.0, 1.0, 0.0]; + assert_eq!(a.cosine_similarity(b).unwrap(), 0.0); + } + + #[test] + fn test_opposite_vectors() { + let a = array![1.0, 2.0, 3.0]; + let b = array![-1.0, -2.0, -3.0]; + assert_eq!(a.cosine_similarity(b).unwrap(), -1.0); + } + + #[test] + fn test_invalid_vectors() { + let a = array![1.0, 2.0]; + let b = array![1.0, 2.0, 3.0]; + assert!(matches!( + a.cosine_similarity(b), + Err(CosineSimilarityError::InvalidVectors { lhs: 2, rhs: 3 }) + )); + } + + #[test] + fn test_zero_vector() { + let a = array![0.0, 0.0, 0.0]; + let b = array![1.0, 2.0, 3.0]; + let similarity = a.cosine_similarity(b); + assert!(similarity.is_ok_and(|item: f64| item.is_nan())); + } + + #[test] + fn test_different_ndarray_types() { + let a = array![1.0, 2.0, 3.0]; + let b = array![1.0, 2.0, 3.0]; + assert_eq!(a.cosine_similarity(b.view()).unwrap(), 1.0); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1cbdf13 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +mod cosine; +pub use cosine::{CosineSimilarity, CosineSimilarityError};