feat: Added stuff
This commit is contained in:
3
src/facedet/mnn/mod.rs
Normal file
3
src/facedet/mnn/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod retinaface;
|
||||
|
||||
pub use retinaface::FaceDetection;
|
||||
174
src/facedet/mnn/retinaface.rs
Normal file
174
src/facedet/mnn/retinaface.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::postprocess::*;
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray_resize::NdFir;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<FaceDetection> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
||||
FaceDetectionBuilder::new
|
||||
}
|
||||
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||
tracing::info!("Loading face detection model from bytes");
|
||||
let mut model = mnn::Interpreter::from_bytes(model)
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
model
|
||||
.set_cache_file("retinaface.cache", 128)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::Metal)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face detection model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetector for FaceDetection {
|
||||
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(Error)?
|
||||
.mapv(|f| f as f32);
|
||||
|
||||
// Apply mean subtraction: [104, 117, 123]
|
||||
resized
|
||||
.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104, 117, 123])
|
||||
.for_each(|(mut array, pixel)| {
|
||||
let pixel = pixel as f32;
|
||||
array.map_inplace(|v| *v -= pixel);
|
||||
});
|
||||
|
||||
let mut resized = resized
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = resized
|
||||
.as_mnn_tensor_mut()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::error::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
|
||||
input.view_mut(),
|
||||
1,
|
||||
3,
|
||||
1024,
|
||||
1024,
|
||||
);
|
||||
}
|
||||
intptr.resize_session(session);
|
||||
let mut input = intptr.input::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, "bbox")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
|
||||
let output_confidence = intptr
|
||||
.output::<f32>(&session, "confidence")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
|
||||
let output_landmark = intptr
|
||||
.output::<f32>(&session, "landmark")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
|
||||
Ok(FaceDetectionModelOutput {
|
||||
bbox: output_tensor,
|
||||
confidence: output_confidence,
|
||||
landmark: output_landmark,
|
||||
})
|
||||
})
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
3
src/facedet/ort/mod.rs
Normal file
3
src/facedet/ort/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod retinaface;
|
||||
|
||||
pub use retinaface::FaceDetection;
|
||||
264
src/facedet/ort/retinaface.rs
Normal file
264
src/facedet/ort/retinaface.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::postprocess::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_resize::NdFir;
|
||||
use ort::{
|
||||
execution_providers::{
|
||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||
},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FaceDetection {
|
||||
session: Session,
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
model_data: Vec<u8>,
|
||||
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
|
||||
intra_threads: Option<usize>,
|
||||
inter_threads: Option<usize>,
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||
Ok(Self {
|
||||
model_data: model.as_ref().to_vec(),
|
||||
execution_providers: None,
|
||||
intra_threads: None,
|
||||
inter_threads: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.into_iter()
|
||||
.filter_map(|provider| match provider.as_str() {
|
||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||
#[cfg(target_os = "macos")]
|
||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||
_ => {
|
||||
tracing::warn!("Unknown execution provider: {}", provider);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
self.execution_providers = Some(execution_providers);
|
||||
} else {
|
||||
tracing::warn!("No valid execution providers found, falling back to CPU");
|
||||
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_intra_threads(mut self, threads: usize) -> Self {
|
||||
self.intra_threads = Some(threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_inter_threads(mut self, threads: usize) -> Self {
|
||||
self.inter_threads = Some(threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> crate::errors::Result<FaceDetection> {
|
||||
let mut session_builder = Session::builder()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session builder")?;
|
||||
|
||||
// Set execution providers
|
||||
if let Some(providers) = self.execution_providers {
|
||||
session_builder = session_builder
|
||||
.with_execution_providers(providers)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set execution providers")?;
|
||||
} else {
|
||||
// Default to CPU
|
||||
session_builder = session_builder
|
||||
.with_execution_providers([CPUExecutionProvider::default().build()])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set default CPU execution provider")?;
|
||||
}
|
||||
|
||||
// Set threading options
|
||||
if let Some(threads) = self.intra_threads {
|
||||
session_builder = session_builder
|
||||
.with_intra_threads(threads)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set intra threads")?;
|
||||
}
|
||||
|
||||
if let Some(threads) = self.inter_threads {
|
||||
session_builder = session_builder
|
||||
.with_inter_threads(threads)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set inter threads")?;
|
||||
}
|
||||
|
||||
// Set optimization level
|
||||
session_builder = session_builder
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set optimization level")?;
|
||||
|
||||
// Create session from model bytes
|
||||
let session = session_builder
|
||||
.commit_from_memory(&self.model_data)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||
|
||||
tracing::info!("Successfully created ORT RetinaFace session");
|
||||
|
||||
Ok(FaceDetection { session })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>>
|
||||
{
|
||||
FaceDetectionBuilder::new
|
||||
}
|
||||
|
||||
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> crate::errors::Result<Self> {
|
||||
tracing::info!("Loading ORT RetinaFace model from bytes");
|
||||
Self::builder()(model)?.build()
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetector for FaceDetection {
|
||||
fn run_model(
|
||||
&mut self,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
) -> crate::errors::Result<FaceDetectionModelOutput> {
|
||||
// Resize image to 1024x1024
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to resize image")?
|
||||
.mapv(|f| f as f32);
|
||||
|
||||
// Apply mean subtraction: [104, 117, 123] for BGR format
|
||||
resized
|
||||
.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104.0, 117.0, 123.0])
|
||||
.for_each(|(mut array, mean)| {
|
||||
array.map_inplace(|v| *v -= mean);
|
||||
});
|
||||
|
||||
// Convert from HWC to NCHW format (add batch dimension and transpose)
|
||||
let input_tensor = resized
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
// Create ORT input tensor
|
||||
let input_value = Tensor::from_array(input_tensor)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create input tensor")?;
|
||||
|
||||
// Run inference
|
||||
tracing::debug!("Running ORT RetinaFace inference");
|
||||
let outputs = self
|
||||
.session
|
||||
.run(ort::inputs!["input" => input_value])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run inference")?;
|
||||
|
||||
// Extract outputs by name
|
||||
let bbox_output = outputs
|
||||
.get("bbox")
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing bbox output from model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract bbox tensor")?;
|
||||
|
||||
let confidence_output = outputs
|
||||
.get("confidence")
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing confidence output from model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract confidence tensor")?;
|
||||
|
||||
let landmark_output = outputs
|
||||
.get("landmark")
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing landmark output from model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract landmark tensor")?;
|
||||
|
||||
// Get tensor shapes and data
|
||||
let (bbox_shape, bbox_data) = bbox_output;
|
||||
let (confidence_shape, confidence_data) = confidence_output;
|
||||
let (landmark_shape, landmark_data) = landmark_output;
|
||||
|
||||
tracing::trace!(
|
||||
"Output shapes - bbox: {:?}, confidence: {:?}, landmark: {:?}",
|
||||
bbox_shape,
|
||||
confidence_shape,
|
||||
landmark_shape
|
||||
);
|
||||
|
||||
// Convert to ndarray format
|
||||
let bbox_dims = bbox_shape.as_ref();
|
||||
let confidence_dims = confidence_shape.as_ref();
|
||||
let landmark_dims = landmark_shape.as_ref();
|
||||
|
||||
let bbox_array = ndarray::Array3::from_shape_vec(
|
||||
(
|
||||
bbox_dims[0] as usize,
|
||||
bbox_dims[1] as usize,
|
||||
bbox_dims[2] as usize,
|
||||
),
|
||||
bbox_data.to_vec(),
|
||||
)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create bbox ndarray")?;
|
||||
|
||||
let confidence_array = ndarray::Array3::from_shape_vec(
|
||||
(
|
||||
confidence_dims[0] as usize,
|
||||
confidence_dims[1] as usize,
|
||||
confidence_dims[2] as usize,
|
||||
),
|
||||
confidence_data.to_vec(),
|
||||
)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create confidence ndarray")?;
|
||||
|
||||
let landmark_array = ndarray::Array3::from_shape_vec(
|
||||
(
|
||||
landmark_dims[0] as usize,
|
||||
landmark_dims[1] as usize,
|
||||
landmark_dims[2] as usize,
|
||||
),
|
||||
landmark_data.to_vec(),
|
||||
)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create landmark ndarray")?;
|
||||
|
||||
Ok(FaceDetectionModelOutput {
|
||||
bbox: bbox_array,
|
||||
confidence: confidence_array,
|
||||
landmark: landmark_array,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,8 @@
|
||||
use crate::errors::*;
|
||||
use bounding_box::{Aabb2, nms::nms};
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use nalgebra::{Point2, Vector2};
|
||||
use ndarray_resize::NdFir;
|
||||
use std::path::Path;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for face detection postprocessing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -32,30 +30,37 @@ impl FaceDetectionConfig {
|
||||
self.threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
|
||||
self.nms_threshold = nms_threshold;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
|
||||
self.variances = variances;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||
self.steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
|
||||
self.min_sizes = min_sizes;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_clip(mut self, clip: bool) -> Self {
|
||||
self.clamp = clip;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_input_width(mut self, input_width: usize) -> Self {
|
||||
self.input_width = input_width;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_input_height(mut self, input_height: usize) -> Self {
|
||||
self.input_height = input_height;
|
||||
self
|
||||
@@ -77,18 +82,6 @@ impl Default for FaceDetectionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionModelOutput {
|
||||
pub bbox: ndarray::Array3<f32>,
|
||||
pub confidence: ndarray::Array3<f32>,
|
||||
pub landmark: ndarray::Array3<f32>,
|
||||
}
|
||||
|
||||
/// Represents the 5 facial landmarks detected by RetinaFace
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct FaceLandmarks {
|
||||
@@ -99,6 +92,13 @@ pub struct FaceLandmarks {
|
||||
pub right_mouth: Point2<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionModelOutput {
|
||||
pub bbox: ndarray::Array3<f32>,
|
||||
pub confidence: ndarray::Array3<f32>,
|
||||
pub landmark: ndarray::Array3<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionProcessedOutput {
|
||||
pub bbox: Vec<Aabb2<f32>>,
|
||||
@@ -113,7 +113,13 @@ pub struct FaceDetectionOutput {
|
||||
pub landmark: Vec<FaceLandmarks>,
|
||||
}
|
||||
|
||||
fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
||||
/// Raw model outputs that can be converted to FaceDetectionModelOutput
|
||||
pub trait IntoModelOutput {
|
||||
fn into_model_output(self) -> Result<FaceDetectionModelOutput>;
|
||||
}
|
||||
|
||||
/// Generate anchors for RetinaFace model
|
||||
pub fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
||||
let mut anchors = Vec::new();
|
||||
let feature_maps: Vec<(usize, usize)> = config
|
||||
.steps
|
||||
@@ -220,9 +226,7 @@ impl FaceDetectionModelOutput {
|
||||
landmarks: decoded_landmarks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetectionModelOutput {
|
||||
pub fn print(&self, limit: usize) {
|
||||
tracing::info!("Detected {} faces", self.bbox.shape()[1]);
|
||||
|
||||
@@ -246,214 +250,76 @@ impl FaceDetectionModelOutput {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
/// Apply Non-Maximum Suppression and convert to final output format
|
||||
pub fn apply_nms_and_finalize(
|
||||
processed: FaceDetectionProcessedOutput,
|
||||
config: &FaceDetectionConfig,
|
||||
image_size: (usize, usize), // (width, height)
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
use itertools::Itertools;
|
||||
|
||||
let factor = Vector2::new(image_size.0 as f32, image_size.1 as f32);
|
||||
|
||||
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
||||
.bbox
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(processed.confidence.iter().cloned())
|
||||
.zip(processed.landmarks.iter().cloned())
|
||||
.sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score))
|
||||
.map(|((b, s), l)| (b, s, l))
|
||||
.multiunzip();
|
||||
|
||||
let keep_indices =
|
||||
nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?;
|
||||
|
||||
let bboxes = boxes
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| keep_indices.contains(i))
|
||||
.flat_map(|(_, x)| x.denormalize(factor).try_cast::<usize>())
|
||||
.collect();
|
||||
let confidence = scores
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| keep_indices.contains(i))
|
||||
.map(|(_, score)| score)
|
||||
.collect();
|
||||
let landmark = landmarks
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| keep_indices.contains(i))
|
||||
.map(|(_, score)| score)
|
||||
.collect();
|
||||
|
||||
Ok(FaceDetectionOutput {
|
||||
bbox: bboxes,
|
||||
confidence,
|
||||
landmark,
|
||||
})
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
/// Common trait for face detection backends
|
||||
pub trait FaceDetector {
|
||||
/// Run inference on the model and return raw outputs
|
||||
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput>;
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<FaceDetection> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
||||
FaceDetectionBuilder::new
|
||||
}
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||
tracing::info!("Loading face detection model from bytes");
|
||||
let mut model = mnn::Interpreter::from_bytes(model)
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
model
|
||||
.set_cache_file("retinaface.cache", 128)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::Metal)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face detection model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
|
||||
pub fn detect_faces(
|
||||
&self,
|
||||
/// Detect faces with full pipeline including postprocessing
|
||||
fn detect_faces(
|
||||
&mut self,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
config: FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, _channels) = image.dim();
|
||||
let output = self
|
||||
.run_models(image)
|
||||
.run_model(image)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
// denormalize the bounding boxes
|
||||
let factor = Vector2::new(width as f32, height as f32);
|
||||
let mut processed = output
|
||||
|
||||
let processed = output
|
||||
.postprocess(&config)
|
||||
.attach_printable("Failed to postprocess")?;
|
||||
|
||||
use itertools::Itertools;
|
||||
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
||||
.bbox
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(processed.confidence.iter().cloned())
|
||||
.zip(processed.landmarks.iter().cloned())
|
||||
.sorted_by_key(|((_, score), _)| ordered_float::OrderedFloat(*score))
|
||||
.map(|((b, s), l)| (b, s, l))
|
||||
.multiunzip();
|
||||
|
||||
let keep_indices =
|
||||
nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?;
|
||||
|
||||
let bboxes = boxes
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| keep_indices.contains(i))
|
||||
.flat_map(|(_, x)| x.denormalize(factor).try_cast::<usize>())
|
||||
.collect();
|
||||
let confidence = scores
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| keep_indices.contains(i))
|
||||
.map(|(_, score)| score)
|
||||
.collect();
|
||||
let landmark = landmarks
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| keep_indices.contains(i))
|
||||
.map(|(_, score)| score)
|
||||
.collect();
|
||||
|
||||
Ok(FaceDetectionOutput {
|
||||
bbox: bboxes,
|
||||
confidence,
|
||||
landmark,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_models(&self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(Error)?
|
||||
.mapv(|f| f as f32)
|
||||
.tap_mut(|arr| {
|
||||
arr.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104, 117, 123])
|
||||
.for_each(|(mut array, pixel)| {
|
||||
let pixel = pixel as f32;
|
||||
array.map_inplace(|v| *v -= pixel);
|
||||
});
|
||||
})
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = resized
|
||||
.as_mnn_tensor_mut()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::error::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
|
||||
input.view_mut(),
|
||||
1,
|
||||
3,
|
||||
1024,
|
||||
1024,
|
||||
);
|
||||
}
|
||||
intptr.resize_session(session);
|
||||
let mut input = intptr.input::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, "bbox")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
|
||||
let output_confidence = intptr
|
||||
.output::<f32>(&session, "confidence")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
|
||||
let output_landmark = intptr
|
||||
.output::<f32>(&session, "landmark")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
|
||||
Ok(FaceDetectionModelOutput {
|
||||
bbox: output_tensor,
|
||||
confidence: output_confidence,
|
||||
landmark: output_landmark,
|
||||
})
|
||||
})
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
apply_nms_and_finalize(processed, &config, (width, height))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user