From 33afbfc2b836b905c2b182bb6266e16115f17bbc Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 18 Aug 2025 11:31:03 +0530 Subject: [PATCH] feat: Added stuff --- .gitmodules | 3 + README.md | 199 +++++++++++- rfcs | 1 + src/cli.rs | 2 + src/facedet.rs | 24 +- src/facedet/mnn/mod.rs | 3 + src/facedet/mnn/retinaface.rs | 174 +++++++++++ src/facedet/ort/mod.rs | 3 + src/facedet/ort/retinaface.rs | 264 ++++++++++++++++ src/facedet/{retinaface.rs => postprocess.rs} | 294 +++++------------- src/faceembed.rs | 21 +- src/faceembed/facenet/mnn_impl.rs | 1 - src/faceembed/facenet/ort_impl.rs | 65 ---- src/faceembed/{ => mnn}/facenet.rs | 9 +- src/faceembed/mnn/mod.rs | 3 + src/faceembed/ort/facenet.rs | 79 +++++ src/faceembed/ort/mod.rs | 3 + src/main.rs | 214 +++++++------ 18 files changed, 987 insertions(+), 375 deletions(-) create mode 100644 .gitmodules create mode 160000 rfcs create mode 100644 src/facedet/mnn/mod.rs create mode 100644 src/facedet/mnn/retinaface.rs create mode 100644 src/facedet/ort/mod.rs create mode 100644 src/facedet/ort/retinaface.rs rename src/facedet/{retinaface.rs => postprocess.rs} (50%) delete mode 100644 src/faceembed/facenet/mnn_impl.rs delete mode 100644 src/faceembed/facenet/ort_impl.rs rename src/faceembed/{ => mnn}/facenet.rs (96%) create mode 100644 src/faceembed/mnn/mod.rs create mode 100644 src/faceembed/ort/facenet.rs create mode 100644 src/faceembed/ort/mod.rs diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..132a516 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "rfcs"] + path = rfcs + url = git@github.com:aftershootco/rfcs.git diff --git a/README.md b/README.md index 4221030..0e8fc44 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,198 @@ -# Face Detection +# Face Detection and Embedding -Rust programs to do face detection and face embedding +A high-performance Rust implementation for face detection and face embedding generation using neural networks. + +## Overview + +This project provides a complete face detection and recognition pipeline with the following capabilities: + +- **Face Detection**: Detect faces in images using RetinaFace model +- **Face Embedding**: Generate face embeddings using FaceNet model +- **Multiple Backends**: Support for both MNN and ONNX runtime execution +- **Hardware Acceleration**: Metal, CoreML, and OpenCL support on compatible platforms +- **Modular Design**: Workspace architecture with reusable components + +## Features + +- 🔍 **Accurate Face Detection** - Uses RetinaFace model for robust face detection +- 🧠 **Face Embeddings** - Generate 512-dimensional face embeddings with FaceNet +- ⚡ **High Performance** - Optimized with hardware acceleration (Metal, CoreML) +- 🔧 **Flexible Configuration** - Adjustable detection thresholds and NMS parameters +- 📦 **Modular Architecture** - Reusable components for image processing and bounding boxes +- 🖼️ **Visual Output** - Draw bounding boxes on detected faces + +## Architecture + +The project is organized as a Rust workspace with the following components: + +- **`detector`** - Main face detection and embedding application +- **`bounding-box`** - Geometric operations and drawing utilities for bounding boxes +- **`ndarray-image`** - Conversion utilities between ndarray and image formats +- **`ndarray-resize`** - Fast image resizing operations on ndarray data + +## Models + +The project includes pre-trained neural network models: + +- **RetinaFace** - Face detection model (`.mnn` and `.onnx` formats) +- **FaceNet** - Face embedding model (`.mnn` and `.onnx` formats) + +## Usage + +### Basic Face Detection + +```bash +# Detect faces using MNN backend (default) +cargo run --release detect path/to/image.jpg + +# Detect faces using ONNX Runtime backend +cargo run --release detect --executor onnx path/to/image.jpg + +# Save output with bounding boxes drawn +cargo run --release detect --output detected.jpg path/to/image.jpg + +# Adjust detection sensitivity +cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg +``` + +### Backend Selection + +The project supports two inference backends: + +- **MNN Backend** (default): High-performance inference framework with Metal/CoreML support +- **ONNX Runtime Backend**: Cross-platform ML inference with broad hardware support + +```bash +# Use MNN backend with Metal acceleration (macOS) +cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg + +# Use ONNX Runtime backend +cargo run --release detect --executor onnx path/to/image.jpg +``` + +### Command Line Options + +```bash +# Face detection with custom parameters +cargo run --release detect [OPTIONS] + +Options: + -m, --model Custom model path + -M, --model-type Model type [default: retina-face] + -o, --output Output image path + -e, --executor Inference backend [mnn, onnx] + -f, --forward-type MNN execution backend [default: cpu] + -t, --threshold Detection threshold [default: 0.8] + -n, --nms-threshold NMS threshold [default: 0.3] +``` + +### Quick Start + +```bash +# Build the project +cargo build --release + +# Run face detection on sample image +just run +# or +cargo run --release detect ./1000066593.jpg +``` + +## Hardware Acceleration + +### MNN Backend + +The MNN backend supports various execution backends: + +- **CPU** - Default, works on all platforms +- **Metal** - macOS GPU acceleration +- **CoreML** - macOS/iOS neural engine acceleration +- **OpenCL** - Cross-platform GPU acceleration + +```bash +# Use Metal acceleration on macOS +cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg + +# Use CoreML on macOS/iOS +cargo run --release detect --executor mnn --forward-type coreml path/to/image.jpg +``` + +### ONNX Runtime Backend + +The ONNX Runtime backend automatically selects the best available execution provider based on your system configuration. + +## Development + +### Prerequisites + +- Rust 2024 edition +- MNN runtime (automatically linked) +- ONNX runtime (for ONNX backend) + +### Building + +```bash +# Standard build +cargo build + +# Release build with optimizations +cargo build --release + +# Run tests +cargo test +``` + +### Project Structure + +``` +├── src/ +│ ├── facedet/ # Face detection modules +│ │ ├── mnn/ # MNN backend implementations +│ │ ├── ort/ # ONNX Runtime backend implementations +│ │ └── postprocess.rs # Shared postprocessing logic +│ ├── faceembed/ # Face embedding modules +│ │ ├── mnn/ # MNN backend implementations +│ │ └── ort/ # ONNX Runtime backend implementations +│ ├── cli.rs # Command line interface +│ └── main.rs # Application entry point +├── models/ # Neural network models (.mnn and .onnx) +├── bounding-box/ # Bounding box utilities +├── ndarray-image/ # Image conversion utilities +└── ndarray-resize/ # Image resizing utilities +``` + +### Backend Architecture + +The codebase is organized to support multiple inference backends: + +- **Common interfaces**: `FaceDetector` and `FaceEmbedder` traits provide unified APIs +- **Shared postprocessing**: Common logic for anchor generation, NMS, and coordinate decoding +- **Backend-specific implementations**: Separate modules for MNN and ONNX Runtime +- **Modular design**: Easy to add new backends by implementing the common traits + +## License + +MIT License + +## Dependencies + +Key dependencies include: + +- **MNN** - High-performance neural network inference framework (MNN backend) +- **ONNX Runtime** - Cross-platform ML inference (ORT backend) +- **ndarray** - N-dimensional array processing +- **image** - Image processing and I/O +- **clap** - Command line argument parsing +- **bounding-box** - Geometric operations for face detection +- **error-stack** - Structured error handling + +### Backend Status + +- ✅ **MNN Backend**: Fully implemented with hardware acceleration support +- 🚧 **ONNX Runtime Backend**: Framework implemented, inference logic to be completed + +*Note: The ORT backend currently provides the framework but requires completion of the inference implementation.* + +--- + +*Built with Rust for maximum performance and safety in computer vision applications.* diff --git a/rfcs b/rfcs new file mode 160000 index 0000000..ad85f4c --- /dev/null +++ b/rfcs @@ -0,0 +1 @@ +Subproject commit ad85f4c8197b2b7e052a2fea062b855fff4533bf diff --git a/src/cli.rs b/src/cli.rs index c6665ab..482f2cd 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -48,6 +48,8 @@ pub struct Detect { pub model_type: Models, #[clap(short, long)] pub output: Option, + #[clap(short = 'e', long)] + pub executor: Option, #[clap(short, long, default_value = "cpu")] pub forward_type: mnn::ForwardType, #[clap(short, long, default_value_t = 0.8)] diff --git a/src/facedet.rs b/src/facedet.rs index 83db01d..c50a491 100644 --- a/src/facedet.rs +++ b/src/facedet.rs @@ -1,2 +1,24 @@ -pub mod retinaface; +pub mod mnn; +pub mod ort; +pub mod postprocess; pub mod yolo; + +// Re-export common types and traits +pub use postprocess::{ + FaceDetectionConfig, FaceDetectionModelOutput, FaceDetectionOutput, + FaceDetectionProcessedOutput, FaceDetector, FaceLandmarks, +}; + +// Convenience type aliases for different backends +pub mod retinaface { + pub use crate::facedet::mnn::retinaface as mnn; + pub use crate::facedet::ort::retinaface as ort; + + // Re-export common types + pub use crate::facedet::postprocess::{ + FaceDetectionConfig, FaceDetectionOutput, FaceDetector, FaceLandmarks, + }; +} + +// Default to MNN implementation for backward compatibility +pub use mnn::retinaface::FaceDetection; diff --git a/src/facedet/mnn/mod.rs b/src/facedet/mnn/mod.rs new file mode 100644 index 0000000..067752d --- /dev/null +++ b/src/facedet/mnn/mod.rs @@ -0,0 +1,3 @@ +pub mod retinaface; + +pub use retinaface::FaceDetection; diff --git a/src/facedet/mnn/retinaface.rs b/src/facedet/mnn/retinaface.rs new file mode 100644 index 0000000..c02ada1 --- /dev/null +++ b/src/facedet/mnn/retinaface.rs @@ -0,0 +1,174 @@ +use crate::errors::*; +use crate::facedet::postprocess::*; +use error_stack::ResultExt; +use mnn_bridge::ndarray::*; +use ndarray_resize::NdFir; +use std::path::Path; + +#[derive(Debug)] +pub struct FaceDetection { + handle: mnn_sync::SessionHandle, +} + +pub struct FaceDetectionBuilder { + schedule_config: Option, + backend_config: Option, + model: mnn::Interpreter, +} + +impl FaceDetectionBuilder { + pub fn new(model: impl AsRef<[u8]>) -> Result { + Ok(Self { + schedule_config: None, + backend_config: None, + model: mnn::Interpreter::from_bytes(model.as_ref()) + .map_err(|e| e.into_inner()) + .change_context(Error) + .attach_printable("Failed to load model from bytes")?, + }) + } + + pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self { + self.schedule_config + .get_or_insert_default() + .set_type(forward_type); + self + } + + pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self { + self.schedule_config = Some(config); + self + } + + pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self { + self.backend_config = Some(config); + self + } + + pub fn build(self) -> Result { + let model = self.model; + let sc = self.schedule_config.unwrap_or_default(); + let handle = mnn_sync::SessionHandle::new(model, sc) + .change_context(Error) + .attach_printable("Failed to create session handle")?; + Ok(FaceDetection { handle }) + } +} + +impl FaceDetection { + pub fn builder>() + -> fn(T) -> std::result::Result> { + FaceDetectionBuilder::new + } + + pub fn new(path: impl AsRef) -> Result { + let model = std::fs::read(path) + .change_context(Error) + .attach_printable("Failed to read model file")?; + Self::new_from_bytes(&model) + } + + pub fn new_from_bytes(model: &[u8]) -> Result { + tracing::info!("Loading face detection model from bytes"); + let mut model = mnn::Interpreter::from_bytes(model) + .map_err(|e| e.into_inner()) + .change_context(Error) + .attach_printable("Failed to load model from bytes")?; + model.set_session_mode(mnn::SessionMode::Release); + model + .set_cache_file("retinaface.cache", 128) + .change_context(Error) + .attach_printable("Failed to set cache file")?; + let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); + let sc = mnn::ScheduleConfig::new() + .with_type(mnn::ForwardType::Metal) + .with_backend_config(bc); + tracing::info!("Creating session handle for face detection model"); + let handle = mnn_sync::SessionHandle::new(model, sc) + .change_context(Error) + .attach_printable("Failed to create session handle")?; + Ok(FaceDetection { handle }) + } +} + +impl FaceDetector for FaceDetection { + fn run_model(&mut self, image: ndarray::ArrayView3) -> Result { + #[rustfmt::skip] + let mut resized = image + .fast_resize(1024, 1024, None) + .change_context(Error)? + .mapv(|f| f as f32); + + // Apply mean subtraction: [104, 117, 123] + resized + .axis_iter_mut(ndarray::Axis(2)) + .zip([104, 117, 123]) + .for_each(|(mut array, pixel)| { + let pixel = pixel as f32; + array.map_inplace(|v| *v -= pixel); + }); + + let mut resized = resized + .permuted_axes((2, 0, 1)) + .insert_axis(ndarray::Axis(0)) + .as_standard_layout() + .into_owned(); + + use ::tap::*; + let output = self + .handle + .run(move |sr| { + let tensor = resized + .as_mnn_tensor_mut() + .attach_printable("Failed to convert ndarray to mnn tensor") + .change_context(mnn::error::ErrorKind::TensorError)?; + tracing::trace!("Image Tensor shape: {:?}", tensor.shape()); + let (intptr, session) = sr.both_mut(); + tracing::trace!("Copying input tensor to host"); + unsafe { + let mut input = intptr.input_unresized::(session, "input")?; + tracing::trace!("Input shape: {:?}", input.shape()); + intptr.resize_tensor_by_nchw::, _>( + input.view_mut(), + 1, + 3, + 1024, + 1024, + ); + } + intptr.resize_session(session); + let mut input = intptr.input::(session, "input")?; + tracing::trace!("Input shape: {:?}", input.shape()); + input.copy_from_host_tensor(tensor.view())?; + + tracing::info!("Running face detection session"); + intptr.run_session(&session)?; + let output_tensor = intptr + .output::(&session, "bbox")? + .create_host_tensor_from_device(true) + .as_ndarray() + .to_owned(); + tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape()); + let output_confidence = intptr + .output::(&session, "confidence")? + .create_host_tensor_from_device(true) + .as_ndarray::() + .to_owned(); + tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape()); + let output_landmark = intptr + .output::(&session, "landmark")? + .create_host_tensor_from_device(true) + .as_ndarray::() + .to_owned(); + tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape()); + Ok(FaceDetectionModelOutput { + bbox: output_tensor, + confidence: output_confidence, + landmark: output_landmark, + }) + }) + .map_err(|e| e.into_inner()) + .change_context(Error)?; + Ok(output) + } +} diff --git a/src/facedet/ort/mod.rs b/src/facedet/ort/mod.rs new file mode 100644 index 0000000..067752d --- /dev/null +++ b/src/facedet/ort/mod.rs @@ -0,0 +1,3 @@ +pub mod retinaface; + +pub use retinaface::FaceDetection; diff --git a/src/facedet/ort/retinaface.rs b/src/facedet/ort/retinaface.rs new file mode 100644 index 0000000..10a870b --- /dev/null +++ b/src/facedet/ort/retinaface.rs @@ -0,0 +1,264 @@ +use crate::errors::*; +use crate::facedet::postprocess::*; +use error_stack::ResultExt; +use ndarray_resize::NdFir; +use ort::{ + execution_providers::{ + CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch, + }, + session::{Session, builder::GraphOptimizationLevel}, + value::Tensor, +}; +use std::path::Path; + +#[derive(Debug)] +pub struct FaceDetection { + session: Session, +} + +pub struct FaceDetectionBuilder { + model_data: Vec, + execution_providers: Option>, + intra_threads: Option, + inter_threads: Option, +} + +impl FaceDetectionBuilder { + pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result { + Ok(Self { + model_data: model.as_ref().to_vec(), + execution_providers: None, + intra_threads: None, + inter_threads: None, + }) + } + + pub fn with_execution_providers(mut self, providers: Vec) -> Self { + let execution_providers: Vec = providers + .into_iter() + .filter_map(|provider| match provider.as_str() { + "cpu" | "CPU" => Some(CPUExecutionProvider::default().build()), + #[cfg(target_os = "macos")] + "coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()), + _ => { + tracing::warn!("Unknown execution provider: {}", provider); + None + } + }) + .collect(); + + if !execution_providers.is_empty() { + self.execution_providers = Some(execution_providers); + } else { + tracing::warn!("No valid execution providers found, falling back to CPU"); + self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]); + } + self + } + + pub fn with_intra_threads(mut self, threads: usize) -> Self { + self.intra_threads = Some(threads); + self + } + + pub fn with_inter_threads(mut self, threads: usize) -> Self { + self.inter_threads = Some(threads); + self + } + + pub fn build(self) -> crate::errors::Result { + let mut session_builder = Session::builder() + .change_context(Error) + .attach_printable("Failed to create session builder")?; + + // Set execution providers + if let Some(providers) = self.execution_providers { + session_builder = session_builder + .with_execution_providers(providers) + .change_context(Error) + .attach_printable("Failed to set execution providers")?; + } else { + // Default to CPU + session_builder = session_builder + .with_execution_providers([CPUExecutionProvider::default().build()]) + .change_context(Error) + .attach_printable("Failed to set default CPU execution provider")?; + } + + // Set threading options + if let Some(threads) = self.intra_threads { + session_builder = session_builder + .with_intra_threads(threads) + .change_context(Error) + .attach_printable("Failed to set intra threads")?; + } + + if let Some(threads) = self.inter_threads { + session_builder = session_builder + .with_inter_threads(threads) + .change_context(Error) + .attach_printable("Failed to set inter threads")?; + } + + // Set optimization level + session_builder = session_builder + .with_optimization_level(GraphOptimizationLevel::Level3) + .change_context(Error) + .attach_printable("Failed to set optimization level")?; + + // Create session from model bytes + let session = session_builder + .commit_from_memory(&self.model_data) + .change_context(Error) + .attach_printable("Failed to create ORT session from model bytes")?; + + tracing::info!("Successfully created ORT RetinaFace session"); + + Ok(FaceDetection { session }) + } +} + +impl FaceDetection { + pub fn builder>() + -> fn(T) -> std::result::Result> + { + FaceDetectionBuilder::new + } + + pub fn new(path: impl AsRef) -> crate::errors::Result { + let model = std::fs::read(path) + .change_context(Error) + .attach_printable("Failed to read model file")?; + Self::new_from_bytes(&model) + } + + pub fn new_from_bytes(model: &[u8]) -> crate::errors::Result { + tracing::info!("Loading ORT RetinaFace model from bytes"); + Self::builder()(model)?.build() + } +} + +impl FaceDetector for FaceDetection { + fn run_model( + &mut self, + image: ndarray::ArrayView3, + ) -> crate::errors::Result { + // Resize image to 1024x1024 + let mut resized = image + .fast_resize(1024, 1024, None) + .change_context(Error) + .attach_printable("Failed to resize image")? + .mapv(|f| f as f32); + + // Apply mean subtraction: [104, 117, 123] for BGR format + resized + .axis_iter_mut(ndarray::Axis(2)) + .zip([104.0, 117.0, 123.0]) + .for_each(|(mut array, mean)| { + array.map_inplace(|v| *v -= mean); + }); + + // Convert from HWC to NCHW format (add batch dimension and transpose) + let input_tensor = resized + .permuted_axes((2, 0, 1)) + .insert_axis(ndarray::Axis(0)) + .as_standard_layout() + .into_owned(); + + tracing::trace!("Input tensor shape: {:?}", input_tensor.shape()); + + // Create ORT input tensor + let input_value = Tensor::from_array(input_tensor) + .change_context(Error) + .attach_printable("Failed to create input tensor")?; + + // Run inference + tracing::debug!("Running ORT RetinaFace inference"); + let outputs = self + .session + .run(ort::inputs!["input" => input_value]) + .change_context(Error) + .attach_printable("Failed to run inference")?; + + // Extract outputs by name + let bbox_output = outputs + .get("bbox") + .ok_or(Error) + .attach_printable("Missing bbox output from model")? + .try_extract_tensor::() + .change_context(Error) + .attach_printable("Failed to extract bbox tensor")?; + + let confidence_output = outputs + .get("confidence") + .ok_or(Error) + .attach_printable("Missing confidence output from model")? + .try_extract_tensor::() + .change_context(Error) + .attach_printable("Failed to extract confidence tensor")?; + + let landmark_output = outputs + .get("landmark") + .ok_or(Error) + .attach_printable("Missing landmark output from model")? + .try_extract_tensor::() + .change_context(Error) + .attach_printable("Failed to extract landmark tensor")?; + + // Get tensor shapes and data + let (bbox_shape, bbox_data) = bbox_output; + let (confidence_shape, confidence_data) = confidence_output; + let (landmark_shape, landmark_data) = landmark_output; + + tracing::trace!( + "Output shapes - bbox: {:?}, confidence: {:?}, landmark: {:?}", + bbox_shape, + confidence_shape, + landmark_shape + ); + + // Convert to ndarray format + let bbox_dims = bbox_shape.as_ref(); + let confidence_dims = confidence_shape.as_ref(); + let landmark_dims = landmark_shape.as_ref(); + + let bbox_array = ndarray::Array3::from_shape_vec( + ( + bbox_dims[0] as usize, + bbox_dims[1] as usize, + bbox_dims[2] as usize, + ), + bbox_data.to_vec(), + ) + .change_context(Error) + .attach_printable("Failed to create bbox ndarray")?; + + let confidence_array = ndarray::Array3::from_shape_vec( + ( + confidence_dims[0] as usize, + confidence_dims[1] as usize, + confidence_dims[2] as usize, + ), + confidence_data.to_vec(), + ) + .change_context(Error) + .attach_printable("Failed to create confidence ndarray")?; + + let landmark_array = ndarray::Array3::from_shape_vec( + ( + landmark_dims[0] as usize, + landmark_dims[1] as usize, + landmark_dims[2] as usize, + ), + landmark_data.to_vec(), + ) + .change_context(Error) + .attach_printable("Failed to create landmark ndarray")?; + + Ok(FaceDetectionModelOutput { + bbox: bbox_array, + confidence: confidence_array, + landmark: landmark_array, + }) + } +} diff --git a/src/facedet/retinaface.rs b/src/facedet/postprocess.rs similarity index 50% rename from src/facedet/retinaface.rs rename to src/facedet/postprocess.rs index 174a9c3..22aa9af 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/postprocess.rs @@ -1,10 +1,8 @@ use crate::errors::*; use bounding_box::{Aabb2, nms::nms}; use error_stack::ResultExt; -use mnn_bridge::ndarray::*; use nalgebra::{Point2, Vector2}; -use ndarray_resize::NdFir; -use std::path::Path; +use std::collections::HashMap; /// Configuration for face detection postprocessing #[derive(Debug, Clone, PartialEq)] @@ -32,30 +30,37 @@ impl FaceDetectionConfig { self.threshold = threshold; self } + pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self { self.nms_threshold = nms_threshold; self } + pub fn with_variances(mut self, variances: [f32; 2]) -> Self { self.variances = variances; self } + pub fn with_steps(mut self, steps: Vec) -> Self { self.steps = steps; self } + pub fn with_min_sizes(mut self, min_sizes: Vec>) -> Self { self.min_sizes = min_sizes; self } + pub fn with_clip(mut self, clip: bool) -> Self { self.clamp = clip; self } + pub fn with_input_width(mut self, input_width: usize) -> Self { self.input_width = input_width; self } + pub fn with_input_height(mut self, input_height: usize) -> Self { self.input_height = input_height; self @@ -77,18 +82,6 @@ impl Default for FaceDetectionConfig { } } -#[derive(Debug)] -pub struct FaceDetection { - handle: mnn_sync::SessionHandle, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct FaceDetectionModelOutput { - pub bbox: ndarray::Array3, - pub confidence: ndarray::Array3, - pub landmark: ndarray::Array3, -} - /// Represents the 5 facial landmarks detected by RetinaFace #[derive(Debug, Copy, Clone, PartialEq)] pub struct FaceLandmarks { @@ -99,6 +92,13 @@ pub struct FaceLandmarks { pub right_mouth: Point2, } +#[derive(Debug, Clone, PartialEq)] +pub struct FaceDetectionModelOutput { + pub bbox: ndarray::Array3, + pub confidence: ndarray::Array3, + pub landmark: ndarray::Array3, +} + #[derive(Debug, Clone, PartialEq)] pub struct FaceDetectionProcessedOutput { pub bbox: Vec>, @@ -113,7 +113,13 @@ pub struct FaceDetectionOutput { pub landmark: Vec, } -fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2 { +/// Raw model outputs that can be converted to FaceDetectionModelOutput +pub trait IntoModelOutput { + fn into_model_output(self) -> Result; +} + +/// Generate anchors for RetinaFace model +pub fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2 { let mut anchors = Vec::new(); let feature_maps: Vec<(usize, usize)> = config .steps @@ -220,9 +226,7 @@ impl FaceDetectionModelOutput { landmarks: decoded_landmarks, }) } -} -impl FaceDetectionModelOutput { pub fn print(&self, limit: usize) { tracing::info!("Detected {} faces", self.bbox.shape()[1]); @@ -246,214 +250,76 @@ impl FaceDetectionModelOutput { } } -pub struct FaceDetectionBuilder { - schedule_config: Option, - backend_config: Option, - model: mnn::Interpreter, +/// Apply Non-Maximum Suppression and convert to final output format +pub fn apply_nms_and_finalize( + processed: FaceDetectionProcessedOutput, + config: &FaceDetectionConfig, + image_size: (usize, usize), // (width, height) +) -> Result { + use itertools::Itertools; + + let factor = Vector2::new(image_size.0 as f32, image_size.1 as f32); + + let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed + .bbox + .iter() + .cloned() + .zip(processed.confidence.iter().cloned()) + .zip(processed.landmarks.iter().cloned()) + .sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score)) + .map(|((b, s), l)| (b, s, l)) + .multiunzip(); + + let keep_indices = + nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?; + + let bboxes = boxes + .into_iter() + .enumerate() + .filter(|(i, _)| keep_indices.contains(i)) + .flat_map(|(_, x)| x.denormalize(factor).try_cast::()) + .collect(); + let confidence = scores + .into_iter() + .enumerate() + .filter(|(i, _)| keep_indices.contains(i)) + .map(|(_, score)| score) + .collect(); + let landmark = landmarks + .into_iter() + .enumerate() + .filter(|(i, _)| keep_indices.contains(i)) + .map(|(_, score)| score) + .collect(); + + Ok(FaceDetectionOutput { + bbox: bboxes, + confidence, + landmark, + }) } -impl FaceDetectionBuilder { - pub fn new(model: impl AsRef<[u8]>) -> Result { - Ok(Self { - schedule_config: None, - backend_config: None, - model: mnn::Interpreter::from_bytes(model.as_ref()) - .map_err(|e| e.into_inner()) - .change_context(Error) - .attach_printable("Failed to load model from bytes")?, - }) - } +/// Common trait for face detection backends +pub trait FaceDetector { + /// Run inference on the model and return raw outputs + fn run_model(&mut self, image: ndarray::ArrayView3) -> Result; - pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self { - self.schedule_config - .get_or_insert_default() - .set_type(forward_type); - self - } - - pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self { - self.schedule_config = Some(config); - self - } - - pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self { - self.backend_config = Some(config); - self - } - - pub fn build(self) -> Result { - let model = self.model; - let sc = self.schedule_config.unwrap_or_default(); - let handle = mnn_sync::SessionHandle::new(model, sc) - .change_context(Error) - .attach_printable("Failed to create session handle")?; - Ok(FaceDetection { handle }) - } -} - -impl FaceDetection { - pub fn builder>() - -> fn(T) -> std::result::Result> { - FaceDetectionBuilder::new - } - pub fn new(path: impl AsRef) -> Result { - let model = std::fs::read(path) - .change_context(Error) - .attach_printable("Failed to read model file")?; - Self::new_from_bytes(&model) - } - - pub fn new_from_bytes(model: &[u8]) -> Result { - tracing::info!("Loading face detection model from bytes"); - let mut model = mnn::Interpreter::from_bytes(model) - .map_err(|e| e.into_inner()) - .change_context(Error) - .attach_printable("Failed to load model from bytes")?; - model.set_session_mode(mnn::SessionMode::Release); - model - .set_cache_file("retinaface.cache", 128) - .change_context(Error) - .attach_printable("Failed to set cache file")?; - let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); - let sc = mnn::ScheduleConfig::new() - .with_type(mnn::ForwardType::Metal) - .with_backend_config(bc); - tracing::info!("Creating session handle for face detection model"); - let handle = mnn_sync::SessionHandle::new(model, sc) - .change_context(Error) - .attach_printable("Failed to create session handle")?; - Ok(FaceDetection { handle }) - } - - pub fn detect_faces( - &self, + /// Detect faces with full pipeline including postprocessing + fn detect_faces( + &mut self, image: ndarray::ArrayView3, config: FaceDetectionConfig, ) -> Result { let (height, width, _channels) = image.dim(); let output = self - .run_models(image) + .run_model(image) .change_context(Error) .attach_printable("Failed to detect faces")?; - // denormalize the bounding boxes - let factor = Vector2::new(width as f32, height as f32); - let mut processed = output + + let processed = output .postprocess(&config) .attach_printable("Failed to postprocess")?; - use itertools::Itertools; - let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed - .bbox - .iter() - .cloned() - .zip(processed.confidence.iter().cloned()) - .zip(processed.landmarks.iter().cloned()) - .sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score)) - .map(|((b, s), l)| (b, s, l)) - .multiunzip(); - - let keep_indices = - nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?; - - let bboxes = boxes - .into_iter() - .enumerate() - .filter(|(i, _)| keep_indices.contains(i)) - .flat_map(|(_, x)| x.denormalize(factor).try_cast::()) - .collect(); - let confidence = scores - .into_iter() - .enumerate() - .filter(|(i, _)| keep_indices.contains(i)) - .map(|(_, score)| score) - .collect(); - let landmark = landmarks - .into_iter() - .enumerate() - .filter(|(i, _)| keep_indices.contains(i)) - .map(|(_, score)| score) - .collect(); - - Ok(FaceDetectionOutput { - bbox: bboxes, - confidence, - landmark, - }) - } - - pub fn run_models(&self, image: ndarray::ArrayView3) -> Result { - #[rustfmt::skip] - let mut resized = image - .fast_resize(1024, 1024, None) - .change_context(Error)? - .mapv(|f| f as f32) - .tap_mut(|arr| { - arr.axis_iter_mut(ndarray::Axis(2)) - .zip([104, 117, 123]) - .for_each(|(mut array, pixel)| { - let pixel = pixel as f32; - array.map_inplace(|v| *v -= pixel); - }); - }) - .permuted_axes((2, 0, 1)) - .insert_axis(ndarray::Axis(0)) - .as_standard_layout() - .into_owned(); - use ::tap::*; - let output = self - .handle - .run(move |sr| { - let tensor = resized - .as_mnn_tensor_mut() - .attach_printable("Failed to convert ndarray to mnn tensor") - .change_context(mnn::error::ErrorKind::TensorError)?; - tracing::trace!("Image Tensor shape: {:?}", tensor.shape()); - let (intptr, session) = sr.both_mut(); - tracing::trace!("Copying input tensor to host"); - unsafe { - let mut input = intptr.input_unresized::(session, "input")?; - tracing::trace!("Input shape: {:?}", input.shape()); - intptr.resize_tensor_by_nchw::, _>( - input.view_mut(), - 1, - 3, - 1024, - 1024, - ); - } - intptr.resize_session(session); - let mut input = intptr.input::(session, "input")?; - tracing::trace!("Input shape: {:?}", input.shape()); - input.copy_from_host_tensor(tensor.view())?; - - tracing::info!("Running face detection session"); - intptr.run_session(&session)?; - let output_tensor = intptr - .output::(&session, "bbox")? - .create_host_tensor_from_device(true) - .as_ndarray() - .to_owned(); - tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape()); - let output_confidence = intptr - .output::(&session, "confidence")? - .create_host_tensor_from_device(true) - .as_ndarray::() - .to_owned(); - tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape()); - let output_landmark = intptr - .output::(&session, "landmark")? - .create_host_tensor_from_device(true) - .as_ndarray::() - .to_owned(); - tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape()); - Ok(FaceDetectionModelOutput { - bbox: output_tensor, - confidence: output_confidence, - landmark: output_landmark, - }) - }) - .map_err(|e| e.into_inner()) - .change_context(Error)?; - Ok(output) + apply_nms_and_finalize(processed, &config, (width, height)) } } diff --git a/src/faceembed.rs b/src/faceembed.rs index e90e6fb..d89ceb0 100644 --- a/src/faceembed.rs +++ b/src/faceembed.rs @@ -1 +1,20 @@ -pub mod facenet; +use crate::errors::*; +use ndarray::{Array2, ArrayView4}; + +pub mod mnn; +pub mod ort; + +/// Common trait for face embedding backends +pub trait FaceEmbedder { + /// Generate embeddings for a batch of face images + fn run_models(&self, faces: ArrayView4) -> Result>; +} + +// Convenience type aliases for different backends +pub mod facenet { + pub use crate::faceembed::mnn::facenet as mnn; + pub use crate::faceembed::ort::facenet as ort; +} + +// Default to MNN implementation for backward compatibility +pub use mnn::facenet::EmbeddingGenerator; diff --git a/src/faceembed/facenet/mnn_impl.rs b/src/faceembed/facenet/mnn_impl.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/faceembed/facenet/mnn_impl.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/faceembed/facenet/ort_impl.rs b/src/faceembed/facenet/ort_impl.rs deleted file mode 100644 index bcc5a90..0000000 --- a/src/faceembed/facenet/ort_impl.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::errors::{Result, *}; -use ndarray::*; -use ort::*; -use std::path::Path; - -#[derive(Debug)] -pub struct EmbeddingGenerator { - handle: ort::session::Session, -} - -// impl EmbeddingGeneratorBuilder { -// pub fn new(model: impl AsRef<[u8]>) -> Result { -// Ok(Self { -// schedule_config: None, -// backend_config: None, -// model: mnn::Interpreter::from_bytes(model.as_ref()) -// .map_err(|e| e.into_inner()) -// .change_context(Error) -// .attach_printable("Failed to load model from bytes")?, -// }) -// } -// -// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self { -// self.schedule_config -// .get_or_insert_default() -// .set_type(forward_type); -// self -// } -// -// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self { -// self.schedule_config = Some(config); -// self -// } -// -// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self { -// self.backend_config = Some(config); -// self -// } -// -// pub fn build(self) -> Result { -// let model = self.model; -// let sc = self.schedule_config.unwrap_or_default(); -// let handle = mnn_sync::SessionHandle::new(model, sc) -// .change_context(Error) -// .attach_printable("Failed to create session handle")?; -// Ok(EmbeddingGenerator { handle }) -// } -// } - -impl EmbeddingGenerator { - const INPUT_NAME: &'static str = "serving_default_input_6:0"; - const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; - pub fn new(path: impl AsRef) -> Result { - let model = std::fs::read(path) - .change_context(Error) - .attach_printable("Failed to read model file")?; - Self::new_from_bytes(&model) - } - - pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result { - todo!() - } - - // pub fn run_models(&self, face: ArrayView4) -> Result> {} -} diff --git a/src/faceembed/facenet.rs b/src/faceembed/mnn/facenet.rs similarity index 96% rename from src/faceembed/facenet.rs rename to src/faceembed/mnn/facenet.rs index 5ef3233..6206c79 100644 --- a/src/faceembed/facenet.rs +++ b/src/faceembed/mnn/facenet.rs @@ -1,9 +1,8 @@ use crate::errors::*; +use crate::faceembed::FaceEmbedder; use mnn_bridge::ndarray::*; use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use std::path::Path; -mod mnn_impl; -mod ort_impl; #[derive(Debug)] pub struct EmbeddingGenerator { @@ -151,3 +150,9 @@ impl EmbeddingGenerator { // todo!() // } } + +impl FaceEmbedder for EmbeddingGenerator { + fn run_models(&self, faces: ArrayView4) -> Result> { + self.run_models(faces) + } +} diff --git a/src/faceembed/mnn/mod.rs b/src/faceembed/mnn/mod.rs new file mode 100644 index 0000000..94700e6 --- /dev/null +++ b/src/faceembed/mnn/mod.rs @@ -0,0 +1,3 @@ +pub mod facenet; + +pub use facenet::EmbeddingGenerator; diff --git a/src/faceembed/ort/facenet.rs b/src/faceembed/ort/facenet.rs new file mode 100644 index 0000000..c2f3272 --- /dev/null +++ b/src/faceembed/ort/facenet.rs @@ -0,0 +1,79 @@ +use crate::errors::*; +use crate::faceembed::FaceEmbedder; +use error_stack::ResultExt; +use ndarray::{Array2, ArrayView4}; +use std::path::Path; + +#[derive(Debug)] +pub struct EmbeddingGenerator { + // Placeholder - ORT implementation to be completed later + _placeholder: (), +} + +pub struct EmbeddingGeneratorBuilder { + _model_data: Vec, +} + +impl EmbeddingGeneratorBuilder { + pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result { + Ok(Self { + _model_data: model.as_ref().to_vec(), + }) + } + + pub fn with_execution_providers(self, _providers: Vec) -> Self { + self + } + + pub fn with_intra_threads(self, _threads: usize) -> Self { + self + } + + pub fn with_inter_threads(self, _threads: usize) -> Self { + self + } + + pub fn build(self) -> crate::errors::Result { + // TODO: Implement ORT session creation + tracing::warn!("ORT FaceNet implementation is not yet complete"); + Ok(EmbeddingGenerator { _placeholder: () }) + } +} + +impl EmbeddingGenerator { + const INPUT_NAME: &'static str = "serving_default_input_6:0"; + const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; + + pub fn builder>() -> fn( + T, + ) -> std::result::Result< + EmbeddingGeneratorBuilder, + error_stack::Report, + > { + EmbeddingGeneratorBuilder::new + } + + pub fn new(path: impl AsRef) -> crate::errors::Result { + let model = std::fs::read(path) + .change_context(Error) + .attach_printable("Failed to read model file")?; + Self::new_from_bytes(&model) + } + + pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result { + tracing::info!("Loading face embedding model from bytes"); + Self::builder()(model)?.build() + } + + pub fn run_models(&self, _face: ArrayView4) -> crate::errors::Result> { + // TODO: Implement ORT inference + tracing::error!("ORT FaceNet inference not yet implemented"); + Err(Error).attach_printable("ORT FaceNet implementation is incomplete") + } +} + +impl FaceEmbedder for EmbeddingGenerator { + fn run_models(&self, faces: ArrayView4) -> crate::errors::Result> { + self.run_models(faces) + } +} diff --git a/src/faceembed/ort/mod.rs b/src/faceembed/ort/mod.rs new file mode 100644 index 0000000..94700e6 --- /dev/null +++ b/src/faceembed/ort/mod.rs @@ -0,0 +1,3 @@ +pub mod facenet; + +pub use facenet::EmbeddingGenerator; diff --git a/src/main.rs b/src/main.rs index 1f8150f..5685524 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod cli; mod errors; use bounding_box::roi::MultiRoi; -use detector::{facedet::retinaface::FaceDetectionConfig, faceembed}; +use detector::{facedet, facedet::FaceDetectionConfig, faceembed}; use errors::*; use fast_image_resize::ResizeOptions; use ndarray::*; @@ -20,97 +20,45 @@ pub fn main() -> Result<()> { let args = ::parse(); match args.cmd { cli::SubCommand::Detect(detect) => { - use detector::facedet; - let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL) - .change_context(Error)? - .with_forward_type(detect.forward_type) - .build() - .change_context(errors::Error) - .attach_printable("Failed to create face detection model")?; - let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL) - .change_context(Error)? - .with_forward_type(detect.forward_type) - .build() - .change_context(errors::Error) - .attach_printable("Failed to create face embedding model")?; - let image = image::open(detect.image).change_context(Error)?; - let image = image.into_rgb8(); - let mut array = image - .into_ndarray() - .change_context(errors::Error) - .attach_printable("Failed to convert image to ndarray")?; - let output = retinaface - .detect_faces( - array.view(), - FaceDetectionConfig::default() - .with_threshold(detect.threshold) - .with_nms_threshold(detect.nms_threshold), - ) - .change_context(errors::Error) - .attach_printable("Failed to detect faces")?; - for bbox in &output.bbox { - tracing::info!("Detected face: {:?}", bbox); - use bounding_box::draw::*; - array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); - } - let face_rois = array - .view() - .multi_roi(&output.bbox) - .change_context(Error)? - .into_iter() - // .inspect(|f| { - // tracing::info!("Face ROI shape before resize: {:?}", f.dim()); - // }) - .map(|roi| { - roi.as_standard_layout() - .fast_resize(512, 512, &ResizeOptions::default()) - .change_context(Error) - }) - // .inspect(|f| { - // f.as_ref().inspect(|f| { - // tracing::info!("Face ROI shape after resize: {:?}", f.dim()); - // }); - // }) - .collect::>>()?; - let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); + // Choose backend based on executor type (defaulting to MNN for backward compatibility) + let executor = detect.executor.unwrap_or(cli::Executor::Mnn); - let chunk_size = CHUNK_SIZE; - let embeddings = face_roi_views - .chunks(chunk_size) - .map(|chunk| { - tracing::info!("Processing chunk of size: {}", chunk.len()); + match executor { + cli::Executor::Mnn => { + let retinaface = facedet::mnn::FaceDetection::builder()(RETINAFACE_MODEL) + .change_context(Error)? + .with_forward_type(detect.forward_type) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL) + .change_context(Error)? + .with_forward_type(detect.forward_type) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; - if chunk.len() < 8 { - tracing::warn!("Chunk size is less than 8, padding with zeros"); - let zeros = Array3::zeros((512, 512, 3)); - let zero_array = core::iter::repeat(zeros.view()) - .take(chunk_size) - .collect::>(); - let face_rois: Array4 = ndarray::stack(Axis(0), zero_array.as_slice()) - .change_context(errors::Error) - .attach_printable("Failed to stack rois together")?; - let output = facenet.run_models(face_rois.view()).change_context(Error)?; - Ok(output) - } else { - let face_rois: Array4 = ndarray::stack(Axis(0), chunk) - .change_context(errors::Error) - .attach_printable("Failed to stack rois together")?; - let output = facenet.run_models(face_rois.view()).change_context(Error)?; - Ok(output) - } - }) - .collect::>>>(); + run_detection(detect, retinaface, facenet)?; + } + cli::Executor::Onnx => { + // Load ONNX models + const RETINAFACE_ONNX_MODEL: &[u8] = + include_bytes!("../models/retinaface.onnx"); + const FACENET_ONNX_MODEL: &[u8] = include_bytes!("../models/facenet.onnx"); - let v = array.view(); - if let Some(output) = detect.output { - let image: image::RgbImage = v - .to_image() - .change_context(errors::Error) - .attach_printable("Failed to convert ndarray to image")?; - image - .save(output) - .change_context(errors::Error) - .attach_printable("Failed to save output image")?; + let retinaface = facedet::ort::FaceDetection::builder()(RETINAFACE_ONNX_MODEL) + .change_context(Error)? + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_ONNX_MODEL) + .change_context(Error)? + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; + + run_detection(detect, retinaface, facenet)?; + } } } cli::SubCommand::List(list) => { @@ -122,3 +70,91 @@ pub fn main() -> Result<()> { } Ok(()) } + +fn run_detection(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()> +where + D: facedet::FaceDetector, + E: faceembed::FaceEmbedder, +{ + let image = image::open(detect.image).change_context(Error)?; + let image = image.into_rgb8(); + let mut array = image + .into_ndarray() + .change_context(errors::Error) + .attach_printable("Failed to convert image to ndarray")?; + let output = retinaface + .detect_faces( + array.view(), + FaceDetectionConfig::default() + .with_threshold(detect.threshold) + .with_nms_threshold(detect.nms_threshold), + ) + .change_context(errors::Error) + .attach_printable("Failed to detect faces")?; + for bbox in &output.bbox { + tracing::info!("Detected face: {:?}", bbox); + use bounding_box::draw::*; + array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); + } + let face_rois = array + .view() + .multi_roi(&output.bbox) + .change_context(Error)? + .into_iter() + // .inspect(|f| { + // tracing::info!("Face ROI shape before resize: {:?}", f.dim()); + // }) + .map(|roi| { + roi.as_standard_layout() + .fast_resize(512, 512, &ResizeOptions::default()) + .change_context(Error) + }) + // .inspect(|f| { + // f.as_ref().inspect(|f| { + // tracing::info!("Face ROI shape after resize: {:?}", f.dim()); + // }); + // }) + .collect::>>()?; + let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); + + let chunk_size = CHUNK_SIZE; + let embeddings = face_roi_views + .chunks(chunk_size) + .map(|chunk| { + tracing::info!("Processing chunk of size: {}", chunk.len()); + + if chunk.len() < 8 { + tracing::warn!("Chunk size is less than 8, padding with zeros"); + let zeros = Array3::zeros((512, 512, 3)); + let zero_array = core::iter::repeat(zeros.view()) + .take(chunk_size) + .collect::>(); + let face_rois: Array4 = ndarray::stack(Axis(0), zero_array.as_slice()) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + let output = facenet.run_models(face_rois.view()).change_context(Error)?; + Ok(output) + } else { + let face_rois: Array4 = ndarray::stack(Axis(0), chunk) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + let output = facenet.run_models(face_rois.view()).change_context(Error)?; + Ok(output) + } + }) + .collect::>>>(); + + let v = array.view(); + if let Some(output) = detect.output { + let image: image::RgbImage = v + .to_image() + .change_context(errors::Error) + .attach_printable("Failed to convert ndarray to image")?; + image + .save(output) + .change_context(errors::Error) + .attach_printable("Failed to save output image")?; + } + + Ok(()) +}