From e91ae5b8658ef7ab33ff44e8f0ca8b2162a3d835 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Thu, 7 Aug 2025 15:45:54 +0530 Subject: [PATCH] feat: Added a manual implementation of nms --- Cargo.lock | 1 + bounding-box/Cargo.toml | 1 + bounding-box/src/lib.rs | 27 ++++++------ bounding-box/src/nms.rs | 93 +++++++++++++---------------------------- src/main.rs | 4 +- 5 files changed, 46 insertions(+), 80 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de2bba7..05c0241 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,6 +255,7 @@ dependencies = [ "nalgebra", "ndarray 0.16.1", "num", + "ordered-float", "simba", "thiserror 2.0.12", ] diff --git a/bounding-box/Cargo.toml b/bounding-box/Cargo.toml index a3f5525..3169ed9 100644 --- a/bounding-box/Cargo.toml +++ b/bounding-box/Cargo.toml @@ -9,6 +9,7 @@ itertools = "0.14.0" nalgebra = "0.33.2" ndarray = { version = "0.16.1", optional = true } num = "0.4.3" +ordered-float = "5.0.0" simba = "0.9.0" thiserror = "2.0.12" diff --git a/bounding-box/src/lib.rs b/bounding-box/src/lib.rs index 8df0dd5..a60cf6b 100644 --- a/bounding-box/src/lib.rs +++ b/bounding-box/src/lib.rs @@ -51,10 +51,12 @@ pub type Aabb3 = AxisAlignedBoundingBox; impl AxisAlignedBoundingBox { // Panics if max < min pub fn new(min_point: Point, max_point: Point) -> Self { - if max_point < min_point { + if max_point >= min_point { + Self::from_min_max_vertices(min_point, max_point) + } else { + dbg!(max_point, min_point); panic!("max_point must be greater than or equal to min_point"); } - Self::from_min_max_vertices(min_point, max_point) } pub fn try_new(min_point: Point, max_point: Point) -> Option { if max_point < min_point { @@ -66,9 +68,9 @@ impl AxisAlignedBoundingBox { Self { point, size } } - pub fn from_min_max_vertices(point1: Point, point2: Point) -> Self { - let size = point2 - point1; - Self::new_point_size(point1, SVector::from(size)) + pub fn from_min_max_vertices(min: Point, max: Point) -> Self { + let size = max - min; + Self::new_point_size(min, SVector::from(size)) } /// Only considers the points closest and furthest from origin @@ -301,11 +303,11 @@ impl AxisAlignedBoundingBox { let inter_min = lhs_min.sup(&rhs_min); let inter_max = lhs_max.inf(&rhs_max); - if inter_max < inter_min { - return T::zero(); - } else { + if inter_max >= inter_min { let intersection = Aabb::new(inter_min, inter_max).measure(); intersection / (self.measure() + other.measure() - intersection) + } else { + return T::zero(); } } } @@ -605,11 +607,8 @@ mod boudning_box_tests { #[test] fn test_specific_values() { - let res = Vector2::new(1920, 1080).cast(); - let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264).denormalize(res); - let box2 = - Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436).denormalize(res); - dbg!(box1, box2); - assert!(box1.iou(&box2) > 0.0); + let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264); + let box2 = Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436); + assert!(box1.iou(&box2) >= 0.0); } } diff --git a/bounding-box/src/nms.rs b/bounding-box/src/nms.rs index 985921e..18838fe 100644 --- a/bounding-box/src/nms.rs +++ b/bounding-box/src/nms.rs @@ -1,4 +1,6 @@ -use std::collections::HashSet; +use std::collections::{HashSet, VecDeque}; + +use itertools::Itertools; use crate::*; /// Apply Non-Maximum Suppression to a set of bounding boxes. @@ -21,7 +23,8 @@ pub fn nms( ) -> HashSet where T: Num - + num::Float + + ordered_float::FloatCore + + core::ops::Neg + core::iter::Product + core::ops::AddAssign + core::ops::SubAssign @@ -29,71 +32,31 @@ where + nalgebra::SimdValue + nalgebra::SimdPartialOrd, { - let mut indices: Vec = (0..boxes.len()) - .filter(|&i| scores[i] >= score_threshold) + let mut combined: VecDeque<(usize, Aabb2, T, bool)> = boxes + .iter() + .enumerate() + .zip(scores) + .filter_map(|((idx, bbox), score)| { + (*score > score_threshold).then_some((idx, *bbox, *score, true)) + }) + .sorted_by_cached_key(|(_, _, score, _)| -ordered_float::OrderedFloat(*score)) .collect(); - indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap()); - - let mut selected_indices = HashSet::new(); - - while let Some(¤t) = indices.first() { - selected_indices.insert(current); - indices.remove(0); - - indices.retain(|&i| { - let iou = calculate_iou(&boxes[current], &boxes[i]); - let iou_ = boxes[current].iou(&boxes[i]); - if iou != iou_ { - dbg!(boxes[current], boxes[i]); - panic!() - }; - iou < nms_threshold - }); + for i in 0..combined.len() { + let first = combined[i]; + if first.3 == false { + continue; + } + let bbox = first.1; + for item in combined.iter_mut().skip(i + 1) { + if bbox.iou(&item.1) > nms_threshold { + item.3 = false + } + } } - selected_indices -} - -/// Calculate the Intersection over Union (IoU) of two bounding boxes. -/// -/// # Arguments -/// -/// * `box1` - The first bounding box. -/// * `box2` - The second bounding box. -/// -/// # Returns -/// -/// The IoU as a value between 0 and 1. -fn calculate_iou(box1: &Aabb2, box2: &Aabb2) -> T -where - T: Num - + num::Float - + core::iter::Product - + core::ops::AddAssign - + core::ops::SubAssign - + core::ops::MulAssign - + nalgebra::SimdValue - + nalgebra::SimdPartialOrd, -{ - let x_left = box1.min_vertex().x.max(box2.min_vertex().x); - let y_top = box1.min_vertex().y.max(box2.min_vertex().y); - let x_right = box1.max_vertex().x.min(box2.max_vertex().x); - let y_bottom = box1.max_vertex().y.min(box2.max_vertex().y); - - let zero = T::zero(); - let inter_width = (x_right - x_left).max(zero); - let inter_height = (y_bottom - y_top).max(zero); - let intersection = inter_width * inter_height; - - let area1 = box1.area(); - let area2 = box2.area(); - - let union = area1 + area2 - intersection; - - if union > zero { - intersection / union - } else { - zero - } + combined + .into_iter() + .filter_map(|(idx, _, _, keep)| keep.then_some(idx)) + .collect() } diff --git a/src/main.rs b/src/main.rs index cfef0f3..b389049 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,7 +27,9 @@ pub fn main() -> Result<()> { let output = model .detect_faces( array.clone(), - FaceDetectionConfig::default().with_threshold(detect.threshold), + FaceDetectionConfig::default() + .with_threshold(detect.threshold) + .with_nms_threshold(detect.nms_threshold), ) .change_context(errors::Error) .attach_printable("Failed to detect faces")?;