diff --git a/Cargo.lock b/Cargo.lock index 05c0241..315667e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,15 +116,6 @@ version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" -[[package]] -name = "approx" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278" -dependencies = [ - "num-traits", -] - [[package]] name = "approx" version = "0.5.1" @@ -253,7 +244,7 @@ dependencies = [ "color", "itertools 0.14.0", "nalgebra", - "ndarray 0.16.1", + "ndarray", "num", "ordered-float", "simba", @@ -506,12 +497,11 @@ dependencies = [ "fast_image_resize", "image", "itertools 0.14.0", - "linfa", "mnn", "mnn-bridge", "mnn-sync", "nalgebra", - "ndarray 0.16.1", + "ndarray", "ndarray-image", "ndarray-resize", "ordered-float", @@ -1098,20 +1088,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "linfa" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f9097edc7c89d03d526efbacf6d90914e3a8fa53bd56c2d1489e3a90819370" -dependencies = [ - "approx 0.4.0", - "ndarray 0.15.6", - "num-traits", - "rand", - "sprs", - "thiserror 1.0.69", -] - [[package]] name = "litemap" version = "0.8.0" @@ -1220,7 +1196,7 @@ source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-ty dependencies = [ "error-stack", "mnn", - "ndarray 0.16.1", + "ndarray", ] [[package]] @@ -1259,7 +1235,7 @@ version = "0.33.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b" dependencies = [ - "approx 0.5.1", + "approx", "matrixmultiply", "nalgebra-macros", "num-complex", @@ -1289,20 +1265,6 @@ dependencies = [ "getrandom 0.2.16", ] -[[package]] -name = "ndarray" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" -dependencies = [ - "approx 0.4.0", - "matrixmultiply", - "num-complex", - "num-integer", - "num-traits", - "rawpointer", -] - [[package]] name = "ndarray" version = "0.16.1" @@ -1323,7 +1285,7 @@ name = "ndarray-image" version = "0.1.0" dependencies = [ "image", - "ndarray 0.16.1", + "ndarray", ] [[package]] @@ -1333,7 +1295,7 @@ dependencies = [ "bytemuck", "error-stack", "fast_image_resize", - "ndarray 0.16.1", + "ndarray", "num", "thiserror 2.0.12", ] @@ -1942,7 +1904,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa" dependencies = [ - "approx 0.5.1", + "approx", "num-complex", "num-traits", "paste", @@ -1979,18 +1941,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "sprs" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88bab60b0a18fb9b3e0c26e92796b3c3a278bf5fa4880f5ad5cc3bdfb843d0b1" -dependencies = [ - "ndarray 0.15.6", - "num-complex", - "num-traits", - "smallvec", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index ea89789..ad37394 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,8 @@ mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" } ndarray-image = { path = "ndarray-image" } ndarray-resize = { path = "ndarray-resize" } mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [ - # "metal", - # "coreml", + "metal", + "coreml", "tracing", ], branch = "restructure-tensor-type" } mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [ @@ -35,7 +35,6 @@ clap_complete = "4.5" error-stack = "0.5" fast_image_resize = "5.2.0" image = "0.25.6" -linfa = "0.7.1" nalgebra = "0.33.2" ndarray = "0.16.1" ndarray-image = { workspace = true } diff --git a/bounding-box/src/draw.rs b/bounding-box/src/draw.rs index 8ed4e38..c6dca47 100644 --- a/bounding-box/src/draw.rs +++ b/bounding-box/src/draw.rs @@ -4,11 +4,11 @@ pub use color::Rgba8; use ndarray::{Array1, Array3, ArrayViewMut3}; pub trait Draw { - fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize); + fn draw(&mut self, item: &T, color: color::Rgba8, thickness: usize); } impl Draw> for Array3 { - fn draw(&mut self, item: Aabb2, color: color::Rgba8, thickness: usize) { + fn draw(&mut self, item: &Aabb2, color: color::Rgba8, thickness: usize) { item.draw(self, color, thickness) } } diff --git a/bounding-box/src/roi.rs b/bounding-box/src/roi.rs index 3d07357..2c8cc31 100644 --- a/bounding-box/src/roi.rs +++ b/bounding-box/src/roi.rs @@ -5,10 +5,17 @@ pub trait Roi<'a, Output> { type Error; fn roi(&'a self, aabb: Aabb2) -> Result; } + pub trait RoiMut<'a, Output> { type Error; fn roi_mut(&'a mut self, aabb: Aabb2) -> Result; } + +pub trait MultiRoi<'a, Output> { + type Error; + fn multi_roi(&'a self, aabbs: &[Aabb2]) -> Result; +} + #[derive(thiserror::Error, Debug, Copy, Clone)] pub enum RoiError { #[error("Region of intereset is out of bounds")] @@ -36,7 +43,7 @@ impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3 { let x2 = aabb.x2(); let y1 = aabb.y1(); let y2 = aabb.y2(); - if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { + if x1 > x2 || y1 > y2 || x2 > self.shape()[1] || y2 > self.shape()[0] { return Err(RoiError::RoiOutOfBounds); } Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..])) @@ -95,3 +102,47 @@ pub fn reborrow_test() { }; dbg!(y); } + +impl<'a> MultiRoi<'a, Vec>> for Array3 { + type Error = RoiError; + fn multi_roi(&'a self, aabbs: &[Aabb2]) -> Result>, Self::Error> { + let (height, width, _channels) = self.dim(); + let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height); + aabbs + .iter() + .map(|aabb| { + let slice_arg = + bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?); + Ok(self.slice(slice_arg)) + }) + .collect::, RoiError>>() + } +} + +impl<'a, 'b> MultiRoi<'a, Vec>> for ArrayView3<'b, u8> { + type Error = RoiError; + fn multi_roi(&'a self, aabbs: &[Aabb2]) -> Result>, Self::Error> { + let (height, width, _channels) = self.dim(); + let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height); + aabbs + .iter() + .map(|aabb| { + let slice_arg = + bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?); + Ok(self.slice_move(slice_arg)) + }) + .collect::, RoiError>>() + } +} + +fn bbox_to_slice_arg( + aabb: Aabb2, +) -> ndarray::SliceInfo<[ndarray::SliceInfoElem; 3], ndarray::Ix3, ndarray::Ix3> { + // This function should convert the bounding box to a slice argument + // For now, we will return a dummy value + let x1 = aabb.x1(); + let x2 = aabb.x2(); + let y1 = aabb.y1(); + let y2 = aabb.y2(); + ndarray::s![y1..y2, x1..x2, ..] +} diff --git a/flake.lock b/flake.lock index c0493d2..ad898d8 100644 --- a/flake.lock +++ b/flake.lock @@ -178,11 +178,11 @@ ] }, "locked": { - "lastModified": 1750732748, - "narHash": "sha256-HR2b3RHsPeJm+Fb+1ui8nXibgniVj7hBNvUbXEyz0DU=", + "lastModified": 1754621349, + "narHash": "sha256-JkXUS/nBHyUqVTuL4EDCvUWauTHV78EYfk+WqiTAMQ4=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "4b4494b2ba7e8a8041b2e28320b2ee02c115c75f", + "rev": "c448ab42002ac39d3337da10420c414fccfb1088", "type": "github" }, "original": { diff --git a/models/retinaface.mnn b/models/retinaface.mnn index 7d702e2..72f5a89 100644 --- a/models/retinaface.mnn +++ b/models/retinaface.mnn @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:456056d9b871bfdf84419882e616589904cc8c10bfb5d74b3678ec0ff0d5d3e0 +oid sha256:72a47d1ceeab1c649a002aca52a12c4f25bcd2245b8674494e970ccb74595202 size 54627324 diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index 7ce7754..856b9ad 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -261,6 +261,10 @@ impl FaceDetection { .change_context(Error) .attach_printable("Failed to load model from bytes")?; model.set_session_mode(mnn::SessionMode::Release); + model + .set_cache_file("retinaface.cache", 128) + .change_context(Error) + .attach_printable("Failed to set cache file")?; let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); let sc = mnn::ScheduleConfig::new() .with_type(mnn::ForwardType::CPU) @@ -330,26 +334,26 @@ impl FaceDetection { pub fn run_models(&self, image: ndarray::ArrayView3) -> Result { #[rustfmt::skip] + let mut resized = image + .fast_resize(1024, 1024, None) + .change_context(Error)? + .mapv(|f| f as f32) + .tap_mut(|arr| { + arr.axis_iter_mut(ndarray::Axis(2)) + .zip([104, 117, 123]) + .for_each(|(mut array, pixel)| { + let pixel = pixel as f32; + array.map_inplace(|v| *v -= pixel); + }); + }) + .permuted_axes((2, 0, 1)) + .insert_axis(ndarray::Axis(0)) + .as_standard_layout() + .into_owned(); use ::tap::*; let output = self .handle .run(move |sr| { - let mut resized = image - .fast_resize(1024, 1024, None) - .change_context(mnn::ErrorKind::TensorError)? - .mapv(|f| f as f32) - .tap_mut(|arr| { - arr.axis_iter_mut(ndarray::Axis(2)) - .zip([104, 117, 123]) - .for_each(|(mut array, pixel)| { - let pixel = pixel as f32; - array.map_inplace(|v| *v -= pixel); - }); - }) - .permuted_axes((2, 0, 1)) - .insert_axis(ndarray::Axis(0)) - .as_standard_layout() - .into_owned(); let tensor = resized .as_mnn_tensor_mut() .attach_printable("Failed to convert ndarray to mnn tensor") diff --git a/src/faceembed/facenet.rs b/src/faceembed/facenet.rs index 5dfe4dd..5abae7a 100644 --- a/src/faceembed/facenet.rs +++ b/src/faceembed/facenet.rs @@ -1,4 +1,5 @@ use crate::errors::*; +use mnn_bridge::ndarray::*; use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use std::path::Path; @@ -8,6 +9,8 @@ pub struct EmbeddingGenerator { } impl EmbeddingGenerator { + const INPUT_NAME: &'static str = "serving_default_input_6:0"; + const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; pub fn new(path: impl AsRef) -> Result { let model = std::fs::read(path) .change_context(Error) @@ -22,9 +25,13 @@ impl EmbeddingGenerator { .change_context(Error) .attach_printable("Failed to load model from bytes")?; model.set_session_mode(mnn::SessionMode::Release); + model + .set_cache_file("facenet.cache", 128) + .change_context(Error) + .attach_printable("Failed to set cache file")?; let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); let sc = mnn::ScheduleConfig::new() - .with_type(mnn::ForwardType::CPU) + .with_type(mnn::ForwardType::Metal) .with_backend_config(bc); tracing::info!("Creating session handle for face embedding model"); let handle = mnn_sync::SessionHandle::new(model, sc) @@ -33,11 +40,55 @@ impl EmbeddingGenerator { Ok(Self { handle }) } - pub fn embedding(&self, roi: ArrayView3) -> Result> { - todo!() + pub fn run_models(&self, face: ArrayView4) -> Result> { + let tensor = face + // .permuted_axes((0, 3, 1, 2)) + .as_standard_layout() + .mapv(|x| x as f32); + let shape: [usize; 4] = tensor.dim().into(); + let shape = shape.map(|f| f as i32); + let output = self + .handle + .run(move |sr| { + let tensor = tensor + .as_mnn_tensor() + .attach_printable("Failed to convert ndarray to mnn tensor") + .change_context(mnn::ErrorKind::TensorError)?; + tracing::trace!("Image Tensor shape: {:?}", tensor.shape()); + let (intptr, session) = sr.both_mut(); + tracing::trace!("Copying input tensor to host"); + unsafe { + let mut input = intptr.input_unresized::(session, Self::INPUT_NAME)?; + tracing::trace!("Input shape: {:?}", input.shape()); + if *input.shape() != shape { + tracing::trace!("Resizing input tensor to shape: {:?}", shape); + // input.resize(shape); + intptr.resize_tensor(input.view_mut(), shape); + } + } + intptr.resize_session(session); + let mut input = intptr.input::(session, Self::INPUT_NAME)?; + tracing::trace!("Input shape: {:?}", input.shape()); + input.copy_from_host_tensor(tensor.view())?; + + tracing::info!("Running face detection session"); + intptr.run_session(&session)?; + let output_tensor = intptr + .output::(&session, Self::OUTPUT_NAME)? + .create_host_tensor_from_device(true) + .as_ndarray() + .to_owned(); + Ok(output_tensor) + }) + .change_context(Error)?; + Ok(output) } - pub fn embeddings(&self, roi: ArrayView4) -> Result> { - todo!() - } + // pub fn embedding(&self, roi: ArrayView3) -> Result> { + // todo!() + // } + + // pub fn embeddings(&self, roi: ArrayView4) -> Result> { + // todo!() + // } } diff --git a/src/main.rs b/src/main.rs index b389049..d06626c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,13 @@ mod cli; mod errors; -use detector::facedet::retinaface::FaceDetectionConfig; +use bounding_box::roi::MultiRoi; +use detector::{facedet::retinaface::FaceDetectionConfig, faceembed}; use errors::*; +use fast_image_resize::ResizeOptions; +use nalgebra::zero; use ndarray_image::*; const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn"); +const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn"); pub fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter("trace") @@ -15,29 +19,84 @@ pub fn main() -> Result<()> { match args.cmd { cli::SubCommand::Detect(detect) => { use detector::facedet; - let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL) + let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL) .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; + let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL) + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; let image = image::open(detect.image).change_context(Error)?; let image = image.into_rgb8(); let mut array = image .into_ndarray() .change_context(errors::Error) .attach_printable("Failed to convert image to ndarray")?; - let output = model + let output = retinaface .detect_faces( - array.clone(), + array.view(), FaceDetectionConfig::default() .with_threshold(detect.threshold) .with_nms_threshold(detect.nms_threshold), ) .change_context(errors::Error) .attach_printable("Failed to detect faces")?; - for bbox in output.bbox { + for bbox in &output.bbox { tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); } + use ndarray::{Array2, Array3, Array4, Axis}; + use ndarray_resize::NdFir; + let face_rois = array + .view() + .multi_roi(&output.bbox) + .change_context(Error)? + .into_iter() + // .inspect(|f| { + // tracing::info!("Face ROI shape before resize: {:?}", f.dim()); + // }) + .map(|roi| { + roi.as_standard_layout() + .fast_resize(512, 512, &ResizeOptions::default()) + .change_context(Error) + }) + // .inspect(|f| { + // f.as_ref().inspect(|f| { + // tracing::info!("Face ROI shape after resize: {:?}", f.dim()); + // }); + // }) + .collect::>>()?; + let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::>(); + + let embeddings = face_roi_views + .chunks(8) + .map(|chunk| { + tracing::info!("Processing chunk of size: {}", chunk.len()); + + if chunk.len() < 8 { + tracing::warn!("Chunk size is less than 8, padding with zeros"); + let zeros = Array3::zeros((512, 512, 3)); + let padded: Vec> = chunk + .iter() + .cloned() + .chain(core::iter::repeat(zeros.view())) + .take(8) + .collect(); + let face_rois: Array4 = ndarray::stack(Axis(0), padded.as_slice()) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + let output = facenet.run_models(face_rois.view()).change_context(Error)?; + Ok(output) + } else { + let face_rois: Array4 = ndarray::stack(Axis(0), chunk) + .change_context(errors::Error) + .attach_printable("Failed to stack rois together")?; + let output = facenet.run_models(face_rois.view()).change_context(Error)?; + Ok(output) + } + }) + .collect::>>>(); + let v = array.view(); if let Some(output) = detect.output { let image: image::RgbImage = v