feat: Make nms return result if scores.len() != boxes.len()
This commit is contained in:
@@ -1,6 +1,11 @@
|
|||||||
use std::collections::{HashSet, VecDeque};
|
use std::collections::{HashSet, VecDeque};
|
||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
|
||||||
|
pub enum NmsError {
|
||||||
|
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
|
||||||
|
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
|
||||||
|
}
|
||||||
|
|
||||||
use crate::*;
|
use crate::*;
|
||||||
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
||||||
@@ -20,7 +25,7 @@ pub fn nms<T>(
|
|||||||
scores: &[T],
|
scores: &[T],
|
||||||
score_threshold: T,
|
score_threshold: T,
|
||||||
nms_threshold: T,
|
nms_threshold: T,
|
||||||
) -> HashSet<usize>
|
) -> Result<HashSet<usize>, NmsError>
|
||||||
where
|
where
|
||||||
T: Num
|
T: Num
|
||||||
+ ordered_float::FloatCore
|
+ ordered_float::FloatCore
|
||||||
@@ -32,6 +37,12 @@ where
|
|||||||
+ nalgebra::SimdValue
|
+ nalgebra::SimdValue
|
||||||
+ nalgebra::SimdPartialOrd,
|
+ nalgebra::SimdPartialOrd,
|
||||||
{
|
{
|
||||||
|
if boxes.len() != scores.len() {
|
||||||
|
return Err(NmsError::BoxesAndScoresLengthMismatch {
|
||||||
|
boxes: boxes.len(),
|
||||||
|
scores: scores.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
|
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@@ -55,8 +66,8 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
combined
|
Ok(combined
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
|
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
|
||||||
.collect()
|
.collect())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user