feat: Added nms

This commit is contained in:
uttarayan21
2025-08-05 15:36:47 +05:30
parent 42ac210bba
commit 06fb0b4487
5 changed files with 85 additions and 25 deletions

13
Cargo.lock generated
View File

@@ -211,7 +211,7 @@ dependencies = [
"bitflags 2.9.1", "bitflags 2.9.1",
"cexpr", "cexpr",
"clang-sys", "clang-sys",
"itertools 0.12.1", "itertools 0.13.0",
"log", "log",
"prettyplease", "prettyplease",
"proc-macro2", "proc-macro2",
@@ -504,6 +504,7 @@ dependencies = [
"error-stack", "error-stack",
"fast_image_resize", "fast_image_resize",
"image", "image",
"itertools 0.14.0",
"linfa", "linfa",
"mnn", "mnn",
"mnn-bridge", "mnn-bridge",
@@ -512,6 +513,7 @@ dependencies = [
"ndarray 0.16.1", "ndarray 0.16.1",
"ndarray-image", "ndarray-image",
"ndarray-resize", "ndarray-resize",
"ordered-float",
"rusqlite", "rusqlite",
"tap", "tap",
"thiserror 2.0.12", "thiserror 2.0.12",
@@ -1487,6 +1489,15 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
name = "ordered-float"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"

View File

@@ -51,6 +51,8 @@ mnn-bridge = { workspace = true }
mnn-sync = { workspace = true } mnn-sync = { workspace = true }
bounding-box = { version = "0.1.0", path = "bounding-box" } bounding-box = { version = "0.1.0", path = "bounding-box" }
color = "0.3.1" color = "0.3.1"
itertools = "0.14.0"
ordered-float = "5.0.0"
[profile.release] [profile.release]
debug = true debug = true

View File

@@ -57,12 +57,17 @@ impl Drawable<Array3<u8>> for Aabb2<usize> {
lines.into_iter().for_each(|line| { lines.into_iter().for_each(|line| {
canvas canvas
.roi_mut(line) .roi_mut(line)
.expect("Failed to get Roi") .map(|mut line| {
.lanes_mut(ndarray::Axis(2)) line.lanes_mut(ndarray::Axis(2))
.into_iter() .into_iter()
.for_each(|mut pixel| { .for_each(|mut pixel| {
pixel.assign(&color); pixel.assign(&color);
}); })
})
.inspect_err(|e| {
dbg!(e);
})
.ok();
}); });
} }
} }

View File

@@ -1,5 +1,5 @@
use crate::errors::*; use crate::errors::*;
use bounding_box::Aabb2; use bounding_box::{Aabb2, nms::nms};
use error_stack::ResultExt; use error_stack::ResultExt;
use mnn_bridge::ndarray::*; use mnn_bridge::ndarray::*;
use nalgebra::{Point2, Vector2}; use nalgebra::{Point2, Vector2};
@@ -56,6 +56,7 @@ pub struct FaceDetection {
handle: mnn_sync::SessionHandle, handle: mnn_sync::SessionHandle,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionModelOutput { pub struct FaceDetectionModelOutput {
pub bbox: ndarray::Array3<f32>, pub bbox: ndarray::Array3<f32>,
pub confidence: ndarray::Array3<f32>, pub confidence: ndarray::Array3<f32>,
@@ -63,6 +64,7 @@ pub struct FaceDetectionModelOutput {
} }
/// Represents the 5 facial landmarks detected by RetinaFace /// Represents the 5 facial landmarks detected by RetinaFace
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct FaceLandmarks { pub struct FaceLandmarks {
pub left_eye: Point2<f32>, pub left_eye: Point2<f32>,
pub right_eye: Point2<f32>, pub right_eye: Point2<f32>,
@@ -70,14 +72,23 @@ pub struct FaceLandmarks {
pub left_mouth: Point2<f32>, pub left_mouth: Point2<f32>,
pub right_mouth: Point2<f32>, pub right_mouth: Point2<f32>,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionProcessedOutput { pub struct FaceDetectionProcessedOutput {
pub bbox: Vec<Aabb2<f32>>, pub bbox: Vec<Aabb2<f32>>,
pub confidence: Vec<f32>, pub confidence: Vec<f32>,
pub landmarks: Vec<FaceLandmarks>, pub landmarks: Vec<FaceLandmarks>,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionOutput {
pub bbox: Vec<Aabb2<usize>>,
pub confidence: Vec<f32>,
pub landmark: Vec<FaceLandmarks>,
}
impl FaceDetectionModelOutput { impl FaceDetectionModelOutput {
pub fn postprocess(self, config: FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> { pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
let mut anchors = Vec::new(); let mut anchors = Vec::new();
for (k, &step) in config.steps.iter().enumerate() { for (k, &step) in config.steps.iter().enumerate() {
let feature_size = 640 / step; let feature_size = 640 / step;
@@ -204,7 +215,48 @@ impl FaceDetection {
Ok(FaceDetection { handle }) Ok(FaceDetection { handle })
} }
pub fn detect_faces(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> { pub fn detect_faces(
&self,
image: ndarray::Array3<u8>,
config: FaceDetectionConfig,
) -> Result<FaceDetectionOutput> {
let (height, width, channels) = image.dim();
let output = self
.run_models(image)
.change_context(Error)
.attach_printable("Failed to detect faces")?;
// denormalize the bounding boxes
let factor = Vector2::new(width as f32, height as f32);
let mut processed = output
.postprocess(&config)
.attach_printable("Failed to postprocess")?;
use itertools::Itertools;
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
.bbox
.iter()
.cloned()
.zip(processed.confidence.iter().cloned())
.zip(processed.landmarks.iter().cloned())
.sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score))
.map(|((b, s), l)| (b, s, l))
.multiunzip();
let boxes = nms(&boxes, &scores, config.threshold, config.nms_threshold);
let bboxes = boxes
.into_iter()
.flat_map(|x| x.denormalize(factor).try_cast::<usize>())
.collect();
Ok(FaceDetectionOutput {
bbox: bboxes,
confidence: processed.confidence,
landmark: processed.landmarks,
})
}
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
#[rustfmt::skip] #[rustfmt::skip]
use ::tap::*; use ::tap::*;
let output = self let output = self

View File

@@ -25,25 +25,15 @@ pub fn main() -> Result<()> {
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?; .attach_printable("Failed to convert image to ndarray")?;
let output = model let output = model
.detect_faces(array.clone()) .detect_faces(
array.clone(),
FaceDetectionConfig::default().with_threshold(detect.threshold),
)
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to detect faces")?; .attach_printable("Failed to detect faces")?;
// output.print(20); for bbox in output.bbox {
let faces = output
.postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold))
.change_context(errors::Error)
.attach_printable("Failed to attach context")?;
for bbox in faces.bbox {
tracing::info!("Detected face: {:?}", bbox); tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*; use bounding_box::draw::*;
let bbox = bbox
.denormalize(nalgebra::SVector::<f32, 2>::new(
array.shape()[1] as f32,
array.shape()[0] as f32,
))
.try_cast()
.ok_or(errors::Error)
.attach_printable("Failed to cast f32 to usize")?;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10); array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10);
} }
let v = array.view(); let v = array.view();