feat: Added cli features for mnn and ort

This commit is contained in:
uttarayan21
2025-08-18 15:07:17 +05:30
parent e7c9c38ed7
commit 3aa95a2ef5
7 changed files with 181 additions and 54 deletions

View File

@@ -1,4 +1,6 @@
use std::path::PathBuf;
use mnn::ForwardType;
#[derive(Debug, clap::Parser)]
pub struct Cli {
#[clap(subcommand)]
@@ -21,23 +23,10 @@ pub enum Models {
Yolo,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
#[derive(Debug, Clone)]
pub enum Executor {
Mnn,
Onnx,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum OnnxEp {
Cpu,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum MnnEp {
Cpu,
Metal,
OpenCL,
CoreML,
Mnn(mnn::ForwardType),
Ort(Vec<detector::ort_ep::ExecutionProvider>),
}
#[derive(Debug, clap::Args)]
@@ -48,10 +37,21 @@ pub struct Detect {
pub model_type: Models,
#[clap(short, long)]
pub output: Option<PathBuf>,
#[clap(short = 'e', long)]
pub executor: Option<Executor>,
#[clap(short, long, default_value = "cpu")]
pub forward_type: mnn::ForwardType,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short,
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]

View File

@@ -1,11 +1,10 @@
use crate::errors::*;
use crate::facedet::*;
use crate::ort_ep::*;
use error_stack::ResultExt;
use ndarray_resize::NdFir;
use ort::{
execution_providers::{
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
},
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
session::{Session, builder::GraphOptimizationLevel},
value::Tensor,
};
@@ -33,18 +32,11 @@ impl FaceDetectionBuilder {
})
}
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
let execution_providers: Vec<ExecutionProviderDispatch> = providers
.into_iter()
.filter_map(|provider| match provider.as_str() {
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
#[cfg(target_os = "macos")]
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
_ => {
tracing::warn!("Unknown execution provider: {}", provider);
None
}
})
.as_ref()
.iter()
.filter_map(|provider| provider.to_dispatch())
.collect();
if !execution_providers.is_empty() {

View File

@@ -1,11 +1,10 @@
use crate::errors::*;
use crate::faceembed::facenet::FaceNetEmbedder;
use crate::ort_ep::*;
use error_stack::ResultExt;
use ndarray::{Array2, ArrayView4};
use ort::{
execution_providers::{
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
},
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
session::{Session, builder::GraphOptimizationLevel},
value::Tensor,
};
@@ -33,18 +32,11 @@ impl EmbeddingGeneratorBuilder {
})
}
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
let execution_providers: Vec<ExecutionProviderDispatch> = providers
.into_iter()
.filter_map(|provider| match provider.as_str() {
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
#[cfg(target_os = "macos")]
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
_ => {
tracing::warn!("Unknown execution provider: {}", provider);
None
}
})
.as_ref()
.iter()
.filter_map(|provider| provider.to_dispatch())
.collect();
if !execution_providers.is_empty() {

View File

@@ -2,4 +2,6 @@ pub mod errors;
pub mod facedet;
pub mod faceembed;
pub mod image;
pub mod ort_ep;
use errors::*;

View File

@@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
const CHUNK_SIZE: usize = 8;
const CHUNK_SIZE: usize = 2;
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("trace")
@@ -23,37 +23,52 @@ pub fn main() -> Result<()> {
match args.cmd {
cli::SubCommand::Detect(detect) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
let executor = detect
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
// .then_some(cli::Executor::Mnn)
// .unwrap_or(cli::Executor::Ort);
match executor {
cli::Executor::Mnn => {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(detect.forward_type)
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(detect.forward_type)
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
cli::Executor::Onnx => {
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;

119
src/ort_ep.rs Normal file
View File

@@ -0,0 +1,119 @@
#[cfg(feature = "ort-cuda")]
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(feature = "ort-coreml")]
use ort::execution_providers::CoreMLExecutionProvider;
#[cfg(feature = "ort-tensorrt")]
use ort::execution_providers::TensorRTExecutionProvider;
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
/// Supported execution providers for ONNX Runtime
#[derive(Debug, Clone)]
pub enum ExecutionProvider {
/// CPU execution provider (always available)
CPU,
/// CoreML execution provider (macOS only)
CoreML,
/// CUDA execution provider (requires cuda feature)
CUDA,
/// TensorRT execution provider (requires tensorrt feature)
TensorRT,
}
impl std::fmt::Display for ExecutionProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionProvider::CPU => write!(f, "CPU"),
ExecutionProvider::CoreML => write!(f, "CoreML"),
ExecutionProvider::CUDA => write!(f, "CUDA"),
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
}
}
}
impl std::str::FromStr for ExecutionProvider {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"cpu" => Ok(ExecutionProvider::CPU),
"coreml" => Ok(ExecutionProvider::CoreML),
"cuda" => Ok(ExecutionProvider::CUDA),
"tensorrt" => Ok(ExecutionProvider::TensorRT),
_ => Err(format!("Unknown execution provider: {}", s)),
}
}
}
impl ExecutionProvider {
/// Returns all available execution providers for the current platform and features
pub fn available_providers() -> Vec<ExecutionProvider> {
vec![
ExecutionProvider::CPU,
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
ExecutionProvider::CoreML,
#[cfg(feature = "ort-cuda")]
ExecutionProvider::CUDA,
#[cfg(feature = "ort-tensorrt")]
ExecutionProvider::TensorRT,
]
}
/// Check if this execution provider is available on the current platform
pub fn is_available(&self) -> bool {
match self {
ExecutionProvider::CPU => true,
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
}
}
}
impl ExecutionProvider {
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
match self {
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
ExecutionProvider::CoreML => {
#[cfg(target_os = "macos")]
{
#[cfg(feature = "ort-coreml")]
{
Some(CoreMLExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-coreml"))]
{
tracing::warn!("coreml support not compiled in");
None
}
}
#[cfg(not(target_os = "macos"))]
{
tracing::warn!("CoreML is only available on macOS");
None
}
}
ExecutionProvider::CUDA => {
#[cfg(feature = "ort-cuda")]
{
Some(CUDAExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-cuda"))]
{
tracing::warn!("CUDA support not compiled in");
None
}
}
ExecutionProvider::TensorRT => {
#[cfg(feature = "ort-tensorrt")]
{
Some(TensorRTExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-tensorrt"))]
{
tracing::warn!("TensorRT support not compiled in");
None
}
}
}
}
}