feat: Make nms return result if scores.len() != boxes.len()

This commit is contained in:
uttarayan21
2025-08-07 15:50:48 +05:30
parent e91ae5b865
commit e60921b099

View File

@@ -1,6 +1,11 @@
use std::collections::{HashSet, VecDeque};
use itertools::Itertools;
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum NmsError {
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
}
use crate::*;
/// Apply Non-Maximum Suppression to a set of bounding boxes.
@@ -20,7 +25,7 @@ pub fn nms<T>(
scores: &[T],
score_threshold: T,
nms_threshold: T,
) -> HashSet<usize>
) -> Result<HashSet<usize>, NmsError>
where
T: Num
+ ordered_float::FloatCore
@@ -32,6 +37,12 @@ where
+ nalgebra::SimdValue
+ nalgebra::SimdPartialOrd,
{
if boxes.len() != scores.len() {
return Err(NmsError::BoxesAndScoresLengthMismatch {
boxes: boxes.len(),
scores: scores.len(),
});
}
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
.iter()
.enumerate()
@@ -55,8 +66,8 @@ where
}
}
combined
Ok(combined
.into_iter()
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
.collect()
.collect())
}