feat: Added facenet
This commit is contained in:
@@ -261,6 +261,10 @@ impl FaceDetection {
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
model
|
||||
.set_cache_file("retinaface.cache", 128)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::CPU)
|
||||
@@ -330,26 +334,26 @@ impl FaceDetection {
|
||||
|
||||
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(Error)?
|
||||
.mapv(|f| f as f32)
|
||||
.tap_mut(|arr| {
|
||||
arr.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104, 117, 123])
|
||||
.for_each(|(mut array, pixel)| {
|
||||
let pixel = pixel as f32;
|
||||
array.map_inplace(|v| *v -= pixel);
|
||||
});
|
||||
})
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(mnn::ErrorKind::TensorError)?
|
||||
.mapv(|f| f as f32)
|
||||
.tap_mut(|arr| {
|
||||
arr.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104, 117, 123])
|
||||
.for_each(|(mut array, pixel)| {
|
||||
let pixel = pixel as f32;
|
||||
array.map_inplace(|v| *v -= pixel);
|
||||
});
|
||||
})
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
let tensor = resized
|
||||
.as_mnn_tensor_mut()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::errors::*;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use std::path::Path;
|
||||
|
||||
@@ -8,6 +9,8 @@ pub struct EmbeddingGenerator {
|
||||
}
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
@@ -22,9 +25,13 @@ impl EmbeddingGenerator {
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
model
|
||||
.set_cache_file("facenet.cache", 128)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::CPU)
|
||||
.with_type(mnn::ForwardType::Metal)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face embedding model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
@@ -33,11 +40,55 @@ impl EmbeddingGenerator {
|
||||
Ok(Self { handle })
|
||||
}
|
||||
|
||||
pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
||||
todo!()
|
||||
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
let tensor = face
|
||||
// .permuted_axes((0, 3, 1, 2))
|
||||
.as_standard_layout()
|
||||
.mapv(|x| x as f32);
|
||||
let shape: [usize; 4] = tensor.dim().into();
|
||||
let shape = shape.map(|f| f as i32);
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = tensor
|
||||
.as_mnn_tensor()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
if *input.shape() != shape {
|
||||
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
||||
// input.resize(shape);
|
||||
intptr.resize_tensor(input.view_mut(), shape);
|
||||
}
|
||||
}
|
||||
intptr.resize_session(session);
|
||||
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, Self::OUTPUT_NAME)?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
Ok(output_tensor)
|
||||
})
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
||||
todo!()
|
||||
}
|
||||
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
}
|
||||
|
||||
69
src/main.rs
69
src/main.rs
@@ -1,9 +1,13 @@
|
||||
mod cli;
|
||||
mod errors;
|
||||
use detector::facedet::retinaface::FaceDetectionConfig;
|
||||
use bounding_box::roi::MultiRoi;
|
||||
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
use nalgebra::zero;
|
||||
use ndarray_image::*;
|
||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
@@ -15,29 +19,84 @@ pub fn main() -> Result<()> {
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
use detector::facedet;
|
||||
let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||
let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = model
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.clone(),
|
||||
array.view(),
|
||||
FaceDetectionConfig::default()
|
||||
.with_threshold(detect.threshold)
|
||||
.with_nms_threshold(detect.nms_threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
for bbox in output.bbox {
|
||||
for bbox in &output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
use ndarray::{Array2, Array3, Array4, Axis};
|
||||
use ndarray_resize::NdFir;
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
// .inspect(|f| {
|
||||
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(512, 512, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
// f.as_ref().inspect(|f| {
|
||||
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
|
||||
// });
|
||||
// })
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let embeddings = face_roi_views
|
||||
.chunks(8)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
if chunk.len() < 8 {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((512, 512, 3));
|
||||
let padded: Vec<ndarray::ArrayView3<'_, u8>> = chunk
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(8)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), padded.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>();
|
||||
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
let image: image::RgbImage = v
|
||||
|
||||
Reference in New Issue
Block a user