diff --git a/bounding-box/src/lib.rs b/bounding-box/src/lib.rs index 7247a2b..8df0dd5 100644 --- a/bounding-box/src/lib.rs +++ b/bounding-box/src/lib.rs @@ -2,7 +2,7 @@ pub mod draw; pub mod nms; pub mod roi; -use nalgebra::{Point, Point2, Point3, SVector, SimdPartialOrd, SimdValue}; +use nalgebra::{Point, Point2, SVector, Vector2}; pub trait Num: num::Num + core::ops::AddAssign @@ -212,13 +212,9 @@ impl AxisAlignedBoundingBox { T: nalgebra::SimdValue, T: nalgebra::SimdPartialOrd, { - let self_min = self.min_vertex(); - let self_max = self.max_vertex(); - let other_min = other.min_vertex(); - let other_max = other.max_vertex(); - let min = self_min.inf(&other_min); - let max = self_max.sup(&other_max); - Self::from_min_max_vertices(min, max) + let min = self.min_vertex().inf(&other.min_vertex()); + let max = self.min_vertex().sup(&other.max_vertex()); + Self::new(min, max) } pub fn union(&self, other: &Self) -> T @@ -244,13 +240,9 @@ impl AxisAlignedBoundingBox { T: nalgebra::SimdPartialOrd, T: nalgebra::SimdValue, { - let inter_min = self.min_vertex().inf(&other.min_vertex()); - let inter_max = self.max_vertex().sup(&other.max_vertex()); - - if inter_max < inter_min { - return None; // No intersection - } - Some(Self::new(inter_min, inter_max)) + let inter_min = self.min_vertex().sup(&other.min_vertex()); + let inter_max = self.max_vertex().inf(&other.max_vertex()); + Self::try_new(inter_min, inter_max) } pub fn denormalize(&self, factor: nalgebra::SVector) -> Self @@ -302,12 +294,19 @@ impl AxisAlignedBoundingBox { T: nalgebra::SimdValue, T: core::ops::MulAssign, { - let intersection = self - .intersection(other) - .map(|v| v.measure()) - .unwrap_or(T::zero()); - let union = self.union(other); - intersection / union + let lhs_min = self.min_vertex(); + let lhs_max = self.max_vertex(); + let rhs_min = other.min_vertex(); + let rhs_max = other.max_vertex(); + + let inter_min = lhs_min.sup(&rhs_min); + let inter_max = lhs_max.inf(&rhs_max); + if inter_max < inter_min { + return T::zero(); + } else { + let intersection = Aabb::new(inter_min, inter_max).measure(); + intersection / (self.measure() + other.measure() - intersection) + } } } @@ -318,8 +317,15 @@ impl Aabb2 { { let point1 = Point2::new(x1, y1); let point2 = Point2::new(x2, y2); - Self::from_min_max_vertices(point1, point2) + Self::new(point1, point2) } + + pub fn from_xywh(x: T, y: T, w: T, h: T) -> Self { + let point = Point2::new(x, y); + let size = Vector2::new(w, h); + Self::new_point_size(point, size) + } + pub fn x1y1(&self) -> Point2 { self.point } @@ -391,156 +397,219 @@ impl Aabb3 { } } -#[test] -fn test_bbox_new() { - use nalgebra::{Point2, Vector2}; +#[cfg(test)] +mod boudning_box_tests { + use super::*; + use nalgebra::*; - let point1 = Point2::new(1.0, 2.0); - let point2 = Point2::new(4.0, 6.0); - let bbox = AxisAlignedBoundingBox::new(point1, point2); + #[test] + fn test_bbox_new() { + let point1 = Point2::new(1.0, 2.0); + let point2 = Point2::new(4.0, 6.0); + let bbox = AxisAlignedBoundingBox::new(point1, point2); - assert_eq!(bbox.min_vertex(), point1); - assert_eq!(bbox.size(), Vector2::new(3.0, 4.0)); - assert_eq!(bbox.center(), Point2::new(2.5, 4.0)); -} + assert_eq!(bbox.min_vertex(), point1); + assert_eq!(bbox.size(), Vector2::new(3.0, 4.0)); + assert_eq!(bbox.center(), Point2::new(2.5, 4.0)); + } -#[test] -fn test_bounding_box_center_2d() { - use nalgebra::{Point2, Vector2}; + #[test] + fn test_intersection_and_merge() { + let point1 = Point2::new(1, 5); + let point2 = Point2::new(3, 2); + let size1 = Vector2::new(3, 4); + let size2 = Vector2::new(1, 3); - let point = Point2::new(1.0, 2.0); - let size = Vector2::new(3.0, 4.0); - let bbox = AxisAlignedBoundingBox::new(point, size); + let this = Aabb2::new_point_size(point1, size1); + let other = Aabb2::new_point_size(point2, size2); + let inter = this.intersection(&other); + let merged = this.merge(&other); + assert_ne!(inter, Some(merged)) + } - assert_eq!(bbox.min_vertex(), point); - assert_eq!(bbox.size(), size); - assert_eq!(bbox.center(), Point2::new(2.5, 4.0)); -} + #[test] + fn test_bounding_box_center_2d() { + let point = Point2::new(1.0, 2.0); + let size = Vector2::new(3.0, 4.0); + let bbox = AxisAlignedBoundingBox::new_point_size(point, size); -#[test] -fn test_bounding_box_center_3d() { - use nalgebra::{Point3, Vector3}; + assert_eq!(bbox.min_vertex(), point); + assert_eq!(bbox.size(), size); + assert_eq!(bbox.center(), Point2::new(2.5, 4.0)); + } - let point = Point3::new(1.0, 2.0, 3.0); - let size = Vector3::new(4.0, 5.0, 6.0); - let bbox = AxisAlignedBoundingBox::new(point, size); + #[test] + fn test_bounding_box_center_3d() { + let point = Point3::new(1.0, 2.0, 3.0); + let size = Vector3::new(4.0, 5.0, 6.0); + let bbox = AxisAlignedBoundingBox::new_point_size(point, size); - assert_eq!(bbox.min_vertex(), point); - assert_eq!(bbox.size(), size); - assert_eq!(bbox.center(), Point3::new(3.0, 4.5, 6.0)); -} + assert_eq!(bbox.min_vertex(), point); + assert_eq!(bbox.size(), size); + assert_eq!(bbox.center(), Point3::new(3.0, 4.5, 6.0)); + } -#[test] -fn test_bounding_box_padding_2d() { - use nalgebra::{Point2, Vector2}; + #[test] + fn test_bounding_box_padding_2d() { + let point = Point2::new(1.0, 2.0); + let size = Vector2::new(3.0, 4.0); + let bbox = AxisAlignedBoundingBox::new_point_size(point, size); - let point = Point2::new(1.0, 2.0); - let size = Vector2::new(3.0, 4.0); - let bbox = AxisAlignedBoundingBox::new(point, size); + let padded_bbox = bbox.padding(1.0); + assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5)); + assert_eq!(padded_bbox.size(), Vector2::new(4.0, 5.0)); + } - let padded_bbox = bbox.padding(1.0); - assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5)); - assert_eq!(padded_bbox.size(), Vector2::new(4.0, 5.0)); -} + #[test] + fn test_bounding_box_scaling_2d() { + let point = Point2::new(1.0, 1.0); + let size = Vector2::new(3.0, 4.0); + let bbox = AxisAlignedBoundingBox::new_point_size(point, size); -#[test] -fn test_bounding_box_scaling_2d() { - use nalgebra::{Point2, Vector2}; + let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0)); + assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0)); + assert_eq!(padded_bbox.size(), Vector2::new(6.0, 8.0)); + } - let point = Point2::new(1.0, 1.0); - let size = Vector2::new(3.0, 4.0); - let bbox = AxisAlignedBoundingBox::new(point, size); + #[test] + fn test_bounding_box_contains_2d() { + let point1 = Point2::new(1.0, 2.0); + let point2 = Point2::new(4.0, 6.0); + let bbox = AxisAlignedBoundingBox::new(point1, point2); - let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0)); - assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0)); - assert_eq!(padded_bbox.size(), Vector2::new(6.0, 8.0)); -} + assert!(bbox.contains_point(&Point2::new(2.0, 3.0))); + assert!(!bbox.contains_point(&Point2::new(5.0, 7.0))); + } -#[test] -fn test_bounding_box_contains_2d() { - use nalgebra::Point2; + #[test] + fn test_bounding_box_union_2d() { + let point1 = Point2::new(1.0, 2.0); + let point2 = Point2::new(4.0, 6.0); + let bbox1 = AxisAlignedBoundingBox::new(point1, point2); - let point1 = Point2::new(1.0, 2.0); - let point2 = Point2::new(4.0, 6.0); - let bbox = AxisAlignedBoundingBox::new(point1, point2); + let point3 = Point2::new(3.0, 5.0); + let point4 = Point2::new(7.0, 8.0); + let bbox2 = AxisAlignedBoundingBox::new(point3, point4); - assert!(bbox.contains_point(&Point2::new(2.0, 3.0))); - assert!(!bbox.contains_point(&Point2::new(5.0, 7.0))); -} + let union_bbox = bbox1.merge(&bbox2); + assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0)); + assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0)); + } -#[test] -fn test_bounding_box_union_2d() { - use nalgebra::{Point2, Vector2}; + #[test] + fn test_bounding_box_intersection_2d() { + let point1 = Point2::new(1.0, 2.0); + let point2 = Point2::new(4.0, 6.0); + let bbox1 = AxisAlignedBoundingBox::new(point1, point2); - let point1 = Point2::new(1.0, 2.0); - let point2 = Point2::new(4.0, 6.0); - let bbox1 = AxisAlignedBoundingBox::new(point1, point2); + let point3 = Point2::new(3.0, 5.0); + let point4 = Point2::new(5.0, 7.0); + let bbox2 = AxisAlignedBoundingBox::new(point3, point4); - let point3 = Point2::new(3.0, 5.0); - let point4 = Point2::new(7.0, 8.0); - let bbox2 = AxisAlignedBoundingBox::new(point3, point4); + let intersection_bbox = bbox1.intersection(&bbox2).unwrap(); + assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0)); + assert_eq!(intersection_bbox.size(), Vector2::new(1.0, 1.0)); + } - let union_bbox = bbox1.merge(&bbox2); - assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0)); - assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0)); -} - -#[test] -fn test_bounding_box_intersection_2d() { - use nalgebra::{Point2, Vector2}; - - let point1 = Point2::new(1.0, 2.0); - let point2 = Point2::new(4.0, 6.0); - let bbox1 = AxisAlignedBoundingBox::new(point1, point2); - - let point3 = Point2::new(3.0, 5.0); - let point4 = Point2::new(5.0, 7.0); - let bbox2 = AxisAlignedBoundingBox::new(point3, point4); - - let intersection_bbox = bbox1.intersection(&bbox2).unwrap(); - assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0)); - assert_eq!(intersection_bbox.size(), Vector2::new(1.0, 1.0)); -} - -#[test] -fn test_bounding_box_contains_point() { - use nalgebra::Point2; - - let point1 = Point2::new(2, 3); - let point2 = Point2::new(5, 4); - let bbox = AxisAlignedBoundingBox::new(point1, point2); - use itertools::Itertools; - for (i, j) in (0..=10).cartesian_product(0..=10) { - if bbox.contains_point(&Point2::new(i, j)) { - if !(2..=5).contains(&i) && !(3..=4).contains(&j) { - panic!( - "Point ({}, {}) should not be contained in the bounding box", - i, j - ); - } - } else { - if (2..=5).contains(&i) && (3..=4).contains(&j) { - panic!( - "Point ({}, {}) should be contained in the bounding box", - i, j - ); + #[test] + fn test_bounding_box_contains_point() { + let point1 = Point2::new(2, 3); + let point2 = Point2::new(5, 4); + let bbox = AxisAlignedBoundingBox::new(point1, point2); + use itertools::Itertools; + for (i, j) in (0..=10).cartesian_product(0..=10) { + if bbox.contains_point(&Point2::new(i, j)) { + if !(2..=5).contains(&i) && !(3..=4).contains(&j) { + panic!( + "Point ({}, {}) should not be contained in the bounding box", + i, j + ); + } + } else { + if (2..=5).contains(&i) && (3..=4).contains(&j) { + panic!( + "Point ({}, {}) should be contained in the bounding box", + i, j + ); + } } } } -} -#[test] -fn test_bounding_box_clamp_box_2d() { - let bbox1 = Aabb2::from_x1y1x2y2(1, 1, 4, 4); - let bbox2 = Aabb2::from_x1y1x2y2(2, 2, 3, 3); - let clamped = bbox2.clamp(&bbox1).unwrap(); - assert_eq!(bbox2, clamped); - let clamped = bbox1.clamp(&bbox2).unwrap(); - assert_eq!(bbox2, clamped); + #[test] + fn test_bounding_box_clamp_box_2d() { + let bbox1 = Aabb2::from_x1y1x2y2(1, 1, 4, 4); + let bbox2 = Aabb2::from_x1y1x2y2(2, 2, 3, 3); + let clamped = bbox2.clamp(&bbox1).unwrap(); + assert_eq!(bbox2, clamped); + let clamped = bbox1.clamp(&bbox2).unwrap(); + assert_eq!(bbox2, clamped); - let bbox1 = Aabb2::from_x1y1x2y2(4, 5, 7, 8); - let bbox2 = Aabb2::from_x1y1x2y2(5, 4, 8, 7); - let clamped = bbox1.clamp(&bbox2).unwrap(); - let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7); - assert_eq!(clamped, expected) + let bbox1 = Aabb2::from_x1y1x2y2(4, 5, 7, 8); + let bbox2 = Aabb2::from_x1y1x2y2(5, 4, 8, 7); + let clamped = bbox1.clamp(&bbox2).unwrap(); + let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7); + assert_eq!(clamped, expected) + } + + #[test] + fn test_iou_identical_boxes() { + let a = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0); + let b = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0); + assert_eq!(a.iou(&b), 1.0); + } + + #[test] + fn test_iou_non_overlapping_boxes() { + let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0); + let b = Aabb2::from_x1y1x2y2(2.0, 2.0, 3.0, 3.0); + assert_eq!(a.iou(&b), 0.0); + } + + #[test] + fn test_iou_partial_overlap() { + let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 2.0, 2.0); + let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0); + // Intersection area = 1, Union area = 7 + assert!((a.iou(&b) - 1.0 / 7.0).abs() < 1e-6); + } + + #[test] + fn test_iou_one_inside_another() { + let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 4.0, 4.0); + let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0); + // Intersection area = 4, Union area = 16 + assert!((a.iou(&b) - 0.25).abs() < 1e-6); + } + + #[test] + fn test_iou_edge_touching() { + let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0); + let b = Aabb2::from_x1y1x2y2(1.0, 0.0, 2.0, 1.0); + assert_eq!(a.iou(&b), 0.0); + } + + #[test] + fn test_iou_corner_touching() { + let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0); + let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 2.0, 2.0); + assert_eq!(a.iou(&b), 0.0); + } + + #[test] + fn test_iou_zero_area_box() { + let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 0.0, 0.0); + let b = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0); + assert_eq!(a.iou(&b), 0.0); + } + + #[test] + fn test_specific_values() { + let res = Vector2::new(1920, 1080).cast(); + let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264).denormalize(res); + let box2 = + Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436).denormalize(res); + dbg!(box1, box2); + assert!(box1.iou(&box2) > 0.0); + } } diff --git a/bounding-box/src/nms.rs b/bounding-box/src/nms.rs index ff534d2..985921e 100644 --- a/bounding-box/src/nms.rs +++ b/bounding-box/src/nms.rs @@ -33,11 +33,7 @@ where .filter(|&i| scores[i] >= score_threshold) .collect(); - indices.sort_by(|&i, &j| { - scores[j] - .partial_cmp(&scores[i]) - .unwrap_or(std::cmp::Ordering::Equal) - }); + indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap()); let mut selected_indices = HashSet::new(); @@ -46,8 +42,12 @@ where indices.remove(0); indices.retain(|&i| { - // let iou = calculate_iou(&boxes[current], &boxes[i]); - let iou = boxes[current].iou(&boxes[i]); + let iou = calculate_iou(&boxes[current], &boxes[i]); + let iou_ = boxes[current].iou(&boxes[i]); + if iou != iou_ { + dbg!(boxes[current], boxes[i]); + panic!() + }; iou < nms_threshold }); } @@ -55,7 +55,7 @@ where selected_indices } -/// Calculate the Intersection over Union (IoU) between two bounding boxes. +/// Calculate the Intersection over Union (IoU) of two bounding boxes. /// /// # Arguments /// @@ -64,36 +64,36 @@ where /// /// # Returns /// -/// The IoU value as a floating-point number. +/// The IoU as a value between 0 and 1. fn calculate_iou(box1: &Aabb2, box2: &Aabb2) -> T where - T: Num + num::Float, - T: core::ops::MulAssign, - T: core::ops::AddAssign, - T: core::ops::SubAssign, - T: nalgebra::SimdValue, - T: nalgebra::SimdPartialOrd, + T: Num + + num::Float + + core::iter::Product + + core::ops::AddAssign + + core::ops::SubAssign + + core::ops::MulAssign + + nalgebra::SimdValue + + nalgebra::SimdPartialOrd, { - // let inter_min_x = box1.min_vertex().x.max(box2.min_vertex().x); - // let inter_min_y = box1.min_vertex().y.max(box2.min_vertex().y); - // let inter_max_x = box1.maxs.x.min(box2.max_vertex().x); - // let inter_max_y = box1.maxs.y.min(box2.max_vertex().y); - let inter_min = box1.min_vertex().inf(&box2.min_vertex()); - let inter_max = box1.max_vertex().sup(&box2.max_vertex()); + let x_left = box1.min_vertex().x.max(box2.min_vertex().x); + let y_top = box1.min_vertex().y.max(box2.min_vertex().y); + let x_right = box1.max_vertex().x.min(box2.max_vertex().x); + let y_bottom = box1.max_vertex().y.min(box2.max_vertex().y); - // let inter_width = (inter_max_x - inter_min_x).max(T::zero()); - // let inter_height = (inter_max_y - inter_min_y).max(T::zero()); - // let inter_width = (inter_max.x - inter_min.x).max(T::zero()); - // let inter_height = (inter_max.y - inter_min.y).max(T::zero()); - // let inter_area = inter_width * inter_height; + let zero = T::zero(); + let inter_width = (x_right - x_left).max(zero); + let inter_height = (y_bottom - y_top).max(zero); + let intersection = inter_width * inter_height; - let inter_area = Aabb2::new(inter_min, inter_max); - let inter_area_2 = box1.intersection(box2); - let union = box1.area() + box2.area() - inter_area.area(); - assert_eq!(Some(inter_area), inter_area_2); - assert_eq!(box1.union(&box2), union); + let area1 = box1.area(); + let area2 = box2.area(); - let inter_area = inter_area.area(); + let union = area1 + area2 - intersection; - inter_area / (box1.area() + box2.area() - inter_area) + if union > zero { + intersection / union + } else { + zero + } } diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index 2f4ea17..20911b0 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -77,6 +77,7 @@ impl Default for FaceDetectionConfig { } } +#[derive(Debug)] pub struct FaceDetection { handle: mnn_sync::SessionHandle, } diff --git a/src/faceembed.rs b/src/faceembed.rs new file mode 100644 index 0000000..e90e6fb --- /dev/null +++ b/src/faceembed.rs @@ -0,0 +1 @@ +pub mod facenet; diff --git a/src/faceembed/facenet.rs b/src/faceembed/facenet.rs new file mode 100644 index 0000000..441a9c3 --- /dev/null +++ b/src/faceembed/facenet.rs @@ -0,0 +1,39 @@ +use crate::errors::*; +use ndarray::{Array1, ArrayView3}; +use std::path::Path; + +#[derive(Debug)] +pub struct EmbeddingGenerator { + handle: mnn_sync::SessionHandle, +} + +impl EmbeddingGenerator { + pub fn new(path: impl AsRef) -> Result { + let model = std::fs::read(path) + .change_context(Error) + .attach_printable("Failed to read model file")?; + Self::new_from_bytes(&model) + } + + pub fn new_from_bytes(model: &[u8]) -> Result { + tracing::info!("Loading face embedding model from bytes"); + let mut model = mnn::Interpreter::from_bytes(model) + .map_err(|e| e.into_inner()) + .change_context(Error) + .attach_printable("Failed to load model from bytes")?; + model.set_session_mode(mnn::SessionMode::Release); + let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); + let sc = mnn::ScheduleConfig::new() + .with_type(mnn::ForwardType::CPU) + .with_backend_config(bc); + tracing::info!("Creating session handle for face embedding model"); + let handle = mnn_sync::SessionHandle::new(model, sc) + .change_context(Error) + .attach_printable("Failed to create session handle")?; + Ok(Self { handle }) + } + + pub fn embedding(&self, roi: ArrayView3) -> Result> { + todo!() + } +} diff --git a/src/lib.rs b/src/lib.rs index ced92a7..965eab5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod errors; pub mod facedet; +pub mod faceembed; pub mod image; use errors::*;