Files
face-detector/src/main.rs
uttarayan21 783320131a
Some checks failed
build / checks-build (push) Has been cancelled
build / checks-matrix (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
feat: Added stuff
2025-08-18 22:10:29 +05:30

176 lines
6.9 KiB
Rust

mod cli;
mod errors;
use bounding_box::roi::MultiRoi;
use detector::{facedet, facedet::FaceDetectionConfig, faceembed};
use errors::*;
use fast_image_resize::ResizeOptions;
use ndarray::*;
use ndarray_image::*;
use ndarray_resize::NdFir;
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("error")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
.init();
let args = <cli::Cli as clap::Parser>::parse();
match args.cmd {
cli::SubCommand::Detect(detect) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
}
}
cli::SubCommand::List(list) => {
println!("List: {:?}", list);
}
cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell);
}
}
Ok(())
}
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
let image = image::open(&detect.image)
.change_context(Error)
.attach_printable(detect.image.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let mut array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(
array.view(),
FaceDetectionConfig::default()
.with_threshold(detect.threshold)
.with_nms_threshold(detect.nms_threshold),
)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
for bbox in &output.bbox {
tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
// .inspect(|f| {
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
// })
.map(|roi| {
roi.as_standard_layout()
.fast_resize(160, 160, &ResizeOptions::default())
.change_context(Error)
})
// .inspect(|f| {
// f.as_ref().inspect(|f| {
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
// });
// })
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = detect.batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len());
if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((160, 160, 3));
let zero_array = core::iter::repeat(zeros.view())
.take(chunk_size)
.collect::<Vec<_>>();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
}
})
.collect::<Result<Vec<Array2<f32>>>>()?;
let v = array.view();
if let Some(output) = detect.output {
let image: image::RgbImage = v
.to_image()
.change_context(errors::Error)
.attach_printable("Failed to convert ndarray to image")?;
image
.save(output)
.change_context(errors::Error)
.attach_printable("Failed to save output image")?;
}
Ok(())
}