Compare commits
10 Commits
043a845fc1
...
2d2309837f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d2309837f | ||
|
|
f5740dc87f | ||
|
|
3753e399b1 | ||
|
|
d52b69911f | ||
|
|
a3ea01b7b6 | ||
|
|
e60921b099 | ||
|
|
e91ae5b865 | ||
|
|
2c43f657aa | ||
|
|
8d07b0846c | ||
|
|
f7aae32caf |
4
.gitattributes
vendored
Normal file
4
.gitattributes
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
models/retinaface.mnn filter=lfs diff=lfs merge=lfs -text
|
||||
models/facenet.mnn filter=lfs diff=lfs merge=lfs -text
|
||||
models/facenet.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
models/retinaface.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -2,3 +2,6 @@
|
||||
/target
|
||||
.direnv
|
||||
*.jpg
|
||||
face_net.onnx
|
||||
.DS_Store
|
||||
*.cache
|
||||
|
||||
968
Cargo.lock
generated
968
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -12,8 +12,8 @@ mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" }
|
||||
ndarray-image = { path = "ndarray-image" }
|
||||
ndarray-resize = { path = "ndarray-resize" }
|
||||
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
|
||||
# "metal",
|
||||
# "coreml",
|
||||
"metal",
|
||||
"coreml",
|
||||
"tracing",
|
||||
], branch = "restructure-tensor-type" }
|
||||
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||
@@ -22,6 +22,7 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
|
||||
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||
"tracing",
|
||||
], branch = "restructure-tensor-type" }
|
||||
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
|
||||
|
||||
[package]
|
||||
name = "detector"
|
||||
@@ -35,8 +36,7 @@ clap_complete = "4.5"
|
||||
error-stack = "0.5"
|
||||
fast_image_resize = "5.2.0"
|
||||
image = "0.25.6"
|
||||
linfa = "0.7.1"
|
||||
nalgebra = "0.33.2"
|
||||
nalgebra = { workspace = true }
|
||||
ndarray = "0.16.1"
|
||||
ndarray-image = { workspace = true }
|
||||
ndarray-resize = { workspace = true }
|
||||
@@ -53,6 +53,7 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
ordered-float = "5.0.0"
|
||||
ort = "2.0.0-rc.10"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
27
Makefile.toml
Normal file
27
Makefile.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[tasks.convert_facenet]
|
||||
command = "MNNConvert"
|
||||
args = [
|
||||
"-f",
|
||||
"ONNX",
|
||||
"--modelFile",
|
||||
"models/facenet.onnx",
|
||||
"--MNNModel",
|
||||
"models/facenet.mnn",
|
||||
"--fp16",
|
||||
"--bizCode",
|
||||
"MNN",
|
||||
]
|
||||
|
||||
[tasks.convert_retinaface]
|
||||
command = "MNNConvert"
|
||||
args = [
|
||||
"-f",
|
||||
"ONNX",
|
||||
"--modelFile",
|
||||
"models/retinaface.onnx",
|
||||
"--MNNModel",
|
||||
"models/retinaface.mnn",
|
||||
"--fp16",
|
||||
"--bizCode",
|
||||
"MNN",
|
||||
]
|
||||
@@ -6,9 +6,10 @@ edition = "2024"
|
||||
[dependencies]
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
nalgebra = "0.33.2"
|
||||
nalgebra = { workspace = true }
|
||||
ndarray = { version = "0.16.1", optional = true }
|
||||
num = "0.4.3"
|
||||
ordered-float = "5.0.0"
|
||||
simba = "0.9.0"
|
||||
thiserror = "2.0.12"
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ pub use color::Rgba8;
|
||||
use ndarray::{Array1, Array3, ArrayViewMut3};
|
||||
|
||||
pub trait Draw<T> {
|
||||
fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize);
|
||||
fn draw(&mut self, item: &T, color: color::Rgba8, thickness: usize);
|
||||
}
|
||||
|
||||
impl Draw<Aabb2<usize>> for Array3<u8> {
|
||||
fn draw(&mut self, item: Aabb2<usize>, color: color::Rgba8, thickness: usize) {
|
||||
fn draw(&mut self, item: &Aabb2<usize>, color: color::Rgba8, thickness: usize) {
|
||||
item.draw(self, color, thickness)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,38 @@ pub mod draw;
|
||||
pub mod nms;
|
||||
pub mod roi;
|
||||
|
||||
use nalgebra::{Point, Point2, Point3, SVector, SimdPartialOrd, SimdValue};
|
||||
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
|
||||
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
|
||||
use nalgebra::{Point, Point2, SVector, Vector2};
|
||||
pub trait Num:
|
||||
num::Num
|
||||
+ core::ops::AddAssign
|
||||
+ core::ops::SubAssign
|
||||
+ core::ops::MulAssign
|
||||
+ core::ops::DivAssign
|
||||
+ core::cmp::PartialOrd
|
||||
+ core::cmp::PartialEq
|
||||
+ nalgebra::SimdPartialOrd
|
||||
+ nalgebra::SimdValue
|
||||
+ Copy
|
||||
+ core::fmt::Debug
|
||||
+ 'static
|
||||
{
|
||||
}
|
||||
impl<
|
||||
T: num::Num
|
||||
+ core::ops::AddAssign
|
||||
+ core::ops::SubAssign
|
||||
+ core::ops::MulAssign
|
||||
+ core::ops::DivAssign
|
||||
+ core::cmp::PartialOrd
|
||||
+ core::cmp::PartialEq
|
||||
+ nalgebra::SimdPartialOrd
|
||||
+ nalgebra::SimdValue
|
||||
+ Copy
|
||||
+ core::fmt::Debug
|
||||
+ 'static,
|
||||
> Num for T
|
||||
{
|
||||
}
|
||||
|
||||
/// An axis aligned bounding box in `D` dimensions, defined by the minimum vertex and a size vector.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
@@ -20,16 +49,28 @@ pub type Aabb2<T> = AxisAlignedBoundingBox<T, 2>;
|
||||
pub type Aabb3<T> = AxisAlignedBoundingBox<T, 3>;
|
||||
|
||||
impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
pub fn new(point: Point<T, D>, size: SVector<T, D>) -> Self {
|
||||
// Panics if max < min
|
||||
pub fn new(min_point: Point<T, D>, max_point: Point<T, D>) -> Self {
|
||||
if max_point >= min_point {
|
||||
Self::from_min_max_vertices(min_point, max_point)
|
||||
} else {
|
||||
dbg!(max_point, min_point);
|
||||
panic!("max_point must be greater than or equal to min_point");
|
||||
}
|
||||
}
|
||||
pub fn try_new(min_point: Point<T, D>, max_point: Point<T, D>) -> Option<Self> {
|
||||
if max_point < min_point {
|
||||
return None;
|
||||
}
|
||||
Some(Self::from_min_max_vertices(min_point, max_point))
|
||||
}
|
||||
pub fn new_point_size(point: Point<T, D>, size: SVector<T, D>) -> Self {
|
||||
Self { point, size }
|
||||
}
|
||||
|
||||
pub fn from_min_max_vertices(point1: Point<T, D>, point2: Point<T, D>) -> Self
|
||||
where
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let size = point2 - point1;
|
||||
Self::new(point1, SVector::from(size))
|
||||
pub fn from_min_max_vertices(min: Point<T, D>, max: Point<T, D>) -> Self {
|
||||
let size = max - min;
|
||||
Self::new_point_size(min, SVector::from(size))
|
||||
}
|
||||
|
||||
/// Only considers the points closest and furthest from origin
|
||||
@@ -151,7 +192,21 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
self.intersection(other)
|
||||
}
|
||||
|
||||
pub fn union(&self, other: &Self) -> Self
|
||||
pub fn component_clamp(&self, min: T, max: T) -> Self
|
||||
where
|
||||
T: PartialOrd,
|
||||
{
|
||||
let mut this = *self;
|
||||
this.point.iter_mut().for_each(|x| {
|
||||
*x = nalgebra::clamp(*x, min, max);
|
||||
});
|
||||
this.size.iter_mut().for_each(|x| {
|
||||
*x = nalgebra::clamp(*x, min, max);
|
||||
});
|
||||
this
|
||||
}
|
||||
|
||||
pub fn merge(&self, other: &Self) -> Self
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
T: core::ops::SubAssign,
|
||||
@@ -159,13 +214,24 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
T: nalgebra::SimdValue,
|
||||
T: nalgebra::SimdPartialOrd,
|
||||
{
|
||||
let self_min = self.min_vertex();
|
||||
let self_max = self.max_vertex();
|
||||
let other_min = other.min_vertex();
|
||||
let other_max = other.max_vertex();
|
||||
let min = self_min.inf(&other_min);
|
||||
let max = self_max.sup(&other_max);
|
||||
Self::from_min_max_vertices(min, max)
|
||||
let min = self.min_vertex().inf(&other.min_vertex());
|
||||
let max = self.min_vertex().sup(&other.max_vertex());
|
||||
Self::new(min, max)
|
||||
}
|
||||
|
||||
pub fn union(&self, other: &Self) -> T
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
T: core::ops::SubAssign,
|
||||
T: core::ops::MulAssign,
|
||||
T: PartialOrd,
|
||||
T: nalgebra::SimdValue,
|
||||
T: nalgebra::SimdPartialOrd,
|
||||
{
|
||||
self.measure() + other.measure()
|
||||
- Self::intersection(self, other)
|
||||
.map(|x| x.measure())
|
||||
.unwrap_or(T::zero())
|
||||
}
|
||||
|
||||
pub fn intersection(&self, other: &Self) -> Option<Self>
|
||||
@@ -176,21 +242,9 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
T: nalgebra::SimdPartialOrd,
|
||||
T: nalgebra::SimdValue,
|
||||
{
|
||||
let self_min = self.min_vertex();
|
||||
let self_max = self.max_vertex();
|
||||
let other_min = other.min_vertex();
|
||||
let other_max = other.max_vertex();
|
||||
|
||||
if self_max < other_min || other_max < self_min {
|
||||
return None; // No intersection
|
||||
}
|
||||
|
||||
let min = self_min.sup(&other_min);
|
||||
let max = self_max.inf(&other_max);
|
||||
Some(Self::from_min_max_vertices(
|
||||
Point::from(min),
|
||||
Point::from(max),
|
||||
))
|
||||
let inter_min = self.min_vertex().sup(&other.min_vertex());
|
||||
let inter_max = self.max_vertex().inf(&other.max_vertex());
|
||||
Self::try_new(inter_min, inter_max)
|
||||
}
|
||||
|
||||
pub fn denormalize(&self, factor: nalgebra::SVector<T, D>) -> Self
|
||||
@@ -233,7 +287,7 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
self.size.product()
|
||||
}
|
||||
|
||||
pub fn iou(&self, other: &Self) -> Option<T>
|
||||
pub fn iou(&self, other: &Self) -> T
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
T: core::ops::SubAssign,
|
||||
@@ -242,9 +296,19 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
T: nalgebra::SimdValue,
|
||||
T: core::ops::MulAssign,
|
||||
{
|
||||
let intersection = self.intersection(other)?;
|
||||
let union = self.union(other);
|
||||
Some(intersection.measure() / union.measure())
|
||||
let lhs_min = self.min_vertex();
|
||||
let lhs_max = self.max_vertex();
|
||||
let rhs_min = other.min_vertex();
|
||||
let rhs_max = other.max_vertex();
|
||||
|
||||
let inter_min = lhs_min.sup(&rhs_min);
|
||||
let inter_max = lhs_max.inf(&rhs_max);
|
||||
if inter_max >= inter_min {
|
||||
let intersection = Aabb::new(inter_min, inter_max).measure();
|
||||
intersection / (self.measure() + other.measure() - intersection)
|
||||
} else {
|
||||
return T::zero();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,15 +319,15 @@ impl<T: Num> Aabb2<T> {
|
||||
{
|
||||
let point1 = Point2::new(x1, y1);
|
||||
let point2 = Point2::new(x2, y2);
|
||||
Self::from_min_max_vertices(point1, point2)
|
||||
Self::new(point1, point2)
|
||||
}
|
||||
pub fn new_2d(point1: Point2<T>, point2: Point2<T>) -> Self
|
||||
where
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let size = point2.coords - point1.coords;
|
||||
Self::new(point1, SVector::from(size))
|
||||
|
||||
pub fn from_xywh(x: T, y: T, w: T, h: T) -> Self {
|
||||
let point = Point2::new(x, y);
|
||||
let size = Vector2::new(w, h);
|
||||
Self::new_point_size(point, size)
|
||||
}
|
||||
|
||||
pub fn x1y1(&self) -> Point2<T> {
|
||||
self.point
|
||||
}
|
||||
@@ -327,14 +391,6 @@ impl<T: Num> Aabb2<T> {
|
||||
}
|
||||
|
||||
impl<T: Num> Aabb3<T> {
|
||||
pub fn new_3d(point1: Point3<T>, point2: Point3<T>) -> Self
|
||||
where
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let size = point2.coords - point1.coords;
|
||||
Self::new(point1, SVector::from(size))
|
||||
}
|
||||
|
||||
pub fn volume(&self) -> T
|
||||
where
|
||||
T: core::ops::MulAssign,
|
||||
@@ -343,13 +399,16 @@ impl<T: Num> Aabb3<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod boudning_box_tests {
|
||||
use super::*;
|
||||
use nalgebra::*;
|
||||
|
||||
#[test]
|
||||
fn test_bbox_new() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
assert_eq!(bbox.min_vertex(), point1);
|
||||
assert_eq!(bbox.size(), Vector2::new(3.0, 4.0));
|
||||
@@ -357,12 +416,24 @@ fn test_bbox_new() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
fn test_intersection_and_merge() {
|
||||
let point1 = Point2::new(1, 5);
|
||||
let point2 = Point2::new(3, 2);
|
||||
let size1 = Vector2::new(3, 4);
|
||||
let size2 = Vector2::new(1, 3);
|
||||
|
||||
let this = Aabb2::new_point_size(point1, size1);
|
||||
let other = Aabb2::new_point_size(point2, size2);
|
||||
let inter = this.intersection(&other);
|
||||
let merged = this.merge(&other);
|
||||
assert_ne!(inter, Some(merged))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_2d() {
|
||||
let point = Point2::new(1.0, 2.0);
|
||||
let size = Vector2::new(3.0, 4.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
assert_eq!(bbox.min_vertex(), point);
|
||||
assert_eq!(bbox.size(), size);
|
||||
@@ -371,11 +442,9 @@ fn test_bounding_box_center_2d() {
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_3d() {
|
||||
use nalgebra::{Point3, Vector3};
|
||||
|
||||
let point = Point3::new(1.0, 2.0, 3.0);
|
||||
let size = Vector3::new(4.0, 5.0, 6.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
assert_eq!(bbox.min_vertex(), point);
|
||||
assert_eq!(bbox.size(), size);
|
||||
@@ -384,11 +453,9 @@ fn test_bounding_box_center_3d() {
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_padding_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
|
||||
let point = Point2::new(1.0, 2.0);
|
||||
let size = Vector2::new(3.0, 4.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
let padded_bbox = bbox.padding(1.0);
|
||||
assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5));
|
||||
@@ -397,11 +464,9 @@ fn test_bounding_box_padding_2d() {
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_scaling_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
|
||||
let point = Point2::new(1.0, 1.0);
|
||||
let size = Vector2::new(3.0, 4.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0));
|
||||
assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0));
|
||||
@@ -410,11 +475,9 @@ fn test_bounding_box_scaling_2d() {
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_contains_2d() {
|
||||
use nalgebra::Point2;
|
||||
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
assert!(bbox.contains_point(&Point2::new(2.0, 3.0)));
|
||||
assert!(!bbox.contains_point(&Point2::new(5.0, 7.0)));
|
||||
@@ -422,32 +485,28 @@ fn test_bounding_box_contains_2d() {
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_union_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
let point3 = Point2::new(3.0, 5.0);
|
||||
let point4 = Point2::new(7.0, 8.0);
|
||||
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
|
||||
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
|
||||
|
||||
let union_bbox = bbox1.union(&bbox2);
|
||||
let union_bbox = bbox1.merge(&bbox2);
|
||||
assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0));
|
||||
assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_intersection_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
let point3 = Point2::new(3.0, 5.0);
|
||||
let point4 = Point2::new(5.0, 7.0);
|
||||
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
|
||||
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
|
||||
|
||||
let intersection_bbox = bbox1.intersection(&bbox2).unwrap();
|
||||
assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0));
|
||||
@@ -456,11 +515,9 @@ fn test_bounding_box_intersection_2d() {
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_contains_point() {
|
||||
use nalgebra::Point2;
|
||||
|
||||
let point1 = Point2::new(2, 3);
|
||||
let point2 = Point2::new(5, 4);
|
||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox = AxisAlignedBoundingBox::new(point1, point2);
|
||||
use itertools::Itertools;
|
||||
for (i, j) in (0..=10).cartesian_product(0..=10) {
|
||||
if bbox.contains_point(&Point2::new(i, j)) {
|
||||
@@ -496,3 +553,62 @@ fn test_bounding_box_clamp_box_2d() {
|
||||
let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7);
|
||||
assert_eq!(clamped, expected)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_identical_boxes() {
|
||||
let a = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
|
||||
assert_eq!(a.iou(&b), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_non_overlapping_boxes() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
let b = Aabb2::from_x1y1x2y2(2.0, 2.0, 3.0, 3.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_partial_overlap() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 2.0, 2.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
|
||||
// Intersection area = 1, Union area = 7
|
||||
assert!((a.iou(&b) - 1.0 / 7.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_one_inside_another() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 4.0, 4.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
|
||||
// Intersection area = 4, Union area = 16
|
||||
assert!((a.iou(&b) - 0.25).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_edge_touching() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 0.0, 2.0, 1.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_corner_touching() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 2.0, 2.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_zero_area_box() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 0.0, 0.0);
|
||||
let b = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_specific_values() {
|
||||
let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264);
|
||||
let box2 = Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436);
|
||||
assert!(box1.iou(&box2) >= 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
|
||||
use itertools::Itertools;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum NmsError {
|
||||
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
|
||||
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
|
||||
}
|
||||
|
||||
use crate::*;
|
||||
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
||||
@@ -18,10 +25,11 @@ pub fn nms<T>(
|
||||
scores: &[T],
|
||||
score_threshold: T,
|
||||
nms_threshold: T,
|
||||
) -> HashSet<usize>
|
||||
) -> Result<HashSet<usize>, NmsError>
|
||||
where
|
||||
T: Num
|
||||
+ num::Float
|
||||
+ ordered_float::FloatCore
|
||||
+ core::ops::Neg<Output = T>
|
||||
+ core::iter::Product<T>
|
||||
+ core::ops::AddAssign
|
||||
+ core::ops::SubAssign
|
||||
@@ -29,56 +37,37 @@ where
|
||||
+ nalgebra::SimdValue
|
||||
+ nalgebra::SimdPartialOrd,
|
||||
{
|
||||
use itertools::Itertools;
|
||||
|
||||
// Create vector of (index, box, score) tuples for boxes with scores above threshold
|
||||
let mut indexed_boxes: Vec<(usize, &Aabb2<T>, &T)> = boxes
|
||||
if boxes.len() != scores.len() {
|
||||
return Err(NmsError::BoxesAndScoresLengthMismatch {
|
||||
boxes: boxes.len(),
|
||||
scores: scores.len(),
|
||||
});
|
||||
}
|
||||
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.zip(scores.iter())
|
||||
.zip(scores)
|
||||
.filter_map(|((idx, bbox), score)| {
|
||||
if *score >= score_threshold {
|
||||
Some((idx, bbox, score))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
(*score > score_threshold).then_some((idx, *bbox, *score, true))
|
||||
})
|
||||
.sorted_by_cached_key(|(_, _, score, _)| -ordered_float::OrderedFloat(*score))
|
||||
.collect();
|
||||
|
||||
// Sort by score in descending order
|
||||
indexed_boxes.sort_by(|(_, _, score_a), (_, _, score_b)| {
|
||||
score_b
|
||||
.partial_cmp(score_a)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut keep_indices = HashSet::new();
|
||||
let mut suppressed = HashSet::new();
|
||||
|
||||
for (i, (idx_i, bbox_i, _)) in indexed_boxes.iter().enumerate() {
|
||||
// Skip if this box is already suppressed
|
||||
if suppressed.contains(idx_i) {
|
||||
for i in 0..combined.len() {
|
||||
let first = combined[i];
|
||||
if first.3 == false {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Keep this box
|
||||
keep_indices.insert(*idx_i);
|
||||
|
||||
// Compare with remaining boxes
|
||||
for (idx_j, bbox_j, _) in indexed_boxes.iter().skip(i + 1) {
|
||||
// Skip if this box is already suppressed
|
||||
if suppressed.contains(idx_j) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate IoU and suppress if above threshold
|
||||
if let Some(iou) = bbox_i.iou(bbox_j) {
|
||||
if iou >= nms_threshold {
|
||||
suppressed.insert(*idx_j);
|
||||
}
|
||||
let bbox = first.1;
|
||||
for item in combined.iter_mut().skip(i + 1) {
|
||||
if bbox.iou(&item.1) > nms_threshold {
|
||||
item.3 = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keep_indices
|
||||
Ok(combined
|
||||
.into_iter()
|
||||
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
|
||||
.collect())
|
||||
}
|
||||
|
||||
@@ -5,10 +5,17 @@ pub trait Roi<'a, Output> {
|
||||
type Error;
|
||||
fn roi(&'a self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait RoiMut<'a, Output> {
|
||||
type Error;
|
||||
fn roi_mut(&'a mut self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait MultiRoi<'a, Output> {
|
||||
type Error;
|
||||
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Copy, Clone)]
|
||||
pub enum RoiError {
|
||||
#[error("Region of intereset is out of bounds")]
|
||||
@@ -36,7 +43,7 @@ impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3<T> {
|
||||
let x2 = aabb.x2();
|
||||
let y1 = aabb.y1();
|
||||
let y2 = aabb.y2();
|
||||
if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
|
||||
if x1 > x2 || y1 > y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
|
||||
return Err(RoiError::RoiOutOfBounds);
|
||||
}
|
||||
Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]))
|
||||
@@ -95,3 +102,47 @@ pub fn reborrow_test() {
|
||||
};
|
||||
dbg!(y);
|
||||
}
|
||||
|
||||
impl<'a> MultiRoi<'a, Vec<ArrayView3<'a, u8>>> for Array3<u8> {
|
||||
type Error = RoiError;
|
||||
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'a, u8>>, Self::Error> {
|
||||
let (height, width, _channels) = self.dim();
|
||||
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
|
||||
aabbs
|
||||
.iter()
|
||||
.map(|aabb| {
|
||||
let slice_arg =
|
||||
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
|
||||
Ok(self.slice(slice_arg))
|
||||
})
|
||||
.collect::<Result<Vec<_>, RoiError>>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> MultiRoi<'a, Vec<ArrayView3<'b, u8>>> for ArrayView3<'b, u8> {
|
||||
type Error = RoiError;
|
||||
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'b, u8>>, Self::Error> {
|
||||
let (height, width, _channels) = self.dim();
|
||||
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
|
||||
aabbs
|
||||
.iter()
|
||||
.map(|aabb| {
|
||||
let slice_arg =
|
||||
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
|
||||
Ok(self.slice_move(slice_arg))
|
||||
})
|
||||
.collect::<Result<Vec<_>, RoiError>>()
|
||||
}
|
||||
}
|
||||
|
||||
fn bbox_to_slice_arg(
|
||||
aabb: Aabb2<usize>,
|
||||
) -> ndarray::SliceInfo<[ndarray::SliceInfoElem; 3], ndarray::Ix3, ndarray::Ix3> {
|
||||
// This function should convert the bounding box to a slice argument
|
||||
// For now, we will return a dummy value
|
||||
let x1 = aabb.x1();
|
||||
let x2 = aabb.x2();
|
||||
let y1 = aabb.y1();
|
||||
let y2 = aabb.y2();
|
||||
ndarray::s![y1..y2, x1..x2, ..]
|
||||
}
|
||||
|
||||
BIN
facenet.mnn
Normal file
BIN
facenet.mnn
Normal file
Binary file not shown.
14
flake.lock
generated
14
flake.lock
generated
@@ -109,16 +109,16 @@
|
||||
"mnn-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1749173738,
|
||||
"narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=",
|
||||
"lastModified": 1753256753,
|
||||
"narHash": "sha256-aTpwVZBkpQiwOVVXDfQIVEx9CswNiPbvNftw8KsoW+Q=",
|
||||
"owner": "alibaba",
|
||||
"repo": "MNN",
|
||||
"rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32",
|
||||
"rev": "a739ea5870a4a45680f0e36ba9662ca39f2f4eec",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "alibaba",
|
||||
"ref": "3.2.0",
|
||||
"ref": "3.2.2",
|
||||
"repo": "MNN",
|
||||
"type": "github"
|
||||
}
|
||||
@@ -178,11 +178,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1750732748,
|
||||
"narHash": "sha256-HR2b3RHsPeJm+Fb+1ui8nXibgniVj7hBNvUbXEyz0DU=",
|
||||
"lastModified": 1754621349,
|
||||
"narHash": "sha256-JkXUS/nBHyUqVTuL4EDCvUWauTHV78EYfk+WqiTAMQ4=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "4b4494b2ba7e8a8041b2e28320b2ee02c115c75f",
|
||||
"rev": "c448ab42002ac39d3337da10420c414fccfb1088",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
80
flake.nix
80
flake.nix
@@ -22,12 +22,13 @@
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
mnn-src = {
|
||||
url = "github:alibaba/MNN/3.2.0";
|
||||
url = "github:alibaba/MNN/3.2.2";
|
||||
flake = false;
|
||||
};
|
||||
};
|
||||
|
||||
outputs = {
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
crane,
|
||||
flake-utils,
|
||||
@@ -40,7 +41,8 @@
|
||||
...
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system: let
|
||||
system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
overlays = [
|
||||
@@ -61,24 +63,39 @@
|
||||
|
||||
stableToolchain = pkgs.rust-bin.stable.latest.default;
|
||||
stableToolchainWithLLvmTools = stableToolchain.override {
|
||||
extensions = ["rust-src" "llvm-tools"];
|
||||
extensions = [
|
||||
"rust-src"
|
||||
"llvm-tools"
|
||||
];
|
||||
};
|
||||
stableToolchainWithRustAnalyzer = stableToolchain.override {
|
||||
extensions = ["rust-src" "rust-analyzer"];
|
||||
extensions = [
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
};
|
||||
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
||||
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
||||
|
||||
src = let
|
||||
src =
|
||||
let
|
||||
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
||||
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
|
||||
sourceFilters =
|
||||
path: type:
|
||||
(craneLib.filterCargoSources path type)
|
||||
|| filterBySuffix path [
|
||||
".c"
|
||||
".h"
|
||||
".hpp"
|
||||
".cpp"
|
||||
".cc"
|
||||
];
|
||||
in
|
||||
lib.cleanSourceWith {
|
||||
filter = sourceFilters;
|
||||
src = ./.;
|
||||
};
|
||||
commonArgs =
|
||||
{
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = name;
|
||||
stdenv = pkgs.clangStdenv;
|
||||
@@ -88,7 +105,8 @@
|
||||
# cmake
|
||||
# llvmPackages.libclang.lib
|
||||
# ];
|
||||
buildInputs = with pkgs;
|
||||
buildInputs =
|
||||
with pkgs;
|
||||
[ ]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
libiconv
|
||||
@@ -99,14 +117,16 @@
|
||||
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
|
||||
});
|
||||
cargoArtifacts = craneLib.buildPackage commonArgs;
|
||||
in {
|
||||
checks =
|
||||
in
|
||||
{
|
||||
"${name}-clippy" = craneLib.cargoClippy (commonArgs
|
||||
checks = {
|
||||
"${name}-clippy" = craneLib.cargoClippy (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||
});
|
||||
}
|
||||
);
|
||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; });
|
||||
"${name}-fmt" = craneLib.cargoFmt { inherit src; };
|
||||
"${name}-toml-fmt" = craneLib.taploFmt {
|
||||
@@ -121,20 +141,26 @@
|
||||
"${name}-deny" = craneLib.cargoDeny {
|
||||
inherit src;
|
||||
};
|
||||
"${name}-nextest" = craneLib.cargoNextest (commonArgs
|
||||
"${name}-nextest" = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
partitions = 1;
|
||||
partitionType = "count";
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; });
|
||||
};
|
||||
|
||||
packages = let
|
||||
pkg = craneLib.buildPackage (commonArgs
|
||||
// {inherit cargoArtifacts;}
|
||||
packages =
|
||||
let
|
||||
pkg = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
}
|
||||
// {
|
||||
nativeBuildInputs = with pkgs; [
|
||||
installShellFiles
|
||||
@@ -145,28 +171,34 @@
|
||||
--fish <($out/bin/${name} completions fish) \
|
||||
--zsh <($out/bin/${name} completions zsh)
|
||||
'';
|
||||
});
|
||||
in {
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
"${name}" = pkg;
|
||||
default = pkg;
|
||||
};
|
||||
|
||||
devShells = {
|
||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
|
||||
default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } (
|
||||
commonArgs
|
||||
// {
|
||||
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
||||
packages = with pkgs;
|
||||
packages =
|
||||
with pkgs;
|
||||
[
|
||||
stableToolchainWithRustAnalyzer
|
||||
cargo-nextest
|
||||
cargo-deny
|
||||
cmake
|
||||
mnn
|
||||
cargo-make
|
||||
]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
apple-sdk_13
|
||||
]);
|
||||
});
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
)
|
||||
|
||||
BIN
models/facenet.mnn
LFS
Normal file
BIN
models/facenet.mnn
LFS
Normal file
Binary file not shown.
BIN
models/facenet.onnx
LFS
Normal file
BIN
models/facenet.onnx
LFS
Normal file
Binary file not shown.
Binary file not shown.
BIN
models/retinaface.onnx
LFS
Normal file
BIN
models/retinaface.onnx
LFS
Normal file
Binary file not shown.
@@ -48,6 +48,8 @@ pub struct Detect {
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
#[clap(short, long, default_value = "cpu")]
|
||||
pub forward_type: mnn::ForwardType,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
|
||||
@@ -6,27 +6,28 @@ use nalgebra::{Point2, Vector2};
|
||||
use ndarray_resize::NdFir;
|
||||
use std::path::Path;
|
||||
|
||||
/// Configuration for face detection postprocessing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionConfig {
|
||||
anchor_sizes: Vec<Vector2<usize>>,
|
||||
steps: Vec<usize>,
|
||||
variance: Vec<f32>,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
/// Minimum confidence to keep a detection
|
||||
pub threshold: f32,
|
||||
/// NMS threshold for suppressing overlapping boxes
|
||||
pub nms_threshold: f32,
|
||||
/// Variances for bounding box decoding
|
||||
pub variances: [f32; 2],
|
||||
/// The step size (stride) for each feature map
|
||||
pub steps: Vec<usize>,
|
||||
/// The minimum anchor sizes for each feature map
|
||||
pub min_sizes: Vec<Vec<usize>>,
|
||||
/// Whether to clip bounding boxes to the image dimensions
|
||||
pub clamp: bool,
|
||||
/// Input image width (used for anchor generation)
|
||||
pub input_width: usize,
|
||||
/// Input image height (used for anchor generation)
|
||||
pub input_height: usize,
|
||||
}
|
||||
|
||||
impl FaceDetectionConfig {
|
||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vector2<usize>>) -> Self {
|
||||
self.anchor_sizes = min_sizes;
|
||||
self
|
||||
}
|
||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||
self.steps = steps;
|
||||
self
|
||||
}
|
||||
pub fn with_variance(mut self, variance: Vec<f32>) -> Self {
|
||||
self.variance = variance;
|
||||
self
|
||||
}
|
||||
pub fn with_threshold(mut self, threshold: f32) -> Self {
|
||||
self.threshold = threshold;
|
||||
self
|
||||
@@ -35,23 +36,48 @@ impl FaceDetectionConfig {
|
||||
self.nms_threshold = nms_threshold;
|
||||
self
|
||||
}
|
||||
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
|
||||
self.variances = variances;
|
||||
self
|
||||
}
|
||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||
self.steps = steps;
|
||||
self
|
||||
}
|
||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
|
||||
self.min_sizes = min_sizes;
|
||||
self
|
||||
}
|
||||
pub fn with_clip(mut self, clip: bool) -> Self {
|
||||
self.clamp = clip;
|
||||
self
|
||||
}
|
||||
pub fn with_input_width(mut self, input_width: usize) -> Self {
|
||||
self.input_width = input_width;
|
||||
self
|
||||
}
|
||||
pub fn with_input_height(mut self, input_height: usize) -> Self {
|
||||
self.input_height = input_height;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FaceDetectionConfig {
|
||||
fn default() -> Self {
|
||||
FaceDetectionConfig {
|
||||
anchor_sizes: vec![
|
||||
Vector2::new(16, 32),
|
||||
Vector2::new(64, 128),
|
||||
Vector2::new(256, 512),
|
||||
],
|
||||
steps: vec![8, 16, 32],
|
||||
variance: vec![0.1, 0.2],
|
||||
threshold: 0.8,
|
||||
Self {
|
||||
threshold: 0.5,
|
||||
nms_threshold: 0.4,
|
||||
variances: [0.1, 0.2],
|
||||
steps: vec![8, 16, 32],
|
||||
min_sizes: vec![vec![16, 32], vec![64, 128], vec![256, 512]],
|
||||
clamp: true,
|
||||
input_width: 1024,
|
||||
input_height: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
@@ -87,80 +113,111 @@ pub struct FaceDetectionOutput {
|
||||
pub landmark: Vec<FaceLandmarks>,
|
||||
}
|
||||
|
||||
fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
||||
let mut anchors = Vec::new();
|
||||
let feature_maps: Vec<(usize, usize)> = config
|
||||
.steps
|
||||
.iter()
|
||||
.map(|&step| {
|
||||
(
|
||||
(config.input_height as f32 / step as f32).ceil() as usize,
|
||||
(config.input_width as f32 / step as f32).ceil() as usize,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (k, f) in feature_maps.iter().enumerate() {
|
||||
let min_sizes = &config.min_sizes[k];
|
||||
for i in 0..f.0 {
|
||||
for j in 0..f.1 {
|
||||
for &min_size in min_sizes {
|
||||
let s_kx = min_size as f32 / config.input_width as f32;
|
||||
let s_ky = min_size as f32 / config.input_height as f32;
|
||||
let dense_cx =
|
||||
(j as f32 + 0.5) * config.steps[k] as f32 / config.input_width as f32;
|
||||
let dense_cy =
|
||||
(i as f32 + 0.5) * config.steps[k] as f32 / config.input_height as f32;
|
||||
anchors.push([
|
||||
dense_cx - s_kx / 2.,
|
||||
dense_cy - s_ky / 2.,
|
||||
dense_cx + s_kx / 2.,
|
||||
dense_cy + s_ky / 2.,
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ndarray::Array2::from_shape_vec((anchors.len(), 4), anchors.into_iter().flatten().collect())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
impl FaceDetectionModelOutput {
|
||||
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||
let mut anchors = Vec::new();
|
||||
for (k, &step) in config.steps.iter().enumerate() {
|
||||
let feature_size = 1024 / step;
|
||||
let min_sizes = config.anchor_sizes[k];
|
||||
let sizes = [min_sizes.x, min_sizes.y];
|
||||
for i in 0..feature_size {
|
||||
for j in 0..feature_size {
|
||||
for &size in &sizes {
|
||||
let cx = (j as f32 + 0.5) * step as f32 / 1024.0;
|
||||
let cy = (i as f32 + 0.5) * step as f32 / 1024.0;
|
||||
let s_k = size as f32 / 1024.0;
|
||||
anchors.push((cx, cy, s_k, s_k));
|
||||
use ndarray::s;
|
||||
|
||||
let priors = generate_anchors(config);
|
||||
|
||||
let scores = self.confidence.slice(s![0, .., 1]);
|
||||
let boxes = self.bbox.slice(s![0, .., ..]);
|
||||
let landmarks_raw = self.landmark.slice(s![0, .., ..]);
|
||||
|
||||
let mut decoded_boxes = Vec::new();
|
||||
let mut decoded_landmarks = Vec::new();
|
||||
let mut confidences = Vec::new();
|
||||
|
||||
for i in 0..priors.shape()[0] {
|
||||
if scores[i] > config.threshold {
|
||||
let prior = priors.row(i);
|
||||
let loc = boxes.row(i);
|
||||
let landm = landmarks_raw.row(i);
|
||||
|
||||
// Decode bounding box
|
||||
let prior_cx = (prior[0] + prior[2]) / 2.0;
|
||||
let prior_cy = (prior[1] + prior[3]) / 2.0;
|
||||
let prior_w = prior[2] - prior[0];
|
||||
let prior_h = prior[3] - prior[1];
|
||||
|
||||
let var = config.variances;
|
||||
let cx = prior_cx + loc[0] * var[0] * prior_w;
|
||||
let cy = prior_cy + loc[1] * var[0] * prior_h;
|
||||
let w = prior_w * (loc[2] * var[1]).exp();
|
||||
let h = prior_h * (loc[3] * var[1]).exp();
|
||||
|
||||
let xmin = cx - w / 2.0;
|
||||
let ymin = cy - h / 2.0;
|
||||
let xmax = cx + w / 2.0;
|
||||
let ymax = cy + h / 2.0;
|
||||
|
||||
let mut bbox =
|
||||
Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax));
|
||||
if config.clamp {
|
||||
bbox.component_clamp(0.0, 1.0);
|
||||
}
|
||||
decoded_boxes.push(bbox);
|
||||
|
||||
// Decode landmarks
|
||||
let mut points = [Point2::new(0.0, 0.0); 5];
|
||||
for j in 0..5 {
|
||||
points[j].x = prior_cx + landm[j * 2] * var[0] * prior_w;
|
||||
points[j].y = prior_cy + landm[j * 2 + 1] * var[0] * prior_h;
|
||||
}
|
||||
let landmarks = FaceLandmarks {
|
||||
left_eye: points[0],
|
||||
right_eye: points[1],
|
||||
nose: points[2],
|
||||
left_mouth: points[3],
|
||||
right_mouth: points[4],
|
||||
};
|
||||
decoded_landmarks.push(landmarks);
|
||||
confidences.push(scores[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut boxes = Vec::new();
|
||||
let mut scores = Vec::new();
|
||||
let mut landmarks = Vec::new();
|
||||
let var0 = config.variance[0];
|
||||
let var1 = config.variance[1];
|
||||
let bbox_data = self.bbox;
|
||||
let conf_data = self.confidence;
|
||||
let landmark_data = self.landmark;
|
||||
let num_priors = bbox_data.shape()[1];
|
||||
for idx in 0..num_priors {
|
||||
let dx = bbox_data[[0, idx, 0]];
|
||||
let dy = bbox_data[[0, idx, 1]];
|
||||
let dw = bbox_data[[0, idx, 2]];
|
||||
let dh = bbox_data[[0, idx, 3]];
|
||||
let (anchor_cx, anchor_cy, anchor_w, anchor_h) = anchors[idx];
|
||||
let pred_cx = anchor_cx + dx * var0 * anchor_w;
|
||||
let pred_cy = anchor_cy + dy * var0 * anchor_h;
|
||||
let pred_w = anchor_w * (dw * var1).exp();
|
||||
let pred_h = anchor_h * (dh * var1).exp();
|
||||
let x_min = pred_cx - pred_w / 2.0;
|
||||
let y_min = pred_cy - pred_h / 2.0;
|
||||
let x_max = pred_cx + pred_w / 2.0;
|
||||
let y_max = pred_cy + pred_h / 2.0;
|
||||
let score = conf_data[[0, idx, 1]];
|
||||
if score > config.threshold {
|
||||
boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max));
|
||||
scores.push(score);
|
||||
|
||||
let left_eye_x = landmark_data[[0, idx, 0]] * anchor_w * var0 + anchor_cx;
|
||||
let left_eye_y = landmark_data[[0, idx, 1]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let right_eye_x = landmark_data[[0, idx, 2]] * anchor_w * var0 + anchor_cx;
|
||||
let right_eye_y = landmark_data[[0, idx, 3]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let nose_x = landmark_data[[0, idx, 4]] * anchor_w * var0 + anchor_cx;
|
||||
let nose_y = landmark_data[[0, idx, 5]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let left_mouth_x = landmark_data[[0, idx, 6]] * anchor_w * var0 + anchor_cx;
|
||||
let left_mouth_y = landmark_data[[0, idx, 7]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let right_mouth_x = landmark_data[[0, idx, 8]] * anchor_w * var0 + anchor_cx;
|
||||
let right_mouth_y = landmark_data[[0, idx, 9]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
landmarks.push(FaceLandmarks {
|
||||
left_eye: Point2::new(left_eye_x, left_eye_y),
|
||||
right_eye: Point2::new(right_eye_x, right_eye_y),
|
||||
nose: Point2::new(nose_x, nose_y),
|
||||
left_mouth: Point2::new(left_mouth_x, left_mouth_y),
|
||||
right_mouth: Point2::new(right_mouth_x, right_mouth_y),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(FaceDetectionProcessedOutput {
|
||||
bbox: boxes,
|
||||
confidence: scores,
|
||||
landmarks,
|
||||
bbox: decoded_boxes,
|
||||
confidence: confidences,
|
||||
landmarks: decoded_landmarks,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -189,7 +246,56 @@ impl FaceDetectionModelOutput {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<FaceDetection> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
||||
FaceDetectionBuilder::new
|
||||
}
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
@@ -204,9 +310,13 @@ impl FaceDetection {
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
model
|
||||
.set_cache_file("retinaface.cache", 128)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::CPU)
|
||||
.with_type(mnn::ForwardType::Metal)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face detection model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
@@ -217,7 +327,7 @@ impl FaceDetection {
|
||||
|
||||
pub fn detect_faces(
|
||||
&self,
|
||||
image: ndarray::Array3<u8>,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
config: FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, _channels) = image.dim();
|
||||
@@ -242,7 +352,8 @@ impl FaceDetection {
|
||||
.map(|((b, s), l)| (b, s, l))
|
||||
.multiunzip();
|
||||
|
||||
let keep_indices = nms(&boxes, &scores, config.threshold, config.nms_threshold);
|
||||
let keep_indices =
|
||||
nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?;
|
||||
|
||||
let bboxes = boxes
|
||||
.into_iter()
|
||||
@@ -270,15 +381,11 @@ impl FaceDetection {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(mnn::ErrorKind::TensorError)?
|
||||
.change_context(Error)?
|
||||
.mapv(|f| f as f32)
|
||||
.tap_mut(|arr| {
|
||||
arr.axis_iter_mut(ndarray::Axis(2))
|
||||
@@ -292,6 +399,10 @@ impl FaceDetection {
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = resized
|
||||
.as_mnn_tensor_mut()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
|
||||
1
src/faceembed.rs
Normal file
1
src/faceembed.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod facenet;
|
||||
153
src/faceembed/facenet.rs
Normal file
153
src/faceembed/facenet.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use crate::errors::*;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use std::path::Path;
|
||||
mod mnn_impl;
|
||||
mod ort_impl;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingGenerator {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
pub struct EmbeddingGeneratorBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl EmbeddingGeneratorBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(EmbeddingGenerator { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
||||
EmbeddingGeneratorBuilder::new
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||
tracing::info!("Loading face embedding model from bytes");
|
||||
let mut model = mnn::Interpreter::from_bytes(model)
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
model
|
||||
.set_cache_file("facenet.cache", 128)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::Metal)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face embedding model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(Self { handle })
|
||||
}
|
||||
|
||||
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
let tensor = face
|
||||
// .permuted_axes((0, 3, 1, 2))
|
||||
.as_standard_layout()
|
||||
.mapv(|x| x as f32);
|
||||
let shape: [usize; 4] = tensor.dim().into();
|
||||
let shape = shape.map(|f| f as i32);
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = tensor
|
||||
.as_mnn_tensor()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
let needs_resize = unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
if *input.shape() != shape {
|
||||
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
||||
// input.resize(shape);
|
||||
intptr.resize_tensor(input.view_mut(), shape);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
if needs_resize {
|
||||
tracing::trace!("Resized input tensor to shape: {:?}", shape);
|
||||
let now = std::time::Instant::now();
|
||||
intptr.resize_session(session);
|
||||
tracing::trace!("Session resized in {:?}", now.elapsed());
|
||||
}
|
||||
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, Self::OUTPUT_NAME)?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
Ok(output_tensor)
|
||||
})
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
}
|
||||
1
src/faceembed/facenet/mnn_impl.rs
Normal file
1
src/faceembed/facenet/mnn_impl.rs
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
65
src/faceembed/facenet/ort_impl.rs
Normal file
65
src/faceembed/facenet/ort_impl.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use crate::errors::{Result, *};
|
||||
use ndarray::*;
|
||||
use ort::*;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingGenerator {
|
||||
handle: ort::session::Session,
|
||||
}
|
||||
|
||||
// impl EmbeddingGeneratorBuilder {
|
||||
// pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
// Ok(Self {
|
||||
// schedule_config: None,
|
||||
// backend_config: None,
|
||||
// model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
// .map_err(|e| e.into_inner())
|
||||
// .change_context(Error)
|
||||
// .attach_printable("Failed to load model from bytes")?,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
// self.schedule_config
|
||||
// .get_or_insert_default()
|
||||
// .set_type(forward_type);
|
||||
// self
|
||||
// }
|
||||
//
|
||||
// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
// self.schedule_config = Some(config);
|
||||
// self
|
||||
// }
|
||||
//
|
||||
// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
// self.backend_config = Some(config);
|
||||
// self
|
||||
// }
|
||||
//
|
||||
// pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||
// let model = self.model;
|
||||
// let sc = self.schedule_config.unwrap_or_default();
|
||||
// let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
// .change_context(Error)
|
||||
// .attach_printable("Failed to create session handle")?;
|
||||
// Ok(EmbeddingGenerator { handle })
|
||||
// }
|
||||
// }
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod errors;
|
||||
pub mod facedet;
|
||||
pub mod faceembed;
|
||||
pub mod image;
|
||||
use errors::*;
|
||||
|
||||
77
src/main.rs
77
src/main.rs
@@ -1,9 +1,15 @@
|
||||
mod cli;
|
||||
mod errors;
|
||||
use detector::facedet::retinaface::FaceDetectionConfig;
|
||||
use bounding_box::roi::MultiRoi;
|
||||
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
use ndarray::*;
|
||||
use ndarray_image::*;
|
||||
use ndarray_resize::NdFir;
|
||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
@@ -15,27 +21,86 @@ pub fn main() -> Result<()> {
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
use detector::facedet;
|
||||
let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||
let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = model
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.clone(),
|
||||
FaceDetectionConfig::default().with_threshold(detect.threshold),
|
||||
array.view(),
|
||||
FaceDetectionConfig::default()
|
||||
.with_threshold(detect.threshold)
|
||||
.with_nms_threshold(detect.nms_threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
for bbox in output.bbox {
|
||||
for bbox in &output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
// .inspect(|f| {
|
||||
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(512, 512, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
// f.as_ref().inspect(|f| {
|
||||
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
|
||||
// });
|
||||
// })
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let chunk_size = CHUNK_SIZE;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
if chunk.len() < 8 {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((512, 512, 3));
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>();
|
||||
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
let image: image::RgbImage = v
|
||||
|
||||
Reference in New Issue
Block a user