feat: Added more ort execution_provider
This commit is contained in:
@@ -46,7 +46,7 @@ pub struct Detect {
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short,
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
|
||||
@@ -2,6 +2,12 @@
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
use ort::execution_providers::CoreMLExecutionProvider;
|
||||
#[cfg(feature = "ort-directml")]
|
||||
use ort::execution_providers::DirectMLExecutionProvider;
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
use ort::execution_providers::OpenVINOExecutionProvider;
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
use ort::execution_providers::TVMExecutionProvider;
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
use ort::execution_providers::TensorRTExecutionProvider;
|
||||
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||
@@ -17,6 +23,12 @@ pub enum ExecutionProvider {
|
||||
CUDA,
|
||||
/// TensorRT execution provider (requires tensorrt feature)
|
||||
TensorRT,
|
||||
/// TVM execution provider (requires tvm feature)
|
||||
TVM,
|
||||
/// OpenVINO execution provider (requires openvino feature)
|
||||
OpenVINO,
|
||||
/// DirectML execution provider (Windows only, requires directml feature)
|
||||
DirectML,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ExecutionProvider {
|
||||
@@ -26,6 +38,9 @@ impl std::fmt::Display for ExecutionProvider {
|
||||
ExecutionProvider::CoreML => write!(f, "CoreML"),
|
||||
ExecutionProvider::CUDA => write!(f, "CUDA"),
|
||||
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
|
||||
ExecutionProvider::TVM => write!(f, "TVM"),
|
||||
ExecutionProvider::OpenVINO => write!(f, "OpenVINO"),
|
||||
ExecutionProvider::DirectML => write!(f, "DirectML"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,6 +54,9 @@ impl std::str::FromStr for ExecutionProvider {
|
||||
"coreml" => Ok(ExecutionProvider::CoreML),
|
||||
"cuda" => Ok(ExecutionProvider::CUDA),
|
||||
"tensorrt" => Ok(ExecutionProvider::TensorRT),
|
||||
"tvm" => Ok(ExecutionProvider::TVM),
|
||||
"openvino" => Ok(ExecutionProvider::OpenVINO),
|
||||
"directml" => Ok(ExecutionProvider::DirectML),
|
||||
_ => Err(format!("Unknown execution provider: {}", s)),
|
||||
}
|
||||
}
|
||||
@@ -55,6 +73,12 @@ impl ExecutionProvider {
|
||||
ExecutionProvider::CUDA,
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
ExecutionProvider::TensorRT,
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
ExecutionProvider::TVM,
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
ExecutionProvider::OpenVINO,
|
||||
#[cfg(all(target_os = "windows", feature = "ort-directml"))]
|
||||
ExecutionProvider::DirectML,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -65,6 +89,11 @@ impl ExecutionProvider {
|
||||
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
|
||||
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
|
||||
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
|
||||
ExecutionProvider::TVM => cfg!(feature = "ort-tvm"),
|
||||
ExecutionProvider::OpenVINO => cfg!(feature = "ort-openvino"),
|
||||
ExecutionProvider::DirectML => {
|
||||
cfg!(target_os = "windows") && cfg!(feature = "ort-directml")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,13 +111,13 @@ impl ExecutionProvider {
|
||||
}
|
||||
#[cfg(not(feature = "ort-coreml"))]
|
||||
{
|
||||
tracing::warn!("coreml support not compiled in");
|
||||
tracing::error!("coreml support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
tracing::warn!("CoreML is only available on macOS");
|
||||
tracing::error!("CoreML is only available on macOS");
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -99,7 +128,7 @@ impl ExecutionProvider {
|
||||
}
|
||||
#[cfg(not(feature = "ort-cuda"))]
|
||||
{
|
||||
tracing::warn!("CUDA support not compiled in");
|
||||
tracing::error!("CUDA support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -110,7 +139,48 @@ impl ExecutionProvider {
|
||||
}
|
||||
#[cfg(not(feature = "ort-tensorrt"))]
|
||||
{
|
||||
tracing::warn!("TensorRT support not compiled in");
|
||||
tracing::error!("TensorRT support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::TVM => {
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
{
|
||||
Some(TVMExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-tvm"))]
|
||||
{
|
||||
tracing::error!("TVM support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::OpenVINO => {
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
{
|
||||
Some(OpenVINOExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-openvino"))]
|
||||
{
|
||||
tracing::error!("OpenVINO support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::DirectML => {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
#[cfg(feature = "ort-directml")]
|
||||
{
|
||||
Some(DirectMLExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-directml"))]
|
||||
{
|
||||
tracing::error!("DirectML support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
tracing::error!("DirectML is only available on Windows");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user