feat(detector): add CUDA support for ONNX face detection
This commit is contained in:
@@ -22,10 +22,12 @@ const FACENET_MODEL_ONNX: &[u8] =
|
||||
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.onnx"));
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("info")
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_target(false)
|
||||
.with_env_filter("info,ort=warn")
|
||||
// .with_thread_ids(true)
|
||||
// .with_thread_names(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_target(true)
|
||||
.init();
|
||||
let args = <cli::Cli as clap::Parser>::parse();
|
||||
match args.cmd {
|
||||
|
||||
@@ -2,9 +2,9 @@ use detector::errors::*;
|
||||
fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("info")
|
||||
// .with_thread_ids(true)
|
||||
// .with_file(true)
|
||||
.with_env_filter("warn,ort=warn")
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
// .with_thread_names(true)
|
||||
.with_target(true)
|
||||
.init();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use iced::{
|
||||
Alignment, Element, Length, Task, Theme,
|
||||
Alignment, Element, Length, Settings, Task, Theme,
|
||||
widget::{
|
||||
Space, button, column, container, image, pick_list, progress_bar, row, scrollable, slider,
|
||||
text,
|
||||
@@ -57,9 +57,13 @@ pub enum Tab {
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ExecutorType {
|
||||
MnnCpu,
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
MnnMetal,
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
MnnCoreML,
|
||||
OnnxCpu,
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
OrtCuda,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -129,7 +133,10 @@ impl Default for FaceDetectorApp {
|
||||
output_path: None,
|
||||
threshold: 0.8,
|
||||
nms_threshold: 0.3,
|
||||
#[cfg(not(any(feature = "mnn-metal", feature = "ort-cuda")))]
|
||||
executor_type: ExecutorType::MnnCpu,
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
executor_type: ExecutorType::OrtCuda,
|
||||
is_processing: false,
|
||||
progress: 0.0,
|
||||
status_message: "Ready".to_string(),
|
||||
@@ -939,12 +946,17 @@ impl FaceDetectorApp {
|
||||
}
|
||||
|
||||
fn settings_view(&self) -> Element<'_, Message> {
|
||||
let executor_options = vec![
|
||||
ExecutorType::MnnCpu,
|
||||
ExecutorType::MnnMetal,
|
||||
ExecutorType::MnnCoreML,
|
||||
ExecutorType::OnnxCpu,
|
||||
];
|
||||
#[allow(unused_mut)]
|
||||
let mut executor_options = vec![ExecutorType::MnnCpu, ExecutorType::OnnxCpu];
|
||||
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
executor_options.push(ExecutorType::MnnMetal);
|
||||
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
executor_options.push(ExecutorType::MnnCoreML);
|
||||
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
executor_options.push(ExecutorType::OrtCuda);
|
||||
|
||||
container(
|
||||
column![
|
||||
@@ -990,9 +1002,13 @@ impl std::fmt::Display for ExecutorType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ExecutorType::MnnCpu => write!(f, "MNN (CPU)"),
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
ExecutorType::MnnMetal => write!(f, "MNN (Metal)"),
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
ExecutorType::MnnCoreML => write!(f, "MNN (CoreML)"),
|
||||
ExecutorType::OnnxCpu => write!(f, "ONNX (CPU)"),
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutorType::OrtCuda => write!(f, "ONNX (CUDA)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1023,10 +1039,15 @@ fn convert_face_rois_to_handles(face_rois: Vec<ndarray::Array3<u8>>) -> Vec<imag
|
||||
}
|
||||
|
||||
pub fn run() -> iced::Result {
|
||||
let settings = Settings {
|
||||
antialiasing: true,
|
||||
..Default::default()
|
||||
};
|
||||
iced::application(
|
||||
"Face Detector",
|
||||
FaceDetectorApp::update,
|
||||
FaceDetectorApp::view,
|
||||
)
|
||||
.settings(settings)
|
||||
.run_with(FaceDetectorApp::new)
|
||||
}
|
||||
|
||||
@@ -114,17 +114,34 @@ impl FaceDetectionBridge {
|
||||
|
||||
// Create detector and detect faces
|
||||
let faces = match executor_type {
|
||||
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => {
|
||||
let forward_type = match executor_type {
|
||||
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
|
||||
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
|
||||
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ExecutorType::MnnCpu => {
|
||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(forward_type)
|
||||
.with_forward_type(mnn::ForwardType::CPU)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
ExecutorType::MnnMetal => {
|
||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::Metal)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
ExecutorType::MnnCoreML => {
|
||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CoreML)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
@@ -142,6 +159,21 @@ impl FaceDetectionBridge {
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutorType::OrtCuda => {
|
||||
use crate::ort_ep::ExecutionProvider;
|
||||
|
||||
let ep = ExecutionProvider::CUDA;
|
||||
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create ONNX CUDA detector: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build ONNX CUDA detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("CUDA detection failed: {}", e))?
|
||||
}
|
||||
};
|
||||
|
||||
let faces_count = faces.bbox.len();
|
||||
@@ -195,24 +227,17 @@ impl FaceDetectionBridge {
|
||||
// Create detector and embedder, detect faces and generate embeddings
|
||||
let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) =
|
||||
match executor_type {
|
||||
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => {
|
||||
let forward_type = match executor_type {
|
||||
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
|
||||
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
|
||||
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ExecutorType::MnnCpu => {
|
||||
let mut detector =
|
||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(forward_type.clone())
|
||||
.with_forward_type(mnn::ForwardType::CPU)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(forward_type)
|
||||
.with_forward_type(mnn::ForwardType::CPU)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
@@ -247,7 +272,148 @@ impl FaceDetectionBridge {
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
ExecutorType::OnnxCpu => unimplemented!(),
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
ExecutorType::MnnMetal => {
|
||||
let mut detector =
|
||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::Metal)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::Metal)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
ExecutorType::MnnCoreML => {
|
||||
let mut detector =
|
||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CoreML)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CoreML)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
ExecutorType::OnnxCpu => unimplemented!("ONNX face comparison not yet implemented"),
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutorType::OrtCuda => {
|
||||
use crate::ort_ep::ExecutionProvider;
|
||||
let ep = ExecutionProvider::CUDA;
|
||||
let mut detector =
|
||||
retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder =
|
||||
facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok((
|
||||
|
||||
@@ -13,7 +13,7 @@ use ort::execution_providers::TensorRTExecutionProvider;
|
||||
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||
|
||||
/// Supported execution providers for ONNX Runtime
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum ExecutionProvider {
|
||||
/// CPU execution provider (always available)
|
||||
CPU,
|
||||
|
||||
Reference in New Issue
Block a user