feat(compare): add face comparison functionality with cosine similarity
This commit is contained in:
331
Cargo.lock
generated
331
Cargo.lock
generated
@@ -139,7 +139,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -148,6 +148,17 @@ version = "0.7.6"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atty"
|
||||||
|
version = "0.2.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||||
|
dependencies = [
|
||||||
|
"hermit-abi",
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
@@ -192,6 +203,29 @@ dependencies = [
|
|||||||
"windows-targets 0.52.6",
|
"windows-targets 0.52.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bindgen"
|
||||||
|
version = "0.60.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "062dddbc1ba4aca46de6338e2bf87771414c335f7b2f2036e8f3e9befebf88e6"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"cexpr",
|
||||||
|
"clang-sys",
|
||||||
|
"clap 3.2.25",
|
||||||
|
"env_logger",
|
||||||
|
"lazy_static",
|
||||||
|
"lazycell",
|
||||||
|
"log",
|
||||||
|
"peeking_take_while",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
"shlex",
|
||||||
|
"which",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bindgen"
|
name = "bindgen"
|
||||||
version = "0.70.1"
|
version = "0.70.1"
|
||||||
@@ -210,7 +244,7 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -281,7 +315,7 @@ checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -349,6 +383,21 @@ dependencies = [
|
|||||||
"libloading",
|
"libloading",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap"
|
||||||
|
version = "3.2.25"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
|
||||||
|
dependencies = [
|
||||||
|
"atty",
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"clap_lex 0.2.4",
|
||||||
|
"indexmap 1.9.3",
|
||||||
|
"strsim 0.10.0",
|
||||||
|
"termcolor",
|
||||||
|
"textwrap",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.5.45"
|
version = "4.5.45"
|
||||||
@@ -367,8 +416,8 @@ checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anstream",
|
"anstream",
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"clap_lex",
|
"clap_lex 0.7.5",
|
||||||
"strsim",
|
"strsim 0.11.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -377,7 +426,7 @@ version = "4.5.57"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
|
checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap 4.5.45",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -389,7 +438,16 @@ dependencies = [
|
|||||||
"heck",
|
"heck",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_lex"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
|
||||||
|
dependencies = [
|
||||||
|
"os_str_bytes",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -505,7 +563,7 @@ name = "detector"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bounding-box",
|
"bounding-box",
|
||||||
"clap",
|
"clap 4.5.45",
|
||||||
"clap_complete",
|
"clap_complete",
|
||||||
"color",
|
"color",
|
||||||
"error-stack",
|
"error-stack",
|
||||||
@@ -518,12 +576,13 @@ dependencies = [
|
|||||||
"nalgebra",
|
"nalgebra",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"ndarray-image",
|
"ndarray-image",
|
||||||
"ndarray-math",
|
"ndarray-math 0.1.0 (git+https://git.darksailor.dev/servius/ndarray-math)",
|
||||||
"ndarray-resize",
|
"ndarray-resize",
|
||||||
"ndarray-safetensors",
|
"ndarray-safetensors",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
"ort",
|
"ort",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
|
"sqlite3-safetensor-cosine",
|
||||||
"tap",
|
"tap",
|
||||||
"thiserror 2.0.15",
|
"thiserror 2.0.15",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -548,7 +607,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -572,6 +631,19 @@ version = "1.15.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_logger"
|
||||||
|
version = "0.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7"
|
||||||
|
dependencies = [
|
||||||
|
"atty",
|
||||||
|
"humantime",
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
"termcolor",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "equator"
|
name = "equator"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
@@ -589,7 +661,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -598,6 +670,16 @@ version = "1.0.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "errno"
|
||||||
|
version = "0.3.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"windows-sys 0.60.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "error-stack"
|
name = "error-stack"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
@@ -851,6 +933,12 @@ dependencies = [
|
|||||||
"crunchy",
|
"crunchy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.12.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.15.5"
|
version = "0.15.5"
|
||||||
@@ -866,7 +954,7 @@ version = "0.10.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
|
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"hashbrown",
|
"hashbrown 0.15.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -875,6 +963,30 @@ version = "0.5.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hermit-abi"
|
||||||
|
version = "0.1.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "home"
|
||||||
|
version = "0.5.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.59.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "humantime"
|
||||||
|
version = "2.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "iana-time-zone"
|
name = "iana-time-zone"
|
||||||
version = "0.1.63"
|
version = "0.1.63"
|
||||||
@@ -1045,6 +1157,16 @@ version = "1.11.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
|
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "indexmap"
|
||||||
|
version = "1.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"hashbrown 0.12.3",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "2.10.0"
|
version = "2.10.0"
|
||||||
@@ -1052,7 +1174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661"
|
checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"equivalent",
|
"equivalent",
|
||||||
"hashbrown",
|
"hashbrown 0.15.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1063,7 +1185,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1137,7 +1259,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1172,6 +1294,12 @@ version = "1.5.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazycell"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lebe"
|
name = "lebe"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
@@ -1214,6 +1342,12 @@ dependencies = [
|
|||||||
"vcpkg",
|
"vcpkg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linux-raw-sys"
|
||||||
|
version = "0.4.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "litemap"
|
name = "litemap"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
@@ -1355,7 +1489,7 @@ version = "0.1.0"
|
|||||||
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
|
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"bindgen",
|
"bindgen 0.70.1",
|
||||||
"cc",
|
"cc",
|
||||||
"cmake",
|
"cmake",
|
||||||
"diffy",
|
"diffy",
|
||||||
@@ -1431,6 +1565,15 @@ dependencies = [
|
|||||||
"ndarray",
|
"ndarray",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray-math"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"ndarray",
|
||||||
|
"num",
|
||||||
|
"thiserror 2.0.15",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ndarray-math"
|
name = "ndarray-math"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -1552,7 +1695,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1652,6 +1795,12 @@ dependencies = [
|
|||||||
"pkg-config",
|
"pkg-config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "os_str_bytes"
|
||||||
|
version = "6.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "overload"
|
name = "overload"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
@@ -1664,6 +1813,12 @@ version = "1.0.15"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "peeking_take_while"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.3.1"
|
version = "2.3.1"
|
||||||
@@ -1741,7 +1896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1769,7 +1924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b"
|
checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2000,6 +2155,19 @@ dependencies = [
|
|||||||
"semver",
|
"semver",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustix"
|
||||||
|
version = "0.38.44"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.2",
|
||||||
|
"errno",
|
||||||
|
"libc",
|
||||||
|
"linux-raw-sys",
|
||||||
|
"windows-sys 0.59.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustversion"
|
name = "rustversion"
|
||||||
version = "1.0.22"
|
version = "1.0.22"
|
||||||
@@ -2060,7 +2228,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2154,18 +2322,79 @@ dependencies = [
|
|||||||
"lock_api",
|
"lock_api",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlite-loadable"
|
||||||
|
version = "0.0.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a916b7bb8738eef189dea88731b619b80bf3f62b3acf05138fa43fbf8621cc94"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"sqlite-loadable-macros",
|
||||||
|
"sqlite3ext-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlite-loadable-macros"
|
||||||
|
version = "0.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f98bc75a8d6fd24f6a2cfea34f28758780fa17279d3051eec926efa381971e48"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlite3-safetensor-cosine"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"ndarray",
|
||||||
|
"ndarray-math 0.1.0",
|
||||||
|
"ndarray-safetensors",
|
||||||
|
"sqlite-loadable",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlite3ext-sys"
|
||||||
|
version = "0.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3afdc2b3dc08f16d6eecf8aa07d19975a268603ab1cca67d3f9b4172c507cf16"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen 0.60.1",
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stable_deref_trait"
|
name = "stable_deref_trait"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strsim"
|
||||||
|
version = "0.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.11.1"
|
version = "0.11.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "1.0.109"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.106"
|
version = "2.0.106"
|
||||||
@@ -2185,7 +2414,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2213,6 +2442,21 @@ version = "0.12.16"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "termcolor"
|
||||||
|
version = "1.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "textwrap"
|
||||||
|
version = "0.16.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.69"
|
version = "1.0.69"
|
||||||
@@ -2239,7 +2483,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2250,7 +2494,7 @@ checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2355,7 +2599,7 @@ version = "0.22.27"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
|
checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"indexmap",
|
"indexmap 2.10.0",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_spanned",
|
"serde_spanned",
|
||||||
"toml_datetime",
|
"toml_datetime",
|
||||||
@@ -2381,7 +2625,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2540,7 +2784,7 @@ dependencies = [
|
|||||||
"log",
|
"log",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2562,7 +2806,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
@@ -2582,6 +2826,18 @@ version = "0.1.10"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
|
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "which"
|
||||||
|
version = "4.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"home",
|
||||||
|
"once_cell",
|
||||||
|
"rustix",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wide"
|
name = "wide"
|
||||||
version = "0.7.33"
|
version = "0.7.33"
|
||||||
@@ -2608,6 +2864,15 @@ version = "0.4.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-util"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.60.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winapi-x86_64-pc-windows-gnu"
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
@@ -2635,7 +2900,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2646,7 +2911,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2882,7 +3147,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2903,7 +3168,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2923,7 +3188,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2957,7 +3222,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors"]
|
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors", "sqlite3-safetensor-cosine"]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -53,6 +53,7 @@ ordered-float = "5.0.0"
|
|||||||
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]}
|
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]}
|
||||||
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||||
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
||||||
|
sqlite3-safetensor-cosine = { version = "0.1.0", path = "sqlite3-safetensor-cosine" }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
@@ -67,4 +68,4 @@ ort-directml = ["ort/directml"]
|
|||||||
mnn-metal = ["mnn/metal"]
|
mnn-metal = ["mnn/metal"]
|
||||||
mnn-coreml = ["mnn/coreml"]
|
mnn-coreml = ["mnn/coreml"]
|
||||||
|
|
||||||
default = []
|
default = ["mnn-metal","mnn-coreml"]
|
||||||
|
|||||||
33
README.md
33
README.md
@@ -55,6 +55,35 @@ cargo run --release detect --output detected.jpg path/to/image.jpg
|
|||||||
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
|
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Face Comparison
|
||||||
|
|
||||||
|
Compare faces between two images by computing and comparing their embeddings:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Compare faces in two images
|
||||||
|
cargo run --release compare image1.jpg image2.jpg
|
||||||
|
|
||||||
|
# Compare with custom thresholds
|
||||||
|
cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg
|
||||||
|
|
||||||
|
# Use ONNX Runtime backend for comparison
|
||||||
|
cargo run --release compare -p cpu image1.jpg image2.jpg
|
||||||
|
|
||||||
|
# Use MNN with Metal acceleration
|
||||||
|
cargo run --release compare -f metal image1.jpg image2.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
The compare command will:
|
||||||
|
1. Detect all faces in both images
|
||||||
|
2. Generate embeddings for each detected face
|
||||||
|
3. Compute cosine similarity between all face pairs
|
||||||
|
4. Display similarity scores and the best match
|
||||||
|
5. Provide interpretation of the similarity scores:
|
||||||
|
- **> 0.8**: Very likely the same person
|
||||||
|
- **0.6-0.8**: Possibly the same person
|
||||||
|
- **0.4-0.6**: Unlikely to be the same person
|
||||||
|
- **< 0.4**: Very unlikely to be the same person
|
||||||
|
|
||||||
### Backend Selection
|
### Backend Selection
|
||||||
|
|
||||||
The project supports two inference backends:
|
The project supports two inference backends:
|
||||||
@@ -106,7 +135,7 @@ The MNN backend supports various execution backends:
|
|||||||
|
|
||||||
- **CPU** - Default, works on all platforms
|
- **CPU** - Default, works on all platforms
|
||||||
- **Metal** - macOS GPU acceleration
|
- **Metal** - macOS GPU acceleration
|
||||||
- **CoreML** - macOS/iOS neural engine acceleration
|
- **CoreML** - macOS/iOS neural engine acceleration
|
||||||
- **OpenCL** - Cross-platform GPU acceleration
|
- **OpenCL** - Cross-platform GPU acceleration
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -179,7 +208,7 @@ MIT License
|
|||||||
Key dependencies include:
|
Key dependencies include:
|
||||||
|
|
||||||
- **MNN** - High-performance neural network inference framework (MNN backend)
|
- **MNN** - High-performance neural network inference framework (MNN backend)
|
||||||
- **ONNX Runtime** - Cross-platform ML inference (ORT backend)
|
- **ONNX Runtime** - Cross-platform ML inference (ORT backend)
|
||||||
- **ndarray** - N-dimensional array processing
|
- **ndarray** - N-dimensional array processing
|
||||||
- **image** - Image processing and I/O
|
- **image** - Image processing and I/O
|
||||||
- **clap** - Command line argument parsing
|
- **clap** - Command line argument parsing
|
||||||
|
|||||||
@@ -2,9 +2,6 @@
|
|||||||
description = "A simple rust flake using rust-overlay and craneLib";
|
description = "A simple rust flake using rust-overlay and craneLib";
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
self = {
|
|
||||||
lfs = true;
|
|
||||||
};
|
|
||||||
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
|
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
crane.url = "github:ipetkov/crane";
|
crane.url = "github:ipetkov/crane";
|
||||||
@@ -206,6 +203,7 @@
|
|||||||
packages = with pkgs;
|
packages = with pkgs;
|
||||||
[
|
[
|
||||||
stableToolchainWithRustAnalyzer
|
stableToolchainWithRustAnalyzer
|
||||||
|
cargo-expand
|
||||||
cargo-nextest
|
cargo-nextest
|
||||||
cargo-deny
|
cargo-deny
|
||||||
cmake
|
cmake
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ use safetensors::tensor::SafeTensors;
|
|||||||
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||||
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||||
/// ```
|
/// ```
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct SafeArraysView<'a> {
|
pub struct SafeArraysView<'a> {
|
||||||
pub tensors: SafeTensors<'a>,
|
pub tensors: SafeTensors<'a>,
|
||||||
}
|
}
|
||||||
|
|||||||
14
sqlite3-safetensor-cosine/Cargo.toml
Normal file
14
sqlite3-safetensor-cosine/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
[package]
|
||||||
|
name = "sqlite3-safetensor-cosine"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
crate-type = ["cdylib", "staticlib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ndarray = "0.16.1"
|
||||||
|
# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||||
|
ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
|
||||||
|
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
|
||||||
|
sqlite-loadable = "0.0.5"
|
||||||
61
sqlite3-safetensor-cosine/src/lib.rs
Normal file
61
sqlite3-safetensor-cosine/src/lib.rs
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
use sqlite_loadable::prelude::*;
|
||||||
|
use sqlite_loadable::{Error, ErrorKind};
|
||||||
|
use sqlite_loadable::{Result, api, define_scalar_function};
|
||||||
|
|
||||||
|
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
|
||||||
|
#[inline(always)]
|
||||||
|
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
|
||||||
|
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if values.len() != 2 {
|
||||||
|
return Err(Error::new(ErrorKind::Message(
|
||||||
|
"cosine_similarity requires exactly 2 arguments".to_string(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
|
||||||
|
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
|
||||||
|
let array_1_st =
|
||||||
|
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
|
||||||
|
let array_2_st =
|
||||||
|
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
|
||||||
|
|
||||||
|
let array_view_1 = array_1_st
|
||||||
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||||
|
.map_err(custom_error)?;
|
||||||
|
let array_view_2 = array_2_st
|
||||||
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||||
|
.map_err(custom_error)?;
|
||||||
|
|
||||||
|
use ndarray_math::*;
|
||||||
|
let similarity = array_view_1
|
||||||
|
.cosine_similarity(array_view_2)
|
||||||
|
.map_err(custom_error)?;
|
||||||
|
api::result_double(context, similarity as f64);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
|
||||||
|
define_scalar_function(
|
||||||
|
db,
|
||||||
|
"cosine_similarity",
|
||||||
|
2,
|
||||||
|
cosine_similarity,
|
||||||
|
FunctionFlags::DETERMINISTIC,
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// Should only be called by underlying SQLite C APIs,
|
||||||
|
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
|
||||||
|
#[unsafe(no_mangle)]
|
||||||
|
pub unsafe extern "C" fn sqlite3_extension_init(
|
||||||
|
db: *mut sqlite3,
|
||||||
|
pz_err_msg: *mut *mut c_char,
|
||||||
|
p_api: *mut sqlite3_api_routines,
|
||||||
|
) -> c_uint {
|
||||||
|
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
|
||||||
|
}
|
||||||
84
src/cli.rs
84
src/cli.rs
@@ -1,6 +1,5 @@
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use mnn::ForwardType;
|
|
||||||
#[derive(Debug, clap::Parser)]
|
#[derive(Debug, clap::Parser)]
|
||||||
pub struct Cli {
|
pub struct Cli {
|
||||||
#[clap(subcommand)]
|
#[clap(subcommand)]
|
||||||
@@ -11,14 +10,16 @@ pub struct Cli {
|
|||||||
pub enum SubCommand {
|
pub enum SubCommand {
|
||||||
#[clap(name = "detect")]
|
#[clap(name = "detect")]
|
||||||
Detect(Detect),
|
Detect(Detect),
|
||||||
#[clap(name = "list")]
|
#[clap(name = "detect-multi")]
|
||||||
List(List),
|
DetectMulti(DetectMulti),
|
||||||
#[clap(name = "query")]
|
#[clap(name = "query")]
|
||||||
Query(Query),
|
Query(Query),
|
||||||
#[clap(name = "similar")]
|
#[clap(name = "similar")]
|
||||||
Similar(Similar),
|
Similar(Similar),
|
||||||
#[clap(name = "stats")]
|
#[clap(name = "stats")]
|
||||||
Stats(Stats),
|
Stats(Stats),
|
||||||
|
#[clap(name = "compare")]
|
||||||
|
Compare(Compare),
|
||||||
#[clap(name = "completions")]
|
#[clap(name = "completions")]
|
||||||
Completions { shell: clap_complete::Shell },
|
Completions { shell: clap_complete::Shell },
|
||||||
}
|
}
|
||||||
@@ -74,7 +75,47 @@ pub struct Detect {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, clap::Args)]
|
#[derive(Debug, clap::Args)]
|
||||||
pub struct List {}
|
pub struct DetectMulti {
|
||||||
|
#[clap(short, long)]
|
||||||
|
pub model: Option<PathBuf>,
|
||||||
|
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||||
|
pub model_type: Models,
|
||||||
|
#[clap(short, long)]
|
||||||
|
pub output_dir: Option<PathBuf>,
|
||||||
|
#[clap(
|
||||||
|
short = 'p',
|
||||||
|
long,
|
||||||
|
default_value = "cpu",
|
||||||
|
group = "execution_provider",
|
||||||
|
required_unless_present = "mnn_forward_type"
|
||||||
|
)]
|
||||||
|
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||||
|
#[clap(
|
||||||
|
short = 'f',
|
||||||
|
long,
|
||||||
|
group = "execution_provider",
|
||||||
|
required_unless_present = "ort_execution_provider"
|
||||||
|
)]
|
||||||
|
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||||
|
#[clap(short, long, default_value_t = 0.8)]
|
||||||
|
pub threshold: f32,
|
||||||
|
#[clap(short, long, default_value_t = 0.3)]
|
||||||
|
pub nms_threshold: f32,
|
||||||
|
#[clap(short, long, default_value_t = 8)]
|
||||||
|
pub batch_size: usize,
|
||||||
|
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||||
|
pub database: PathBuf,
|
||||||
|
#[clap(long, default_value = "facenet")]
|
||||||
|
pub model_name: String,
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
help = "Image extensions to process (e.g., jpg,png,jpeg)",
|
||||||
|
default_value = "jpg,jpeg,png,bmp,tiff,webp"
|
||||||
|
)]
|
||||||
|
pub extensions: String,
|
||||||
|
#[clap(help = "Directory containing images to process")]
|
||||||
|
pub input_dir: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, clap::Args)]
|
#[derive(Debug, clap::Args)]
|
||||||
pub struct Query {
|
pub struct Query {
|
||||||
@@ -108,6 +149,41 @@ pub struct Stats {
|
|||||||
pub database: PathBuf,
|
pub database: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Args)]
|
||||||
|
pub struct Compare {
|
||||||
|
#[clap(short, long)]
|
||||||
|
pub model: Option<PathBuf>,
|
||||||
|
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||||
|
pub model_type: Models,
|
||||||
|
#[clap(
|
||||||
|
short = 'p',
|
||||||
|
long,
|
||||||
|
default_value = "cpu",
|
||||||
|
group = "execution_provider",
|
||||||
|
required_unless_present = "mnn_forward_type"
|
||||||
|
)]
|
||||||
|
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||||
|
#[clap(
|
||||||
|
short = 'f',
|
||||||
|
long,
|
||||||
|
group = "execution_provider",
|
||||||
|
required_unless_present = "ort_execution_provider"
|
||||||
|
)]
|
||||||
|
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||||
|
#[clap(short, long, default_value_t = 0.8)]
|
||||||
|
pub threshold: f32,
|
||||||
|
#[clap(short, long, default_value_t = 0.3)]
|
||||||
|
pub nms_threshold: f32,
|
||||||
|
#[clap(short, long, default_value_t = 8)]
|
||||||
|
pub batch_size: usize,
|
||||||
|
#[clap(long, default_value = "facenet")]
|
||||||
|
pub model_name: String,
|
||||||
|
#[clap(help = "First image to compare")]
|
||||||
|
pub image1: PathBuf,
|
||||||
|
#[clap(help = "Second image to compare")]
|
||||||
|
pub image2: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
impl Cli {
|
impl Cli {
|
||||||
pub fn completions(shell: clap_complete::Shell) {
|
pub fn completions(shell: clap_complete::Shell) {
|
||||||
let mut command = <Cli as clap::CommandFactory>::command();
|
let mut command = <Cli as clap::CommandFactory>::command();
|
||||||
|
|||||||
136
src/database.rs
136
src/database.rs
@@ -65,7 +65,14 @@ impl FaceDatabase {
|
|||||||
/// Create a new database connection and initialize tables
|
/// Create a new database connection and initialize tables
|
||||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||||
let conn = Connection::open(db_path).change_context(Error)?;
|
let conn = Connection::open(db_path).change_context(Error)?;
|
||||||
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
unsafe {
|
||||||
|
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||||
|
conn.load_extension(
|
||||||
|
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||||
|
None::<&str>,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
}
|
||||||
let db = Self { conn };
|
let db = Self { conn };
|
||||||
db.create_tables()?;
|
db.create_tables()?;
|
||||||
Ok(db)
|
Ok(db)
|
||||||
@@ -190,10 +197,9 @@ impl FaceDatabase {
|
|||||||
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
||||||
.change_context(Error)?;
|
.change_context(Error)?;
|
||||||
|
|
||||||
stmt.execute(params![file_path, width, height])
|
Ok(stmt
|
||||||
.change_context(Error)?;
|
.insert(params![file_path, width, height])
|
||||||
|
.change_context(Error)?)
|
||||||
Ok(self.conn.last_insert_rowid())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store face detection results
|
/// Store face detection results
|
||||||
@@ -231,17 +237,16 @@ impl FaceDatabase {
|
|||||||
)
|
)
|
||||||
.change_context(Error)?;
|
.change_context(Error)?;
|
||||||
|
|
||||||
stmt.execute(params![
|
Ok(stmt
|
||||||
image_id,
|
.insert(params![
|
||||||
bbox.x1() as f32,
|
image_id,
|
||||||
bbox.y1() as f32,
|
bbox.x1() as f32,
|
||||||
bbox.x2() as f32,
|
bbox.y1() as f32,
|
||||||
bbox.y2() as f32,
|
bbox.x2() as f32,
|
||||||
confidence
|
bbox.y2() as f32,
|
||||||
])
|
confidence
|
||||||
.change_context(Error)?;
|
])
|
||||||
|
.change_context(Error)?)
|
||||||
Ok(self.conn.last_insert_rowid())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store face landmarks
|
/// Store face landmarks
|
||||||
@@ -258,22 +263,21 @@ impl FaceDatabase {
|
|||||||
)
|
)
|
||||||
.change_context(Error)?;
|
.change_context(Error)?;
|
||||||
|
|
||||||
stmt.execute(params![
|
Ok(stmt
|
||||||
face_id,
|
.insert(params![
|
||||||
landmarks.left_eye.x,
|
face_id,
|
||||||
landmarks.left_eye.y,
|
landmarks.left_eye.x,
|
||||||
landmarks.right_eye.x,
|
landmarks.left_eye.y,
|
||||||
landmarks.right_eye.y,
|
landmarks.right_eye.x,
|
||||||
landmarks.nose.x,
|
landmarks.right_eye.y,
|
||||||
landmarks.nose.y,
|
landmarks.nose.x,
|
||||||
landmarks.left_mouth.x,
|
landmarks.nose.y,
|
||||||
landmarks.left_mouth.y,
|
landmarks.left_mouth.x,
|
||||||
landmarks.right_mouth.x,
|
landmarks.left_mouth.y,
|
||||||
landmarks.right_mouth.y,
|
landmarks.right_mouth.x,
|
||||||
])
|
landmarks.right_mouth.y,
|
||||||
.change_context(Error)?;
|
])
|
||||||
|
.change_context(Error)?)
|
||||||
Ok(self.conn.last_insert_rowid())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store face embeddings
|
/// Store face embeddings
|
||||||
@@ -310,12 +314,12 @@ impl FaceDatabase {
|
|||||||
embedding: ndarray::ArrayView1<f32>,
|
embedding: ndarray::ArrayView1<f32>,
|
||||||
model_name: &str,
|
model_name: &str,
|
||||||
) -> Result<i64> {
|
) -> Result<i64> {
|
||||||
let embedding_bytes =
|
let safe_arrays =
|
||||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
||||||
.change_context(Error)?
|
|
||||||
.serialize()
|
|
||||||
.change_context(Error)?;
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
|
||||||
|
|
||||||
let mut stmt = self
|
let mut stmt = self
|
||||||
.conn
|
.conn
|
||||||
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
||||||
@@ -462,6 +466,35 @@ impl FaceDatabase {
|
|||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#"
|
||||||
|
SELECT images.id, images.file_path, images.width, images.height, images.created_at
|
||||||
|
FROM images
|
||||||
|
JOIN faces ON faces.image_id = images.id
|
||||||
|
WHERE faces.id = ?1
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let result = stmt
|
||||||
|
.query_row(params![face_id], |row| {
|
||||||
|
Ok(ImageRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
file_path: row.get(1)?,
|
||||||
|
width: row.get(2)?,
|
||||||
|
height: row.get(3)?,
|
||||||
|
created_at: row.get(4)?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.optional()
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get database statistics
|
/// Get database statistics
|
||||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||||
let images: usize = self
|
let images: usize = self
|
||||||
@@ -528,6 +561,39 @@ impl FaceDatabase {
|
|||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
|
||||||
|
let embedding_bytes =
|
||||||
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||||
|
.change_context(Error)
|
||||||
|
.unwrap()
|
||||||
|
.serialize()
|
||||||
|
.change_context(Error)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#"
|
||||||
|
SELECT face_id,
|
||||||
|
cosine_similarity(?1, embedding)
|
||||||
|
FROM embeddings
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result_iter = stmt
|
||||||
|
.query_map(params![embedding_bytes], |row| {
|
||||||
|
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||||
|
})
|
||||||
|
.change_context(Error)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
for result in result_iter {
|
||||||
|
println!("{:?}", result);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ pub trait FaceDetector {
|
|||||||
fn detect_faces(
|
fn detect_faces(
|
||||||
&mut self,
|
&mut self,
|
||||||
image: ndarray::ArrayView3<u8>,
|
image: ndarray::ArrayView3<u8>,
|
||||||
config: FaceDetectionConfig,
|
config: &FaceDetectionConfig,
|
||||||
) -> Result<FaceDetectionOutput> {
|
) -> Result<FaceDetectionOutput> {
|
||||||
let (height, width, _channels) = image.dim();
|
let (height, width, _channels) = image.dim();
|
||||||
let output = self
|
let output = self
|
||||||
|
|||||||
@@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use ndarray::{Array2, ArrayView4};
|
use ndarray::{Array2, ArrayView4};
|
||||||
|
|
||||||
|
pub mod preprocessing {
|
||||||
|
use ndarray::*;
|
||||||
|
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
|
||||||
|
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
|
||||||
|
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
|
||||||
|
let mean = image.mean().unwrap_or(0.0);
|
||||||
|
let std = image.std(0.0);
|
||||||
|
if std > 0.0 {
|
||||||
|
image.mapv_inplace(|x| (x - mean) / std);
|
||||||
|
} else {
|
||||||
|
image.mapv_inplace(|x| (x - 127.5) / 128.0)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
owned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||||
pub trait FaceEmbedder {
|
pub trait FaceEmbedder {
|
||||||
/// Generate embeddings for a batch of face images
|
/// Generate embeddings for a batch of face images
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pub mod ort;
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use error_stack::ResultExt;
|
use error_stack::ResultExt;
|
||||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||||
|
use ndarray_math::{CosineSimilarity, EuclideanDistance};
|
||||||
|
|
||||||
/// Configuration for face embedding processing
|
/// Configuration for face embedding processing
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@@ -32,9 +33,9 @@ impl FaceEmbeddingConfig {
|
|||||||
impl Default for FaceEmbeddingConfig {
|
impl Default for FaceEmbeddingConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
input_width: 160,
|
input_width: 320,
|
||||||
input_height: 160,
|
input_height: 320,
|
||||||
normalize: true,
|
normalize: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,15 +64,14 @@ impl FaceEmbedding {
|
|||||||
|
|
||||||
/// Calculate cosine similarity with another embedding
|
/// Calculate cosine similarity with another embedding
|
||||||
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
|
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
|
||||||
let dot_product = self.vector.dot(&other.vector);
|
self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
|
||||||
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
|
|
||||||
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
|
|
||||||
dot_product / (norm_self * norm_other)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate Euclidean distance with another embedding
|
/// Calculate Euclidean distance with another embedding
|
||||||
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
|
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
|
||||||
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt()
|
self.vector
|
||||||
|
.euclidean_distance(other.vector.view())
|
||||||
|
.unwrap_or(f32::INFINITY)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Normalize the embedding vector to unit length
|
/// Normalize the embedding vector to unit length
|
||||||
|
|||||||
@@ -64,10 +64,7 @@ impl EmbeddingGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||||
let tensor = face
|
let tensor = crate::faceembed::preprocessing::preprocess(face);
|
||||||
// .permuted_axes((0, 3, 1, 2))
|
|
||||||
.as_standard_layout()
|
|
||||||
.mapv(|x| x as f32);
|
|
||||||
let shape: [usize; 4] = tensor.dim().into();
|
let shape: [usize; 4] = tensor.dim().into();
|
||||||
let shape = shape.map(|f| f as i32);
|
let shape = shape.map(|f| f as i32);
|
||||||
let output = self
|
let output = self
|
||||||
|
|||||||
@@ -135,10 +135,12 @@ impl EmbeddingGenerator {
|
|||||||
|
|
||||||
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
// Convert input from u8 to f32 and normalize to [0, 1] range
|
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||||
let input_tensor = faces
|
let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
|
||||||
.mapv(|x| x as f32 / 255.0)
|
|
||||||
.as_standard_layout()
|
// face_array = np.asarray(face_resized, 'float32')
|
||||||
.into_owned();
|
// mean, std = face_array.mean(), face_array.std()
|
||||||
|
// face_normalized = (face_array - mean) / std
|
||||||
|
// let input_tensor = faces.mean()
|
||||||
|
|
||||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||||
|
|
||||||
|
|||||||
576
src/main.rs
576
src/main.rs
@@ -75,8 +75,61 @@ pub fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cli::SubCommand::List(list) => {
|
cli::SubCommand::DetectMulti(detect_multi) => {
|
||||||
println!("List: {:?}", list);
|
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||||
|
|
||||||
|
let executor = detect_multi
|
||||||
|
.mnn_forward_type
|
||||||
|
.map(|f| cli::Executor::Mnn(f))
|
||||||
|
.or_else(|| {
|
||||||
|
if detect_multi.ort_execution_provider.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(cli::Executor::Ort(
|
||||||
|
detect_multi.ort_execution_provider.clone(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||||
|
|
||||||
|
match executor {
|
||||||
|
cli::Executor::Mnn(forward) => {
|
||||||
|
let retinaface =
|
||||||
|
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_forward_type(forward)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face detection model")?;
|
||||||
|
let facenet =
|
||||||
|
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_forward_type(forward)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
|
run_multi_detection(detect_multi, retinaface, facenet)?;
|
||||||
|
}
|
||||||
|
cli::Executor::Ort(ep) => {
|
||||||
|
let retinaface =
|
||||||
|
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_execution_providers(&ep)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face detection model")?;
|
||||||
|
let facenet =
|
||||||
|
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_execution_providers(ep)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
|
run_multi_detection(detect_multi, retinaface, facenet)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cli::SubCommand::Query(query) => {
|
cli::SubCommand::Query(query) => {
|
||||||
run_query(query)?;
|
run_query(query)?;
|
||||||
@@ -87,6 +140,59 @@ pub fn main() -> Result<()> {
|
|||||||
cli::SubCommand::Stats(stats) => {
|
cli::SubCommand::Stats(stats) => {
|
||||||
run_stats(stats)?;
|
run_stats(stats)?;
|
||||||
}
|
}
|
||||||
|
cli::SubCommand::Compare(compare) => {
|
||||||
|
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||||
|
let executor = compare
|
||||||
|
.mnn_forward_type
|
||||||
|
.map(|f| cli::Executor::Mnn(f))
|
||||||
|
.or_else(|| {
|
||||||
|
if compare.ort_execution_provider.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(cli::Executor::Ort(compare.ort_execution_provider.clone()))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||||
|
|
||||||
|
match executor {
|
||||||
|
cli::Executor::Mnn(forward) => {
|
||||||
|
let retinaface =
|
||||||
|
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_forward_type(forward)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face detection model")?;
|
||||||
|
let facenet =
|
||||||
|
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_forward_type(forward)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
|
run_compare(compare, retinaface, facenet)?;
|
||||||
|
}
|
||||||
|
cli::Executor::Ort(ep) => {
|
||||||
|
let retinaface =
|
||||||
|
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_execution_providers(&ep)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face detection model")?;
|
||||||
|
let facenet =
|
||||||
|
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_execution_providers(ep)
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
|
run_compare(compare, retinaface, facenet)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
cli::SubCommand::Completions { shell } => {
|
cli::SubCommand::Completions { shell } => {
|
||||||
cli::Cli::completions(shell);
|
cli::Cli::completions(shell);
|
||||||
}
|
}
|
||||||
@@ -122,7 +228,7 @@ where
|
|||||||
let output = retinaface
|
let output = retinaface
|
||||||
.detect_faces(
|
.detect_faces(
|
||||||
array.view(),
|
array.view(),
|
||||||
FaceDetectionConfig::default()
|
&FaceDetectionConfig::default()
|
||||||
.with_threshold(detect.threshold)
|
.with_threshold(detect.threshold)
|
||||||
.with_nms_threshold(detect.nms_threshold),
|
.with_nms_threshold(detect.nms_threshold),
|
||||||
)
|
)
|
||||||
@@ -163,7 +269,7 @@ where
|
|||||||
// })
|
// })
|
||||||
.map(|roi| {
|
.map(|roi| {
|
||||||
roi.as_standard_layout()
|
roi.as_standard_layout()
|
||||||
.fast_resize(160, 160, &ResizeOptions::default())
|
.fast_resize(320, 320, &ResizeOptions::default())
|
||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
})
|
})
|
||||||
// .inspect(|f| {
|
// .inspect(|f| {
|
||||||
@@ -182,11 +288,14 @@ where
|
|||||||
|
|
||||||
if chunk.len() < chunk_size {
|
if chunk.len() < chunk_size {
|
||||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||||
let zeros = Array3::zeros((160, 160, 3));
|
let zeros = Array3::zeros((320, 320, 3));
|
||||||
let zero_array = core::iter::repeat(zeros.view())
|
let chunk: Vec<_> = chunk
|
||||||
|
.iter()
|
||||||
|
.map(|arr| arr.reborrow())
|
||||||
|
.chain(core::iter::repeat(zeros.view()))
|
||||||
.take(chunk_size)
|
.take(chunk_size)
|
||||||
.collect::<Vec<_>>();
|
.collect();
|
||||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to stack rois together")?;
|
.attach_printable("Failed to stack rois together")?;
|
||||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||||
@@ -328,6 +437,446 @@ fn run_query(query: cli::Query) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_compare<D, E>(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()>
|
||||||
|
where
|
||||||
|
D: facedet::FaceDetector,
|
||||||
|
E: faceembed::FaceEmbedder,
|
||||||
|
{
|
||||||
|
// Helper function to detect faces and compute embeddings for an image
|
||||||
|
fn process_image<D, E>(
|
||||||
|
image_path: &std::path::Path,
|
||||||
|
retinaface: &mut D,
|
||||||
|
facenet: &mut E,
|
||||||
|
config: &FaceDetectionConfig,
|
||||||
|
batch_size: usize,
|
||||||
|
) -> Result<(Vec<Array1<f32>>, usize)>
|
||||||
|
where
|
||||||
|
D: facedet::FaceDetector,
|
||||||
|
E: faceembed::FaceEmbedder,
|
||||||
|
{
|
||||||
|
let image = image::open(image_path)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable(image_path.to_string_lossy().to_string())?;
|
||||||
|
let image = image.into_rgb8();
|
||||||
|
let array = image
|
||||||
|
.into_ndarray()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to convert image to ndarray")?;
|
||||||
|
|
||||||
|
let output = retinaface
|
||||||
|
.detect_faces(array.view(), config)
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to detect faces")?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Detected {} faces in {}",
|
||||||
|
output.bbox.len(),
|
||||||
|
image_path.display()
|
||||||
|
);
|
||||||
|
|
||||||
|
if output.bbox.is_empty() {
|
||||||
|
return Ok((Vec::new(), 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
let face_rois = array
|
||||||
|
.view()
|
||||||
|
.multi_roi(&output.bbox)
|
||||||
|
.change_context(Error)?
|
||||||
|
.into_iter()
|
||||||
|
.map(|roi| {
|
||||||
|
roi.as_standard_layout()
|
||||||
|
.fast_resize(320, 320, &ResizeOptions::default())
|
||||||
|
.change_context(Error)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||||
|
let chunk_size = batch_size;
|
||||||
|
let embeddings = face_roi_views
|
||||||
|
.chunks(chunk_size)
|
||||||
|
.map(|chunk| {
|
||||||
|
if chunk.len() < chunk_size {
|
||||||
|
let zeros = Array3::zeros((320, 320, 3));
|
||||||
|
let chunk: Vec<_> = chunk
|
||||||
|
.iter()
|
||||||
|
.map(|arr| arr.reborrow())
|
||||||
|
.chain(core::iter::repeat(zeros.view()))
|
||||||
|
.take(chunk_size)
|
||||||
|
.collect();
|
||||||
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to stack rois together")?;
|
||||||
|
facenet.run_models(face_rois.view()).change_context(Error)
|
||||||
|
} else {
|
||||||
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to stack rois together")?;
|
||||||
|
facenet.run_models(face_rois.view()).change_context(Error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<Array2<f32>>>>()?;
|
||||||
|
|
||||||
|
// Flatten embeddings into individual face embeddings
|
||||||
|
let mut face_embeddings = Vec::new();
|
||||||
|
for embedding_batch in embeddings {
|
||||||
|
for i in 0..output.bbox.len().min(embedding_batch.nrows()) {
|
||||||
|
face_embeddings.push(embedding_batch.row(i).to_owned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((face_embeddings, output.bbox.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to compute cosine similarity between two embeddings
|
||||||
|
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
|
||||||
|
let dot_product = a.dot(b);
|
||||||
|
let norm_a = a.dot(a).sqrt();
|
||||||
|
let norm_b = b.dot(b).sqrt();
|
||||||
|
dot_product / (norm_a * norm_b)
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = FaceDetectionConfig::default()
|
||||||
|
.with_threshold(compare.threshold)
|
||||||
|
.with_nms_threshold(compare.nms_threshold);
|
||||||
|
|
||||||
|
// Process both images
|
||||||
|
let (embeddings1, face_count1) = process_image(
|
||||||
|
&compare.image1,
|
||||||
|
&mut retinaface,
|
||||||
|
&mut facenet,
|
||||||
|
&config,
|
||||||
|
compare.batch_size,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let (embeddings2, face_count2) = process_image(
|
||||||
|
&compare.image2,
|
||||||
|
&mut retinaface,
|
||||||
|
&mut facenet,
|
||||||
|
&config,
|
||||||
|
compare.batch_size,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Image 1 ({}): {} faces detected",
|
||||||
|
compare.image1.display(),
|
||||||
|
face_count1
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"Image 2 ({}): {} faces detected",
|
||||||
|
compare.image2.display(),
|
||||||
|
face_count2
|
||||||
|
);
|
||||||
|
|
||||||
|
if embeddings1.is_empty() && embeddings2.is_empty() {
|
||||||
|
println!("No faces detected in either image");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if embeddings1.is_empty() {
|
||||||
|
println!("No faces detected in image 1");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if embeddings2.is_empty() {
|
||||||
|
println!("No faces detected in image 2");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare all faces between the two images
|
||||||
|
println!("\nFace comparison results:");
|
||||||
|
println!("========================");
|
||||||
|
|
||||||
|
let mut max_similarity = f32::NEG_INFINITY;
|
||||||
|
let mut best_match = (0, 0);
|
||||||
|
|
||||||
|
for (i, emb1) in embeddings1.iter().enumerate() {
|
||||||
|
for (j, emb2) in embeddings2.iter().enumerate() {
|
||||||
|
let similarity = cosine_similarity(emb1, emb2);
|
||||||
|
println!(
|
||||||
|
"Face {} (image 1) vs Face {} (image 2): {:.4}",
|
||||||
|
i + 1,
|
||||||
|
j + 1,
|
||||||
|
similarity
|
||||||
|
);
|
||||||
|
|
||||||
|
if similarity > max_similarity {
|
||||||
|
max_similarity = similarity;
|
||||||
|
best_match = (i + 1, j + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}",
|
||||||
|
best_match.0, best_match.1, max_similarity
|
||||||
|
);
|
||||||
|
|
||||||
|
// Interpretation of similarity score
|
||||||
|
if max_similarity > 0.8 {
|
||||||
|
println!("Interpretation: Very likely the same person");
|
||||||
|
} else if max_similarity > 0.6 {
|
||||||
|
println!("Interpretation: Possibly the same person");
|
||||||
|
} else if max_similarity > 0.4 {
|
||||||
|
println!("Interpretation: Unlikely to be the same person");
|
||||||
|
} else {
|
||||||
|
println!("Interpretation: Very unlikely to be the same person");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_multi_detection<D, E>(
|
||||||
|
detect_multi: cli::DetectMulti,
|
||||||
|
mut retinaface: D,
|
||||||
|
mut facenet: E,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
D: facedet::FaceDetector,
|
||||||
|
E: faceembed::FaceEmbedder,
|
||||||
|
{
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
// Initialize database - always save to database for multi-detection
|
||||||
|
let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?;
|
||||||
|
|
||||||
|
// Parse supported extensions
|
||||||
|
let extensions: std::collections::HashSet<String> = detect_multi
|
||||||
|
.extensions
|
||||||
|
.split(',')
|
||||||
|
.map(|ext| ext.trim().to_lowercase())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Create output directory if specified
|
||||||
|
if let Some(ref output_dir) = detect_multi.output_dir {
|
||||||
|
fs::create_dir_all(output_dir)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create output directory")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read directory and filter image files
|
||||||
|
let entries = fs::read_dir(&detect_multi.input_dir)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to read input directory")?;
|
||||||
|
|
||||||
|
let mut image_paths = Vec::new();
|
||||||
|
for entry in entries {
|
||||||
|
let entry = entry.change_context(Error)?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.is_file() {
|
||||||
|
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
|
||||||
|
if extensions.contains(&ext.to_lowercase()) {
|
||||||
|
image_paths.push(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if image_paths.is_empty() {
|
||||||
|
tracing::warn!(
|
||||||
|
"No image files found in directory: {:?}",
|
||||||
|
detect_multi.input_dir
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("Found {} image files to process", image_paths.len());
|
||||||
|
|
||||||
|
let mut total_faces = 0;
|
||||||
|
let mut processed_images = 0;
|
||||||
|
|
||||||
|
// Process each image
|
||||||
|
for (idx, image_path) in image_paths.iter().enumerate() {
|
||||||
|
tracing::info!(
|
||||||
|
"Processing image {}/{}: {:?}",
|
||||||
|
idx + 1,
|
||||||
|
image_paths.len(),
|
||||||
|
image_path
|
||||||
|
);
|
||||||
|
|
||||||
|
// Load and process image
|
||||||
|
let image = match image::open(image_path) {
|
||||||
|
Ok(img) => img.into_rgb8(),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to load image {:?}: {}", image_path, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (image_width, image_height) = image.dimensions();
|
||||||
|
let mut array = match image.into_ndarray().change_context(errors::Error) {
|
||||||
|
Ok(arr) => arr,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to convert image to ndarray: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = FaceDetectionConfig::default()
|
||||||
|
.with_threshold(detect_multi.threshold)
|
||||||
|
.with_nms_threshold(detect_multi.nms_threshold);
|
||||||
|
// Detect faces
|
||||||
|
let output = match retinaface.detect_faces(array.view(), &config) {
|
||||||
|
Ok(output) => output,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let num_faces = output.bbox.len();
|
||||||
|
total_faces += num_faces;
|
||||||
|
|
||||||
|
if num_faces == 0 {
|
||||||
|
tracing::info!("No faces detected in {:?}", image_path);
|
||||||
|
} else {
|
||||||
|
tracing::info!("Detected {} faces in {:?}", num_faces, image_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store image and detections in database
|
||||||
|
let image_path_str = image_path.to_string_lossy();
|
||||||
|
let img_id = match db.store_image(&image_path_str, image_width, image_height) {
|
||||||
|
Ok(id) => id,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to store image in database: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let face_ids = match db.store_face_detections(img_id, &output) {
|
||||||
|
Ok(ids) => ids,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to store face detections in database: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Draw bounding boxes if output directory is specified
|
||||||
|
if detect_multi.output_dir.is_some() {
|
||||||
|
for bbox in &output.bbox {
|
||||||
|
use bounding_box::draw::*;
|
||||||
|
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process face embeddings if faces were detected
|
||||||
|
if !face_ids.is_empty() {
|
||||||
|
let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) {
|
||||||
|
Ok(rois) => rois,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to extract face ROIs: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let face_rois: Result<Vec<_>> = face_rois
|
||||||
|
.into_iter()
|
||||||
|
.map(|roi| {
|
||||||
|
roi.as_standard_layout()
|
||||||
|
.fast_resize(320, 320, &ResizeOptions::default())
|
||||||
|
.change_context(Error)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let face_rois = match face_rois {
|
||||||
|
Ok(rois) => rois,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to resize face ROIs: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let chunk_size = detect_multi.batch_size;
|
||||||
|
let embeddings: Result<Vec<Array2<f32>>> = face_roi_views
|
||||||
|
.chunks(chunk_size)
|
||||||
|
.map(|chunk| {
|
||||||
|
if chunk.len() < chunk_size {
|
||||||
|
let zeros = Array3::zeros((320, 320, 3));
|
||||||
|
let chunk: Vec<_> = chunk
|
||||||
|
.iter()
|
||||||
|
.map(|arr| arr.reborrow())
|
||||||
|
.chain(core::iter::repeat(zeros.view()))
|
||||||
|
.take(chunk_size)
|
||||||
|
.collect();
|
||||||
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to stack rois together")?;
|
||||||
|
facenet.run_models(face_rois.view()).change_context(Error)
|
||||||
|
} else {
|
||||||
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to stack rois together")?;
|
||||||
|
facenet.run_models(face_rois.view()).change_context(Error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let embeddings = match embeddings {
|
||||||
|
Ok(emb) => emb,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to generate embeddings: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Store embeddings in database
|
||||||
|
if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) {
|
||||||
|
tracing::error!("Failed to store embeddings in database: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save output image if directory specified
|
||||||
|
if let Some(ref output_dir) = detect_multi.output_dir {
|
||||||
|
let output_filename = format!(
|
||||||
|
"detected_{}",
|
||||||
|
image_path.file_name().unwrap().to_string_lossy()
|
||||||
|
);
|
||||||
|
let output_path = output_dir.join(output_filename);
|
||||||
|
|
||||||
|
let v = array.view();
|
||||||
|
let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) {
|
||||||
|
Ok(img) => img,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to convert ndarray to image: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = output_image.save(&output_path) {
|
||||||
|
tracing::error!("Failed to save output image to {:?}: {}", output_path, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("Saved output image to {:?}", output_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
processed_images += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print final statistics
|
||||||
|
tracing::info!(
|
||||||
|
"Processing complete: {}/{} images processed successfully, {} total faces detected",
|
||||||
|
processed_images,
|
||||||
|
image_paths.len(),
|
||||||
|
total_faces
|
||||||
|
);
|
||||||
|
|
||||||
|
let (num_images, num_faces, num_landmarks, num_embeddings) =
|
||||||
|
db.get_stats().change_context(Error)?;
|
||||||
|
tracing::info!(
|
||||||
|
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
|
||||||
|
num_images,
|
||||||
|
num_faces,
|
||||||
|
num_landmarks,
|
||||||
|
num_embeddings
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn run_similar(similar: cli::Similar) -> Result<()> {
|
fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||||
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
|
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
|
||||||
|
|
||||||
@@ -341,14 +890,19 @@ fn run_similar(similar: cli::Similar) -> Result<()> {
|
|||||||
let similar_faces = db
|
let similar_faces = db
|
||||||
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
|
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
|
||||||
.change_context(Error)?;
|
.change_context(Error)?;
|
||||||
|
// Get image information for the similar faces
|
||||||
println!(
|
println!(
|
||||||
"Found {} similar faces (threshold: {:.3}):",
|
"Found {} similar faces (threshold: {:.3}):",
|
||||||
similar_faces.len(),
|
similar_faces.len(),
|
||||||
similar.threshold
|
similar.threshold
|
||||||
);
|
);
|
||||||
for (face_id, similarity) in similar_faces {
|
for (face_id, similarity) in &similar_faces {
|
||||||
println!(" Face {}: similarity {:.3}", face_id, similarity);
|
if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? {
|
||||||
|
println!(
|
||||||
|
" Face {}: similarity {:.3}, image: {}",
|
||||||
|
face_id, similarity, image_info.file_path
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
Reference in New Issue
Block a user