From e60921b09990651889d756e758b8617da8644a65 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Thu, 7 Aug 2025 15:50:48 +0530 Subject: [PATCH] feat: Make nms return result if scores.len() != boxes.len() --- bounding-box/src/nms.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/bounding-box/src/nms.rs b/bounding-box/src/nms.rs index 18838fe..bc3a41b 100644 --- a/bounding-box/src/nms.rs +++ b/bounding-box/src/nms.rs @@ -1,6 +1,11 @@ use std::collections::{HashSet, VecDeque}; use itertools::Itertools; +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum NmsError { + #[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")] + BoxesAndScoresLengthMismatch { boxes: usize, scores: usize }, +} use crate::*; /// Apply Non-Maximum Suppression to a set of bounding boxes. @@ -20,7 +25,7 @@ pub fn nms( scores: &[T], score_threshold: T, nms_threshold: T, -) -> HashSet +) -> Result, NmsError> where T: Num + ordered_float::FloatCore @@ -32,6 +37,12 @@ where + nalgebra::SimdValue + nalgebra::SimdPartialOrd, { + if boxes.len() != scores.len() { + return Err(NmsError::BoxesAndScoresLengthMismatch { + boxes: boxes.len(), + scores: scores.len(), + }); + } let mut combined: VecDeque<(usize, Aabb2, T, bool)> = boxes .iter() .enumerate() @@ -55,8 +66,8 @@ where } } - combined + Ok(combined .into_iter() .filter_map(|(idx, _, _, keep)| keep.then_some(idx)) - .collect() + .collect()) }