From bfa389b497e9e1e646581289a9976949b079a44b Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Thu, 21 Aug 2025 17:34:07 +0530 Subject: [PATCH] feat(compare): add face comparison functionality with cosine similarity --- Cargo.lock | 331 +++++++++++++-- Cargo.toml | 5 +- README.md | 33 +- flake.nix | 4 +- ndarray-safetensors/src/lib.rs | 1 + sqlite3-safetensor-cosine/Cargo.toml | 14 + sqlite3-safetensor-cosine/src/lib.rs | 61 +++ src/cli.rs | 84 +++- src/database.rs | 136 +++++-- src/facedet/retinaface.rs | 2 +- src/faceembed.rs | 17 + src/faceembed/facenet.rs | 16 +- src/faceembed/facenet/mnn.rs | 5 +- src/faceembed/facenet/ort.rs | 10 +- src/main.rs | 576 ++++++++++++++++++++++++++- 15 files changed, 1188 insertions(+), 107 deletions(-) create mode 100644 sqlite3-safetensor-cosine/Cargo.toml create mode 100644 sqlite3-safetensor-cosine/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 540a5e9..2889848 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,7 +139,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -148,6 +148,17 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -192,6 +203,29 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "bindgen" +version = "0.60.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "062dddbc1ba4aca46de6338e2bf87771414c335f7b2f2036e8f3e9befebf88e6" +dependencies = [ + "bitflags 1.3.2", + "cexpr", + "clang-sys", + "clap 3.2.25", + "env_logger", + "lazy_static", + "lazycell", + "log", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "which", +] + [[package]] name = "bindgen" version = "0.70.1" @@ -210,7 +244,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.106", ] [[package]] @@ -281,7 +315,7 @@ checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -349,6 +383,21 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "3.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" +dependencies = [ + "atty", + "bitflags 1.3.2", + "clap_lex 0.2.4", + "indexmap 1.9.3", + "strsim 0.10.0", + "termcolor", + "textwrap", +] + [[package]] name = "clap" version = "4.5.45" @@ -367,8 +416,8 @@ checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8" dependencies = [ "anstream", "anstyle", - "clap_lex", - "strsim", + "clap_lex 0.7.5", + "strsim 0.11.1", ] [[package]] @@ -377,7 +426,7 @@ version = "4.5.57" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad" dependencies = [ - "clap", + "clap 4.5.45", ] [[package]] @@ -389,7 +438,16 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "clap_lex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" +dependencies = [ + "os_str_bytes", ] [[package]] @@ -505,7 +563,7 @@ name = "detector" version = "0.1.0" dependencies = [ "bounding-box", - "clap", + "clap 4.5.45", "clap_complete", "color", "error-stack", @@ -518,12 +576,13 @@ dependencies = [ "nalgebra", "ndarray", "ndarray-image", - "ndarray-math", + "ndarray-math 0.1.0 (git+https://git.darksailor.dev/servius/ndarray-math)", "ndarray-resize", "ndarray-safetensors", "ordered-float", "ort", "rusqlite", + "sqlite3-safetensor-cosine", "tap", "thiserror 2.0.15", "tokio", @@ -548,7 +607,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -572,6 +631,19 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + [[package]] name = "equator" version = "0.4.2" @@ -589,7 +661,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -598,6 +670,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + [[package]] name = "error-stack" version = "0.5.0" @@ -851,6 +933,12 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.15.5" @@ -866,7 +954,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown", + "hashbrown 0.15.5", ] [[package]] @@ -875,6 +963,30 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "humantime" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" + [[package]] name = "iana-time-zone" version = "0.1.63" @@ -1045,6 +1157,16 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + [[package]] name = "indexmap" version = "2.10.0" @@ -1052,7 +1174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.5", ] [[package]] @@ -1063,7 +1185,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -1137,7 +1259,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -1172,6 +1294,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "lebe" version = "0.5.2" @@ -1214,6 +1342,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "litemap" version = "0.8.0" @@ -1355,7 +1489,7 @@ version = "0.1.0" source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5" dependencies = [ "anyhow", - "bindgen", + "bindgen 0.70.1", "cc", "cmake", "diffy", @@ -1431,6 +1565,15 @@ dependencies = [ "ndarray", ] +[[package]] +name = "ndarray-math" +version = "0.1.0" +dependencies = [ + "ndarray", + "num", + "thiserror 2.0.15", +] + [[package]] name = "ndarray-math" version = "0.1.0" @@ -1552,7 +1695,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -1652,6 +1795,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "os_str_bytes" +version = "6.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" + [[package]] name = "overload" version = "0.1.1" @@ -1664,6 +1813,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1741,7 +1896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.106", ] [[package]] @@ -1769,7 +1924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" dependencies = [ "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2000,6 +2155,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.9.2", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -2060,7 +2228,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2154,18 +2322,79 @@ dependencies = [ "lock_api", ] +[[package]] +name = "sqlite-loadable" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a916b7bb8738eef189dea88731b619b80bf3f62b3acf05138fa43fbf8621cc94" +dependencies = [ + "bitflags 1.3.2", + "serde", + "serde_json", + "sqlite-loadable-macros", + "sqlite3ext-sys", +] + +[[package]] +name = "sqlite-loadable-macros" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98bc75a8d6fd24f6a2cfea34f28758780fa17279d3051eec926efa381971e48" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "sqlite3-safetensor-cosine" +version = "0.1.0" +dependencies = [ + "ndarray", + "ndarray-math 0.1.0", + "ndarray-safetensors", + "sqlite-loadable", +] + +[[package]] +name = "sqlite3ext-sys" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3afdc2b3dc08f16d6eecf8aa07d19975a268603ab1cca67d3f9b4172c507cf16" +dependencies = [ + "bindgen 0.60.1", + "cc", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.106" @@ -2185,7 +2414,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2213,6 +2442,21 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "textwrap" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057" + [[package]] name = "thiserror" version = "1.0.69" @@ -2239,7 +2483,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2250,7 +2494,7 @@ checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2355,7 +2599,7 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", @@ -2381,7 +2625,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2540,7 +2784,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.106", "wasm-bindgen-shared", ] @@ -2562,7 +2806,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2582,6 +2826,18 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "wide" version = "0.7.33" @@ -2608,6 +2864,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2635,7 +2900,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2646,7 +2911,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2882,7 +3147,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", "synstructure", ] @@ -2903,7 +3168,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2923,7 +3188,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", "synstructure", ] @@ -2957,7 +3222,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4e9ca64..add6c8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors"] +members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors", "sqlite3-safetensor-cosine"] [workspace.package] version = "0.1.0" @@ -53,6 +53,7 @@ ordered-float = "5.0.0" ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]} ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" } ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" } +sqlite3-safetensor-cosine = { version = "0.1.0", path = "sqlite3-safetensor-cosine" } [profile.release] debug = true @@ -67,4 +68,4 @@ ort-directml = ["ort/directml"] mnn-metal = ["mnn/metal"] mnn-coreml = ["mnn/coreml"] -default = [] +default = ["mnn-metal","mnn-coreml"] diff --git a/README.md b/README.md index 0e8fc44..58376a5 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,35 @@ cargo run --release detect --output detected.jpg path/to/image.jpg cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg ``` +### Face Comparison + +Compare faces between two images by computing and comparing their embeddings: + +```bash +# Compare faces in two images +cargo run --release compare image1.jpg image2.jpg + +# Compare with custom thresholds +cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg + +# Use ONNX Runtime backend for comparison +cargo run --release compare -p cpu image1.jpg image2.jpg + +# Use MNN with Metal acceleration +cargo run --release compare -f metal image1.jpg image2.jpg +``` + +The compare command will: +1. Detect all faces in both images +2. Generate embeddings for each detected face +3. Compute cosine similarity between all face pairs +4. Display similarity scores and the best match +5. Provide interpretation of the similarity scores: + - **> 0.8**: Very likely the same person + - **0.6-0.8**: Possibly the same person + - **0.4-0.6**: Unlikely to be the same person + - **< 0.4**: Very unlikely to be the same person + ### Backend Selection The project supports two inference backends: @@ -106,7 +135,7 @@ The MNN backend supports various execution backends: - **CPU** - Default, works on all platforms - **Metal** - macOS GPU acceleration -- **CoreML** - macOS/iOS neural engine acceleration +- **CoreML** - macOS/iOS neural engine acceleration - **OpenCL** - Cross-platform GPU acceleration ```bash @@ -179,7 +208,7 @@ MIT License Key dependencies include: - **MNN** - High-performance neural network inference framework (MNN backend) -- **ONNX Runtime** - Cross-platform ML inference (ORT backend) +- **ONNX Runtime** - Cross-platform ML inference (ORT backend) - **ndarray** - N-dimensional array processing - **image** - Image processing and I/O - **clap** - Command line argument parsing diff --git a/flake.nix b/flake.nix index 1c2275f..f38df5b 100644 --- a/flake.nix +++ b/flake.nix @@ -2,9 +2,6 @@ description = "A simple rust flake using rust-overlay and craneLib"; inputs = { - self = { - lfs = true; - }; nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; flake-utils.url = "github:numtide/flake-utils"; crane.url = "github:ipetkov/crane"; @@ -206,6 +203,7 @@ packages = with pkgs; [ stableToolchainWithRustAnalyzer + cargo-expand cargo-nextest cargo-deny cmake diff --git a/ndarray-safetensors/src/lib.rs b/ndarray-safetensors/src/lib.rs index ab7e96d..3ec5a74 100644 --- a/ndarray-safetensors/src/lib.rs +++ b/ndarray-safetensors/src/lib.rs @@ -68,6 +68,7 @@ use safetensors::tensor::SafeTensors; /// let view = SafeArrayView::from_bytes(&bytes).unwrap(); /// let tensor: ndarray::ArrayView2 = view.tensor("data").unwrap(); /// ``` +#[derive(Debug)] pub struct SafeArraysView<'a> { pub tensors: SafeTensors<'a>, } diff --git a/sqlite3-safetensor-cosine/Cargo.toml b/sqlite3-safetensor-cosine/Cargo.toml new file mode 100644 index 0000000..0002b0a --- /dev/null +++ b/sqlite3-safetensor-cosine/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sqlite3-safetensor-cosine" +version.workspace = true +edition.workspace = true + +[lib] +crate-type = ["cdylib", "staticlib"] + +[dependencies] +ndarray = "0.16.1" +# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" } +ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" } +ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" } +sqlite-loadable = "0.0.5" diff --git a/sqlite3-safetensor-cosine/src/lib.rs b/sqlite3-safetensor-cosine/src/lib.rs new file mode 100644 index 0000000..99a317d --- /dev/null +++ b/sqlite3-safetensor-cosine/src/lib.rs @@ -0,0 +1,61 @@ +use sqlite_loadable::prelude::*; +use sqlite_loadable::{Error, ErrorKind}; +use sqlite_loadable::{Result, api, define_scalar_function}; + +fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + #[inline(always)] + fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error { + sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string())) + } + + if values.len() != 2 { + return Err(Error::new(ErrorKind::Message( + "cosine_similarity requires exactly 2 arguments".to_string(), + ))); + } + let array_1 = api::value_blob(values.get(0).expect("1st argument")); + let array_2 = api::value_blob(values.get(1).expect("2nd argument")); + let array_1_st = + ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?; + let array_2_st = + ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?; + + let array_view_1 = array_1_st + .tensor_by_index::(0) + .map_err(custom_error)?; + let array_view_2 = array_2_st + .tensor_by_index::(0) + .map_err(custom_error)?; + + use ndarray_math::*; + let similarity = array_view_1 + .cosine_similarity(array_view_2) + .map_err(custom_error)?; + api::result_double(context, similarity as f64); + + Ok(()) +} + +pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> { + define_scalar_function( + db, + "cosine_similarity", + 2, + cosine_similarity, + FunctionFlags::DETERMINISTIC, + )?; + Ok(()) +} + +/// # Safety +/// +/// Should only be called by underlying SQLite C APIs, +/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn sqlite3_extension_init( + db: *mut sqlite3, + pz_err_msg: *mut *mut c_char, + p_api: *mut sqlite3_api_routines, +) -> c_uint { + register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init) +} diff --git a/src/cli.rs b/src/cli.rs index 1fb20d4..fae29ac 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,6 +1,5 @@ use std::path::PathBuf; -use mnn::ForwardType; #[derive(Debug, clap::Parser)] pub struct Cli { #[clap(subcommand)] @@ -11,14 +10,16 @@ pub struct Cli { pub enum SubCommand { #[clap(name = "detect")] Detect(Detect), - #[clap(name = "list")] - List(List), + #[clap(name = "detect-multi")] + DetectMulti(DetectMulti), #[clap(name = "query")] Query(Query), #[clap(name = "similar")] Similar(Similar), #[clap(name = "stats")] Stats(Stats), + #[clap(name = "compare")] + Compare(Compare), #[clap(name = "completions")] Completions { shell: clap_complete::Shell }, } @@ -74,7 +75,47 @@ pub struct Detect { } #[derive(Debug, clap::Args)] -pub struct List {} +pub struct DetectMulti { + #[clap(short, long)] + pub model: Option, + #[clap(short = 'M', long, default_value = "retina-face")] + pub model_type: Models, + #[clap(short, long)] + pub output_dir: Option, + #[clap( + short = 'p', + long, + default_value = "cpu", + group = "execution_provider", + required_unless_present = "mnn_forward_type" + )] + pub ort_execution_provider: Vec, + #[clap( + short = 'f', + long, + group = "execution_provider", + required_unless_present = "ort_execution_provider" + )] + pub mnn_forward_type: Option, + #[clap(short, long, default_value_t = 0.8)] + pub threshold: f32, + #[clap(short, long, default_value_t = 0.3)] + pub nms_threshold: f32, + #[clap(short, long, default_value_t = 8)] + pub batch_size: usize, + #[clap(short = 'd', long, default_value = "face_detections.db")] + pub database: PathBuf, + #[clap(long, default_value = "facenet")] + pub model_name: String, + #[clap( + long, + help = "Image extensions to process (e.g., jpg,png,jpeg)", + default_value = "jpg,jpeg,png,bmp,tiff,webp" + )] + pub extensions: String, + #[clap(help = "Directory containing images to process")] + pub input_dir: PathBuf, +} #[derive(Debug, clap::Args)] pub struct Query { @@ -108,6 +149,41 @@ pub struct Stats { pub database: PathBuf, } +#[derive(Debug, clap::Args)] +pub struct Compare { + #[clap(short, long)] + pub model: Option, + #[clap(short = 'M', long, default_value = "retina-face")] + pub model_type: Models, + #[clap( + short = 'p', + long, + default_value = "cpu", + group = "execution_provider", + required_unless_present = "mnn_forward_type" + )] + pub ort_execution_provider: Vec, + #[clap( + short = 'f', + long, + group = "execution_provider", + required_unless_present = "ort_execution_provider" + )] + pub mnn_forward_type: Option, + #[clap(short, long, default_value_t = 0.8)] + pub threshold: f32, + #[clap(short, long, default_value_t = 0.3)] + pub nms_threshold: f32, + #[clap(short, long, default_value_t = 8)] + pub batch_size: usize, + #[clap(long, default_value = "facenet")] + pub model_name: String, + #[clap(help = "First image to compare")] + pub image1: PathBuf, + #[clap(help = "Second image to compare")] + pub image2: PathBuf, +} + impl Cli { pub fn completions(shell: clap_complete::Shell) { let mut command = ::command(); diff --git a/src/database.rs b/src/database.rs index 35211e7..048e54d 100644 --- a/src/database.rs +++ b/src/database.rs @@ -65,7 +65,14 @@ impl FaceDatabase { /// Create a new database connection and initialize tables pub fn new>(db_path: P) -> Result { let conn = Connection::open(db_path).change_context(Error)?; - add_sqlite_cosine_similarity(&conn).change_context(Error)?; + unsafe { + let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?; + conn.load_extension( + "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib", + None::<&str>, + ) + .change_context(Error)?; + } let db = Self { conn }; db.create_tables()?; Ok(db) @@ -190,10 +197,9 @@ impl FaceDatabase { .prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)") .change_context(Error)?; - stmt.execute(params![file_path, width, height]) - .change_context(Error)?; - - Ok(self.conn.last_insert_rowid()) + Ok(stmt + .insert(params![file_path, width, height]) + .change_context(Error)?) } /// Store face detection results @@ -231,17 +237,16 @@ impl FaceDatabase { ) .change_context(Error)?; - stmt.execute(params![ - image_id, - bbox.x1() as f32, - bbox.y1() as f32, - bbox.x2() as f32, - bbox.y2() as f32, - confidence - ]) - .change_context(Error)?; - - Ok(self.conn.last_insert_rowid()) + Ok(stmt + .insert(params![ + image_id, + bbox.x1() as f32, + bbox.y1() as f32, + bbox.x2() as f32, + bbox.y2() as f32, + confidence + ]) + .change_context(Error)?) } /// Store face landmarks @@ -258,22 +263,21 @@ impl FaceDatabase { ) .change_context(Error)?; - stmt.execute(params![ - face_id, - landmarks.left_eye.x, - landmarks.left_eye.y, - landmarks.right_eye.x, - landmarks.right_eye.y, - landmarks.nose.x, - landmarks.nose.y, - landmarks.left_mouth.x, - landmarks.left_mouth.y, - landmarks.right_mouth.x, - landmarks.right_mouth.y, - ]) - .change_context(Error)?; - - Ok(self.conn.last_insert_rowid()) + Ok(stmt + .insert(params![ + face_id, + landmarks.left_eye.x, + landmarks.left_eye.y, + landmarks.right_eye.x, + landmarks.right_eye.y, + landmarks.nose.x, + landmarks.nose.y, + landmarks.left_mouth.x, + landmarks.left_mouth.y, + landmarks.right_mouth.x, + landmarks.right_mouth.y, + ]) + .change_context(Error)?) } /// Store face embeddings @@ -310,12 +314,12 @@ impl FaceDatabase { embedding: ndarray::ArrayView1, model_name: &str, ) -> Result { - let embedding_bytes = + let safe_arrays = ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)]) - .change_context(Error)? - .serialize() .change_context(Error)?; + let embedding_bytes = safe_arrays.serialize().change_context(Error)?; + let mut stmt = self .conn .prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)") @@ -462,6 +466,35 @@ impl FaceDatabase { Ok(embeddings) } + pub fn get_image_for_face(&self, face_id: i64) -> Result> { + let mut stmt = self + .conn + .prepare( + r#" + SELECT images.id, images.file_path, images.width, images.height, images.created_at + FROM images + JOIN faces ON faces.image_id = images.id + WHERE faces.id = ?1 + "#, + ) + .change_context(Error)?; + + let result = stmt + .query_row(params![face_id], |row| { + Ok(ImageRecord { + id: row.get(0)?, + file_path: row.get(1)?, + width: row.get(2)?, + height: row.get(3)?, + created_at: row.get(4)?, + }) + }) + .optional() + .change_context(Error)?; + + Ok(result) + } + /// Get database statistics pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> { let images: usize = self @@ -528,6 +561,39 @@ impl FaceDatabase { Ok(result) } + + pub fn query_similarity(&self, embedding: &ndarray::Array1) { + let embedding_bytes = + ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())]) + .change_context(Error) + .unwrap() + .serialize() + .change_context(Error) + .unwrap(); + + let mut stmt = self + .conn + .prepare( + r#" + SELECT face_id, + cosine_similarity(?1, embedding) + FROM embeddings + "#, + ) + .change_context(Error) + .unwrap(); + + let result_iter = stmt + .query_map(params![embedding_bytes], |row| { + Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?)) + }) + .change_context(Error) + .unwrap(); + + for result in result_iter { + println!("{:?}", result); + } + } } fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> { diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index 874618e..3cab72f 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -310,7 +310,7 @@ pub trait FaceDetector { fn detect_faces( &mut self, image: ndarray::ArrayView3, - config: FaceDetectionConfig, + config: &FaceDetectionConfig, ) -> Result { let (height, width, _channels) = image.dim(); let output = self diff --git a/src/faceembed.rs b/src/faceembed.rs index bdc5294..3c06a20 100644 --- a/src/faceembed.rs +++ b/src/faceembed.rs @@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator; use crate::errors::*; use ndarray::{Array2, ArrayView4}; +pub mod preprocessing { + use ndarray::*; + pub fn preprocess(faces: ArrayView4) -> Array4 { + let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned(); + owned.axis_iter_mut(Axis(0)).for_each(|mut image| { + let mean = image.mean().unwrap_or(0.0); + let std = image.std(0.0); + if std > 0.0 { + image.mapv_inplace(|x| (x - mean) / std); + } else { + image.mapv_inplace(|x| (x - 127.5) / 128.0) + } + }); + owned + } +} + /// Common trait for face embedding backends - maintained for backward compatibility pub trait FaceEmbedder { /// Generate embeddings for a batch of face images diff --git a/src/faceembed/facenet.rs b/src/faceembed/facenet.rs index 6977d2e..04a6adb 100644 --- a/src/faceembed/facenet.rs +++ b/src/faceembed/facenet.rs @@ -4,6 +4,7 @@ pub mod ort; use crate::errors::*; use error_stack::ResultExt; use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; +use ndarray_math::{CosineSimilarity, EuclideanDistance}; /// Configuration for face embedding processing #[derive(Debug, Clone, PartialEq)] @@ -32,9 +33,9 @@ impl FaceEmbeddingConfig { impl Default for FaceEmbeddingConfig { fn default() -> Self { Self { - input_width: 160, - input_height: 160, - normalize: true, + input_width: 320, + input_height: 320, + normalize: false, } } } @@ -63,15 +64,14 @@ impl FaceEmbedding { /// Calculate cosine similarity with another embedding pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 { - let dot_product = self.vector.dot(&other.vector); - let norm_self = self.vector.mapv(|x| x * x).sum().sqrt(); - let norm_other = other.vector.mapv(|x| x * x).sum().sqrt(); - dot_product / (norm_self * norm_other) + self.vector.cosine_similarity(&other.vector).unwrap_or(0.0) } /// Calculate Euclidean distance with another embedding pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 { - (&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt() + self.vector + .euclidean_distance(other.vector.view()) + .unwrap_or(f32::INFINITY) } /// Normalize the embedding vector to unit length diff --git a/src/faceembed/facenet/mnn.rs b/src/faceembed/facenet/mnn.rs index c6819e5..4cff396 100644 --- a/src/faceembed/facenet/mnn.rs +++ b/src/faceembed/facenet/mnn.rs @@ -64,10 +64,7 @@ impl EmbeddingGenerator { } pub fn run_models(&self, face: ArrayView4) -> Result> { - let tensor = face - // .permuted_axes((0, 3, 1, 2)) - .as_standard_layout() - .mapv(|x| x as f32); + let tensor = crate::faceembed::preprocessing::preprocess(face); let shape: [usize; 4] = tensor.dim().into(); let shape = shape.map(|f| f as i32); let output = self diff --git a/src/faceembed/facenet/ort.rs b/src/faceembed/facenet/ort.rs index 1cd272f..b66256d 100644 --- a/src/faceembed/facenet/ort.rs +++ b/src/faceembed/facenet/ort.rs @@ -135,10 +135,12 @@ impl EmbeddingGenerator { pub fn run_models(&mut self, faces: ArrayView4) -> crate::errors::Result> { // Convert input from u8 to f32 and normalize to [0, 1] range - let input_tensor = faces - .mapv(|x| x as f32 / 255.0) - .as_standard_layout() - .into_owned(); + let input_tensor = crate::faceembed::preprocessing::preprocess(faces); + + // face_array = np.asarray(face_resized, 'float32') + // mean, std = face_array.mean(), face_array.std() + // face_normalized = (face_array - mean) / std + // let input_tensor = faces.mean() tracing::trace!("Input tensor shape: {:?}", input_tensor.shape()); diff --git a/src/main.rs b/src/main.rs index 387115c..d2e705a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,8 +75,61 @@ pub fn main() -> Result<()> { } } } - cli::SubCommand::List(list) => { - println!("List: {:?}", list); + cli::SubCommand::DetectMulti(detect_multi) => { + // Choose backend based on executor type (defaulting to MNN for backward compatibility) + + let executor = detect_multi + .mnn_forward_type + .map(|f| cli::Executor::Mnn(f)) + .or_else(|| { + if detect_multi.ort_execution_provider.is_empty() { + None + } else { + Some(cli::Executor::Ort( + detect_multi.ort_execution_provider.clone(), + )) + } + }) + .unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU)); + + match executor { + cli::Executor::Mnn(forward) => { + let retinaface = + facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) + .change_context(Error)? + .with_forward_type(forward) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = + faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) + .change_context(Error)? + .with_forward_type(forward) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; + + run_multi_detection(detect_multi, retinaface, facenet)?; + } + cli::Executor::Ort(ep) => { + let retinaface = + facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX) + .change_context(Error)? + .with_execution_providers(&ep) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = + faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX) + .change_context(Error)? + .with_execution_providers(ep) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; + + run_multi_detection(detect_multi, retinaface, facenet)?; + } + } } cli::SubCommand::Query(query) => { run_query(query)?; @@ -87,6 +140,59 @@ pub fn main() -> Result<()> { cli::SubCommand::Stats(stats) => { run_stats(stats)?; } + cli::SubCommand::Compare(compare) => { + // Choose backend based on executor type (defaulting to MNN for backward compatibility) + let executor = compare + .mnn_forward_type + .map(|f| cli::Executor::Mnn(f)) + .or_else(|| { + if compare.ort_execution_provider.is_empty() { + None + } else { + Some(cli::Executor::Ort(compare.ort_execution_provider.clone())) + } + }) + .unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU)); + + match executor { + cli::Executor::Mnn(forward) => { + let retinaface = + facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) + .change_context(Error)? + .with_forward_type(forward) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = + faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) + .change_context(Error)? + .with_forward_type(forward) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; + + run_compare(compare, retinaface, facenet)?; + } + cli::Executor::Ort(ep) => { + let retinaface = + facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX) + .change_context(Error)? + .with_execution_providers(&ep) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = + faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX) + .change_context(Error)? + .with_execution_providers(ep) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; + + run_compare(compare, retinaface, facenet)?; + } + } + } cli::SubCommand::Completions { shell } => { cli::Cli::completions(shell); } @@ -122,7 +228,7 @@ where let output = retinaface .detect_faces( array.view(), - FaceDetectionConfig::default() + &FaceDetectionConfig::default() .with_threshold(detect.threshold) .with_nms_threshold(detect.nms_threshold), ) @@ -163,7 +269,7 @@ where // }) .map(|roi| { roi.as_standard_layout() - .fast_resize(160, 160, &ResizeOptions::default()) + .fast_resize(320, 320, &ResizeOptions::default()) .change_context(Error) }) // .inspect(|f| { @@ -182,11 +288,14 @@ where if chunk.len() < chunk_size { tracing::warn!("Chunk size is less than 8, padding with zeros"); - let zeros = Array3::zeros((160, 160, 3)); - let zero_array = core::iter::repeat(zeros.view()) + let zeros = Array3::zeros((320, 320, 3)); + let chunk: Vec<_> = chunk + .iter() + .map(|arr| arr.reborrow()) + .chain(core::iter::repeat(zeros.view())) .take(chunk_size) - .collect::>(); - let face_rois: Array4 = ndarray::stack(Axis(0), zero_array.as_slice()) + .collect(); + let face_rois: Array4 = ndarray::stack(Axis(0), chunk.as_slice()) .change_context(errors::Error) .attach_printable("Failed to stack rois together")?; let output = facenet.run_models(face_rois.view()).change_context(Error)?; @@ -328,6 +437,446 @@ fn run_query(query: cli::Query) -> Result<()> { Ok(()) } +fn run_compare(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()> +where + D: facedet::FaceDetector, + E: faceembed::FaceEmbedder, +{ + // Helper function to detect faces and compute embeddings for an image + fn process_image( + image_path: &std::path::Path, + retinaface: &mut D, + facenet: &mut E, + config: &FaceDetectionConfig, + batch_size: usize, + ) -> Result<(Vec>, usize)> + where + D: facedet::FaceDetector, + E: faceembed::FaceEmbedder, + { + let image = image::open(image_path) + .change_context(Error) + .attach_printable(image_path.to_string_lossy().to_string())?; + let image = image.into_rgb8(); + let array = image + .into_ndarray() + .change_context(errors::Error) + .attach_printable("Failed to convert image to ndarray")?; + + let output = retinaface + .detect_faces(array.view(), config) + .change_context(errors::Error) + .attach_printable("Failed to detect faces")?; + + tracing::info!( + "Detected {} faces in {}", + output.bbox.len(), + image_path.display() + ); + + if output.bbox.is_empty() { + return Ok((Vec::new(), 0)); + } + + let face_rois = array + .view() + .multi_roi(&output.bbox) + .change_context(Error)? + .into_iter() + .map(|roi| { + roi.as_standard_layout() + .fast_resize(320, 320, &ResizeOptions::default()) + .change_context(Error) + }) + .collect::>>()?; + + let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); + let chunk_size = batch_size; + let embeddings = face_roi_views + .chunks(chunk_size) + .map(|chunk| { + if chunk.len() < chunk_size { + let zeros = Array3::zeros((320, 320, 3)); + let chunk: Vec<_> = chunk + .iter() + .map(|arr| arr.reborrow()) + .chain(core::iter::repeat(zeros.view())) + .take(chunk_size) + .collect(); + let face_rois: Array4 = ndarray::stack(Axis(0), chunk.as_slice()) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + facenet.run_models(face_rois.view()).change_context(Error) + } else { + let face_rois: Array4 = ndarray::stack(Axis(0), chunk) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + facenet.run_models(face_rois.view()).change_context(Error) + } + }) + .collect::>>>()?; + + // Flatten embeddings into individual face embeddings + let mut face_embeddings = Vec::new(); + for embedding_batch in embeddings { + for i in 0..output.bbox.len().min(embedding_batch.nrows()) { + face_embeddings.push(embedding_batch.row(i).to_owned()); + } + } + + Ok((face_embeddings, output.bbox.len())) + } + + // Helper function to compute cosine similarity between two embeddings + fn cosine_similarity(a: &Array1, b: &Array1) -> f32 { + let dot_product = a.dot(b); + let norm_a = a.dot(a).sqrt(); + let norm_b = b.dot(b).sqrt(); + dot_product / (norm_a * norm_b) + } + + let config = FaceDetectionConfig::default() + .with_threshold(compare.threshold) + .with_nms_threshold(compare.nms_threshold); + + // Process both images + let (embeddings1, face_count1) = process_image( + &compare.image1, + &mut retinaface, + &mut facenet, + &config, + compare.batch_size, + )?; + + let (embeddings2, face_count2) = process_image( + &compare.image2, + &mut retinaface, + &mut facenet, + &config, + compare.batch_size, + )?; + + println!( + "Image 1 ({}): {} faces detected", + compare.image1.display(), + face_count1 + ); + println!( + "Image 2 ({}): {} faces detected", + compare.image2.display(), + face_count2 + ); + + if embeddings1.is_empty() && embeddings2.is_empty() { + println!("No faces detected in either image"); + return Ok(()); + } + + if embeddings1.is_empty() { + println!("No faces detected in image 1"); + return Ok(()); + } + + if embeddings2.is_empty() { + println!("No faces detected in image 2"); + return Ok(()); + } + + // Compare all faces between the two images + println!("\nFace comparison results:"); + println!("========================"); + + let mut max_similarity = f32::NEG_INFINITY; + let mut best_match = (0, 0); + + for (i, emb1) in embeddings1.iter().enumerate() { + for (j, emb2) in embeddings2.iter().enumerate() { + let similarity = cosine_similarity(emb1, emb2); + println!( + "Face {} (image 1) vs Face {} (image 2): {:.4}", + i + 1, + j + 1, + similarity + ); + + if similarity > max_similarity { + max_similarity = similarity; + best_match = (i + 1, j + 1); + } + } + } + + println!( + "\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}", + best_match.0, best_match.1, max_similarity + ); + + // Interpretation of similarity score + if max_similarity > 0.8 { + println!("Interpretation: Very likely the same person"); + } else if max_similarity > 0.6 { + println!("Interpretation: Possibly the same person"); + } else if max_similarity > 0.4 { + println!("Interpretation: Unlikely to be the same person"); + } else { + println!("Interpretation: Very unlikely to be the same person"); + } + + Ok(()) +} + +fn run_multi_detection( + detect_multi: cli::DetectMulti, + mut retinaface: D, + mut facenet: E, +) -> Result<()> +where + D: facedet::FaceDetector, + E: faceembed::FaceEmbedder, +{ + use std::fs; + + // Initialize database - always save to database for multi-detection + let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?; + + // Parse supported extensions + let extensions: std::collections::HashSet = detect_multi + .extensions + .split(',') + .map(|ext| ext.trim().to_lowercase()) + .collect(); + + // Create output directory if specified + if let Some(ref output_dir) = detect_multi.output_dir { + fs::create_dir_all(output_dir) + .change_context(Error) + .attach_printable("Failed to create output directory")?; + } + + // Read directory and filter image files + let entries = fs::read_dir(&detect_multi.input_dir) + .change_context(Error) + .attach_printable("Failed to read input directory")?; + + let mut image_paths = Vec::new(); + for entry in entries { + let entry = entry.change_context(Error)?; + let path = entry.path(); + + if path.is_file() { + if let Some(ext) = path.extension().and_then(|e| e.to_str()) { + if extensions.contains(&ext.to_lowercase()) { + image_paths.push(path); + } + } + } + } + + if image_paths.is_empty() { + tracing::warn!( + "No image files found in directory: {:?}", + detect_multi.input_dir + ); + return Ok(()); + } + + tracing::info!("Found {} image files to process", image_paths.len()); + + let mut total_faces = 0; + let mut processed_images = 0; + + // Process each image + for (idx, image_path) in image_paths.iter().enumerate() { + tracing::info!( + "Processing image {}/{}: {:?}", + idx + 1, + image_paths.len(), + image_path + ); + + // Load and process image + let image = match image::open(image_path) { + Ok(img) => img.into_rgb8(), + Err(e) => { + tracing::error!("Failed to load image {:?}: {}", image_path, e); + continue; + } + }; + + let (image_width, image_height) = image.dimensions(); + let mut array = match image.into_ndarray().change_context(errors::Error) { + Ok(arr) => arr, + Err(e) => { + tracing::error!("Failed to convert image to ndarray: {:?}", e); + continue; + } + }; + + let config = FaceDetectionConfig::default() + .with_threshold(detect_multi.threshold) + .with_nms_threshold(detect_multi.nms_threshold); + // Detect faces + let output = match retinaface.detect_faces(array.view(), &config) { + Ok(output) => output, + Err(e) => { + tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e); + continue; + } + }; + + let num_faces = output.bbox.len(); + total_faces += num_faces; + + if num_faces == 0 { + tracing::info!("No faces detected in {:?}", image_path); + } else { + tracing::info!("Detected {} faces in {:?}", num_faces, image_path); + } + + // Store image and detections in database + let image_path_str = image_path.to_string_lossy(); + let img_id = match db.store_image(&image_path_str, image_width, image_height) { + Ok(id) => id, + Err(e) => { + tracing::error!("Failed to store image in database: {:?}", e); + continue; + } + }; + + let face_ids = match db.store_face_detections(img_id, &output) { + Ok(ids) => ids, + Err(e) => { + tracing::error!("Failed to store face detections in database: {:?}", e); + continue; + } + }; + + // Draw bounding boxes if output directory is specified + if detect_multi.output_dir.is_some() { + for bbox in &output.bbox { + use bounding_box::draw::*; + array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); + } + } + + // Process face embeddings if faces were detected + if !face_ids.is_empty() { + let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) { + Ok(rois) => rois, + Err(e) => { + tracing::error!("Failed to extract face ROIs: {:?}", e); + continue; + } + }; + + let face_rois: Result> = face_rois + .into_iter() + .map(|roi| { + roi.as_standard_layout() + .fast_resize(320, 320, &ResizeOptions::default()) + .change_context(Error) + }) + .collect(); + + let face_rois = match face_rois { + Ok(rois) => rois, + Err(e) => { + tracing::error!("Failed to resize face ROIs: {:?}", e); + continue; + } + }; + + let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); + + let chunk_size = detect_multi.batch_size; + let embeddings: Result>> = face_roi_views + .chunks(chunk_size) + .map(|chunk| { + if chunk.len() < chunk_size { + let zeros = Array3::zeros((320, 320, 3)); + let chunk: Vec<_> = chunk + .iter() + .map(|arr| arr.reborrow()) + .chain(core::iter::repeat(zeros.view())) + .take(chunk_size) + .collect(); + let face_rois: Array4 = ndarray::stack(Axis(0), chunk.as_slice()) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + facenet.run_models(face_rois.view()).change_context(Error) + } else { + let face_rois: Array4 = ndarray::stack(Axis(0), chunk) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + facenet.run_models(face_rois.view()).change_context(Error) + } + }) + .collect(); + + let embeddings = match embeddings { + Ok(emb) => emb, + Err(e) => { + tracing::error!("Failed to generate embeddings: {:?}", e); + continue; + } + }; + + // Store embeddings in database + if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) { + tracing::error!("Failed to store embeddings in database: {:?}", e); + continue; + } + } + + // Save output image if directory specified + if let Some(ref output_dir) = detect_multi.output_dir { + let output_filename = format!( + "detected_{}", + image_path.file_name().unwrap().to_string_lossy() + ); + let output_path = output_dir.join(output_filename); + + let v = array.view(); + let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) { + Ok(img) => img, + Err(e) => { + tracing::error!("Failed to convert ndarray to image: {:?}", e); + continue; + } + }; + + if let Err(e) = output_image.save(&output_path) { + tracing::error!("Failed to save output image to {:?}: {}", output_path, e); + continue; + } + + tracing::info!("Saved output image to {:?}", output_path); + } + + processed_images += 1; + } + + // Print final statistics + tracing::info!( + "Processing complete: {}/{} images processed successfully, {} total faces detected", + processed_images, + image_paths.len(), + total_faces + ); + + let (num_images, num_faces, num_landmarks, num_embeddings) = + db.get_stats().change_context(Error)?; + tracing::info!( + "Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}", + num_images, + num_faces, + num_landmarks, + num_embeddings + ); + + Ok(()) +} + fn run_similar(similar: cli::Similar) -> Result<()> { let db = FaceDatabase::new(&similar.database).change_context(Error)?; @@ -341,14 +890,19 @@ fn run_similar(similar: cli::Similar) -> Result<()> { let similar_faces = db .find_similar_faces(query_embedding, similar.threshold, similar.limit) .change_context(Error)?; - + // Get image information for the similar faces println!( "Found {} similar faces (threshold: {:.3}):", similar_faces.len(), similar.threshold ); - for (face_id, similarity) in similar_faces { - println!(" Face {}: similarity {:.3}", face_id, similarity); + for (face_id, similarity) in &similar_faces { + if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? { + println!( + " Face {}: similarity {:.3}, image: {}", + face_id, similarity, image_info.file_path + ); + } } Ok(())