feat: Added stuff
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "rfcs"]
|
||||||
|
path = rfcs
|
||||||
|
url = git@github.com:aftershootco/rfcs.git
|
||||||
199
README.md
199
README.md
@@ -1,3 +1,198 @@
|
|||||||
# Face Detection
|
# Face Detection and Embedding
|
||||||
|
|
||||||
Rust programs to do face detection and face embedding
|
A high-performance Rust implementation for face detection and face embedding generation using neural networks.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This project provides a complete face detection and recognition pipeline with the following capabilities:
|
||||||
|
|
||||||
|
- **Face Detection**: Detect faces in images using RetinaFace model
|
||||||
|
- **Face Embedding**: Generate face embeddings using FaceNet model
|
||||||
|
- **Multiple Backends**: Support for both MNN and ONNX runtime execution
|
||||||
|
- **Hardware Acceleration**: Metal, CoreML, and OpenCL support on compatible platforms
|
||||||
|
- **Modular Design**: Workspace architecture with reusable components
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 🔍 **Accurate Face Detection** - Uses RetinaFace model for robust face detection
|
||||||
|
- 🧠 **Face Embeddings** - Generate 512-dimensional face embeddings with FaceNet
|
||||||
|
- ⚡ **High Performance** - Optimized with hardware acceleration (Metal, CoreML)
|
||||||
|
- 🔧 **Flexible Configuration** - Adjustable detection thresholds and NMS parameters
|
||||||
|
- 📦 **Modular Architecture** - Reusable components for image processing and bounding boxes
|
||||||
|
- 🖼️ **Visual Output** - Draw bounding boxes on detected faces
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The project is organized as a Rust workspace with the following components:
|
||||||
|
|
||||||
|
- **`detector`** - Main face detection and embedding application
|
||||||
|
- **`bounding-box`** - Geometric operations and drawing utilities for bounding boxes
|
||||||
|
- **`ndarray-image`** - Conversion utilities between ndarray and image formats
|
||||||
|
- **`ndarray-resize`** - Fast image resizing operations on ndarray data
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
The project includes pre-trained neural network models:
|
||||||
|
|
||||||
|
- **RetinaFace** - Face detection model (`.mnn` and `.onnx` formats)
|
||||||
|
- **FaceNet** - Face embedding model (`.mnn` and `.onnx` formats)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Face Detection
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Detect faces using MNN backend (default)
|
||||||
|
cargo run --release detect path/to/image.jpg
|
||||||
|
|
||||||
|
# Detect faces using ONNX Runtime backend
|
||||||
|
cargo run --release detect --executor onnx path/to/image.jpg
|
||||||
|
|
||||||
|
# Save output with bounding boxes drawn
|
||||||
|
cargo run --release detect --output detected.jpg path/to/image.jpg
|
||||||
|
|
||||||
|
# Adjust detection sensitivity
|
||||||
|
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend Selection
|
||||||
|
|
||||||
|
The project supports two inference backends:
|
||||||
|
|
||||||
|
- **MNN Backend** (default): High-performance inference framework with Metal/CoreML support
|
||||||
|
- **ONNX Runtime Backend**: Cross-platform ML inference with broad hardware support
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use MNN backend with Metal acceleration (macOS)
|
||||||
|
cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg
|
||||||
|
|
||||||
|
# Use ONNX Runtime backend
|
||||||
|
cargo run --release detect --executor onnx path/to/image.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Face detection with custom parameters
|
||||||
|
cargo run --release detect [OPTIONS] <IMAGE>
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-m, --model <MODEL> Custom model path
|
||||||
|
-M, --model-type <MODEL_TYPE> Model type [default: retina-face]
|
||||||
|
-o, --output <OUTPUT> Output image path
|
||||||
|
-e, --executor <EXECUTOR> Inference backend [mnn, onnx]
|
||||||
|
-f, --forward-type <FORWARD_TYPE> MNN execution backend [default: cpu]
|
||||||
|
-t, --threshold <THRESHOLD> Detection threshold [default: 0.8]
|
||||||
|
-n, --nms-threshold <NMS_THRESHOLD> NMS threshold [default: 0.3]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build the project
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# Run face detection on sample image
|
||||||
|
just run
|
||||||
|
# or
|
||||||
|
cargo run --release detect ./1000066593.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
## Hardware Acceleration
|
||||||
|
|
||||||
|
### MNN Backend
|
||||||
|
|
||||||
|
The MNN backend supports various execution backends:
|
||||||
|
|
||||||
|
- **CPU** - Default, works on all platforms
|
||||||
|
- **Metal** - macOS GPU acceleration
|
||||||
|
- **CoreML** - macOS/iOS neural engine acceleration
|
||||||
|
- **OpenCL** - Cross-platform GPU acceleration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use Metal acceleration on macOS
|
||||||
|
cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg
|
||||||
|
|
||||||
|
# Use CoreML on macOS/iOS
|
||||||
|
cargo run --release detect --executor mnn --forward-type coreml path/to/image.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
### ONNX Runtime Backend
|
||||||
|
|
||||||
|
The ONNX Runtime backend automatically selects the best available execution provider based on your system configuration.
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Rust 2024 edition
|
||||||
|
- MNN runtime (automatically linked)
|
||||||
|
- ONNX runtime (for ONNX backend)
|
||||||
|
|
||||||
|
### Building
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Standard build
|
||||||
|
cargo build
|
||||||
|
|
||||||
|
# Release build with optimizations
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
├── src/
|
||||||
|
│ ├── facedet/ # Face detection modules
|
||||||
|
│ │ ├── mnn/ # MNN backend implementations
|
||||||
|
│ │ ├── ort/ # ONNX Runtime backend implementations
|
||||||
|
│ │ └── postprocess.rs # Shared postprocessing logic
|
||||||
|
│ ├── faceembed/ # Face embedding modules
|
||||||
|
│ │ ├── mnn/ # MNN backend implementations
|
||||||
|
│ │ └── ort/ # ONNX Runtime backend implementations
|
||||||
|
│ ├── cli.rs # Command line interface
|
||||||
|
│ └── main.rs # Application entry point
|
||||||
|
├── models/ # Neural network models (.mnn and .onnx)
|
||||||
|
├── bounding-box/ # Bounding box utilities
|
||||||
|
├── ndarray-image/ # Image conversion utilities
|
||||||
|
└── ndarray-resize/ # Image resizing utilities
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend Architecture
|
||||||
|
|
||||||
|
The codebase is organized to support multiple inference backends:
|
||||||
|
|
||||||
|
- **Common interfaces**: `FaceDetector` and `FaceEmbedder` traits provide unified APIs
|
||||||
|
- **Shared postprocessing**: Common logic for anchor generation, NMS, and coordinate decoding
|
||||||
|
- **Backend-specific implementations**: Separate modules for MNN and ONNX Runtime
|
||||||
|
- **Modular design**: Easy to add new backends by implementing the common traits
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
Key dependencies include:
|
||||||
|
|
||||||
|
- **MNN** - High-performance neural network inference framework (MNN backend)
|
||||||
|
- **ONNX Runtime** - Cross-platform ML inference (ORT backend)
|
||||||
|
- **ndarray** - N-dimensional array processing
|
||||||
|
- **image** - Image processing and I/O
|
||||||
|
- **clap** - Command line argument parsing
|
||||||
|
- **bounding-box** - Geometric operations for face detection
|
||||||
|
- **error-stack** - Structured error handling
|
||||||
|
|
||||||
|
### Backend Status
|
||||||
|
|
||||||
|
- ✅ **MNN Backend**: Fully implemented with hardware acceleration support
|
||||||
|
- 🚧 **ONNX Runtime Backend**: Framework implemented, inference logic to be completed
|
||||||
|
|
||||||
|
*Note: The ORT backend currently provides the framework but requires completion of the inference implementation.*
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Built with Rust for maximum performance and safety in computer vision applications.*
|
||||||
|
|||||||
1
rfcs
Submodule
1
rfcs
Submodule
Submodule rfcs added at ad85f4c819
@@ -48,6 +48,8 @@ pub struct Detect {
|
|||||||
pub model_type: Models,
|
pub model_type: Models,
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
pub output: Option<PathBuf>,
|
pub output: Option<PathBuf>,
|
||||||
|
#[clap(short = 'e', long)]
|
||||||
|
pub executor: Option<Executor>,
|
||||||
#[clap(short, long, default_value = "cpu")]
|
#[clap(short, long, default_value = "cpu")]
|
||||||
pub forward_type: mnn::ForwardType,
|
pub forward_type: mnn::ForwardType,
|
||||||
#[clap(short, long, default_value_t = 0.8)]
|
#[clap(short, long, default_value_t = 0.8)]
|
||||||
|
|||||||
@@ -1,2 +1,24 @@
|
|||||||
pub mod retinaface;
|
pub mod mnn;
|
||||||
|
pub mod ort;
|
||||||
|
pub mod postprocess;
|
||||||
pub mod yolo;
|
pub mod yolo;
|
||||||
|
|
||||||
|
// Re-export common types and traits
|
||||||
|
pub use postprocess::{
|
||||||
|
FaceDetectionConfig, FaceDetectionModelOutput, FaceDetectionOutput,
|
||||||
|
FaceDetectionProcessedOutput, FaceDetector, FaceLandmarks,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convenience type aliases for different backends
|
||||||
|
pub mod retinaface {
|
||||||
|
pub use crate::facedet::mnn::retinaface as mnn;
|
||||||
|
pub use crate::facedet::ort::retinaface as ort;
|
||||||
|
|
||||||
|
// Re-export common types
|
||||||
|
pub use crate::facedet::postprocess::{
|
||||||
|
FaceDetectionConfig, FaceDetectionOutput, FaceDetector, FaceLandmarks,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to MNN implementation for backward compatibility
|
||||||
|
pub use mnn::retinaface::FaceDetection;
|
||||||
|
|||||||
3
src/facedet/mnn/mod.rs
Normal file
3
src/facedet/mnn/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod retinaface;
|
||||||
|
|
||||||
|
pub use retinaface::FaceDetection;
|
||||||
174
src/facedet/mnn/retinaface.rs
Normal file
174
src/facedet/mnn/retinaface.rs
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
use crate::errors::*;
|
||||||
|
use crate::facedet::postprocess::*;
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use mnn_bridge::ndarray::*;
|
||||||
|
use ndarray_resize::NdFir;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FaceDetection {
|
||||||
|
handle: mnn_sync::SessionHandle,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FaceDetectionBuilder {
|
||||||
|
schedule_config: Option<mnn::ScheduleConfig>,
|
||||||
|
backend_config: Option<mnn::BackendConfig>,
|
||||||
|
model: mnn::Interpreter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetectionBuilder {
|
||||||
|
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
schedule_config: None,
|
||||||
|
backend_config: None,
|
||||||
|
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||||
|
.map_err(|e| e.into_inner())
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to load model from bytes")?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||||
|
self.schedule_config
|
||||||
|
.get_or_insert_default()
|
||||||
|
.set_type(forward_type);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||||
|
self.schedule_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||||
|
self.backend_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> Result<FaceDetection> {
|
||||||
|
let model = self.model;
|
||||||
|
let sc = self.schedule_config.unwrap_or_default();
|
||||||
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create session handle")?;
|
||||||
|
Ok(FaceDetection { handle })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetection {
|
||||||
|
pub fn builder<T: AsRef<[u8]>>()
|
||||||
|
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
||||||
|
FaceDetectionBuilder::new
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||||
|
let model = std::fs::read(path)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to read model file")?;
|
||||||
|
Self::new_from_bytes(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||||
|
tracing::info!("Loading face detection model from bytes");
|
||||||
|
let mut model = mnn::Interpreter::from_bytes(model)
|
||||||
|
.map_err(|e| e.into_inner())
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to load model from bytes")?;
|
||||||
|
model.set_session_mode(mnn::SessionMode::Release);
|
||||||
|
model
|
||||||
|
.set_cache_file("retinaface.cache", 128)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set cache file")?;
|
||||||
|
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||||
|
let sc = mnn::ScheduleConfig::new()
|
||||||
|
.with_type(mnn::ForwardType::Metal)
|
||||||
|
.with_backend_config(bc);
|
||||||
|
tracing::info!("Creating session handle for face detection model");
|
||||||
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create session handle")?;
|
||||||
|
Ok(FaceDetection { handle })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetector for FaceDetection {
|
||||||
|
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let mut resized = image
|
||||||
|
.fast_resize(1024, 1024, None)
|
||||||
|
.change_context(Error)?
|
||||||
|
.mapv(|f| f as f32);
|
||||||
|
|
||||||
|
// Apply mean subtraction: [104, 117, 123]
|
||||||
|
resized
|
||||||
|
.axis_iter_mut(ndarray::Axis(2))
|
||||||
|
.zip([104, 117, 123])
|
||||||
|
.for_each(|(mut array, pixel)| {
|
||||||
|
let pixel = pixel as f32;
|
||||||
|
array.map_inplace(|v| *v -= pixel);
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut resized = resized
|
||||||
|
.permuted_axes((2, 0, 1))
|
||||||
|
.insert_axis(ndarray::Axis(0))
|
||||||
|
.as_standard_layout()
|
||||||
|
.into_owned();
|
||||||
|
|
||||||
|
use ::tap::*;
|
||||||
|
let output = self
|
||||||
|
.handle
|
||||||
|
.run(move |sr| {
|
||||||
|
let tensor = resized
|
||||||
|
.as_mnn_tensor_mut()
|
||||||
|
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||||
|
.change_context(mnn::error::ErrorKind::TensorError)?;
|
||||||
|
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||||
|
let (intptr, session) = sr.both_mut();
|
||||||
|
tracing::trace!("Copying input tensor to host");
|
||||||
|
unsafe {
|
||||||
|
let mut input = intptr.input_unresized::<f32>(session, "input")?;
|
||||||
|
tracing::trace!("Input shape: {:?}", input.shape());
|
||||||
|
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
|
||||||
|
input.view_mut(),
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
intptr.resize_session(session);
|
||||||
|
let mut input = intptr.input::<f32>(session, "input")?;
|
||||||
|
tracing::trace!("Input shape: {:?}", input.shape());
|
||||||
|
input.copy_from_host_tensor(tensor.view())?;
|
||||||
|
|
||||||
|
tracing::info!("Running face detection session");
|
||||||
|
intptr.run_session(&session)?;
|
||||||
|
let output_tensor = intptr
|
||||||
|
.output::<f32>(&session, "bbox")?
|
||||||
|
.create_host_tensor_from_device(true)
|
||||||
|
.as_ndarray()
|
||||||
|
.to_owned();
|
||||||
|
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
|
||||||
|
let output_confidence = intptr
|
||||||
|
.output::<f32>(&session, "confidence")?
|
||||||
|
.create_host_tensor_from_device(true)
|
||||||
|
.as_ndarray::<ndarray::Ix3>()
|
||||||
|
.to_owned();
|
||||||
|
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
|
||||||
|
let output_landmark = intptr
|
||||||
|
.output::<f32>(&session, "landmark")?
|
||||||
|
.create_host_tensor_from_device(true)
|
||||||
|
.as_ndarray::<ndarray::Ix3>()
|
||||||
|
.to_owned();
|
||||||
|
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
|
||||||
|
Ok(FaceDetectionModelOutput {
|
||||||
|
bbox: output_tensor,
|
||||||
|
confidence: output_confidence,
|
||||||
|
landmark: output_landmark,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.map_err(|e| e.into_inner())
|
||||||
|
.change_context(Error)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/facedet/ort/mod.rs
Normal file
3
src/facedet/ort/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod retinaface;
|
||||||
|
|
||||||
|
pub use retinaface::FaceDetection;
|
||||||
264
src/facedet/ort/retinaface.rs
Normal file
264
src/facedet/ort/retinaface.rs
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
use crate::errors::*;
|
||||||
|
use crate::facedet::postprocess::*;
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use ndarray_resize::NdFir;
|
||||||
|
use ort::{
|
||||||
|
execution_providers::{
|
||||||
|
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||||
|
},
|
||||||
|
session::{Session, builder::GraphOptimizationLevel},
|
||||||
|
value::Tensor,
|
||||||
|
};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FaceDetection {
|
||||||
|
session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FaceDetectionBuilder {
|
||||||
|
model_data: Vec<u8>,
|
||||||
|
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
|
||||||
|
intra_threads: Option<usize>,
|
||||||
|
inter_threads: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetectionBuilder {
|
||||||
|
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
model_data: model.as_ref().to_vec(),
|
||||||
|
execution_providers: None,
|
||||||
|
intra_threads: None,
|
||||||
|
inter_threads: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||||
|
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|provider| match provider.as_str() {
|
||||||
|
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||||
|
_ => {
|
||||||
|
tracing::warn!("Unknown execution provider: {}", provider);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if !execution_providers.is_empty() {
|
||||||
|
self.execution_providers = Some(execution_providers);
|
||||||
|
} else {
|
||||||
|
tracing::warn!("No valid execution providers found, falling back to CPU");
|
||||||
|
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_intra_threads(mut self, threads: usize) -> Self {
|
||||||
|
self.intra_threads = Some(threads);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_inter_threads(mut self, threads: usize) -> Self {
|
||||||
|
self.inter_threads = Some(threads);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> crate::errors::Result<FaceDetection> {
|
||||||
|
let mut session_builder = Session::builder()
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create session builder")?;
|
||||||
|
|
||||||
|
// Set execution providers
|
||||||
|
if let Some(providers) = self.execution_providers {
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_execution_providers(providers)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set execution providers")?;
|
||||||
|
} else {
|
||||||
|
// Default to CPU
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_execution_providers([CPUExecutionProvider::default().build()])
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set default CPU execution provider")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set threading options
|
||||||
|
if let Some(threads) = self.intra_threads {
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_intra_threads(threads)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set intra threads")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(threads) = self.inter_threads {
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_inter_threads(threads)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set inter threads")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set optimization level
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set optimization level")?;
|
||||||
|
|
||||||
|
// Create session from model bytes
|
||||||
|
let session = session_builder
|
||||||
|
.commit_from_memory(&self.model_data)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||||
|
|
||||||
|
tracing::info!("Successfully created ORT RetinaFace session");
|
||||||
|
|
||||||
|
Ok(FaceDetection { session })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetection {
|
||||||
|
pub fn builder<T: AsRef<[u8]>>()
|
||||||
|
-> fn(T) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>>
|
||||||
|
{
|
||||||
|
FaceDetectionBuilder::new
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
||||||
|
let model = std::fs::read(path)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to read model file")?;
|
||||||
|
Self::new_from_bytes(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_bytes(model: &[u8]) -> crate::errors::Result<Self> {
|
||||||
|
tracing::info!("Loading ORT RetinaFace model from bytes");
|
||||||
|
Self::builder()(model)?.build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetector for FaceDetection {
|
||||||
|
fn run_model(
|
||||||
|
&mut self,
|
||||||
|
image: ndarray::ArrayView3<u8>,
|
||||||
|
) -> crate::errors::Result<FaceDetectionModelOutput> {
|
||||||
|
// Resize image to 1024x1024
|
||||||
|
let mut resized = image
|
||||||
|
.fast_resize(1024, 1024, None)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to resize image")?
|
||||||
|
.mapv(|f| f as f32);
|
||||||
|
|
||||||
|
// Apply mean subtraction: [104, 117, 123] for BGR format
|
||||||
|
resized
|
||||||
|
.axis_iter_mut(ndarray::Axis(2))
|
||||||
|
.zip([104.0, 117.0, 123.0])
|
||||||
|
.for_each(|(mut array, mean)| {
|
||||||
|
array.map_inplace(|v| *v -= mean);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert from HWC to NCHW format (add batch dimension and transpose)
|
||||||
|
let input_tensor = resized
|
||||||
|
.permuted_axes((2, 0, 1))
|
||||||
|
.insert_axis(ndarray::Axis(0))
|
||||||
|
.as_standard_layout()
|
||||||
|
.into_owned();
|
||||||
|
|
||||||
|
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||||
|
|
||||||
|
// Create ORT input tensor
|
||||||
|
let input_value = Tensor::from_array(input_tensor)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create input tensor")?;
|
||||||
|
|
||||||
|
// Run inference
|
||||||
|
tracing::debug!("Running ORT RetinaFace inference");
|
||||||
|
let outputs = self
|
||||||
|
.session
|
||||||
|
.run(ort::inputs!["input" => input_value])
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to run inference")?;
|
||||||
|
|
||||||
|
// Extract outputs by name
|
||||||
|
let bbox_output = outputs
|
||||||
|
.get("bbox")
|
||||||
|
.ok_or(Error)
|
||||||
|
.attach_printable("Missing bbox output from model")?
|
||||||
|
.try_extract_tensor::<f32>()
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to extract bbox tensor")?;
|
||||||
|
|
||||||
|
let confidence_output = outputs
|
||||||
|
.get("confidence")
|
||||||
|
.ok_or(Error)
|
||||||
|
.attach_printable("Missing confidence output from model")?
|
||||||
|
.try_extract_tensor::<f32>()
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to extract confidence tensor")?;
|
||||||
|
|
||||||
|
let landmark_output = outputs
|
||||||
|
.get("landmark")
|
||||||
|
.ok_or(Error)
|
||||||
|
.attach_printable("Missing landmark output from model")?
|
||||||
|
.try_extract_tensor::<f32>()
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to extract landmark tensor")?;
|
||||||
|
|
||||||
|
// Get tensor shapes and data
|
||||||
|
let (bbox_shape, bbox_data) = bbox_output;
|
||||||
|
let (confidence_shape, confidence_data) = confidence_output;
|
||||||
|
let (landmark_shape, landmark_data) = landmark_output;
|
||||||
|
|
||||||
|
tracing::trace!(
|
||||||
|
"Output shapes - bbox: {:?}, confidence: {:?}, landmark: {:?}",
|
||||||
|
bbox_shape,
|
||||||
|
confidence_shape,
|
||||||
|
landmark_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
// Convert to ndarray format
|
||||||
|
let bbox_dims = bbox_shape.as_ref();
|
||||||
|
let confidence_dims = confidence_shape.as_ref();
|
||||||
|
let landmark_dims = landmark_shape.as_ref();
|
||||||
|
|
||||||
|
let bbox_array = ndarray::Array3::from_shape_vec(
|
||||||
|
(
|
||||||
|
bbox_dims[0] as usize,
|
||||||
|
bbox_dims[1] as usize,
|
||||||
|
bbox_dims[2] as usize,
|
||||||
|
),
|
||||||
|
bbox_data.to_vec(),
|
||||||
|
)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create bbox ndarray")?;
|
||||||
|
|
||||||
|
let confidence_array = ndarray::Array3::from_shape_vec(
|
||||||
|
(
|
||||||
|
confidence_dims[0] as usize,
|
||||||
|
confidence_dims[1] as usize,
|
||||||
|
confidence_dims[2] as usize,
|
||||||
|
),
|
||||||
|
confidence_data.to_vec(),
|
||||||
|
)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create confidence ndarray")?;
|
||||||
|
|
||||||
|
let landmark_array = ndarray::Array3::from_shape_vec(
|
||||||
|
(
|
||||||
|
landmark_dims[0] as usize,
|
||||||
|
landmark_dims[1] as usize,
|
||||||
|
landmark_dims[2] as usize,
|
||||||
|
),
|
||||||
|
landmark_data.to_vec(),
|
||||||
|
)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create landmark ndarray")?;
|
||||||
|
|
||||||
|
Ok(FaceDetectionModelOutput {
|
||||||
|
bbox: bbox_array,
|
||||||
|
confidence: confidence_array,
|
||||||
|
landmark: landmark_array,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,8 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use bounding_box::{Aabb2, nms::nms};
|
use bounding_box::{Aabb2, nms::nms};
|
||||||
use error_stack::ResultExt;
|
use error_stack::ResultExt;
|
||||||
use mnn_bridge::ndarray::*;
|
|
||||||
use nalgebra::{Point2, Vector2};
|
use nalgebra::{Point2, Vector2};
|
||||||
use ndarray_resize::NdFir;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
/// Configuration for face detection postprocessing
|
/// Configuration for face detection postprocessing
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@@ -32,30 +30,37 @@ impl FaceDetectionConfig {
|
|||||||
self.threshold = threshold;
|
self.threshold = threshold;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
|
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
|
||||||
self.nms_threshold = nms_threshold;
|
self.nms_threshold = nms_threshold;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
|
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
|
||||||
self.variances = variances;
|
self.variances = variances;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||||
self.steps = steps;
|
self.steps = steps;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
|
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
|
||||||
self.min_sizes = min_sizes;
|
self.min_sizes = min_sizes;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_clip(mut self, clip: bool) -> Self {
|
pub fn with_clip(mut self, clip: bool) -> Self {
|
||||||
self.clamp = clip;
|
self.clamp = clip;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_input_width(mut self, input_width: usize) -> Self {
|
pub fn with_input_width(mut self, input_width: usize) -> Self {
|
||||||
self.input_width = input_width;
|
self.input_width = input_width;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_input_height(mut self, input_height: usize) -> Self {
|
pub fn with_input_height(mut self, input_height: usize) -> Self {
|
||||||
self.input_height = input_height;
|
self.input_height = input_height;
|
||||||
self
|
self
|
||||||
@@ -77,18 +82,6 @@ impl Default for FaceDetectionConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct FaceDetection {
|
|
||||||
handle: mnn_sync::SessionHandle,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct FaceDetectionModelOutput {
|
|
||||||
pub bbox: ndarray::Array3<f32>,
|
|
||||||
pub confidence: ndarray::Array3<f32>,
|
|
||||||
pub landmark: ndarray::Array3<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents the 5 facial landmarks detected by RetinaFace
|
/// Represents the 5 facial landmarks detected by RetinaFace
|
||||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||||
pub struct FaceLandmarks {
|
pub struct FaceLandmarks {
|
||||||
@@ -99,6 +92,13 @@ pub struct FaceLandmarks {
|
|||||||
pub right_mouth: Point2<f32>,
|
pub right_mouth: Point2<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct FaceDetectionModelOutput {
|
||||||
|
pub bbox: ndarray::Array3<f32>,
|
||||||
|
pub confidence: ndarray::Array3<f32>,
|
||||||
|
pub landmark: ndarray::Array3<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct FaceDetectionProcessedOutput {
|
pub struct FaceDetectionProcessedOutput {
|
||||||
pub bbox: Vec<Aabb2<f32>>,
|
pub bbox: Vec<Aabb2<f32>>,
|
||||||
@@ -113,7 +113,13 @@ pub struct FaceDetectionOutput {
|
|||||||
pub landmark: Vec<FaceLandmarks>,
|
pub landmark: Vec<FaceLandmarks>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
/// Raw model outputs that can be converted to FaceDetectionModelOutput
|
||||||
|
pub trait IntoModelOutput {
|
||||||
|
fn into_model_output(self) -> Result<FaceDetectionModelOutput>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate anchors for RetinaFace model
|
||||||
|
pub fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
||||||
let mut anchors = Vec::new();
|
let mut anchors = Vec::new();
|
||||||
let feature_maps: Vec<(usize, usize)> = config
|
let feature_maps: Vec<(usize, usize)> = config
|
||||||
.steps
|
.steps
|
||||||
@@ -220,9 +226,7 @@ impl FaceDetectionModelOutput {
|
|||||||
landmarks: decoded_landmarks,
|
landmarks: decoded_landmarks,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl FaceDetectionModelOutput {
|
|
||||||
pub fn print(&self, limit: usize) {
|
pub fn print(&self, limit: usize) {
|
||||||
tracing::info!("Detected {} faces", self.bbox.shape()[1]);
|
tracing::info!("Detected {} faces", self.bbox.shape()[1]);
|
||||||
|
|
||||||
@@ -246,102 +250,16 @@ impl FaceDetectionModelOutput {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FaceDetectionBuilder {
|
/// Apply Non-Maximum Suppression and convert to final output format
|
||||||
schedule_config: Option<mnn::ScheduleConfig>,
|
pub fn apply_nms_and_finalize(
|
||||||
backend_config: Option<mnn::BackendConfig>,
|
processed: FaceDetectionProcessedOutput,
|
||||||
model: mnn::Interpreter,
|
config: &FaceDetectionConfig,
|
||||||
}
|
image_size: (usize, usize), // (width, height)
|
||||||
|
|
||||||
impl FaceDetectionBuilder {
|
|
||||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
|
||||||
Ok(Self {
|
|
||||||
schedule_config: None,
|
|
||||||
backend_config: None,
|
|
||||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
|
||||||
.map_err(|e| e.into_inner())
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to load model from bytes")?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
|
||||||
self.schedule_config
|
|
||||||
.get_or_insert_default()
|
|
||||||
.set_type(forward_type);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
|
||||||
self.schedule_config = Some(config);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
|
||||||
self.backend_config = Some(config);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build(self) -> Result<FaceDetection> {
|
|
||||||
let model = self.model;
|
|
||||||
let sc = self.schedule_config.unwrap_or_default();
|
|
||||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to create session handle")?;
|
|
||||||
Ok(FaceDetection { handle })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FaceDetection {
|
|
||||||
pub fn builder<T: AsRef<[u8]>>()
|
|
||||||
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
|
||||||
FaceDetectionBuilder::new
|
|
||||||
}
|
|
||||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
|
||||||
let model = std::fs::read(path)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to read model file")?;
|
|
||||||
Self::new_from_bytes(&model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
|
||||||
tracing::info!("Loading face detection model from bytes");
|
|
||||||
let mut model = mnn::Interpreter::from_bytes(model)
|
|
||||||
.map_err(|e| e.into_inner())
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to load model from bytes")?;
|
|
||||||
model.set_session_mode(mnn::SessionMode::Release);
|
|
||||||
model
|
|
||||||
.set_cache_file("retinaface.cache", 128)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to set cache file")?;
|
|
||||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
|
||||||
let sc = mnn::ScheduleConfig::new()
|
|
||||||
.with_type(mnn::ForwardType::Metal)
|
|
||||||
.with_backend_config(bc);
|
|
||||||
tracing::info!("Creating session handle for face detection model");
|
|
||||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to create session handle")?;
|
|
||||||
Ok(FaceDetection { handle })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn detect_faces(
|
|
||||||
&self,
|
|
||||||
image: ndarray::ArrayView3<u8>,
|
|
||||||
config: FaceDetectionConfig,
|
|
||||||
) -> Result<FaceDetectionOutput> {
|
) -> Result<FaceDetectionOutput> {
|
||||||
let (height, width, _channels) = image.dim();
|
|
||||||
let output = self
|
|
||||||
.run_models(image)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to detect faces")?;
|
|
||||||
// denormalize the bounding boxes
|
|
||||||
let factor = Vector2::new(width as f32, height as f32);
|
|
||||||
let mut processed = output
|
|
||||||
.postprocess(&config)
|
|
||||||
.attach_printable("Failed to postprocess")?;
|
|
||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
let factor = Vector2::new(image_size.0 as f32, image_size.1 as f32);
|
||||||
|
|
||||||
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
||||||
.bbox
|
.bbox
|
||||||
.iter()
|
.iter()
|
||||||
@@ -381,79 +299,27 @@ impl FaceDetection {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
/// Common trait for face detection backends
|
||||||
#[rustfmt::skip]
|
pub trait FaceDetector {
|
||||||
let mut resized = image
|
/// Run inference on the model and return raw outputs
|
||||||
.fast_resize(1024, 1024, None)
|
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput>;
|
||||||
.change_context(Error)?
|
|
||||||
.mapv(|f| f as f32)
|
|
||||||
.tap_mut(|arr| {
|
|
||||||
arr.axis_iter_mut(ndarray::Axis(2))
|
|
||||||
.zip([104, 117, 123])
|
|
||||||
.for_each(|(mut array, pixel)| {
|
|
||||||
let pixel = pixel as f32;
|
|
||||||
array.map_inplace(|v| *v -= pixel);
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.permuted_axes((2, 0, 1))
|
|
||||||
.insert_axis(ndarray::Axis(0))
|
|
||||||
.as_standard_layout()
|
|
||||||
.into_owned();
|
|
||||||
use ::tap::*;
|
|
||||||
let output = self
|
|
||||||
.handle
|
|
||||||
.run(move |sr| {
|
|
||||||
let tensor = resized
|
|
||||||
.as_mnn_tensor_mut()
|
|
||||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
|
||||||
.change_context(mnn::error::ErrorKind::TensorError)?;
|
|
||||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
|
||||||
let (intptr, session) = sr.both_mut();
|
|
||||||
tracing::trace!("Copying input tensor to host");
|
|
||||||
unsafe {
|
|
||||||
let mut input = intptr.input_unresized::<f32>(session, "input")?;
|
|
||||||
tracing::trace!("Input shape: {:?}", input.shape());
|
|
||||||
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
|
|
||||||
input.view_mut(),
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
1024,
|
|
||||||
1024,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
intptr.resize_session(session);
|
|
||||||
let mut input = intptr.input::<f32>(session, "input")?;
|
|
||||||
tracing::trace!("Input shape: {:?}", input.shape());
|
|
||||||
input.copy_from_host_tensor(tensor.view())?;
|
|
||||||
|
|
||||||
tracing::info!("Running face detection session");
|
/// Detect faces with full pipeline including postprocessing
|
||||||
intptr.run_session(&session)?;
|
fn detect_faces(
|
||||||
let output_tensor = intptr
|
&mut self,
|
||||||
.output::<f32>(&session, "bbox")?
|
image: ndarray::ArrayView3<u8>,
|
||||||
.create_host_tensor_from_device(true)
|
config: FaceDetectionConfig,
|
||||||
.as_ndarray()
|
) -> Result<FaceDetectionOutput> {
|
||||||
.to_owned();
|
let (height, width, _channels) = image.dim();
|
||||||
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
|
let output = self
|
||||||
let output_confidence = intptr
|
.run_model(image)
|
||||||
.output::<f32>(&session, "confidence")?
|
.change_context(Error)
|
||||||
.create_host_tensor_from_device(true)
|
.attach_printable("Failed to detect faces")?;
|
||||||
.as_ndarray::<ndarray::Ix3>()
|
|
||||||
.to_owned();
|
let processed = output
|
||||||
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
|
.postprocess(&config)
|
||||||
let output_landmark = intptr
|
.attach_printable("Failed to postprocess")?;
|
||||||
.output::<f32>(&session, "landmark")?
|
|
||||||
.create_host_tensor_from_device(true)
|
apply_nms_and_finalize(processed, &config, (width, height))
|
||||||
.as_ndarray::<ndarray::Ix3>()
|
|
||||||
.to_owned();
|
|
||||||
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
|
|
||||||
Ok(FaceDetectionModelOutput {
|
|
||||||
bbox: output_tensor,
|
|
||||||
confidence: output_confidence,
|
|
||||||
landmark: output_landmark,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.map_err(|e| e.into_inner())
|
|
||||||
.change_context(Error)?;
|
|
||||||
Ok(output)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1 +1,20 @@
|
|||||||
pub mod facenet;
|
use crate::errors::*;
|
||||||
|
use ndarray::{Array2, ArrayView4};
|
||||||
|
|
||||||
|
pub mod mnn;
|
||||||
|
pub mod ort;
|
||||||
|
|
||||||
|
/// Common trait for face embedding backends
|
||||||
|
pub trait FaceEmbedder {
|
||||||
|
/// Generate embeddings for a batch of face images
|
||||||
|
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience type aliases for different backends
|
||||||
|
pub mod facenet {
|
||||||
|
pub use crate::faceembed::mnn::facenet as mnn;
|
||||||
|
pub use crate::faceembed::ort::facenet as ort;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to MNN implementation for backward compatibility
|
||||||
|
pub use mnn::facenet::EmbeddingGenerator;
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
use crate::errors::{Result, *};
|
|
||||||
use ndarray::*;
|
|
||||||
use ort::*;
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct EmbeddingGenerator {
|
|
||||||
handle: ort::session::Session,
|
|
||||||
}
|
|
||||||
|
|
||||||
// impl EmbeddingGeneratorBuilder {
|
|
||||||
// pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
|
||||||
// Ok(Self {
|
|
||||||
// schedule_config: None,
|
|
||||||
// backend_config: None,
|
|
||||||
// model: mnn::Interpreter::from_bytes(model.as_ref())
|
|
||||||
// .map_err(|e| e.into_inner())
|
|
||||||
// .change_context(Error)
|
|
||||||
// .attach_printable("Failed to load model from bytes")?,
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
|
||||||
// self.schedule_config
|
|
||||||
// .get_or_insert_default()
|
|
||||||
// .set_type(forward_type);
|
|
||||||
// self
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
|
||||||
// self.schedule_config = Some(config);
|
|
||||||
// self
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
|
||||||
// self.backend_config = Some(config);
|
|
||||||
// self
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// pub fn build(self) -> Result<EmbeddingGenerator> {
|
|
||||||
// let model = self.model;
|
|
||||||
// let sc = self.schedule_config.unwrap_or_default();
|
|
||||||
// let handle = mnn_sync::SessionHandle::new(model, sc)
|
|
||||||
// .change_context(Error)
|
|
||||||
// .attach_printable("Failed to create session handle")?;
|
|
||||||
// Ok(EmbeddingGenerator { handle })
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
impl EmbeddingGenerator {
|
|
||||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
|
||||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
|
||||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
|
||||||
let model = std::fs::read(path)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to read model file")?;
|
|
||||||
Self::new_from_bytes(&model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result<Self> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
// pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {}
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
|
use crate::faceembed::FaceEmbedder;
|
||||||
use mnn_bridge::ndarray::*;
|
use mnn_bridge::ndarray::*;
|
||||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
mod mnn_impl;
|
|
||||||
mod ort_impl;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct EmbeddingGenerator {
|
pub struct EmbeddingGenerator {
|
||||||
@@ -151,3 +150,9 @@ impl EmbeddingGenerator {
|
|||||||
// todo!()
|
// todo!()
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FaceEmbedder for EmbeddingGenerator {
|
||||||
|
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||||
|
self.run_models(faces)
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/faceembed/mnn/mod.rs
Normal file
3
src/faceembed/mnn/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod facenet;
|
||||||
|
|
||||||
|
pub use facenet::EmbeddingGenerator;
|
||||||
79
src/faceembed/ort/facenet.rs
Normal file
79
src/faceembed/ort/facenet.rs
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
use crate::errors::*;
|
||||||
|
use crate::faceembed::FaceEmbedder;
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use ndarray::{Array2, ArrayView4};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EmbeddingGenerator {
|
||||||
|
// Placeholder - ORT implementation to be completed later
|
||||||
|
_placeholder: (),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmbeddingGeneratorBuilder {
|
||||||
|
_model_data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingGeneratorBuilder {
|
||||||
|
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
_model_data: model.as_ref().to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_execution_providers(self, _providers: Vec<String>) -> Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_intra_threads(self, _threads: usize) -> Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_inter_threads(self, _threads: usize) -> Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> crate::errors::Result<EmbeddingGenerator> {
|
||||||
|
// TODO: Implement ORT session creation
|
||||||
|
tracing::warn!("ORT FaceNet implementation is not yet complete");
|
||||||
|
Ok(EmbeddingGenerator { _placeholder: () })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingGenerator {
|
||||||
|
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||||
|
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||||
|
|
||||||
|
pub fn builder<T: AsRef<[u8]>>() -> fn(
|
||||||
|
T,
|
||||||
|
) -> std::result::Result<
|
||||||
|
EmbeddingGeneratorBuilder,
|
||||||
|
error_stack::Report<crate::errors::Error>,
|
||||||
|
> {
|
||||||
|
EmbeddingGeneratorBuilder::new
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
||||||
|
let model = std::fs::read(path)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to read model file")?;
|
||||||
|
Self::new_from_bytes(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||||
|
tracing::info!("Loading face embedding model from bytes");
|
||||||
|
Self::builder()(model)?.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
|
// TODO: Implement ORT inference
|
||||||
|
tracing::error!("ORT FaceNet inference not yet implemented");
|
||||||
|
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceEmbedder for EmbeddingGenerator {
|
||||||
|
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
|
self.run_models(faces)
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/faceembed/ort/mod.rs
Normal file
3
src/faceembed/ort/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod facenet;
|
||||||
|
|
||||||
|
pub use facenet::EmbeddingGenerator;
|
||||||
60
src/main.rs
60
src/main.rs
@@ -1,7 +1,7 @@
|
|||||||
mod cli;
|
mod cli;
|
||||||
mod errors;
|
mod errors;
|
||||||
use bounding_box::roi::MultiRoi;
|
use bounding_box::roi::MultiRoi;
|
||||||
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
use detector::{facedet, facedet::FaceDetectionConfig, faceembed};
|
||||||
use errors::*;
|
use errors::*;
|
||||||
use fast_image_resize::ResizeOptions;
|
use fast_image_resize::ResizeOptions;
|
||||||
use ndarray::*;
|
use ndarray::*;
|
||||||
@@ -20,19 +20,62 @@ pub fn main() -> Result<()> {
|
|||||||
let args = <cli::Cli as clap::Parser>::parse();
|
let args = <cli::Cli as clap::Parser>::parse();
|
||||||
match args.cmd {
|
match args.cmd {
|
||||||
cli::SubCommand::Detect(detect) => {
|
cli::SubCommand::Detect(detect) => {
|
||||||
use detector::facedet;
|
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||||
let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
|
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
|
||||||
|
|
||||||
|
match executor {
|
||||||
|
cli::Executor::Mnn => {
|
||||||
|
let retinaface = facedet::mnn::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
.with_forward_type(detect.forward_type)
|
.with_forward_type(detect.forward_type)
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
|
let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
.with_forward_type(detect.forward_type)
|
.with_forward_type(detect.forward_type)
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face embedding model")?;
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
|
run_detection(detect, retinaface, facenet)?;
|
||||||
|
}
|
||||||
|
cli::Executor::Onnx => {
|
||||||
|
// Load ONNX models
|
||||||
|
const RETINAFACE_ONNX_MODEL: &[u8] =
|
||||||
|
include_bytes!("../models/retinaface.onnx");
|
||||||
|
const FACENET_ONNX_MODEL: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||||
|
|
||||||
|
let retinaface = facedet::ort::FaceDetection::builder()(RETINAFACE_ONNX_MODEL)
|
||||||
|
.change_context(Error)?
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face detection model")?;
|
||||||
|
let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_ONNX_MODEL)
|
||||||
|
.change_context(Error)?
|
||||||
|
.build()
|
||||||
|
.change_context(errors::Error)
|
||||||
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
|
run_detection(detect, retinaface, facenet)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cli::SubCommand::List(list) => {
|
||||||
|
println!("List: {:?}", list);
|
||||||
|
}
|
||||||
|
cli::SubCommand::Completions { shell } => {
|
||||||
|
cli::Cli::completions(shell);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()>
|
||||||
|
where
|
||||||
|
D: facedet::FaceDetector,
|
||||||
|
E: faceembed::FaceEmbedder,
|
||||||
|
{
|
||||||
let image = image::open(detect.image).change_context(Error)?;
|
let image = image::open(detect.image).change_context(Error)?;
|
||||||
let image = image.into_rgb8();
|
let image = image.into_rgb8();
|
||||||
let mut array = image
|
let mut array = image
|
||||||
@@ -112,13 +155,6 @@ pub fn main() -> Result<()> {
|
|||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to save output image")?;
|
.attach_printable("Failed to save output image")?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
cli::SubCommand::List(list) => {
|
|
||||||
println!("List: {:?}", list);
|
|
||||||
}
|
|
||||||
cli::SubCommand::Completions { shell } => {
|
|
||||||
cli::Cli::completions(shell);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user