feat: Added stuff
This commit is contained in:
214
src/main.rs
214
src/main.rs
@@ -1,7 +1,7 @@
|
||||
mod cli;
|
||||
mod errors;
|
||||
use bounding_box::roi::MultiRoi;
|
||||
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
||||
use detector::{facedet, facedet::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
use ndarray::*;
|
||||
@@ -20,97 +20,45 @@ pub fn main() -> Result<()> {
|
||||
let args = <cli::Cli as clap::Parser>::parse();
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
use detector::facedet;
|
||||
let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.view(),
|
||||
FaceDetectionConfig::default()
|
||||
.with_threshold(detect.threshold)
|
||||
.with_nms_threshold(detect.nms_threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
for bbox in &output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
// .inspect(|f| {
|
||||
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(512, 512, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
// f.as_ref().inspect(|f| {
|
||||
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
|
||||
// });
|
||||
// })
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
|
||||
|
||||
let chunk_size = CHUNK_SIZE;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
match executor {
|
||||
cli::Executor::Mnn => {
|
||||
let retinaface = facedet::mnn::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
if chunk.len() < 8 {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((512, 512, 3));
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>();
|
||||
run_detection(detect, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Onnx => {
|
||||
// Load ONNX models
|
||||
const RETINAFACE_ONNX_MODEL: &[u8] =
|
||||
include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_ONNX_MODEL: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
let image: image::RgbImage = v
|
||||
.to_image()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert ndarray to image")?;
|
||||
image
|
||||
.save(output)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to save output image")?;
|
||||
let retinaface = facedet::ort::FaceDetection::builder()(RETINAFACE_ONNX_MODEL)
|
||||
.change_context(Error)?
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_ONNX_MODEL)
|
||||
.change_context(Error)?
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_detection(detect, retinaface, facenet)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::List(list) => {
|
||||
@@ -122,3 +70,91 @@ pub fn main() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.view(),
|
||||
FaceDetectionConfig::default()
|
||||
.with_threshold(detect.threshold)
|
||||
.with_nms_threshold(detect.nms_threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
for bbox in &output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
// .inspect(|f| {
|
||||
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(512, 512, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
// f.as_ref().inspect(|f| {
|
||||
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
|
||||
// });
|
||||
// })
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let chunk_size = CHUNK_SIZE;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
if chunk.len() < 8 {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((512, 512, 3));
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>();
|
||||
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
let image: image::RgbImage = v
|
||||
.to_image()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert ndarray to image")?;
|
||||
image
|
||||
.save(output)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to save output image")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user