diff --git a/Cargo.toml b/Cargo.toml index ccea3d8..a5c263e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,3 +57,10 @@ ort = "2.0.0-rc.10" [profile.release] debug = true + +[features] +ort-cuda = ["ort/cuda"] +ort-coreml = ["ort/coreml"] +ort-tensorrt = ["ort/tensorrt"] + +default = ["ort-coreml"] diff --git a/src/cli.rs b/src/cli.rs index 482f2cd..e2c78d6 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,4 +1,6 @@ use std::path::PathBuf; + +use mnn::ForwardType; #[derive(Debug, clap::Parser)] pub struct Cli { #[clap(subcommand)] @@ -21,23 +23,10 @@ pub enum Models { Yolo, } -#[derive(Debug, clap::ValueEnum, Clone, Copy)] +#[derive(Debug, Clone)] pub enum Executor { - Mnn, - Onnx, -} - -#[derive(Debug, clap::ValueEnum, Clone, Copy)] -pub enum OnnxEp { - Cpu, -} - -#[derive(Debug, clap::ValueEnum, Clone, Copy)] -pub enum MnnEp { - Cpu, - Metal, - OpenCL, - CoreML, + Mnn(mnn::ForwardType), + Ort(Vec), } #[derive(Debug, clap::Args)] @@ -48,10 +37,21 @@ pub struct Detect { pub model_type: Models, #[clap(short, long)] pub output: Option, - #[clap(short = 'e', long)] - pub executor: Option, - #[clap(short, long, default_value = "cpu")] - pub forward_type: mnn::ForwardType, + #[clap( + short = 'p', + long, + default_value = "cpu", + group = "execution_provider", + required_unless_present = "mnn_forward_type" + )] + pub ort_execution_provider: Vec, + #[clap( + short, + long, + group = "execution_provider", + required_unless_present = "ort_execution_provider" + )] + pub mnn_forward_type: Option, #[clap(short, long, default_value_t = 0.8)] pub threshold: f32, #[clap(short, long, default_value_t = 0.3)] diff --git a/src/facedet/retinaface/ort.rs b/src/facedet/retinaface/ort.rs index 6349423..afd33ac 100644 --- a/src/facedet/retinaface/ort.rs +++ b/src/facedet/retinaface/ort.rs @@ -1,11 +1,10 @@ use crate::errors::*; use crate::facedet::*; +use crate::ort_ep::*; use error_stack::ResultExt; use ndarray_resize::NdFir; use ort::{ - execution_providers::{ - CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch, - }, + execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}, session::{Session, builder::GraphOptimizationLevel}, value::Tensor, }; @@ -33,18 +32,11 @@ impl FaceDetectionBuilder { }) } - pub fn with_execution_providers(mut self, providers: Vec) -> Self { + pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self { let execution_providers: Vec = providers - .into_iter() - .filter_map(|provider| match provider.as_str() { - "cpu" | "CPU" => Some(CPUExecutionProvider::default().build()), - #[cfg(target_os = "macos")] - "coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()), - _ => { - tracing::warn!("Unknown execution provider: {}", provider); - None - } - }) + .as_ref() + .iter() + .filter_map(|provider| provider.to_dispatch()) .collect(); if !execution_providers.is_empty() { diff --git a/src/faceembed/facenet/ort.rs b/src/faceembed/facenet/ort.rs index 3db56a2..1cd272f 100644 --- a/src/faceembed/facenet/ort.rs +++ b/src/faceembed/facenet/ort.rs @@ -1,11 +1,10 @@ use crate::errors::*; use crate::faceembed::facenet::FaceNetEmbedder; +use crate::ort_ep::*; use error_stack::ResultExt; use ndarray::{Array2, ArrayView4}; use ort::{ - execution_providers::{ - CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch, - }, + execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}, session::{Session, builder::GraphOptimizationLevel}, value::Tensor, }; @@ -33,18 +32,11 @@ impl EmbeddingGeneratorBuilder { }) } - pub fn with_execution_providers(mut self, providers: Vec) -> Self { + pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self { let execution_providers: Vec = providers - .into_iter() - .filter_map(|provider| match provider.as_str() { - "cpu" | "CPU" => Some(CPUExecutionProvider::default().build()), - #[cfg(target_os = "macos")] - "coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()), - _ => { - tracing::warn!("Unknown execution provider: {}", provider); - None - } - }) + .as_ref() + .iter() + .filter_map(|provider| provider.to_dispatch()) .collect(); if !execution_providers.is_empty() { diff --git a/src/lib.rs b/src/lib.rs index 965eab5..14dc5c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,4 +2,6 @@ pub mod errors; pub mod facedet; pub mod faceembed; pub mod image; +pub mod ort_ep; + use errors::*; diff --git a/src/main.rs b/src/main.rs index 2546d5f..200fdcf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn"); const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn"); const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx"); const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx"); -const CHUNK_SIZE: usize = 8; +const CHUNK_SIZE: usize = 2; pub fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter("trace") @@ -23,37 +23,52 @@ pub fn main() -> Result<()> { match args.cmd { cli::SubCommand::Detect(detect) => { // Choose backend based on executor type (defaulting to MNN for backward compatibility) - let executor = detect.executor.unwrap_or(cli::Executor::Mnn); + + let executor = detect + .mnn_forward_type + .map(|f| cli::Executor::Mnn(f)) + .or_else(|| { + if detect.ort_execution_provider.is_empty() { + None + } else { + Some(cli::Executor::Ort(detect.ort_execution_provider.clone())) + } + }) + .unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU)); + // .then_some(cli::Executor::Mnn) + // .unwrap_or(cli::Executor::Ort); match executor { - cli::Executor::Mnn => { + cli::Executor::Mnn(forward) => { let retinaface = facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN) .change_context(Error)? - .with_forward_type(detect.forward_type) + .with_forward_type(forward) .build() .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; let facenet = faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) .change_context(Error)? - .with_forward_type(detect.forward_type) + .with_forward_type(forward) .build() .change_context(errors::Error) .attach_printable("Failed to create face embedding model")?; run_detection(detect, retinaface, facenet)?; } - cli::Executor::Onnx => { + cli::Executor::Ort(ep) => { let retinaface = facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX) .change_context(Error)? + .with_execution_providers(&ep) .build() .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; let facenet = faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX) .change_context(Error)? + .with_execution_providers(ep) .build() .change_context(errors::Error) .attach_printable("Failed to create face embedding model")?; diff --git a/src/ort_ep.rs b/src/ort_ep.rs new file mode 100644 index 0000000..b30ef3a --- /dev/null +++ b/src/ort_ep.rs @@ -0,0 +1,119 @@ +#[cfg(feature = "ort-cuda")] +use ort::execution_providers::CUDAExecutionProvider; +#[cfg(feature = "ort-coreml")] +use ort::execution_providers::CoreMLExecutionProvider; +#[cfg(feature = "ort-tensorrt")] +use ort::execution_providers::TensorRTExecutionProvider; +use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}; + +/// Supported execution providers for ONNX Runtime +#[derive(Debug, Clone)] +pub enum ExecutionProvider { + /// CPU execution provider (always available) + CPU, + /// CoreML execution provider (macOS only) + CoreML, + /// CUDA execution provider (requires cuda feature) + CUDA, + /// TensorRT execution provider (requires tensorrt feature) + TensorRT, +} + +impl std::fmt::Display for ExecutionProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExecutionProvider::CPU => write!(f, "CPU"), + ExecutionProvider::CoreML => write!(f, "CoreML"), + ExecutionProvider::CUDA => write!(f, "CUDA"), + ExecutionProvider::TensorRT => write!(f, "TensorRT"), + } + } +} + +impl std::str::FromStr for ExecutionProvider { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cpu" => Ok(ExecutionProvider::CPU), + "coreml" => Ok(ExecutionProvider::CoreML), + "cuda" => Ok(ExecutionProvider::CUDA), + "tensorrt" => Ok(ExecutionProvider::TensorRT), + _ => Err(format!("Unknown execution provider: {}", s)), + } + } +} + +impl ExecutionProvider { + /// Returns all available execution providers for the current platform and features + pub fn available_providers() -> Vec { + vec![ + ExecutionProvider::CPU, + #[cfg(all(target_os = "macos", feature = "ort-coreml"))] + ExecutionProvider::CoreML, + #[cfg(feature = "ort-cuda")] + ExecutionProvider::CUDA, + #[cfg(feature = "ort-tensorrt")] + ExecutionProvider::TensorRT, + ] + } + + /// Check if this execution provider is available on the current platform + pub fn is_available(&self) -> bool { + match self { + ExecutionProvider::CPU => true, + ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"), + ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"), + ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"), + } + } +} + +impl ExecutionProvider { + pub fn to_dispatch(&self) -> Option { + match self { + ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()), + ExecutionProvider::CoreML => { + #[cfg(target_os = "macos")] + { + #[cfg(feature = "ort-coreml")] + { + Some(CoreMLExecutionProvider::default().build()) + } + #[cfg(not(feature = "ort-coreml"))] + { + tracing::warn!("coreml support not compiled in"); + None + } + } + #[cfg(not(target_os = "macos"))] + { + tracing::warn!("CoreML is only available on macOS"); + None + } + } + ExecutionProvider::CUDA => { + #[cfg(feature = "ort-cuda")] + { + Some(CUDAExecutionProvider::default().build()) + } + #[cfg(not(feature = "ort-cuda"))] + { + tracing::warn!("CUDA support not compiled in"); + None + } + } + ExecutionProvider::TensorRT => { + #[cfg(feature = "ort-tensorrt")] + { + Some(TensorRTExecutionProvider::default().build()) + } + #[cfg(not(feature = "ort-tensorrt"))] + { + tracing::warn!("TensorRT support not compiled in"); + None + } + } + } + } +}