feat: Added facenet

This commit is contained in:
uttarayan21
2025-08-08 15:01:25 +05:30
parent a3ea01b7b6
commit d52b69911f
9 changed files with 208 additions and 94 deletions

View File

@@ -1,4 +1,5 @@
use crate::errors::*;
use mnn_bridge::ndarray::*;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use std::path::Path;
@@ -8,6 +9,8 @@ pub struct EmbeddingGenerator {
}
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
@@ -22,9 +25,13 @@ impl EmbeddingGenerator {
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
model
.set_cache_file("facenet.cache", 128)
.change_context(Error)
.attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::CPU)
.with_type(mnn::ForwardType::Metal)
.with_backend_config(bc);
tracing::info!("Creating session handle for face embedding model");
let handle = mnn_sync::SessionHandle::new(model, sc)
@@ -33,11 +40,55 @@ impl EmbeddingGenerator {
Ok(Self { handle })
}
pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
todo!()
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = face
// .permuted_axes((0, 3, 1, 2))
.as_standard_layout()
.mapv(|x| x as f32);
let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32);
let output = self
.handle
.run(move |sr| {
let tensor = tensor
.as_mnn_tensor()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
unsafe {
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
if *input.shape() != shape {
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
// input.resize(shape);
intptr.resize_tensor(input.view_mut(), shape);
}
}
intptr.resize_session(session);
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, Self::OUTPUT_NAME)?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
Ok(output_tensor)
})
.change_context(Error)?;
Ok(output)
}
pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
todo!()
}
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
// todo!()
// }
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
// todo!()
// }
}