diff --git a/rfcs b/rfcs index ad85f4c..98ec027 160000 --- a/rfcs +++ b/rfcs @@ -1 +1 @@ -Subproject commit ad85f4c8197b2b7e052a2fea062b855fff4533bf +Subproject commit 98ec027ca5aa27b390b91e7d619c2512f76574d3 diff --git a/src/facedet.rs b/src/facedet.rs index c50a491..2a366be 100644 --- a/src/facedet.rs +++ b/src/facedet.rs @@ -1,24 +1,8 @@ -pub mod mnn; -pub mod ort; -pub mod postprocess; +pub mod retinaface; pub mod yolo; // Re-export common types and traits -pub use postprocess::{ +pub use retinaface::{ FaceDetectionConfig, FaceDetectionModelOutput, FaceDetectionOutput, FaceDetectionProcessedOutput, FaceDetector, FaceLandmarks, }; - -// Convenience type aliases for different backends -pub mod retinaface { - pub use crate::facedet::mnn::retinaface as mnn; - pub use crate::facedet::ort::retinaface as ort; - - // Re-export common types - pub use crate::facedet::postprocess::{ - FaceDetectionConfig, FaceDetectionOutput, FaceDetector, FaceLandmarks, - }; -} - -// Default to MNN implementation for backward compatibility -pub use mnn::retinaface::FaceDetection; diff --git a/src/facedet/mnn/mod.rs b/src/facedet/mnn/mod.rs deleted file mode 100644 index 067752d..0000000 --- a/src/facedet/mnn/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod retinaface; - -pub use retinaface::FaceDetection; diff --git a/src/facedet/ort/mod.rs b/src/facedet/ort/mod.rs deleted file mode 100644 index 067752d..0000000 --- a/src/facedet/ort/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod retinaface; - -pub use retinaface::FaceDetection; diff --git a/src/facedet/postprocess.rs b/src/facedet/retinaface.rs similarity index 99% rename from src/facedet/postprocess.rs rename to src/facedet/retinaface.rs index 22aa9af..874618e 100644 --- a/src/facedet/postprocess.rs +++ b/src/facedet/retinaface.rs @@ -1,8 +1,10 @@ +pub mod mnn; +pub mod ort; + use crate::errors::*; use bounding_box::{Aabb2, nms::nms}; use error_stack::ResultExt; use nalgebra::{Point2, Vector2}; -use std::collections::HashMap; /// Configuration for face detection postprocessing #[derive(Debug, Clone, PartialEq)] diff --git a/src/facedet/mnn/retinaface.rs b/src/facedet/retinaface/mnn.rs similarity index 99% rename from src/facedet/mnn/retinaface.rs rename to src/facedet/retinaface/mnn.rs index c02ada1..4a513f8 100644 --- a/src/facedet/mnn/retinaface.rs +++ b/src/facedet/retinaface/mnn.rs @@ -1,5 +1,5 @@ use crate::errors::*; -use crate::facedet::postprocess::*; +use crate::facedet::*; use error_stack::ResultExt; use mnn_bridge::ndarray::*; use ndarray_resize::NdFir; diff --git a/src/facedet/ort/retinaface.rs b/src/facedet/retinaface/ort.rs similarity index 99% rename from src/facedet/ort/retinaface.rs rename to src/facedet/retinaface/ort.rs index 10a870b..e57cbbc 100644 --- a/src/facedet/ort/retinaface.rs +++ b/src/facedet/retinaface/ort.rs @@ -1,5 +1,5 @@ use crate::errors::*; -use crate::facedet::postprocess::*; +use crate::facedet::*; use error_stack::ResultExt; use ndarray_resize::NdFir; use ort::{ diff --git a/src/main.rs b/src/main.rs index 5685524..4b15bec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,8 +7,10 @@ use fast_image_resize::ResizeOptions; use ndarray::*; use ndarray_image::*; use ndarray_resize::NdFir; -const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn"); -const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn"); +const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn"); +const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn"); +const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx"); +const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx"); const CHUNK_SIZE: usize = 8; pub fn main() -> Result<()> { tracing_subscriber::fmt() @@ -25,13 +27,14 @@ pub fn main() -> Result<()> { match executor { cli::Executor::Mnn => { - let retinaface = facedet::mnn::FaceDetection::builder()(RETINAFACE_MODEL) - .change_context(Error)? - .with_forward_type(detect.forward_type) - .build() - .change_context(errors::Error) - .attach_printable("Failed to create face detection model")?; - let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL) + let retinaface = + facedet::retinaface::mnn::FaceDetection::builder()(RETINAFACE_MODEL_MNN) + .change_context(Error)? + .with_forward_type(detect.forward_type) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL_MNN) .change_context(Error)? .with_forward_type(detect.forward_type) .build() @@ -41,17 +44,13 @@ pub fn main() -> Result<()> { run_detection(detect, retinaface, facenet)?; } cli::Executor::Onnx => { - // Load ONNX models - const RETINAFACE_ONNX_MODEL: &[u8] = - include_bytes!("../models/retinaface.onnx"); - const FACENET_ONNX_MODEL: &[u8] = include_bytes!("../models/facenet.onnx"); - - let retinaface = facedet::ort::FaceDetection::builder()(RETINAFACE_ONNX_MODEL) - .change_context(Error)? - .build() - .change_context(errors::Error) - .attach_printable("Failed to create face detection model")?; - let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_ONNX_MODEL) + let retinaface = + facedet::retinaface::ort::FaceDetection::builder()(RETINAFACE_MODEL_ONNX) + .change_context(Error)? + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face detection model")?; + let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_MODEL_ONNX) .change_context(Error)? .build() .change_context(errors::Error)