feat: Added a manual implementation of nms
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -255,6 +255,7 @@ dependencies = [
|
|||||||
"nalgebra",
|
"nalgebra",
|
||||||
"ndarray 0.16.1",
|
"ndarray 0.16.1",
|
||||||
"num",
|
"num",
|
||||||
|
"ordered-float",
|
||||||
"simba",
|
"simba",
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ itertools = "0.14.0"
|
|||||||
nalgebra = "0.33.2"
|
nalgebra = "0.33.2"
|
||||||
ndarray = { version = "0.16.1", optional = true }
|
ndarray = { version = "0.16.1", optional = true }
|
||||||
num = "0.4.3"
|
num = "0.4.3"
|
||||||
|
ordered-float = "5.0.0"
|
||||||
simba = "0.9.0"
|
simba = "0.9.0"
|
||||||
thiserror = "2.0.12"
|
thiserror = "2.0.12"
|
||||||
|
|
||||||
|
|||||||
@@ -51,10 +51,12 @@ pub type Aabb3<T> = AxisAlignedBoundingBox<T, 3>;
|
|||||||
impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||||
// Panics if max < min
|
// Panics if max < min
|
||||||
pub fn new(min_point: Point<T, D>, max_point: Point<T, D>) -> Self {
|
pub fn new(min_point: Point<T, D>, max_point: Point<T, D>) -> Self {
|
||||||
if max_point < min_point {
|
if max_point >= min_point {
|
||||||
|
Self::from_min_max_vertices(min_point, max_point)
|
||||||
|
} else {
|
||||||
|
dbg!(max_point, min_point);
|
||||||
panic!("max_point must be greater than or equal to min_point");
|
panic!("max_point must be greater than or equal to min_point");
|
||||||
}
|
}
|
||||||
Self::from_min_max_vertices(min_point, max_point)
|
|
||||||
}
|
}
|
||||||
pub fn try_new(min_point: Point<T, D>, max_point: Point<T, D>) -> Option<Self> {
|
pub fn try_new(min_point: Point<T, D>, max_point: Point<T, D>) -> Option<Self> {
|
||||||
if max_point < min_point {
|
if max_point < min_point {
|
||||||
@@ -66,9 +68,9 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
|||||||
Self { point, size }
|
Self { point, size }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_min_max_vertices(point1: Point<T, D>, point2: Point<T, D>) -> Self {
|
pub fn from_min_max_vertices(min: Point<T, D>, max: Point<T, D>) -> Self {
|
||||||
let size = point2 - point1;
|
let size = max - min;
|
||||||
Self::new_point_size(point1, SVector::from(size))
|
Self::new_point_size(min, SVector::from(size))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Only considers the points closest and furthest from origin
|
/// Only considers the points closest and furthest from origin
|
||||||
@@ -301,11 +303,11 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
|||||||
|
|
||||||
let inter_min = lhs_min.sup(&rhs_min);
|
let inter_min = lhs_min.sup(&rhs_min);
|
||||||
let inter_max = lhs_max.inf(&rhs_max);
|
let inter_max = lhs_max.inf(&rhs_max);
|
||||||
if inter_max < inter_min {
|
if inter_max >= inter_min {
|
||||||
return T::zero();
|
|
||||||
} else {
|
|
||||||
let intersection = Aabb::new(inter_min, inter_max).measure();
|
let intersection = Aabb::new(inter_min, inter_max).measure();
|
||||||
intersection / (self.measure() + other.measure() - intersection)
|
intersection / (self.measure() + other.measure() - intersection)
|
||||||
|
} else {
|
||||||
|
return T::zero();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -605,11 +607,8 @@ mod boudning_box_tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_specific_values() {
|
fn test_specific_values() {
|
||||||
let res = Vector2::new(1920, 1080).cast();
|
let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264);
|
||||||
let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264).denormalize(res);
|
let box2 = Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436);
|
||||||
let box2 =
|
assert!(box1.iou(&box2) >= 0.0);
|
||||||
Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436).denormalize(res);
|
|
||||||
dbg!(box1, box2);
|
|
||||||
assert!(box1.iou(&box2) > 0.0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::{HashSet, VecDeque};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::*;
|
use crate::*;
|
||||||
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
||||||
@@ -21,7 +23,8 @@ pub fn nms<T>(
|
|||||||
) -> HashSet<usize>
|
) -> HashSet<usize>
|
||||||
where
|
where
|
||||||
T: Num
|
T: Num
|
||||||
+ num::Float
|
+ ordered_float::FloatCore
|
||||||
|
+ core::ops::Neg<Output = T>
|
||||||
+ core::iter::Product<T>
|
+ core::iter::Product<T>
|
||||||
+ core::ops::AddAssign
|
+ core::ops::AddAssign
|
||||||
+ core::ops::SubAssign
|
+ core::ops::SubAssign
|
||||||
@@ -29,71 +32,31 @@ where
|
|||||||
+ nalgebra::SimdValue
|
+ nalgebra::SimdValue
|
||||||
+ nalgebra::SimdPartialOrd,
|
+ nalgebra::SimdPartialOrd,
|
||||||
{
|
{
|
||||||
let mut indices: Vec<usize> = (0..boxes.len())
|
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
|
||||||
.filter(|&i| scores[i] >= score_threshold)
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.zip(scores)
|
||||||
|
.filter_map(|((idx, bbox), score)| {
|
||||||
|
(*score > score_threshold).then_some((idx, *bbox, *score, true))
|
||||||
|
})
|
||||||
|
.sorted_by_cached_key(|(_, _, score, _)| -ordered_float::OrderedFloat(*score))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap());
|
for i in 0..combined.len() {
|
||||||
|
let first = combined[i];
|
||||||
let mut selected_indices = HashSet::new();
|
if first.3 == false {
|
||||||
|
continue;
|
||||||
while let Some(¤t) = indices.first() {
|
}
|
||||||
selected_indices.insert(current);
|
let bbox = first.1;
|
||||||
indices.remove(0);
|
for item in combined.iter_mut().skip(i + 1) {
|
||||||
|
if bbox.iou(&item.1) > nms_threshold {
|
||||||
indices.retain(|&i| {
|
item.3 = false
|
||||||
let iou = calculate_iou(&boxes[current], &boxes[i]);
|
}
|
||||||
let iou_ = boxes[current].iou(&boxes[i]);
|
}
|
||||||
if iou != iou_ {
|
|
||||||
dbg!(boxes[current], boxes[i]);
|
|
||||||
panic!()
|
|
||||||
};
|
|
||||||
iou < nms_threshold
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
selected_indices
|
combined
|
||||||
}
|
.into_iter()
|
||||||
|
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
|
||||||
/// Calculate the Intersection over Union (IoU) of two bounding boxes.
|
.collect()
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `box1` - The first bounding box.
|
|
||||||
/// * `box2` - The second bounding box.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// The IoU as a value between 0 and 1.
|
|
||||||
fn calculate_iou<T>(box1: &Aabb2<T>, box2: &Aabb2<T>) -> T
|
|
||||||
where
|
|
||||||
T: Num
|
|
||||||
+ num::Float
|
|
||||||
+ core::iter::Product<T>
|
|
||||||
+ core::ops::AddAssign
|
|
||||||
+ core::ops::SubAssign
|
|
||||||
+ core::ops::MulAssign
|
|
||||||
+ nalgebra::SimdValue
|
|
||||||
+ nalgebra::SimdPartialOrd,
|
|
||||||
{
|
|
||||||
let x_left = box1.min_vertex().x.max(box2.min_vertex().x);
|
|
||||||
let y_top = box1.min_vertex().y.max(box2.min_vertex().y);
|
|
||||||
let x_right = box1.max_vertex().x.min(box2.max_vertex().x);
|
|
||||||
let y_bottom = box1.max_vertex().y.min(box2.max_vertex().y);
|
|
||||||
|
|
||||||
let zero = T::zero();
|
|
||||||
let inter_width = (x_right - x_left).max(zero);
|
|
||||||
let inter_height = (y_bottom - y_top).max(zero);
|
|
||||||
let intersection = inter_width * inter_height;
|
|
||||||
|
|
||||||
let area1 = box1.area();
|
|
||||||
let area2 = box2.area();
|
|
||||||
|
|
||||||
let union = area1 + area2 - intersection;
|
|
||||||
|
|
||||||
if union > zero {
|
|
||||||
intersection / union
|
|
||||||
} else {
|
|
||||||
zero
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ pub fn main() -> Result<()> {
|
|||||||
let output = model
|
let output = model
|
||||||
.detect_faces(
|
.detect_faces(
|
||||||
array.clone(),
|
array.clone(),
|
||||||
FaceDetectionConfig::default().with_threshold(detect.threshold),
|
FaceDetectionConfig::default()
|
||||||
|
.with_threshold(detect.threshold)
|
||||||
|
.with_nms_threshold(detect.nms_threshold),
|
||||||
)
|
)
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to detect faces")?;
|
.attach_printable("Failed to detect faces")?;
|
||||||
|
|||||||
Reference in New Issue
Block a user