feat(detector): add CUDA support for ONNX face detection
Some checks failed
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
build / checks-matrix (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-28 18:32:00 +05:30
parent 4256c0af74
commit ac8f1d01b4
10 changed files with 315 additions and 92 deletions

42
Cargo.lock generated
View File

@@ -1271,13 +1271,14 @@ dependencies = [
"image 0.25.6", "image 0.25.6",
"imageproc", "imageproc",
"itertools 0.14.0", "itertools 0.14.0",
"linfa",
"mnn", "mnn",
"mnn-bridge", "mnn-bridge",
"mnn-sync", "mnn-sync",
"nalgebra 0.34.0", "nalgebra 0.34.0",
"ndarray", "ndarray",
"ndarray-image", "ndarray-image",
"ndarray-math 0.1.0 (git+https://git.darksailor.dev/servius/ndarray-math)", "ndarray-math",
"ndarray-resize", "ndarray-resize",
"ndarray-safetensors", "ndarray-safetensors",
"ordered-float", "ordered-float",
@@ -2916,6 +2917,19 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "linfa"
version = "0.7.1"
source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba"
dependencies = [
"approx",
"ndarray",
"num-traits",
"rand 0.8.5",
"sprs",
"thiserror 2.0.15",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.4.15" version = "0.4.15"
@@ -3223,6 +3237,7 @@ version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [ dependencies = [
"approx",
"matrixmultiply", "matrixmultiply",
"num-complex", "num-complex",
"num-integer", "num-integer",
@@ -3244,16 +3259,7 @@ dependencies = [
[[package]] [[package]]
name = "ndarray-math" name = "ndarray-math"
version = "0.1.0" version = "0.1.0"
dependencies = [ source = "git+https://git.darksailor.dev/servius/ndarray-math#df17c36193df60e070e4e120c9feebe68ff3f517"
"ndarray",
"num",
"thiserror 2.0.15",
]
[[package]]
name = "ndarray-math"
version = "0.1.0"
source = "git+https://git.darksailor.dev/servius/ndarray-math#f047966f20835267f20e5839272b9ab36c445796"
dependencies = [ dependencies = [
"ndarray", "ndarray",
"num", "num",
@@ -5120,6 +5126,18 @@ dependencies = [
"bitflags 2.9.2", "bitflags 2.9.2",
] ]
[[package]]
name = "sprs"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bff8419009a08f6cb7519a602c5590241fbff1446bcc823c07af15386eb801b"
dependencies = [
"ndarray",
"num-complex",
"num-traits",
"smallvec 1.15.1",
]
[[package]] [[package]]
name = "sqlite-loadable" name = "sqlite-loadable"
version = "0.0.5" version = "0.0.5"
@@ -5149,7 +5167,7 @@ name = "sqlite3-safetensor-cosine"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"ndarray", "ndarray",
"ndarray-math 0.1.0", "ndarray-math",
"ndarray-safetensors", "ndarray-safetensors",
"sqlite-loadable", "sqlite-loadable",
] ]

View File

@@ -1,43 +1,42 @@
[workspace] [workspace]
members = [ members = [
"ndarray-image", "ndarray-image",
"ndarray-resize", "ndarray-resize",
".", ".",
"bounding-box", "bounding-box",
"ndarray-safetensors", "ndarray-safetensors",
"sqlite3-safetensor-cosine", "sqlite3-safetensor-cosine",
"ndcv-bridge", "ndcv-bridge",
"bbox", "bbox",
] ]
[workspace.package] [workspace.package]
version = "0.1.0" version = "0.1.0"
edition = "2024" edition = "2024"
[patch.crates-io]
linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
[workspace.dependencies] [workspace.dependencies]
bbox = { path = "bbox" }
divan = { version = "0.1.21" } divan = { version = "0.1.21" }
ndarray-npy = "0.9.1" ndarray-npy = "0.9.1"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
ndarray-image = { path = "ndarray-image" } ndarray-image = { path = "ndarray-image" }
ndarray-resize = { path = "ndarray-resize" } ndarray-resize = { path = "ndarray-resize" }
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [ mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
# "metal", # "metal",
# "coreml", # "coreml",
"tracing", "tracing",
], branch = "restructure-tensor-type" } ], branch = "restructure-tensor-type" }
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
"ndarray", "ndarray",
], branch = "restructure-tensor-type" } ], branch = "restructure-tensor-type" }
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [ mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
"tracing", "tracing",
], branch = "restructure-tensor-type" } ], branch = "restructure-tensor-type" }
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] } nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
opencv = { version = "0.95.1" } opencv = { version = "0.95.1" }
bounding-box = { path = "bounding-box" } bounding-box = { path = "bounding-box" }
ndarray-safetensors = { path = "ndarray-safetensors" }
wide = "0.7.33"
rayon = "1.11.0"
bytemuck = "1.23.2" bytemuck = "1.23.2"
error-stack = "0.5.0" error-stack = "0.5.0"
thiserror = "2.0" thiserror = "2.0"
@@ -76,9 +75,10 @@ color = "0.3.1"
itertools = "0.14.0" itertools = "0.14.0"
ordered-float = "5.0.0" ordered-float = "5.0.0"
ort = { version = "2.0.0-rc.10", default-features = false, features = [ ort = { version = "2.0.0-rc.10", default-features = false, features = [
"std", "std",
"tracing", "tracing",
"ndarray", "ndarray",
"cuda",
] } ] }
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" } ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" } ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
@@ -89,12 +89,13 @@ iced = { version = "0.13", features = ["tokio", "image"] }
rfd = "0.15" rfd = "0.15"
futures = "0.3" futures = "0.3"
imageproc = "0.25" imageproc = "0.25"
linfa = "0.7.1"
[profile.release] [profile.release]
debug = true debug = true
[features] [features]
ort-cuda = ["ort/cuda"] ort-cuda = []
ort-coreml = ["ort/coreml"] ort-coreml = ["ort/coreml"]
ort-tensorrt = ["ort/tensorrt"] ort-tensorrt = ["ort/tensorrt"]
ort-tvm = ["ort/tvm"] ort-tvm = ["ort/tvm"]
@@ -103,7 +104,7 @@ ort-directml = ["ort/directml"]
mnn-metal = ["mnn/metal"] mnn-metal = ["mnn/metal"]
mnn-coreml = ["mnn/coreml"] mnn-coreml = ["mnn/coreml"]
default = ["mnn-metal", "mnn-coreml"] default = ["ort-cuda"]
[[test]] [[test]]
name = "test_bbox_replacement" name = "test_bbox_replacement"

View File

@@ -5,34 +5,34 @@ workspace = false
[tasks.convert_facenet] [tasks.convert_facenet]
command = "MNNConvert" command = "MNNConvert"
args = [ args = [
"-f", "-f",
"ONNX", "ONNX",
"--modelFile", "--modelFile",
"models/facenet.onnx", "models/facenet.onnx",
"--MNNModel", "--MNNModel",
"models/facenet.mnn", "models/facenet.mnn",
"--fp16", "--fp16",
"--bizCode", "--bizCode",
"MNN", "MNN",
] ]
workspace = false workspace = false
[tasks.convert_retinaface] [tasks.convert_retinaface]
command = "MNNConvert" command = "MNNConvert"
args = [ args = [
"-f", "-f",
"ONNX", "ONNX",
"--modelFile", "--modelFile",
"models/retinaface.onnx", "models/retinaface.onnx",
"--MNNModel", "--MNNModel",
"models/retinaface.mnn", "models/retinaface.mnn",
"--fp16", "--fp16",
"--bizCode", "--bizCode",
"MNN", "MNN",
] ]
workspace = false workspace = false
[tasks.gui] [tasks.gui]
command = "cargo" command = "cargo"
args = ["run", "--bin", "gui"] args = ["run", "--release", "--bin", "gui"]
workspace = false workspace = false

View File

@@ -43,6 +43,8 @@
system: let system: let
pkgs = import nixpkgs { pkgs = import nixpkgs {
inherit system; inherit system;
config.allowUnfree = true;
config.cudaSupport = pkgs.stdenv.isLinux;
overlays = [ overlays = [
rust-overlay.overlays.default rust-overlay.overlays.default
(final: prev: { (final: prev: {
@@ -75,7 +77,7 @@
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain; craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools; craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
ort_static = pkgs.onnxruntime.overrideAttrs (old: { ort_static = (pkgs.onnxruntime.overide {cudaSupport = true;}).overrideAttrs (old: {
cmakeFlags = cmakeFlags =
old.cmakeFlags old.cmakeFlags
++ [ ++ [
@@ -198,8 +200,9 @@
devShells = { devShells = {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} ( default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
commonArgs commonArgs
// { // rec {
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver"; LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${builtins.toString (pkgs.lib.makeLibraryPath packages)}";
packages = with pkgs; packages = with pkgs;
[ [
stableToolchainWithRustAnalyzer stableToolchainWithRustAnalyzer
@@ -215,6 +218,18 @@
] ]
++ (lib.optionals pkgs.stdenv.isDarwin [ ++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13 apple-sdk_13
])
++ (lib.optionals pkgs.stdenv.isLinux [
xorg.libX11
xorg.libXcursor
xorg.libXrandr
xorg.libXi
xorg.libxcb
libxkbcommon
vulkan-loader
wayland
zenity
cudatoolkit
]); ]);
} }
); );

View File

@@ -8,7 +8,7 @@ crate-type = ["cdylib", "staticlib"]
[dependencies] [dependencies]
ndarray = "0.16.1" ndarray = "0.16.1"
# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" } ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" } # ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" } ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
sqlite-loadable = "0.0.5" sqlite-loadable = "0.0.5"

View File

@@ -22,10 +22,12 @@ const FACENET_MODEL_ONNX: &[u8] =
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.onnx")); include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.onnx"));
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter("info") .with_env_filter("info,ort=warn")
.with_thread_ids(true) // .with_thread_ids(true)
.with_thread_names(true) // .with_thread_names(true)
.with_target(false) .with_file(true)
.with_line_number(true)
.with_target(true)
.init(); .init();
let args = <cli::Cli as clap::Parser>::parse(); let args = <cli::Cli as clap::Parser>::parse();
match args.cmd { match args.cmd {

View File

@@ -2,9 +2,9 @@ use detector::errors::*;
fn main() -> Result<()> { fn main() -> Result<()> {
// Initialize logging // Initialize logging
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter("info") .with_env_filter("warn,ort=warn")
// .with_thread_ids(true) .with_file(true)
// .with_file(true) .with_line_number(true)
// .with_thread_names(true) // .with_thread_names(true)
.with_target(true) .with_target(true)
.init(); .init();

View File

@@ -1,5 +1,5 @@
use iced::{ use iced::{
Alignment, Element, Length, Task, Theme, Alignment, Element, Length, Settings, Task, Theme,
widget::{ widget::{
Space, button, column, container, image, pick_list, progress_bar, row, scrollable, slider, Space, button, column, container, image, pick_list, progress_bar, row, scrollable, slider,
text, text,
@@ -57,9 +57,13 @@ pub enum Tab {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum ExecutorType { pub enum ExecutorType {
MnnCpu, MnnCpu,
#[cfg(feature = "mnn-metal")]
MnnMetal, MnnMetal,
#[cfg(feature = "mnn-coreml")]
MnnCoreML, MnnCoreML,
OnnxCpu, OnnxCpu,
#[cfg(feature = "ort-cuda")]
OrtCuda,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -129,7 +133,10 @@ impl Default for FaceDetectorApp {
output_path: None, output_path: None,
threshold: 0.8, threshold: 0.8,
nms_threshold: 0.3, nms_threshold: 0.3,
#[cfg(not(any(feature = "mnn-metal", feature = "ort-cuda")))]
executor_type: ExecutorType::MnnCpu, executor_type: ExecutorType::MnnCpu,
#[cfg(feature = "ort-cuda")]
executor_type: ExecutorType::OrtCuda,
is_processing: false, is_processing: false,
progress: 0.0, progress: 0.0,
status_message: "Ready".to_string(), status_message: "Ready".to_string(),
@@ -939,12 +946,17 @@ impl FaceDetectorApp {
} }
fn settings_view(&self) -> Element<'_, Message> { fn settings_view(&self) -> Element<'_, Message> {
let executor_options = vec![ #[allow(unused_mut)]
ExecutorType::MnnCpu, let mut executor_options = vec![ExecutorType::MnnCpu, ExecutorType::OnnxCpu];
ExecutorType::MnnMetal,
ExecutorType::MnnCoreML, #[cfg(feature = "mnn-metal")]
ExecutorType::OnnxCpu, executor_options.push(ExecutorType::MnnMetal);
];
#[cfg(feature = "mnn-coreml")]
executor_options.push(ExecutorType::MnnCoreML);
#[cfg(feature = "ort-cuda")]
executor_options.push(ExecutorType::OrtCuda);
container( container(
column![ column![
@@ -990,9 +1002,13 @@ impl std::fmt::Display for ExecutorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
ExecutorType::MnnCpu => write!(f, "MNN (CPU)"), ExecutorType::MnnCpu => write!(f, "MNN (CPU)"),
#[cfg(feature = "mnn-metal")]
ExecutorType::MnnMetal => write!(f, "MNN (Metal)"), ExecutorType::MnnMetal => write!(f, "MNN (Metal)"),
#[cfg(feature = "mnn-coreml")]
ExecutorType::MnnCoreML => write!(f, "MNN (CoreML)"), ExecutorType::MnnCoreML => write!(f, "MNN (CoreML)"),
ExecutorType::OnnxCpu => write!(f, "ONNX (CPU)"), ExecutorType::OnnxCpu => write!(f, "ONNX (CPU)"),
#[cfg(feature = "ort-cuda")]
ExecutorType::OrtCuda => write!(f, "ONNX (CUDA)"),
} }
} }
} }
@@ -1023,10 +1039,15 @@ fn convert_face_rois_to_handles(face_rois: Vec<ndarray::Array3<u8>>) -> Vec<imag
} }
pub fn run() -> iced::Result { pub fn run() -> iced::Result {
let settings = Settings {
antialiasing: true,
..Default::default()
};
iced::application( iced::application(
"Face Detector", "Face Detector",
FaceDetectorApp::update, FaceDetectorApp::update,
FaceDetectorApp::view, FaceDetectorApp::view,
) )
.settings(settings)
.run_with(FaceDetectorApp::new) .run_with(FaceDetectorApp::new)
} }

View File

@@ -114,17 +114,34 @@ impl FaceDetectionBridge {
// Create detector and detect faces // Create detector and detect faces
let faces = match executor_type { let faces = match executor_type {
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => { ExecutorType::MnnCpu => {
let forward_type = match executor_type {
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
_ => unreachable!(),
};
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))? .map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(forward_type) .with_forward_type(mnn::ForwardType::CPU)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
#[cfg(feature = "mnn-metal")]
ExecutorType::MnnMetal => {
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::Metal)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
#[cfg(feature = "mnn-coreml")]
ExecutorType::MnnCoreML => {
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::CoreML)
.build() .build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?; .map_err(|e| format!("Failed to build MNN detector: {}", e))?;
@@ -142,6 +159,21 @@ impl FaceDetectionBridge {
.detect_faces(image_array.view(), &config) .detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))? .map_err(|e| format!("Detection failed: {}", e))?
} }
#[cfg(feature = "ort-cuda")]
ExecutorType::OrtCuda => {
use crate::ort_ep::ExecutionProvider;
let ep = ExecutionProvider::CUDA;
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create ONNX CUDA detector: {}", e))?
.with_execution_providers([ep])
.build()
.map_err(|e| format!("Failed to build ONNX CUDA detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("CUDA detection failed: {}", e))?
}
}; };
let faces_count = faces.bbox.len(); let faces_count = faces.bbox.len();
@@ -195,24 +227,17 @@ impl FaceDetectionBridge {
// Create detector and embedder, detect faces and generate embeddings // Create detector and embedder, detect faces and generate embeddings
let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) = let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) =
match executor_type { match executor_type {
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => { ExecutorType::MnnCpu => {
let forward_type = match executor_type {
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
_ => unreachable!(),
};
let mut detector = let mut detector =
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))? .map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(forward_type.clone()) .with_forward_type(mnn::ForwardType::CPU)
.build() .build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?; .map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))? .map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(forward_type) .with_forward_type(mnn::ForwardType::CPU)
.build() .build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?; .map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
@@ -247,7 +272,148 @@ impl FaceDetectionBridge {
best_similarity, best_similarity,
) )
} }
ExecutorType::OnnxCpu => unimplemented!(), #[cfg(feature = "mnn-metal")]
ExecutorType::MnnMetal => {
let mut detector =
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::Metal)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(mnn::ForwardType::Metal)
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
#[cfg(feature = "mnn-coreml")]
ExecutorType::MnnCoreML => {
let mut detector =
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::CoreML)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(mnn::ForwardType::CoreML)
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
ExecutorType::OnnxCpu => unimplemented!("ONNX face comparison not yet implemented"),
#[cfg(feature = "ort-cuda")]
ExecutorType::OrtCuda => {
use crate::ort_ep::ExecutionProvider;
let ep = ExecutionProvider::CUDA;
let mut detector =
retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_execution_providers([ep])
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder =
facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_execution_providers([ep])
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
}; };
Ok(( Ok((

View File

@@ -13,7 +13,7 @@ use ort::execution_providers::TensorRTExecutionProvider;
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}; use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
/// Supported execution providers for ONNX Runtime /// Supported execution providers for ONNX Runtime
#[derive(Debug, Clone)] #[derive(Debug, Copy, Clone)]
pub enum ExecutionProvider { pub enum ExecutionProvider {
/// CPU execution provider (always available) /// CPU execution provider (always available)
CPU, CPU,