Compare commits

...

10 Commits

Author SHA1 Message Date
uttarayan21
2d2309837f feat: Added stuff
Some checks failed
build / checks-matrix (push) Successful in 23m6s
build / codecov (push) Failing after 19m30s
docs / docs (push) Failing after 28m54s
build / checks-build (push) Has been cancelled
2025-08-13 18:08:03 +05:30
uttarayan21
f5740dc87f feat: Added .gitattributes and .gitignore 2025-08-08 15:19:59 +05:30
uttarayan21
3753e399b1 feat: Added models 2025-08-08 15:15:50 +05:30
uttarayan21
d52b69911f feat: Added facenet 2025-08-08 15:01:25 +05:30
uttarayan21
a3ea01b7b6 feat: Added facenet 2025-08-07 17:24:01 +05:30
uttarayan21
e60921b099 feat: Make nms return result if scores.len() != boxes.len() 2025-08-07 15:50:48 +05:30
uttarayan21
e91ae5b865 feat: Added a manual implementation of nms 2025-08-07 15:45:54 +05:30
uttarayan21
2c43f657aa backup: broken backup 2025-08-07 13:30:34 +05:30
uttarayan21
8d07b0846c feat: Working retinaface 2025-08-07 11:51:10 +05:30
uttarayan21
f7aae32caf broken: Remove the FaceDetectionConfig 2025-08-05 19:17:31 +05:30
25 changed files with 1875 additions and 609 deletions

4
.gitattributes vendored Normal file
View File

@@ -0,0 +1,4 @@
models/retinaface.mnn filter=lfs diff=lfs merge=lfs -text
models/facenet.mnn filter=lfs diff=lfs merge=lfs -text
models/facenet.onnx filter=lfs diff=lfs merge=lfs -text
models/retinaface.onnx filter=lfs diff=lfs merge=lfs -text

3
.gitignore vendored
View File

@@ -2,3 +2,6 @@
/target
.direnv
*.jpg
face_net.onnx
.DS_Store
*.cache

968
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,8 +12,8 @@ mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" }
ndarray-image = { path = "ndarray-image" }
ndarray-resize = { path = "ndarray-resize" }
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
# "metal",
# "coreml",
"metal",
"coreml",
"tracing",
], branch = "restructure-tensor-type" }
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
@@ -22,6 +22,7 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
"tracing",
], branch = "restructure-tensor-type" }
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
[package]
name = "detector"
@@ -35,8 +36,7 @@ clap_complete = "4.5"
error-stack = "0.5"
fast_image_resize = "5.2.0"
image = "0.25.6"
linfa = "0.7.1"
nalgebra = "0.33.2"
nalgebra = { workspace = true }
ndarray = "0.16.1"
ndarray-image = { workspace = true }
ndarray-resize = { workspace = true }
@@ -53,6 +53,7 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
color = "0.3.1"
itertools = "0.14.0"
ordered-float = "5.0.0"
ort = "2.0.0-rc.10"
[profile.release]
debug = true

27
Makefile.toml Normal file
View File

@@ -0,0 +1,27 @@
[tasks.convert_facenet]
command = "MNNConvert"
args = [
"-f",
"ONNX",
"--modelFile",
"models/facenet.onnx",
"--MNNModel",
"models/facenet.mnn",
"--fp16",
"--bizCode",
"MNN",
]
[tasks.convert_retinaface]
command = "MNNConvert"
args = [
"-f",
"ONNX",
"--modelFile",
"models/retinaface.onnx",
"--MNNModel",
"models/retinaface.mnn",
"--fp16",
"--bizCode",
"MNN",
]

View File

@@ -6,9 +6,10 @@ edition = "2024"
[dependencies]
color = "0.3.1"
itertools = "0.14.0"
nalgebra = "0.33.2"
nalgebra = { workspace = true }
ndarray = { version = "0.16.1", optional = true }
num = "0.4.3"
ordered-float = "5.0.0"
simba = "0.9.0"
thiserror = "2.0.12"

View File

@@ -4,11 +4,11 @@ pub use color::Rgba8;
use ndarray::{Array1, Array3, ArrayViewMut3};
pub trait Draw<T> {
fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize);
fn draw(&mut self, item: &T, color: color::Rgba8, thickness: usize);
}
impl Draw<Aabb2<usize>> for Array3<u8> {
fn draw(&mut self, item: Aabb2<usize>, color: color::Rgba8, thickness: usize) {
fn draw(&mut self, item: &Aabb2<usize>, color: color::Rgba8, thickness: usize) {
item.draw(self, color, thickness)
}
}

View File

@@ -2,9 +2,38 @@ pub mod draw;
pub mod nms;
pub mod roi;
use nalgebra::{Point, Point2, Point3, SVector, SimdPartialOrd, SimdValue};
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
use nalgebra::{Point, Point2, SVector, Vector2};
pub trait Num:
num::Num
+ core::ops::AddAssign
+ core::ops::SubAssign
+ core::ops::MulAssign
+ core::ops::DivAssign
+ core::cmp::PartialOrd
+ core::cmp::PartialEq
+ nalgebra::SimdPartialOrd
+ nalgebra::SimdValue
+ Copy
+ core::fmt::Debug
+ 'static
{
}
impl<
T: num::Num
+ core::ops::AddAssign
+ core::ops::SubAssign
+ core::ops::MulAssign
+ core::ops::DivAssign
+ core::cmp::PartialOrd
+ core::cmp::PartialEq
+ nalgebra::SimdPartialOrd
+ nalgebra::SimdValue
+ Copy
+ core::fmt::Debug
+ 'static,
> Num for T
{
}
/// An axis aligned bounding box in `D` dimensions, defined by the minimum vertex and a size vector.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
@@ -20,16 +49,28 @@ pub type Aabb2<T> = AxisAlignedBoundingBox<T, 2>;
pub type Aabb3<T> = AxisAlignedBoundingBox<T, 3>;
impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
pub fn new(point: Point<T, D>, size: SVector<T, D>) -> Self {
// Panics if max < min
pub fn new(min_point: Point<T, D>, max_point: Point<T, D>) -> Self {
if max_point >= min_point {
Self::from_min_max_vertices(min_point, max_point)
} else {
dbg!(max_point, min_point);
panic!("max_point must be greater than or equal to min_point");
}
}
pub fn try_new(min_point: Point<T, D>, max_point: Point<T, D>) -> Option<Self> {
if max_point < min_point {
return None;
}
Some(Self::from_min_max_vertices(min_point, max_point))
}
pub fn new_point_size(point: Point<T, D>, size: SVector<T, D>) -> Self {
Self { point, size }
}
pub fn from_min_max_vertices(point1: Point<T, D>, point2: Point<T, D>) -> Self
where
T: core::ops::SubAssign,
{
let size = point2 - point1;
Self::new(point1, SVector::from(size))
pub fn from_min_max_vertices(min: Point<T, D>, max: Point<T, D>) -> Self {
let size = max - min;
Self::new_point_size(min, SVector::from(size))
}
/// Only considers the points closest and furthest from origin
@@ -151,7 +192,21 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
self.intersection(other)
}
pub fn union(&self, other: &Self) -> Self
pub fn component_clamp(&self, min: T, max: T) -> Self
where
T: PartialOrd,
{
let mut this = *self;
this.point.iter_mut().for_each(|x| {
*x = nalgebra::clamp(*x, min, max);
});
this.size.iter_mut().for_each(|x| {
*x = nalgebra::clamp(*x, min, max);
});
this
}
pub fn merge(&self, other: &Self) -> Self
where
T: core::ops::AddAssign,
T: core::ops::SubAssign,
@@ -159,13 +214,24 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
T: nalgebra::SimdValue,
T: nalgebra::SimdPartialOrd,
{
let self_min = self.min_vertex();
let self_max = self.max_vertex();
let other_min = other.min_vertex();
let other_max = other.max_vertex();
let min = self_min.inf(&other_min);
let max = self_max.sup(&other_max);
Self::from_min_max_vertices(min, max)
let min = self.min_vertex().inf(&other.min_vertex());
let max = self.min_vertex().sup(&other.max_vertex());
Self::new(min, max)
}
pub fn union(&self, other: &Self) -> T
where
T: core::ops::AddAssign,
T: core::ops::SubAssign,
T: core::ops::MulAssign,
T: PartialOrd,
T: nalgebra::SimdValue,
T: nalgebra::SimdPartialOrd,
{
self.measure() + other.measure()
- Self::intersection(self, other)
.map(|x| x.measure())
.unwrap_or(T::zero())
}
pub fn intersection(&self, other: &Self) -> Option<Self>
@@ -176,21 +242,9 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
T: nalgebra::SimdPartialOrd,
T: nalgebra::SimdValue,
{
let self_min = self.min_vertex();
let self_max = self.max_vertex();
let other_min = other.min_vertex();
let other_max = other.max_vertex();
if self_max < other_min || other_max < self_min {
return None; // No intersection
}
let min = self_min.sup(&other_min);
let max = self_max.inf(&other_max);
Some(Self::from_min_max_vertices(
Point::from(min),
Point::from(max),
))
let inter_min = self.min_vertex().sup(&other.min_vertex());
let inter_max = self.max_vertex().inf(&other.max_vertex());
Self::try_new(inter_min, inter_max)
}
pub fn denormalize(&self, factor: nalgebra::SVector<T, D>) -> Self
@@ -233,7 +287,7 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
self.size.product()
}
pub fn iou(&self, other: &Self) -> Option<T>
pub fn iou(&self, other: &Self) -> T
where
T: core::ops::AddAssign,
T: core::ops::SubAssign,
@@ -242,9 +296,19 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
T: nalgebra::SimdValue,
T: core::ops::MulAssign,
{
let intersection = self.intersection(other)?;
let union = self.union(other);
Some(intersection.measure() / union.measure())
let lhs_min = self.min_vertex();
let lhs_max = self.max_vertex();
let rhs_min = other.min_vertex();
let rhs_max = other.max_vertex();
let inter_min = lhs_min.sup(&rhs_min);
let inter_max = lhs_max.inf(&rhs_max);
if inter_max >= inter_min {
let intersection = Aabb::new(inter_min, inter_max).measure();
intersection / (self.measure() + other.measure() - intersection)
} else {
return T::zero();
}
}
}
@@ -255,15 +319,15 @@ impl<T: Num> Aabb2<T> {
{
let point1 = Point2::new(x1, y1);
let point2 = Point2::new(x2, y2);
Self::from_min_max_vertices(point1, point2)
Self::new(point1, point2)
}
pub fn new_2d(point1: Point2<T>, point2: Point2<T>) -> Self
where
T: core::ops::SubAssign,
{
let size = point2.coords - point1.coords;
Self::new(point1, SVector::from(size))
pub fn from_xywh(x: T, y: T, w: T, h: T) -> Self {
let point = Point2::new(x, y);
let size = Vector2::new(w, h);
Self::new_point_size(point, size)
}
pub fn x1y1(&self) -> Point2<T> {
self.point
}
@@ -327,14 +391,6 @@ impl<T: Num> Aabb2<T> {
}
impl<T: Num> Aabb3<T> {
pub fn new_3d(point1: Point3<T>, point2: Point3<T>) -> Self
where
T: core::ops::SubAssign,
{
let size = point2.coords - point1.coords;
Self::new(point1, SVector::from(size))
}
pub fn volume(&self) -> T
where
T: core::ops::MulAssign,
@@ -343,156 +399,216 @@ impl<T: Num> Aabb3<T> {
}
}
#[test]
fn test_bbox_new() {
use nalgebra::{Point2, Vector2};
#[cfg(test)]
mod boudning_box_tests {
use super::*;
use nalgebra::*;
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
#[test]
fn test_bbox_new() {
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox = AxisAlignedBoundingBox::new(point1, point2);
assert_eq!(bbox.min_vertex(), point1);
assert_eq!(bbox.size(), Vector2::new(3.0, 4.0));
assert_eq!(bbox.center(), Point2::new(2.5, 4.0));
}
assert_eq!(bbox.min_vertex(), point1);
assert_eq!(bbox.size(), Vector2::new(3.0, 4.0));
assert_eq!(bbox.center(), Point2::new(2.5, 4.0));
}
#[test]
fn test_bounding_box_center_2d() {
use nalgebra::{Point2, Vector2};
#[test]
fn test_intersection_and_merge() {
let point1 = Point2::new(1, 5);
let point2 = Point2::new(3, 2);
let size1 = Vector2::new(3, 4);
let size2 = Vector2::new(1, 3);
let point = Point2::new(1.0, 2.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
let this = Aabb2::new_point_size(point1, size1);
let other = Aabb2::new_point_size(point2, size2);
let inter = this.intersection(&other);
let merged = this.merge(&other);
assert_ne!(inter, Some(merged))
}
assert_eq!(bbox.min_vertex(), point);
assert_eq!(bbox.size(), size);
assert_eq!(bbox.center(), Point2::new(2.5, 4.0));
}
#[test]
fn test_bounding_box_center_2d() {
let point = Point2::new(1.0, 2.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
#[test]
fn test_bounding_box_center_3d() {
use nalgebra::{Point3, Vector3};
assert_eq!(bbox.min_vertex(), point);
assert_eq!(bbox.size(), size);
assert_eq!(bbox.center(), Point2::new(2.5, 4.0));
}
let point = Point3::new(1.0, 2.0, 3.0);
let size = Vector3::new(4.0, 5.0, 6.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
#[test]
fn test_bounding_box_center_3d() {
let point = Point3::new(1.0, 2.0, 3.0);
let size = Vector3::new(4.0, 5.0, 6.0);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
assert_eq!(bbox.min_vertex(), point);
assert_eq!(bbox.size(), size);
assert_eq!(bbox.center(), Point3::new(3.0, 4.5, 6.0));
}
assert_eq!(bbox.min_vertex(), point);
assert_eq!(bbox.size(), size);
assert_eq!(bbox.center(), Point3::new(3.0, 4.5, 6.0));
}
#[test]
fn test_bounding_box_padding_2d() {
use nalgebra::{Point2, Vector2};
#[test]
fn test_bounding_box_padding_2d() {
let point = Point2::new(1.0, 2.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
let point = Point2::new(1.0, 2.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
let padded_bbox = bbox.padding(1.0);
assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5));
assert_eq!(padded_bbox.size(), Vector2::new(4.0, 5.0));
}
let padded_bbox = bbox.padding(1.0);
assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5));
assert_eq!(padded_bbox.size(), Vector2::new(4.0, 5.0));
}
#[test]
fn test_bounding_box_scaling_2d() {
let point = Point2::new(1.0, 1.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
#[test]
fn test_bounding_box_scaling_2d() {
use nalgebra::{Point2, Vector2};
let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0));
assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0));
assert_eq!(padded_bbox.size(), Vector2::new(6.0, 8.0));
}
let point = Point2::new(1.0, 1.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
#[test]
fn test_bounding_box_contains_2d() {
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox = AxisAlignedBoundingBox::new(point1, point2);
let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0));
assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0));
assert_eq!(padded_bbox.size(), Vector2::new(6.0, 8.0));
}
assert!(bbox.contains_point(&Point2::new(2.0, 3.0)));
assert!(!bbox.contains_point(&Point2::new(5.0, 7.0)));
}
#[test]
fn test_bounding_box_contains_2d() {
use nalgebra::Point2;
#[test]
fn test_bounding_box_union_2d() {
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
let point3 = Point2::new(3.0, 5.0);
let point4 = Point2::new(7.0, 8.0);
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
assert!(bbox.contains_point(&Point2::new(2.0, 3.0)));
assert!(!bbox.contains_point(&Point2::new(5.0, 7.0)));
}
let union_bbox = bbox1.merge(&bbox2);
assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0));
assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0));
}
#[test]
fn test_bounding_box_union_2d() {
use nalgebra::{Point2, Vector2};
#[test]
fn test_bounding_box_intersection_2d() {
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
let point3 = Point2::new(3.0, 5.0);
let point4 = Point2::new(5.0, 7.0);
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
let point3 = Point2::new(3.0, 5.0);
let point4 = Point2::new(7.0, 8.0);
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
let intersection_bbox = bbox1.intersection(&bbox2).unwrap();
assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0));
assert_eq!(intersection_bbox.size(), Vector2::new(1.0, 1.0));
}
let union_bbox = bbox1.union(&bbox2);
assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0));
assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0));
}
#[test]
fn test_bounding_box_intersection_2d() {
use nalgebra::{Point2, Vector2};
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
let point3 = Point2::new(3.0, 5.0);
let point4 = Point2::new(5.0, 7.0);
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
let intersection_bbox = bbox1.intersection(&bbox2).unwrap();
assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0));
assert_eq!(intersection_bbox.size(), Vector2::new(1.0, 1.0));
}
#[test]
fn test_bounding_box_contains_point() {
use nalgebra::Point2;
let point1 = Point2::new(2, 3);
let point2 = Point2::new(5, 4);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
use itertools::Itertools;
for (i, j) in (0..=10).cartesian_product(0..=10) {
if bbox.contains_point(&Point2::new(i, j)) {
if !(2..=5).contains(&i) && !(3..=4).contains(&j) {
panic!(
"Point ({}, {}) should not be contained in the bounding box",
i, j
);
}
} else {
if (2..=5).contains(&i) && (3..=4).contains(&j) {
panic!(
"Point ({}, {}) should be contained in the bounding box",
i, j
);
#[test]
fn test_bounding_box_contains_point() {
let point1 = Point2::new(2, 3);
let point2 = Point2::new(5, 4);
let bbox = AxisAlignedBoundingBox::new(point1, point2);
use itertools::Itertools;
for (i, j) in (0..=10).cartesian_product(0..=10) {
if bbox.contains_point(&Point2::new(i, j)) {
if !(2..=5).contains(&i) && !(3..=4).contains(&j) {
panic!(
"Point ({}, {}) should not be contained in the bounding box",
i, j
);
}
} else {
if (2..=5).contains(&i) && (3..=4).contains(&j) {
panic!(
"Point ({}, {}) should be contained in the bounding box",
i, j
);
}
}
}
}
}
#[test]
fn test_bounding_box_clamp_box_2d() {
let bbox1 = Aabb2::from_x1y1x2y2(1, 1, 4, 4);
let bbox2 = Aabb2::from_x1y1x2y2(2, 2, 3, 3);
let clamped = bbox2.clamp(&bbox1).unwrap();
assert_eq!(bbox2, clamped);
let clamped = bbox1.clamp(&bbox2).unwrap();
assert_eq!(bbox2, clamped);
#[test]
fn test_bounding_box_clamp_box_2d() {
let bbox1 = Aabb2::from_x1y1x2y2(1, 1, 4, 4);
let bbox2 = Aabb2::from_x1y1x2y2(2, 2, 3, 3);
let clamped = bbox2.clamp(&bbox1).unwrap();
assert_eq!(bbox2, clamped);
let clamped = bbox1.clamp(&bbox2).unwrap();
assert_eq!(bbox2, clamped);
let bbox1 = Aabb2::from_x1y1x2y2(4, 5, 7, 8);
let bbox2 = Aabb2::from_x1y1x2y2(5, 4, 8, 7);
let clamped = bbox1.clamp(&bbox2).unwrap();
let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7);
assert_eq!(clamped, expected)
let bbox1 = Aabb2::from_x1y1x2y2(4, 5, 7, 8);
let bbox2 = Aabb2::from_x1y1x2y2(5, 4, 8, 7);
let clamped = bbox1.clamp(&bbox2).unwrap();
let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7);
assert_eq!(clamped, expected)
}
#[test]
fn test_iou_identical_boxes() {
let a = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
let b = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
assert_eq!(a.iou(&b), 1.0);
}
#[test]
fn test_iou_non_overlapping_boxes() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
let b = Aabb2::from_x1y1x2y2(2.0, 2.0, 3.0, 3.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_iou_partial_overlap() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 2.0, 2.0);
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
// Intersection area = 1, Union area = 7
assert!((a.iou(&b) - 1.0 / 7.0).abs() < 1e-6);
}
#[test]
fn test_iou_one_inside_another() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 4.0, 4.0);
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
// Intersection area = 4, Union area = 16
assert!((a.iou(&b) - 0.25).abs() < 1e-6);
}
#[test]
fn test_iou_edge_touching() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
let b = Aabb2::from_x1y1x2y2(1.0, 0.0, 2.0, 1.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_iou_corner_touching() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 2.0, 2.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_iou_zero_area_box() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 0.0, 0.0);
let b = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_specific_values() {
let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264);
let box2 = Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436);
assert!(box1.iou(&box2) >= 0.0);
}
}

View File

@@ -1,4 +1,11 @@
use std::collections::HashSet;
use std::collections::{HashSet, VecDeque};
use itertools::Itertools;
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum NmsError {
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
}
use crate::*;
/// Apply Non-Maximum Suppression to a set of bounding boxes.
@@ -18,10 +25,11 @@ pub fn nms<T>(
scores: &[T],
score_threshold: T,
nms_threshold: T,
) -> HashSet<usize>
) -> Result<HashSet<usize>, NmsError>
where
T: Num
+ num::Float
+ ordered_float::FloatCore
+ core::ops::Neg<Output = T>
+ core::iter::Product<T>
+ core::ops::AddAssign
+ core::ops::SubAssign
@@ -29,56 +37,37 @@ where
+ nalgebra::SimdValue
+ nalgebra::SimdPartialOrd,
{
use itertools::Itertools;
// Create vector of (index, box, score) tuples for boxes with scores above threshold
let mut indexed_boxes: Vec<(usize, &Aabb2<T>, &T)> = boxes
if boxes.len() != scores.len() {
return Err(NmsError::BoxesAndScoresLengthMismatch {
boxes: boxes.len(),
scores: scores.len(),
});
}
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
.iter()
.enumerate()
.zip(scores.iter())
.zip(scores)
.filter_map(|((idx, bbox), score)| {
if *score >= score_threshold {
Some((idx, bbox, score))
} else {
None
}
(*score > score_threshold).then_some((idx, *bbox, *score, true))
})
.sorted_by_cached_key(|(_, _, score, _)| -ordered_float::OrderedFloat(*score))
.collect();
// Sort by score in descending order
indexed_boxes.sort_by(|(_, _, score_a), (_, _, score_b)| {
score_b
.partial_cmp(score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut keep_indices = HashSet::new();
let mut suppressed = HashSet::new();
for (i, (idx_i, bbox_i, _)) in indexed_boxes.iter().enumerate() {
// Skip if this box is already suppressed
if suppressed.contains(idx_i) {
for i in 0..combined.len() {
let first = combined[i];
if first.3 == false {
continue;
}
// Keep this box
keep_indices.insert(*idx_i);
// Compare with remaining boxes
for (idx_j, bbox_j, _) in indexed_boxes.iter().skip(i + 1) {
// Skip if this box is already suppressed
if suppressed.contains(idx_j) {
continue;
}
// Calculate IoU and suppress if above threshold
if let Some(iou) = bbox_i.iou(bbox_j) {
if iou >= nms_threshold {
suppressed.insert(*idx_j);
}
let bbox = first.1;
for item in combined.iter_mut().skip(i + 1) {
if bbox.iou(&item.1) > nms_threshold {
item.3 = false
}
}
}
keep_indices
Ok(combined
.into_iter()
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
.collect())
}

View File

@@ -5,10 +5,17 @@ pub trait Roi<'a, Output> {
type Error;
fn roi(&'a self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
}
pub trait RoiMut<'a, Output> {
type Error;
fn roi_mut(&'a mut self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
}
pub trait MultiRoi<'a, Output> {
type Error;
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Output, Self::Error>;
}
#[derive(thiserror::Error, Debug, Copy, Clone)]
pub enum RoiError {
#[error("Region of intereset is out of bounds")]
@@ -36,7 +43,7 @@ impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3<T> {
let x2 = aabb.x2();
let y1 = aabb.y1();
let y2 = aabb.y2();
if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
if x1 > x2 || y1 > y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
return Err(RoiError::RoiOutOfBounds);
}
Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]))
@@ -95,3 +102,47 @@ pub fn reborrow_test() {
};
dbg!(y);
}
impl<'a> MultiRoi<'a, Vec<ArrayView3<'a, u8>>> for Array3<u8> {
type Error = RoiError;
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'a, u8>>, Self::Error> {
let (height, width, _channels) = self.dim();
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
aabbs
.iter()
.map(|aabb| {
let slice_arg =
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
Ok(self.slice(slice_arg))
})
.collect::<Result<Vec<_>, RoiError>>()
}
}
impl<'a, 'b> MultiRoi<'a, Vec<ArrayView3<'b, u8>>> for ArrayView3<'b, u8> {
type Error = RoiError;
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'b, u8>>, Self::Error> {
let (height, width, _channels) = self.dim();
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
aabbs
.iter()
.map(|aabb| {
let slice_arg =
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
Ok(self.slice_move(slice_arg))
})
.collect::<Result<Vec<_>, RoiError>>()
}
}
fn bbox_to_slice_arg(
aabb: Aabb2<usize>,
) -> ndarray::SliceInfo<[ndarray::SliceInfoElem; 3], ndarray::Ix3, ndarray::Ix3> {
// This function should convert the bounding box to a slice argument
// For now, we will return a dummy value
let x1 = aabb.x1();
let x2 = aabb.x2();
let y1 = aabb.y1();
let y2 = aabb.y2();
ndarray::s![y1..y2, x1..x2, ..]
}

BIN
facenet.mnn Normal file

Binary file not shown.

14
flake.lock generated
View File

@@ -109,16 +109,16 @@
"mnn-src": {
"flake": false,
"locked": {
"lastModified": 1749173738,
"narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=",
"lastModified": 1753256753,
"narHash": "sha256-aTpwVZBkpQiwOVVXDfQIVEx9CswNiPbvNftw8KsoW+Q=",
"owner": "alibaba",
"repo": "MNN",
"rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32",
"rev": "a739ea5870a4a45680f0e36ba9662ca39f2f4eec",
"type": "github"
},
"original": {
"owner": "alibaba",
"ref": "3.2.0",
"ref": "3.2.2",
"repo": "MNN",
"type": "github"
}
@@ -178,11 +178,11 @@
]
},
"locked": {
"lastModified": 1750732748,
"narHash": "sha256-HR2b3RHsPeJm+Fb+1ui8nXibgniVj7hBNvUbXEyz0DU=",
"lastModified": 1754621349,
"narHash": "sha256-JkXUS/nBHyUqVTuL4EDCvUWauTHV78EYfk+WqiTAMQ4=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "4b4494b2ba7e8a8041b2e28320b2ee02c115c75f",
"rev": "c448ab42002ac39d3337da10420c414fccfb1088",
"type": "github"
},
"original": {

216
flake.nix
View File

@@ -22,25 +22,27 @@
inputs.nixpkgs.follows = "nixpkgs";
};
mnn-src = {
url = "github:alibaba/MNN/3.2.0";
url = "github:alibaba/MNN/3.2.2";
flake = false;
};
};
outputs = {
self,
crane,
flake-utils,
nixpkgs,
rust-overlay,
advisory-db,
nix-github-actions,
mnn-overlay,
mnn-src,
...
}:
outputs =
{
self,
crane,
flake-utils,
nixpkgs,
rust-overlay,
advisory-db,
nix-github-actions,
mnn-overlay,
mnn-src,
...
}:
flake-utils.lib.eachDefaultSystem (
system: let
system:
let
pkgs = import nixpkgs {
inherit system;
overlays = [
@@ -61,118 +63,148 @@
stableToolchain = pkgs.rust-bin.stable.latest.default;
stableToolchainWithLLvmTools = stableToolchain.override {
extensions = ["rust-src" "llvm-tools"];
extensions = [
"rust-src"
"llvm-tools"
];
};
stableToolchainWithRustAnalyzer = stableToolchain.override {
extensions = ["rust-src" "rust-analyzer"];
extensions = [
"rust-src"
"rust-analyzer"
];
};
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
src = let
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
in
src =
let
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
sourceFilters =
path: type:
(craneLib.filterCargoSources path type)
|| filterBySuffix path [
".c"
".h"
".hpp"
".cpp"
".cc"
];
in
lib.cleanSourceWith {
filter = sourceFilters;
src = ./.;
};
commonArgs =
{
inherit src;
pname = name;
stdenv = pkgs.clangStdenv;
doCheck = false;
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
# nativeBuildInputs = with pkgs; [
# cmake
# llvmPackages.libclang.lib
# ];
buildInputs = with pkgs;
[]
++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv
apple-sdk_13
]);
}
// (lib.optionalAttrs pkgs.stdenv.isLinux {
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
});
commonArgs = {
inherit src;
pname = name;
stdenv = pkgs.clangStdenv;
doCheck = false;
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
# nativeBuildInputs = with pkgs; [
# cmake
# llvmPackages.libclang.lib
# ];
buildInputs =
with pkgs;
[ ]
++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv
apple-sdk_13
]);
}
// (lib.optionalAttrs pkgs.stdenv.isLinux {
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
});
cargoArtifacts = craneLib.buildPackage commonArgs;
in {
checks =
{
"${name}-clippy" = craneLib.cargoClippy (commonArgs
// {
inherit cargoArtifacts;
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
});
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
"${name}-toml-fmt" = craneLib.taploFmt {
src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"];
};
# Audit dependencies
"${name}-audit" = craneLib.cargoAudit {
inherit src advisory-db;
};
# Audit licenses
"${name}-deny" = craneLib.cargoDeny {
inherit src;
};
"${name}-nextest" = craneLib.cargoNextest (commonArgs
// {
inherit cargoArtifacts;
partitions = 1;
partitionType = "count";
});
}
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
in
{
checks = {
"${name}-clippy" = craneLib.cargoClippy (
commonArgs
// {
inherit cargoArtifacts;
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
}
);
"${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; });
"${name}-fmt" = craneLib.cargoFmt { inherit src; };
"${name}-toml-fmt" = craneLib.taploFmt {
src = pkgs.lib.sources.sourceFilesBySuffices src [ ".toml" ];
};
# Audit dependencies
"${name}-audit" = craneLib.cargoAudit {
inherit src advisory-db;
};
packages = let
pkg = craneLib.buildPackage (commonArgs
// {inherit cargoArtifacts;}
# Audit licenses
"${name}-deny" = craneLib.cargoDeny {
inherit src;
};
"${name}-nextest" = craneLib.cargoNextest (
commonArgs
// {
nativeBuildInputs = with pkgs; [
installShellFiles
];
postInstall = ''
installShellCompletion --cmd ${name} \
--bash <($out/bin/${name} completions bash) \
--fish <($out/bin/${name} completions fish) \
--zsh <($out/bin/${name} completions zsh)
'';
});
in {
"${name}" = pkg;
default = pkg;
inherit cargoArtifacts;
partitions = 1;
partitionType = "count";
}
);
}
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; });
};
packages =
let
pkg = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
}
// {
nativeBuildInputs = with pkgs; [
installShellFiles
];
postInstall = ''
installShellCompletion --cmd ${name} \
--bash <($out/bin/${name} completions bash) \
--fish <($out/bin/${name} completions fish) \
--zsh <($out/bin/${name} completions zsh)
'';
}
);
in
{
"${name}" = pkg;
default = pkg;
};
devShells = {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } (
commonArgs
// {
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
packages = with pkgs;
packages =
with pkgs;
[
stableToolchainWithRustAnalyzer
cargo-nextest
cargo-deny
cmake
mnn
cargo-make
]
++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13
]);
});
}
);
};
}
)
// {
githubActions = nix-github-actions.lib.mkGithubMatrix {
checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks;
checks = nixpkgs.lib.getAttrs [ "x86_64-linux" ] self.checks;
};
};
}

BIN
models/facenet.mnn LFS Normal file

Binary file not shown.

BIN
models/facenet.onnx LFS Normal file

Binary file not shown.

Binary file not shown.

BIN
models/retinaface.onnx LFS Normal file

Binary file not shown.

View File

@@ -48,6 +48,8 @@ pub struct Detect {
pub model_type: Models,
#[clap(short, long)]
pub output: Option<PathBuf>,
#[clap(short, long, default_value = "cpu")]
pub forward_type: mnn::ForwardType,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]

View File

@@ -6,27 +6,28 @@ use nalgebra::{Point2, Vector2};
use ndarray_resize::NdFir;
use std::path::Path;
/// Configuration for face detection postprocessing
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionConfig {
anchor_sizes: Vec<Vector2<usize>>,
steps: Vec<usize>,
variance: Vec<f32>,
threshold: f32,
nms_threshold: f32,
/// Minimum confidence to keep a detection
pub threshold: f32,
/// NMS threshold for suppressing overlapping boxes
pub nms_threshold: f32,
/// Variances for bounding box decoding
pub variances: [f32; 2],
/// The step size (stride) for each feature map
pub steps: Vec<usize>,
/// The minimum anchor sizes for each feature map
pub min_sizes: Vec<Vec<usize>>,
/// Whether to clip bounding boxes to the image dimensions
pub clamp: bool,
/// Input image width (used for anchor generation)
pub input_width: usize,
/// Input image height (used for anchor generation)
pub input_height: usize,
}
impl FaceDetectionConfig {
pub fn with_min_sizes(mut self, min_sizes: Vec<Vector2<usize>>) -> Self {
self.anchor_sizes = min_sizes;
self
}
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
self.steps = steps;
self
}
pub fn with_variance(mut self, variance: Vec<f32>) -> Self {
self.variance = variance;
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
@@ -35,23 +36,48 @@ impl FaceDetectionConfig {
self.nms_threshold = nms_threshold;
self
}
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
self.variances = variances;
self
}
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
self.steps = steps;
self
}
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
self.min_sizes = min_sizes;
self
}
pub fn with_clip(mut self, clip: bool) -> Self {
self.clamp = clip;
self
}
pub fn with_input_width(mut self, input_width: usize) -> Self {
self.input_width = input_width;
self
}
pub fn with_input_height(mut self, input_height: usize) -> Self {
self.input_height = input_height;
self
}
}
impl Default for FaceDetectionConfig {
fn default() -> Self {
FaceDetectionConfig {
anchor_sizes: vec![
Vector2::new(16, 32),
Vector2::new(64, 128),
Vector2::new(256, 512),
],
steps: vec![8, 16, 32],
variance: vec![0.1, 0.2],
threshold: 0.8,
Self {
threshold: 0.5,
nms_threshold: 0.4,
variances: [0.1, 0.2],
steps: vec![8, 16, 32],
min_sizes: vec![vec![16, 32], vec![64, 128], vec![256, 512]],
clamp: true,
input_width: 1024,
input_height: 1024,
}
}
}
#[derive(Debug)]
pub struct FaceDetection {
handle: mnn_sync::SessionHandle,
}
@@ -87,80 +113,111 @@ pub struct FaceDetectionOutput {
pub landmark: Vec<FaceLandmarks>,
}
impl FaceDetectionModelOutput {
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
let mut anchors = Vec::new();
for (k, &step) in config.steps.iter().enumerate() {
let feature_size = 1024 / step;
let min_sizes = config.anchor_sizes[k];
let sizes = [min_sizes.x, min_sizes.y];
for i in 0..feature_size {
for j in 0..feature_size {
for &size in &sizes {
let cx = (j as f32 + 0.5) * step as f32 / 1024.0;
let cy = (i as f32 + 0.5) * step as f32 / 1024.0;
let s_k = size as f32 / 1024.0;
anchors.push((cx, cy, s_k, s_k));
}
fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
let mut anchors = Vec::new();
let feature_maps: Vec<(usize, usize)> = config
.steps
.iter()
.map(|&step| {
(
(config.input_height as f32 / step as f32).ceil() as usize,
(config.input_width as f32 / step as f32).ceil() as usize,
)
})
.collect();
for (k, f) in feature_maps.iter().enumerate() {
let min_sizes = &config.min_sizes[k];
for i in 0..f.0 {
for j in 0..f.1 {
for &min_size in min_sizes {
let s_kx = min_size as f32 / config.input_width as f32;
let s_ky = min_size as f32 / config.input_height as f32;
let dense_cx =
(j as f32 + 0.5) * config.steps[k] as f32 / config.input_width as f32;
let dense_cy =
(i as f32 + 0.5) * config.steps[k] as f32 / config.input_height as f32;
anchors.push([
dense_cx - s_kx / 2.,
dense_cy - s_ky / 2.,
dense_cx + s_kx / 2.,
dense_cy + s_ky / 2.,
]);
}
}
}
let mut boxes = Vec::new();
let mut scores = Vec::new();
let mut landmarks = Vec::new();
let var0 = config.variance[0];
let var1 = config.variance[1];
let bbox_data = self.bbox;
let conf_data = self.confidence;
let landmark_data = self.landmark;
let num_priors = bbox_data.shape()[1];
for idx in 0..num_priors {
let dx = bbox_data[[0, idx, 0]];
let dy = bbox_data[[0, idx, 1]];
let dw = bbox_data[[0, idx, 2]];
let dh = bbox_data[[0, idx, 3]];
let (anchor_cx, anchor_cy, anchor_w, anchor_h) = anchors[idx];
let pred_cx = anchor_cx + dx * var0 * anchor_w;
let pred_cy = anchor_cy + dy * var0 * anchor_h;
let pred_w = anchor_w * (dw * var1).exp();
let pred_h = anchor_h * (dh * var1).exp();
let x_min = pred_cx - pred_w / 2.0;
let y_min = pred_cy - pred_h / 2.0;
let x_max = pred_cx + pred_w / 2.0;
let y_max = pred_cy + pred_h / 2.0;
let score = conf_data[[0, idx, 1]];
if score > config.threshold {
boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max));
scores.push(score);
}
let left_eye_x = landmark_data[[0, idx, 0]] * anchor_w * var0 + anchor_cx;
let left_eye_y = landmark_data[[0, idx, 1]] * anchor_h * var0 + anchor_cy;
ndarray::Array2::from_shape_vec((anchors.len(), 4), anchors.into_iter().flatten().collect())
.unwrap()
}
let right_eye_x = landmark_data[[0, idx, 2]] * anchor_w * var0 + anchor_cx;
let right_eye_y = landmark_data[[0, idx, 3]] * anchor_h * var0 + anchor_cy;
impl FaceDetectionModelOutput {
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
use ndarray::s;
let nose_x = landmark_data[[0, idx, 4]] * anchor_w * var0 + anchor_cx;
let nose_y = landmark_data[[0, idx, 5]] * anchor_h * var0 + anchor_cy;
let priors = generate_anchors(config);
let left_mouth_x = landmark_data[[0, idx, 6]] * anchor_w * var0 + anchor_cx;
let left_mouth_y = landmark_data[[0, idx, 7]] * anchor_h * var0 + anchor_cy;
let scores = self.confidence.slice(s![0, .., 1]);
let boxes = self.bbox.slice(s![0, .., ..]);
let landmarks_raw = self.landmark.slice(s![0, .., ..]);
let right_mouth_x = landmark_data[[0, idx, 8]] * anchor_w * var0 + anchor_cx;
let right_mouth_y = landmark_data[[0, idx, 9]] * anchor_h * var0 + anchor_cy;
let mut decoded_boxes = Vec::new();
let mut decoded_landmarks = Vec::new();
let mut confidences = Vec::new();
landmarks.push(FaceLandmarks {
left_eye: Point2::new(left_eye_x, left_eye_y),
right_eye: Point2::new(right_eye_x, right_eye_y),
nose: Point2::new(nose_x, nose_y),
left_mouth: Point2::new(left_mouth_x, left_mouth_y),
right_mouth: Point2::new(right_mouth_x, right_mouth_y),
});
for i in 0..priors.shape()[0] {
if scores[i] > config.threshold {
let prior = priors.row(i);
let loc = boxes.row(i);
let landm = landmarks_raw.row(i);
// Decode bounding box
let prior_cx = (prior[0] + prior[2]) / 2.0;
let prior_cy = (prior[1] + prior[3]) / 2.0;
let prior_w = prior[2] - prior[0];
let prior_h = prior[3] - prior[1];
let var = config.variances;
let cx = prior_cx + loc[0] * var[0] * prior_w;
let cy = prior_cy + loc[1] * var[0] * prior_h;
let w = prior_w * (loc[2] * var[1]).exp();
let h = prior_h * (loc[3] * var[1]).exp();
let xmin = cx - w / 2.0;
let ymin = cy - h / 2.0;
let xmax = cx + w / 2.0;
let ymax = cy + h / 2.0;
let mut bbox =
Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax));
if config.clamp {
bbox.component_clamp(0.0, 1.0);
}
decoded_boxes.push(bbox);
// Decode landmarks
let mut points = [Point2::new(0.0, 0.0); 5];
for j in 0..5 {
points[j].x = prior_cx + landm[j * 2] * var[0] * prior_w;
points[j].y = prior_cy + landm[j * 2 + 1] * var[0] * prior_h;
}
let landmarks = FaceLandmarks {
left_eye: points[0],
right_eye: points[1],
nose: points[2],
left_mouth: points[3],
right_mouth: points[4],
};
decoded_landmarks.push(landmarks);
confidences.push(scores[i]);
}
}
Ok(FaceDetectionProcessedOutput {
bbox: boxes,
confidence: scores,
landmarks,
bbox: decoded_boxes,
confidence: confidences,
landmarks: decoded_landmarks,
})
}
}
@@ -189,7 +246,56 @@ impl FaceDetectionModelOutput {
}
}
pub struct FaceDetectionBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl FaceDetectionBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<FaceDetection> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
}
impl FaceDetection {
pub fn builder<T: AsRef<[u8]>>()
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
FaceDetectionBuilder::new
}
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
@@ -204,9 +310,13 @@ impl FaceDetection {
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
model
.set_cache_file("retinaface.cache", 128)
.change_context(Error)
.attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::CPU)
.with_type(mnn::ForwardType::Metal)
.with_backend_config(bc);
tracing::info!("Creating session handle for face detection model");
let handle = mnn_sync::SessionHandle::new(model, sc)
@@ -217,7 +327,7 @@ impl FaceDetection {
pub fn detect_faces(
&self,
image: ndarray::Array3<u8>,
image: ndarray::ArrayView3<u8>,
config: FaceDetectionConfig,
) -> Result<FaceDetectionOutput> {
let (height, width, _channels) = image.dim();
@@ -242,7 +352,8 @@ impl FaceDetection {
.map(|((b, s), l)| (b, s, l))
.multiunzip();
let keep_indices = nms(&boxes, &scores, config.threshold, config.nms_threshold);
let keep_indices =
nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?;
let bboxes = boxes
.into_iter()
@@ -270,28 +381,28 @@ impl FaceDetection {
})
}
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
#[rustfmt::skip]
let mut resized = image
.fast_resize(1024, 1024, None)
.change_context(Error)?
.mapv(|f| f as f32)
.tap_mut(|arr| {
arr.axis_iter_mut(ndarray::Axis(2))
.zip([104, 117, 123])
.for_each(|(mut array, pixel)| {
let pixel = pixel as f32;
array.map_inplace(|v| *v -= pixel);
});
})
.permuted_axes((2, 0, 1))
.insert_axis(ndarray::Axis(0))
.as_standard_layout()
.into_owned();
use ::tap::*;
let output = self
.handle
.run(move |sr| {
let mut resized = image
.fast_resize(1024, 1024, None)
.change_context(mnn::ErrorKind::TensorError)?
.mapv(|f| f as f32)
.tap_mut(|arr| {
arr.axis_iter_mut(ndarray::Axis(2))
.zip([104, 117, 123])
.for_each(|(mut array, pixel)| {
let pixel = pixel as f32;
array.map_inplace(|v| *v -= pixel);
});
})
.permuted_axes((2, 0, 1))
.insert_axis(ndarray::Axis(0))
.as_standard_layout()
.into_owned();
let tensor = resized
.as_mnn_tensor_mut()
.attach_printable("Failed to convert ndarray to mnn tensor")

1
src/faceembed.rs Normal file
View File

@@ -0,0 +1 @@
pub mod facenet;

153
src/faceembed/facenet.rs Normal file
View File

@@ -0,0 +1,153 @@
use crate::errors::*;
use mnn_bridge::ndarray::*;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use std::path::Path;
mod mnn_impl;
mod ort_impl;
#[derive(Debug)]
pub struct EmbeddingGenerator {
handle: mnn_sync::SessionHandle,
}
pub struct EmbeddingGeneratorBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl EmbeddingGeneratorBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<EmbeddingGenerator> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(EmbeddingGenerator { handle })
}
}
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn builder<T: AsRef<[u8]>>()
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
EmbeddingGeneratorBuilder::new
}
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
tracing::info!("Loading face embedding model from bytes");
let mut model = mnn::Interpreter::from_bytes(model)
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
model
.set_cache_file("facenet.cache", 128)
.change_context(Error)
.attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::Metal)
.with_backend_config(bc);
tracing::info!("Creating session handle for face embedding model");
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(Self { handle })
}
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = face
// .permuted_axes((0, 3, 1, 2))
.as_standard_layout()
.mapv(|x| x as f32);
let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32);
let output = self
.handle
.run(move |sr| {
let tensor = tensor
.as_mnn_tensor()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
let needs_resize = unsafe {
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
if *input.shape() != shape {
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
// input.resize(shape);
intptr.resize_tensor(input.view_mut(), shape);
true
} else {
false
}
};
if needs_resize {
tracing::trace!("Resized input tensor to shape: {:?}", shape);
let now = std::time::Instant::now();
intptr.resize_session(session);
tracing::trace!("Session resized in {:?}", now.elapsed());
}
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, Self::OUTPUT_NAME)?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
Ok(output_tensor)
})
.change_context(Error)?;
Ok(output)
}
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
// todo!()
// }
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
// todo!()
// }
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,65 @@
use crate::errors::{Result, *};
use ndarray::*;
use ort::*;
use std::path::Path;
#[derive(Debug)]
pub struct EmbeddingGenerator {
handle: ort::session::Session,
}
// impl EmbeddingGeneratorBuilder {
// pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
// Ok(Self {
// schedule_config: None,
// backend_config: None,
// model: mnn::Interpreter::from_bytes(model.as_ref())
// .map_err(|e| e.into_inner())
// .change_context(Error)
// .attach_printable("Failed to load model from bytes")?,
// })
// }
//
// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
// self.schedule_config
// .get_or_insert_default()
// .set_type(forward_type);
// self
// }
//
// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
// self.schedule_config = Some(config);
// self
// }
//
// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
// self.backend_config = Some(config);
// self
// }
//
// pub fn build(self) -> Result<EmbeddingGenerator> {
// let model = self.model;
// let sc = self.schedule_config.unwrap_or_default();
// let handle = mnn_sync::SessionHandle::new(model, sc)
// .change_context(Error)
// .attach_printable("Failed to create session handle")?;
// Ok(EmbeddingGenerator { handle })
// }
// }
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result<Self> {
todo!()
}
// pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {}
}

View File

@@ -1,4 +1,5 @@
pub mod errors;
pub mod facedet;
pub mod faceembed;
pub mod image;
use errors::*;

View File

@@ -1,9 +1,15 @@
mod cli;
mod errors;
use detector::facedet::retinaface::FaceDetectionConfig;
use bounding_box::roi::MultiRoi;
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
use errors::*;
use fast_image_resize::ResizeOptions;
use ndarray::*;
use ndarray_image::*;
use ndarray_resize::NdFir;
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
const CHUNK_SIZE: usize = 8;
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("trace")
@@ -15,27 +21,86 @@ pub fn main() -> Result<()> {
match args.cmd {
cli::SubCommand::Detect(detect) => {
use detector::facedet;
let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
.change_context(Error)?
.with_forward_type(detect.forward_type)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
.change_context(Error)?
.with_forward_type(detect.forward_type)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
let image = image::open(detect.image).change_context(Error)?;
let image = image.into_rgb8();
let mut array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = model
let output = retinaface
.detect_faces(
array.clone(),
FaceDetectionConfig::default().with_threshold(detect.threshold),
array.view(),
FaceDetectionConfig::default()
.with_threshold(detect.threshold)
.with_nms_threshold(detect.nms_threshold),
)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
for bbox in output.bbox {
for bbox in &output.bbox {
tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
// .inspect(|f| {
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
// })
.map(|roi| {
roi.as_standard_layout()
.fast_resize(512, 512, &ResizeOptions::default())
.change_context(Error)
})
// .inspect(|f| {
// f.as_ref().inspect(|f| {
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
// });
// })
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = CHUNK_SIZE;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len());
if chunk.len() < 8 {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((512, 512, 3));
let zero_array = core::iter::repeat(zeros.view())
.take(chunk_size)
.collect::<Vec<_>>();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
}
})
.collect::<Result<Vec<Array2<f32>>>>();
let v = array.view();
if let Some(output) = detect.output {
let image: image::RgbImage = v