From 561fb2a9244bae14223afdbfa5618f6f9a4aa6c8 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 5 Aug 2025 13:39:15 +0530 Subject: [PATCH] feat: Added threshold for scores and nms --- bounding-box/src/lib.rs | 2 +- src/cli.rs | 2 ++ src/facedet/retinaface.rs | 48 +++++++++++++++++++++++++++++++++------ src/main.rs | 5 ++-- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/bounding-box/src/lib.rs b/bounding-box/src/lib.rs index cde1095..a91262f 100644 --- a/bounding-box/src/lib.rs +++ b/bounding-box/src/lib.rs @@ -2,7 +2,6 @@ pub mod draw; pub mod nms; pub mod roi; -use itertools::Itertools; use nalgebra::{Point, Point2, Point3, SVector}; pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {} impl Num for T {} @@ -458,6 +457,7 @@ fn test_bounding_box_contains_point() { let point1 = Point2::new(2, 3); let point2 = Point2::new(5, 4); let bbox = AxisAlignedBoundingBox::new_2d(point1, point2); + use itertools::Itertools; for (i, j) in (0..=10).cartesian_product(0..=10) { if bbox.contains_point(&Point2::new(i, j)) { if !(2..=5).contains(&i) && !(3..=4).contains(&j) { diff --git a/src/cli.rs b/src/cli.rs index 0be2112..3e88032 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -47,6 +47,8 @@ pub struct Detect { pub model_type: Models, #[clap(short, long)] pub output: Option, + #[clap(short, long, default_value_t = 0.8)] + pub threshold: f32, pub image: PathBuf, } diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index 40c12e9..e8ca9c7 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -10,6 +10,31 @@ pub struct FaceDetectionConfig { min_sizes: Vec>, steps: Vec, variance: Vec, + threshold: f32, + nms_threshold: f32, +} + +impl FaceDetectionConfig { + pub fn with_min_sizes(mut self, min_sizes: Vec>) -> Self { + self.min_sizes = min_sizes; + self + } + pub fn with_steps(mut self, steps: Vec) -> Self { + self.steps = steps; + self + } + pub fn with_variance(mut self, variance: Vec) -> Self { + self.variance = variance; + self + } + pub fn with_threshold(mut self, threshold: f32) -> Self { + self.threshold = threshold; + self + } + pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self { + self.nms_threshold = nms_threshold; + self + } } impl Default for FaceDetectionConfig { @@ -22,6 +47,8 @@ impl Default for FaceDetectionConfig { ], steps: vec![8, 16, 32], variance: vec![0.1, 0.2], + threshold: 0.8, + nms_threshold: 0.6, } } } @@ -35,8 +62,13 @@ pub struct FaceDetectionModelOutput { pub landmark: ndarray::Array3, } +pub struct FaceDetectionProcessedOutput { + pub bbox: Vec>, + pub confidence: Vec, +} + impl FaceDetectionModelOutput { - pub fn postprocess(self, config: FaceDetectionConfig) -> Result>> { + pub fn postprocess(self, config: FaceDetectionConfig) -> Result { let mut anchors = Vec::new(); for (k, &step) in config.steps.iter().enumerate() { let feature_size = 640 / step; @@ -54,6 +86,7 @@ impl FaceDetectionModelOutput { } } let mut boxes = Vec::new(); + let mut scores = Vec::new(); let var0 = config.variance[0]; let var1 = config.variance[1]; let bbox_data = self.bbox; @@ -74,14 +107,15 @@ impl FaceDetectionModelOutput { let x_max = pred_cx + pred_w / 2.0; let y_max = pred_cy + pred_h / 2.0; let score = conf_data[[0, idx, 1]]; - if score > 0.6 { - boxes.push(Aabb2::from_min_max_vertices( - Point2::new(x_min, y_min), - Point2::new(x_max, y_max), - )); + if score > config.threshold { + boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max)); + scores.push(score); } } - Ok(boxes) + Ok(FaceDetectionProcessedOutput { + bbox: boxes, + confidence: scores, + }) } } diff --git a/src/main.rs b/src/main.rs index 616d1c1..ab5922a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod cli; mod errors; +use detector::facedet::retinaface::FaceDetectionConfig; use errors::*; use ndarray_image::*; const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn"); @@ -29,11 +30,11 @@ pub fn main() -> Result<()> { .attach_printable("Failed to detect faces")?; // output.print(20); let aabbs = output - .postprocess(Default::default()) + .postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold)) .change_context(errors::Error) .attach_printable("Failed to attach context")?; for bbox in aabbs { - println!("Detected face: {:?}", bbox); + tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; let bbox = bbox .denormalize(nalgebra::SVector::::new(