feat(compare): add face comparison functionality with cosine similarity
This commit is contained in:
331
Cargo.lock
generated
331
Cargo.lock
generated
@@ -139,7 +139,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -148,6 +148,17 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
@@ -192,6 +203,29 @@ dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.60.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "062dddbc1ba4aca46de6338e2bf87771414c335f7b2f2036e8f3e9befebf88e6"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"clap 3.2.25",
|
||||
"env_logger",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
"peeking_take_while",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.70.1"
|
||||
@@ -210,7 +244,7 @@ dependencies = [
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -281,7 +315,7 @@ checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -349,6 +383,21 @@ dependencies = [
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "3.2.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags 1.3.2",
|
||||
"clap_lex 0.2.4",
|
||||
"indexmap 1.9.3",
|
||||
"strsim 0.10.0",
|
||||
"termcolor",
|
||||
"textwrap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.45"
|
||||
@@ -367,8 +416,8 @@ checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
"strsim",
|
||||
"clap_lex 0.7.5",
|
||||
"strsim 0.11.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -377,7 +426,7 @@ version = "4.5.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"clap 4.5.45",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -389,7 +438,16 @@ dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
|
||||
dependencies = [
|
||||
"os_str_bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -505,7 +563,7 @@ name = "detector"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bounding-box",
|
||||
"clap",
|
||||
"clap 4.5.45",
|
||||
"clap_complete",
|
||||
"color",
|
||||
"error-stack",
|
||||
@@ -518,12 +576,13 @@ dependencies = [
|
||||
"nalgebra",
|
||||
"ndarray",
|
||||
"ndarray-image",
|
||||
"ndarray-math",
|
||||
"ndarray-math 0.1.0 (git+https://git.darksailor.dev/servius/ndarray-math)",
|
||||
"ndarray-resize",
|
||||
"ndarray-safetensors",
|
||||
"ordered-float",
|
||||
"ort",
|
||||
"rusqlite",
|
||||
"sqlite3-safetensor-cosine",
|
||||
"tap",
|
||||
"thiserror 2.0.15",
|
||||
"tokio",
|
||||
@@ -548,7 +607,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -572,6 +631,19 @@ version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"humantime",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equator"
|
||||
version = "0.4.2"
|
||||
@@ -589,7 +661,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -598,6 +670,16 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "error-stack"
|
||||
version = "0.5.0"
|
||||
@@ -851,6 +933,12 @@ dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
@@ -866,7 +954,7 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
|
||||
dependencies = [
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -875,6 +963,30 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f"
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.63"
|
||||
@@ -1045,6 +1157,16 @@ version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.10.0"
|
||||
@@ -1052,7 +1174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1063,7 +1185,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1137,7 +1259,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1172,6 +1294,12 @@ version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "lebe"
|
||||
version = "0.5.2"
|
||||
@@ -1214,6 +1342,12 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.0"
|
||||
@@ -1355,7 +1489,7 @@ version = "0.1.0"
|
||||
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bindgen",
|
||||
"bindgen 0.70.1",
|
||||
"cc",
|
||||
"cmake",
|
||||
"diffy",
|
||||
@@ -1431,6 +1565,15 @@ dependencies = [
|
||||
"ndarray",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray-math"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ndarray",
|
||||
"num",
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray-math"
|
||||
version = "0.1.0"
|
||||
@@ -1552,7 +1695,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1652,6 +1795,12 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "os_str_bytes"
|
||||
version = "6.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
@@ -1664,6 +1813,12 @@ version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "peeking_take_while"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.1"
|
||||
@@ -1741,7 +1896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1769,7 +1924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2000,6 +2155,19 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.9.2",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -2060,7 +2228,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2154,18 +2322,79 @@ dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlite-loadable"
|
||||
version = "0.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a916b7bb8738eef189dea88731b619b80bf3f62b3acf05138fa43fbf8621cc94"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlite-loadable-macros",
|
||||
"sqlite3ext-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlite-loadable-macros"
|
||||
version = "0.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f98bc75a8d6fd24f6a2cfea34f28758780fa17279d3051eec926efa381971e48"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlite3-safetensor-cosine"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ndarray",
|
||||
"ndarray-math 0.1.0",
|
||||
"ndarray-safetensors",
|
||||
"sqlite-loadable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlite3ext-sys"
|
||||
version = "0.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3afdc2b3dc08f16d6eecf8aa07d19975a268603ab1cca67d3f9b4172c507cf16"
|
||||
dependencies = [
|
||||
"bindgen 0.60.1",
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.106"
|
||||
@@ -2185,7 +2414,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2213,6 +2442,21 @@ version = "0.12.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
@@ -2239,7 +2483,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2250,7 +2494,7 @@ checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2355,7 +2599,7 @@ version = "0.22.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"indexmap 2.10.0",
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
@@ -2381,7 +2625,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2540,7 +2784,7 @@ dependencies = [
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@@ -2562,7 +2806,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@@ -2582,6 +2826,18 @@ version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"home",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wide"
|
||||
version = "0.7.33"
|
||||
@@ -2608,6 +2864,15 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22"
|
||||
dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
@@ -2635,7 +2900,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2646,7 +2911,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2882,7 +3147,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
@@ -2903,7 +3168,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2923,7 +3188,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
@@ -2957,7 +3222,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[workspace]
|
||||
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors"]
|
||||
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors", "sqlite3-safetensor-cosine"]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
@@ -53,6 +53,7 @@ ordered-float = "5.0.0"
|
||||
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]}
|
||||
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
||||
sqlite3-safetensor-cosine = { version = "0.1.0", path = "sqlite3-safetensor-cosine" }
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
@@ -67,4 +68,4 @@ ort-directml = ["ort/directml"]
|
||||
mnn-metal = ["mnn/metal"]
|
||||
mnn-coreml = ["mnn/coreml"]
|
||||
|
||||
default = []
|
||||
default = ["mnn-metal","mnn-coreml"]
|
||||
|
||||
29
README.md
29
README.md
@@ -55,6 +55,35 @@ cargo run --release detect --output detected.jpg path/to/image.jpg
|
||||
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
|
||||
```
|
||||
|
||||
### Face Comparison
|
||||
|
||||
Compare faces between two images by computing and comparing their embeddings:
|
||||
|
||||
```bash
|
||||
# Compare faces in two images
|
||||
cargo run --release compare image1.jpg image2.jpg
|
||||
|
||||
# Compare with custom thresholds
|
||||
cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg
|
||||
|
||||
# Use ONNX Runtime backend for comparison
|
||||
cargo run --release compare -p cpu image1.jpg image2.jpg
|
||||
|
||||
# Use MNN with Metal acceleration
|
||||
cargo run --release compare -f metal image1.jpg image2.jpg
|
||||
```
|
||||
|
||||
The compare command will:
|
||||
1. Detect all faces in both images
|
||||
2. Generate embeddings for each detected face
|
||||
3. Compute cosine similarity between all face pairs
|
||||
4. Display similarity scores and the best match
|
||||
5. Provide interpretation of the similarity scores:
|
||||
- **> 0.8**: Very likely the same person
|
||||
- **0.6-0.8**: Possibly the same person
|
||||
- **0.4-0.6**: Unlikely to be the same person
|
||||
- **< 0.4**: Very unlikely to be the same person
|
||||
|
||||
### Backend Selection
|
||||
|
||||
The project supports two inference backends:
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
description = "A simple rust flake using rust-overlay and craneLib";
|
||||
|
||||
inputs = {
|
||||
self = {
|
||||
lfs = true;
|
||||
};
|
||||
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
crane.url = "github:ipetkov/crane";
|
||||
@@ -206,6 +203,7 @@
|
||||
packages = with pkgs;
|
||||
[
|
||||
stableToolchainWithRustAnalyzer
|
||||
cargo-expand
|
||||
cargo-nextest
|
||||
cargo-deny
|
||||
cmake
|
||||
|
||||
@@ -68,6 +68,7 @@ use safetensors::tensor::SafeTensors;
|
||||
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct SafeArraysView<'a> {
|
||||
pub tensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
14
sqlite3-safetensor-cosine/Cargo.toml
Normal file
14
sqlite3-safetensor-cosine/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "sqlite3-safetensor-cosine"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "staticlib"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.16.1"
|
||||
# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||
ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
|
||||
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
|
||||
sqlite-loadable = "0.0.5"
|
||||
61
sqlite3-safetensor-cosine/src/lib.rs
Normal file
61
sqlite3-safetensor-cosine/src/lib.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use sqlite_loadable::prelude::*;
|
||||
use sqlite_loadable::{Error, ErrorKind};
|
||||
use sqlite_loadable::{Result, api, define_scalar_function};
|
||||
|
||||
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
|
||||
#[inline(always)]
|
||||
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
|
||||
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
|
||||
}
|
||||
|
||||
if values.len() != 2 {
|
||||
return Err(Error::new(ErrorKind::Message(
|
||||
"cosine_similarity requires exactly 2 arguments".to_string(),
|
||||
)));
|
||||
}
|
||||
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
|
||||
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
|
||||
let array_1_st =
|
||||
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
|
||||
let array_2_st =
|
||||
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
|
||||
|
||||
let array_view_1 = array_1_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(custom_error)?;
|
||||
let array_view_2 = array_2_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(custom_error)?;
|
||||
|
||||
use ndarray_math::*;
|
||||
let similarity = array_view_1
|
||||
.cosine_similarity(array_view_2)
|
||||
.map_err(custom_error)?;
|
||||
api::result_double(context, similarity as f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
|
||||
define_scalar_function(
|
||||
db,
|
||||
"cosine_similarity",
|
||||
2,
|
||||
cosine_similarity,
|
||||
FunctionFlags::DETERMINISTIC,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Should only be called by underlying SQLite C APIs,
|
||||
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn sqlite3_extension_init(
|
||||
db: *mut sqlite3,
|
||||
pz_err_msg: *mut *mut c_char,
|
||||
p_api: *mut sqlite3_api_routines,
|
||||
) -> c_uint {
|
||||
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
|
||||
}
|
||||
84
src/cli.rs
84
src/cli.rs
@@ -1,6 +1,5 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use mnn::ForwardType;
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
@@ -11,14 +10,16 @@ pub struct Cli {
|
||||
pub enum SubCommand {
|
||||
#[clap(name = "detect")]
|
||||
Detect(Detect),
|
||||
#[clap(name = "list")]
|
||||
List(List),
|
||||
#[clap(name = "detect-multi")]
|
||||
DetectMulti(DetectMulti),
|
||||
#[clap(name = "query")]
|
||||
Query(Query),
|
||||
#[clap(name = "similar")]
|
||||
Similar(Similar),
|
||||
#[clap(name = "stats")]
|
||||
Stats(Stats),
|
||||
#[clap(name = "compare")]
|
||||
Compare(Compare),
|
||||
#[clap(name = "completions")]
|
||||
Completions { shell: clap_complete::Shell },
|
||||
}
|
||||
@@ -74,7 +75,47 @@ pub struct Detect {
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct List {}
|
||||
pub struct DetectMulti {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output_dir: Option<PathBuf>,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(
|
||||
long,
|
||||
help = "Image extensions to process (e.g., jpg,png,jpeg)",
|
||||
default_value = "jpg,jpeg,png,bmp,tiff,webp"
|
||||
)]
|
||||
pub extensions: String,
|
||||
#[clap(help = "Directory containing images to process")]
|
||||
pub input_dir: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Query {
|
||||
@@ -108,6 +149,41 @@ pub struct Stats {
|
||||
pub database: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Compare {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(help = "First image to compare")]
|
||||
pub image1: PathBuf,
|
||||
#[clap(help = "Second image to compare")]
|
||||
pub image2: PathBuf,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
|
||||
136
src/database.rs
136
src/database.rs
@@ -65,7 +65,14 @@ impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
||||
unsafe {
|
||||
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
conn.load_extension(
|
||||
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
None::<&str>,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
}
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
@@ -190,10 +197,9 @@ impl FaceDatabase {
|
||||
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![file_path, width, height])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
Ok(stmt
|
||||
.insert(params![file_path, width, height])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face detection results
|
||||
@@ -231,17 +237,16 @@ impl FaceDatabase {
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![
|
||||
image_id,
|
||||
bbox.x1() as f32,
|
||||
bbox.y1() as f32,
|
||||
bbox.x2() as f32,
|
||||
bbox.y2() as f32,
|
||||
confidence
|
||||
])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
Ok(stmt
|
||||
.insert(params![
|
||||
image_id,
|
||||
bbox.x1() as f32,
|
||||
bbox.y1() as f32,
|
||||
bbox.x2() as f32,
|
||||
bbox.y2() as f32,
|
||||
confidence
|
||||
])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face landmarks
|
||||
@@ -258,22 +263,21 @@ impl FaceDatabase {
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![
|
||||
face_id,
|
||||
landmarks.left_eye.x,
|
||||
landmarks.left_eye.y,
|
||||
landmarks.right_eye.x,
|
||||
landmarks.right_eye.y,
|
||||
landmarks.nose.x,
|
||||
landmarks.nose.y,
|
||||
landmarks.left_mouth.x,
|
||||
landmarks.left_mouth.y,
|
||||
landmarks.right_mouth.x,
|
||||
landmarks.right_mouth.y,
|
||||
])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
Ok(stmt
|
||||
.insert(params![
|
||||
face_id,
|
||||
landmarks.left_eye.x,
|
||||
landmarks.left_eye.y,
|
||||
landmarks.right_eye.x,
|
||||
landmarks.right_eye.y,
|
||||
landmarks.nose.x,
|
||||
landmarks.nose.y,
|
||||
landmarks.left_mouth.x,
|
||||
landmarks.left_mouth.y,
|
||||
landmarks.right_mouth.x,
|
||||
landmarks.right_mouth.y,
|
||||
])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face embeddings
|
||||
@@ -310,12 +314,12 @@ impl FaceDatabase {
|
||||
embedding: ndarray::ArrayView1<f32>,
|
||||
model_name: &str,
|
||||
) -> Result<i64> {
|
||||
let embedding_bytes =
|
||||
let safe_arrays =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
||||
.change_context(Error)?
|
||||
.serialize()
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
||||
@@ -462,6 +466,35 @@ impl FaceDatabase {
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT images.id, images.file_path, images.width, images.height, images.created_at
|
||||
FROM images
|
||||
JOIN faces ON faces.image_id = images.id
|
||||
WHERE faces.id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![face_id], |row| {
|
||||
Ok(ImageRecord {
|
||||
id: row.get(0)?,
|
||||
file_path: row.get(1)?,
|
||||
width: row.get(2)?,
|
||||
height: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get database statistics
|
||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let images: usize = self
|
||||
@@ -528,6 +561,39 @@ impl FaceDatabase {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
|
||||
let embedding_bytes =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||
.change_context(Error)
|
||||
.unwrap()
|
||||
.serialize()
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT face_id,
|
||||
cosine_similarity(?1, embedding)
|
||||
FROM embeddings
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
let result_iter = stmt
|
||||
.query_map(params![embedding_bytes], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||
})
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
for result in result_iter {
|
||||
println!("{:?}", result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
|
||||
@@ -310,7 +310,7 @@ pub trait FaceDetector {
|
||||
fn detect_faces(
|
||||
&mut self,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
config: FaceDetectionConfig,
|
||||
config: &FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, _channels) = image.dim();
|
||||
let output = self
|
||||
|
||||
@@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
|
||||
use crate::errors::*;
|
||||
use ndarray::{Array2, ArrayView4};
|
||||
|
||||
pub mod preprocessing {
|
||||
use ndarray::*;
|
||||
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
|
||||
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
|
||||
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
|
||||
let mean = image.mean().unwrap_or(0.0);
|
||||
let std = image.std(0.0);
|
||||
if std > 0.0 {
|
||||
image.mapv_inplace(|x| (x - mean) / std);
|
||||
} else {
|
||||
image.mapv_inplace(|x| (x - 127.5) / 128.0)
|
||||
}
|
||||
});
|
||||
owned
|
||||
}
|
||||
}
|
||||
|
||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||
pub trait FaceEmbedder {
|
||||
/// Generate embeddings for a batch of face images
|
||||
|
||||
@@ -4,6 +4,7 @@ pub mod ort;
|
||||
use crate::errors::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use ndarray_math::{CosineSimilarity, EuclideanDistance};
|
||||
|
||||
/// Configuration for face embedding processing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -32,9 +33,9 @@ impl FaceEmbeddingConfig {
|
||||
impl Default for FaceEmbeddingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input_width: 160,
|
||||
input_height: 160,
|
||||
normalize: true,
|
||||
input_width: 320,
|
||||
input_height: 320,
|
||||
normalize: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,15 +64,14 @@ impl FaceEmbedding {
|
||||
|
||||
/// Calculate cosine similarity with another embedding
|
||||
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
|
||||
let dot_product = self.vector.dot(&other.vector);
|
||||
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
|
||||
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
|
||||
dot_product / (norm_self * norm_other)
|
||||
self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Calculate Euclidean distance with another embedding
|
||||
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
|
||||
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt()
|
||||
self.vector
|
||||
.euclidean_distance(other.vector.view())
|
||||
.unwrap_or(f32::INFINITY)
|
||||
}
|
||||
|
||||
/// Normalize the embedding vector to unit length
|
||||
|
||||
@@ -64,10 +64,7 @@ impl EmbeddingGenerator {
|
||||
}
|
||||
|
||||
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
let tensor = face
|
||||
// .permuted_axes((0, 3, 1, 2))
|
||||
.as_standard_layout()
|
||||
.mapv(|x| x as f32);
|
||||
let tensor = crate::faceembed::preprocessing::preprocess(face);
|
||||
let shape: [usize; 4] = tensor.dim().into();
|
||||
let shape = shape.map(|f| f as i32);
|
||||
let output = self
|
||||
|
||||
@@ -135,10 +135,12 @@ impl EmbeddingGenerator {
|
||||
|
||||
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||
let input_tensor = faces
|
||||
.mapv(|x| x as f32 / 255.0)
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
|
||||
|
||||
// face_array = np.asarray(face_resized, 'float32')
|
||||
// mean, std = face_array.mean(), face_array.std()
|
||||
// face_normalized = (face_array - mean) / std
|
||||
// let input_tensor = faces.mean()
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
|
||||
576
src/main.rs
576
src/main.rs
@@ -75,8 +75,61 @@ pub fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::List(list) => {
|
||||
println!("List: {:?}", list);
|
||||
cli::SubCommand::DetectMulti(detect_multi) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
|
||||
let executor = detect_multi
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if detect_multi.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(
|
||||
detect_multi.ort_execution_provider.clone(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_multi_detection(detect_multi, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_multi_detection(detect_multi, retinaface, facenet)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::Query(query) => {
|
||||
run_query(query)?;
|
||||
@@ -87,6 +140,59 @@ pub fn main() -> Result<()> {
|
||||
cli::SubCommand::Stats(stats) => {
|
||||
run_stats(stats)?;
|
||||
}
|
||||
cli::SubCommand::Compare(compare) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
let executor = compare
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if compare.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(compare.ort_execution_provider.clone()))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_compare(compare, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_compare(compare, retinaface, facenet)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
@@ -122,7 +228,7 @@ where
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.view(),
|
||||
FaceDetectionConfig::default()
|
||||
&FaceDetectionConfig::default()
|
||||
.with_threshold(detect.threshold)
|
||||
.with_nms_threshold(detect.nms_threshold),
|
||||
)
|
||||
@@ -163,7 +269,7 @@ where
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(160, 160, &ResizeOptions::default())
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
@@ -182,11 +288,14 @@ where
|
||||
|
||||
if chunk.len() < chunk_size {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((160, 160, 3));
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
@@ -328,6 +437,446 @@ fn run_query(query: cli::Query) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_compare<D, E>(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
// Helper function to detect faces and compute embeddings for an image
|
||||
fn process_image<D, E>(
|
||||
image_path: &std::path::Path,
|
||||
retinaface: &mut D,
|
||||
facenet: &mut E,
|
||||
config: &FaceDetectionConfig,
|
||||
batch_size: usize,
|
||||
) -> Result<(Vec<Array1<f32>>, usize)>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
let image = image::open(image_path)
|
||||
.change_context(Error)
|
||||
.attach_printable(image_path.to_string_lossy().to_string())?;
|
||||
let image = image.into_rgb8();
|
||||
let array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
|
||||
let output = retinaface
|
||||
.detect_faces(array.view(), config)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
|
||||
tracing::info!(
|
||||
"Detected {} faces in {}",
|
||||
output.bbox.len(),
|
||||
image_path.display()
|
||||
);
|
||||
|
||||
if output.bbox.is_empty() {
|
||||
return Ok((Vec::new(), 0));
|
||||
}
|
||||
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
let chunk_size = batch_size;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
if chunk.len() < chunk_size {
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>()?;
|
||||
|
||||
// Flatten embeddings into individual face embeddings
|
||||
let mut face_embeddings = Vec::new();
|
||||
for embedding_batch in embeddings {
|
||||
for i in 0..output.bbox.len().min(embedding_batch.nrows()) {
|
||||
face_embeddings.push(embedding_batch.row(i).to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((face_embeddings, output.bbox.len()))
|
||||
}
|
||||
|
||||
// Helper function to compute cosine similarity between two embeddings
|
||||
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
|
||||
let dot_product = a.dot(b);
|
||||
let norm_a = a.dot(a).sqrt();
|
||||
let norm_b = b.dot(b).sqrt();
|
||||
dot_product / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
let config = FaceDetectionConfig::default()
|
||||
.with_threshold(compare.threshold)
|
||||
.with_nms_threshold(compare.nms_threshold);
|
||||
|
||||
// Process both images
|
||||
let (embeddings1, face_count1) = process_image(
|
||||
&compare.image1,
|
||||
&mut retinaface,
|
||||
&mut facenet,
|
||||
&config,
|
||||
compare.batch_size,
|
||||
)?;
|
||||
|
||||
let (embeddings2, face_count2) = process_image(
|
||||
&compare.image2,
|
||||
&mut retinaface,
|
||||
&mut facenet,
|
||||
&config,
|
||||
compare.batch_size,
|
||||
)?;
|
||||
|
||||
println!(
|
||||
"Image 1 ({}): {} faces detected",
|
||||
compare.image1.display(),
|
||||
face_count1
|
||||
);
|
||||
println!(
|
||||
"Image 2 ({}): {} faces detected",
|
||||
compare.image2.display(),
|
||||
face_count2
|
||||
);
|
||||
|
||||
if embeddings1.is_empty() && embeddings2.is_empty() {
|
||||
println!("No faces detected in either image");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if embeddings1.is_empty() {
|
||||
println!("No faces detected in image 1");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if embeddings2.is_empty() {
|
||||
println!("No faces detected in image 2");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Compare all faces between the two images
|
||||
println!("\nFace comparison results:");
|
||||
println!("========================");
|
||||
|
||||
let mut max_similarity = f32::NEG_INFINITY;
|
||||
let mut best_match = (0, 0);
|
||||
|
||||
for (i, emb1) in embeddings1.iter().enumerate() {
|
||||
for (j, emb2) in embeddings2.iter().enumerate() {
|
||||
let similarity = cosine_similarity(emb1, emb2);
|
||||
println!(
|
||||
"Face {} (image 1) vs Face {} (image 2): {:.4}",
|
||||
i + 1,
|
||||
j + 1,
|
||||
similarity
|
||||
);
|
||||
|
||||
if similarity > max_similarity {
|
||||
max_similarity = similarity;
|
||||
best_match = (i + 1, j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}",
|
||||
best_match.0, best_match.1, max_similarity
|
||||
);
|
||||
|
||||
// Interpretation of similarity score
|
||||
if max_similarity > 0.8 {
|
||||
println!("Interpretation: Very likely the same person");
|
||||
} else if max_similarity > 0.6 {
|
||||
println!("Interpretation: Possibly the same person");
|
||||
} else if max_similarity > 0.4 {
|
||||
println!("Interpretation: Unlikely to be the same person");
|
||||
} else {
|
||||
println!("Interpretation: Very unlikely to be the same person");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_multi_detection<D, E>(
|
||||
detect_multi: cli::DetectMulti,
|
||||
mut retinaface: D,
|
||||
mut facenet: E,
|
||||
) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
use std::fs;
|
||||
|
||||
// Initialize database - always save to database for multi-detection
|
||||
let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?;
|
||||
|
||||
// Parse supported extensions
|
||||
let extensions: std::collections::HashSet<String> = detect_multi
|
||||
.extensions
|
||||
.split(',')
|
||||
.map(|ext| ext.trim().to_lowercase())
|
||||
.collect();
|
||||
|
||||
// Create output directory if specified
|
||||
if let Some(ref output_dir) = detect_multi.output_dir {
|
||||
fs::create_dir_all(output_dir)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output directory")?;
|
||||
}
|
||||
|
||||
// Read directory and filter image files
|
||||
let entries = fs::read_dir(&detect_multi.input_dir)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read input directory")?;
|
||||
|
||||
let mut image_paths = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry.change_context(Error)?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.is_file() {
|
||||
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
|
||||
if extensions.contains(&ext.to_lowercase()) {
|
||||
image_paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if image_paths.is_empty() {
|
||||
tracing::warn!(
|
||||
"No image files found in directory: {:?}",
|
||||
detect_multi.input_dir
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!("Found {} image files to process", image_paths.len());
|
||||
|
||||
let mut total_faces = 0;
|
||||
let mut processed_images = 0;
|
||||
|
||||
// Process each image
|
||||
for (idx, image_path) in image_paths.iter().enumerate() {
|
||||
tracing::info!(
|
||||
"Processing image {}/{}: {:?}",
|
||||
idx + 1,
|
||||
image_paths.len(),
|
||||
image_path
|
||||
);
|
||||
|
||||
// Load and process image
|
||||
let image = match image::open(image_path) {
|
||||
Ok(img) => img.into_rgb8(),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to load image {:?}: {}", image_path, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (image_width, image_height) = image.dimensions();
|
||||
let mut array = match image.into_ndarray().change_context(errors::Error) {
|
||||
Ok(arr) => arr,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to convert image to ndarray: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let config = FaceDetectionConfig::default()
|
||||
.with_threshold(detect_multi.threshold)
|
||||
.with_nms_threshold(detect_multi.nms_threshold);
|
||||
// Detect faces
|
||||
let output = match retinaface.detect_faces(array.view(), &config) {
|
||||
Ok(output) => output,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let num_faces = output.bbox.len();
|
||||
total_faces += num_faces;
|
||||
|
||||
if num_faces == 0 {
|
||||
tracing::info!("No faces detected in {:?}", image_path);
|
||||
} else {
|
||||
tracing::info!("Detected {} faces in {:?}", num_faces, image_path);
|
||||
}
|
||||
|
||||
// Store image and detections in database
|
||||
let image_path_str = image_path.to_string_lossy();
|
||||
let img_id = match db.store_image(&image_path_str, image_width, image_height) {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to store image in database: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let face_ids = match db.store_face_detections(img_id, &output) {
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to store face detections in database: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Draw bounding boxes if output directory is specified
|
||||
if detect_multi.output_dir.is_some() {
|
||||
for bbox in &output.bbox {
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Process face embeddings if faces were detected
|
||||
if !face_ids.is_empty() {
|
||||
let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) {
|
||||
Ok(rois) => rois,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to extract face ROIs: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let face_rois: Result<Vec<_>> = face_rois
|
||||
.into_iter()
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let face_rois = match face_rois {
|
||||
Ok(rois) => rois,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to resize face ROIs: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let chunk_size = detect_multi.batch_size;
|
||||
let embeddings: Result<Vec<Array2<f32>>> = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
if chunk.len() < chunk_size {
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let embeddings = match embeddings {
|
||||
Ok(emb) => emb,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to generate embeddings: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Store embeddings in database
|
||||
if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) {
|
||||
tracing::error!("Failed to store embeddings in database: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Save output image if directory specified
|
||||
if let Some(ref output_dir) = detect_multi.output_dir {
|
||||
let output_filename = format!(
|
||||
"detected_{}",
|
||||
image_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let output_path = output_dir.join(output_filename);
|
||||
|
||||
let v = array.view();
|
||||
let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) {
|
||||
Ok(img) => img,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to convert ndarray to image: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = output_image.save(&output_path) {
|
||||
tracing::error!("Failed to save output image to {:?}: {}", output_path, e);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!("Saved output image to {:?}", output_path);
|
||||
}
|
||||
|
||||
processed_images += 1;
|
||||
}
|
||||
|
||||
// Print final statistics
|
||||
tracing::info!(
|
||||
"Processing complete: {}/{} images processed successfully, {} total faces detected",
|
||||
processed_images,
|
||||
image_paths.len(),
|
||||
total_faces
|
||||
);
|
||||
|
||||
let (num_images, num_faces, num_landmarks, num_embeddings) =
|
||||
db.get_stats().change_context(Error)?;
|
||||
tracing::info!(
|
||||
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
|
||||
num_images,
|
||||
num_faces,
|
||||
num_landmarks,
|
||||
num_embeddings
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
|
||||
|
||||
@@ -341,14 +890,19 @@ fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||
let similar_faces = db
|
||||
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Get image information for the similar faces
|
||||
println!(
|
||||
"Found {} similar faces (threshold: {:.3}):",
|
||||
similar_faces.len(),
|
||||
similar.threshold
|
||||
);
|
||||
for (face_id, similarity) in similar_faces {
|
||||
println!(" Face {}: similarity {:.3}", face_id, similarity);
|
||||
for (face_id, similarity) in &similar_faces {
|
||||
if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? {
|
||||
println!(
|
||||
" Face {}: similarity {:.3}, image: {}",
|
||||
face_id, similarity, image_info.file_path
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user