feat(detector): add CUDA support for ONNX face detection
This commit is contained in:
42
Cargo.lock
generated
42
Cargo.lock
generated
@@ -1271,13 +1271,14 @@ dependencies = [
|
|||||||
"image 0.25.6",
|
"image 0.25.6",
|
||||||
"imageproc",
|
"imageproc",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
|
"linfa",
|
||||||
"mnn",
|
"mnn",
|
||||||
"mnn-bridge",
|
"mnn-bridge",
|
||||||
"mnn-sync",
|
"mnn-sync",
|
||||||
"nalgebra 0.34.0",
|
"nalgebra 0.34.0",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"ndarray-image",
|
"ndarray-image",
|
||||||
"ndarray-math 0.1.0 (git+https://git.darksailor.dev/servius/ndarray-math)",
|
"ndarray-math",
|
||||||
"ndarray-resize",
|
"ndarray-resize",
|
||||||
"ndarray-safetensors",
|
"ndarray-safetensors",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
@@ -2916,6 +2917,19 @@ dependencies = [
|
|||||||
"vcpkg",
|
"vcpkg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linfa"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"ndarray",
|
||||||
|
"num-traits",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"sprs",
|
||||||
|
"thiserror 2.0.15",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.4.15"
|
version = "0.4.15"
|
||||||
@@ -3223,6 +3237,7 @@ version = "0.16.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"approx",
|
||||||
"matrixmultiply",
|
"matrixmultiply",
|
||||||
"num-complex",
|
"num-complex",
|
||||||
"num-integer",
|
"num-integer",
|
||||||
@@ -3244,16 +3259,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "ndarray-math"
|
name = "ndarray-math"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
source = "git+https://git.darksailor.dev/servius/ndarray-math#df17c36193df60e070e4e120c9feebe68ff3f517"
|
||||||
"ndarray",
|
|
||||||
"num",
|
|
||||||
"thiserror 2.0.15",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ndarray-math"
|
|
||||||
version = "0.1.0"
|
|
||||||
source = "git+https://git.darksailor.dev/servius/ndarray-math#f047966f20835267f20e5839272b9ab36c445796"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"num",
|
"num",
|
||||||
@@ -5120,6 +5126,18 @@ dependencies = [
|
|||||||
"bitflags 2.9.2",
|
"bitflags 2.9.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sprs"
|
||||||
|
version = "0.11.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8bff8419009a08f6cb7519a602c5590241fbff1446bcc823c07af15386eb801b"
|
||||||
|
dependencies = [
|
||||||
|
"ndarray",
|
||||||
|
"num-complex",
|
||||||
|
"num-traits",
|
||||||
|
"smallvec 1.15.1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlite-loadable"
|
name = "sqlite-loadable"
|
||||||
version = "0.0.5"
|
version = "0.0.5"
|
||||||
@@ -5149,7 +5167,7 @@ name = "sqlite3-safetensor-cosine"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"ndarray-math 0.1.0",
|
"ndarray-math",
|
||||||
"ndarray-safetensors",
|
"ndarray-safetensors",
|
||||||
"sqlite-loadable",
|
"sqlite-loadable",
|
||||||
]
|
]
|
||||||
|
|||||||
45
Cargo.toml
45
Cargo.toml
@@ -1,43 +1,42 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"ndarray-image",
|
"ndarray-image",
|
||||||
"ndarray-resize",
|
"ndarray-resize",
|
||||||
".",
|
".",
|
||||||
"bounding-box",
|
"bounding-box",
|
||||||
"ndarray-safetensors",
|
"ndarray-safetensors",
|
||||||
"sqlite3-safetensor-cosine",
|
"sqlite3-safetensor-cosine",
|
||||||
"ndcv-bridge",
|
"ndcv-bridge",
|
||||||
"bbox",
|
"bbox",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
|
[patch.crates-io]
|
||||||
|
linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
bbox = { path = "bbox" }
|
|
||||||
divan = { version = "0.1.21" }
|
divan = { version = "0.1.21" }
|
||||||
ndarray-npy = "0.9.1"
|
ndarray-npy = "0.9.1"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
ndarray-image = { path = "ndarray-image" }
|
ndarray-image = { path = "ndarray-image" }
|
||||||
ndarray-resize = { path = "ndarray-resize" }
|
ndarray-resize = { path = "ndarray-resize" }
|
||||||
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
|
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
|
||||||
# "metal",
|
# "metal",
|
||||||
# "coreml",
|
# "coreml",
|
||||||
"tracing",
|
"tracing",
|
||||||
], branch = "restructure-tensor-type" }
|
], branch = "restructure-tensor-type" }
|
||||||
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||||
"ndarray",
|
"ndarray",
|
||||||
], branch = "restructure-tensor-type" }
|
], branch = "restructure-tensor-type" }
|
||||||
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||||
"tracing",
|
"tracing",
|
||||||
], branch = "restructure-tensor-type" }
|
], branch = "restructure-tensor-type" }
|
||||||
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
|
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
|
||||||
opencv = { version = "0.95.1" }
|
opencv = { version = "0.95.1" }
|
||||||
bounding-box = { path = "bounding-box" }
|
bounding-box = { path = "bounding-box" }
|
||||||
ndarray-safetensors = { path = "ndarray-safetensors" }
|
|
||||||
wide = "0.7.33"
|
|
||||||
rayon = "1.11.0"
|
|
||||||
bytemuck = "1.23.2"
|
bytemuck = "1.23.2"
|
||||||
error-stack = "0.5.0"
|
error-stack = "0.5.0"
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
@@ -76,9 +75,10 @@ color = "0.3.1"
|
|||||||
itertools = "0.14.0"
|
itertools = "0.14.0"
|
||||||
ordered-float = "5.0.0"
|
ordered-float = "5.0.0"
|
||||||
ort = { version = "2.0.0-rc.10", default-features = false, features = [
|
ort = { version = "2.0.0-rc.10", default-features = false, features = [
|
||||||
"std",
|
"std",
|
||||||
"tracing",
|
"tracing",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
|
"cuda",
|
||||||
] }
|
] }
|
||||||
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||||
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
||||||
@@ -89,12 +89,13 @@ iced = { version = "0.13", features = ["tokio", "image"] }
|
|||||||
rfd = "0.15"
|
rfd = "0.15"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
imageproc = "0.25"
|
imageproc = "0.25"
|
||||||
|
linfa = "0.7.1"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
ort-cuda = ["ort/cuda"]
|
ort-cuda = []
|
||||||
ort-coreml = ["ort/coreml"]
|
ort-coreml = ["ort/coreml"]
|
||||||
ort-tensorrt = ["ort/tensorrt"]
|
ort-tensorrt = ["ort/tensorrt"]
|
||||||
ort-tvm = ["ort/tvm"]
|
ort-tvm = ["ort/tvm"]
|
||||||
@@ -103,7 +104,7 @@ ort-directml = ["ort/directml"]
|
|||||||
mnn-metal = ["mnn/metal"]
|
mnn-metal = ["mnn/metal"]
|
||||||
mnn-coreml = ["mnn/coreml"]
|
mnn-coreml = ["mnn/coreml"]
|
||||||
|
|
||||||
default = ["mnn-metal", "mnn-coreml"]
|
default = ["ort-cuda"]
|
||||||
|
|
||||||
[[test]]
|
[[test]]
|
||||||
name = "test_bbox_replacement"
|
name = "test_bbox_replacement"
|
||||||
|
|||||||
@@ -5,34 +5,34 @@ workspace = false
|
|||||||
[tasks.convert_facenet]
|
[tasks.convert_facenet]
|
||||||
command = "MNNConvert"
|
command = "MNNConvert"
|
||||||
args = [
|
args = [
|
||||||
"-f",
|
"-f",
|
||||||
"ONNX",
|
"ONNX",
|
||||||
"--modelFile",
|
"--modelFile",
|
||||||
"models/facenet.onnx",
|
"models/facenet.onnx",
|
||||||
"--MNNModel",
|
"--MNNModel",
|
||||||
"models/facenet.mnn",
|
"models/facenet.mnn",
|
||||||
"--fp16",
|
"--fp16",
|
||||||
"--bizCode",
|
"--bizCode",
|
||||||
"MNN",
|
"MNN",
|
||||||
]
|
]
|
||||||
workspace = false
|
workspace = false
|
||||||
|
|
||||||
[tasks.convert_retinaface]
|
[tasks.convert_retinaface]
|
||||||
command = "MNNConvert"
|
command = "MNNConvert"
|
||||||
args = [
|
args = [
|
||||||
"-f",
|
"-f",
|
||||||
"ONNX",
|
"ONNX",
|
||||||
"--modelFile",
|
"--modelFile",
|
||||||
"models/retinaface.onnx",
|
"models/retinaface.onnx",
|
||||||
"--MNNModel",
|
"--MNNModel",
|
||||||
"models/retinaface.mnn",
|
"models/retinaface.mnn",
|
||||||
"--fp16",
|
"--fp16",
|
||||||
"--bizCode",
|
"--bizCode",
|
||||||
"MNN",
|
"MNN",
|
||||||
]
|
]
|
||||||
workspace = false
|
workspace = false
|
||||||
|
|
||||||
[tasks.gui]
|
[tasks.gui]
|
||||||
command = "cargo"
|
command = "cargo"
|
||||||
args = ["run", "--bin", "gui"]
|
args = ["run", "--release", "--bin", "gui"]
|
||||||
workspace = false
|
workspace = false
|
||||||
|
|||||||
19
flake.nix
19
flake.nix
@@ -43,6 +43,8 @@
|
|||||||
system: let
|
system: let
|
||||||
pkgs = import nixpkgs {
|
pkgs = import nixpkgs {
|
||||||
inherit system;
|
inherit system;
|
||||||
|
config.allowUnfree = true;
|
||||||
|
config.cudaSupport = pkgs.stdenv.isLinux;
|
||||||
overlays = [
|
overlays = [
|
||||||
rust-overlay.overlays.default
|
rust-overlay.overlays.default
|
||||||
(final: prev: {
|
(final: prev: {
|
||||||
@@ -75,7 +77,7 @@
|
|||||||
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
||||||
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
||||||
|
|
||||||
ort_static = pkgs.onnxruntime.overrideAttrs (old: {
|
ort_static = (pkgs.onnxruntime.overide {cudaSupport = true;}).overrideAttrs (old: {
|
||||||
cmakeFlags =
|
cmakeFlags =
|
||||||
old.cmakeFlags
|
old.cmakeFlags
|
||||||
++ [
|
++ [
|
||||||
@@ -198,8 +200,9 @@
|
|||||||
devShells = {
|
devShells = {
|
||||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
|
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
|
||||||
commonArgs
|
commonArgs
|
||||||
// {
|
// rec {
|
||||||
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
||||||
|
LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${builtins.toString (pkgs.lib.makeLibraryPath packages)}";
|
||||||
packages = with pkgs;
|
packages = with pkgs;
|
||||||
[
|
[
|
||||||
stableToolchainWithRustAnalyzer
|
stableToolchainWithRustAnalyzer
|
||||||
@@ -215,6 +218,18 @@
|
|||||||
]
|
]
|
||||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||||
apple-sdk_13
|
apple-sdk_13
|
||||||
|
])
|
||||||
|
++ (lib.optionals pkgs.stdenv.isLinux [
|
||||||
|
xorg.libX11
|
||||||
|
xorg.libXcursor
|
||||||
|
xorg.libXrandr
|
||||||
|
xorg.libXi
|
||||||
|
xorg.libxcb
|
||||||
|
libxkbcommon
|
||||||
|
vulkan-loader
|
||||||
|
wayland
|
||||||
|
zenity
|
||||||
|
cudatoolkit
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ crate-type = ["cdylib", "staticlib"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||||
ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
|
# ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
|
||||||
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
|
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
|
||||||
sqlite-loadable = "0.0.5"
|
sqlite-loadable = "0.0.5"
|
||||||
|
|||||||
@@ -22,10 +22,12 @@ const FACENET_MODEL_ONNX: &[u8] =
|
|||||||
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.onnx"));
|
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.onnx"));
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("info")
|
.with_env_filter("info,ort=warn")
|
||||||
.with_thread_ids(true)
|
// .with_thread_ids(true)
|
||||||
.with_thread_names(true)
|
// .with_thread_names(true)
|
||||||
.with_target(false)
|
.with_file(true)
|
||||||
|
.with_line_number(true)
|
||||||
|
.with_target(true)
|
||||||
.init();
|
.init();
|
||||||
let args = <cli::Cli as clap::Parser>::parse();
|
let args = <cli::Cli as clap::Parser>::parse();
|
||||||
match args.cmd {
|
match args.cmd {
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ use detector::errors::*;
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
// Initialize logging
|
// Initialize logging
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("info")
|
.with_env_filter("warn,ort=warn")
|
||||||
// .with_thread_ids(true)
|
.with_file(true)
|
||||||
// .with_file(true)
|
.with_line_number(true)
|
||||||
// .with_thread_names(true)
|
// .with_thread_names(true)
|
||||||
.with_target(true)
|
.with_target(true)
|
||||||
.init();
|
.init();
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use iced::{
|
use iced::{
|
||||||
Alignment, Element, Length, Task, Theme,
|
Alignment, Element, Length, Settings, Task, Theme,
|
||||||
widget::{
|
widget::{
|
||||||
Space, button, column, container, image, pick_list, progress_bar, row, scrollable, slider,
|
Space, button, column, container, image, pick_list, progress_bar, row, scrollable, slider,
|
||||||
text,
|
text,
|
||||||
@@ -57,9 +57,13 @@ pub enum Tab {
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum ExecutorType {
|
pub enum ExecutorType {
|
||||||
MnnCpu,
|
MnnCpu,
|
||||||
|
#[cfg(feature = "mnn-metal")]
|
||||||
MnnMetal,
|
MnnMetal,
|
||||||
|
#[cfg(feature = "mnn-coreml")]
|
||||||
MnnCoreML,
|
MnnCoreML,
|
||||||
OnnxCpu,
|
OnnxCpu,
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
OrtCuda,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -129,7 +133,10 @@ impl Default for FaceDetectorApp {
|
|||||||
output_path: None,
|
output_path: None,
|
||||||
threshold: 0.8,
|
threshold: 0.8,
|
||||||
nms_threshold: 0.3,
|
nms_threshold: 0.3,
|
||||||
|
#[cfg(not(any(feature = "mnn-metal", feature = "ort-cuda")))]
|
||||||
executor_type: ExecutorType::MnnCpu,
|
executor_type: ExecutorType::MnnCpu,
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
executor_type: ExecutorType::OrtCuda,
|
||||||
is_processing: false,
|
is_processing: false,
|
||||||
progress: 0.0,
|
progress: 0.0,
|
||||||
status_message: "Ready".to_string(),
|
status_message: "Ready".to_string(),
|
||||||
@@ -939,12 +946,17 @@ impl FaceDetectorApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn settings_view(&self) -> Element<'_, Message> {
|
fn settings_view(&self) -> Element<'_, Message> {
|
||||||
let executor_options = vec![
|
#[allow(unused_mut)]
|
||||||
ExecutorType::MnnCpu,
|
let mut executor_options = vec![ExecutorType::MnnCpu, ExecutorType::OnnxCpu];
|
||||||
ExecutorType::MnnMetal,
|
|
||||||
ExecutorType::MnnCoreML,
|
#[cfg(feature = "mnn-metal")]
|
||||||
ExecutorType::OnnxCpu,
|
executor_options.push(ExecutorType::MnnMetal);
|
||||||
];
|
|
||||||
|
#[cfg(feature = "mnn-coreml")]
|
||||||
|
executor_options.push(ExecutorType::MnnCoreML);
|
||||||
|
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
executor_options.push(ExecutorType::OrtCuda);
|
||||||
|
|
||||||
container(
|
container(
|
||||||
column![
|
column![
|
||||||
@@ -990,9 +1002,13 @@ impl std::fmt::Display for ExecutorType {
|
|||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ExecutorType::MnnCpu => write!(f, "MNN (CPU)"),
|
ExecutorType::MnnCpu => write!(f, "MNN (CPU)"),
|
||||||
|
#[cfg(feature = "mnn-metal")]
|
||||||
ExecutorType::MnnMetal => write!(f, "MNN (Metal)"),
|
ExecutorType::MnnMetal => write!(f, "MNN (Metal)"),
|
||||||
|
#[cfg(feature = "mnn-coreml")]
|
||||||
ExecutorType::MnnCoreML => write!(f, "MNN (CoreML)"),
|
ExecutorType::MnnCoreML => write!(f, "MNN (CoreML)"),
|
||||||
ExecutorType::OnnxCpu => write!(f, "ONNX (CPU)"),
|
ExecutorType::OnnxCpu => write!(f, "ONNX (CPU)"),
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
ExecutorType::OrtCuda => write!(f, "ONNX (CUDA)"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1023,10 +1039,15 @@ fn convert_face_rois_to_handles(face_rois: Vec<ndarray::Array3<u8>>) -> Vec<imag
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn run() -> iced::Result {
|
pub fn run() -> iced::Result {
|
||||||
|
let settings = Settings {
|
||||||
|
antialiasing: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
iced::application(
|
iced::application(
|
||||||
"Face Detector",
|
"Face Detector",
|
||||||
FaceDetectorApp::update,
|
FaceDetectorApp::update,
|
||||||
FaceDetectorApp::view,
|
FaceDetectorApp::view,
|
||||||
)
|
)
|
||||||
|
.settings(settings)
|
||||||
.run_with(FaceDetectorApp::new)
|
.run_with(FaceDetectorApp::new)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,17 +114,34 @@ impl FaceDetectionBridge {
|
|||||||
|
|
||||||
// Create detector and detect faces
|
// Create detector and detect faces
|
||||||
let faces = match executor_type {
|
let faces = match executor_type {
|
||||||
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => {
|
ExecutorType::MnnCpu => {
|
||||||
let forward_type = match executor_type {
|
|
||||||
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
|
|
||||||
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
|
|
||||||
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
.with_forward_type(forward_type)
|
.with_forward_type(mnn::ForwardType::CPU)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
|
detector
|
||||||
|
.detect_faces(image_array.view(), &config)
|
||||||
|
.map_err(|e| format!("Detection failed: {}", e))?
|
||||||
|
}
|
||||||
|
#[cfg(feature = "mnn-metal")]
|
||||||
|
ExecutorType::MnnMetal => {
|
||||||
|
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
|
.with_forward_type(mnn::ForwardType::Metal)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
|
detector
|
||||||
|
.detect_faces(image_array.view(), &config)
|
||||||
|
.map_err(|e| format!("Detection failed: {}", e))?
|
||||||
|
}
|
||||||
|
#[cfg(feature = "mnn-coreml")]
|
||||||
|
ExecutorType::MnnCoreML => {
|
||||||
|
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
|
.with_forward_type(mnn::ForwardType::CoreML)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
@@ -142,6 +159,21 @@ impl FaceDetectionBridge {
|
|||||||
.detect_faces(image_array.view(), &config)
|
.detect_faces(image_array.view(), &config)
|
||||||
.map_err(|e| format!("Detection failed: {}", e))?
|
.map_err(|e| format!("Detection failed: {}", e))?
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
ExecutorType::OrtCuda => {
|
||||||
|
use crate::ort_ep::ExecutionProvider;
|
||||||
|
|
||||||
|
let ep = ExecutionProvider::CUDA;
|
||||||
|
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||||
|
.map_err(|e| format!("Failed to create ONNX CUDA detector: {}", e))?
|
||||||
|
.with_execution_providers([ep])
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build ONNX CUDA detector: {}", e))?;
|
||||||
|
|
||||||
|
detector
|
||||||
|
.detect_faces(image_array.view(), &config)
|
||||||
|
.map_err(|e| format!("CUDA detection failed: {}", e))?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let faces_count = faces.bbox.len();
|
let faces_count = faces.bbox.len();
|
||||||
@@ -195,24 +227,17 @@ impl FaceDetectionBridge {
|
|||||||
// Create detector and embedder, detect faces and generate embeddings
|
// Create detector and embedder, detect faces and generate embeddings
|
||||||
let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) =
|
let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) =
|
||||||
match executor_type {
|
match executor_type {
|
||||||
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => {
|
ExecutorType::MnnCpu => {
|
||||||
let forward_type = match executor_type {
|
|
||||||
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
|
|
||||||
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
|
|
||||||
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut detector =
|
let mut detector =
|
||||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
.with_forward_type(forward_type.clone())
|
.with_forward_type(mnn::ForwardType::CPU)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||||
.with_forward_type(forward_type)
|
.with_forward_type(mnn::ForwardType::CPU)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||||
|
|
||||||
@@ -247,7 +272,148 @@ impl FaceDetectionBridge {
|
|||||||
best_similarity,
|
best_similarity,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ExecutorType::OnnxCpu => unimplemented!(),
|
#[cfg(feature = "mnn-metal")]
|
||||||
|
ExecutorType::MnnMetal => {
|
||||||
|
let mut detector =
|
||||||
|
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
|
.with_forward_type(mnn::ForwardType::Metal)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
|
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
|
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||||
|
.with_forward_type(mnn::ForwardType::Metal)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||||
|
|
||||||
|
let img_1 = run_detection(
|
||||||
|
image1_path,
|
||||||
|
&mut detector,
|
||||||
|
&mut embedder,
|
||||||
|
threshold,
|
||||||
|
nms_threshold,
|
||||||
|
2,
|
||||||
|
)?;
|
||||||
|
let img_2 = run_detection(
|
||||||
|
image2_path,
|
||||||
|
&mut detector,
|
||||||
|
&mut embedder,
|
||||||
|
threshold,
|
||||||
|
nms_threshold,
|
||||||
|
2,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let image1_rois = img_1.rois;
|
||||||
|
let image2_rois = img_2.rois;
|
||||||
|
let image1_bbox_len = img_1.bbox.len();
|
||||||
|
let image2_bbox_len = img_2.bbox.len();
|
||||||
|
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||||
|
|
||||||
|
(
|
||||||
|
image1_bbox_len,
|
||||||
|
image2_bbox_len,
|
||||||
|
image1_rois,
|
||||||
|
image2_rois,
|
||||||
|
best_similarity,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "mnn-coreml")]
|
||||||
|
ExecutorType::MnnCoreML => {
|
||||||
|
let mut detector =
|
||||||
|
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
|
.with_forward_type(mnn::ForwardType::CoreML)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
|
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
|
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||||
|
.with_forward_type(mnn::ForwardType::CoreML)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||||
|
|
||||||
|
let img_1 = run_detection(
|
||||||
|
image1_path,
|
||||||
|
&mut detector,
|
||||||
|
&mut embedder,
|
||||||
|
threshold,
|
||||||
|
nms_threshold,
|
||||||
|
2,
|
||||||
|
)?;
|
||||||
|
let img_2 = run_detection(
|
||||||
|
image2_path,
|
||||||
|
&mut detector,
|
||||||
|
&mut embedder,
|
||||||
|
threshold,
|
||||||
|
nms_threshold,
|
||||||
|
2,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let image1_rois = img_1.rois;
|
||||||
|
let image2_rois = img_2.rois;
|
||||||
|
let image1_bbox_len = img_1.bbox.len();
|
||||||
|
let image2_bbox_len = img_2.bbox.len();
|
||||||
|
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||||
|
|
||||||
|
(
|
||||||
|
image1_bbox_len,
|
||||||
|
image2_bbox_len,
|
||||||
|
image1_rois,
|
||||||
|
image2_rois,
|
||||||
|
best_similarity,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ExecutorType::OnnxCpu => unimplemented!("ONNX face comparison not yet implemented"),
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
ExecutorType::OrtCuda => {
|
||||||
|
use crate::ort_ep::ExecutionProvider;
|
||||||
|
let ep = ExecutionProvider::CUDA;
|
||||||
|
let mut detector =
|
||||||
|
retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||||
|
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||||
|
.with_execution_providers([ep])
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||||
|
|
||||||
|
let mut embedder =
|
||||||
|
facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||||
|
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||||
|
.with_execution_providers([ep])
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||||
|
|
||||||
|
let img_1 = run_detection(
|
||||||
|
image1_path,
|
||||||
|
&mut detector,
|
||||||
|
&mut embedder,
|
||||||
|
threshold,
|
||||||
|
nms_threshold,
|
||||||
|
2,
|
||||||
|
)?;
|
||||||
|
let img_2 = run_detection(
|
||||||
|
image2_path,
|
||||||
|
&mut detector,
|
||||||
|
&mut embedder,
|
||||||
|
threshold,
|
||||||
|
nms_threshold,
|
||||||
|
2,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let image1_rois = img_1.rois;
|
||||||
|
let image2_rois = img_2.rois;
|
||||||
|
let image1_bbox_len = img_1.bbox.len();
|
||||||
|
let image2_bbox_len = img_2.bbox.len();
|
||||||
|
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||||
|
|
||||||
|
(
|
||||||
|
image1_bbox_len,
|
||||||
|
image2_bbox_len,
|
||||||
|
image1_rois,
|
||||||
|
image2_rois,
|
||||||
|
best_similarity,
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use ort::execution_providers::TensorRTExecutionProvider;
|
|||||||
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||||
|
|
||||||
/// Supported execution providers for ONNX Runtime
|
/// Supported execution providers for ONNX Runtime
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
pub enum ExecutionProvider {
|
pub enum ExecutionProvider {
|
||||||
/// CPU execution provider (always available)
|
/// CPU execution provider (always available)
|
||||||
CPU,
|
CPU,
|
||||||
|
|||||||
Reference in New Issue
Block a user