Compare commits

...

3 Commits

Author SHA1 Message Date
uttarayan21
7fc958b299 feat: Added more ort execution_provider
Some checks failed
build / checks-matrix (push) Failing after 19m0s
build / checks-build (push) Has been skipped
build / codecov (push) Failing after 19m3s
docs / docs (push) Failing after 28m31s
2025-08-18 16:31:16 +05:30
uttarayan21
3aa95a2ef5 feat: Added cli features for mnn and ort 2025-08-18 15:07:17 +05:30
uttarayan21
e7c9c38ed7 feat: implement the facenet implementation for ort 2025-08-18 13:20:55 +05:30
14 changed files with 499 additions and 267 deletions

127
Cargo.lock generated
View File

@@ -112,9 +112,9 @@ dependencies = [
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.98" version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
[[package]] [[package]]
name = "approx" name = "approx"
@@ -127,9 +127,9 @@ dependencies = [
[[package]] [[package]]
name = "arbitrary" name = "arbitrary"
version = "1.4.1" version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1"
[[package]] [[package]]
name = "arg_enum_proc_macro" name = "arg_enum_proc_macro"
@@ -170,9 +170,9 @@ dependencies = [
[[package]] [[package]]
name = "avif-serialize" name = "avif-serialize"
version = "0.8.5" version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ea8ef51aced2b9191c08197f55450d830876d9933f8f48a429b354f1d496b42" checksum = "47c8fbc0f831f4519fe8b810b6a7a91410ec83031b8233f730a0480029f6a23f"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
] ]
@@ -211,7 +211,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f" checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [ dependencies = [
"annotate-snippets", "annotate-snippets",
"bitflags 2.9.1", "bitflags 2.9.2",
"cexpr", "cexpr",
"clang-sys", "clang-sys",
"itertools 0.13.0", "itertools 0.13.0",
@@ -239,9 +239,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.9.1" version = "2.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29"
[[package]] [[package]]
name = "bitstream-io" name = "bitstream-io"
@@ -269,7 +269,7 @@ dependencies = [
"num", "num",
"ordered-float", "ordered-float",
"simba", "simba",
"thiserror 2.0.12", "thiserror 2.0.15",
] ]
[[package]] [[package]]
@@ -310,9 +310,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.32" version = "1.2.33"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e" checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
@@ -369,9 +369,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.43" version = "4.5.45"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50fd97c9dc2399518aa331917ac6f274280ec5eb34e555dd291899745c48ec6f" checksum = "1fc0e74a703892159f5ae7d3aac52c8e6c392f5ae5f359c70b5881d60aaac318"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@@ -379,9 +379,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.43" version = "4.5.44"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c35b5830294e1fa0462034af85cc95225a4cb07092c088c55bda3147cfcd8f65" checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@@ -391,18 +391,18 @@ dependencies = [
[[package]] [[package]]
name = "clap_complete" name = "clap_complete"
version = "4.5.56" version = "4.5.57"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67e4efcbb5da11a92e8a609233aa1e8a7d91e38de0be865f016d14700d45a7fd" checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
dependencies = [ dependencies = [
"clap", "clap",
] ]
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.41" version = "4.5.45"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" checksum = "14cb31bb0a7d536caef2639baa7fad459e15c3144efefa6dbd1c84562c4739f6"
dependencies = [ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
@@ -580,7 +580,7 @@ dependencies = [
"ort", "ort",
"rusqlite", "rusqlite",
"tap", "tap",
"thiserror 2.0.12", "thiserror 2.0.15",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@@ -960,9 +960,9 @@ checksum = "f2d1aab06663bdce00d6ca5e5ed586ec8d18033a771906c993a1e3755b368d85"
[[package]] [[package]]
name = "glob" name = "glob"
version = "0.3.2" version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]] [[package]]
name = "half" name = "half"
@@ -1212,7 +1212,7 @@ version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
"cfg-if", "cfg-if",
"libc", "libc",
] ]
@@ -1320,9 +1320,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.174" version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]] [[package]]
name = "libfuzzer-sys" name = "libfuzzer-sys"
@@ -1350,7 +1350,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3" checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
"libc", "libc",
"redox_syscall", "redox_syscall",
] ]
@@ -1479,14 +1479,14 @@ dependencies = [
"libc", "libc",
"mnn-sys", "mnn-sys",
"oneshot", "oneshot",
"thiserror 2.0.12", "thiserror 2.0.15",
"tracing", "tracing",
] ]
[[package]] [[package]]
name = "mnn-bridge" name = "mnn-bridge"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#390c1bad0a4f719520bdb105e389b18982ee12a5" source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#f972f1eb903d6ae38f1f1c409a01559282f2f036"
dependencies = [ dependencies = [
"error-stack", "error-stack",
"mnn", "mnn",
@@ -1496,7 +1496,7 @@ dependencies = [
[[package]] [[package]]
name = "mnn-sync" name = "mnn-sync"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#390c1bad0a4f719520bdb105e389b18982ee12a5" source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#f972f1eb903d6ae38f1f1c409a01559282f2f036"
dependencies = [ dependencies = [
"error-stack", "error-stack",
"flume", "flume",
@@ -1612,7 +1612,7 @@ dependencies = [
"fast_image_resize", "fast_image_resize",
"ndarray", "ndarray",
"num", "num",
"thiserror 2.0.12", "thiserror 2.0.15",
] ]
[[package]] [[package]]
@@ -1779,7 +1779,7 @@ version = "0.10.73"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
"cfg-if", "cfg-if",
"foreign-types", "foreign-types",
"libc", "libc",
@@ -1954,9 +1954,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.95" version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@@ -2098,9 +2098,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.10.0" version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [ dependencies = [
"either", "either",
"rayon-core", "rayon-core",
@@ -2108,9 +2108,9 @@ dependencies = [
[[package]] [[package]]
name = "rayon-core" name = "rayon-core"
version = "1.12.1" version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [ dependencies = [
"crossbeam-deque", "crossbeam-deque",
"crossbeam-utils", "crossbeam-utils",
@@ -2122,7 +2122,7 @@ version = "0.5.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
] ]
[[package]] [[package]]
@@ -2181,7 +2181,7 @@ version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "165ca6e57b20e1351573e3729b958bc62f0e48025386970b6e4d29e7a7e71f3f" checksum = "165ca6e57b20e1351573e3729b958bc62f0e48025386970b6e4d29e7a7e71f3f"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
"chrono", "chrono",
"csv", "csv",
"fallible-iterator", "fallible-iterator",
@@ -2223,7 +2223,7 @@ version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
@@ -2250,9 +2250,9 @@ dependencies = [
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.21" version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]] [[package]]
name = "ryu" name = "ryu"
@@ -2290,7 +2290,7 @@ version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
"core-foundation", "core-foundation",
"core-foundation-sys", "core-foundation-sys",
"libc", "libc",
@@ -2460,9 +2460,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.104" version = "2.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@@ -2540,11 +2540,11 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "2.0.12" version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850"
dependencies = [ dependencies = [
"thiserror-impl 2.0.12", "thiserror-impl 2.0.15",
] ]
[[package]] [[package]]
@@ -2560,9 +2560,9 @@ dependencies = [
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "2.0.12" version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@@ -2759,9 +2759,9 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]] [[package]]
name = "ureq" name = "ureq"
version = "3.0.12" version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f0fde9bc91026e381155f8c67cb354bcd35260b2f4a29bcc84639f762760c39" checksum = "00432f493971db5d8e47a65aeb3b02f8226b9b11f1450ff86bb772776ebadd70"
dependencies = [ dependencies = [
"base64", "base64",
"der", "der",
@@ -2773,14 +2773,14 @@ dependencies = [
"socks", "socks",
"ureq-proto", "ureq-proto",
"utf-8", "utf-8",
"webpki-root-certs 0.26.11", "webpki-root-certs",
] ]
[[package]] [[package]]
name = "ureq-proto" name = "ureq-proto"
version = "0.4.2" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59db78ad1923f2b1be62b6da81fe80b173605ca0d57f85da2e005382adf693f7" checksum = "c5b6cabebbecc4c45189ab06b52f956206cea7d8c8a20851c35a85cb169224cc"
dependencies = [ dependencies = [
"base64", "base64",
"http", "http",
@@ -2819,9 +2819,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.17.0" version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" checksum = "f33196643e165781c20a5ead5582283a7dacbb87855d867fbc2df3f81eddc1be"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"wasm-bindgen", "wasm-bindgen",
@@ -2935,15 +2935,6 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "webpki-root-certs"
version = "0.26.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75c7f0ef91146ebfb530314f5f1d24528d7f0767efbfd31dce919275413e393e"
dependencies = [
"webpki-root-certs 1.0.2",
]
[[package]] [[package]]
name = "webpki-root-certs" name = "webpki-root-certs"
version = "1.0.2" version = "1.0.2"
@@ -3221,7 +3212,7 @@ version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.2",
] ]
[[package]] [[package]]

View File

@@ -53,7 +53,17 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
color = "0.3.1" color = "0.3.1"
itertools = "0.14.0" itertools = "0.14.0"
ordered-float = "5.0.0" ordered-float = "5.0.0"
ort = "2.0.0-rc.10" ort = { version = "2.0.0-rc.10" }
[profile.release] [profile.release]
debug = true debug = true
[features]
ort-cuda = ["ort/cuda"]
ort-coreml = ["ort/coreml"]
ort-tensorrt = ["ort/tensorrt"]
ort-tvm = ["ort/tvm"]
ort-openvino = ["ort/openvino"]
ort-directml = ["ort/directml"]
default = ["ort-coreml"]

24
flake.lock generated
View File

@@ -3,11 +3,11 @@
"advisory-db": { "advisory-db": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1750151065, "lastModified": 1755283329,
"narHash": "sha256-il+CAqChFIB82xP6bO43dWlUVs+NlG7a4g8liIP5HcI=", "narHash": "sha256-33bd+PHbon+cgEiWE/zkr7dpEF5E0DiHOzyoUQbkYBc=",
"owner": "rustsec", "owner": "rustsec",
"repo": "advisory-db", "repo": "advisory-db",
"rev": "7573f55ba337263f61167dbb0ea926cdc7c8eb5d", "rev": "61aac2116c8cb7cc80ff8ca283eec7687d384038",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -18,11 +18,11 @@
}, },
"crane": { "crane": {
"locked": { "locked": {
"lastModified": 1750266157, "lastModified": 1754269165,
"narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=", "narHash": "sha256-0tcS8FHd4QjbCVoxN9jI+PjHgA4vc/IjkUSp+N3zy0U=",
"owner": "ipetkov", "owner": "ipetkov",
"repo": "crane", "repo": "crane",
"rev": "e37c943371b73ed87faf33f7583860f81f1d5a48", "rev": "444e81206df3f7d92780680e45858e31d2f07a08",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -145,11 +145,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1750506804, "lastModified": 1755186698,
"narHash": "sha256-VLFNc4egNjovYVxDGyBYTrvVCgDYgENp5bVi9fPTDYc=", "narHash": "sha256-wNO3+Ks2jZJ4nTHMuks+cxAiVBGNuEBXsT29Bz6HASo=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "4206c4cb56751df534751b058295ea61357bbbaa", "rev": "fbcf476f790d8a217c3eab4e12033dc4a0f6d23c",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -178,11 +178,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1754621349, "lastModified": 1755485198,
"narHash": "sha256-JkXUS/nBHyUqVTuL4EDCvUWauTHV78EYfk+WqiTAMQ4=", "narHash": "sha256-C3042ST2lUg0nh734gmuP4lRRIBitA6Maegg2/jYRM4=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "c448ab42002ac39d3337da10420c414fccfb1088", "rev": "aa45e63d431b28802ca4490cfc796b9e31731df7",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@@ -2,6 +2,9 @@
description = "A simple rust flake using rust-overlay and craneLib"; description = "A simple rust flake using rust-overlay and craneLib";
inputs = { inputs = {
self = {
lfs = true;
};
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
crane.url = "github:ipetkov/crane"; crane.url = "github:ipetkov/crane";
@@ -27,8 +30,7 @@
}; };
}; };
outputs = outputs = {
{
self, self,
crane, crane,
flake-utils, flake-utils,
@@ -41,8 +43,7 @@
... ...
}: }:
flake-utils.lib.eachDefaultSystem ( flake-utils.lib.eachDefaultSystem (
system: system: let
let
pkgs = import nixpkgs { pkgs = import nixpkgs {
inherit system; inherit system;
overlays = [ overlays = [
@@ -77,11 +78,9 @@
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain; craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools; craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
src = src = let
let
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts; filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
sourceFilters = sourceFilters = path: type:
path: type:
(craneLib.filterCargoSources path type) (craneLib.filterCargoSources path type)
|| filterBySuffix path [ || filterBySuffix path [
".c" ".c"
@@ -95,19 +94,19 @@
filter = sourceFilters; filter = sourceFilters;
src = ./.; src = ./.;
}; };
commonArgs = { commonArgs =
{
inherit src; inherit src;
pname = name; pname = name;
stdenv = pkgs.clangStdenv; stdenv = p: p.clangStdenv;
doCheck = false; doCheck = false;
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
# nativeBuildInputs = with pkgs; [ nativeBuildInputs = with pkgs; [
# cmake cmake
# llvmPackages.libclang.lib pkg-config
# ]; ];
buildInputs = buildInputs = with pkgs;
with pkgs; [onnxruntime]
[ ]
++ (lib.optionals pkgs.stdenv.isDarwin [ ++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv libiconv
apple-sdk_13 apple-sdk_13
@@ -117,9 +116,9 @@
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include"; # BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
}); });
cargoArtifacts = craneLib.buildPackage commonArgs; cargoArtifacts = craneLib.buildPackage commonArgs;
in in {
checks =
{ {
checks = {
"${name}-clippy" = craneLib.cargoClippy ( "${name}-clippy" = craneLib.cargoClippy (
commonArgs commonArgs
// { // {
@@ -127,10 +126,10 @@
cargoClippyExtraArgs = "--all-targets -- --deny warnings"; cargoClippyExtraArgs = "--all-targets -- --deny warnings";
} }
); );
"${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; }); "${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
"${name}-fmt" = craneLib.cargoFmt { inherit src; }; "${name}-fmt" = craneLib.cargoFmt {inherit src;};
"${name}-toml-fmt" = craneLib.taploFmt { "${name}-toml-fmt" = craneLib.taploFmt {
src = pkgs.lib.sources.sourceFilesBySuffices src [ ".toml" ]; src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"];
}; };
# Audit dependencies # Audit dependencies
"${name}-audit" = craneLib.cargoAudit { "${name}-audit" = craneLib.cargoAudit {
@@ -151,11 +150,10 @@
); );
} }
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) { // lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; }); "${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
}; };
packages = packages = let
let
pkg = craneLib.buildPackage ( pkg = craneLib.buildPackage (
commonArgs commonArgs
// { // {
@@ -173,19 +171,17 @@
''; '';
} }
); );
in in {
{
"${name}" = pkg; "${name}" = pkg;
default = pkg; default = pkg;
}; };
devShells = { devShells = {
default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } ( default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
commonArgs commonArgs
// { // {
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver"; LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
packages = packages = with pkgs;
with pkgs;
[ [
stableToolchainWithRustAnalyzer stableToolchainWithRustAnalyzer
cargo-nextest cargo-nextest
@@ -204,7 +200,7 @@
) )
// { // {
githubActions = nix-github-actions.lib.mkGithubMatrix { githubActions = nix-github-actions.lib.mkGithubMatrix {
checks = nixpkgs.lib.getAttrs [ "x86_64-linux" ] self.checks; checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks;
}; };
}; };
} }

View File

@@ -5,7 +5,7 @@ fn shape_error() -> ndarray::ShapeError {
mod rgb8 { mod rgb8 {
use super::Result; use super::Result;
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<u8>> { pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<'_, u8>> {
let (width, height) = image.dimensions(); let (width, height) = image.dimensions();
let data = image.as_raw(); let data = image.as_raw();
ndarray::ArrayView3::from_shape((height as usize, width as usize, 3), data) ndarray::ArrayView3::from_shape((height as usize, width as usize, 3), data)
@@ -31,7 +31,9 @@ mod rgb8 {
mod rgba8 { mod rgba8 {
use super::Result; use super::Result;
pub(super) fn image_as_ndarray(image: &image::RgbaImage) -> Result<ndarray::ArrayView3<u8>> { pub(super) fn image_as_ndarray(
image: &image::RgbaImage,
) -> Result<ndarray::ArrayView3<'_, u8>> {
let (width, height) = image.dimensions(); let (width, height) = image.dimensions();
let data = image.as_raw(); let data = image.as_raw();
ndarray::ArrayView3::from_shape((height as usize, width as usize, 4), data) ndarray::ArrayView3::from_shape((height as usize, width as usize, 4), data)
@@ -82,7 +84,7 @@ mod gray_alpha8 {
use super::Result; use super::Result;
pub(super) fn image_as_ndarray( pub(super) fn image_as_ndarray(
image: &image::GrayAlphaImage, image: &image::GrayAlphaImage,
) -> Result<ndarray::ArrayView3<u8>> { ) -> Result<ndarray::ArrayView3<'_, u8>> {
let (width, height) = image.dimensions(); let (width, height) = image.dimensions();
let data = image.as_raw(); let data = image.as_raw();
ndarray::ArrayView3::from_shape((height as usize, width as usize, 2), data) ndarray::ArrayView3::from_shape((height as usize, width as usize, 2), data)

View File

@@ -147,7 +147,7 @@ impl<S: ndarray::Data<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Di
NdAsImage<T, D> for ndarray::ArrayBase<S, D> NdAsImage<T, D> for ndarray::ArrayBase<S, D>
{ {
/// Clones self and makes a new image /// Clones self and makes a new image
fn as_image_ref(&self) -> Result<ImageRef> { fn as_image_ref(&self) -> Result<ImageRef<'_>> {
let shape = self.shape(); let shape = self.shape();
let rows = *shape let rows = *shape
.first() .first()

View File

@@ -1,4 +1,6 @@
use std::path::PathBuf; use std::path::PathBuf;
use mnn::ForwardType;
#[derive(Debug, clap::Parser)] #[derive(Debug, clap::Parser)]
pub struct Cli { pub struct Cli {
#[clap(subcommand)] #[clap(subcommand)]
@@ -21,23 +23,10 @@ pub enum Models {
Yolo, Yolo,
} }
#[derive(Debug, clap::ValueEnum, Clone, Copy)] #[derive(Debug, Clone)]
pub enum Executor { pub enum Executor {
Mnn, Mnn(mnn::ForwardType),
Onnx, Ort(Vec<detector::ort_ep::ExecutionProvider>),
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum OnnxEp {
Cpu,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum MnnEp {
Cpu,
Metal,
OpenCL,
CoreML,
} }
#[derive(Debug, clap::Args)] #[derive(Debug, clap::Args)]
@@ -48,10 +37,21 @@ pub struct Detect {
pub model_type: Models, pub model_type: Models,
#[clap(short, long)] #[clap(short, long)]
pub output: Option<PathBuf>, pub output: Option<PathBuf>,
#[clap(short = 'e', long)] #[clap(
pub executor: Option<Executor>, short = 'p',
#[clap(short, long, default_value = "cpu")] long,
pub forward_type: mnn::ForwardType, default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)] #[clap(short, long, default_value_t = 0.8)]
pub threshold: f32, pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)] #[clap(short, long, default_value_t = 0.3)]

View File

@@ -1,11 +1,10 @@
use crate::errors::*; use crate::errors::*;
use crate::facedet::*; use crate::facedet::*;
use crate::ort_ep::*;
use error_stack::ResultExt; use error_stack::ResultExt;
use ndarray_resize::NdFir; use ndarray_resize::NdFir;
use ort::{ use ort::{
execution_providers::{ execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
},
session::{Session, builder::GraphOptimizationLevel}, session::{Session, builder::GraphOptimizationLevel},
value::Tensor, value::Tensor,
}; };
@@ -33,18 +32,11 @@ impl FaceDetectionBuilder {
}) })
} }
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self { pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
let execution_providers: Vec<ExecutionProviderDispatch> = providers let execution_providers: Vec<ExecutionProviderDispatch> = providers
.into_iter() .as_ref()
.filter_map(|provider| match provider.as_str() { .iter()
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()), .filter_map(|provider| provider.to_dispatch())
#[cfg(target_os = "macos")]
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
_ => {
tracing::warn!("Unknown execution provider: {}", provider);
None
}
})
.collect(); .collect();
if !execution_providers.is_empty() { if !execution_providers.is_empty() {

View File

@@ -14,5 +14,5 @@ use ndarray::{Array2, ArrayView4};
/// Common trait for face embedding backends - maintained for backward compatibility /// Common trait for face embedding backends - maintained for backward compatibility
pub trait FaceEmbedder { pub trait FaceEmbedder {
/// Generate embeddings for a batch of face images /// Generate embeddings for a batch of face images
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>; fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
} }

View File

@@ -142,14 +142,6 @@ impl EmbeddingGenerator {
.change_context(Error)?; .change_context(Error)?;
Ok(output) Ok(output)
} }
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
// todo!()
// }
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
// todo!()
// }
} }
impl FaceNetEmbedder for EmbeddingGenerator { impl FaceNetEmbedder for EmbeddingGenerator {
@@ -160,7 +152,7 @@ impl FaceNetEmbedder for EmbeddingGenerator {
// Main trait implementation for backward compatibility // Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> { fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
self.run_models(faces) EmbeddingGenerator::run_models(self, faces)
} }
} }

View File

@@ -1,11 +1,10 @@
use crate::errors::*; use crate::errors::*;
use crate::faceembed::facenet::FaceNetEmbedder; use crate::faceembed::facenet::FaceNetEmbedder;
use crate::ort_ep::*;
use error_stack::ResultExt; use error_stack::ResultExt;
use ndarray::{Array2, ArrayView4}; use ndarray::{Array2, ArrayView4};
use ort::{ use ort::{
execution_providers::{ execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
},
session::{Session, builder::GraphOptimizationLevel}, session::{Session, builder::GraphOptimizationLevel},
value::Tensor, value::Tensor,
}; };
@@ -33,18 +32,11 @@ impl EmbeddingGeneratorBuilder {
}) })
} }
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self { pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
let execution_providers: Vec<ExecutionProviderDispatch> = providers let execution_providers: Vec<ExecutionProviderDispatch> = providers
.into_iter() .as_ref()
.filter_map(|provider| match provider.as_str() { .iter()
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()), .filter_map(|provider| provider.to_dispatch())
#[cfg(target_os = "macos")]
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
_ => {
tracing::warn!("Unknown execution provider: {}", provider);
None
}
})
.collect(); .collect();
if !execution_providers.is_empty() { if !execution_providers.is_empty() {
@@ -112,7 +104,7 @@ impl EmbeddingGeneratorBuilder {
.change_context(Error) .change_context(Error)
.attach_printable("Failed to create ORT session from model bytes")?; .attach_printable("Failed to create ORT session from model bytes")?;
tracing::info!("Successfully created ORT RetinaFace session"); tracing::info!("Successfully created ORT FaceNet session");
Ok(EmbeddingGenerator { session }) Ok(EmbeddingGenerator { session })
} }
@@ -137,14 +129,63 @@ impl EmbeddingGenerator {
} }
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> { pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
tracing::info!("Loading face embedding model from bytes"); tracing::info!("Loading ORT face embedding model from bytes");
Self::builder(model)?.build() Self::builder(model)?.build()
} }
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> { pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// TODO: Implement ORT inference // Convert input from u8 to f32 and normalize to [0, 1] range
tracing::error!("ORT FaceNet inference not yet implemented"); let input_tensor = faces
Err(Error).attach_printable("ORT FaceNet implementation is incomplete") .mapv(|x| x as f32 / 255.0)
.as_standard_layout()
.into_owned();
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
// Create ORT input tensor
let input_value = Tensor::from_array(input_tensor)
.change_context(Error)
.attach_printable("Failed to create input tensor")?;
// Run inference
tracing::debug!("Running ORT FaceNet inference");
let outputs = self
.session
.run(ort::inputs![Self::INPUT_NAME => input_value])
.change_context(Error)
.attach_printable("Failed to run inference")?;
// Extract output tensor
let output = outputs
.get(Self::OUTPUT_NAME)
.ok_or(Error)
.attach_printable("Missing output from FaceNet model")?
.try_extract_tensor::<f32>()
.change_context(Error)
.attach_printable("Failed to extract output tensor")?;
let (output_shape, output_data) = output;
tracing::trace!("Output shape: {:?}", output_shape);
// Convert to ndarray format
let output_dims = output_shape.as_ref();
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
let batch_size = output_dims[0] as usize;
let embedding_dim = output_dims[1] as usize;
let output_array =
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
.change_context(Error)
.attach_printable("Failed to create output ndarray")?;
tracing::trace!(
"Generated embeddings with shape: {:?}",
output_array.shape()
);
Ok(output_array)
} }
} }
@@ -156,7 +197,9 @@ impl FaceNetEmbedder for EmbeddingGenerator {
// Main trait implementation for backward compatibility // Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> { fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Need to create a mutable reference for the session
// This is a workaround for the trait signature mismatch
self.run_models(faces) self.run_models(faces)
} }
} }

View File

@@ -2,4 +2,6 @@ pub mod errors;
pub mod facedet; pub mod facedet;
pub mod faceembed; pub mod faceembed;
pub mod image; pub mod image;
pub mod ort_ep;
use errors::*; use errors::*;

View File

@@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn"); const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx"); const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx"); const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
const CHUNK_SIZE: usize = 8; const CHUNK_SIZE: usize = 2;
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter("trace") .with_env_filter("trace")
@@ -23,37 +23,52 @@ pub fn main() -> Result<()> {
match args.cmd { match args.cmd {
cli::SubCommand::Detect(detect) => { cli::SubCommand::Detect(detect) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility) // Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
let executor = detect
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
// .then_some(cli::Executor::Mnn)
// .unwrap_or(cli::Executor::Ort);
match executor { match executor {
cli::Executor::Mnn => { cli::Executor::Mnn(forward) => {
let retinaface = let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)? .change_context(Error)?
.with_forward_type(detect.forward_type) .with_forward_type(forward)
.build() .build()
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to create face detection model")?; .attach_printable("Failed to create face detection model")?;
let facenet = let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)? .change_context(Error)?
.with_forward_type(detect.forward_type) .with_forward_type(forward)
.build() .build()
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?; .attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?; run_detection(detect, retinaface, facenet)?;
} }
cli::Executor::Onnx => { cli::Executor::Ort(ep) => {
let retinaface = let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX) facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)? .change_context(Error)?
.with_execution_providers(&ep)
.build() .build()
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to create face detection model")?; .attach_printable("Failed to create face detection model")?;
let facenet = let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX) faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)? .change_context(Error)?
.with_execution_providers(ep)
.build() .build()
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?; .attach_printable("Failed to create face embedding model")?;
@@ -72,7 +87,7 @@ pub fn main() -> Result<()> {
Ok(()) Ok(())
} }
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()> fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
where where
D: facedet::FaceDetector, D: facedet::FaceDetector,
E: faceembed::FaceEmbedder, E: faceembed::FaceEmbedder,

189
src/ort_ep.rs Normal file
View File

@@ -0,0 +1,189 @@
#[cfg(feature = "ort-cuda")]
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(feature = "ort-coreml")]
use ort::execution_providers::CoreMLExecutionProvider;
#[cfg(feature = "ort-directml")]
use ort::execution_providers::DirectMLExecutionProvider;
#[cfg(feature = "ort-openvino")]
use ort::execution_providers::OpenVINOExecutionProvider;
#[cfg(feature = "ort-tvm")]
use ort::execution_providers::TVMExecutionProvider;
#[cfg(feature = "ort-tensorrt")]
use ort::execution_providers::TensorRTExecutionProvider;
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
/// Supported execution providers for ONNX Runtime
#[derive(Debug, Clone)]
pub enum ExecutionProvider {
/// CPU execution provider (always available)
CPU,
/// CoreML execution provider (macOS only)
CoreML,
/// CUDA execution provider (requires cuda feature)
CUDA,
/// TensorRT execution provider (requires tensorrt feature)
TensorRT,
/// TVM execution provider (requires tvm feature)
TVM,
/// OpenVINO execution provider (requires openvino feature)
OpenVINO,
/// DirectML execution provider (Windows only, requires directml feature)
DirectML,
}
impl std::fmt::Display for ExecutionProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionProvider::CPU => write!(f, "CPU"),
ExecutionProvider::CoreML => write!(f, "CoreML"),
ExecutionProvider::CUDA => write!(f, "CUDA"),
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
ExecutionProvider::TVM => write!(f, "TVM"),
ExecutionProvider::OpenVINO => write!(f, "OpenVINO"),
ExecutionProvider::DirectML => write!(f, "DirectML"),
}
}
}
impl std::str::FromStr for ExecutionProvider {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"cpu" => Ok(ExecutionProvider::CPU),
"coreml" => Ok(ExecutionProvider::CoreML),
"cuda" => Ok(ExecutionProvider::CUDA),
"tensorrt" => Ok(ExecutionProvider::TensorRT),
"tvm" => Ok(ExecutionProvider::TVM),
"openvino" => Ok(ExecutionProvider::OpenVINO),
"directml" => Ok(ExecutionProvider::DirectML),
_ => Err(format!("Unknown execution provider: {}", s)),
}
}
}
impl ExecutionProvider {
/// Returns all available execution providers for the current platform and features
pub fn available_providers() -> Vec<ExecutionProvider> {
vec![
ExecutionProvider::CPU,
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
ExecutionProvider::CoreML,
#[cfg(feature = "ort-cuda")]
ExecutionProvider::CUDA,
#[cfg(feature = "ort-tensorrt")]
ExecutionProvider::TensorRT,
#[cfg(feature = "ort-tvm")]
ExecutionProvider::TVM,
#[cfg(feature = "ort-openvino")]
ExecutionProvider::OpenVINO,
#[cfg(all(target_os = "windows", feature = "ort-directml"))]
ExecutionProvider::DirectML,
]
}
/// Check if this execution provider is available on the current platform
pub fn is_available(&self) -> bool {
match self {
ExecutionProvider::CPU => true,
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
ExecutionProvider::TVM => cfg!(feature = "ort-tvm"),
ExecutionProvider::OpenVINO => cfg!(feature = "ort-openvino"),
ExecutionProvider::DirectML => {
cfg!(target_os = "windows") && cfg!(feature = "ort-directml")
}
}
}
}
impl ExecutionProvider {
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
match self {
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
ExecutionProvider::CoreML => {
#[cfg(target_os = "macos")]
{
#[cfg(feature = "ort-coreml")]
{
Some(CoreMLExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-coreml"))]
{
tracing::error!("coreml support not compiled in");
None
}
}
#[cfg(not(target_os = "macos"))]
{
tracing::error!("CoreML is only available on macOS");
None
}
}
ExecutionProvider::CUDA => {
#[cfg(feature = "ort-cuda")]
{
Some(CUDAExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-cuda"))]
{
tracing::error!("CUDA support not compiled in");
None
}
}
ExecutionProvider::TensorRT => {
#[cfg(feature = "ort-tensorrt")]
{
Some(TensorRTExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-tensorrt"))]
{
tracing::error!("TensorRT support not compiled in");
None
}
}
ExecutionProvider::TVM => {
#[cfg(feature = "ort-tvm")]
{
Some(TVMExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-tvm"))]
{
tracing::error!("TVM support not compiled in");
None
}
}
ExecutionProvider::OpenVINO => {
#[cfg(feature = "ort-openvino")]
{
Some(OpenVINOExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-openvino"))]
{
tracing::error!("OpenVINO support not compiled in");
None
}
}
ExecutionProvider::DirectML => {
#[cfg(target_os = "windows")]
{
#[cfg(feature = "ort-directml")]
{
Some(DirectMLExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-directml"))]
{
tracing::error!("DirectML support not compiled in");
None
}
}
#[cfg(not(target_os = "windows"))]
{
tracing::error!("DirectML is only available on Windows");
None
}
}
}
}
}