feat: Added threshold for scores and nms
This commit is contained in:
@@ -2,7 +2,6 @@ pub mod draw;
|
|||||||
pub mod nms;
|
pub mod nms;
|
||||||
pub mod roi;
|
pub mod roi;
|
||||||
|
|
||||||
use itertools::Itertools;
|
|
||||||
use nalgebra::{Point, Point2, Point3, SVector};
|
use nalgebra::{Point, Point2, Point3, SVector};
|
||||||
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
|
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
|
||||||
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
|
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
|
||||||
@@ -458,6 +457,7 @@ fn test_bounding_box_contains_point() {
|
|||||||
let point1 = Point2::new(2, 3);
|
let point1 = Point2::new(2, 3);
|
||||||
let point2 = Point2::new(5, 4);
|
let point2 = Point2::new(5, 4);
|
||||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||||
|
use itertools::Itertools;
|
||||||
for (i, j) in (0..=10).cartesian_product(0..=10) {
|
for (i, j) in (0..=10).cartesian_product(0..=10) {
|
||||||
if bbox.contains_point(&Point2::new(i, j)) {
|
if bbox.contains_point(&Point2::new(i, j)) {
|
||||||
if !(2..=5).contains(&i) && !(3..=4).contains(&j) {
|
if !(2..=5).contains(&i) && !(3..=4).contains(&j) {
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ pub struct Detect {
|
|||||||
pub model_type: Models,
|
pub model_type: Models,
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
pub output: Option<PathBuf>,
|
pub output: Option<PathBuf>,
|
||||||
|
#[clap(short, long, default_value_t = 0.8)]
|
||||||
|
pub threshold: f32,
|
||||||
pub image: PathBuf,
|
pub image: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,31 @@ pub struct FaceDetectionConfig {
|
|||||||
min_sizes: Vec<Vector2<usize>>,
|
min_sizes: Vec<Vector2<usize>>,
|
||||||
steps: Vec<usize>,
|
steps: Vec<usize>,
|
||||||
variance: Vec<f32>,
|
variance: Vec<f32>,
|
||||||
|
threshold: f32,
|
||||||
|
nms_threshold: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetectionConfig {
|
||||||
|
pub fn with_min_sizes(mut self, min_sizes: Vec<Vector2<usize>>) -> Self {
|
||||||
|
self.min_sizes = min_sizes;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||||
|
self.steps = steps;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
pub fn with_variance(mut self, variance: Vec<f32>) -> Self {
|
||||||
|
self.variance = variance;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
pub fn with_threshold(mut self, threshold: f32) -> Self {
|
||||||
|
self.threshold = threshold;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
|
||||||
|
self.nms_threshold = nms_threshold;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for FaceDetectionConfig {
|
impl Default for FaceDetectionConfig {
|
||||||
@@ -22,6 +47,8 @@ impl Default for FaceDetectionConfig {
|
|||||||
],
|
],
|
||||||
steps: vec![8, 16, 32],
|
steps: vec![8, 16, 32],
|
||||||
variance: vec![0.1, 0.2],
|
variance: vec![0.1, 0.2],
|
||||||
|
threshold: 0.8,
|
||||||
|
nms_threshold: 0.6,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -35,8 +62,13 @@ pub struct FaceDetectionModelOutput {
|
|||||||
pub landmark: ndarray::Array3<f32>,
|
pub landmark: ndarray::Array3<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct FaceDetectionProcessedOutput {
|
||||||
|
pub bbox: Vec<Aabb2<f32>>,
|
||||||
|
pub confidence: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
impl FaceDetectionModelOutput {
|
impl FaceDetectionModelOutput {
|
||||||
pub fn postprocess(self, config: FaceDetectionConfig) -> Result<Vec<Aabb2<f32>>> {
|
pub fn postprocess(self, config: FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||||
let mut anchors = Vec::new();
|
let mut anchors = Vec::new();
|
||||||
for (k, &step) in config.steps.iter().enumerate() {
|
for (k, &step) in config.steps.iter().enumerate() {
|
||||||
let feature_size = 640 / step;
|
let feature_size = 640 / step;
|
||||||
@@ -54,6 +86,7 @@ impl FaceDetectionModelOutput {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
let mut boxes = Vec::new();
|
let mut boxes = Vec::new();
|
||||||
|
let mut scores = Vec::new();
|
||||||
let var0 = config.variance[0];
|
let var0 = config.variance[0];
|
||||||
let var1 = config.variance[1];
|
let var1 = config.variance[1];
|
||||||
let bbox_data = self.bbox;
|
let bbox_data = self.bbox;
|
||||||
@@ -74,14 +107,15 @@ impl FaceDetectionModelOutput {
|
|||||||
let x_max = pred_cx + pred_w / 2.0;
|
let x_max = pred_cx + pred_w / 2.0;
|
||||||
let y_max = pred_cy + pred_h / 2.0;
|
let y_max = pred_cy + pred_h / 2.0;
|
||||||
let score = conf_data[[0, idx, 1]];
|
let score = conf_data[[0, idx, 1]];
|
||||||
if score > 0.6 {
|
if score > config.threshold {
|
||||||
boxes.push(Aabb2::from_min_max_vertices(
|
boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max));
|
||||||
Point2::new(x_min, y_min),
|
scores.push(score);
|
||||||
Point2::new(x_max, y_max),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(boxes)
|
Ok(FaceDetectionProcessedOutput {
|
||||||
|
bbox: boxes,
|
||||||
|
confidence: scores,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
mod cli;
|
mod cli;
|
||||||
mod errors;
|
mod errors;
|
||||||
|
use detector::facedet::retinaface::FaceDetectionConfig;
|
||||||
use errors::*;
|
use errors::*;
|
||||||
use ndarray_image::*;
|
use ndarray_image::*;
|
||||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||||
@@ -29,11 +30,11 @@ pub fn main() -> Result<()> {
|
|||||||
.attach_printable("Failed to detect faces")?;
|
.attach_printable("Failed to detect faces")?;
|
||||||
// output.print(20);
|
// output.print(20);
|
||||||
let aabbs = output
|
let aabbs = output
|
||||||
.postprocess(Default::default())
|
.postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold))
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to attach context")?;
|
.attach_printable("Failed to attach context")?;
|
||||||
for bbox in aabbs {
|
for bbox in aabbs {
|
||||||
println!("Detected face: {:?}", bbox);
|
tracing::info!("Detected face: {:?}", bbox);
|
||||||
use bounding_box::draw::*;
|
use bounding_box::draw::*;
|
||||||
let bbox = bbox
|
let bbox = bbox
|
||||||
.denormalize(nalgebra::SVector::<f32, 2>::new(
|
.denormalize(nalgebra::SVector::<f32, 2>::new(
|
||||||
|
|||||||
Reference in New Issue
Block a user