feat: Added nms
This commit is contained in:
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -211,7 +211,7 @@ dependencies = [
|
|||||||
"bitflags 2.9.1",
|
"bitflags 2.9.1",
|
||||||
"cexpr",
|
"cexpr",
|
||||||
"clang-sys",
|
"clang-sys",
|
||||||
"itertools 0.12.1",
|
"itertools 0.13.0",
|
||||||
"log",
|
"log",
|
||||||
"prettyplease",
|
"prettyplease",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
@@ -504,6 +504,7 @@ dependencies = [
|
|||||||
"error-stack",
|
"error-stack",
|
||||||
"fast_image_resize",
|
"fast_image_resize",
|
||||||
"image",
|
"image",
|
||||||
|
"itertools 0.14.0",
|
||||||
"linfa",
|
"linfa",
|
||||||
"mnn",
|
"mnn",
|
||||||
"mnn-bridge",
|
"mnn-bridge",
|
||||||
@@ -512,6 +513,7 @@ dependencies = [
|
|||||||
"ndarray 0.16.1",
|
"ndarray 0.16.1",
|
||||||
"ndarray-image",
|
"ndarray-image",
|
||||||
"ndarray-resize",
|
"ndarray-resize",
|
||||||
|
"ordered-float",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"tap",
|
"tap",
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
@@ -1487,6 +1489,15 @@ version = "0.1.11"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
|
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ordered-float"
|
||||||
|
version = "5.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "overload"
|
name = "overload"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ mnn-bridge = { workspace = true }
|
|||||||
mnn-sync = { workspace = true }
|
mnn-sync = { workspace = true }
|
||||||
bounding-box = { version = "0.1.0", path = "bounding-box" }
|
bounding-box = { version = "0.1.0", path = "bounding-box" }
|
||||||
color = "0.3.1"
|
color = "0.3.1"
|
||||||
|
itertools = "0.14.0"
|
||||||
|
ordered-float = "5.0.0"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
|
|||||||
@@ -57,12 +57,17 @@ impl Drawable<Array3<u8>> for Aabb2<usize> {
|
|||||||
lines.into_iter().for_each(|line| {
|
lines.into_iter().for_each(|line| {
|
||||||
canvas
|
canvas
|
||||||
.roi_mut(line)
|
.roi_mut(line)
|
||||||
.expect("Failed to get Roi")
|
.map(|mut line| {
|
||||||
.lanes_mut(ndarray::Axis(2))
|
line.lanes_mut(ndarray::Axis(2))
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.for_each(|mut pixel| {
|
.for_each(|mut pixel| {
|
||||||
pixel.assign(&color);
|
pixel.assign(&color);
|
||||||
});
|
})
|
||||||
|
})
|
||||||
|
.inspect_err(|e| {
|
||||||
|
dbg!(e);
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use bounding_box::Aabb2;
|
use bounding_box::{Aabb2, nms::nms};
|
||||||
use error_stack::ResultExt;
|
use error_stack::ResultExt;
|
||||||
use mnn_bridge::ndarray::*;
|
use mnn_bridge::ndarray::*;
|
||||||
use nalgebra::{Point2, Vector2};
|
use nalgebra::{Point2, Vector2};
|
||||||
@@ -56,6 +56,7 @@ pub struct FaceDetection {
|
|||||||
handle: mnn_sync::SessionHandle,
|
handle: mnn_sync::SessionHandle,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct FaceDetectionModelOutput {
|
pub struct FaceDetectionModelOutput {
|
||||||
pub bbox: ndarray::Array3<f32>,
|
pub bbox: ndarray::Array3<f32>,
|
||||||
pub confidence: ndarray::Array3<f32>,
|
pub confidence: ndarray::Array3<f32>,
|
||||||
@@ -63,6 +64,7 @@ pub struct FaceDetectionModelOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Represents the 5 facial landmarks detected by RetinaFace
|
/// Represents the 5 facial landmarks detected by RetinaFace
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||||
pub struct FaceLandmarks {
|
pub struct FaceLandmarks {
|
||||||
pub left_eye: Point2<f32>,
|
pub left_eye: Point2<f32>,
|
||||||
pub right_eye: Point2<f32>,
|
pub right_eye: Point2<f32>,
|
||||||
@@ -70,14 +72,23 @@ pub struct FaceLandmarks {
|
|||||||
pub left_mouth: Point2<f32>,
|
pub left_mouth: Point2<f32>,
|
||||||
pub right_mouth: Point2<f32>,
|
pub right_mouth: Point2<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct FaceDetectionProcessedOutput {
|
pub struct FaceDetectionProcessedOutput {
|
||||||
pub bbox: Vec<Aabb2<f32>>,
|
pub bbox: Vec<Aabb2<f32>>,
|
||||||
pub confidence: Vec<f32>,
|
pub confidence: Vec<f32>,
|
||||||
pub landmarks: Vec<FaceLandmarks>,
|
pub landmarks: Vec<FaceLandmarks>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct FaceDetectionOutput {
|
||||||
|
pub bbox: Vec<Aabb2<usize>>,
|
||||||
|
pub confidence: Vec<f32>,
|
||||||
|
pub landmark: Vec<FaceLandmarks>,
|
||||||
|
}
|
||||||
|
|
||||||
impl FaceDetectionModelOutput {
|
impl FaceDetectionModelOutput {
|
||||||
pub fn postprocess(self, config: FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||||
let mut anchors = Vec::new();
|
let mut anchors = Vec::new();
|
||||||
for (k, &step) in config.steps.iter().enumerate() {
|
for (k, &step) in config.steps.iter().enumerate() {
|
||||||
let feature_size = 640 / step;
|
let feature_size = 640 / step;
|
||||||
@@ -204,7 +215,48 @@ impl FaceDetection {
|
|||||||
Ok(FaceDetection { handle })
|
Ok(FaceDetection { handle })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn detect_faces(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
|
pub fn detect_faces(
|
||||||
|
&self,
|
||||||
|
image: ndarray::Array3<u8>,
|
||||||
|
config: FaceDetectionConfig,
|
||||||
|
) -> Result<FaceDetectionOutput> {
|
||||||
|
let (height, width, channels) = image.dim();
|
||||||
|
let output = self
|
||||||
|
.run_models(image)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to detect faces")?;
|
||||||
|
// denormalize the bounding boxes
|
||||||
|
let factor = Vector2::new(width as f32, height as f32);
|
||||||
|
let mut processed = output
|
||||||
|
.postprocess(&config)
|
||||||
|
.attach_printable("Failed to postprocess")?;
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
||||||
|
.bbox
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.zip(processed.confidence.iter().cloned())
|
||||||
|
.zip(processed.landmarks.iter().cloned())
|
||||||
|
.sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score))
|
||||||
|
.map(|((b, s), l)| (b, s, l))
|
||||||
|
.multiunzip();
|
||||||
|
|
||||||
|
let boxes = nms(&boxes, &scores, config.threshold, config.nms_threshold);
|
||||||
|
|
||||||
|
let bboxes = boxes
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|x| x.denormalize(factor).try_cast::<usize>())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(FaceDetectionOutput {
|
||||||
|
bbox: bboxes,
|
||||||
|
confidence: processed.confidence,
|
||||||
|
landmark: processed.landmarks,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
use ::tap::*;
|
use ::tap::*;
|
||||||
let output = self
|
let output = self
|
||||||
|
|||||||
20
src/main.rs
20
src/main.rs
@@ -25,25 +25,15 @@ pub fn main() -> Result<()> {
|
|||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to convert image to ndarray")?;
|
.attach_printable("Failed to convert image to ndarray")?;
|
||||||
let output = model
|
let output = model
|
||||||
.detect_faces(array.clone())
|
.detect_faces(
|
||||||
|
array.clone(),
|
||||||
|
FaceDetectionConfig::default().with_threshold(detect.threshold),
|
||||||
|
)
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to detect faces")?;
|
.attach_printable("Failed to detect faces")?;
|
||||||
// output.print(20);
|
for bbox in output.bbox {
|
||||||
let faces = output
|
|
||||||
.postprocess(FaceDetectionConfig::default().with_threshold(detect.threshold))
|
|
||||||
.change_context(errors::Error)
|
|
||||||
.attach_printable("Failed to attach context")?;
|
|
||||||
for bbox in faces.bbox {
|
|
||||||
tracing::info!("Detected face: {:?}", bbox);
|
tracing::info!("Detected face: {:?}", bbox);
|
||||||
use bounding_box::draw::*;
|
use bounding_box::draw::*;
|
||||||
let bbox = bbox
|
|
||||||
.denormalize(nalgebra::SVector::<f32, 2>::new(
|
|
||||||
array.shape()[1] as f32,
|
|
||||||
array.shape()[0] as f32,
|
|
||||||
))
|
|
||||||
.try_cast()
|
|
||||||
.ok_or(errors::Error)
|
|
||||||
.attach_printable("Failed to cast f32 to usize")?;
|
|
||||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10);
|
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 10);
|
||||||
}
|
}
|
||||||
let v = array.view();
|
let v = array.view();
|
||||||
|
|||||||
Reference in New Issue
Block a user