feat: Added stuff
Some checks failed
build / checks-matrix (push) Successful in 23m6s
build / codecov (push) Failing after 19m30s
docs / docs (push) Failing after 28m54s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-13 18:08:03 +05:30
parent f5740dc87f
commit 2d2309837f
12 changed files with 1151 additions and 227 deletions

919
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [ mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
"tracing", "tracing",
], branch = "restructure-tensor-type" } ], branch = "restructure-tensor-type" }
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
[package] [package]
name = "detector" name = "detector"
@@ -35,7 +36,7 @@ clap_complete = "4.5"
error-stack = "0.5" error-stack = "0.5"
fast_image_resize = "5.2.0" fast_image_resize = "5.2.0"
image = "0.25.6" image = "0.25.6"
nalgebra = "0.33.2" nalgebra = { workspace = true }
ndarray = "0.16.1" ndarray = "0.16.1"
ndarray-image = { workspace = true } ndarray-image = { workspace = true }
ndarray-resize = { workspace = true } ndarray-resize = { workspace = true }
@@ -52,6 +53,7 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
color = "0.3.1" color = "0.3.1"
itertools = "0.14.0" itertools = "0.14.0"
ordered-float = "5.0.0" ordered-float = "5.0.0"
ort = "2.0.0-rc.10"
[profile.release] [profile.release]
debug = true debug = true

27
Makefile.toml Normal file
View File

@@ -0,0 +1,27 @@
[tasks.convert_facenet]
command = "MNNConvert"
args = [
"-f",
"ONNX",
"--modelFile",
"models/facenet.onnx",
"--MNNModel",
"models/facenet.mnn",
"--fp16",
"--bizCode",
"MNN",
]
[tasks.convert_retinaface]
command = "MNNConvert"
args = [
"-f",
"ONNX",
"--modelFile",
"models/retinaface.onnx",
"--MNNModel",
"models/retinaface.mnn",
"--fp16",
"--bizCode",
"MNN",
]

View File

@@ -6,7 +6,7 @@ edition = "2024"
[dependencies] [dependencies]
color = "0.3.1" color = "0.3.1"
itertools = "0.14.0" itertools = "0.14.0"
nalgebra = "0.33.2" nalgebra = { workspace = true }
ndarray = { version = "0.16.1", optional = true } ndarray = { version = "0.16.1", optional = true }
num = "0.4.3" num = "0.4.3"
ordered-float = "5.0.0" ordered-float = "5.0.0"

BIN
facenet.mnn Normal file

Binary file not shown.

214
flake.nix
View File

@@ -27,20 +27,22 @@
}; };
}; };
outputs = { outputs =
self, {
crane, self,
flake-utils, crane,
nixpkgs, flake-utils,
rust-overlay, nixpkgs,
advisory-db, rust-overlay,
nix-github-actions, advisory-db,
mnn-overlay, nix-github-actions,
mnn-src, mnn-overlay,
... mnn-src,
}: ...
}:
flake-utils.lib.eachDefaultSystem ( flake-utils.lib.eachDefaultSystem (
system: let system:
let
pkgs = import nixpkgs { pkgs = import nixpkgs {
inherit system; inherit system;
overlays = [ overlays = [
@@ -61,118 +63,148 @@
stableToolchain = pkgs.rust-bin.stable.latest.default; stableToolchain = pkgs.rust-bin.stable.latest.default;
stableToolchainWithLLvmTools = stableToolchain.override { stableToolchainWithLLvmTools = stableToolchain.override {
extensions = ["rust-src" "llvm-tools"]; extensions = [
"rust-src"
"llvm-tools"
];
}; };
stableToolchainWithRustAnalyzer = stableToolchain.override { stableToolchainWithRustAnalyzer = stableToolchain.override {
extensions = ["rust-src" "rust-analyzer"]; extensions = [
"rust-src"
"rust-analyzer"
];
}; };
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain; craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools; craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
src = let src =
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts; let
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"]; filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
in sourceFilters =
path: type:
(craneLib.filterCargoSources path type)
|| filterBySuffix path [
".c"
".h"
".hpp"
".cpp"
".cc"
];
in
lib.cleanSourceWith { lib.cleanSourceWith {
filter = sourceFilters; filter = sourceFilters;
src = ./.; src = ./.;
}; };
commonArgs = commonArgs = {
{ inherit src;
inherit src; pname = name;
pname = name; stdenv = pkgs.clangStdenv;
stdenv = pkgs.clangStdenv; doCheck = false;
doCheck = false; LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; # nativeBuildInputs = with pkgs; [
# nativeBuildInputs = with pkgs; [ # cmake
# cmake # llvmPackages.libclang.lib
# llvmPackages.libclang.lib # ];
# ]; buildInputs =
buildInputs = with pkgs; with pkgs;
[] [ ]
++ (lib.optionals pkgs.stdenv.isDarwin [ ++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv libiconv
apple-sdk_13 apple-sdk_13
]); ]);
} }
// (lib.optionalAttrs pkgs.stdenv.isLinux { // (lib.optionalAttrs pkgs.stdenv.isLinux {
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include"; # BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
}); });
cargoArtifacts = craneLib.buildPackage commonArgs; cargoArtifacts = craneLib.buildPackage commonArgs;
in { in
checks = {
{ checks = {
"${name}-clippy" = craneLib.cargoClippy (commonArgs "${name}-clippy" = craneLib.cargoClippy (
// { commonArgs
inherit cargoArtifacts; // {
cargoClippyExtraArgs = "--all-targets -- --deny warnings"; inherit cargoArtifacts;
}); cargoClippyExtraArgs = "--all-targets -- --deny warnings";
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;}); }
"${name}-fmt" = craneLib.cargoFmt {inherit src;}; );
"${name}-toml-fmt" = craneLib.taploFmt { "${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; });
src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"]; "${name}-fmt" = craneLib.cargoFmt { inherit src; };
}; "${name}-toml-fmt" = craneLib.taploFmt {
# Audit dependencies src = pkgs.lib.sources.sourceFilesBySuffices src [ ".toml" ];
"${name}-audit" = craneLib.cargoAudit { };
inherit src advisory-db; # Audit dependencies
}; "${name}-audit" = craneLib.cargoAudit {
inherit src advisory-db;
# Audit licenses
"${name}-deny" = craneLib.cargoDeny {
inherit src;
};
"${name}-nextest" = craneLib.cargoNextest (commonArgs
// {
inherit cargoArtifacts;
partitions = 1;
partitionType = "count";
});
}
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
}; };
packages = let # Audit licenses
pkg = craneLib.buildPackage (commonArgs "${name}-deny" = craneLib.cargoDeny {
// {inherit cargoArtifacts;} inherit src;
};
"${name}-nextest" = craneLib.cargoNextest (
commonArgs
// { // {
nativeBuildInputs = with pkgs; [ inherit cargoArtifacts;
installShellFiles partitions = 1;
]; partitionType = "count";
postInstall = '' }
installShellCompletion --cmd ${name} \ );
--bash <($out/bin/${name} completions bash) \ }
--fish <($out/bin/${name} completions fish) \ // lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
--zsh <($out/bin/${name} completions zsh) "${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; });
'';
});
in {
"${name}" = pkg;
default = pkg;
}; };
packages =
let
pkg = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
}
// {
nativeBuildInputs = with pkgs; [
installShellFiles
];
postInstall = ''
installShellCompletion --cmd ${name} \
--bash <($out/bin/${name} completions bash) \
--fish <($out/bin/${name} completions fish) \
--zsh <($out/bin/${name} completions zsh)
'';
}
);
in
{
"${name}" = pkg;
default = pkg;
};
devShells = { devShells = {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } (
commonArgs
// { // {
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver"; LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
packages = with pkgs; packages =
with pkgs;
[ [
stableToolchainWithRustAnalyzer stableToolchainWithRustAnalyzer
cargo-nextest cargo-nextest
cargo-deny cargo-deny
cmake cmake
mnn mnn
cargo-make
] ]
++ (lib.optionals pkgs.stdenv.isDarwin [ ++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13 apple-sdk_13
]); ]);
}); }
);
}; };
} }
) )
// { // {
githubActions = nix-github-actions.lib.mkGithubMatrix { githubActions = nix-github-actions.lib.mkGithubMatrix {
checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks; checks = nixpkgs.lib.getAttrs [ "x86_64-linux" ] self.checks;
}; };
}; };
} }

View File

@@ -48,6 +48,8 @@ pub struct Detect {
pub model_type: Models, pub model_type: Models,
#[clap(short, long)] #[clap(short, long)]
pub output: Option<PathBuf>, pub output: Option<PathBuf>,
#[clap(short, long, default_value = "cpu")]
pub forward_type: mnn::ForwardType,
#[clap(short, long, default_value_t = 0.8)] #[clap(short, long, default_value_t = 0.8)]
pub threshold: f32, pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)] #[clap(short, long, default_value_t = 0.3)]

View File

@@ -246,7 +246,56 @@ impl FaceDetectionModelOutput {
} }
} }
pub struct FaceDetectionBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl FaceDetectionBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<FaceDetection> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
}
impl FaceDetection { impl FaceDetection {
pub fn builder<T: AsRef<[u8]>>()
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
FaceDetectionBuilder::new
}
pub fn new(path: impl AsRef<Path>) -> Result<Self> { pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path) let model = std::fs::read(path)
.change_context(Error) .change_context(Error)
@@ -267,7 +316,7 @@ impl FaceDetection {
.attach_printable("Failed to set cache file")?; .attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High); let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new() let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::CPU) .with_type(mnn::ForwardType::Metal)
.with_backend_config(bc); .with_backend_config(bc);
tracing::info!("Creating session handle for face detection model"); tracing::info!("Creating session handle for face detection model");
let handle = mnn_sync::SessionHandle::new(model, sc) let handle = mnn_sync::SessionHandle::new(model, sc)

View File

@@ -2,11 +2,57 @@ use crate::errors::*;
use mnn_bridge::ndarray::*; use mnn_bridge::ndarray::*;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use std::path::Path; use std::path::Path;
mod mnn_impl;
mod ort_impl;
#[derive(Debug)] #[derive(Debug)]
pub struct EmbeddingGenerator { pub struct EmbeddingGenerator {
handle: mnn_sync::SessionHandle, handle: mnn_sync::SessionHandle,
} }
pub struct EmbeddingGeneratorBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl EmbeddingGeneratorBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<EmbeddingGenerator> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(EmbeddingGenerator { handle })
}
}
impl EmbeddingGenerator { impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0"; const INPUT_NAME: &'static str = "serving_default_input_6:0";
@@ -18,6 +64,11 @@ impl EmbeddingGenerator {
Self::new_from_bytes(&model) Self::new_from_bytes(&model)
} }
pub fn builder<T: AsRef<[u8]>>()
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
EmbeddingGeneratorBuilder::new
}
pub fn new_from_bytes(model: &[u8]) -> Result<Self> { pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
tracing::info!("Loading face embedding model from bytes"); tracing::info!("Loading face embedding model from bytes");
let mut model = mnn::Interpreter::from_bytes(model) let mut model = mnn::Interpreter::from_bytes(model)
@@ -57,16 +108,24 @@ impl EmbeddingGenerator {
tracing::trace!("Image Tensor shape: {:?}", tensor.shape()); tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut(); let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host"); tracing::trace!("Copying input tensor to host");
unsafe { let needs_resize = unsafe {
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?; let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape()); tracing::trace!("Input shape: {:?}", input.shape());
if *input.shape() != shape { if *input.shape() != shape {
tracing::trace!("Resizing input tensor to shape: {:?}", shape); tracing::trace!("Resizing input tensor to shape: {:?}", shape);
// input.resize(shape); // input.resize(shape);
intptr.resize_tensor(input.view_mut(), shape); intptr.resize_tensor(input.view_mut(), shape);
true
} else {
false
} }
};
if needs_resize {
tracing::trace!("Resized input tensor to shape: {:?}", shape);
let now = std::time::Instant::now();
intptr.resize_session(session);
tracing::trace!("Session resized in {:?}", now.elapsed());
} }
intptr.resize_session(session);
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?; let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape()); tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?; input.copy_from_host_tensor(tensor.view())?;

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,65 @@
use crate::errors::{Result, *};
use ndarray::*;
use ort::*;
use std::path::Path;
#[derive(Debug)]
pub struct EmbeddingGenerator {
handle: ort::session::Session,
}
// impl EmbeddingGeneratorBuilder {
// pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
// Ok(Self {
// schedule_config: None,
// backend_config: None,
// model: mnn::Interpreter::from_bytes(model.as_ref())
// .map_err(|e| e.into_inner())
// .change_context(Error)
// .attach_printable("Failed to load model from bytes")?,
// })
// }
//
// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
// self.schedule_config
// .get_or_insert_default()
// .set_type(forward_type);
// self
// }
//
// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
// self.schedule_config = Some(config);
// self
// }
//
// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
// self.backend_config = Some(config);
// self
// }
//
// pub fn build(self) -> Result<EmbeddingGenerator> {
// let model = self.model;
// let sc = self.schedule_config.unwrap_or_default();
// let handle = mnn_sync::SessionHandle::new(model, sc)
// .change_context(Error)
// .attach_printable("Failed to create session handle")?;
// Ok(EmbeddingGenerator { handle })
// }
// }
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result<Self> {
todo!()
}
// pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {}
}

View File

@@ -4,10 +4,12 @@ use bounding_box::roi::MultiRoi;
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed}; use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
use errors::*; use errors::*;
use fast_image_resize::ResizeOptions; use fast_image_resize::ResizeOptions;
use nalgebra::zero; use ndarray::*;
use ndarray_image::*; use ndarray_image::*;
use ndarray_resize::NdFir;
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn"); const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn"); const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
const CHUNK_SIZE: usize = 8;
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter("trace") .with_env_filter("trace")
@@ -19,10 +21,16 @@ pub fn main() -> Result<()> {
match args.cmd { match args.cmd {
cli::SubCommand::Detect(detect) => { cli::SubCommand::Detect(detect) => {
use detector::facedet; use detector::facedet;
let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL) let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
.change_context(Error)?
.with_forward_type(detect.forward_type)
.build()
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to create face detection model")?; .attach_printable("Failed to create face detection model")?;
let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL) let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
.change_context(Error)?
.with_forward_type(detect.forward_type)
.build()
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?; .attach_printable("Failed to create face embedding model")?;
let image = image::open(detect.image).change_context(Error)?; let image = image::open(detect.image).change_context(Error)?;
@@ -45,8 +53,6 @@ pub fn main() -> Result<()> {
use bounding_box::draw::*; use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1); array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
} }
use ndarray::{Array2, Array3, Array4, Axis};
use ndarray_resize::NdFir;
let face_rois = array let face_rois = array
.view() .view()
.multi_roi(&output.bbox) .multi_roi(&output.bbox)
@@ -68,21 +74,19 @@ pub fn main() -> Result<()> {
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>(); let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = CHUNK_SIZE;
let embeddings = face_roi_views let embeddings = face_roi_views
.chunks(8) .chunks(chunk_size)
.map(|chunk| { .map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len()); tracing::info!("Processing chunk of size: {}", chunk.len());
if chunk.len() < 8 { if chunk.len() < 8 {
tracing::warn!("Chunk size is less than 8, padding with zeros"); tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((512, 512, 3)); let zeros = Array3::zeros((512, 512, 3));
let padded: Vec<ndarray::ArrayView3<'_, u8>> = chunk let zero_array = core::iter::repeat(zeros.view())
.iter() .take(chunk_size)
.cloned() .collect::<Vec<_>>();
.chain(core::iter::repeat(zeros.view())) let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
.take(8)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), padded.as_slice())
.change_context(errors::Error) .change_context(errors::Error)
.attach_printable("Failed to stack rois together")?; .attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?; let output = facenet.run_models(face_rois.view()).change_context(Error)?;