feat: Make nms return result if scores.len() != boxes.len()
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
|
||||
use itertools::Itertools;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum NmsError {
|
||||
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
|
||||
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
|
||||
}
|
||||
|
||||
use crate::*;
|
||||
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
||||
@@ -20,7 +25,7 @@ pub fn nms<T>(
|
||||
scores: &[T],
|
||||
score_threshold: T,
|
||||
nms_threshold: T,
|
||||
) -> HashSet<usize>
|
||||
) -> Result<HashSet<usize>, NmsError>
|
||||
where
|
||||
T: Num
|
||||
+ ordered_float::FloatCore
|
||||
@@ -32,6 +37,12 @@ where
|
||||
+ nalgebra::SimdValue
|
||||
+ nalgebra::SimdPartialOrd,
|
||||
{
|
||||
if boxes.len() != scores.len() {
|
||||
return Err(NmsError::BoxesAndScoresLengthMismatch {
|
||||
boxes: boxes.len(),
|
||||
scores: scores.len(),
|
||||
});
|
||||
}
|
||||
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -55,8 +66,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
combined
|
||||
Ok(combined
|
||||
.into_iter()
|
||||
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
|
||||
.collect()
|
||||
.collect())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user