feat: Added facenet
This commit is contained in:
64
Cargo.lock
generated
64
Cargo.lock
generated
@@ -116,15 +116,6 @@ version = "1.0.97"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
|
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "approx"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278"
|
|
||||||
dependencies = [
|
|
||||||
"num-traits",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "approx"
|
name = "approx"
|
||||||
version = "0.5.1"
|
version = "0.5.1"
|
||||||
@@ -253,7 +244,7 @@ dependencies = [
|
|||||||
"color",
|
"color",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
"nalgebra",
|
"nalgebra",
|
||||||
"ndarray 0.16.1",
|
"ndarray",
|
||||||
"num",
|
"num",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
"simba",
|
"simba",
|
||||||
@@ -506,12 +497,11 @@ dependencies = [
|
|||||||
"fast_image_resize",
|
"fast_image_resize",
|
||||||
"image",
|
"image",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
"linfa",
|
|
||||||
"mnn",
|
"mnn",
|
||||||
"mnn-bridge",
|
"mnn-bridge",
|
||||||
"mnn-sync",
|
"mnn-sync",
|
||||||
"nalgebra",
|
"nalgebra",
|
||||||
"ndarray 0.16.1",
|
"ndarray",
|
||||||
"ndarray-image",
|
"ndarray-image",
|
||||||
"ndarray-resize",
|
"ndarray-resize",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
@@ -1098,20 +1088,6 @@ dependencies = [
|
|||||||
"vcpkg",
|
"vcpkg",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "linfa"
|
|
||||||
version = "0.7.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "56f9097edc7c89d03d526efbacf6d90914e3a8fa53bd56c2d1489e3a90819370"
|
|
||||||
dependencies = [
|
|
||||||
"approx 0.4.0",
|
|
||||||
"ndarray 0.15.6",
|
|
||||||
"num-traits",
|
|
||||||
"rand",
|
|
||||||
"sprs",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "litemap"
|
name = "litemap"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
@@ -1220,7 +1196,7 @@ source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-ty
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"error-stack",
|
"error-stack",
|
||||||
"mnn",
|
"mnn",
|
||||||
"ndarray 0.16.1",
|
"ndarray",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1259,7 +1235,7 @@ version = "0.33.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
|
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"approx 0.5.1",
|
"approx",
|
||||||
"matrixmultiply",
|
"matrixmultiply",
|
||||||
"nalgebra-macros",
|
"nalgebra-macros",
|
||||||
"num-complex",
|
"num-complex",
|
||||||
@@ -1289,20 +1265,6 @@ dependencies = [
|
|||||||
"getrandom 0.2.16",
|
"getrandom 0.2.16",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ndarray"
|
|
||||||
version = "0.15.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
|
||||||
dependencies = [
|
|
||||||
"approx 0.4.0",
|
|
||||||
"matrixmultiply",
|
|
||||||
"num-complex",
|
|
||||||
"num-integer",
|
|
||||||
"num-traits",
|
|
||||||
"rawpointer",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ndarray"
|
name = "ndarray"
|
||||||
version = "0.16.1"
|
version = "0.16.1"
|
||||||
@@ -1323,7 +1285,7 @@ name = "ndarray-image"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"image",
|
"image",
|
||||||
"ndarray 0.16.1",
|
"ndarray",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1333,7 +1295,7 @@ dependencies = [
|
|||||||
"bytemuck",
|
"bytemuck",
|
||||||
"error-stack",
|
"error-stack",
|
||||||
"fast_image_resize",
|
"fast_image_resize",
|
||||||
"ndarray 0.16.1",
|
"ndarray",
|
||||||
"num",
|
"num",
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
]
|
]
|
||||||
@@ -1942,7 +1904,7 @@ version = "0.9.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa"
|
checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"approx 0.5.1",
|
"approx",
|
||||||
"num-complex",
|
"num-complex",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"paste",
|
"paste",
|
||||||
@@ -1979,18 +1941,6 @@ dependencies = [
|
|||||||
"lock_api",
|
"lock_api",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "sprs"
|
|
||||||
version = "0.11.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "88bab60b0a18fb9b3e0c26e92796b3c3a278bf5fa4880f5ad5cc3bdfb843d0b1"
|
|
||||||
dependencies = [
|
|
||||||
"ndarray 0.15.6",
|
|
||||||
"num-complex",
|
|
||||||
"num-traits",
|
|
||||||
"smallvec",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stable_deref_trait"
|
name = "stable_deref_trait"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" }
|
|||||||
ndarray-image = { path = "ndarray-image" }
|
ndarray-image = { path = "ndarray-image" }
|
||||||
ndarray-resize = { path = "ndarray-resize" }
|
ndarray-resize = { path = "ndarray-resize" }
|
||||||
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
|
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
|
||||||
# "metal",
|
"metal",
|
||||||
# "coreml",
|
"coreml",
|
||||||
"tracing",
|
"tracing",
|
||||||
], branch = "restructure-tensor-type" }
|
], branch = "restructure-tensor-type" }
|
||||||
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||||
@@ -35,7 +35,6 @@ clap_complete = "4.5"
|
|||||||
error-stack = "0.5"
|
error-stack = "0.5"
|
||||||
fast_image_resize = "5.2.0"
|
fast_image_resize = "5.2.0"
|
||||||
image = "0.25.6"
|
image = "0.25.6"
|
||||||
linfa = "0.7.1"
|
|
||||||
nalgebra = "0.33.2"
|
nalgebra = "0.33.2"
|
||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
ndarray-image = { workspace = true }
|
ndarray-image = { workspace = true }
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ pub use color::Rgba8;
|
|||||||
use ndarray::{Array1, Array3, ArrayViewMut3};
|
use ndarray::{Array1, Array3, ArrayViewMut3};
|
||||||
|
|
||||||
pub trait Draw<T> {
|
pub trait Draw<T> {
|
||||||
fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize);
|
fn draw(&mut self, item: &T, color: color::Rgba8, thickness: usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Draw<Aabb2<usize>> for Array3<u8> {
|
impl Draw<Aabb2<usize>> for Array3<u8> {
|
||||||
fn draw(&mut self, item: Aabb2<usize>, color: color::Rgba8, thickness: usize) {
|
fn draw(&mut self, item: &Aabb2<usize>, color: color::Rgba8, thickness: usize) {
|
||||||
item.draw(self, color, thickness)
|
item.draw(self, color, thickness)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,17 @@ pub trait Roi<'a, Output> {
|
|||||||
type Error;
|
type Error;
|
||||||
fn roi(&'a self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
fn roi(&'a self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait RoiMut<'a, Output> {
|
pub trait RoiMut<'a, Output> {
|
||||||
type Error;
|
type Error;
|
||||||
fn roi_mut(&'a mut self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
fn roi_mut(&'a mut self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait MultiRoi<'a, Output> {
|
||||||
|
type Error;
|
||||||
|
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Output, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug, Copy, Clone)]
|
#[derive(thiserror::Error, Debug, Copy, Clone)]
|
||||||
pub enum RoiError {
|
pub enum RoiError {
|
||||||
#[error("Region of intereset is out of bounds")]
|
#[error("Region of intereset is out of bounds")]
|
||||||
@@ -36,7 +43,7 @@ impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3<T> {
|
|||||||
let x2 = aabb.x2();
|
let x2 = aabb.x2();
|
||||||
let y1 = aabb.y1();
|
let y1 = aabb.y1();
|
||||||
let y2 = aabb.y2();
|
let y2 = aabb.y2();
|
||||||
if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
|
if x1 > x2 || y1 > y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
|
||||||
return Err(RoiError::RoiOutOfBounds);
|
return Err(RoiError::RoiOutOfBounds);
|
||||||
}
|
}
|
||||||
Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]))
|
Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]))
|
||||||
@@ -95,3 +102,47 @@ pub fn reborrow_test() {
|
|||||||
};
|
};
|
||||||
dbg!(y);
|
dbg!(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> MultiRoi<'a, Vec<ArrayView3<'a, u8>>> for Array3<u8> {
|
||||||
|
type Error = RoiError;
|
||||||
|
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'a, u8>>, Self::Error> {
|
||||||
|
let (height, width, _channels) = self.dim();
|
||||||
|
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
|
||||||
|
aabbs
|
||||||
|
.iter()
|
||||||
|
.map(|aabb| {
|
||||||
|
let slice_arg =
|
||||||
|
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
|
||||||
|
Ok(self.slice(slice_arg))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, RoiError>>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> MultiRoi<'a, Vec<ArrayView3<'b, u8>>> for ArrayView3<'b, u8> {
|
||||||
|
type Error = RoiError;
|
||||||
|
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'b, u8>>, Self::Error> {
|
||||||
|
let (height, width, _channels) = self.dim();
|
||||||
|
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
|
||||||
|
aabbs
|
||||||
|
.iter()
|
||||||
|
.map(|aabb| {
|
||||||
|
let slice_arg =
|
||||||
|
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
|
||||||
|
Ok(self.slice_move(slice_arg))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, RoiError>>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bbox_to_slice_arg(
|
||||||
|
aabb: Aabb2<usize>,
|
||||||
|
) -> ndarray::SliceInfo<[ndarray::SliceInfoElem; 3], ndarray::Ix3, ndarray::Ix3> {
|
||||||
|
// This function should convert the bounding box to a slice argument
|
||||||
|
// For now, we will return a dummy value
|
||||||
|
let x1 = aabb.x1();
|
||||||
|
let x2 = aabb.x2();
|
||||||
|
let y1 = aabb.y1();
|
||||||
|
let y2 = aabb.y2();
|
||||||
|
ndarray::s![y1..y2, x1..x2, ..]
|
||||||
|
}
|
||||||
|
|||||||
6
flake.lock
generated
6
flake.lock
generated
@@ -178,11 +178,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1750732748,
|
"lastModified": 1754621349,
|
||||||
"narHash": "sha256-HR2b3RHsPeJm+Fb+1ui8nXibgniVj7hBNvUbXEyz0DU=",
|
"narHash": "sha256-JkXUS/nBHyUqVTuL4EDCvUWauTHV78EYfk+WqiTAMQ4=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "4b4494b2ba7e8a8041b2e28320b2ee02c115c75f",
|
"rev": "c448ab42002ac39d3337da10420c414fccfb1088",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|||||||
BIN
models/retinaface.mnn
LFS
BIN
models/retinaface.mnn
LFS
Binary file not shown.
@@ -261,6 +261,10 @@ impl FaceDetection {
|
|||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
.attach_printable("Failed to load model from bytes")?;
|
.attach_printable("Failed to load model from bytes")?;
|
||||||
model.set_session_mode(mnn::SessionMode::Release);
|
model.set_session_mode(mnn::SessionMode::Release);
|
||||||
|
model
|
||||||
|
.set_cache_file("retinaface.cache", 128)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set cache file")?;
|
||||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||||
let sc = mnn::ScheduleConfig::new()
|
let sc = mnn::ScheduleConfig::new()
|
||||||
.with_type(mnn::ForwardType::CPU)
|
.with_type(mnn::ForwardType::CPU)
|
||||||
@@ -330,13 +334,9 @@ impl FaceDetection {
|
|||||||
|
|
||||||
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
use ::tap::*;
|
|
||||||
let output = self
|
|
||||||
.handle
|
|
||||||
.run(move |sr| {
|
|
||||||
let mut resized = image
|
let mut resized = image
|
||||||
.fast_resize(1024, 1024, None)
|
.fast_resize(1024, 1024, None)
|
||||||
.change_context(mnn::ErrorKind::TensorError)?
|
.change_context(Error)?
|
||||||
.mapv(|f| f as f32)
|
.mapv(|f| f as f32)
|
||||||
.tap_mut(|arr| {
|
.tap_mut(|arr| {
|
||||||
arr.axis_iter_mut(ndarray::Axis(2))
|
arr.axis_iter_mut(ndarray::Axis(2))
|
||||||
@@ -350,6 +350,10 @@ impl FaceDetection {
|
|||||||
.insert_axis(ndarray::Axis(0))
|
.insert_axis(ndarray::Axis(0))
|
||||||
.as_standard_layout()
|
.as_standard_layout()
|
||||||
.into_owned();
|
.into_owned();
|
||||||
|
use ::tap::*;
|
||||||
|
let output = self
|
||||||
|
.handle
|
||||||
|
.run(move |sr| {
|
||||||
let tensor = resized
|
let tensor = resized
|
||||||
.as_mnn_tensor_mut()
|
.as_mnn_tensor_mut()
|
||||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
|
use mnn_bridge::ndarray::*;
|
||||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -8,6 +9,8 @@ pub struct EmbeddingGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingGenerator {
|
impl EmbeddingGenerator {
|
||||||
|
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||||
|
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||||
let model = std::fs::read(path)
|
let model = std::fs::read(path)
|
||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
@@ -22,9 +25,13 @@ impl EmbeddingGenerator {
|
|||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
.attach_printable("Failed to load model from bytes")?;
|
.attach_printable("Failed to load model from bytes")?;
|
||||||
model.set_session_mode(mnn::SessionMode::Release);
|
model.set_session_mode(mnn::SessionMode::Release);
|
||||||
|
model
|
||||||
|
.set_cache_file("facenet.cache", 128)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set cache file")?;
|
||||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||||
let sc = mnn::ScheduleConfig::new()
|
let sc = mnn::ScheduleConfig::new()
|
||||||
.with_type(mnn::ForwardType::CPU)
|
.with_type(mnn::ForwardType::Metal)
|
||||||
.with_backend_config(bc);
|
.with_backend_config(bc);
|
||||||
tracing::info!("Creating session handle for face embedding model");
|
tracing::info!("Creating session handle for face embedding model");
|
||||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
@@ -33,11 +40,55 @@ impl EmbeddingGenerator {
|
|||||||
Ok(Self { handle })
|
Ok(Self { handle })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||||
todo!()
|
let tensor = face
|
||||||
|
// .permuted_axes((0, 3, 1, 2))
|
||||||
|
.as_standard_layout()
|
||||||
|
.mapv(|x| x as f32);
|
||||||
|
let shape: [usize; 4] = tensor.dim().into();
|
||||||
|
let shape = shape.map(|f| f as i32);
|
||||||
|
let output = self
|
||||||
|
.handle
|
||||||
|
.run(move |sr| {
|
||||||
|
let tensor = tensor
|
||||||
|
.as_mnn_tensor()
|
||||||
|
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||||
|
.change_context(mnn::ErrorKind::TensorError)?;
|
||||||
|
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||||
|
let (intptr, session) = sr.both_mut();
|
||||||
|
tracing::trace!("Copying input tensor to host");
|
||||||
|
unsafe {
|
||||||
|
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
||||||
|
tracing::trace!("Input shape: {:?}", input.shape());
|
||||||
|
if *input.shape() != shape {
|
||||||
|
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
||||||
|
// input.resize(shape);
|
||||||
|
intptr.resize_tensor(input.view_mut(), shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
intptr.resize_session(session);
|
||||||
|
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
||||||
|
tracing::trace!("Input shape: {:?}", input.shape());
|
||||||
|
input.copy_from_host_tensor(tensor.view())?;
|
||||||
|
|
||||||
|
tracing::info!("Running face detection session");
|
||||||
|
intptr.run_session(&session)?;
|
||||||
|
let output_tensor = intptr
|
||||||
|
.output::<f32>(&session, Self::OUTPUT_NAME)?
|
||||||
|
.create_host_tensor_from_device(true)
|
||||||
|
.as_ndarray()
|
||||||
|
.to_owned();
|
||||||
|
Ok(output_tensor)
|
||||||
|
})
|
||||||
|
.change_context(Error)?;
|
||||||
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
||||||
todo!()
|
// todo!()
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
||||||
|
// todo!()
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|||||||
69
src/main.rs
69
src/main.rs
@@ -1,9 +1,13 @@
|
|||||||
mod cli;
|
mod cli;
|
||||||
mod errors;
|
mod errors;
|
||||||
use detector::facedet::retinaface::FaceDetectionConfig;
|
use bounding_box::roi::MultiRoi;
|
||||||
|
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
||||||
use errors::*;
|
use errors::*;
|
||||||
|
use fast_image_resize::ResizeOptions;
|
||||||
|
use nalgebra::zero;
|
||||||
use ndarray_image::*;
|
use ndarray_image::*;
|
||||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||||
|
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("trace")
|
.with_env_filter("trace")
|
||||||
@@ -15,29 +19,84 @@ pub fn main() -> Result<()> {
|
|||||||
match args.cmd {
|
match args.cmd {
|
||||||
cli::SubCommand::Detect(detect) => {
|
cli::SubCommand::Detect(detect) => {
|
||||||
use detector::facedet;
|
use detector::facedet;
|
||||||
let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
|
let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL)
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
let image = image::open(detect.image).change_context(Error)?;
|
let image = image::open(detect.image).change_context(Error)?;
|
||||||
let image = image.into_rgb8();
|
let image = image.into_rgb8();
|
||||||
let mut array = image
|
let mut array = image
|
||||||
.into_ndarray()
|
.into_ndarray()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to convert image to ndarray")?;
|
.attach_printable("Failed to convert image to ndarray")?;
|
||||||
let output = model
|
let output = retinaface
|
||||||
.detect_faces(
|
.detect_faces(
|
||||||
array.clone(),
|
array.view(),
|
||||||
FaceDetectionConfig::default()
|
FaceDetectionConfig::default()
|
||||||
.with_threshold(detect.threshold)
|
.with_threshold(detect.threshold)
|
||||||
.with_nms_threshold(detect.nms_threshold),
|
.with_nms_threshold(detect.nms_threshold),
|
||||||
)
|
)
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to detect faces")?;
|
.attach_printable("Failed to detect faces")?;
|
||||||
for bbox in output.bbox {
|
for bbox in &output.bbox {
|
||||||
tracing::info!("Detected face: {:?}", bbox);
|
tracing::info!("Detected face: {:?}", bbox);
|
||||||
use bounding_box::draw::*;
|
use bounding_box::draw::*;
|
||||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||||
}
|
}
|
||||||
|
use ndarray::{Array2, Array3, Array4, Axis};
|
||||||
|
use ndarray_resize::NdFir;
|
||||||
|
let face_rois = array
|
||||||
|
.view()
|
||||||
|
.multi_roi(&output.bbox)
|
||||||
|
.change_context(Error)?
|
||||||
|
.into_iter()
|
||||||
|
// .inspect(|f| {
|
||||||
|
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
|
||||||
|
// })
|
||||||
|
.map(|roi| {
|
||||||
|
roi.as_standard_layout()
|
||||||
|
.fast_resize(512, 512, &ResizeOptions::default())
|
||||||
|
.change_context(Error)
|
||||||
|
})
|
||||||
|
// .inspect(|f| {
|
||||||
|
// f.as_ref().inspect(|f| {
|
||||||
|
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
|
||||||
|
// });
|
||||||
|
// })
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let embeddings = face_roi_views
|
||||||
|
.chunks(8)
|
||||||
|
.map(|chunk| {
|
||||||
|
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||||
|
|
||||||
|
if chunk.len() < 8 {
|
||||||
|
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||||
|
let zeros = Array3::zeros((512, 512, 3));
|
||||||
|
let padded: Vec<ndarray::ArrayView3<'_, u8>> = chunk
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.chain(core::iter::repeat(zeros.view()))
|
||||||
|
.take(8)
|
||||||
|
.collect();
|
||||||
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), padded.as_slice())
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to stack rois together")?;
|
||||||
|
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||||
|
Ok(output)
|
||||||
|
} else {
|
||||||
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to stack rois together")?;
|
||||||
|
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<Array2<f32>>>>();
|
||||||
|
|
||||||
let v = array.view();
|
let v = array.view();
|
||||||
if let Some(output) = detect.output {
|
if let Some(output) = detect.output {
|
||||||
let image: image::RgbImage = v
|
let image: image::RgbImage = v
|
||||||
|
|||||||
Reference in New Issue
Block a user