mod cli; mod errors; use bounding_box::roi::MultiRoi; use detector::{facedet, facedet::FaceDetectionConfig, faceembed}; use errors::*; use fast_image_resize::ResizeOptions; use ndarray::*; use ndarray_image::*; use ndarray_resize::NdFir; const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn"); const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn"); const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx"); const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx"); pub fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter("error") .with_thread_ids(true) .with_thread_names(true) .with_target(false) .init(); let args = ::parse(); match args.cmd { cli::SubCommand::Detect(detect) => { // Choose backend based on executor type (defaulting to MNN for backward compatibility) let executor = detect .mnn_forward_type .map(|f| cli::Executor::Mnn(f)) .or_else(|| { if detect.ort_execution_provider.is_empty() { None } else { Some(cli::Executor::Ort(detect.ort_execution_provider.clone())) } }) .unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU)); match executor { cli::Executor::Mnn(forward) => { let retinaface = facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) .change_context(Error)? .with_forward_type(forward) .build() .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; let facenet = faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) .change_context(Error)? .with_forward_type(forward) .build() .change_context(errors::Error) .attach_printable("Failed to create face embedding model")?; run_detection(detect, retinaface, facenet)?; } cli::Executor::Ort(ep) => { let retinaface = facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX) .change_context(Error)? .with_execution_providers(&ep) .build() .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; let facenet = faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX) .change_context(Error)? .with_execution_providers(ep) .build() .change_context(errors::Error) .attach_printable("Failed to create face embedding model")?; run_detection(detect, retinaface, facenet)?; } } } cli::SubCommand::List(list) => { println!("List: {:?}", list); } cli::SubCommand::Completions { shell } => { cli::Cli::completions(shell); } } Ok(()) } fn run_detection(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()> where D: facedet::FaceDetector, E: faceembed::FaceEmbedder, { let image = image::open(&detect.image) .change_context(Error) .attach_printable(detect.image.to_string_lossy().to_string())?; let image = image.into_rgb8(); let mut array = image .into_ndarray() .change_context(errors::Error) .attach_printable("Failed to convert image to ndarray")?; let output = retinaface .detect_faces( array.view(), FaceDetectionConfig::default() .with_threshold(detect.threshold) .with_nms_threshold(detect.nms_threshold), ) .change_context(errors::Error) .attach_printable("Failed to detect faces")?; for bbox in &output.bbox { tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); } let face_rois = array .view() .multi_roi(&output.bbox) .change_context(Error)? .into_iter() // .inspect(|f| { // tracing::info!("Face ROI shape before resize: {:?}", f.dim()); // }) .map(|roi| { roi.as_standard_layout() .fast_resize(160, 160, &ResizeOptions::default()) .change_context(Error) }) // .inspect(|f| { // f.as_ref().inspect(|f| { // tracing::info!("Face ROI shape after resize: {:?}", f.dim()); // }); // }) .collect::>>()?; let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); let chunk_size = detect.batch_size; let embeddings = face_roi_views .chunks(chunk_size) .map(|chunk| { tracing::info!("Processing chunk of size: {}", chunk.len()); if chunk.len() < chunk_size { tracing::warn!("Chunk size is less than 8, padding with zeros"); let zeros = Array3::zeros((160, 160, 3)); let zero_array = core::iter::repeat(zeros.view()) .take(chunk_size) .collect::>(); let face_rois: Array4 = ndarray::stack(Axis(0), zero_array.as_slice()) .change_context(errors::Error) .attach_printable("Failed to stack rois together")?; let output = facenet.run_models(face_rois.view()).change_context(Error)?; Ok(output) } else { let face_rois: Array4 = ndarray::stack(Axis(0), chunk) .change_context(errors::Error) .attach_printable("Failed to stack rois together")?; let output = facenet.run_models(face_rois.view()).change_context(Error)?; Ok(output) } }) .collect::>>>()?; let v = array.view(); if let Some(output) = detect.output { let image: image::RgbImage = v .to_image() .change_context(errors::Error) .attach_printable("Failed to convert ndarray to image")?; image .save(output) .change_context(errors::Error) .attach_printable("Failed to save output image")?; } Ok(()) }