diff --git a/Cargo.lock b/Cargo.lock index 88537b3..de2bba7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,7 +211,7 @@ dependencies = [ "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.12.1", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", @@ -504,6 +504,7 @@ dependencies = [ "error-stack", "fast_image_resize", "image", + "itertools 0.14.0", "linfa", "mnn", "mnn-bridge", @@ -512,6 +513,7 @@ dependencies = [ "ndarray 0.16.1", "ndarray-image", "ndarray-resize", + "ordered-float", "rusqlite", "tap", "thiserror 2.0.12", @@ -1487,6 +1489,15 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" +[[package]] +name = "ordered-float" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index eb6a62e..ea89789 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,8 @@ mnn-bridge = { workspace = true } mnn-sync = { workspace = true } bounding-box = { version = "0.1.0", path = "bounding-box" } color = "0.3.1" +itertools = "0.14.0" +ordered-float = "5.0.0" [profile.release] debug = true diff --git a/bounding-box/src/draw.rs b/bounding-box/src/draw.rs index 6d8e044..1f97b2a 100644 --- a/bounding-box/src/draw.rs +++ b/bounding-box/src/draw.rs @@ -57,12 +57,17 @@ impl Drawable> for Aabb2 { lines.into_iter().for_each(|line| { canvas .roi_mut(line) - .expect("Failed to get Roi") - .lanes_mut(ndarray::Axis(2)) - .into_iter() - .for_each(|mut pixel| { - pixel.assign(&color); - }); + .map(|mut line| { + line.lanes_mut(ndarray::Axis(2)) + .into_iter() + .for_each(|mut pixel| { + pixel.assign(&color); + }) + }) + .inspect_err(|e| { + dbg!(e); + }) + .ok(); }); } } diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index d0afe5d..6f36d24 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -1,5 +1,5 @@ use crate::errors::*; -use bounding_box::Aabb2; +use bounding_box::{Aabb2, nms::nms}; use error_stack::ResultExt; use mnn_bridge::ndarray::*; use nalgebra::{Point2, Vector2}; @@ -56,6 +56,7 @@ pub struct FaceDetection { handle: mnn_sync::SessionHandle, } +#[derive(Debug, Clone, PartialEq)] pub struct FaceDetectionModelOutput { pub bbox: ndarray::Array3, pub confidence: ndarray::Array3, @@ -63,6 +64,7 @@ pub struct FaceDetectionModelOutput { } /// Represents the 5 facial landmarks detected by RetinaFace +#[derive(Debug, Copy, Clone, PartialEq)] pub struct FaceLandmarks { pub left_eye: Point2, pub right_eye: Point2, @@ -70,14 +72,23 @@ pub struct FaceLandmarks { pub left_mouth: Point2, pub right_mouth: Point2, } + +#[derive(Debug, Clone, PartialEq)] pub struct FaceDetectionProcessedOutput { pub bbox: Vec>, pub confidence: Vec, pub landmarks: Vec, } +#[derive(Debug, Clone, PartialEq)] +pub struct FaceDetectionOutput { + pub bbox: Vec>, + pub confidence: Vec, + pub landmark: Vec, +} + impl FaceDetectionModelOutput { - pub fn postprocess(self, config: FaceDetectionConfig) -> Result { + pub fn postprocess(self, config: &FaceDetectionConfig) -> Result { let mut anchors = Vec::new(); for (k, &step) in config.steps.iter().enumerate() { let feature_size = 640 / step; @@ -204,7 +215,48 @@ impl FaceDetection { Ok(FaceDetection { handle }) } - pub fn detect_faces(&self, image: ndarray::Array3) -> Result { + pub fn detect_faces( + &self, + image: ndarray::Array3, + config: FaceDetectionConfig, + ) -> Result { + let (height, width, channels) = image.dim(); + let output = self + .run_models(image) + .change_context(Error) + .attach_printable("Failed to detect faces")?; + // denormalize the bounding boxes + let factor = Vector2::new(width as f32, height as f32); + let mut processed = output + .postprocess(&config) + .attach_printable("Failed to postprocess")?; + + use itertools::Itertools; + let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed + .bbox + .iter() + .cloned() + .zip(processed.confidence.iter().cloned()) + .zip(processed.landmarks.iter().cloned()) + .sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score)) + .map(|((b, s), l)| (b, s, l)) + .multiunzip(); + + let boxes = nms(&boxes, &scores, config.threshold, config.nms_threshold); + + let bboxes = boxes + .into_iter() + .flat_map(|x| x.denormalize(factor).try_cast::()) + .collect(); + + Ok(FaceDetectionOutput { + bbox: bboxes, + confidence: processed.confidence, + landmark: processed.landmarks, + }) + } + + pub fn run_models(&self, image: ndarray::Array3) -> Result { #[rustfmt::skip] use ::tap::*; let output = self diff --git a/src/main.rs b/src/main.rs index 56973bb..6fe5ad8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,25 +25,15 @@ pub fn main() -> Result<()> { .change_context(errors::Error) .attach_printable("Failed to convert image to ndarray")?; let output = model - .detect_faces(array.clone()) + .detect_faces( + array.clone(), + FaceDetectionConfig::default().with_threshold(detect.threshold), + ) .change_context(errors::Error) .attach_printable("Failed to detect faces")?; - // output.print(20); - let faces = output - .postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold)) - .change_context(errors::Error) - .attach_printable("Failed to attach context")?; - for bbox in faces.bbox { + for bbox in output.bbox { tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; - let bbox = bbox - .denormalize(nalgebra::SVector::::new( - array.shape()[1] as f32, - array.shape()[0] as f32, - )) - .try_cast() - .ok_or(errors::Error) - .attach_printable("Failed to cast f32 to usize")?; array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10); } let v = array.view();