feat: Working retinaface
This commit is contained in:
@@ -6,15 +6,77 @@ use nalgebra::{Point2, Vector2};
|
||||
use ndarray_resize::NdFir;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct FaceDetectionConfig {}
|
||||
/// Configuration for face detection postprocessing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionConfig {
|
||||
/// Minimum confidence to keep a detection
|
||||
pub threshold: f32,
|
||||
/// NMS threshold for suppressing overlapping boxes
|
||||
pub nms_threshold: f32,
|
||||
/// Variances for bounding box decoding
|
||||
pub variances: [f32; 2],
|
||||
/// The step size (stride) for each feature map
|
||||
pub steps: Vec<usize>,
|
||||
/// The minimum anchor sizes for each feature map
|
||||
pub min_sizes: Vec<Vec<usize>>,
|
||||
/// Whether to clip bounding boxes to the image dimensions
|
||||
pub clamp: bool,
|
||||
/// Input image width (used for anchor generation)
|
||||
pub input_width: usize,
|
||||
/// Input image height (used for anchor generation)
|
||||
pub input_height: usize,
|
||||
}
|
||||
|
||||
impl FaceDetectionConfig {}
|
||||
impl FaceDetectionConfig {
|
||||
pub fn with_threshold(mut self, threshold: f32) -> Self {
|
||||
self.threshold = threshold;
|
||||
self
|
||||
}
|
||||
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
|
||||
self.nms_threshold = nms_threshold;
|
||||
self
|
||||
}
|
||||
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
|
||||
self.variances = variances;
|
||||
self
|
||||
}
|
||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||
self.steps = steps;
|
||||
self
|
||||
}
|
||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
|
||||
self.min_sizes = min_sizes;
|
||||
self
|
||||
}
|
||||
pub fn with_clip(mut self, clip: bool) -> Self {
|
||||
self.clamp = clip;
|
||||
self
|
||||
}
|
||||
pub fn with_input_width(mut self, input_width: usize) -> Self {
|
||||
self.input_width = input_width;
|
||||
self
|
||||
}
|
||||
pub fn with_input_height(mut self, input_height: usize) -> Self {
|
||||
self.input_height = input_height;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FaceDetectionConfig {
|
||||
fn default() -> Self {
|
||||
FaceDetectionConfig {}
|
||||
Self {
|
||||
threshold: 0.5,
|
||||
nms_threshold: 0.4,
|
||||
variances: [0.1, 0.2],
|
||||
steps: vec![8, 16, 32],
|
||||
min_sizes: vec![vec![16, 32], vec![64, 128], vec![256, 512]],
|
||||
clamp: true,
|
||||
input_width: 1024,
|
||||
input_height: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
@@ -50,8 +112,112 @@ pub struct FaceDetectionOutput {
|
||||
pub landmark: Vec<FaceLandmarks>,
|
||||
}
|
||||
|
||||
fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
||||
let mut anchors = Vec::new();
|
||||
let feature_maps: Vec<(usize, usize)> = config
|
||||
.steps
|
||||
.iter()
|
||||
.map(|&step| {
|
||||
(
|
||||
(config.input_height as f32 / step as f32).ceil() as usize,
|
||||
(config.input_width as f32 / step as f32).ceil() as usize,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (k, f) in feature_maps.iter().enumerate() {
|
||||
let min_sizes = &config.min_sizes[k];
|
||||
for i in 0..f.0 {
|
||||
for j in 0..f.1 {
|
||||
for &min_size in min_sizes {
|
||||
let s_kx = min_size as f32 / config.input_width as f32;
|
||||
let s_ky = min_size as f32 / config.input_height as f32;
|
||||
let dense_cx =
|
||||
(j as f32 + 0.5) * config.steps[k] as f32 / config.input_width as f32;
|
||||
let dense_cy =
|
||||
(i as f32 + 0.5) * config.steps[k] as f32 / config.input_height as f32;
|
||||
anchors.push([
|
||||
dense_cx - s_kx / 2.,
|
||||
dense_cy - s_ky / 2.,
|
||||
dense_cx + s_kx / 2.,
|
||||
dense_cy + s_ky / 2.,
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ndarray::Array2::from_shape_vec((anchors.len(), 4), anchors.into_iter().flatten().collect())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
impl FaceDetectionModelOutput {
|
||||
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||
use ndarray::s;
|
||||
|
||||
let priors = generate_anchors(config);
|
||||
|
||||
let scores = self.confidence.slice(s![0, .., 1]);
|
||||
let boxes = self.bbox.slice(s![0, .., ..]);
|
||||
let landmarks_raw = self.landmark.slice(s![0, .., ..]);
|
||||
|
||||
let mut decoded_boxes = Vec::new();
|
||||
let mut decoded_landmarks = Vec::new();
|
||||
let mut confidences = Vec::new();
|
||||
|
||||
for i in 0..priors.shape()[0] {
|
||||
if scores[i] > config.threshold {
|
||||
let prior = priors.row(i);
|
||||
let loc = boxes.row(i);
|
||||
let landm = landmarks_raw.row(i);
|
||||
|
||||
// Decode bounding box
|
||||
let prior_cx = (prior[0] + prior[2]) / 2.0;
|
||||
let prior_cy = (prior[1] + prior[3]) / 2.0;
|
||||
let prior_w = prior[2] - prior[0];
|
||||
let prior_h = prior[3] - prior[1];
|
||||
|
||||
let var = config.variances;
|
||||
let cx = prior_cx + loc[0] * var[0] * prior_w;
|
||||
let cy = prior_cy + loc[1] * var[0] * prior_h;
|
||||
let w = prior_w * (loc[2] * var[1]).exp();
|
||||
let h = prior_h * (loc[3] * var[1]).exp();
|
||||
|
||||
let xmin = cx - w / 2.0;
|
||||
let ymin = cy - h / 2.0;
|
||||
let xmax = cx + w / 2.0;
|
||||
let ymax = cy + h / 2.0;
|
||||
|
||||
let mut bbox =
|
||||
Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax));
|
||||
if config.clamp {
|
||||
bbox.component_clamp(0.0, 1.0);
|
||||
}
|
||||
decoded_boxes.push(bbox);
|
||||
|
||||
// Decode landmarks
|
||||
let mut points = [Point2::new(0.0, 0.0); 5];
|
||||
for j in 0..5 {
|
||||
points[j].x = prior_cx + landm[j * 2] * var[0] * prior_w;
|
||||
points[j].y = prior_cy + landm[j * 2 + 1] * var[0] * prior_h;
|
||||
}
|
||||
let landmarks = FaceLandmarks {
|
||||
left_eye: points[0],
|
||||
right_eye: points[1],
|
||||
nose: points[2],
|
||||
left_mouth: points[3],
|
||||
right_mouth: points[4],
|
||||
};
|
||||
decoded_landmarks.push(landmarks);
|
||||
confidences.push(scores[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(FaceDetectionProcessedOutput {
|
||||
bbox: decoded_boxes,
|
||||
confidence: confidences,
|
||||
landmarks: decoded_landmarks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user