154 lines
5.7 KiB
Rust
154 lines
5.7 KiB
Rust
use crate::errors::*;
|
|
use mnn_bridge::ndarray::*;
|
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
|
use std::path::Path;
|
|
mod mnn_impl;
|
|
mod ort_impl;
|
|
|
|
#[derive(Debug)]
|
|
pub struct EmbeddingGenerator {
|
|
handle: mnn_sync::SessionHandle,
|
|
}
|
|
pub struct EmbeddingGeneratorBuilder {
|
|
schedule_config: Option<mnn::ScheduleConfig>,
|
|
backend_config: Option<mnn::BackendConfig>,
|
|
model: mnn::Interpreter,
|
|
}
|
|
|
|
impl EmbeddingGeneratorBuilder {
|
|
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
|
Ok(Self {
|
|
schedule_config: None,
|
|
backend_config: None,
|
|
model: mnn::Interpreter::from_bytes(model.as_ref())
|
|
.map_err(|e| e.into_inner())
|
|
.change_context(Error)
|
|
.attach_printable("Failed to load model from bytes")?,
|
|
})
|
|
}
|
|
|
|
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
|
self.schedule_config
|
|
.get_or_insert_default()
|
|
.set_type(forward_type);
|
|
self
|
|
}
|
|
|
|
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
|
self.schedule_config = Some(config);
|
|
self
|
|
}
|
|
|
|
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
|
self.backend_config = Some(config);
|
|
self
|
|
}
|
|
|
|
pub fn build(self) -> Result<EmbeddingGenerator> {
|
|
let model = self.model;
|
|
let sc = self.schedule_config.unwrap_or_default();
|
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create session handle")?;
|
|
Ok(EmbeddingGenerator { handle })
|
|
}
|
|
}
|
|
|
|
impl EmbeddingGenerator {
|
|
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
|
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
|
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
|
let model = std::fs::read(path)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to read model file")?;
|
|
Self::new_from_bytes(&model)
|
|
}
|
|
|
|
pub fn builder<T: AsRef<[u8]>>()
|
|
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
|
EmbeddingGeneratorBuilder::new
|
|
}
|
|
|
|
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
|
tracing::info!("Loading face embedding model from bytes");
|
|
let mut model = mnn::Interpreter::from_bytes(model)
|
|
.map_err(|e| e.into_inner())
|
|
.change_context(Error)
|
|
.attach_printable("Failed to load model from bytes")?;
|
|
model.set_session_mode(mnn::SessionMode::Release);
|
|
model
|
|
.set_cache_file("facenet.cache", 128)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to set cache file")?;
|
|
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
|
let sc = mnn::ScheduleConfig::new()
|
|
.with_type(mnn::ForwardType::Metal)
|
|
.with_backend_config(bc);
|
|
tracing::info!("Creating session handle for face embedding model");
|
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
|
.change_context(Error)
|
|
.attach_printable("Failed to create session handle")?;
|
|
Ok(Self { handle })
|
|
}
|
|
|
|
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
|
let tensor = face
|
|
// .permuted_axes((0, 3, 1, 2))
|
|
.as_standard_layout()
|
|
.mapv(|x| x as f32);
|
|
let shape: [usize; 4] = tensor.dim().into();
|
|
let shape = shape.map(|f| f as i32);
|
|
let output = self
|
|
.handle
|
|
.run(move |sr| {
|
|
let tensor = tensor
|
|
.as_mnn_tensor()
|
|
.attach_printable("Failed to convert ndarray to mnn tensor")
|
|
.change_context(mnn::ErrorKind::TensorError)?;
|
|
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
|
let (intptr, session) = sr.both_mut();
|
|
tracing::trace!("Copying input tensor to host");
|
|
let needs_resize = unsafe {
|
|
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
|
tracing::trace!("Input shape: {:?}", input.shape());
|
|
if *input.shape() != shape {
|
|
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
|
// input.resize(shape);
|
|
intptr.resize_tensor(input.view_mut(), shape);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
};
|
|
if needs_resize {
|
|
tracing::trace!("Resized input tensor to shape: {:?}", shape);
|
|
let now = std::time::Instant::now();
|
|
intptr.resize_session(session);
|
|
tracing::trace!("Session resized in {:?}", now.elapsed());
|
|
}
|
|
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
|
tracing::trace!("Input shape: {:?}", input.shape());
|
|
input.copy_from_host_tensor(tensor.view())?;
|
|
|
|
tracing::info!("Running face detection session");
|
|
intptr.run_session(&session)?;
|
|
let output_tensor = intptr
|
|
.output::<f32>(&session, Self::OUTPUT_NAME)?
|
|
.create_host_tensor_from_device(true)
|
|
.as_ndarray()
|
|
.to_owned();
|
|
Ok(output_tensor)
|
|
})
|
|
.change_context(Error)?;
|
|
Ok(output)
|
|
}
|
|
|
|
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
|
// todo!()
|
|
// }
|
|
|
|
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
|
// todo!()
|
|
// }
|
|
}
|