feat: Added more ort execution_provider
Some checks failed
build / checks-matrix (push) Failing after 19m0s
build / checks-build (push) Has been skipped
build / codecov (push) Failing after 19m3s
docs / docs (push) Failing after 28m31s

This commit is contained in:
uttarayan21
2025-08-18 16:31:16 +05:30
parent 3aa95a2ef5
commit 7fc958b299
8 changed files with 261 additions and 199 deletions

View File

@@ -46,7 +46,7 @@ pub struct Detect {
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short,
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"

View File

@@ -2,6 +2,12 @@
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(feature = "ort-coreml")]
use ort::execution_providers::CoreMLExecutionProvider;
#[cfg(feature = "ort-directml")]
use ort::execution_providers::DirectMLExecutionProvider;
#[cfg(feature = "ort-openvino")]
use ort::execution_providers::OpenVINOExecutionProvider;
#[cfg(feature = "ort-tvm")]
use ort::execution_providers::TVMExecutionProvider;
#[cfg(feature = "ort-tensorrt")]
use ort::execution_providers::TensorRTExecutionProvider;
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
@@ -17,6 +23,12 @@ pub enum ExecutionProvider {
CUDA,
/// TensorRT execution provider (requires tensorrt feature)
TensorRT,
/// TVM execution provider (requires tvm feature)
TVM,
/// OpenVINO execution provider (requires openvino feature)
OpenVINO,
/// DirectML execution provider (Windows only, requires directml feature)
DirectML,
}
impl std::fmt::Display for ExecutionProvider {
@@ -26,6 +38,9 @@ impl std::fmt::Display for ExecutionProvider {
ExecutionProvider::CoreML => write!(f, "CoreML"),
ExecutionProvider::CUDA => write!(f, "CUDA"),
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
ExecutionProvider::TVM => write!(f, "TVM"),
ExecutionProvider::OpenVINO => write!(f, "OpenVINO"),
ExecutionProvider::DirectML => write!(f, "DirectML"),
}
}
}
@@ -39,6 +54,9 @@ impl std::str::FromStr for ExecutionProvider {
"coreml" => Ok(ExecutionProvider::CoreML),
"cuda" => Ok(ExecutionProvider::CUDA),
"tensorrt" => Ok(ExecutionProvider::TensorRT),
"tvm" => Ok(ExecutionProvider::TVM),
"openvino" => Ok(ExecutionProvider::OpenVINO),
"directml" => Ok(ExecutionProvider::DirectML),
_ => Err(format!("Unknown execution provider: {}", s)),
}
}
@@ -55,6 +73,12 @@ impl ExecutionProvider {
ExecutionProvider::CUDA,
#[cfg(feature = "ort-tensorrt")]
ExecutionProvider::TensorRT,
#[cfg(feature = "ort-tvm")]
ExecutionProvider::TVM,
#[cfg(feature = "ort-openvino")]
ExecutionProvider::OpenVINO,
#[cfg(all(target_os = "windows", feature = "ort-directml"))]
ExecutionProvider::DirectML,
]
}
@@ -65,6 +89,11 @@ impl ExecutionProvider {
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
ExecutionProvider::TVM => cfg!(feature = "ort-tvm"),
ExecutionProvider::OpenVINO => cfg!(feature = "ort-openvino"),
ExecutionProvider::DirectML => {
cfg!(target_os = "windows") && cfg!(feature = "ort-directml")
}
}
}
}
@@ -82,13 +111,13 @@ impl ExecutionProvider {
}
#[cfg(not(feature = "ort-coreml"))]
{
tracing::warn!("coreml support not compiled in");
tracing::error!("coreml support not compiled in");
None
}
}
#[cfg(not(target_os = "macos"))]
{
tracing::warn!("CoreML is only available on macOS");
tracing::error!("CoreML is only available on macOS");
None
}
}
@@ -99,7 +128,7 @@ impl ExecutionProvider {
}
#[cfg(not(feature = "ort-cuda"))]
{
tracing::warn!("CUDA support not compiled in");
tracing::error!("CUDA support not compiled in");
None
}
}
@@ -110,7 +139,48 @@ impl ExecutionProvider {
}
#[cfg(not(feature = "ort-tensorrt"))]
{
tracing::warn!("TensorRT support not compiled in");
tracing::error!("TensorRT support not compiled in");
None
}
}
ExecutionProvider::TVM => {
#[cfg(feature = "ort-tvm")]
{
Some(TVMExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-tvm"))]
{
tracing::error!("TVM support not compiled in");
None
}
}
ExecutionProvider::OpenVINO => {
#[cfg(feature = "ort-openvino")]
{
Some(OpenVINOExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-openvino"))]
{
tracing::error!("OpenVINO support not compiled in");
None
}
}
ExecutionProvider::DirectML => {
#[cfg(target_os = "windows")]
{
#[cfg(feature = "ort-directml")]
{
Some(DirectMLExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-directml"))]
{
tracing::error!("DirectML support not compiled in");
None
}
}
#[cfg(not(target_os = "windows"))]
{
tracing::error!("DirectML is only available on Windows");
None
}
}