feat: Changed the struct for retinaface
This commit is contained in:
2
rfcs
2
rfcs
Submodule rfcs updated: ad85f4c819...98ec027ca5
@@ -1,24 +1,8 @@
|
||||
pub mod mnn;
|
||||
pub mod ort;
|
||||
pub mod postprocess;
|
||||
pub mod retinaface;
|
||||
pub mod yolo;
|
||||
|
||||
// Re-export common types and traits
|
||||
pub use postprocess::{
|
||||
pub use retinaface::{
|
||||
FaceDetectionConfig, FaceDetectionModelOutput, FaceDetectionOutput,
|
||||
FaceDetectionProcessedOutput, FaceDetector, FaceLandmarks,
|
||||
};
|
||||
|
||||
// Convenience type aliases for different backends
|
||||
pub mod retinaface {
|
||||
pub use crate::facedet::mnn::retinaface as mnn;
|
||||
pub use crate::facedet::ort::retinaface as ort;
|
||||
|
||||
// Re-export common types
|
||||
pub use crate::facedet::postprocess::{
|
||||
FaceDetectionConfig, FaceDetectionOutput, FaceDetector, FaceLandmarks,
|
||||
};
|
||||
}
|
||||
|
||||
// Default to MNN implementation for backward compatibility
|
||||
pub use mnn::retinaface::FaceDetection;
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod retinaface;
|
||||
|
||||
pub use retinaface::FaceDetection;
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod retinaface;
|
||||
|
||||
pub use retinaface::FaceDetection;
|
||||
@@ -1,8 +1,10 @@
|
||||
pub mod mnn;
|
||||
pub mod ort;
|
||||
|
||||
use crate::errors::*;
|
||||
use bounding_box::{Aabb2, nms::nms};
|
||||
use error_stack::ResultExt;
|
||||
use nalgebra::{Point2, Vector2};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for face detection postprocessing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::postprocess::*;
|
||||
use crate::facedet::*;
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray_resize::NdFir;
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::postprocess::*;
|
||||
use crate::facedet::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_resize::NdFir;
|
||||
use ort::{
|
||||
39
src/main.rs
39
src/main.rs
@@ -7,8 +7,10 @@ use fast_image_resize::ResizeOptions;
|
||||
use ndarray::*;
|
||||
use ndarray_image::*;
|
||||
use ndarray_resize::NdFir;
|
||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
@@ -25,13 +27,14 @@ pub fn main() -> Result<()> {
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn => {
|
||||
let retinaface = facedet::mnn::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder()(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
@@ -41,17 +44,13 @@ pub fn main() -> Result<()> {
|
||||
run_detection(detect, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Onnx => {
|
||||
// Load ONNX models
|
||||
const RETINAFACE_ONNX_MODEL: &[u8] =
|
||||
include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_ONNX_MODEL: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
|
||||
let retinaface = facedet::ort::FaceDetection::builder()(RETINAFACE_ONNX_MODEL)
|
||||
.change_context(Error)?
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_ONNX_MODEL)
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder()(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
|
||||
Reference in New Issue
Block a user