feat(compare): add face comparison functionality with cosine similarity
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m18s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-21 17:34:07 +05:30
parent f8122892e0
commit bfa389b497
15 changed files with 1188 additions and 107 deletions

331
Cargo.lock generated
View File

@@ -139,7 +139,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -148,6 +148,17 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi",
"libc",
"winapi",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.5.0" version = "1.5.0"
@@ -192,6 +203,29 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "bindgen"
version = "0.60.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "062dddbc1ba4aca46de6338e2bf87771414c335f7b2f2036e8f3e9befebf88e6"
dependencies = [
"bitflags 1.3.2",
"cexpr",
"clang-sys",
"clap 3.2.25",
"env_logger",
"lazy_static",
"lazycell",
"log",
"peeking_take_while",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"which",
]
[[package]] [[package]]
name = "bindgen" name = "bindgen"
version = "0.70.1" version = "0.70.1"
@@ -210,7 +244,7 @@ dependencies = [
"regex", "regex",
"rustc-hash", "rustc-hash",
"shlex", "shlex",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -281,7 +315,7 @@ checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -349,6 +383,21 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "clap"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
dependencies = [
"atty",
"bitflags 1.3.2",
"clap_lex 0.2.4",
"indexmap 1.9.3",
"strsim 0.10.0",
"termcolor",
"textwrap",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.45" version = "4.5.45"
@@ -367,8 +416,8 @@ checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
"clap_lex", "clap_lex 0.7.5",
"strsim", "strsim 0.11.1",
] ]
[[package]] [[package]]
@@ -377,7 +426,7 @@ version = "4.5.57"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad" checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
dependencies = [ dependencies = [
"clap", "clap 4.5.45",
] ]
[[package]] [[package]]
@@ -389,7 +438,16 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
]
[[package]]
name = "clap_lex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
] ]
[[package]] [[package]]
@@ -505,7 +563,7 @@ name = "detector"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bounding-box", "bounding-box",
"clap", "clap 4.5.45",
"clap_complete", "clap_complete",
"color", "color",
"error-stack", "error-stack",
@@ -518,12 +576,13 @@ dependencies = [
"nalgebra", "nalgebra",
"ndarray", "ndarray",
"ndarray-image", "ndarray-image",
"ndarray-math", "ndarray-math 0.1.0 (git+https://git.darksailor.dev/servius/ndarray-math)",
"ndarray-resize", "ndarray-resize",
"ndarray-safetensors", "ndarray-safetensors",
"ordered-float", "ordered-float",
"ort", "ort",
"rusqlite", "rusqlite",
"sqlite3-safetensor-cosine",
"tap", "tap",
"thiserror 2.0.15", "thiserror 2.0.15",
"tokio", "tokio",
@@ -548,7 +607,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -572,6 +631,19 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "env_logger"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7"
dependencies = [
"atty",
"humantime",
"log",
"regex",
"termcolor",
]
[[package]] [[package]]
name = "equator" name = "equator"
version = "0.4.2" version = "0.4.2"
@@ -589,7 +661,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -598,6 +670,16 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
dependencies = [
"libc",
"windows-sys 0.60.2",
]
[[package]] [[package]]
name = "error-stack" name = "error-stack"
version = "0.5.0" version = "0.5.0"
@@ -851,6 +933,12 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.5" version = "0.15.5"
@@ -866,7 +954,7 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [ dependencies = [
"hashbrown", "hashbrown 0.15.5",
] ]
[[package]] [[package]]
@@ -875,6 +963,30 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "humantime"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f"
[[package]] [[package]]
name = "iana-time-zone" name = "iana-time-zone"
version = "0.1.63" version = "0.1.63"
@@ -1045,6 +1157,16 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
[[package]]
name = "indexmap"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.10.0" version = "2.10.0"
@@ -1052,7 +1174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown", "hashbrown 0.15.5",
] ]
[[package]] [[package]]
@@ -1063,7 +1185,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -1137,7 +1259,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -1172,6 +1294,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "lebe" name = "lebe"
version = "0.5.2" version = "0.5.2"
@@ -1214,6 +1342,12 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]] [[package]]
name = "litemap" name = "litemap"
version = "0.8.0" version = "0.8.0"
@@ -1355,7 +1489,7 @@ version = "0.1.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5" source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen", "bindgen 0.70.1",
"cc", "cc",
"cmake", "cmake",
"diffy", "diffy",
@@ -1431,6 +1565,15 @@ dependencies = [
"ndarray", "ndarray",
] ]
[[package]]
name = "ndarray-math"
version = "0.1.0"
dependencies = [
"ndarray",
"num",
"thiserror 2.0.15",
]
[[package]] [[package]]
name = "ndarray-math" name = "ndarray-math"
version = "0.1.0" version = "0.1.0"
@@ -1552,7 +1695,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -1652,6 +1795,12 @@ dependencies = [
"pkg-config", "pkg-config",
] ]
[[package]]
name = "os_str_bytes"
version = "6.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@@ -1664,6 +1813,12 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.1" version = "2.3.1"
@@ -1741,7 +1896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -1769,7 +1924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b"
dependencies = [ dependencies = [
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2000,6 +2155,19 @@ dependencies = [
"semver", "semver",
] ]
[[package]]
name = "rustix"
version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.9.2",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.22" version = "1.0.22"
@@ -2060,7 +2228,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2154,18 +2322,79 @@ dependencies = [
"lock_api", "lock_api",
] ]
[[package]]
name = "sqlite-loadable"
version = "0.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a916b7bb8738eef189dea88731b619b80bf3f62b3acf05138fa43fbf8621cc94"
dependencies = [
"bitflags 1.3.2",
"serde",
"serde_json",
"sqlite-loadable-macros",
"sqlite3ext-sys",
]
[[package]]
name = "sqlite-loadable-macros"
version = "0.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f98bc75a8d6fd24f6a2cfea34f28758780fa17279d3051eec926efa381971e48"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "sqlite3-safetensor-cosine"
version = "0.1.0"
dependencies = [
"ndarray",
"ndarray-math 0.1.0",
"ndarray-safetensors",
"sqlite-loadable",
]
[[package]]
name = "sqlite3ext-sys"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3afdc2b3dc08f16d6eecf8aa07d19975a268603ab1cca67d3f9b4172c507cf16"
dependencies = [
"bindgen 0.60.1",
"cc",
]
[[package]] [[package]]
name = "stable_deref_trait" name = "stable_deref_trait"
version = "1.2.0" version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.11.1" version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.106" version = "2.0.106"
@@ -2185,7 +2414,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2213,6 +2442,21 @@ version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]]
name = "textwrap"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.69" version = "1.0.69"
@@ -2239,7 +2483,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2250,7 +2494,7 @@ checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2355,7 +2599,7 @@ version = "0.22.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
dependencies = [ dependencies = [
"indexmap", "indexmap 2.10.0",
"serde", "serde",
"serde_spanned", "serde_spanned",
"toml_datetime", "toml_datetime",
@@ -2381,7 +2625,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2540,7 +2784,7 @@ dependencies = [
"log", "log",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@@ -2562,7 +2806,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@@ -2582,6 +2826,18 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3" checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]] [[package]]
name = "wide" name = "wide"
version = "0.7.33" version = "0.7.33"
@@ -2608,6 +2864,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22"
dependencies = [
"windows-sys 0.60.2",
]
[[package]] [[package]]
name = "winapi-x86_64-pc-windows-gnu" name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0" version = "0.4.0"
@@ -2635,7 +2900,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2646,7 +2911,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2882,7 +3147,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
"synstructure", "synstructure",
] ]
@@ -2903,7 +3168,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]
@@ -2923,7 +3188,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
"synstructure", "synstructure",
] ]
@@ -2957,7 +3222,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.106",
] ]
[[package]] [[package]]

View File

@@ -1,5 +1,5 @@
[workspace] [workspace]
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors"] members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors", "sqlite3-safetensor-cosine"]
[workspace.package] [workspace.package]
version = "0.1.0" version = "0.1.0"
@@ -53,6 +53,7 @@ ordered-float = "5.0.0"
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]} ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]}
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" } ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" } ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
sqlite3-safetensor-cosine = { version = "0.1.0", path = "sqlite3-safetensor-cosine" }
[profile.release] [profile.release]
debug = true debug = true
@@ -67,4 +68,4 @@ ort-directml = ["ort/directml"]
mnn-metal = ["mnn/metal"] mnn-metal = ["mnn/metal"]
mnn-coreml = ["mnn/coreml"] mnn-coreml = ["mnn/coreml"]
default = [] default = ["mnn-metal","mnn-coreml"]

View File

@@ -55,6 +55,35 @@ cargo run --release detect --output detected.jpg path/to/image.jpg
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
``` ```
### Face Comparison
Compare faces between two images by computing and comparing their embeddings:
```bash
# Compare faces in two images
cargo run --release compare image1.jpg image2.jpg
# Compare with custom thresholds
cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg
# Use ONNX Runtime backend for comparison
cargo run --release compare -p cpu image1.jpg image2.jpg
# Use MNN with Metal acceleration
cargo run --release compare -f metal image1.jpg image2.jpg
```
The compare command will:
1. Detect all faces in both images
2. Generate embeddings for each detected face
3. Compute cosine similarity between all face pairs
4. Display similarity scores and the best match
5. Provide interpretation of the similarity scores:
- **> 0.8**: Very likely the same person
- **0.6-0.8**: Possibly the same person
- **0.4-0.6**: Unlikely to be the same person
- **< 0.4**: Very unlikely to be the same person
### Backend Selection ### Backend Selection
The project supports two inference backends: The project supports two inference backends:

View File

@@ -2,9 +2,6 @@
description = "A simple rust flake using rust-overlay and craneLib"; description = "A simple rust flake using rust-overlay and craneLib";
inputs = { inputs = {
self = {
lfs = true;
};
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
crane.url = "github:ipetkov/crane"; crane.url = "github:ipetkov/crane";
@@ -206,6 +203,7 @@
packages = with pkgs; packages = with pkgs;
[ [
stableToolchainWithRustAnalyzer stableToolchainWithRustAnalyzer
cargo-expand
cargo-nextest cargo-nextest
cargo-deny cargo-deny
cmake cmake

View File

@@ -68,6 +68,7 @@ use safetensors::tensor::SafeTensors;
/// let view = SafeArrayView::from_bytes(&bytes).unwrap(); /// let view = SafeArrayView::from_bytes(&bytes).unwrap();
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap(); /// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
/// ``` /// ```
#[derive(Debug)]
pub struct SafeArraysView<'a> { pub struct SafeArraysView<'a> {
pub tensors: SafeTensors<'a>, pub tensors: SafeTensors<'a>,
} }

View File

@@ -0,0 +1,14 @@
[package]
name = "sqlite3-safetensor-cosine"
version.workspace = true
edition.workspace = true
[lib]
crate-type = ["cdylib", "staticlib"]
[dependencies]
ndarray = "0.16.1"
# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
sqlite-loadable = "0.0.5"

View File

@@ -0,0 +1,61 @@
use sqlite_loadable::prelude::*;
use sqlite_loadable::{Error, ErrorKind};
use sqlite_loadable::{Result, api, define_scalar_function};
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
#[inline(always)]
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
}
if values.len() != 2 {
return Err(Error::new(ErrorKind::Message(
"cosine_similarity requires exactly 2 arguments".to_string(),
)));
}
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
let array_1_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
let array_2_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
use ndarray_math::*;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(custom_error)?;
api::result_double(context, similarity as f64);
Ok(())
}
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
define_scalar_function(
db,
"cosine_similarity",
2,
cosine_similarity,
FunctionFlags::DETERMINISTIC,
)?;
Ok(())
}
/// # Safety
///
/// Should only be called by underlying SQLite C APIs,
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sqlite3_extension_init(
db: *mut sqlite3,
pz_err_msg: *mut *mut c_char,
p_api: *mut sqlite3_api_routines,
) -> c_uint {
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
}

View File

@@ -1,6 +1,5 @@
use std::path::PathBuf; use std::path::PathBuf;
use mnn::ForwardType;
#[derive(Debug, clap::Parser)] #[derive(Debug, clap::Parser)]
pub struct Cli { pub struct Cli {
#[clap(subcommand)] #[clap(subcommand)]
@@ -11,14 +10,16 @@ pub struct Cli {
pub enum SubCommand { pub enum SubCommand {
#[clap(name = "detect")] #[clap(name = "detect")]
Detect(Detect), Detect(Detect),
#[clap(name = "list")] #[clap(name = "detect-multi")]
List(List), DetectMulti(DetectMulti),
#[clap(name = "query")] #[clap(name = "query")]
Query(Query), Query(Query),
#[clap(name = "similar")] #[clap(name = "similar")]
Similar(Similar), Similar(Similar),
#[clap(name = "stats")] #[clap(name = "stats")]
Stats(Stats), Stats(Stats),
#[clap(name = "compare")]
Compare(Compare),
#[clap(name = "completions")] #[clap(name = "completions")]
Completions { shell: clap_complete::Shell }, Completions { shell: clap_complete::Shell },
} }
@@ -74,7 +75,47 @@ pub struct Detect {
} }
#[derive(Debug, clap::Args)] #[derive(Debug, clap::Args)]
pub struct List {} pub struct DetectMulti {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output_dir: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(
long,
help = "Image extensions to process (e.g., jpg,png,jpeg)",
default_value = "jpg,jpeg,png,bmp,tiff,webp"
)]
pub extensions: String,
#[clap(help = "Directory containing images to process")]
pub input_dir: PathBuf,
}
#[derive(Debug, clap::Args)] #[derive(Debug, clap::Args)]
pub struct Query { pub struct Query {
@@ -108,6 +149,41 @@ pub struct Stats {
pub database: PathBuf, pub database: PathBuf,
} }
#[derive(Debug, clap::Args)]
pub struct Compare {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(help = "First image to compare")]
pub image1: PathBuf,
#[clap(help = "Second image to compare")]
pub image2: PathBuf,
}
impl Cli { impl Cli {
pub fn completions(shell: clap_complete::Shell) { pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command(); let mut command = <Cli as clap::CommandFactory>::command();

View File

@@ -65,7 +65,14 @@ impl FaceDatabase {
/// Create a new database connection and initialize tables /// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> { pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?; let conn = Connection::open(db_path).change_context(Error)?;
add_sqlite_cosine_similarity(&conn).change_context(Error)?; unsafe {
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
conn.load_extension(
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
None::<&str>,
)
.change_context(Error)?;
}
let db = Self { conn }; let db = Self { conn };
db.create_tables()?; db.create_tables()?;
Ok(db) Ok(db)
@@ -190,10 +197,9 @@ impl FaceDatabase {
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)") .prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
.change_context(Error)?; .change_context(Error)?;
stmt.execute(params![file_path, width, height]) Ok(stmt
.change_context(Error)?; .insert(params![file_path, width, height])
.change_context(Error)?)
Ok(self.conn.last_insert_rowid())
} }
/// Store face detection results /// Store face detection results
@@ -231,17 +237,16 @@ impl FaceDatabase {
) )
.change_context(Error)?; .change_context(Error)?;
stmt.execute(params![ Ok(stmt
image_id, .insert(params![
bbox.x1() as f32, image_id,
bbox.y1() as f32, bbox.x1() as f32,
bbox.x2() as f32, bbox.y1() as f32,
bbox.y2() as f32, bbox.x2() as f32,
confidence bbox.y2() as f32,
]) confidence
.change_context(Error)?; ])
.change_context(Error)?)
Ok(self.conn.last_insert_rowid())
} }
/// Store face landmarks /// Store face landmarks
@@ -258,22 +263,21 @@ impl FaceDatabase {
) )
.change_context(Error)?; .change_context(Error)?;
stmt.execute(params![ Ok(stmt
face_id, .insert(params![
landmarks.left_eye.x, face_id,
landmarks.left_eye.y, landmarks.left_eye.x,
landmarks.right_eye.x, landmarks.left_eye.y,
landmarks.right_eye.y, landmarks.right_eye.x,
landmarks.nose.x, landmarks.right_eye.y,
landmarks.nose.y, landmarks.nose.x,
landmarks.left_mouth.x, landmarks.nose.y,
landmarks.left_mouth.y, landmarks.left_mouth.x,
landmarks.right_mouth.x, landmarks.left_mouth.y,
landmarks.right_mouth.y, landmarks.right_mouth.x,
]) landmarks.right_mouth.y,
.change_context(Error)?; ])
.change_context(Error)?)
Ok(self.conn.last_insert_rowid())
} }
/// Store face embeddings /// Store face embeddings
@@ -310,12 +314,12 @@ impl FaceDatabase {
embedding: ndarray::ArrayView1<f32>, embedding: ndarray::ArrayView1<f32>,
model_name: &str, model_name: &str,
) -> Result<i64> { ) -> Result<i64> {
let embedding_bytes = let safe_arrays =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)]) ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
.change_context(Error)?
.serialize()
.change_context(Error)?; .change_context(Error)?;
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)") .prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
@@ -462,6 +466,35 @@ impl FaceDatabase {
Ok(embeddings) Ok(embeddings)
} }
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT images.id, images.file_path, images.width, images.height, images.created_at
FROM images
JOIN faces ON faces.image_id = images.id
WHERE faces.id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get database statistics /// Get database statistics
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> { pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
let images: usize = self let images: usize = self
@@ -528,6 +561,39 @@ impl FaceDatabase {
Ok(result) Ok(result)
} }
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)
.unwrap()
.serialize()
.change_context(Error)
.unwrap();
let mut stmt = self
.conn
.prepare(
r#"
SELECT face_id,
cosine_similarity(?1, embedding)
FROM embeddings
"#,
)
.change_context(Error)
.unwrap();
let result_iter = stmt
.query_map(params![embedding_bytes], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)
.unwrap();
for result in result_iter {
println!("{:?}", result);
}
}
} }
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> { fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {

View File

@@ -310,7 +310,7 @@ pub trait FaceDetector {
fn detect_faces( fn detect_faces(
&mut self, &mut self,
image: ndarray::ArrayView3<u8>, image: ndarray::ArrayView3<u8>,
config: FaceDetectionConfig, config: &FaceDetectionConfig,
) -> Result<FaceDetectionOutput> { ) -> Result<FaceDetectionOutput> {
let (height, width, _channels) = image.dim(); let (height, width, _channels) = image.dim();
let output = self let output = self

View File

@@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
use crate::errors::*; use crate::errors::*;
use ndarray::{Array2, ArrayView4}; use ndarray::{Array2, ArrayView4};
pub mod preprocessing {
use ndarray::*;
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
let mean = image.mean().unwrap_or(0.0);
let std = image.std(0.0);
if std > 0.0 {
image.mapv_inplace(|x| (x - mean) / std);
} else {
image.mapv_inplace(|x| (x - 127.5) / 128.0)
}
});
owned
}
}
/// Common trait for face embedding backends - maintained for backward compatibility /// Common trait for face embedding backends - maintained for backward compatibility
pub trait FaceEmbedder { pub trait FaceEmbedder {
/// Generate embeddings for a batch of face images /// Generate embeddings for a batch of face images

View File

@@ -4,6 +4,7 @@ pub mod ort;
use crate::errors::*; use crate::errors::*;
use error_stack::ResultExt; use error_stack::ResultExt;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use ndarray_math::{CosineSimilarity, EuclideanDistance};
/// Configuration for face embedding processing /// Configuration for face embedding processing
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@@ -32,9 +33,9 @@ impl FaceEmbeddingConfig {
impl Default for FaceEmbeddingConfig { impl Default for FaceEmbeddingConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
input_width: 160, input_width: 320,
input_height: 160, input_height: 320,
normalize: true, normalize: false,
} }
} }
} }
@@ -63,15 +64,14 @@ impl FaceEmbedding {
/// Calculate cosine similarity with another embedding /// Calculate cosine similarity with another embedding
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 { pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
let dot_product = self.vector.dot(&other.vector); self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
dot_product / (norm_self * norm_other)
} }
/// Calculate Euclidean distance with another embedding /// Calculate Euclidean distance with another embedding
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 { pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt() self.vector
.euclidean_distance(other.vector.view())
.unwrap_or(f32::INFINITY)
} }
/// Normalize the embedding vector to unit length /// Normalize the embedding vector to unit length

View File

@@ -64,10 +64,7 @@ impl EmbeddingGenerator {
} }
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> { pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = face let tensor = crate::faceembed::preprocessing::preprocess(face);
// .permuted_axes((0, 3, 1, 2))
.as_standard_layout()
.mapv(|x| x as f32);
let shape: [usize; 4] = tensor.dim().into(); let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32); let shape = shape.map(|f| f as i32);
let output = self let output = self

View File

@@ -135,10 +135,12 @@ impl EmbeddingGenerator {
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> { pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Convert input from u8 to f32 and normalize to [0, 1] range // Convert input from u8 to f32 and normalize to [0, 1] range
let input_tensor = faces let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
.mapv(|x| x as f32 / 255.0)
.as_standard_layout() // face_array = np.asarray(face_resized, 'float32')
.into_owned(); // mean, std = face_array.mean(), face_array.std()
// face_normalized = (face_array - mean) / std
// let input_tensor = faces.mean()
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape()); tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());

View File

@@ -75,8 +75,61 @@ pub fn main() -> Result<()> {
} }
} }
} }
cli::SubCommand::List(list) => { cli::SubCommand::DetectMulti(detect_multi) => {
println!("List: {:?}", list); // Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect_multi
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect_multi.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(
detect_multi.ort_execution_provider.clone(),
))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_multi_detection(detect_multi, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_multi_detection(detect_multi, retinaface, facenet)?;
}
}
} }
cli::SubCommand::Query(query) => { cli::SubCommand::Query(query) => {
run_query(query)?; run_query(query)?;
@@ -87,6 +140,59 @@ pub fn main() -> Result<()> {
cli::SubCommand::Stats(stats) => { cli::SubCommand::Stats(stats) => {
run_stats(stats)?; run_stats(stats)?;
} }
cli::SubCommand::Compare(compare) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = compare
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if compare.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(compare.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_compare(compare, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_compare(compare, retinaface, facenet)?;
}
}
}
cli::SubCommand::Completions { shell } => { cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell); cli::Cli::completions(shell);
} }
@@ -122,7 +228,7 @@ where
let output = retinaface let output = retinaface
.detect_faces( .detect_faces(
array.view(), array.view(),
FaceDetectionConfig::default() &FaceDetectionConfig::default()
.with_threshold(detect.threshold) .with_threshold(detect.threshold)
.with_nms_threshold(detect.nms_threshold), .with_nms_threshold(detect.nms_threshold),
) )
@@ -163,7 +269,7 @@ where
// }) // })
.map(|roi| { .map(|roi| {
roi.as_standard_layout() roi.as_standard_layout()
.fast_resize(160, 160, &ResizeOptions::default()) .fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error) .change_context(Error)
}) })
// .inspect(|f| { // .inspect(|f| {
@@ -182,11 +288,14 @@ where
if chunk.len() < chunk_size { if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros"); tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((160, 160, 3)); let zeros = Array3::zeros((320, 320, 3));
let zero_array = core::iter::repeat(zeros.view()) let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size) .take(chunk_size)
.collect::<Vec<_>>(); .collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice()) let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to stack rois together")?; .attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?; let output = facenet.run_models(face_rois.view()).change_context(Error)?;
@@ -328,6 +437,446 @@ fn run_query(query: cli::Query) -> Result<()> {
Ok(()) Ok(())
} }
fn run_compare<D, E>(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
// Helper function to detect faces and compute embeddings for an image
fn process_image<D, E>(
image_path: &std::path::Path,
retinaface: &mut D,
facenet: &mut E,
config: &FaceDetectionConfig,
batch_size: usize,
) -> Result<(Vec<Array1<f32>>, usize)>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
let image = image::open(image_path)
.change_context(Error)
.attach_printable(image_path.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(array.view(), config)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
tracing::info!(
"Detected {} faces in {}",
output.bbox.len(),
image_path.display()
);
if output.bbox.is_empty() {
return Ok((Vec::new(), 0));
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
if chunk.len() < chunk_size {
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
}
})
.collect::<Result<Vec<Array2<f32>>>>()?;
// Flatten embeddings into individual face embeddings
let mut face_embeddings = Vec::new();
for embedding_batch in embeddings {
for i in 0..output.bbox.len().min(embedding_batch.nrows()) {
face_embeddings.push(embedding_batch.row(i).to_owned());
}
}
Ok((face_embeddings, output.bbox.len()))
}
// Helper function to compute cosine similarity between two embeddings
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
dot_product / (norm_a * norm_b)
}
let config = FaceDetectionConfig::default()
.with_threshold(compare.threshold)
.with_nms_threshold(compare.nms_threshold);
// Process both images
let (embeddings1, face_count1) = process_image(
&compare.image1,
&mut retinaface,
&mut facenet,
&config,
compare.batch_size,
)?;
let (embeddings2, face_count2) = process_image(
&compare.image2,
&mut retinaface,
&mut facenet,
&config,
compare.batch_size,
)?;
println!(
"Image 1 ({}): {} faces detected",
compare.image1.display(),
face_count1
);
println!(
"Image 2 ({}): {} faces detected",
compare.image2.display(),
face_count2
);
if embeddings1.is_empty() && embeddings2.is_empty() {
println!("No faces detected in either image");
return Ok(());
}
if embeddings1.is_empty() {
println!("No faces detected in image 1");
return Ok(());
}
if embeddings2.is_empty() {
println!("No faces detected in image 2");
return Ok(());
}
// Compare all faces between the two images
println!("\nFace comparison results:");
println!("========================");
let mut max_similarity = f32::NEG_INFINITY;
let mut best_match = (0, 0);
for (i, emb1) in embeddings1.iter().enumerate() {
for (j, emb2) in embeddings2.iter().enumerate() {
let similarity = cosine_similarity(emb1, emb2);
println!(
"Face {} (image 1) vs Face {} (image 2): {:.4}",
i + 1,
j + 1,
similarity
);
if similarity > max_similarity {
max_similarity = similarity;
best_match = (i + 1, j + 1);
}
}
}
println!(
"\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}",
best_match.0, best_match.1, max_similarity
);
// Interpretation of similarity score
if max_similarity > 0.8 {
println!("Interpretation: Very likely the same person");
} else if max_similarity > 0.6 {
println!("Interpretation: Possibly the same person");
} else if max_similarity > 0.4 {
println!("Interpretation: Unlikely to be the same person");
} else {
println!("Interpretation: Very unlikely to be the same person");
}
Ok(())
}
fn run_multi_detection<D, E>(
detect_multi: cli::DetectMulti,
mut retinaface: D,
mut facenet: E,
) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
use std::fs;
// Initialize database - always save to database for multi-detection
let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?;
// Parse supported extensions
let extensions: std::collections::HashSet<String> = detect_multi
.extensions
.split(',')
.map(|ext| ext.trim().to_lowercase())
.collect();
// Create output directory if specified
if let Some(ref output_dir) = detect_multi.output_dir {
fs::create_dir_all(output_dir)
.change_context(Error)
.attach_printable("Failed to create output directory")?;
}
// Read directory and filter image files
let entries = fs::read_dir(&detect_multi.input_dir)
.change_context(Error)
.attach_printable("Failed to read input directory")?;
let mut image_paths = Vec::new();
for entry in entries {
let entry = entry.change_context(Error)?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
if extensions.contains(&ext.to_lowercase()) {
image_paths.push(path);
}
}
}
}
if image_paths.is_empty() {
tracing::warn!(
"No image files found in directory: {:?}",
detect_multi.input_dir
);
return Ok(());
}
tracing::info!("Found {} image files to process", image_paths.len());
let mut total_faces = 0;
let mut processed_images = 0;
// Process each image
for (idx, image_path) in image_paths.iter().enumerate() {
tracing::info!(
"Processing image {}/{}: {:?}",
idx + 1,
image_paths.len(),
image_path
);
// Load and process image
let image = match image::open(image_path) {
Ok(img) => img.into_rgb8(),
Err(e) => {
tracing::error!("Failed to load image {:?}: {}", image_path, e);
continue;
}
};
let (image_width, image_height) = image.dimensions();
let mut array = match image.into_ndarray().change_context(errors::Error) {
Ok(arr) => arr,
Err(e) => {
tracing::error!("Failed to convert image to ndarray: {:?}", e);
continue;
}
};
let config = FaceDetectionConfig::default()
.with_threshold(detect_multi.threshold)
.with_nms_threshold(detect_multi.nms_threshold);
// Detect faces
let output = match retinaface.detect_faces(array.view(), &config) {
Ok(output) => output,
Err(e) => {
tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e);
continue;
}
};
let num_faces = output.bbox.len();
total_faces += num_faces;
if num_faces == 0 {
tracing::info!("No faces detected in {:?}", image_path);
} else {
tracing::info!("Detected {} faces in {:?}", num_faces, image_path);
}
// Store image and detections in database
let image_path_str = image_path.to_string_lossy();
let img_id = match db.store_image(&image_path_str, image_width, image_height) {
Ok(id) => id,
Err(e) => {
tracing::error!("Failed to store image in database: {:?}", e);
continue;
}
};
let face_ids = match db.store_face_detections(img_id, &output) {
Ok(ids) => ids,
Err(e) => {
tracing::error!("Failed to store face detections in database: {:?}", e);
continue;
}
};
// Draw bounding boxes if output directory is specified
if detect_multi.output_dir.is_some() {
for bbox in &output.bbox {
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
}
// Process face embeddings if faces were detected
if !face_ids.is_empty() {
let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) {
Ok(rois) => rois,
Err(e) => {
tracing::error!("Failed to extract face ROIs: {:?}", e);
continue;
}
};
let face_rois: Result<Vec<_>> = face_rois
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
.collect();
let face_rois = match face_rois {
Ok(rois) => rois,
Err(e) => {
tracing::error!("Failed to resize face ROIs: {:?}", e);
continue;
}
};
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = detect_multi.batch_size;
let embeddings: Result<Vec<Array2<f32>>> = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
if chunk.len() < chunk_size {
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
}
})
.collect();
let embeddings = match embeddings {
Ok(emb) => emb,
Err(e) => {
tracing::error!("Failed to generate embeddings: {:?}", e);
continue;
}
};
// Store embeddings in database
if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) {
tracing::error!("Failed to store embeddings in database: {:?}", e);
continue;
}
}
// Save output image if directory specified
if let Some(ref output_dir) = detect_multi.output_dir {
let output_filename = format!(
"detected_{}",
image_path.file_name().unwrap().to_string_lossy()
);
let output_path = output_dir.join(output_filename);
let v = array.view();
let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) {
Ok(img) => img,
Err(e) => {
tracing::error!("Failed to convert ndarray to image: {:?}", e);
continue;
}
};
if let Err(e) = output_image.save(&output_path) {
tracing::error!("Failed to save output image to {:?}: {}", output_path, e);
continue;
}
tracing::info!("Saved output image to {:?}", output_path);
}
processed_images += 1;
}
// Print final statistics
tracing::info!(
"Processing complete: {}/{} images processed successfully, {} total faces detected",
processed_images,
image_paths.len(),
total_faces
);
let (num_images, num_faces, num_landmarks, num_embeddings) =
db.get_stats().change_context(Error)?;
tracing::info!(
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
num_images,
num_faces,
num_landmarks,
num_embeddings
);
Ok(())
}
fn run_similar(similar: cli::Similar) -> Result<()> { fn run_similar(similar: cli::Similar) -> Result<()> {
let db = FaceDatabase::new(&similar.database).change_context(Error)?; let db = FaceDatabase::new(&similar.database).change_context(Error)?;
@@ -341,14 +890,19 @@ fn run_similar(similar: cli::Similar) -> Result<()> {
let similar_faces = db let similar_faces = db
.find_similar_faces(query_embedding, similar.threshold, similar.limit) .find_similar_faces(query_embedding, similar.threshold, similar.limit)
.change_context(Error)?; .change_context(Error)?;
// Get image information for the similar faces
println!( println!(
"Found {} similar faces (threshold: {:.3}):", "Found {} similar faces (threshold: {:.3}):",
similar_faces.len(), similar_faces.len(),
similar.threshold similar.threshold
); );
for (face_id, similarity) in similar_faces { for (face_id, similarity) in &similar_faces {
println!(" Face {}: similarity {:.3}", face_id, similarity); if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? {
println!(
" Face {}: similarity {:.3}, image: {}",
face_id, similarity, image_info.file_path
);
}
} }
Ok(()) Ok(())