feat: Added stuff
Some checks failed
build / checks-build (push) Has been cancelled
build / checks-matrix (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-18 22:10:29 +05:30
parent 7fc958b299
commit 783320131a
12 changed files with 73 additions and 464 deletions

View File

@@ -11,10 +11,9 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
const CHUNK_SIZE: usize = 2;
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("trace")
.with_env_filter("error")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
@@ -35,8 +34,6 @@ pub fn main() -> Result<()> {
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
// .then_some(cli::Executor::Mnn)
// .unwrap_or(cli::Executor::Ort);
match executor {
cli::Executor::Mnn(forward) => {
@@ -92,7 +89,9 @@ where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
let image = image::open(detect.image).change_context(Error)?;
let image = image::open(&detect.image)
.change_context(Error)
.attach_printable(detect.image.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let mut array = image
.into_ndarray()
@@ -122,7 +121,7 @@ where
// })
.map(|roi| {
roi.as_standard_layout()
.fast_resize(512, 512, &ResizeOptions::default())
.fast_resize(160, 160, &ResizeOptions::default())
.change_context(Error)
})
// .inspect(|f| {
@@ -133,15 +132,15 @@ where
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = CHUNK_SIZE;
let chunk_size = detect.batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len());
if chunk.len() < 8 {
if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((512, 512, 3));
let zeros = Array3::zeros((160, 160, 3));
let zero_array = core::iter::repeat(zeros.view())
.take(chunk_size)
.collect::<Vec<_>>();
@@ -158,7 +157,7 @@ where
Ok(output)
}
})
.collect::<Result<Vec<Array2<f32>>>>();
.collect::<Result<Vec<Array2<f32>>>>()?;
let v = array.view();
if let Some(output) = detect.output {