Compare commits
3 Commits
5a1f4b9ef6
...
7fc958b299
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fc958b299 | ||
|
|
3aa95a2ef5 | ||
|
|
e7c9c38ed7 |
127
Cargo.lock
generated
127
Cargo.lock
generated
@@ -112,9 +112,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.98"
|
||||
version = "1.0.99"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
|
||||
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
@@ -127,9 +127,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
version = "1.4.1"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
|
||||
checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1"
|
||||
|
||||
[[package]]
|
||||
name = "arg_enum_proc_macro"
|
||||
@@ -170,9 +170,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "avif-serialize"
|
||||
version = "0.8.5"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ea8ef51aced2b9191c08197f55450d830876d9933f8f48a429b354f1d496b42"
|
||||
checksum = "47c8fbc0f831f4519fe8b810b6a7a91410ec83031b8233f730a0480029f6a23f"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
]
|
||||
@@ -211,7 +211,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
|
||||
dependencies = [
|
||||
"annotate-snippets",
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
@@ -239,9 +239,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.9.1"
|
||||
version = "2.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
|
||||
checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29"
|
||||
|
||||
[[package]]
|
||||
name = "bitstream-io"
|
||||
@@ -269,7 +269,7 @@ dependencies = [
|
||||
"num",
|
||||
"ordered-float",
|
||||
"simba",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -310,9 +310,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.32"
|
||||
version = "1.2.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e"
|
||||
checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
@@ -369,9 +369,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.43"
|
||||
version = "4.5.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50fd97c9dc2399518aa331917ac6f274280ec5eb34e555dd291899745c48ec6f"
|
||||
checksum = "1fc0e74a703892159f5ae7d3aac52c8e6c392f5ae5f359c70b5881d60aaac318"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
@@ -379,9 +379,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.43"
|
||||
version = "4.5.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c35b5830294e1fa0462034af85cc95225a4cb07092c088c55bda3147cfcd8f65"
|
||||
checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@@ -391,18 +391,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_complete"
|
||||
version = "4.5.56"
|
||||
version = "4.5.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67e4efcbb5da11a92e8a609233aa1e8a7d91e38de0be865f016d14700d45a7fd"
|
||||
checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
|
||||
dependencies = [
|
||||
"clap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.41"
|
||||
version = "4.5.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491"
|
||||
checksum = "14cb31bb0a7d536caef2639baa7fad459e15c3144efefa6dbd1c84562c4739f6"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
@@ -580,7 +580,7 @@ dependencies = [
|
||||
"ort",
|
||||
"rusqlite",
|
||||
"tap",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.15",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
@@ -960,9 +960,9 @@ checksum = "f2d1aab06663bdce00d6ca5e5ed586ec8d18033a771906c993a1e3755b368d85"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.2"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
@@ -1212,7 +1212,7 @@ version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -1320,9 +1320,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.174"
|
||||
version = "0.2.175"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
|
||||
[[package]]
|
||||
name = "libfuzzer-sys"
|
||||
@@ -1350,7 +1350,7 @@ version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
]
|
||||
@@ -1479,14 +1479,14 @@ dependencies = [
|
||||
"libc",
|
||||
"mnn-sys",
|
||||
"oneshot",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.15",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mnn-bridge"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#390c1bad0a4f719520bdb105e389b18982ee12a5"
|
||||
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#f972f1eb903d6ae38f1f1c409a01559282f2f036"
|
||||
dependencies = [
|
||||
"error-stack",
|
||||
"mnn",
|
||||
@@ -1496,7 +1496,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "mnn-sync"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#390c1bad0a4f719520bdb105e389b18982ee12a5"
|
||||
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#f972f1eb903d6ae38f1f1c409a01559282f2f036"
|
||||
dependencies = [
|
||||
"error-stack",
|
||||
"flume",
|
||||
@@ -1612,7 +1612,7 @@ dependencies = [
|
||||
"fast_image_resize",
|
||||
"ndarray",
|
||||
"num",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1779,7 +1779,7 @@ version = "0.10.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"cfg-if",
|
||||
"foreign-types",
|
||||
"libc",
|
||||
@@ -1954,9 +1954,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.95"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
|
||||
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
@@ -2098,9 +2098,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.10.0"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
@@ -2108,9 +2108,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.12.1"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
@@ -2122,7 +2122,7 @@ version = "0.5.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2181,7 +2181,7 @@ version = "0.37.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "165ca6e57b20e1351573e3729b958bc62f0e48025386970b6e4d29e7a7e71f3f"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"chrono",
|
||||
"csv",
|
||||
"fallible-iterator",
|
||||
@@ -2223,7 +2223,7 @@ version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
@@ -2250,9 +2250,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.21"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
|
||||
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
@@ -2290,7 +2290,7 @@ version = "2.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
@@ -2460,9 +2460,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.104"
|
||||
version = "2.0.106"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40"
|
||||
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2540,11 +2540,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.12"
|
||||
version = "2.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
|
||||
checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.12",
|
||||
"thiserror-impl 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2560,9 +2560,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.12"
|
||||
version = "2.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
||||
checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2759,9 +2759,9 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "3.0.12"
|
||||
version = "3.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f0fde9bc91026e381155f8c67cb354bcd35260b2f4a29bcc84639f762760c39"
|
||||
checksum = "00432f493971db5d8e47a65aeb3b02f8226b9b11f1450ff86bb772776ebadd70"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"der",
|
||||
@@ -2773,14 +2773,14 @@ dependencies = [
|
||||
"socks",
|
||||
"ureq-proto",
|
||||
"utf-8",
|
||||
"webpki-root-certs 0.26.11",
|
||||
"webpki-root-certs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ureq-proto"
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59db78ad1923f2b1be62b6da81fe80b173605ca0d57f85da2e005382adf693f7"
|
||||
checksum = "c5b6cabebbecc4c45189ab06b52f956206cea7d8c8a20851c35a85cb169224cc"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"http",
|
||||
@@ -2819,9 +2819,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.17.0"
|
||||
version = "1.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
|
||||
checksum = "f33196643e165781c20a5ead5582283a7dacbb87855d867fbc2df3f81eddc1be"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
@@ -2935,15 +2935,6 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-root-certs"
|
||||
version = "0.26.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75c7f0ef91146ebfb530314f5f1d24528d7f0767efbfd31dce919275413e393e"
|
||||
dependencies = [
|
||||
"webpki-root-certs 1.0.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-root-certs"
|
||||
version = "1.0.2"
|
||||
@@ -3221,7 +3212,7 @@ version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.9.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -53,7 +53,17 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
ordered-float = "5.0.0"
|
||||
ort = "2.0.0-rc.10"
|
||||
ort = { version = "2.0.0-rc.10" }
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
[features]
|
||||
ort-cuda = ["ort/cuda"]
|
||||
ort-coreml = ["ort/coreml"]
|
||||
ort-tensorrt = ["ort/tensorrt"]
|
||||
ort-tvm = ["ort/tvm"]
|
||||
ort-openvino = ["ort/openvino"]
|
||||
ort-directml = ["ort/directml"]
|
||||
|
||||
default = ["ort-coreml"]
|
||||
|
||||
24
flake.lock
generated
24
flake.lock
generated
@@ -3,11 +3,11 @@
|
||||
"advisory-db": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1750151065,
|
||||
"narHash": "sha256-il+CAqChFIB82xP6bO43dWlUVs+NlG7a4g8liIP5HcI=",
|
||||
"lastModified": 1755283329,
|
||||
"narHash": "sha256-33bd+PHbon+cgEiWE/zkr7dpEF5E0DiHOzyoUQbkYBc=",
|
||||
"owner": "rustsec",
|
||||
"repo": "advisory-db",
|
||||
"rev": "7573f55ba337263f61167dbb0ea926cdc7c8eb5d",
|
||||
"rev": "61aac2116c8cb7cc80ff8ca283eec7687d384038",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -18,11 +18,11 @@
|
||||
},
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1750266157,
|
||||
"narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=",
|
||||
"lastModified": 1754269165,
|
||||
"narHash": "sha256-0tcS8FHd4QjbCVoxN9jI+PjHgA4vc/IjkUSp+N3zy0U=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "e37c943371b73ed87faf33f7583860f81f1d5a48",
|
||||
"rev": "444e81206df3f7d92780680e45858e31d2f07a08",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -145,11 +145,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1750506804,
|
||||
"narHash": "sha256-VLFNc4egNjovYVxDGyBYTrvVCgDYgENp5bVi9fPTDYc=",
|
||||
"lastModified": 1755186698,
|
||||
"narHash": "sha256-wNO3+Ks2jZJ4nTHMuks+cxAiVBGNuEBXsT29Bz6HASo=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "4206c4cb56751df534751b058295ea61357bbbaa",
|
||||
"rev": "fbcf476f790d8a217c3eab4e12033dc4a0f6d23c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,11 +178,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1754621349,
|
||||
"narHash": "sha256-JkXUS/nBHyUqVTuL4EDCvUWauTHV78EYfk+WqiTAMQ4=",
|
||||
"lastModified": 1755485198,
|
||||
"narHash": "sha256-C3042ST2lUg0nh734gmuP4lRRIBitA6Maegg2/jYRM4=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "c448ab42002ac39d3337da10420c414fccfb1088",
|
||||
"rev": "aa45e63d431b28802ca4490cfc796b9e31731df7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
214
flake.nix
214
flake.nix
@@ -2,6 +2,9 @@
|
||||
description = "A simple rust flake using rust-overlay and craneLib";
|
||||
|
||||
inputs = {
|
||||
self = {
|
||||
lfs = true;
|
||||
};
|
||||
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
crane.url = "github:ipetkov/crane";
|
||||
@@ -27,22 +30,20 @@
|
||||
};
|
||||
};
|
||||
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
crane,
|
||||
flake-utils,
|
||||
nixpkgs,
|
||||
rust-overlay,
|
||||
advisory-db,
|
||||
nix-github-actions,
|
||||
mnn-overlay,
|
||||
mnn-src,
|
||||
...
|
||||
}:
|
||||
outputs = {
|
||||
self,
|
||||
crane,
|
||||
flake-utils,
|
||||
nixpkgs,
|
||||
rust-overlay,
|
||||
advisory-db,
|
||||
nix-github-actions,
|
||||
mnn-overlay,
|
||||
mnn-src,
|
||||
...
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
system: let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
overlays = [
|
||||
@@ -77,115 +78,110 @@
|
||||
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
||||
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
||||
|
||||
src =
|
||||
let
|
||||
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
||||
sourceFilters =
|
||||
path: type:
|
||||
(craneLib.filterCargoSources path type)
|
||||
|| filterBySuffix path [
|
||||
".c"
|
||||
".h"
|
||||
".hpp"
|
||||
".cpp"
|
||||
".cc"
|
||||
];
|
||||
in
|
||||
src = let
|
||||
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
||||
sourceFilters = path: type:
|
||||
(craneLib.filterCargoSources path type)
|
||||
|| filterBySuffix path [
|
||||
".c"
|
||||
".h"
|
||||
".hpp"
|
||||
".cpp"
|
||||
".cc"
|
||||
];
|
||||
in
|
||||
lib.cleanSourceWith {
|
||||
filter = sourceFilters;
|
||||
src = ./.;
|
||||
};
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = name;
|
||||
stdenv = pkgs.clangStdenv;
|
||||
doCheck = false;
|
||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||
# nativeBuildInputs = with pkgs; [
|
||||
# cmake
|
||||
# llvmPackages.libclang.lib
|
||||
# ];
|
||||
buildInputs =
|
||||
with pkgs;
|
||||
[ ]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
libiconv
|
||||
apple-sdk_13
|
||||
]);
|
||||
}
|
||||
// (lib.optionalAttrs pkgs.stdenv.isLinux {
|
||||
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
|
||||
});
|
||||
cargoArtifacts = craneLib.buildPackage commonArgs;
|
||||
in
|
||||
{
|
||||
checks = {
|
||||
"${name}-clippy" = craneLib.cargoClippy (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||
}
|
||||
);
|
||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; });
|
||||
"${name}-fmt" = craneLib.cargoFmt { inherit src; };
|
||||
"${name}-toml-fmt" = craneLib.taploFmt {
|
||||
src = pkgs.lib.sources.sourceFilesBySuffices src [ ".toml" ];
|
||||
};
|
||||
# Audit dependencies
|
||||
"${name}-audit" = craneLib.cargoAudit {
|
||||
inherit src advisory-db;
|
||||
};
|
||||
|
||||
# Audit licenses
|
||||
"${name}-deny" = craneLib.cargoDeny {
|
||||
commonArgs =
|
||||
{
|
||||
inherit src;
|
||||
};
|
||||
"${name}-nextest" = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
partitions = 1;
|
||||
partitionType = "count";
|
||||
}
|
||||
);
|
||||
}
|
||||
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; });
|
||||
};
|
||||
|
||||
packages =
|
||||
let
|
||||
pkg = craneLib.buildPackage (
|
||||
pname = name;
|
||||
stdenv = p: p.clangStdenv;
|
||||
doCheck = false;
|
||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||
nativeBuildInputs = with pkgs; [
|
||||
cmake
|
||||
pkg-config
|
||||
];
|
||||
buildInputs = with pkgs;
|
||||
[onnxruntime]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
libiconv
|
||||
apple-sdk_13
|
||||
]);
|
||||
}
|
||||
// (lib.optionalAttrs pkgs.stdenv.isLinux {
|
||||
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
|
||||
});
|
||||
cargoArtifacts = craneLib.buildPackage commonArgs;
|
||||
in {
|
||||
checks =
|
||||
{
|
||||
"${name}-clippy" = craneLib.cargoClippy (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
}
|
||||
// {
|
||||
nativeBuildInputs = with pkgs; [
|
||||
installShellFiles
|
||||
];
|
||||
postInstall = ''
|
||||
installShellCompletion --cmd ${name} \
|
||||
--bash <($out/bin/${name} completions bash) \
|
||||
--fish <($out/bin/${name} completions fish) \
|
||||
--zsh <($out/bin/${name} completions zsh)
|
||||
'';
|
||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
"${name}" = pkg;
|
||||
default = pkg;
|
||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
|
||||
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
|
||||
"${name}-toml-fmt" = craneLib.taploFmt {
|
||||
src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"];
|
||||
};
|
||||
# Audit dependencies
|
||||
"${name}-audit" = craneLib.cargoAudit {
|
||||
inherit src advisory-db;
|
||||
};
|
||||
|
||||
# Audit licenses
|
||||
"${name}-deny" = craneLib.cargoDeny {
|
||||
inherit src;
|
||||
};
|
||||
"${name}-nextest" = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
partitions = 1;
|
||||
partitionType = "count";
|
||||
}
|
||||
);
|
||||
}
|
||||
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
|
||||
};
|
||||
|
||||
packages = let
|
||||
pkg = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
}
|
||||
// {
|
||||
nativeBuildInputs = with pkgs; [
|
||||
installShellFiles
|
||||
];
|
||||
postInstall = ''
|
||||
installShellCompletion --cmd ${name} \
|
||||
--bash <($out/bin/${name} completions bash) \
|
||||
--fish <($out/bin/${name} completions fish) \
|
||||
--zsh <($out/bin/${name} completions zsh)
|
||||
'';
|
||||
}
|
||||
);
|
||||
in {
|
||||
"${name}" = pkg;
|
||||
default = pkg;
|
||||
};
|
||||
|
||||
devShells = {
|
||||
default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } (
|
||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
|
||||
commonArgs
|
||||
// {
|
||||
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
||||
packages =
|
||||
with pkgs;
|
||||
packages = with pkgs;
|
||||
[
|
||||
stableToolchainWithRustAnalyzer
|
||||
cargo-nextest
|
||||
@@ -204,7 +200,7 @@
|
||||
)
|
||||
// {
|
||||
githubActions = nix-github-actions.lib.mkGithubMatrix {
|
||||
checks = nixpkgs.lib.getAttrs [ "x86_64-linux" ] self.checks;
|
||||
checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ fn shape_error() -> ndarray::ShapeError {
|
||||
|
||||
mod rgb8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<u8>> {
|
||||
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView3::from_shape((height as usize, width as usize, 3), data)
|
||||
@@ -31,7 +31,9 @@ mod rgb8 {
|
||||
|
||||
mod rgba8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(image: &image::RgbaImage) -> Result<ndarray::ArrayView3<u8>> {
|
||||
pub(super) fn image_as_ndarray(
|
||||
image: &image::RgbaImage,
|
||||
) -> Result<ndarray::ArrayView3<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView3::from_shape((height as usize, width as usize, 4), data)
|
||||
@@ -82,7 +84,7 @@ mod gray_alpha8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(
|
||||
image: &image::GrayAlphaImage,
|
||||
) -> Result<ndarray::ArrayView3<u8>> {
|
||||
) -> Result<ndarray::ArrayView3<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView3::from_shape((height as usize, width as usize, 2), data)
|
||||
|
||||
@@ -147,7 +147,7 @@ impl<S: ndarray::Data<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Di
|
||||
NdAsImage<T, D> for ndarray::ArrayBase<S, D>
|
||||
{
|
||||
/// Clones self and makes a new image
|
||||
fn as_image_ref(&self) -> Result<ImageRef> {
|
||||
fn as_image_ref(&self) -> Result<ImageRef<'_>> {
|
||||
let shape = self.shape();
|
||||
let rows = *shape
|
||||
.first()
|
||||
|
||||
40
src/cli.rs
40
src/cli.rs
@@ -1,4 +1,6 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use mnn::ForwardType;
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
@@ -21,23 +23,10 @@ pub enum Models {
|
||||
Yolo,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Executor {
|
||||
Mnn,
|
||||
Onnx,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum OnnxEp {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum MnnEp {
|
||||
Cpu,
|
||||
Metal,
|
||||
OpenCL,
|
||||
CoreML,
|
||||
Mnn(mnn::ForwardType),
|
||||
Ort(Vec<detector::ort_ep::ExecutionProvider>),
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
@@ -48,10 +37,21 @@ pub struct Detect {
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
#[clap(short = 'e', long)]
|
||||
pub executor: Option<Executor>,
|
||||
#[clap(short, long, default_value = "cpu")]
|
||||
pub forward_type: mnn::ForwardType,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::*;
|
||||
use crate::ort_ep::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_resize::NdFir;
|
||||
use ort::{
|
||||
execution_providers::{
|
||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||
},
|
||||
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
@@ -33,18 +32,11 @@ impl FaceDetectionBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.into_iter()
|
||||
.filter_map(|provider| match provider.as_str() {
|
||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||
#[cfg(target_os = "macos")]
|
||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||
_ => {
|
||||
tracing::warn!("Unknown execution provider: {}", provider);
|
||||
None
|
||||
}
|
||||
})
|
||||
.as_ref()
|
||||
.iter()
|
||||
.filter_map(|provider| provider.to_dispatch())
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
|
||||
@@ -14,5 +14,5 @@ use ndarray::{Array2, ArrayView4};
|
||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||
pub trait FaceEmbedder {
|
||||
/// Generate embeddings for a batch of face images
|
||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||
}
|
||||
|
||||
@@ -142,14 +142,6 @@ impl EmbeddingGenerator {
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
}
|
||||
|
||||
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
@@ -160,7 +152,7 @@ impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
|
||||
// Main trait implementation for backward compatibility
|
||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
self.run_models(faces)
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
EmbeddingGenerator::run_models(self, faces)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::errors::*;
|
||||
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||
use crate::ort_ep::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray::{Array2, ArrayView4};
|
||||
use ort::{
|
||||
execution_providers::{
|
||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||
},
|
||||
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
@@ -33,18 +32,11 @@ impl EmbeddingGeneratorBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.into_iter()
|
||||
.filter_map(|provider| match provider.as_str() {
|
||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||
#[cfg(target_os = "macos")]
|
||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||
_ => {
|
||||
tracing::warn!("Unknown execution provider: {}", provider);
|
||||
None
|
||||
}
|
||||
})
|
||||
.as_ref()
|
||||
.iter()
|
||||
.filter_map(|provider| provider.to_dispatch())
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
@@ -112,7 +104,7 @@ impl EmbeddingGeneratorBuilder {
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||
|
||||
tracing::info!("Successfully created ORT RetinaFace session");
|
||||
tracing::info!("Successfully created ORT FaceNet session");
|
||||
|
||||
Ok(EmbeddingGenerator { session })
|
||||
}
|
||||
@@ -137,14 +129,63 @@ impl EmbeddingGenerator {
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||
tracing::info!("Loading face embedding model from bytes");
|
||||
tracing::info!("Loading ORT face embedding model from bytes");
|
||||
Self::builder(model)?.build()
|
||||
}
|
||||
|
||||
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// TODO: Implement ORT inference
|
||||
tracing::error!("ORT FaceNet inference not yet implemented");
|
||||
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
|
||||
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||
let input_tensor = faces
|
||||
.mapv(|x| x as f32 / 255.0)
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
// Create ORT input tensor
|
||||
let input_value = Tensor::from_array(input_tensor)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create input tensor")?;
|
||||
|
||||
// Run inference
|
||||
tracing::debug!("Running ORT FaceNet inference");
|
||||
let outputs = self
|
||||
.session
|
||||
.run(ort::inputs![Self::INPUT_NAME => input_value])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run inference")?;
|
||||
|
||||
// Extract output tensor
|
||||
let output = outputs
|
||||
.get(Self::OUTPUT_NAME)
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing output from FaceNet model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract output tensor")?;
|
||||
|
||||
let (output_shape, output_data) = output;
|
||||
|
||||
tracing::trace!("Output shape: {:?}", output_shape);
|
||||
|
||||
// Convert to ndarray format
|
||||
let output_dims = output_shape.as_ref();
|
||||
|
||||
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
|
||||
let batch_size = output_dims[0] as usize;
|
||||
let embedding_dim = output_dims[1] as usize;
|
||||
|
||||
let output_array =
|
||||
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output ndarray")?;
|
||||
|
||||
tracing::trace!(
|
||||
"Generated embeddings with shape: {:?}",
|
||||
output_array.shape()
|
||||
);
|
||||
|
||||
Ok(output_array)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +197,9 @@ impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
|
||||
// Main trait implementation for backward compatibility
|
||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Need to create a mutable reference for the session
|
||||
// This is a workaround for the trait signature mismatch
|
||||
self.run_models(faces)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,6 @@ pub mod errors;
|
||||
pub mod facedet;
|
||||
pub mod faceembed;
|
||||
pub mod image;
|
||||
pub mod ort_ep;
|
||||
|
||||
use errors::*;
|
||||
|
||||
29
src/main.rs
29
src/main.rs
@@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
const CHUNK_SIZE: usize = 2;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
@@ -23,37 +23,52 @@ pub fn main() -> Result<()> {
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
|
||||
|
||||
let executor = detect
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if detect.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
// .then_some(cli::Executor::Mnn)
|
||||
// .unwrap_or(cli::Executor::Ort);
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn => {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_detection(detect, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Onnx => {
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
@@ -72,7 +87,7 @@ pub fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()>
|
||||
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
|
||||
189
src/ort_ep.rs
Normal file
189
src/ort_ep.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
use ort::execution_providers::CoreMLExecutionProvider;
|
||||
#[cfg(feature = "ort-directml")]
|
||||
use ort::execution_providers::DirectMLExecutionProvider;
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
use ort::execution_providers::OpenVINOExecutionProvider;
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
use ort::execution_providers::TVMExecutionProvider;
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
use ort::execution_providers::TensorRTExecutionProvider;
|
||||
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||
|
||||
/// Supported execution providers for ONNX Runtime
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ExecutionProvider {
|
||||
/// CPU execution provider (always available)
|
||||
CPU,
|
||||
/// CoreML execution provider (macOS only)
|
||||
CoreML,
|
||||
/// CUDA execution provider (requires cuda feature)
|
||||
CUDA,
|
||||
/// TensorRT execution provider (requires tensorrt feature)
|
||||
TensorRT,
|
||||
/// TVM execution provider (requires tvm feature)
|
||||
TVM,
|
||||
/// OpenVINO execution provider (requires openvino feature)
|
||||
OpenVINO,
|
||||
/// DirectML execution provider (Windows only, requires directml feature)
|
||||
DirectML,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ExecutionProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ExecutionProvider::CPU => write!(f, "CPU"),
|
||||
ExecutionProvider::CoreML => write!(f, "CoreML"),
|
||||
ExecutionProvider::CUDA => write!(f, "CUDA"),
|
||||
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
|
||||
ExecutionProvider::TVM => write!(f, "TVM"),
|
||||
ExecutionProvider::OpenVINO => write!(f, "OpenVINO"),
|
||||
ExecutionProvider::DirectML => write!(f, "DirectML"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ExecutionProvider {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"cpu" => Ok(ExecutionProvider::CPU),
|
||||
"coreml" => Ok(ExecutionProvider::CoreML),
|
||||
"cuda" => Ok(ExecutionProvider::CUDA),
|
||||
"tensorrt" => Ok(ExecutionProvider::TensorRT),
|
||||
"tvm" => Ok(ExecutionProvider::TVM),
|
||||
"openvino" => Ok(ExecutionProvider::OpenVINO),
|
||||
"directml" => Ok(ExecutionProvider::DirectML),
|
||||
_ => Err(format!("Unknown execution provider: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
/// Returns all available execution providers for the current platform and features
|
||||
pub fn available_providers() -> Vec<ExecutionProvider> {
|
||||
vec![
|
||||
ExecutionProvider::CPU,
|
||||
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
|
||||
ExecutionProvider::CoreML,
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutionProvider::CUDA,
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
ExecutionProvider::TensorRT,
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
ExecutionProvider::TVM,
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
ExecutionProvider::OpenVINO,
|
||||
#[cfg(all(target_os = "windows", feature = "ort-directml"))]
|
||||
ExecutionProvider::DirectML,
|
||||
]
|
||||
}
|
||||
|
||||
/// Check if this execution provider is available on the current platform
|
||||
pub fn is_available(&self) -> bool {
|
||||
match self {
|
||||
ExecutionProvider::CPU => true,
|
||||
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
|
||||
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
|
||||
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
|
||||
ExecutionProvider::TVM => cfg!(feature = "ort-tvm"),
|
||||
ExecutionProvider::OpenVINO => cfg!(feature = "ort-openvino"),
|
||||
ExecutionProvider::DirectML => {
|
||||
cfg!(target_os = "windows") && cfg!(feature = "ort-directml")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
|
||||
match self {
|
||||
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
|
||||
ExecutionProvider::CoreML => {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
{
|
||||
Some(CoreMLExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-coreml"))]
|
||||
{
|
||||
tracing::error!("coreml support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
tracing::error!("CoreML is only available on macOS");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::CUDA => {
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
{
|
||||
Some(CUDAExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-cuda"))]
|
||||
{
|
||||
tracing::error!("CUDA support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::TensorRT => {
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
{
|
||||
Some(TensorRTExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-tensorrt"))]
|
||||
{
|
||||
tracing::error!("TensorRT support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::TVM => {
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
{
|
||||
Some(TVMExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-tvm"))]
|
||||
{
|
||||
tracing::error!("TVM support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::OpenVINO => {
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
{
|
||||
Some(OpenVINOExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-openvino"))]
|
||||
{
|
||||
tracing::error!("OpenVINO support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::DirectML => {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
#[cfg(feature = "ort-directml")]
|
||||
{
|
||||
Some(DirectMLExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-directml"))]
|
||||
{
|
||||
tracing::error!("DirectML support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
tracing::error!("DirectML is only available on Windows");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user