feat: Added cli features for mnn and ort
This commit is contained in:
27
src/main.rs
27
src/main.rs
@@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
const CHUNK_SIZE: usize = 2;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
@@ -23,37 +23,52 @@ pub fn main() -> Result<()> {
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
|
||||
|
||||
let executor = detect
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if detect.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
// .then_some(cli::Executor::Mnn)
|
||||
// .unwrap_or(cli::Executor::Ort);
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn => {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_detection(detect, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Onnx => {
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
Reference in New Issue
Block a user