feat: Added nms
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::errors::*;
|
||||
use bounding_box::Aabb2;
|
||||
use bounding_box::{Aabb2, nms::nms};
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use nalgebra::{Point2, Vector2};
|
||||
@@ -56,6 +56,7 @@ pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionModelOutput {
|
||||
pub bbox: ndarray::Array3<f32>,
|
||||
pub confidence: ndarray::Array3<f32>,
|
||||
@@ -63,6 +64,7 @@ pub struct FaceDetectionModelOutput {
|
||||
}
|
||||
|
||||
/// Represents the 5 facial landmarks detected by RetinaFace
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct FaceLandmarks {
|
||||
pub left_eye: Point2<f32>,
|
||||
pub right_eye: Point2<f32>,
|
||||
@@ -70,14 +72,23 @@ pub struct FaceLandmarks {
|
||||
pub left_mouth: Point2<f32>,
|
||||
pub right_mouth: Point2<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionProcessedOutput {
|
||||
pub bbox: Vec<Aabb2<f32>>,
|
||||
pub confidence: Vec<f32>,
|
||||
pub landmarks: Vec<FaceLandmarks>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionOutput {
|
||||
pub bbox: Vec<Aabb2<usize>>,
|
||||
pub confidence: Vec<f32>,
|
||||
pub landmark: Vec<FaceLandmarks>,
|
||||
}
|
||||
|
||||
impl FaceDetectionModelOutput {
|
||||
pub fn postprocess(self, config: FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||
let mut anchors = Vec::new();
|
||||
for (k, &step) in config.steps.iter().enumerate() {
|
||||
let feature_size = 640 / step;
|
||||
@@ -204,7 +215,48 @@ impl FaceDetection {
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
|
||||
pub fn detect_faces(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
pub fn detect_faces(
|
||||
&self,
|
||||
image: ndarray::Array3<u8>,
|
||||
config: FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, channels) = image.dim();
|
||||
let output = self
|
||||
.run_models(image)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
// denormalize the bounding boxes
|
||||
let factor = Vector2::new(width as f32, height as f32);
|
||||
let mut processed = output
|
||||
.postprocess(&config)
|
||||
.attach_printable("Failed to postprocess")?;
|
||||
|
||||
use itertools::Itertools;
|
||||
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
||||
.bbox
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(processed.confidence.iter().cloned())
|
||||
.zip(processed.landmarks.iter().cloned())
|
||||
.sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score))
|
||||
.map(|((b, s), l)| (b, s, l))
|
||||
.multiunzip();
|
||||
|
||||
let boxes = nms(&boxes, &scores, config.threshold, config.nms_threshold);
|
||||
|
||||
let bboxes = boxes
|
||||
.into_iter()
|
||||
.flat_map(|x| x.denormalize(factor).try_cast::<usize>())
|
||||
.collect();
|
||||
|
||||
Ok(FaceDetectionOutput {
|
||||
bbox: bboxes,
|
||||
confidence: processed.confidence,
|
||||
landmark: processed.landmarks,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
|
||||
20
src/main.rs
20
src/main.rs
@@ -25,25 +25,15 @@ pub fn main() -> Result<()> {
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = model
|
||||
.detect_faces(array.clone())
|
||||
.detect_faces(
|
||||
array.clone(),
|
||||
FaceDetectionConfig::default().with_threshold(detect.threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
// output.print(20);
|
||||
let faces = output
|
||||
.postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold))
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to attach context")?;
|
||||
for bbox in faces.bbox {
|
||||
for bbox in output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
let bbox = bbox
|
||||
.denormalize(nalgebra::SVector::<f32, 2>::new(
|
||||
array.shape()[1] as f32,
|
||||
array.shape()[0] as f32,
|
||||
))
|
||||
.try_cast()
|
||||
.ok_or(errors::Error)
|
||||
.attach_printable("Failed to cast f32 to usize")?;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10);
|
||||
}
|
||||
let v = array.view();
|
||||
|
||||
Reference in New Issue
Block a user