Files
face-detector/src/faceembed/facenet/mnn.rs
uttarayan21 bfa389b497
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m18s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled
feat(compare): add face comparison functionality with cosine similarity
2025-08-21 17:34:07 +05:30

128 lines
4.7 KiB
Rust

use crate::errors::*;
use crate::faceembed::facenet::FaceNetEmbedder;
use mnn_bridge::ndarray::*;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use std::path::Path;
#[derive(Debug)]
pub struct EmbeddingGenerator {
handle: mnn_sync::SessionHandle,
}
pub struct EmbeddingGeneratorBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl EmbeddingGeneratorBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<EmbeddingGenerator> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(EmbeddingGenerator { handle })
}
}
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn builder<T: AsRef<[u8]>>(
model: T,
) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
EmbeddingGeneratorBuilder::new(model)
}
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = crate::faceembed::preprocessing::preprocess(face);
let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32);
let output = self
.handle
.run(move |sr| {
let tensor = tensor
.as_mnn_tensor()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
let needs_resize = unsafe {
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
if *input.shape() != shape {
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
// input.resize(shape);
intptr.resize_tensor(input.view_mut(), shape);
true
} else {
false
}
};
if needs_resize {
tracing::trace!("Resized input tensor to shape: {:?}", shape);
let now = std::time::Instant::now();
intptr.resize_session(session);
tracing::trace!("Session resized in {:?}", now.elapsed());
}
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, Self::OUTPUT_NAME)?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
Ok(output_tensor)
})
.change_context(Error)?;
Ok(output)
}
}
impl FaceNetEmbedder for EmbeddingGenerator {
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
self.run_models(faces)
}
}
// Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
EmbeddingGenerator::run_models(self, faces)
}
}