feat: initial commit with cosine similarity

This commit is contained in:
uttarayan21
2025-06-23 15:18:36 +05:30
commit 669d1bf568
11 changed files with 920 additions and 0 deletions

1
.envrc Normal file
View File

@@ -0,0 +1 @@
use flake

59
.github/workflows/build.yaml vendored Normal file
View File

@@ -0,0 +1,59 @@
name: build
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
env:
CARGO_TERM_COLOR: always
jobs:
checks-matrix:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
- uses: DeterminateSystems/nix-installer-action@main
- id: set-matrix
name: Generate Nix Matrix
run: |
set -Eeu
matrix="$(nix eval --json '.#githubActions.matrix')"
echo "matrix=$matrix" >> "$GITHUB_OUTPUT"
checks-build:
needs: checks-matrix
runs-on: ${{ matrix.os }}
strategy:
matrix: ${{fromJSON(needs.checks-matrix.outputs.matrix)}}
steps:
- uses: actions/checkout@v4
- uses: DeterminateSystems/nix-installer-action@main
- run: nix build -L '.#${{ matrix.attr }}'
codecov:
runs-on: ubuntu-latest
permissions:
id-token: "write"
contents: "read"
steps:
- uses: actions/checkout@v4
- uses: DeterminateSystems/nix-installer-action@main
- name: Run codecov
run: nix build .#checks.x86_64-linux.ndarray-math-llvm-cov
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
with:
flags: unittests
name: codecov-ndarray-math
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}
files: ./result
verbose: true

38
.github/workflows/docs.yaml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: docs
on:
push:
branches: [ master ]
env:
CARGO_TERM_COLOR: always
jobs:
docs:
runs-on: ubuntu-latest
permissions:
id-token: "write"
contents: "read"
pages: "write"
steps:
- uses: actions/checkout@v4
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: DeterminateSystems/flake-checker-action@main
- name: Generate docs
run: nix build .#checks.x86_64-linux.ndarray-math-docs
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: result/share/doc
- name: Deploy to gh-pages
id: deployment
uses: actions/deploy-pages@v4

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/result
/target
.direnv

192
Cargo.lock generated Normal file
View File

@@ -0,0 +1,192 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "autocfg"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "ndarray-math"
version = "0.1.0"
dependencies = [
"ndarray",
"num",
"thiserror",
]
[[package]]
name = "num"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "portable-atomic"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
[[package]]
name = "portable-atomic-util"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
dependencies = [
"portable-atomic",
]
[[package]]
name = "proc-macro2"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "syn"
version = "2.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"

12
Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[package]
name = "ndarray-math"
version = "0.1.0"
edition = "2024"
license = "MIT"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = "0.16"
num = "0.4.3"
thiserror = "2.0.12"

236
deny.toml Normal file
View File

@@ -0,0 +1,236 @@
# This template contains all of the possible sections and their default values
# Note that all fields that take a lint level have these possible values:
# * deny - An error will be produced and the check will fail
# * warn - A warning will be produced, but the check will not fail
# * allow - No warning or error will be produced, though in some cases a note
# will be
# The values provided in this template are the default values that will be used
# when any section or field is not specified in your own configuration
# Root options
# The graph table configures how the dependency graph is constructed and thus
# which crates the checks are performed against
[graph]
# If 1 or more target triples (and optionally, target_features) are specified,
# only the specified targets will be checked when running `cargo deny check`.
# This means, if a particular package is only ever used as a target specific
# dependency, such as, for example, the `nix` crate only being used via the
# `target_family = "unix"` configuration, that only having windows targets in
# this list would mean the nix crate, as well as any of its exclusive
# dependencies not shared by any other crates, would be ignored, as the target
# list here is effectively saying which targets you are building for.
targets = [
# The triple can be any string, but only the target triples built in to
# rustc (as of 1.40) can be checked against actual config expressions
#"x86_64-unknown-linux-musl",
# You can also specify which target_features you promise are enabled for a
# particular target. target_features are currently not validated against
# the actual valid features supported by the target architecture.
#{ triple = "wasm32-unknown-unknown", features = ["atomics"] },
]
# When creating the dependency graph used as the source of truth when checks are
# executed, this field can be used to prune crates from the graph, removing them
# from the view of cargo-deny. This is an extremely heavy hammer, as if a crate
# is pruned from the graph, all of its dependencies will also be pruned unless
# they are connected to another crate in the graph that hasn't been pruned,
# so it should be used with care. The identifiers are [Package ID Specifications]
# (https://doc.rust-lang.org/cargo/reference/pkgid-spec.html)
#exclude = []
# If true, metadata will be collected with `--all-features`. Note that this can't
# be toggled off if true, if you want to conditionally enable `--all-features` it
# is recommended to pass `--all-features` on the cmd line instead
all-features = false
# If true, metadata will be collected with `--no-default-features`. The same
# caveat with `all-features` applies
no-default-features = false
# If set, these feature will be enabled when collecting metadata. If `--features`
# is specified on the cmd line they will take precedence over this option.
#features = []
# The output table provides options for how/if diagnostics are outputted
[output]
# When outputting inclusion graphs in diagnostics that include features, this
# option can be used to specify the depth at which feature edges will be added.
# This option is included since the graphs can be quite large and the addition
# of features from the crate(s) to all of the graph roots can be far too verbose.
# This option can be overridden via `--feature-depth` on the cmd line
feature-depth = 1
# This section is considered when running `cargo deny check advisories`
# More documentation for the advisories section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html
[advisories]
# The path where the advisory databases are cloned/fetched into
#db-path = "$CARGO_HOME/advisory-dbs"
# The url(s) of the advisory databases to use
#db-urls = ["https://github.com/rustsec/advisory-db"]
# A list of advisory IDs to ignore. Note that ignored advisories will still
# output a note when they are encountered.
ignore = [
#"RUSTSEC-0000-0000",
#{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" },
#"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish
#{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" },
]
# If this is true, then cargo deny will use the git executable to fetch advisory database.
# If this is false, then it uses a built-in git library.
# Setting this to true can be helpful if you have special authentication requirements that cargo-deny does not support.
# See Git Authentication for more information about setting up git authentication.
#git-fetch-with-cli = true
# This section is considered when running `cargo deny check licenses`
# More documentation for the licenses section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html
[licenses]
# List of explicitly allowed licenses
# See https://spdx.org/licenses/ for list of possible licenses
# [possible values: any SPDX 3.11 short identifier (+ optional exception)].
allow = [
"MIT",
"Apache-2.0",
"Unicode-3.0",
#"Apache-2.0 WITH LLVM-exception",
]
# The confidence threshold for detecting a license from license text.
# The higher the value, the more closely the license text must be to the
# canonical license text of a valid SPDX license file.
# [possible values: any between 0.0 and 1.0].
confidence-threshold = 0.8
# Allow 1 or more licenses on a per-crate basis, so that particular licenses
# aren't accepted for every possible crate as with the normal allow list
exceptions = [
# Each entry is the crate and version constraint, and its specific allow
# list
#{ allow = ["Zlib"], crate = "adler32" },
]
# Some crates don't have (easily) machine readable licensing information,
# adding a clarification entry for it allows you to manually specify the
# licensing information
#[[licenses.clarify]]
# The package spec the clarification applies to
#crate = "ring"
# The SPDX expression for the license requirements of the crate
#expression = "MIT AND ISC AND OpenSSL"
# One or more files in the crate's source used as the "source of truth" for
# the license expression. If the contents match, the clarification will be used
# when running the license check, otherwise the clarification will be ignored
# and the crate will be checked normally, which may produce warnings or errors
# depending on the rest of your configuration
#license-files = [
# Each entry is a crate relative path, and the (opaque) hash of its contents
#{ path = "LICENSE", hash = 0xbd0eed23 }
#]
[licenses.private]
# If true, ignores workspace crates that aren't published, or are only
# published to private registries.
# To see how to mark a crate as unpublished (to the official registry),
# visit https://doc.rust-lang.org/cargo/reference/manifest.html#the-publish-field.
ignore = false
# One or more private registries that you might publish crates to, if a crate
# is only published to private registries, and ignore is true, the crate will
# not have its license(s) checked
registries = [
#"https://sekretz.com/registry
]
# This section is considered when running `cargo deny check bans`.
# More documentation about the 'bans' section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html
[bans]
# Lint level for when multiple versions of the same crate are detected
multiple-versions = "warn"
# Lint level for when a crate version requirement is `*`
wildcards = "allow"
# The graph highlighting used when creating dotgraphs for crates
# with multiple versions
# * lowest-version - The path to the lowest versioned duplicate is highlighted
# * simplest-path - The path to the version with the fewest edges is highlighted
# * all - Both lowest-version and simplest-path are used
highlight = "all"
# The default lint level for `default` features for crates that are members of
# the workspace that is being checked. This can be overridden by allowing/denying
# `default` on a crate-by-crate basis if desired.
workspace-default-features = "allow"
# The default lint level for `default` features for external crates that are not
# members of the workspace. This can be overridden by allowing/denying `default`
# on a crate-by-crate basis if desired.
external-default-features = "allow"
# List of crates that are allowed. Use with care!
allow = [
#"ansi_term@0.11.0",
#{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" },
]
# List of crates to deny
deny = [
#"ansi_term@0.11.0",
#{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is banned" },
# Wrapper crates can optionally be specified to allow the crate when it
# is a direct dependency of the otherwise banned crate
#{ crate = "ansi_term@0.11.0", wrappers = ["this-crate-directly-depends-on-ansi_term"] },
]
# List of features to allow/deny
# Each entry the name of a crate and a version range. If version is
# not specified, all versions will be matched.
#[[bans.features]]
#crate = "reqwest"
# Features to not allow
#deny = ["json"]
# Features to allow
#allow = [
# "rustls",
# "__rustls",
# "__tls",
# "hyper-rustls",
# "rustls",
# "rustls-pemfile",
# "rustls-tls-webpki-roots",
# "tokio-rustls",
# "webpki-roots",
#]
# If true, the allowed features must exactly match the enabled feature set. If
# this is set there is no point setting `deny`
#exact = true
# Certain crates/versions that will be skipped when doing duplicate detection.
skip = [
#"ansi_term@0.11.0",
#{ crate = "ansi_term@0.11.0", reason = "you can specify a reason why it can't be updated/removed" },
]
# Similarly to `skip` allows you to skip certain crates during duplicate
# detection. Unlike skip, it also includes the entire tree of transitive
# dependencies starting at the specified crate, up to a certain depth, which is
# by default infinite.
skip-tree = [
#"ansi_term@0.11.0", # will be skipped along with _all_ of its direct and transitive dependencies
#{ crate = "ansi_term@0.11.0", depth = 20 },
]
# This section is considered when running `cargo deny check sources`.
# More documentation about the 'sources' section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html
[sources]
# Lint level for what to happen when a crate from a crate registry that is not
# in the allow list is encountered
unknown-registry = "warn"
# Lint level for what to happen when a crate from a git repository that is not
# in the allow list is encountered
unknown-git = "warn"
# List of URLs for allowed crate registries. Defaults to the crates.io index
# if not specified. If it is specified but empty, no registries are allowed.
allow-registry = ["https://github.com/rust-lang/crates.io-index"]
# List of URLs for allowed Git repositories
allow-git = []
[sources.allow-org]
# github.com organizations to allow git sources for
github = []
# gitlab.com organizations to allow git sources for
gitlab = []
# bitbucket.org organizations to allow git sources for
bitbucket = []

136
flake.lock generated Normal file
View File

@@ -0,0 +1,136 @@
{
"nodes": {
"advisory-db": {
"flake": false,
"locked": {
"lastModified": 1750151065,
"narHash": "sha256-il+CAqChFIB82xP6bO43dWlUVs+NlG7a4g8liIP5HcI=",
"owner": "rustsec",
"repo": "advisory-db",
"rev": "7573f55ba337263f61167dbb0ea926cdc7c8eb5d",
"type": "github"
},
"original": {
"owner": "rustsec",
"repo": "advisory-db",
"type": "github"
}
},
"crane": {
"locked": {
"lastModified": 1750266157,
"narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=",
"owner": "ipetkov",
"repo": "crane",
"rev": "e37c943371b73ed87faf33f7583860f81f1d5a48",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nix-github-actions": {
"inputs": {
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1737420293,
"narHash": "sha256-F1G5ifvqTpJq7fdkT34e/Jy9VCyzd5XfJ9TO8fHhJWE=",
"owner": "nix-community",
"repo": "nix-github-actions",
"rev": "f4158fa080ef4503c8f4c820967d946c2af31ec9",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "nix-github-actions",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1750134718,
"narHash": "sha256-v263g4GbxXv87hMXMCpjkIxd/viIF7p3JpJrwgKdNiI=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "9e83b64f727c88a7711a2c463a7b16eedb69a84c",
"type": "github"
},
"original": {
"owner": "nixos",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"advisory-db": "advisory-db",
"crane": "crane",
"flake-utils": "flake-utils",
"nix-github-actions": "nix-github-actions",
"nixpkgs": "nixpkgs",
"rust-overlay": "rust-overlay"
}
},
"rust-overlay": {
"inputs": {
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1750214276,
"narHash": "sha256-1kniuhH70q4TAC/xIvjFYH46aHiLrbIlcr6fdrRwO1A=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "f9b2b2b1327ff6beab4662b8ea41689e0a57b8d4",
"type": "github"
},
"original": {
"owner": "oxalica",
"repo": "rust-overlay",
"type": "github"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

148
flake.nix Normal file
View File

@@ -0,0 +1,148 @@
{
description = "A simple rust flake using rust-overlay and craneLib";
inputs = {
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
crane.url = "github:ipetkov/crane";
nix-github-actions = {
url = "github:nix-community/nix-github-actions";
inputs.nixpkgs.follows = "nixpkgs";
};
rust-overlay = {
url = "github:oxalica/rust-overlay";
inputs.nixpkgs.follows = "nixpkgs";
};
advisory-db = {
url = "github:rustsec/advisory-db";
flake = false;
};
};
outputs = {
self,
crane,
flake-utils,
nixpkgs,
rust-overlay,
advisory-db,
nix-github-actions,
...
}:
flake-utils.lib.eachDefaultSystem (
system: let
pkgs = import nixpkgs {
inherit system;
overlays = [
rust-overlay.overlays.default
];
};
inherit (pkgs) lib;
cargoToml = builtins.fromTOML (builtins.readFile ./Cargo.toml);
name = cargoToml.package.name;
stableToolchain = pkgs.rust-bin.stable.latest.default;
stableToolchainWithLLvmTools = stableToolchain.override {
extensions = ["rust-src" "llvm-tools"];
};
stableToolchainWithRustAnalyzer = stableToolchain.override {
extensions = ["rust-src" "rust-analyzer"];
};
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
src = let
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
in
lib.cleanSourceWith {
filter = sourceFilters;
src = ./.;
};
commonArgs =
{
inherit src;
pname = name;
stdenv = pkgs.clangStdenv;
doCheck = false;
# LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
# nativeBuildInputs = with pkgs; [
# cmake
# llvmPackages.libclang.lib
# ];
buildInputs = with pkgs;
[]
++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv
apple-sdk_13
]);
}
// (lib.optionalAttrs pkgs.stdenv.isLinux {
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
});
cargoArtifacts = craneLib.buildPackage commonArgs;
in {
checks =
{
"${name}-clippy" = craneLib.cargoClippy (commonArgs
// {
inherit cargoArtifacts;
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
});
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
"${name}-toml-fmt" = craneLib.taploFmt {
src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"];
};
# Audit dependencies
"${name}-audit" = craneLib.cargoAudit {
inherit src advisory-db;
};
# Audit licenses
"${name}-deny" = craneLib.cargoDeny {
inherit src;
};
"${name}-nextest" = craneLib.cargoNextest (commonArgs
// {
inherit cargoArtifacts;
partitions = 1;
partitionType = "count";
});
}
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
};
packages = let
pkg = craneLib.buildPackage (
commonArgs
// {inherit cargoArtifacts;}
);
in {
"${name}" = pkg;
default = pkg;
};
devShells = {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
// {
packages = with pkgs;
[
stableToolchainWithRustAnalyzer
cargo-nextest
cargo-deny
]
++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13
]);
});
};
}
)
// {
githubActions = nix-github-actions.lib.mkGithubMatrix {
checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks;
};
};
}

93
src/cosine.rs Normal file
View File

@@ -0,0 +1,93 @@
use ndarray::{ArrayBase, Ix1};
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum CosineSimilarityError {
#[error(
"Invalid vectors: Vectors must have the same length for similarity calculation. LHS: {lhs}, RHS: {rhs}"
)]
InvalidVectors { lhs: usize, rhs: usize },
}
pub trait CosineSimilarity<T, Rhs = Self> {
/// Computes the cosine similarity between two vectors.
///
/// A `Result` containing the cosine similarity as a `f64`, or an error if the vectors are invalid.
fn cosine_similarity(&self, rhs: Rhs) -> Result<T, CosineSimilarityError>;
}
impl<S1, S2, T> CosineSimilarity<T, ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix1>
where
S1: ndarray::Data<Elem = T>,
S2: ndarray::Data<Elem = T>,
T: num::traits::Float + 'static,
{
fn cosine_similarity(&self, rhs: ArrayBase<S2, Ix1>) -> Result<T, CosineSimilarityError> {
if self.len() != rhs.len() {
return Err(CosineSimilarityError::InvalidVectors {
lhs: self.len(),
rhs: rhs.len(),
});
}
debug_assert!(
self.iter().all(|&x| x.is_finite()),
"LHS vector contains non-finite values"
);
debug_assert!(
rhs.iter().all(|&x| x.is_finite()),
"RHS vector contains non-finite values"
);
let numerator = self.dot(&rhs);
let denominator = self.powi(2).sum().sqrt() * rhs.powi(2).sum().sqrt();
Ok(numerator / denominator)
}
}
#[cfg(test)]
mod cosine_tests {
use super::*;
use ndarray::*;
#[test]
fn test_same_vectors() {
let a = array![1.0, 2.0, 3.0];
let b = array![1.0, 2.0, 3.0];
assert_eq!(a.cosine_similarity(b).unwrap(), 1.0);
}
#[test]
fn test_orthogonal_vectors() {
let a = array![1.0, 0.0, 0.0];
let b = array![0.0, 1.0, 0.0];
assert_eq!(a.cosine_similarity(b).unwrap(), 0.0);
}
#[test]
fn test_opposite_vectors() {
let a = array![1.0, 2.0, 3.0];
let b = array![-1.0, -2.0, -3.0];
assert_eq!(a.cosine_similarity(b).unwrap(), -1.0);
}
#[test]
fn test_invalid_vectors() {
let a = array![1.0, 2.0];
let b = array![1.0, 2.0, 3.0];
assert!(matches!(
a.cosine_similarity(b),
Err(CosineSimilarityError::InvalidVectors { lhs: 2, rhs: 3 })
));
}
#[test]
fn test_zero_vector() {
let a = array![0.0, 0.0, 0.0];
let b = array![1.0, 2.0, 3.0];
let similarity = a.cosine_similarity(b);
assert!(similarity.is_ok_and(|item: f64| item.is_nan()));
}
#[test]
fn test_different_ndarray_types() {
let a = array![1.0, 2.0, 3.0];
let b = array![1.0, 2.0, 3.0];
assert_eq!(a.cosine_similarity(b.view()).unwrap(), 1.0);
}
}

2
src/lib.rs Normal file
View File

@@ -0,0 +1,2 @@
mod cosine;
pub use cosine::{CosineSimilarity, CosineSimilarityError};