From df5584d797b0a63ab70c31b7534160ab65b9e475 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 4 Aug 2025 19:27:45 +0530 Subject: [PATCH] feat: Added postprocessing for retinaface --- Cargo.lock | 3 ++ Cargo.toml | 1 + bounding-box/Cargo.toml | 2 + bounding-box/src/draw.rs | 48 +++++++++++++------- bounding-box/src/lib.rs | 67 +++++++++++++++++++++++++++ bounding-box/src/roi.rs | 97 ++++++++++++++++++++++++++++++++++++++++ src/cli.rs | 2 + src/main.rs | 36 +++++++++++++-- 8 files changed, 235 insertions(+), 21 deletions(-) create mode 100644 bounding-box/src/roi.rs diff --git a/Cargo.lock b/Cargo.lock index f6db907..12effd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,6 +255,8 @@ dependencies = [ "nalgebra", "ndarray 0.16.1", "num", + "simba", + "thiserror 2.0.12", ] [[package]] @@ -498,6 +500,7 @@ dependencies = [ "bounding-box", "clap", "clap_complete", + "color", "error-stack", "fast_image_resize", "image", diff --git a/Cargo.toml b/Cargo.toml index 543bffa..a38bc79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ mnn = { workspace = true } mnn-bridge = { workspace = true } mnn-sync = { workspace = true } bounding-box = { version = "0.1.0", path = "bounding-box" } +color = "0.3.1" [profile.release] debug = true diff --git a/bounding-box/Cargo.toml b/bounding-box/Cargo.toml index 9d875ba..a3f5525 100644 --- a/bounding-box/Cargo.toml +++ b/bounding-box/Cargo.toml @@ -9,6 +9,8 @@ itertools = "0.14.0" nalgebra = "0.33.2" ndarray = { version = "0.16.1", optional = true } num = "0.4.3" +simba = "0.9.0" +thiserror = "2.0.12" [features] ndarray = ["dep:ndarray"] diff --git a/bounding-box/src/draw.rs b/bounding-box/src/draw.rs index ada3edd..7a80864 100644 --- a/bounding-box/src/draw.rs +++ b/bounding-box/src/draw.rs @@ -1,39 +1,53 @@ use crate::*; pub use color::Rgba8; -use ndarray::{Array3, ArrayViewMut3}; +use ndarray::{Array1, Array3, ArrayViewMut3}; pub trait Draw { fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize); } -// impl> Draw for Array3 { -// fn draw(&self, item: T, color: color::Rgba8, thickness: usize) { -// item.draw(&self, color, thickness); -// } -// } +impl Draw> for Array3 { + fn draw(&mut self, item: Aabb2, color: color::Rgba8, thickness: usize) { + item.draw(self, color, thickness) + } +} -pub trait Drawable { - fn draw(&self, canvas: &mut Canvas, color: color::Rgba8, thickness: T); +pub trait Drawable { + fn draw(&self, canvas: &mut Canvas, color: color::Rgba8, thickness: usize); } /// Implementing Drawable for Aabb2 with Array3 as the canvas type /// Assuming Array3 is a 3D array representing an image with RGB/RGBA channels -impl Drawable, T> for Aabb2 -where - T: Num + core::ops::SubAssign + core::ops::AddAssign + core::ops::DivAssign, - T: PartialOrd, -{ - fn draw(&self, canvas: &mut ArrayViewMut3, color: color::Rgba8, thickness: T) { +impl Drawable> for Aabb2 { + fn draw(&self, canvas: &mut ArrayViewMut3, color: color::Rgba8, thickness: usize) { use itertools::Itertools; - let (height, width, channels) = canvas.dim(); + // let (height, width, channels) = canvas.dim(); + let color = Array1::from_vec(vec![color.r, color.g, color.b, color.a]); self.corners() .iter() .zip(self.padding(thickness).corners()) .tuple_windows() .for_each(|((a, b), (c, d))| { - let bbox = Aabb2::from_vertices([*a, b, *c, d]); - todo!(); + let bbox = Aabb2::from_vertices([*a, b, *c, d]).expect("Invalid bounding box"); + use crate::roi::RoiMut; + let mut out = canvas.roi_mut(bbox).expect("Failed to get ROI"); + out.lanes_mut(ndarray::Axis(2)) + .into_iter() + .for_each(|mut pixel| { + pixel.assign(&color); + }); }); } } + +impl Drawable> for Aabb2 { + fn draw(&self, canvas: &mut Array3, color: color::Rgba8, thickness: usize) { + use itertools::Itertools; + // let (height, width, channels) = canvas.dim(); + let color = Array1::from_vec(vec![color.r, color.g, color.b, color.a]); + let pixel_size = canvas.dim().2; + let color = color.slice(ndarray::s![..pixel_size]); + let [x1y1, x2y1, x1y2, x2y2] = self.corners(); + } +} diff --git a/bounding-box/src/lib.rs b/bounding-box/src/lib.rs index 281ff62..26e5b30 100644 --- a/bounding-box/src/lib.rs +++ b/bounding-box/src/lib.rs @@ -1,5 +1,6 @@ pub mod draw; pub mod nms; +pub mod roi; use nalgebra::{Point, Point2, Point3, SVector}; pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {} @@ -31,6 +32,8 @@ impl AxisAlignedBoundingBox { Self::new(point1, SVector::from(size)) } + /// Only considers the points closest and furthest from origin + /// Points which are rotated along in the z axis (in 2d) are not considered pub fn from_vertices(points: [Point; 4]) -> Option where T: core::ops::SubAssign, @@ -182,9 +185,51 @@ impl AxisAlignedBoundingBox { Point::from(max), )) } + + pub fn denormalize(&self, factor: nalgebra::SVector) -> Self + where + T: core::ops::MulAssign, + T: core::ops::AddAssign, + // nalgebra::constraint::ShapeConstraint: + // nalgebra::constraint::DimEq, nalgebra::Const>, + { + Self { + point: (self.point.coords.component_mul(&factor)).into(), + size: self.size.component_mul(&factor), + } + } + + pub fn cast(&self) -> Option> + where + // T: num::NumCast, + T2: Num + simba::scalar::SubsetOf, + { + Some(Aabb { + point: Point::from(self.point.coords.try_cast::()?), + size: self.size.try_cast::()?, + }) + } + + // pub fn as_(&self) -> Option> + // where + // T2: Num + simba::scalar::SubsetOf, + // { + // Some(Aabb { + // point: Point::from(self.point.coords.as_()), + // size: self.size.as_(), + // }) + // } } impl Aabb2 { + pub fn from_x1y1x2y2(x1: T, x2: T, y1: T, y2: T) -> Self + where + T: core::ops::SubAssign, + { + let point1 = Point2::new(x1, y1); + let point2 = Point2::new(x2, y2); + Self::from_min_max_vertices(point1, point2) + } pub fn new_2d(point1: Point2, point2: Point2) -> Self where T: core::ops::SubAssign, @@ -217,6 +262,28 @@ impl Aabb2 { Point2::new(self.point.x, self.point.y + self.size.y) } + pub fn x1(&self) -> T { + self.point.x + } + + pub fn y1(&self) -> T { + self.point.y + } + + pub fn x2(&self) -> T + where + T: core::ops::AddAssign, + { + self.point.x + self.size.x + } + + pub fn y2(&self) -> T + where + T: core::ops::AddAssign, + { + self.point.y + self.size.y + } + pub fn corners(&self) -> [Point2; 4] where T: core::ops::AddAssign, diff --git a/bounding-box/src/roi.rs b/bounding-box/src/roi.rs new file mode 100644 index 0000000..3d07357 --- /dev/null +++ b/bounding-box/src/roi.rs @@ -0,0 +1,97 @@ +use crate::*; +use ndarray::{Array3, ArrayView3, ArrayViewMut3}; +/// A trait that extracts a region of interest from an image +pub trait Roi<'a, Output> { + type Error; + fn roi(&'a self, aabb: Aabb2) -> Result; +} +pub trait RoiMut<'a, Output> { + type Error; + fn roi_mut(&'a mut self, aabb: Aabb2) -> Result; +} +#[derive(thiserror::Error, Debug, Copy, Clone)] +pub enum RoiError { + #[error("Region of intereset is out of bounds")] + RoiOutOfBounds, +} + +impl<'a, T: Num> Roi<'a, ArrayView3<'a, T>> for Array3 { + type Error = RoiError; + fn roi(&'a self, aabb: Aabb2) -> Result, Self::Error> { + let x1 = aabb.x1(); + let x2 = aabb.x2(); + let y1 = aabb.y1(); + let y2 = aabb.y2(); + if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { + return Err(RoiError::RoiOutOfBounds); + } + Ok(self.slice(ndarray::s![y1..y2, x1..x2, ..])) + } +} + +impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3 { + type Error = RoiError; + fn roi_mut(&'a mut self, aabb: Aabb2) -> Result, Self::Error> { + let x1 = aabb.x1(); + let x2 = aabb.x2(); + let y1 = aabb.y1(); + let y2 = aabb.y2(); + if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { + return Err(RoiError::RoiOutOfBounds); + } + Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..])) + } +} + +impl<'a, 'b, T: Num> Roi<'a, ArrayView3<'b, T>> for ArrayView3<'b, T> { + type Error = RoiError; + fn roi(&'a self, aabb: Aabb2) -> Result, Self::Error> { + let x1 = aabb.x1(); + let x2 = aabb.x2(); + let y1 = aabb.y1(); + let y2 = aabb.y2(); + if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { + return Err(RoiError::RoiOutOfBounds); + } + Ok(self.slice_move(ndarray::s![y1..y2, x1..x2, ..])) + } +} +// impl<'a, 'b, T: Num> Roi<'a, ArrayViewMut3<'b, T>> for ArrayViewMut3<'b, T> { +// type Error = RoiError; +// fn roi(&'a self, aabb: Aabb2) -> Result, Self::Error> { +// let x1 = aabb.x1(); +// let x2 = aabb.x2(); +// let y1 = aabb.y1(); +// let y2 = aabb.y2(); +// if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { +// return Err(RoiError::RoiOutOfBounds); +// } +// Ok(self.slice(ndarray::s![y1..y2, x1..x2, ..])) +// } +// } + +impl<'a, 'b: 'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for ArrayViewMut3<'b, T> { + type Error = RoiError; + fn roi_mut(&'a mut self, aabb: Aabb2) -> Result, Self::Error> { + let x1 = aabb.x1(); + let x2 = aabb.x2(); + let y1 = aabb.y1(); + let y2 = aabb.y2(); + if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { + return Err(RoiError::RoiOutOfBounds); + } + let out: ArrayViewMut3<'a, T> = self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]); + Ok(out) + } +} + +#[test] +pub fn reborrow_test() { + let ndarray = ndarray::Array::from_shape_vec((5, 5, 5), vec![33; 5 * 5 * 5]).unwrap(); + let aabb = Aabb2::from_x1y1x2y2(2, 3, 4, 5); + let y = { + let view = ndarray.view(); + view.roi(aabb).unwrap() + }; + dbg!(y); +} diff --git a/src/cli.rs b/src/cli.rs index afd12e1..0be2112 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -45,6 +45,8 @@ pub struct Detect { pub model: Option, #[clap(short = 'M', long, default_value = "retina-face")] pub model_type: Models, + #[clap(short, long)] + pub output: Option, pub image: PathBuf, } diff --git a/src/main.rs b/src/main.rs index 1f0869f..10d59e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,17 +19,45 @@ pub fn main() -> Result<()> { .attach_printable("Failed to create face detection model")?; let image = image::open(detect.image).change_context(Error)?; let image = image.into_rgb8(); - let array = image + let mut array = image .into_ndarray() .change_context(errors::Error) .attach_printable("Failed to convert image to ndarray")?; let output = model - .detect_faces(array) + .detect_faces(array.clone()) .change_context(errors::Error) .attach_printable("Failed to detect faces")?; // output.print(20); - let aabbs = output.postprocess(Default::default()); - dbg!(aabbs); + let aabbs = output + .postprocess(Default::default()) + .change_context(errors::Error) + .attach_printable("Failed to attach context")?; + for bbox in aabbs { + println!("Detected face: {:?}", bbox); + use bounding_box::draw::*; + let bbox = bbox + .denormalize(nalgebra::SVector::::new( + array.shape()[1] as f32, + array.shape()[0] as f32, + )) + .cast() + .ok_or(errors::Error) + .attach_printable("Failed to cast f32 to usize")?; + dbg!(bbox); + array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 20); + break; + } + let v = array.view(); + if let Some(output) = detect.output { + let image: image::RgbImage = v + .to_image() + .change_context(errors::Error) + .attach_printable("Failed to convert ndarray to image")?; + image + .save(output) + .change_context(errors::Error) + .attach_printable("Failed to save output image")?; + } } cli::SubCommand::List(list) => { println!("List: {:?}", list);