feat: Added threshold for scores and nms

This commit is contained in:
uttarayan21
2025-08-05 13:39:15 +05:30
parent bcb7c94390
commit 561fb2a924
4 changed files with 47 additions and 10 deletions

View File

@@ -2,7 +2,6 @@ pub mod draw;
pub mod nms; pub mod nms;
pub mod roi; pub mod roi;
use itertools::Itertools;
use nalgebra::{Point, Point2, Point3, SVector}; use nalgebra::{Point, Point2, Point3, SVector};
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {} pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {} impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
@@ -458,6 +457,7 @@ fn test_bounding_box_contains_point() {
let point1 = Point2::new(2, 3); let point1 = Point2::new(2, 3);
let point2 = Point2::new(5, 4); let point2 = Point2::new(5, 4);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2); let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
use itertools::Itertools;
for (i, j) in (0..=10).cartesian_product(0..=10) { for (i, j) in (0..=10).cartesian_product(0..=10) {
if bbox.contains_point(&Point2::new(i, j)) { if bbox.contains_point(&Point2::new(i, j)) {
if !(2..=5).contains(&i) && !(3..=4).contains(&j) { if !(2..=5).contains(&i) && !(3..=4).contains(&j) {

View File

@@ -47,6 +47,8 @@ pub struct Detect {
pub model_type: Models, pub model_type: Models,
#[clap(short, long)] #[clap(short, long)]
pub output: Option<PathBuf>, pub output: Option<PathBuf>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
pub image: PathBuf, pub image: PathBuf,
} }

View File

@@ -10,6 +10,31 @@ pub struct FaceDetectionConfig {
min_sizes: Vec<Vector2<usize>>, min_sizes: Vec<Vector2<usize>>,
steps: Vec<usize>, steps: Vec<usize>,
variance: Vec<f32>, variance: Vec<f32>,
threshold: f32,
nms_threshold: f32,
}
impl FaceDetectionConfig {
pub fn with_min_sizes(mut self, min_sizes: Vec<Vector2<usize>>) -> Self {
self.min_sizes = min_sizes;
self
}
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
self.steps = steps;
self
}
pub fn with_variance(mut self, variance: Vec<f32>) -> Self {
self.variance = variance;
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
self.nms_threshold = nms_threshold;
self
}
} }
impl Default for FaceDetectionConfig { impl Default for FaceDetectionConfig {
@@ -22,6 +47,8 @@ impl Default for FaceDetectionConfig {
], ],
steps: vec![8, 16, 32], steps: vec![8, 16, 32],
variance: vec![0.1, 0.2], variance: vec![0.1, 0.2],
threshold: 0.8,
nms_threshold: 0.6,
} }
} }
} }
@@ -35,8 +62,13 @@ pub struct FaceDetectionModelOutput {
pub landmark: ndarray::Array3<f32>, pub landmark: ndarray::Array3<f32>,
} }
pub struct FaceDetectionProcessedOutput {
pub bbox: Vec<Aabb2<f32>>,
pub confidence: Vec<f32>,
}
impl FaceDetectionModelOutput { impl FaceDetectionModelOutput {
pub fn postprocess(self, config: FaceDetectionConfig) -> Result<Vec<Aabb2<f32>>> { pub fn postprocess(self, config: FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
let mut anchors = Vec::new(); let mut anchors = Vec::new();
for (k, &step) in config.steps.iter().enumerate() { for (k, &step) in config.steps.iter().enumerate() {
let feature_size = 640 / step; let feature_size = 640 / step;
@@ -54,6 +86,7 @@ impl FaceDetectionModelOutput {
} }
} }
let mut boxes = Vec::new(); let mut boxes = Vec::new();
let mut scores = Vec::new();
let var0 = config.variance[0]; let var0 = config.variance[0];
let var1 = config.variance[1]; let var1 = config.variance[1];
let bbox_data = self.bbox; let bbox_data = self.bbox;
@@ -74,14 +107,15 @@ impl FaceDetectionModelOutput {
let x_max = pred_cx + pred_w / 2.0; let x_max = pred_cx + pred_w / 2.0;
let y_max = pred_cy + pred_h / 2.0; let y_max = pred_cy + pred_h / 2.0;
let score = conf_data[[0, idx, 1]]; let score = conf_data[[0, idx, 1]];
if score > 0.6 { if score > config.threshold {
boxes.push(Aabb2::from_min_max_vertices( boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max));
Point2::new(x_min, y_min), scores.push(score);
Point2::new(x_max, y_max),
));
} }
} }
Ok(boxes) Ok(FaceDetectionProcessedOutput {
bbox: boxes,
confidence: scores,
})
} }
} }

View File

@@ -1,5 +1,6 @@
mod cli; mod cli;
mod errors; mod errors;
use detector::facedet::retinaface::FaceDetectionConfig;
use errors::*; use errors::*;
use ndarray_image::*; use ndarray_image::*;
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn"); const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
@@ -29,11 +30,11 @@ pub fn main() -> Result<()> {
.attach_printable("Failed to detect faces")?; .attach_printable("Failed to detect faces")?;
// output.print(20); // output.print(20);
let aabbs = output let aabbs = output
.postprocess(Default::default()) .postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold))
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to attach context")?; .attach_printable("Failed to attach context")?;
for bbox in aabbs { for bbox in aabbs {
println!("Detected face: {:?}", bbox); tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*; use bounding_box::draw::*;
let bbox = bbox let bbox = bbox
.denormalize(nalgebra::SVector::<f32, 2>::new( .denormalize(nalgebra::SVector::<f32, 2>::new(