use crate::errors::*; use crate::faceembed::facenet::FaceNetEmbedder; use mnn_bridge::ndarray::*; use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use std::path::Path; #[derive(Debug)] pub struct EmbeddingGenerator { handle: mnn_sync::SessionHandle, } pub struct EmbeddingGeneratorBuilder { schedule_config: Option, backend_config: Option, model: mnn::Interpreter, } impl EmbeddingGeneratorBuilder { pub fn new(model: impl AsRef<[u8]>) -> Result { Ok(Self { schedule_config: None, backend_config: None, model: mnn::Interpreter::from_bytes(model.as_ref()) .map_err(|e| e.into_inner()) .change_context(Error) .attach_printable("Failed to load model from bytes")?, }) } pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self { self.schedule_config .get_or_insert_default() .set_type(forward_type); self } pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self { self.schedule_config = Some(config); self } pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self { self.backend_config = Some(config); self } pub fn build(self) -> Result { let model = self.model; let sc = self.schedule_config.unwrap_or_default(); let handle = mnn_sync::SessionHandle::new(model, sc) .change_context(Error) .attach_printable("Failed to create session handle")?; Ok(EmbeddingGenerator { handle }) } } impl EmbeddingGenerator { const INPUT_NAME: &'static str = "serving_default_input_6:0"; const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; pub fn builder>( model: T, ) -> std::result::Result> { EmbeddingGeneratorBuilder::new(model) } pub fn run_models(&self, face: ArrayView4) -> Result> { let tensor = crate::faceembed::preprocessing::preprocess(face); let shape: [usize; 4] = tensor.dim().into(); let shape = shape.map(|f| f as i32); let output = self .handle .run(move |sr| { let tensor = tensor .as_mnn_tensor() .attach_printable("Failed to convert ndarray to mnn tensor") .change_context(mnn::ErrorKind::TensorError)?; tracing::trace!("Image Tensor shape: {:?}", tensor.shape()); let (intptr, session) = sr.both_mut(); tracing::trace!("Copying input tensor to host"); let needs_resize = unsafe { let mut input = intptr.input_unresized::(session, Self::INPUT_NAME)?; tracing::trace!("Input shape: {:?}", input.shape()); if *input.shape() != shape { tracing::trace!("Resizing input tensor to shape: {:?}", shape); // input.resize(shape); intptr.resize_tensor(input.view_mut(), shape); true } else { false } }; if needs_resize { tracing::trace!("Resized input tensor to shape: {:?}", shape); let now = std::time::Instant::now(); intptr.resize_session(session); tracing::trace!("Session resized in {:?}", now.elapsed()); } let mut input = intptr.input::(session, Self::INPUT_NAME)?; tracing::trace!("Input shape: {:?}", input.shape()); input.copy_from_host_tensor(tensor.view())?; tracing::info!("Running face detection session"); intptr.run_session(&session)?; let output_tensor = intptr .output::(&session, Self::OUTPUT_NAME)? .create_host_tensor_from_device(true) .as_ndarray() .to_owned(); Ok(output_tensor) }) .change_context(Error)?; Ok(output) } } impl FaceNetEmbedder for EmbeddingGenerator { fn run_model(&mut self, faces: ArrayView4) -> Result> { self.run_models(faces) } } // Main trait implementation for backward compatibility impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { fn run_models(&mut self, faces: ArrayView4) -> Result> { EmbeddingGenerator::run_models(self, faces) } }