diff --git a/src/gui/app.rs b/src/gui/app.rs index a537192..642800c 100644 --- a/src/gui/app.rs +++ b/src/gui/app.rs @@ -76,6 +76,8 @@ pub enum ComparisonResult { Success { image1_faces: usize, image2_faces: usize, + image1_face_rois: Vec>, + image2_face_rois: Vec>, best_similarity: f32, processing_time: f64, }, @@ -765,6 +767,8 @@ impl FaceDetectorApp { ComparisonResult::Success { image1_faces, image2_faces, + image1_face_rois: _, + image2_face_rois: _, best_similarity, processing_time, } => { diff --git a/src/gui/bridge.rs b/src/gui/bridge.rs index beece2e..5b28770 100644 --- a/src/gui/bridge.rs +++ b/src/gui/bridge.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use crate::errors; use crate::facedet::{FaceDetectionConfig, FaceDetector, retinaface}; -use crate::faceembed::{FaceNetEmbedder, facenet}; +use crate::faceembed::facenet; use crate::gui::app::{ComparisonResult, DetectionResult, ExecutorType}; use bounding_box::Aabb2; use bounding_box::roi::MultiRoi as _; @@ -70,11 +70,19 @@ impl FaceDetectionBridge { ) .await { - Ok((image1_faces, image2_faces, best_similarity)) => { + Ok(( + image1_faces, + image2_faces, + image1_face_rois, + image2_face_rois, + best_similarity, + )) => { let processing_time = start_time.elapsed().as_secs_f64(); ComparisonResult::Success { image1_faces, image2_faces, + image1_face_rois, + image2_face_rois, best_similarity, processing_time, } @@ -180,53 +188,75 @@ impl FaceDetectionBridge { threshold: f32, nms_threshold: f32, executor_type: ExecutorType, - ) -> Result<(usize, usize, f32), Box> { + ) -> Result< + (usize, usize, Vec>, Vec>, f32), + Box, + > { // Create detector and embedder, detect faces and generate embeddings - let (faces1, faces2, best_similarity) = match executor_type { - ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => { - let forward_type = match executor_type { - ExecutorType::MnnCpu => mnn::ForwardType::CPU, - ExecutorType::MnnMetal => mnn::ForwardType::Metal, - ExecutorType::MnnCoreML => mnn::ForwardType::CoreML, - _ => unreachable!(), - }; + let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) = + match executor_type { + ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => { + let forward_type = match executor_type { + ExecutorType::MnnCpu => mnn::ForwardType::CPU, + ExecutorType::MnnMetal => mnn::ForwardType::Metal, + ExecutorType::MnnCoreML => mnn::ForwardType::CoreML, + _ => unreachable!(), + }; - let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) - .map_err(|e| format!("Failed to create MNN detector: {}", e))? - .with_forward_type(forward_type.clone()) - .build() - .map_err(|e| format!("Failed to build MNN detector: {}", e))?; + let mut detector = + retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) + .map_err(|e| format!("Failed to create MNN detector: {}", e))? + .with_forward_type(forward_type.clone()) + .build() + .map_err(|e| format!("Failed to build MNN detector: {}", e))?; - let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) - .map_err(|e| format!("Failed to create MNN embedder: {}", e))? - .with_forward_type(forward_type) - .build() - .map_err(|e| format!("Failed to build MNN embedder: {}", e))?; + let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) + .map_err(|e| format!("Failed to create MNN embedder: {}", e))? + .with_forward_type(forward_type) + .build() + .map_err(|e| format!("Failed to build MNN embedder: {}", e))?; - let img_1 = run_detection( - image1_path, - &mut detector, - &mut embedder, - threshold, - nms_threshold, - 2, - )?; - let img_2 = run_detection( - image2_path, - &mut detector, - &mut embedder, - threshold, - nms_threshold, - 2, - )?; + let img_1 = run_detection( + image1_path, + &mut detector, + &mut embedder, + threshold, + nms_threshold, + 2, + )?; + let img_2 = run_detection( + image2_path, + &mut detector, + &mut embedder, + threshold, + nms_threshold, + 2, + )?; - let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?; - (img_1, img_2, best_similarity) - } - ExecutorType::OnnxCpu => unimplemented!(), - }; + let image1_rois = img_1.rois; + let image2_rois = img_2.rois; + let image1_bbox_len = img_1.bbox.len(); + let image2_bbox_len = img_2.bbox.len(); + let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?; - Ok((faces1.bbox.len(), faces2.bbox.len(), best_similarity)) + ( + image1_bbox_len, + image2_bbox_len, + image1_rois, + image2_rois, + best_similarity, + ) + } + ExecutorType::OnnxCpu => unimplemented!(), + }; + + Ok(( + image1_faces, + image2_faces, + image1_rois, + image2_rois, + best_similarity, + )) } } @@ -308,7 +338,7 @@ where .into_iter() .map(|roi| { roi.as_standard_layout() - .fast_resize(320, 320, &ResizeOptions::default()) + .fast_resize(224, 224, &ResizeOptions::default()) .change_context(Error) }) .collect::>>()?; @@ -322,7 +352,7 @@ where let og_size = chunk.len(); if chunk.len() < chunk_size { tracing::warn!("Chunk size is less than 8, padding with zeros"); - let zeros = Array3::zeros((320, 320, 3)); + let zeros = Array3::zeros((224, 224, 3)); let chunk: Vec<_> = chunk .iter() .map(|arr| arr.reborrow())