mod cli; mod errors; use bounding_box::roi::MultiRoi; use detector::{facedet::retinaface::FaceDetectionConfig, faceembed}; use errors::*; use fast_image_resize::ResizeOptions; use nalgebra::zero; use ndarray_image::*; const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn"); const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn"); pub fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter("trace") .with_thread_ids(true) .with_thread_names(true) .with_target(false) .init(); let args = ::parse(); match args.cmd { cli::SubCommand::Detect(detect) => { use detector::facedet; let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL) .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL) .change_context(errors::Error) .attach_printable("Failed to create face embedding model")?; let image = image::open(detect.image).change_context(Error)?; let image = image.into_rgb8(); let mut array = image .into_ndarray() .change_context(errors::Error) .attach_printable("Failed to convert image to ndarray")?; let output = retinaface .detect_faces( array.view(), FaceDetectionConfig::default() .with_threshold(detect.threshold) .with_nms_threshold(detect.nms_threshold), ) .change_context(errors::Error) .attach_printable("Failed to detect faces")?; for bbox in &output.bbox { tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); } use ndarray::{Array2, Array3, Array4, Axis}; use ndarray_resize::NdFir; let face_rois = array .view() .multi_roi(&output.bbox) .change_context(Error)? .into_iter() // .inspect(|f| { // tracing::info!("Face ROI shape before resize: {:?}", f.dim()); // }) .map(|roi| { roi.as_standard_layout() .fast_resize(512, 512, &ResizeOptions::default()) .change_context(Error) }) // .inspect(|f| { // f.as_ref().inspect(|f| { // tracing::info!("Face ROI shape after resize: {:?}", f.dim()); // }); // }) .collect::>>()?; let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); let embeddings = face_roi_views .chunks(8) .map(|chunk| { tracing::info!("Processing chunk of size: {}", chunk.len()); if chunk.len() < 8 { tracing::warn!("Chunk size is less than 8, padding with zeros"); let zeros = Array3::zeros((512, 512, 3)); let padded: Vec> = chunk .iter() .cloned() .chain(core::iter::repeat(zeros.view())) .take(8) .collect(); let face_rois: Array4 = ndarray::stack(Axis(0), padded.as_slice()) .change_context(errors::Error) .attach_printable("Failed to stack rois together")?; let output = facenet.run_models(face_rois.view()).change_context(Error)?; Ok(output) } else { let face_rois: Array4 = ndarray::stack(Axis(0), chunk) .change_context(errors::Error) .attach_printable("Failed to stack rois together")?; let output = facenet.run_models(face_rois.view()).change_context(Error)?; Ok(output) } }) .collect::>>>(); let v = array.view(); if let Some(output) = detect.output { let image: image::RgbImage = v .to_image() .change_context(errors::Error) .attach_printable("Failed to convert ndarray to image")?; image .save(output) .change_context(errors::Error) .attach_printable("Failed to save output image")?; } } cli::SubCommand::List(list) => { println!("List: {:?}", list); } cli::SubCommand::Completions { shell } => { cli::Cli::completions(shell); } } Ok(()) }