feat: Changed the struct for retinaface
Some checks failed
build / checks-matrix (push) Has been cancelled
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-18 11:59:09 +05:30
parent 33afbfc2b8
commit 050e937408
8 changed files with 27 additions and 48 deletions

View File

@@ -0,0 +1,174 @@
use crate::errors::*;
use crate::facedet::*;
use error_stack::ResultExt;
use mnn_bridge::ndarray::*;
use ndarray_resize::NdFir;
use std::path::Path;
#[derive(Debug)]
pub struct FaceDetection {
handle: mnn_sync::SessionHandle,
}
pub struct FaceDetectionBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl FaceDetectionBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<FaceDetection> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
}
impl FaceDetection {
pub fn builder<T: AsRef<[u8]>>()
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
FaceDetectionBuilder::new
}
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
tracing::info!("Loading face detection model from bytes");
let mut model = mnn::Interpreter::from_bytes(model)
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
model
.set_cache_file("retinaface.cache", 128)
.change_context(Error)
.attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::Metal)
.with_backend_config(bc);
tracing::info!("Creating session handle for face detection model");
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
}
impl FaceDetector for FaceDetection {
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
#[rustfmt::skip]
let mut resized = image
.fast_resize(1024, 1024, None)
.change_context(Error)?
.mapv(|f| f as f32);
// Apply mean subtraction: [104, 117, 123]
resized
.axis_iter_mut(ndarray::Axis(2))
.zip([104, 117, 123])
.for_each(|(mut array, pixel)| {
let pixel = pixel as f32;
array.map_inplace(|v| *v -= pixel);
});
let mut resized = resized
.permuted_axes((2, 0, 1))
.insert_axis(ndarray::Axis(0))
.as_standard_layout()
.into_owned();
use ::tap::*;
let output = self
.handle
.run(move |sr| {
let tensor = resized
.as_mnn_tensor_mut()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::error::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
unsafe {
let mut input = intptr.input_unresized::<f32>(session, "input")?;
tracing::trace!("Input shape: {:?}", input.shape());
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
input.view_mut(),
1,
3,
1024,
1024,
);
}
intptr.resize_session(session);
let mut input = intptr.input::<f32>(session, "input")?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, "bbox")?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
let output_confidence = intptr
.output::<f32>(&session, "confidence")?
.create_host_tensor_from_device(true)
.as_ndarray::<ndarray::Ix3>()
.to_owned();
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
let output_landmark = intptr
.output::<f32>(&session, "landmark")?
.create_host_tensor_from_device(true)
.as_ndarray::<ndarray::Ix3>()
.to_owned();
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
Ok(FaceDetectionModelOutput {
bbox: output_tensor,
confidence: output_confidence,
landmark: output_landmark,
})
})
.map_err(|e| e.into_inner())
.change_context(Error)?;
Ok(output)
}
}