From 043a845fc14c1b064d0806b13617b461532edfbe Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 5 Aug 2025 18:14:31 +0530 Subject: [PATCH] feat: Remove bbox crate and use 1024 for image size --- bbox/Cargo.toml | 13 - bbox/src/lib.rs | 708 -------------------------------------- bbox/src/traits.rs | 2 - bbox/src/traits/max.rs | 27 -- bbox/src/traits/min.rs | 27 -- bounding-box/src/draw.rs | 5 +- bounding-box/src/nms.rs | 93 ++--- src/cli.rs | 3 + src/facedet/retinaface.rs | 48 ++- src/main.rs | 2 +- 10 files changed, 89 insertions(+), 839 deletions(-) delete mode 100644 bbox/Cargo.toml delete mode 100644 bbox/src/lib.rs delete mode 100644 bbox/src/traits.rs delete mode 100644 bbox/src/traits/max.rs delete mode 100644 bbox/src/traits/min.rs diff --git a/bbox/Cargo.toml b/bbox/Cargo.toml deleted file mode 100644 index 5f54ac3..0000000 --- a/bbox/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "bbox" -version = "0.1.0" -edition = "2024" - -[dependencies] -ndarray = "0.16" -num = "0.4.3" -serde = { version = "1", features = ["derive"], optional = true } - -[features] -serde = ["dep:serde"] -default = ["serde"] diff --git a/bbox/src/lib.rs b/bbox/src/lib.rs deleted file mode 100644 index 0bba68f..0000000 --- a/bbox/src/lib.rs +++ /dev/null @@ -1,708 +0,0 @@ -pub mod traits; - -/// A bounding box of co-ordinates whose origin is at the top-left corner. -#[derive( - Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Hash, serde::Serialize, serde::Deserialize, -)] -#[non_exhaustive] -pub struct BBox { - pub x: T, - pub y: T, - pub width: T, - pub height: T, -} - -impl From<[T; 4]> for BBox { - fn from([x, y, width, height]: [T; 4]) -> Self { - Self { - x, - y, - width, - height, - } - } -} - -impl BBox { - pub fn new(x: T, y: T, width: T, height: T) -> Self { - Self { - x, - y, - width, - height, - } - } - - /// Casts the internal values to another type using [as] keyword - pub fn cast(self) -> BBox - where - T: num::cast::AsPrimitive, - T2: Copy + 'static, - { - BBox { - x: self.x.as_(), - y: self.y.as_(), - width: self.width.as_(), - height: self.height.as_(), - } - } - - /// Clamps all the internal values to the given min and max. - pub fn clamp(&self, min: T, max: T) -> Self - where - T: std::cmp::PartialOrd, - { - Self { - x: num::clamp(self.x, min, max), - y: num::clamp(self.y, min, max), - width: num::clamp(self.width, min, max), - height: num::clamp(self.height, min, max), - } - } - - pub fn clamp_box(&self, bbox: BBox) -> Self - where - T: std::cmp::PartialOrd, - T: num::Zero, - T: core::ops::Add, - T: core::ops::Sub, - { - let x1 = num::clamp(self.x1(), bbox.x1(), bbox.x2()); - let y1 = num::clamp(self.y1(), bbox.y1(), bbox.y2()); - let x2 = num::clamp(self.x2(), bbox.x1(), bbox.x2()); - let y2 = num::clamp(self.y2(), bbox.y1(), bbox.y2()); - Self::new_xyxy(x1, y1, x2, y2) - } - - pub fn normalize(&self, width: T, height: T) -> Self - where - T: core::ops::Div + Copy, - { - Self { - x: self.x / width, - y: self.y / height, - width: self.width / width, - height: self.height / height, - } - } - - /// Normalize after casting to float - pub fn normalize_f64(&self, width: T, height: T) -> BBox - where - T: core::ops::Div + Copy, - T: num::cast::AsPrimitive, - { - BBox { - x: self.x.as_() / width.as_(), - y: self.y.as_() / height.as_(), - width: self.width.as_() / width.as_(), - height: self.height.as_() / height.as_(), - } - } - - pub fn denormalize(&self, width: T, height: T) -> Self - where - T: core::ops::Mul + Copy, - { - Self { - x: self.x * width, - y: self.y * height, - width: self.width * width, - height: self.height * height, - } - } - - pub fn height(&self) -> T { - self.height - } - - pub fn width(&self) -> T { - self.width - } - - pub fn padding(&self, padding: T) -> Self - where - T: core::ops::Add + core::ops::Sub + Copy, - { - Self { - x: self.x - padding, - y: self.y - padding, - width: self.width + padding + padding, - height: self.height + padding + padding, - } - } - - pub fn padding_height(&self, padding: T) -> Self - where - T: core::ops::Add + core::ops::Sub + Copy, - { - Self { - x: self.x, - y: self.y - padding, - width: self.width, - height: self.height + padding + padding, - } - } - - pub fn padding_width(&self, padding: T) -> Self - where - T: core::ops::Add + core::ops::Sub + Copy, - { - Self { - x: self.x - padding, - y: self.y, - width: self.width + padding + padding, - height: self.height, - } - } - - // Enlarge / shrink the bounding box by a factor while - // keeping the center point and the aspect ratio fixed - pub fn scale(&self, factor: T) -> Self - where - T: core::ops::Mul, - T: core::ops::Sub, - T: core::ops::Add, - T: core::ops::Div, - T: num::One + Copy, - { - let two = num::one::() + num::one::(); - let width = self.width * factor; - let height = self.height * factor; - let width_inc = width - self.width; - let height_inc = height - self.height; - Self { - x: self.x - width_inc / two, - y: self.y - height_inc / two, - width, - height, - } - } - - pub fn scale_x(&self, factor: T) -> Self - where - T: core::ops::Mul - + core::ops::Sub - + core::ops::Add - + core::ops::Div - + num::One - + Copy, - { - let two = num::one::() + num::one::(); - let width = self.width * factor; - let width_inc = width - self.width; - Self { - x: self.x - width_inc / two, - y: self.y, - width, - height: self.height, - } - } - - pub fn scale_y(&self, factor: T) -> Self - where - T: core::ops::Mul - + core::ops::Sub - + core::ops::Add - + core::ops::Div - + num::One - + Copy, - { - let two = num::one::() + num::one::(); - let height = self.height * factor; - let height_inc = height - self.height; - Self { - x: self.x, - y: self.y - height_inc / two, - width: self.width, - height, - } - } - - pub fn offset(&self, offset: Point) -> Self - where - T: core::ops::Add + Copy, - { - Self { - x: self.x + offset.x, - y: self.y + offset.y, - width: self.width, - height: self.height, - } - } - - /// Translate the bounding box by the given offset - /// if they are in the same scale - pub fn translate(&self, bbox: Self) -> Self - where - T: core::ops::Add + Copy, - { - Self { - x: self.x + bbox.x, - y: self.y + bbox.y, - width: self.width, - height: self.height, - } - } - - pub fn with_top_left(&self, top_left: Point) -> Self { - Self { - x: top_left.x, - y: top_left.y, - width: self.width, - height: self.height, - } - } - - pub fn center(&self) -> Point - where - T: core::ops::Add + core::ops::Div + Copy, - T: num::One, - { - let two = T::one() + T::one(); - Point::new(self.x + self.width / two, self.y + self.height / two) - } - - pub fn area(&self) -> T - where - T: core::ops::Mul + Copy, - { - self.width * self.height - } - - // Corresponds to self.x1() and self.y1() - pub fn top_left(&self) -> Point { - Point::new(self.x, self.y) - } - - pub fn top_right(&self) -> Point - where - T: core::ops::Add + Copy, - { - Point::new(self.x + self.width, self.y) - } - - pub fn bottom_left(&self) -> Point - where - T: core::ops::Add + Copy, - { - Point::new(self.x, self.y + self.height) - } - - // Corresponds to self.x2() and self.y2() - pub fn bottom_right(&self) -> Point - where - T: core::ops::Add + Copy, - { - Point::new(self.x + self.width, self.y + self.height) - } - - pub const fn x1(&self) -> T { - self.x - } - - pub const fn y1(&self) -> T { - self.y - } - - pub fn x2(&self) -> T - where - T: core::ops::Add + Copy, - { - self.x + self.width - } - - pub fn y2(&self) -> T - where - T: core::ops::Add + Copy, - { - self.y + self.height - } - - pub fn overlap(&self, other: &Self) -> T - where - T: std::cmp::PartialOrd - + traits::min::Min - + traits::max::Max - + num::Zero - + core::ops::Add - + core::ops::Sub - + core::ops::Mul - + Copy, - { - let x1 = self.x.max(other.x); - let y1 = self.y.max(other.y); - let x2 = (self.x + self.width).min(other.x + other.width); - let y2 = (self.y + self.height).min(other.y + other.height); - let width = (x2 - x1).max(T::zero()); - let height = (y2 - y1).max(T::zero()); - width * height - } - - pub fn iou(&self, other: &Self) -> T - where - T: std::cmp::Ord - + num::Zero - + traits::min::Min - + traits::max::Max - + core::ops::Add - + core::ops::Sub - + core::ops::Mul - + core::ops::Div - + Copy, - { - let overlap = self.overlap(other); - let union = self.area() + other.area() - overlap; - overlap / union - } - - pub fn contains(&self, point: Point) -> bool - where - T: std::cmp::PartialOrd + core::ops::Add + Copy, - { - point.x >= self.x - && point.x <= self.x + self.width - && point.y >= self.y - && point.y <= self.y + self.height - } - - pub fn contains_bbox(&self, other: Self) -> bool - where - T: std::cmp::PartialOrd + Copy, - T: core::ops::Add, - { - self.contains(other.top_left()) - && self.contains(other.top_right()) - && self.contains(other.bottom_left()) - && self.contains(other.bottom_right()) - } - - pub fn new_xywh(x: T, y: T, width: T, height: T) -> Self { - Self { - x, - y, - width, - height, - } - } - pub fn new_xyxy(x1: T, y1: T, x2: T, y2: T) -> Self - where - T: core::ops::Sub + Copy, - { - Self { - x: x1, - y: y1, - width: x2 - x1, - height: y2 - y1, - } - } - - pub fn containing(box1: Self, box2: Self) -> Self - where - T: traits::min::Min + traits::max::Max + Copy, - T: core::ops::Sub, - T: core::ops::Add, - { - let x1 = box1.x.min(box2.x); - let y1 = box1.y.min(box2.y); - let x2 = box1.x2().max(box2.x2()); - let y2 = box1.y2().max(box2.y2()); - Self::new_xyxy(x1, y1, x2, y2) - } -} - -impl + Copy> core::ops::Sub for BBox { - type Output = BBox; - fn sub(self, rhs: T) -> Self::Output { - BBox { - x: self.x - rhs, - y: self.y - rhs, - width: self.width - rhs, - height: self.height - rhs, - } - } -} - -impl + Copy> core::ops::Add for BBox { - type Output = BBox; - fn add(self, rhs: T) -> Self::Output { - BBox { - x: self.x + rhs, - y: self.y + rhs, - width: self.width + rhs, - height: self.height + rhs, - } - } -} -impl + Copy> core::ops::Mul for BBox { - type Output = BBox; - fn mul(self, rhs: T) -> Self::Output { - BBox { - x: self.x * rhs, - y: self.y * rhs, - width: self.width * rhs, - height: self.height * rhs, - } - } -} -impl + Copy> core::ops::Div for BBox { - type Output = BBox; - fn div(self, rhs: T) -> Self::Output { - BBox { - x: self.x / rhs, - y: self.y / rhs, - width: self.width / rhs, - height: self.height / rhs, - } - } -} - -impl core::ops::Add> for BBox -where - T: core::ops::Sub - + core::ops::Add - + traits::min::Min - + traits::max::Max - + Copy, -{ - type Output = BBox; - fn add(self, rhs: BBox) -> Self::Output { - let x1 = self.x1().min(rhs.x1()); - let y1 = self.y1().min(rhs.y1()); - let x2 = self.x2().max(rhs.x2()); - let y2 = self.y2().max(rhs.y2()); - BBox::new_xyxy(x1, y1, x2, y2) - } -} - -#[test] -fn test_bbox_add() { - let bbox1: BBox = BBox::new_xyxy(0, 0, 10, 10); - let bbox2: BBox = BBox::new_xyxy(5, 5, 15, 15); - let bbox3: BBox = bbox1 + bbox2; - assert_eq!(bbox3, BBox::new_xyxy(0, 0, 15, 15).cast()); -} - -#[derive( - Debug, Copy, Clone, serde::Serialize, serde::Deserialize, PartialEq, PartialOrd, Eq, Ord, Hash, -)] -pub struct Point { - x: T, - y: T, -} - -impl Point { - pub const fn new(x: T, y: T) -> Self { - Self { x, y } - } - - pub const fn x(&self) -> T - where - T: Copy, - { - self.x - } - - pub const fn y(&self) -> T - where - T: Copy, - { - self.y - } - - pub fn cast(&self) -> Point - where - T: num::cast::AsPrimitive, - T2: Copy + 'static, - { - Point { - x: self.x.as_(), - y: self.y.as_(), - } - } -} - -impl + Copy> core::ops::Sub> for Point { - type Output = Point; - fn sub(self, rhs: Point) -> Self::Output { - Point { - x: self.x - rhs.x, - y: self.y - rhs.y, - } - } -} - -impl + Copy> core::ops::Add> for Point { - type Output = Point; - fn add(self, rhs: Point) -> Self::Output { - Point { - x: self.x + rhs.x, - y: self.y + rhs.y, - } - } -} - -impl + Copy> Point { - /// If both the boxes are in the same scale then make the translation of the origin to the - /// other box - pub fn with_origin(&self, origin: Self) -> Self { - *self - origin - } -} - -impl + Copy> Point { - pub fn translate(&self, point: Point) -> Self { - *self + point - } -} - -impl BBox -where - I: num::cast::AsPrimitive, -{ - pub fn zeros_ndarray_2d(&self) -> ndarray::Array2 { - ndarray::Array2::::zeros((self.height.as_(), self.width.as_())) - } - pub fn zeros_ndarray_3d(&self, channels: usize) -> ndarray::Array3 { - ndarray::Array3::::zeros((self.height.as_(), self.width.as_(), channels)) - } - pub fn ones_ndarray_2d(&self) -> ndarray::Array2 { - ndarray::Array2::::ones((self.height.as_(), self.width.as_())) - } -} - -impl BBox { - pub fn round(&self) -> Self { - Self { - x: self.x.round(), - y: self.y.round(), - width: self.width.round(), - height: self.height.round(), - } - } -} - -#[cfg(test)] -mod bbox_clamp_tests { - use super::*; - #[test] - pub fn bbox_test_clamp_box() { - let large_box = BBox::new(0, 0, 100, 100); - let small_box = BBox::new(10, 10, 20, 20); - let clamped = large_box.clamp_box(small_box); - assert_eq!(clamped, small_box); - } - - #[test] - pub fn bbox_test_clamp_box_offset() { - let box_a = BBox::new(0, 0, 100, 100); - let box_b = BBox::new(-10, -10, 20, 20); - let clamped = box_b.clamp_box(box_a); - let expected = BBox::new(0, 0, 10, 10); - assert_eq!(expected, clamped); - } -} - -#[cfg(test)] -mod bbox_padding_tests { - use super::*; - #[test] - pub fn bbox_test_padding() { - let bbox = BBox::new(0, 0, 10, 10); - let padded = bbox.padding(2); - assert_eq!(padded, BBox::new(-2, -2, 14, 14)); - } - - #[test] - pub fn bbox_test_padding_height() { - let bbox = BBox::new(0, 0, 10, 10); - let padded = bbox.padding_height(2); - assert_eq!(padded, BBox::new(0, -2, 10, 14)); - } - - #[test] - pub fn bbox_test_padding_width() { - let bbox = BBox::new(0, 0, 10, 10); - let padded = bbox.padding_width(2); - assert_eq!(padded, BBox::new(-2, 0, 14, 10)); - } - - #[test] - pub fn bbox_test_clamped_padding() { - let bbox = BBox::new(0, 0, 10, 10); - let padded = bbox.padding(2); - let clamp = BBox::new(0, 0, 12, 12); - let clamped = padded.clamp_box(clamp); - assert_eq!(clamped, clamp); - } - - #[test] - pub fn bbox_clamp_failure() { - let og = BBox::new(475.0, 79.625, 37.0, 282.15); - let padded = BBox { - x: 471.3, - y: 51.412499999999994, - width: 40.69999999999999, - height: 338.54999999999995, - }; - let clamp = BBox::new(0.0, 0.0, 512.0, 512.0); - let sus = padded.clamp_box(clamp); - assert!(clamp.contains_bbox(sus)); - } -} - -#[cfg(test)] -mod bbox_scale_tests { - use super::*; - #[test] - pub fn bbox_test_scale_int() { - let bbox = BBox::new(0, 0, 10, 10); - let scaled = bbox.scale(2); - assert_eq!(scaled, BBox::new(-5, -5, 20, 20)); - } - - #[test] - pub fn bbox_test_scale_float() { - let bbox = BBox::new(0, 0, 10, 10).cast(); - let scaled = bbox.scale(1.05); // 5% increase - let l = 10.0 * 0.05; - assert_eq!(scaled, BBox::new(-l / 2.0, -l / 2.0, 10.0 + l, 10.0 + l)); - } - - #[test] - pub fn bbox_test_scale_float_negative() { - let bbox = BBox::new(0, 0, 10, 10).cast(); - let scaled = bbox.scale(0.95); // 5% decrease - let l = -10.0 * 0.05; - assert_eq!(scaled, BBox::new(-l / 2.0, -l / 2.0, 10.0 + l, 10.0 + l)); - } - - #[test] - pub fn bbox_scale_float() { - let bbox = BBox::new_xywh(0, 0, 200, 200); - let scaled = bbox.cast::().scale(1.1).cast::().clamp(0, 1000); - let expected = BBox::new(0, 0, 220, 220); - assert_eq!(scaled, expected); - } - #[test] - pub fn add_padding_bbox_example() { - // let result = add_padding_bbox( - // vec![Rect::new(100, 200, 300, 400)], - // (0.1, 0.1), - // (1000, 1000), - // ); - // assert_eq!(result[0], Rect::new(70, 160, 360, 480)); - let bbox = BBox::new(100, 200, 300, 400); - let scaled = bbox.cast::().scale(1.2).cast::().clamp(0, 1000); - assert_eq!(bbox, BBox::new(100, 200, 300, 400)); - assert_eq!(scaled, BBox::new(70, 160, 360, 480)); - } - #[test] - pub fn scale_bboxes() { - // let result = scale_bboxes(Rect::new(100, 200, 300, 400), (1000, 1000), (500, 500)); - // assert_eq!(result[0], Rect::new(200, 400, 600, 800)); - let bbox = BBox::new(100, 200, 300, 400); - let scaled = bbox.scale(2); - assert_eq!(scaled, BBox::new(200, 400, 600, 800)); - } -} diff --git a/bbox/src/traits.rs b/bbox/src/traits.rs deleted file mode 100644 index 8c95292..0000000 --- a/bbox/src/traits.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod max; -pub mod min; diff --git a/bbox/src/traits/max.rs b/bbox/src/traits/max.rs deleted file mode 100644 index 94b501d..0000000 --- a/bbox/src/traits/max.rs +++ /dev/null @@ -1,27 +0,0 @@ -pub trait Max: Sized + Copy { - fn max(self, other: Self) -> Self; -} - -macro_rules! impl_max { - ($($t:ty),*) => { - $( - impl Max for $t { - fn max(self, other: Self) -> Self { - Ord::max(self, other) - } - } - )* - }; - (float $($t:ty),*) => { - $( - impl Max for $t { - fn max(self, other: Self) -> Self { - Self::max(self, other) - } - } - )* - }; -} - -impl_max!(usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, i64, i128); -impl_max!(float f32, f64); diff --git a/bbox/src/traits/min.rs b/bbox/src/traits/min.rs deleted file mode 100644 index e4d24ae..0000000 --- a/bbox/src/traits/min.rs +++ /dev/null @@ -1,27 +0,0 @@ -pub trait Min: Sized + Copy { - fn min(self, other: Self) -> Self; -} - -macro_rules! impl_min { - ($($t:ty),*) => { - $( - impl Min for $t { - fn min(self, other: Self) -> Self { - Ord::min(self, other) - } - } - )* - }; - (float $($t:ty),*) => { - $( - impl Min for $t { - fn min(self, other: Self) -> Self { - Self::min(self, other) - } - } - )* - }; -} - -impl_min!(usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, i64, i128); -impl_min!(float f32, f64); diff --git a/bounding-box/src/draw.rs b/bounding-box/src/draw.rs index 1f97b2a..8ed4e38 100644 --- a/bounding-box/src/draw.rs +++ b/bounding-box/src/draw.rs @@ -53,8 +53,9 @@ impl Drawable> for Aabb2 { let bottom = Aabb2::from_x1y1x2y2(x1y2.x, x1y2.y, x2y2.x, x2y2.y + thickness); let left = Aabb2::from_x1y1x2y2(x1y1.x, x1y1.y, x1y2.x + thickness, x1y2.y); let right = Aabb2::from_x1y1x2y2(x2y1.x, x2y1.y, x2y2.x + thickness, x2y2.y + thickness); - let lines = [top, bottom, left, right]; - lines.into_iter().for_each(|line| { + let canvas_bbox = Aabb2::from_x1y1x2y2(0, 0, canvas.dim().1 - 1, canvas.dim().0 - 1); + let lines = [top, bottom, left, right].map(|bbox| bbox.clamp(&canvas_bbox)); + lines.into_iter().flatten().for_each(|line| { canvas .roi_mut(line) .map(|mut line| { diff --git a/bounding-box/src/nms.rs b/bounding-box/src/nms.rs index 04a051f..d2a652c 100644 --- a/bounding-box/src/nms.rs +++ b/bounding-box/src/nms.rs @@ -6,7 +6,9 @@ use crate::*; /// # Arguments /// /// * `boxes` - A slice of bounding boxes to apply NMS on. -/// * `threshold` - The IoU threshold for suppression. +/// * `scores` - A slice of confidence scores corresponding to the bounding boxes. +/// * `score_threshold` - The minimum score threshold for consideration. +/// * `nms_threshold` - The IoU threshold for suppression. /// /// # Returns /// @@ -16,7 +18,7 @@ pub fn nms( scores: &[T], score_threshold: T, nms_threshold: T, -) -> Vec> +) -> HashSet where T: Num + num::Float @@ -28,48 +30,55 @@ where + nalgebra::SimdPartialOrd, { use itertools::Itertools; - let bboxes: Vec<_> = boxes - .iter() - .zip(scores.iter()) - .filter_map(|(bbox, score)| (score >= &score_threshold).then_some((bbox, score))) - .sorted_by(|(_, score_a), (_, score_b)| { - score_b - .partial_cmp(score_a) - .unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(bbox, _)| bbox) - .collect(); - let outputs = bboxes + // Create vector of (index, box, score) tuples for boxes with scores above threshold + let mut indexed_boxes: Vec<(usize, &Aabb2, &T)> = boxes .iter() .enumerate() - .scan( - HashSet::with_capacity(bboxes.len()), - |state, (index, bbox)| { - if state.is_empty() { - state.insert(index); - return Some(Some(bbox)); - } else { - if state.contains(&index) { - return Some(None); - } - let to_remove = bboxes - .iter() - .enumerate() - .skip(index + 1) - .filter_map(|(index, bbox_b)| { - (!state.contains(&index)).then_some(index)?; - let iou = bbox.iou(bbox_b)?; - (iou >= nms_threshold).then_some(index) - }) - .collect_vec(); - state.extend(to_remove); - Some(Some(bbox)) + .zip(scores.iter()) + .filter_map(|((idx, bbox), score)| { + if *score >= score_threshold { + Some((idx, bbox, score)) + } else { + None + } + }) + .collect(); + + // Sort by score in descending order + indexed_boxes.sort_by(|(_, _, score_a), (_, _, score_b)| { + score_b + .partial_cmp(score_a) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut keep_indices = HashSet::new(); + let mut suppressed = HashSet::new(); + + for (i, (idx_i, bbox_i, _)) in indexed_boxes.iter().enumerate() { + // Skip if this box is already suppressed + if suppressed.contains(idx_i) { + continue; + } + + // Keep this box + keep_indices.insert(*idx_i); + + // Compare with remaining boxes + for (idx_j, bbox_j, _) in indexed_boxes.iter().skip(i + 1) { + // Skip if this box is already suppressed + if suppressed.contains(idx_j) { + continue; + } + + // Calculate IoU and suppress if above threshold + if let Some(iou) = bbox_i.iou(bbox_j) { + if iou >= nms_threshold { + suppressed.insert(*idx_j); } - }, - ) - .flatten() - .map(|bbox| **bbox) - .collect_vec(); - outputs + } + } + } + + keep_indices } diff --git a/src/cli.rs b/src/cli.rs index 3e88032..bb0227a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -18,6 +18,7 @@ pub enum SubCommand { #[derive(Debug, clap::ValueEnum, Clone, Copy)] pub enum Models { RetinaFace, + Yolo, } #[derive(Debug, clap::ValueEnum, Clone, Copy)] @@ -49,6 +50,8 @@ pub struct Detect { pub output: Option, #[clap(short, long, default_value_t = 0.8)] pub threshold: f32, + #[clap(short, long, default_value_t = 0.3)] + pub nms_threshold: f32, pub image: PathBuf, } diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index 6f36d24..69b8f52 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -7,7 +7,7 @@ use ndarray_resize::NdFir; use std::path::Path; pub struct FaceDetectionConfig { - min_sizes: Vec>, + anchor_sizes: Vec>, steps: Vec, variance: Vec, threshold: f32, @@ -16,7 +16,7 @@ pub struct FaceDetectionConfig { impl FaceDetectionConfig { pub fn with_min_sizes(mut self, min_sizes: Vec>) -> Self { - self.min_sizes = min_sizes; + self.anchor_sizes = min_sizes; self } pub fn with_steps(mut self, steps: Vec) -> Self { @@ -40,7 +40,7 @@ impl FaceDetectionConfig { impl Default for FaceDetectionConfig { fn default() -> Self { FaceDetectionConfig { - min_sizes: vec![ + anchor_sizes: vec![ Vector2::new(16, 32), Vector2::new(64, 128), Vector2::new(256, 512), @@ -48,7 +48,7 @@ impl Default for FaceDetectionConfig { steps: vec![8, 16, 32], variance: vec![0.1, 0.2], threshold: 0.8, - nms_threshold: 0.6, + nms_threshold: 0.4, } } } @@ -91,15 +91,15 @@ impl FaceDetectionModelOutput { pub fn postprocess(self, config: &FaceDetectionConfig) -> Result { let mut anchors = Vec::new(); for (k, &step) in config.steps.iter().enumerate() { - let feature_size = 640 / step; - let min_sizes = config.min_sizes[k]; + let feature_size = 1024 / step; + let min_sizes = config.anchor_sizes[k]; let sizes = [min_sizes.x, min_sizes.y]; for i in 0..feature_size { for j in 0..feature_size { for &size in &sizes { - let cx = (j as f32 + 0.5) * step as f32 / 640.0; - let cy = (i as f32 + 0.5) * step as f32 / 640.0; - let s_k = size as f32 / 640.0; + let cx = (j as f32 + 0.5) * step as f32 / 1024.0; + let cy = (i as f32 + 0.5) * step as f32 / 1024.0; + let s_k = size as f32 / 1024.0; anchors.push((cx, cy, s_k, s_k)); } } @@ -220,7 +220,7 @@ impl FaceDetection { image: ndarray::Array3, config: FaceDetectionConfig, ) -> Result { - let (height, width, channels) = image.dim(); + let (height, width, _channels) = image.dim(); let output = self .run_models(image) .change_context(Error) @@ -242,17 +242,31 @@ impl FaceDetection { .map(|((b, s), l)| (b, s, l)) .multiunzip(); - let boxes = nms(&boxes, &scores, config.threshold, config.nms_threshold); + let keep_indices = nms(&boxes, &scores, config.threshold, config.nms_threshold); let bboxes = boxes .into_iter() - .flat_map(|x| x.denormalize(factor).try_cast::()) + .enumerate() + .filter(|(i, _)| keep_indices.contains(i)) + .flat_map(|(_, x)| x.denormalize(factor).try_cast::()) + .collect(); + let confidence = scores + .into_iter() + .enumerate() + .filter(|(i, _)| keep_indices.contains(i)) + .map(|(_, score)| score) + .collect(); + let landmark = landmarks + .into_iter() + .enumerate() + .filter(|(i, _)| keep_indices.contains(i)) + .map(|(_, score)| score) .collect(); Ok(FaceDetectionOutput { bbox: bboxes, - confidence: processed.confidence, - landmark: processed.landmarks, + confidence, + landmark, }) } @@ -263,7 +277,7 @@ impl FaceDetection { .handle .run(move |sr| { let mut resized = image - .fast_resize(640, 640, None) + .fast_resize(1024, 1024, None) .change_context(mnn::ErrorKind::TensorError)? .mapv(|f| f as f32) .tap_mut(|arr| { @@ -292,8 +306,8 @@ impl FaceDetection { input.view_mut(), 1, 3, - 640, - 640, + 1024, + 1024, ); } intptr.resize_session(session); diff --git a/src/main.rs b/src/main.rs index 6fe5ad8..cfef0f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,7 +34,7 @@ pub fn main() -> Result<()> { for bbox in output.bbox { tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; - array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10); + array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); } let v = array.view(); if let Some(output) = detect.output {