feat: Added cli features for mnn and ort
This commit is contained in:
@@ -57,3 +57,10 @@ ort = "2.0.0-rc.10"
|
|||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
|
|
||||||
|
[features]
|
||||||
|
ort-cuda = ["ort/cuda"]
|
||||||
|
ort-coreml = ["ort/coreml"]
|
||||||
|
ort-tensorrt = ["ort/tensorrt"]
|
||||||
|
|
||||||
|
default = ["ort-coreml"]
|
||||||
|
|||||||
40
src/cli.rs
40
src/cli.rs
@@ -1,4 +1,6 @@
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use mnn::ForwardType;
|
||||||
#[derive(Debug, clap::Parser)]
|
#[derive(Debug, clap::Parser)]
|
||||||
pub struct Cli {
|
pub struct Cli {
|
||||||
#[clap(subcommand)]
|
#[clap(subcommand)]
|
||||||
@@ -21,23 +23,10 @@ pub enum Models {
|
|||||||
Yolo,
|
Yolo,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Executor {
|
pub enum Executor {
|
||||||
Mnn,
|
Mnn(mnn::ForwardType),
|
||||||
Onnx,
|
Ort(Vec<detector::ort_ep::ExecutionProvider>),
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
|
||||||
pub enum OnnxEp {
|
|
||||||
Cpu,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
|
||||||
pub enum MnnEp {
|
|
||||||
Cpu,
|
|
||||||
Metal,
|
|
||||||
OpenCL,
|
|
||||||
CoreML,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, clap::Args)]
|
#[derive(Debug, clap::Args)]
|
||||||
@@ -48,10 +37,21 @@ pub struct Detect {
|
|||||||
pub model_type: Models,
|
pub model_type: Models,
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
pub output: Option<PathBuf>,
|
pub output: Option<PathBuf>,
|
||||||
#[clap(short = 'e', long)]
|
#[clap(
|
||||||
pub executor: Option<Executor>,
|
short = 'p',
|
||||||
#[clap(short, long, default_value = "cpu")]
|
long,
|
||||||
pub forward_type: mnn::ForwardType,
|
default_value = "cpu",
|
||||||
|
group = "execution_provider",
|
||||||
|
required_unless_present = "mnn_forward_type"
|
||||||
|
)]
|
||||||
|
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||||
|
#[clap(
|
||||||
|
short,
|
||||||
|
long,
|
||||||
|
group = "execution_provider",
|
||||||
|
required_unless_present = "ort_execution_provider"
|
||||||
|
)]
|
||||||
|
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||||
#[clap(short, long, default_value_t = 0.8)]
|
#[clap(short, long, default_value_t = 0.8)]
|
||||||
pub threshold: f32,
|
pub threshold: f32,
|
||||||
#[clap(short, long, default_value_t = 0.3)]
|
#[clap(short, long, default_value_t = 0.3)]
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use crate::facedet::*;
|
use crate::facedet::*;
|
||||||
|
use crate::ort_ep::*;
|
||||||
use error_stack::ResultExt;
|
use error_stack::ResultExt;
|
||||||
use ndarray_resize::NdFir;
|
use ndarray_resize::NdFir;
|
||||||
use ort::{
|
use ort::{
|
||||||
execution_providers::{
|
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
|
||||||
},
|
|
||||||
session::{Session, builder::GraphOptimizationLevel},
|
session::{Session, builder::GraphOptimizationLevel},
|
||||||
value::Tensor,
|
value::Tensor,
|
||||||
};
|
};
|
||||||
@@ -33,18 +32,11 @@ impl FaceDetectionBuilder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||||
.into_iter()
|
.as_ref()
|
||||||
.filter_map(|provider| match provider.as_str() {
|
.iter()
|
||||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
.filter_map(|provider| provider.to_dispatch())
|
||||||
#[cfg(target_os = "macos")]
|
|
||||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
|
||||||
_ => {
|
|
||||||
tracing::warn!("Unknown execution provider: {}", provider);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if !execution_providers.is_empty() {
|
if !execution_providers.is_empty() {
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use crate::faceembed::facenet::FaceNetEmbedder;
|
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||||
|
use crate::ort_ep::*;
|
||||||
use error_stack::ResultExt;
|
use error_stack::ResultExt;
|
||||||
use ndarray::{Array2, ArrayView4};
|
use ndarray::{Array2, ArrayView4};
|
||||||
use ort::{
|
use ort::{
|
||||||
execution_providers::{
|
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||||
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
|
||||||
},
|
|
||||||
session::{Session, builder::GraphOptimizationLevel},
|
session::{Session, builder::GraphOptimizationLevel},
|
||||||
value::Tensor,
|
value::Tensor,
|
||||||
};
|
};
|
||||||
@@ -33,18 +32,11 @@ impl EmbeddingGeneratorBuilder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||||
.into_iter()
|
.as_ref()
|
||||||
.filter_map(|provider| match provider.as_str() {
|
.iter()
|
||||||
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
.filter_map(|provider| provider.to_dispatch())
|
||||||
#[cfg(target_os = "macos")]
|
|
||||||
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
|
||||||
_ => {
|
|
||||||
tracing::warn!("Unknown execution provider: {}", provider);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if !execution_providers.is_empty() {
|
if !execution_providers.is_empty() {
|
||||||
|
|||||||
@@ -2,4 +2,6 @@ pub mod errors;
|
|||||||
pub mod facedet;
|
pub mod facedet;
|
||||||
pub mod faceembed;
|
pub mod faceembed;
|
||||||
pub mod image;
|
pub mod image;
|
||||||
|
pub mod ort_ep;
|
||||||
|
|
||||||
use errors::*;
|
use errors::*;
|
||||||
|
|||||||
27
src/main.rs
27
src/main.rs
@@ -11,7 +11,7 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
|
|||||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
||||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||||
const CHUNK_SIZE: usize = 8;
|
const CHUNK_SIZE: usize = 2;
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("trace")
|
.with_env_filter("trace")
|
||||||
@@ -23,37 +23,52 @@ pub fn main() -> Result<()> {
|
|||||||
match args.cmd {
|
match args.cmd {
|
||||||
cli::SubCommand::Detect(detect) => {
|
cli::SubCommand::Detect(detect) => {
|
||||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||||
let executor = detect.executor.unwrap_or(cli::Executor::Mnn);
|
|
||||||
|
let executor = detect
|
||||||
|
.mnn_forward_type
|
||||||
|
.map(|f| cli::Executor::Mnn(f))
|
||||||
|
.or_else(|| {
|
||||||
|
if detect.ort_execution_provider.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||||
|
// .then_some(cli::Executor::Mnn)
|
||||||
|
// .unwrap_or(cli::Executor::Ort);
|
||||||
|
|
||||||
match executor {
|
match executor {
|
||||||
cli::Executor::Mnn => {
|
cli::Executor::Mnn(forward) => {
|
||||||
let retinaface =
|
let retinaface =
|
||||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
.with_forward_type(detect.forward_type)
|
.with_forward_type(forward)
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
let facenet =
|
let facenet =
|
||||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
.with_forward_type(detect.forward_type)
|
.with_forward_type(forward)
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face embedding model")?;
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|
||||||
run_detection(detect, retinaface, facenet)?;
|
run_detection(detect, retinaface, facenet)?;
|
||||||
}
|
}
|
||||||
cli::Executor::Onnx => {
|
cli::Executor::Ort(ep) => {
|
||||||
let retinaface =
|
let retinaface =
|
||||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
|
.with_execution_providers(&ep)
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
let facenet =
|
let facenet =
|
||||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
|
.with_execution_providers(ep)
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face embedding model")?;
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
|
|||||||
119
src/ort_ep.rs
Normal file
119
src/ort_ep.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
use ort::execution_providers::CUDAExecutionProvider;
|
||||||
|
#[cfg(feature = "ort-coreml")]
|
||||||
|
use ort::execution_providers::CoreMLExecutionProvider;
|
||||||
|
#[cfg(feature = "ort-tensorrt")]
|
||||||
|
use ort::execution_providers::TensorRTExecutionProvider;
|
||||||
|
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||||
|
|
||||||
|
/// Supported execution providers for ONNX Runtime
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ExecutionProvider {
|
||||||
|
/// CPU execution provider (always available)
|
||||||
|
CPU,
|
||||||
|
/// CoreML execution provider (macOS only)
|
||||||
|
CoreML,
|
||||||
|
/// CUDA execution provider (requires cuda feature)
|
||||||
|
CUDA,
|
||||||
|
/// TensorRT execution provider (requires tensorrt feature)
|
||||||
|
TensorRT,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ExecutionProvider {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ExecutionProvider::CPU => write!(f, "CPU"),
|
||||||
|
ExecutionProvider::CoreML => write!(f, "CoreML"),
|
||||||
|
ExecutionProvider::CUDA => write!(f, "CUDA"),
|
||||||
|
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for ExecutionProvider {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"cpu" => Ok(ExecutionProvider::CPU),
|
||||||
|
"coreml" => Ok(ExecutionProvider::CoreML),
|
||||||
|
"cuda" => Ok(ExecutionProvider::CUDA),
|
||||||
|
"tensorrt" => Ok(ExecutionProvider::TensorRT),
|
||||||
|
_ => Err(format!("Unknown execution provider: {}", s)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExecutionProvider {
|
||||||
|
/// Returns all available execution providers for the current platform and features
|
||||||
|
pub fn available_providers() -> Vec<ExecutionProvider> {
|
||||||
|
vec![
|
||||||
|
ExecutionProvider::CPU,
|
||||||
|
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
|
||||||
|
ExecutionProvider::CoreML,
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
ExecutionProvider::CUDA,
|
||||||
|
#[cfg(feature = "ort-tensorrt")]
|
||||||
|
ExecutionProvider::TensorRT,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this execution provider is available on the current platform
|
||||||
|
pub fn is_available(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ExecutionProvider::CPU => true,
|
||||||
|
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
|
||||||
|
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
|
||||||
|
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExecutionProvider {
|
||||||
|
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
|
||||||
|
match self {
|
||||||
|
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
|
||||||
|
ExecutionProvider::CoreML => {
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
{
|
||||||
|
#[cfg(feature = "ort-coreml")]
|
||||||
|
{
|
||||||
|
Some(CoreMLExecutionProvider::default().build())
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "ort-coreml"))]
|
||||||
|
{
|
||||||
|
tracing::warn!("coreml support not compiled in");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
{
|
||||||
|
tracing::warn!("CoreML is only available on macOS");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExecutionProvider::CUDA => {
|
||||||
|
#[cfg(feature = "ort-cuda")]
|
||||||
|
{
|
||||||
|
Some(CUDAExecutionProvider::default().build())
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "ort-cuda"))]
|
||||||
|
{
|
||||||
|
tracing::warn!("CUDA support not compiled in");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExecutionProvider::TensorRT => {
|
||||||
|
#[cfg(feature = "ort-tensorrt")]
|
||||||
|
{
|
||||||
|
Some(TensorRTExecutionProvider::default().build())
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "ort-tensorrt"))]
|
||||||
|
{
|
||||||
|
tracing::warn!("TensorRT support not compiled in");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user