265 lines
8.9 KiB
Rust
265 lines
8.9 KiB
Rust
use crate::errors::*;
|
|
use crate::facedet::postprocess::*;
|
|
use error_stack::ResultExt;
|
|
use ndarray_resize::NdFir;
|
|
use ort::{
|
|
execution_providers::{
|
|
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
|
},
|
|
session::{Session, builder::GraphOptimizationLevel},
|
|
value::Tensor,
|
|
};
|
|
use std::path::Path;
|
|
|
|
#[derive(Debug)]
|
|
pub struct FaceDetection {
|
|
session: Session,
|
|
}
|
|
|
|
pub struct FaceDetectionBuilder {
|
|
model_data: Vec<u8>,
|
|
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
|
|
intra_threads: Option<usize>,
|
|
inter_threads: Option<usize>,
|
|
}
|
|
|
|
impl FaceDetectionBuilder {
|
|
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
|
Ok(Self {
|
|
model_data: model.as_ref().to_vec(),
|
|
execution_providers: None,
|
|
intra_threads: None,
|
|
inter_threads: None,
|
|
})
|
|
}
|
|
|
|
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
|
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
|
.into_iter()
|
|
.filter_map(|provider| match provider.as_str() {
|
|
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
|
#[cfg(target_os = "macos")]
|
|
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
|
_ => {
|
|
tracing::warn!("Unknown execution provider: {}", provider);
|
|
None
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
if !execution_providers.is_empty() {
|
|
self.execution_providers = Some(execution_providers);
|
|
} else {
|
|
tracing::warn!("No valid execution providers found, falling back to CPU");
|
|
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn with_intra_threads(mut self, threads: usize) -> Self {
|
|
self.intra_threads = Some(threads);
|
|
self
|
|
}
|
|
|
|
pub fn with_inter_threads(mut self, threads: usize) -> Self {
|
|
self.inter_threads = Some(threads);
|
|
self
|
|
}
|
|
|
|
pub fn build(self) -> crate::errors::Result<FaceDetection> {
|
|
let mut session_builder = Session::builder()
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create session builder")?;
|
|
|
|
// Set execution providers
|
|
if let Some(providers) = self.execution_providers {
|
|
session_builder = session_builder
|
|
.with_execution_providers(providers)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to set execution providers")?;
|
|
} else {
|
|
// Default to CPU
|
|
session_builder = session_builder
|
|
.with_execution_providers([CPUExecutionProvider::default().build()])
|
|
.change_context(Error)
|
|
.attach_printable("Failed to set default CPU execution provider")?;
|
|
}
|
|
|
|
// Set threading options
|
|
if let Some(threads) = self.intra_threads {
|
|
session_builder = session_builder
|
|
.with_intra_threads(threads)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to set intra threads")?;
|
|
}
|
|
|
|
if let Some(threads) = self.inter_threads {
|
|
session_builder = session_builder
|
|
.with_inter_threads(threads)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to set inter threads")?;
|
|
}
|
|
|
|
// Set optimization level
|
|
session_builder = session_builder
|
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to set optimization level")?;
|
|
|
|
// Create session from model bytes
|
|
let session = session_builder
|
|
.commit_from_memory(&self.model_data)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create ORT session from model bytes")?;
|
|
|
|
tracing::info!("Successfully created ORT RetinaFace session");
|
|
|
|
Ok(FaceDetection { session })
|
|
}
|
|
}
|
|
|
|
impl FaceDetection {
|
|
pub fn builder<T: AsRef<[u8]>>()
|
|
-> fn(T) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>>
|
|
{
|
|
FaceDetectionBuilder::new
|
|
}
|
|
|
|
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
|
let model = std::fs::read(path)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to read model file")?;
|
|
Self::new_from_bytes(&model)
|
|
}
|
|
|
|
pub fn new_from_bytes(model: &[u8]) -> crate::errors::Result<Self> {
|
|
tracing::info!("Loading ORT RetinaFace model from bytes");
|
|
Self::builder()(model)?.build()
|
|
}
|
|
}
|
|
|
|
impl FaceDetector for FaceDetection {
|
|
fn run_model(
|
|
&mut self,
|
|
image: ndarray::ArrayView3<u8>,
|
|
) -> crate::errors::Result<FaceDetectionModelOutput> {
|
|
// Resize image to 1024x1024
|
|
let mut resized = image
|
|
.fast_resize(1024, 1024, None)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to resize image")?
|
|
.mapv(|f| f as f32);
|
|
|
|
// Apply mean subtraction: [104, 117, 123] for BGR format
|
|
resized
|
|
.axis_iter_mut(ndarray::Axis(2))
|
|
.zip([104.0, 117.0, 123.0])
|
|
.for_each(|(mut array, mean)| {
|
|
array.map_inplace(|v| *v -= mean);
|
|
});
|
|
|
|
// Convert from HWC to NCHW format (add batch dimension and transpose)
|
|
let input_tensor = resized
|
|
.permuted_axes((2, 0, 1))
|
|
.insert_axis(ndarray::Axis(0))
|
|
.as_standard_layout()
|
|
.into_owned();
|
|
|
|
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
|
|
|
// Create ORT input tensor
|
|
let input_value = Tensor::from_array(input_tensor)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create input tensor")?;
|
|
|
|
// Run inference
|
|
tracing::debug!("Running ORT RetinaFace inference");
|
|
let outputs = self
|
|
.session
|
|
.run(ort::inputs!["input" => input_value])
|
|
.change_context(Error)
|
|
.attach_printable("Failed to run inference")?;
|
|
|
|
// Extract outputs by name
|
|
let bbox_output = outputs
|
|
.get("bbox")
|
|
.ok_or(Error)
|
|
.attach_printable("Missing bbox output from model")?
|
|
.try_extract_tensor::<f32>()
|
|
.change_context(Error)
|
|
.attach_printable("Failed to extract bbox tensor")?;
|
|
|
|
let confidence_output = outputs
|
|
.get("confidence")
|
|
.ok_or(Error)
|
|
.attach_printable("Missing confidence output from model")?
|
|
.try_extract_tensor::<f32>()
|
|
.change_context(Error)
|
|
.attach_printable("Failed to extract confidence tensor")?;
|
|
|
|
let landmark_output = outputs
|
|
.get("landmark")
|
|
.ok_or(Error)
|
|
.attach_printable("Missing landmark output from model")?
|
|
.try_extract_tensor::<f32>()
|
|
.change_context(Error)
|
|
.attach_printable("Failed to extract landmark tensor")?;
|
|
|
|
// Get tensor shapes and data
|
|
let (bbox_shape, bbox_data) = bbox_output;
|
|
let (confidence_shape, confidence_data) = confidence_output;
|
|
let (landmark_shape, landmark_data) = landmark_output;
|
|
|
|
tracing::trace!(
|
|
"Output shapes - bbox: {:?}, confidence: {:?}, landmark: {:?}",
|
|
bbox_shape,
|
|
confidence_shape,
|
|
landmark_shape
|
|
);
|
|
|
|
// Convert to ndarray format
|
|
let bbox_dims = bbox_shape.as_ref();
|
|
let confidence_dims = confidence_shape.as_ref();
|
|
let landmark_dims = landmark_shape.as_ref();
|
|
|
|
let bbox_array = ndarray::Array3::from_shape_vec(
|
|
(
|
|
bbox_dims[0] as usize,
|
|
bbox_dims[1] as usize,
|
|
bbox_dims[2] as usize,
|
|
),
|
|
bbox_data.to_vec(),
|
|
)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create bbox ndarray")?;
|
|
|
|
let confidence_array = ndarray::Array3::from_shape_vec(
|
|
(
|
|
confidence_dims[0] as usize,
|
|
confidence_dims[1] as usize,
|
|
confidence_dims[2] as usize,
|
|
),
|
|
confidence_data.to_vec(),
|
|
)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create confidence ndarray")?;
|
|
|
|
let landmark_array = ndarray::Array3::from_shape_vec(
|
|
(
|
|
landmark_dims[0] as usize,
|
|
landmark_dims[1] as usize,
|
|
landmark_dims[2] as usize,
|
|
),
|
|
landmark_data.to_vec(),
|
|
)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create landmark ndarray")?;
|
|
|
|
Ok(FaceDetectionModelOutput {
|
|
bbox: bbox_array,
|
|
confidence: confidence_array,
|
|
landmark: landmark_array,
|
|
})
|
|
}
|
|
}
|