use crate::errors::*; use mnn_bridge::ndarray::*; use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use std::path::Path; mod mnn_impl; mod ort_impl; #[derive(Debug)] pub struct EmbeddingGenerator { handle: mnn_sync::SessionHandle, } pub struct EmbeddingGeneratorBuilder { schedule_config: Option, backend_config: Option, model: mnn::Interpreter, } impl EmbeddingGeneratorBuilder { pub fn new(model: impl AsRef<[u8]>) -> Result { Ok(Self { schedule_config: None, backend_config: None, model: mnn::Interpreter::from_bytes(model.as_ref()) .map_err(|e| e.into_inner()) .change_context(Error) .attach_printable("Failed to load model from bytes")?, }) } pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self { self.schedule_config .get_or_insert_default() .set_type(forward_type); self } pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self { self.schedule_config = Some(config); self } pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self { self.backend_config = Some(config); self } pub fn build(self) -> Result { let model = self.model; let sc = self.schedule_config.unwrap_or_default(); let handle = mnn_sync::SessionHandle::new(model, sc) .change_context(Error) .attach_printable("Failed to create session handle")?; Ok(EmbeddingGenerator { handle }) } } impl EmbeddingGenerator { const INPUT_NAME: &'static str = "serving_default_input_6:0"; const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; pub fn new(path: impl AsRef) -> Result { let model = std::fs::read(path) .change_context(Error) .attach_printable("Failed to read model file")?; Self::new_from_bytes(&model) } pub fn builder>() -> fn(T) -> std::result::Result> { EmbeddingGeneratorBuilder::new } pub fn new_from_bytes(model: &[u8]) -> Result { tracing::info!("Loading face embedding model from bytes"); let mut model = mnn::Interpreter::from_bytes(model) .map_err(|e| e.into_inner()) .change_context(Error) .attach_printable("Failed to load model from bytes")?; model.set_session_mode(mnn::SessionMode::Release); model .set_cache_file("facenet.cache", 128) .change_context(Error) .attach_printable("Failed to set cache file")?; let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); let sc = mnn::ScheduleConfig::new() .with_type(mnn::ForwardType::Metal) .with_backend_config(bc); tracing::info!("Creating session handle for face embedding model"); let handle = mnn_sync::SessionHandle::new(model, sc) .change_context(Error) .attach_printable("Failed to create session handle")?; Ok(Self { handle }) } pub fn run_models(&self, face: ArrayView4) -> Result> { let tensor = face // .permuted_axes((0, 3, 1, 2)) .as_standard_layout() .mapv(|x| x as f32); let shape: [usize; 4] = tensor.dim().into(); let shape = shape.map(|f| f as i32); let output = self .handle .run(move |sr| { let tensor = tensor .as_mnn_tensor() .attach_printable("Failed to convert ndarray to mnn tensor") .change_context(mnn::ErrorKind::TensorError)?; tracing::trace!("Image Tensor shape: {:?}", tensor.shape()); let (intptr, session) = sr.both_mut(); tracing::trace!("Copying input tensor to host"); let needs_resize = unsafe { let mut input = intptr.input_unresized::(session, Self::INPUT_NAME)?; tracing::trace!("Input shape: {:?}", input.shape()); if *input.shape() != shape { tracing::trace!("Resizing input tensor to shape: {:?}", shape); // input.resize(shape); intptr.resize_tensor(input.view_mut(), shape); true } else { false } }; if needs_resize { tracing::trace!("Resized input tensor to shape: {:?}", shape); let now = std::time::Instant::now(); intptr.resize_session(session); tracing::trace!("Session resized in {:?}", now.elapsed()); } let mut input = intptr.input::(session, Self::INPUT_NAME)?; tracing::trace!("Input shape: {:?}", input.shape()); input.copy_from_host_tensor(tensor.view())?; tracing::info!("Running face detection session"); intptr.run_session(&session)?; let output_tensor = intptr .output::(&session, Self::OUTPUT_NAME)? .create_host_tensor_from_device(true) .as_ndarray() .to_owned(); Ok(output_tensor) }) .change_context(Error)?; Ok(output) } // pub fn embedding(&self, roi: ArrayView3) -> Result> { // todo!() // } // pub fn embeddings(&self, roi: ArrayView4) -> Result> { // todo!() // } }