diff --git a/flake.lock b/flake.lock index ddecf2c..c0493d2 100644 --- a/flake.lock +++ b/flake.lock @@ -109,16 +109,16 @@ "mnn-src": { "flake": false, "locked": { - "lastModified": 1749173738, - "narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=", + "lastModified": 1753256753, + "narHash": "sha256-aTpwVZBkpQiwOVVXDfQIVEx9CswNiPbvNftw8KsoW+Q=", "owner": "alibaba", "repo": "MNN", - "rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32", + "rev": "a739ea5870a4a45680f0e36ba9662ca39f2f4eec", "type": "github" }, "original": { "owner": "alibaba", - "ref": "3.2.0", + "ref": "3.2.2", "repo": "MNN", "type": "github" } diff --git a/flake.nix b/flake.nix index ddd2e70..9aae400 100644 --- a/flake.nix +++ b/flake.nix @@ -22,7 +22,7 @@ inputs.nixpkgs.follows = "nixpkgs"; }; mnn-src = { - url = "github:alibaba/MNN/3.2.0"; + url = "github:alibaba/MNN/3.2.2"; flake = false; }; }; diff --git a/models/retinaface.mnn b/models/retinaface.mnn index 45b7d85..7d702e2 100644 Binary files a/models/retinaface.mnn and b/models/retinaface.mnn differ diff --git a/src/facedet/retinaface.rs b/src/facedet/retinaface.rs index 20911b0..7ce7754 100644 --- a/src/facedet/retinaface.rs +++ b/src/facedet/retinaface.rs @@ -274,7 +274,7 @@ impl FaceDetection { pub fn detect_faces( &self, - image: ndarray::Array3, + image: ndarray::ArrayView3, config: FaceDetectionConfig, ) -> Result { let (height, width, _channels) = image.dim(); @@ -299,7 +299,8 @@ impl FaceDetection { .map(|((b, s), l)| (b, s, l)) .multiunzip(); - let keep_indices = nms(&boxes, &scores, config.threshold, config.nms_threshold); + let keep_indices = + nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?; let bboxes = boxes .into_iter() @@ -327,7 +328,7 @@ impl FaceDetection { }) } - pub fn run_models(&self, image: ndarray::Array3) -> Result { + pub fn run_models(&self, image: ndarray::ArrayView3) -> Result { #[rustfmt::skip] use ::tap::*; let output = self diff --git a/src/faceembed/facenet.rs b/src/faceembed/facenet.rs index 441a9c3..5dfe4dd 100644 --- a/src/faceembed/facenet.rs +++ b/src/faceembed/facenet.rs @@ -1,5 +1,5 @@ use crate::errors::*; -use ndarray::{Array1, ArrayView3}; +use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use std::path::Path; #[derive(Debug)] @@ -36,4 +36,8 @@ impl EmbeddingGenerator { pub fn embedding(&self, roi: ArrayView3) -> Result> { todo!() } + + pub fn embeddings(&self, roi: ArrayView4) -> Result> { + todo!() + } }