feat: Added cli features for mnn and ort
This commit is contained in:
@@ -57,3 +57,10 @@ ort = "2.0.0-rc.10"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
[features]
|
||||
ort-cuda = ["ort/cuda"]
|
||||
ort-coreml = ["ort/coreml"]
|
||||
ort-tensorrt = ["ort/tensorrt"]
|
||||
|
||||
default = ["ort-coreml"]
|
||||
|
||||
40
src/cli.rs
40
src/cli.rs
@@ -1,4 +1,6 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use mnn::ForwardType;
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
@@ -21,23 +23,10 @@ pub enum Models {
|
||||
Yolo,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Executor {
|
||||
Mnn,
|
||||
Onnx,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum OnnxEp {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum MnnEp {
|
||||
Cpu,
|
||||
Metal,
|
||||
OpenCL,
|
||||
CoreML,
|
||||
Mnn(mnn::ForwardType),
|
||||
Ort(Vec<detector::ort_ep::ExecutionProvider>),
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
@@ -48,10 +37,21 @@ pub struct Detect {
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
#[clap(short = 'e', long)]
|
||||
pub executor: Option<Executor>,
|
||||
#[clap(short, long, default_value = "cpu")]
|
||||
pub forward_type: mnn::ForwardType,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::*;
|
||||
use crate::ort_ep::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_resize::NdFir;
|
||||
use ort::{
|
||||
execution_providers::{
|
||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||
},
|
||||
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
@@ -33,18 +32,11 @@ impl FaceDetectionBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.into_iter()
|
||||
.filter_map(|provider| match provider.as_str() {
|
||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||
#[cfg(target_os = "macos")]
|
||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||
_ => {
|
||||
tracing::warn!("Unknown execution provider: {}", provider);
|
||||
None
|
||||
}
|
||||
})
|
||||
.as_ref()
|
||||
.iter()
|
||||
.filter_map(|provider| provider.to_dispatch())
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::errors::*;
|
||||
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||
use crate::ort_ep::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray::{Array2, ArrayView4};
|
||||
use ort::{
|
||||
execution_providers::{
|
||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||
},
|
||||
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
@@ -33,18 +32,11 @@ impl EmbeddingGeneratorBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.into_iter()
|
||||
.filter_map(|provider| match provider.as_str() {
|
||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||
#[cfg(target_os = "macos")]
|
||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||
_ => {
|
||||
tracing::warn!("Unknown execution provider: {}", provider);
|
||||
None
|
||||
}
|
||||
})
|
||||
.as_ref()
|
||||
.iter()
|
||||
.filter_map(|provider| provider.to_dispatch())
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
|
||||
@@ -2,4 +2,6 @@ pub mod errors;
|
||||
pub mod facedet;
|
||||
pub mod faceembed;
|
||||
pub mod image;
|
||||
pub mod ort_ep;
|
||||
|
||||
use errors::*;
|
||||
|
||||
27
src/main.rs
27
src/main.rs
@@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
const CHUNK_SIZE: usize = 2;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
@@ -23,37 +23,52 @@ pub fn main() -> Result<()> {
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
|
||||
|
||||
let executor = detect
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if detect.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
// .then_some(cli::Executor::Mnn)
|
||||
// .unwrap_or(cli::Executor::Ort);
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn => {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_detection(detect, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Onnx => {
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
119
src/ort_ep.rs
Normal file
119
src/ort_ep.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
use ort::execution_providers::CoreMLExecutionProvider;
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
use ort::execution_providers::TensorRTExecutionProvider;
|
||||
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||
|
||||
/// Supported execution providers for ONNX Runtime
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ExecutionProvider {
|
||||
/// CPU execution provider (always available)
|
||||
CPU,
|
||||
/// CoreML execution provider (macOS only)
|
||||
CoreML,
|
||||
/// CUDA execution provider (requires cuda feature)
|
||||
CUDA,
|
||||
/// TensorRT execution provider (requires tensorrt feature)
|
||||
TensorRT,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ExecutionProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ExecutionProvider::CPU => write!(f, "CPU"),
|
||||
ExecutionProvider::CoreML => write!(f, "CoreML"),
|
||||
ExecutionProvider::CUDA => write!(f, "CUDA"),
|
||||
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ExecutionProvider {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"cpu" => Ok(ExecutionProvider::CPU),
|
||||
"coreml" => Ok(ExecutionProvider::CoreML),
|
||||
"cuda" => Ok(ExecutionProvider::CUDA),
|
||||
"tensorrt" => Ok(ExecutionProvider::TensorRT),
|
||||
_ => Err(format!("Unknown execution provider: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
/// Returns all available execution providers for the current platform and features
|
||||
pub fn available_providers() -> Vec<ExecutionProvider> {
|
||||
vec![
|
||||
ExecutionProvider::CPU,
|
||||
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
|
||||
ExecutionProvider::CoreML,
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutionProvider::CUDA,
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
ExecutionProvider::TensorRT,
|
||||
]
|
||||
}
|
||||
|
||||
/// Check if this execution provider is available on the current platform
|
||||
pub fn is_available(&self) -> bool {
|
||||
match self {
|
||||
ExecutionProvider::CPU => true,
|
||||
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
|
||||
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
|
||||
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
|
||||
match self {
|
||||
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
|
||||
ExecutionProvider::CoreML => {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
{
|
||||
Some(CoreMLExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-coreml"))]
|
||||
{
|
||||
tracing::warn!("coreml support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
tracing::warn!("CoreML is only available on macOS");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::CUDA => {
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
{
|
||||
Some(CUDAExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-cuda"))]
|
||||
{
|
||||
tracing::warn!("CUDA support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::TensorRT => {
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
{
|
||||
Some(TensorRTExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-tensorrt"))]
|
||||
{
|
||||
tracing::warn!("TensorRT support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user