feat: Added stuff
This commit is contained in:
919
Cargo.lock
generated
919
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
|
||||
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||
"tracing",
|
||||
], branch = "restructure-tensor-type" }
|
||||
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
|
||||
|
||||
[package]
|
||||
name = "detector"
|
||||
@@ -35,7 +36,7 @@ clap_complete = "4.5"
|
||||
error-stack = "0.5"
|
||||
fast_image_resize = "5.2.0"
|
||||
image = "0.25.6"
|
||||
nalgebra = "0.33.2"
|
||||
nalgebra = { workspace = true }
|
||||
ndarray = "0.16.1"
|
||||
ndarray-image = { workspace = true }
|
||||
ndarray-resize = { workspace = true }
|
||||
@@ -52,6 +53,7 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
ordered-float = "5.0.0"
|
||||
ort = "2.0.0-rc.10"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
27
Makefile.toml
Normal file
27
Makefile.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[tasks.convert_facenet]
|
||||
command = "MNNConvert"
|
||||
args = [
|
||||
"-f",
|
||||
"ONNX",
|
||||
"--modelFile",
|
||||
"models/facenet.onnx",
|
||||
"--MNNModel",
|
||||
"models/facenet.mnn",
|
||||
"--fp16",
|
||||
"--bizCode",
|
||||
"MNN",
|
||||
]
|
||||
|
||||
[tasks.convert_retinaface]
|
||||
command = "MNNConvert"
|
||||
args = [
|
||||
"-f",
|
||||
"ONNX",
|
||||
"--modelFile",
|
||||
"models/retinaface.onnx",
|
||||
"--MNNModel",
|
||||
"models/retinaface.mnn",
|
||||
"--fp16",
|
||||
"--bizCode",
|
||||
"MNN",
|
||||
]
|
||||
@@ -6,7 +6,7 @@ edition = "2024"
|
||||
[dependencies]
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
nalgebra = "0.33.2"
|
||||
nalgebra = { workspace = true }
|
||||
ndarray = { version = "0.16.1", optional = true }
|
||||
num = "0.4.3"
|
||||
ordered-float = "5.0.0"
|
||||
|
||||
BIN
facenet.mnn
Normal file
BIN
facenet.mnn
Normal file
Binary file not shown.
90
flake.nix
90
flake.nix
@@ -27,7 +27,8 @@
|
||||
};
|
||||
};
|
||||
|
||||
outputs = {
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
crane,
|
||||
flake-utils,
|
||||
@@ -40,7 +41,8 @@
|
||||
...
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system: let
|
||||
system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
overlays = [
|
||||
@@ -61,24 +63,39 @@
|
||||
|
||||
stableToolchain = pkgs.rust-bin.stable.latest.default;
|
||||
stableToolchainWithLLvmTools = stableToolchain.override {
|
||||
extensions = ["rust-src" "llvm-tools"];
|
||||
extensions = [
|
||||
"rust-src"
|
||||
"llvm-tools"
|
||||
];
|
||||
};
|
||||
stableToolchainWithRustAnalyzer = stableToolchain.override {
|
||||
extensions = ["rust-src" "rust-analyzer"];
|
||||
extensions = [
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
};
|
||||
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
||||
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
||||
|
||||
src = let
|
||||
src =
|
||||
let
|
||||
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
||||
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
|
||||
sourceFilters =
|
||||
path: type:
|
||||
(craneLib.filterCargoSources path type)
|
||||
|| filterBySuffix path [
|
||||
".c"
|
||||
".h"
|
||||
".hpp"
|
||||
".cpp"
|
||||
".cc"
|
||||
];
|
||||
in
|
||||
lib.cleanSourceWith {
|
||||
filter = sourceFilters;
|
||||
src = ./.;
|
||||
};
|
||||
commonArgs =
|
||||
{
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = name;
|
||||
stdenv = pkgs.clangStdenv;
|
||||
@@ -88,8 +105,9 @@
|
||||
# cmake
|
||||
# llvmPackages.libclang.lib
|
||||
# ];
|
||||
buildInputs = with pkgs;
|
||||
[]
|
||||
buildInputs =
|
||||
with pkgs;
|
||||
[ ]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
libiconv
|
||||
apple-sdk_13
|
||||
@@ -99,18 +117,20 @@
|
||||
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
|
||||
});
|
||||
cargoArtifacts = craneLib.buildPackage commonArgs;
|
||||
in {
|
||||
checks =
|
||||
in
|
||||
{
|
||||
"${name}-clippy" = craneLib.cargoClippy (commonArgs
|
||||
checks = {
|
||||
"${name}-clippy" = craneLib.cargoClippy (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||
});
|
||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
|
||||
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
|
||||
}
|
||||
);
|
||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; });
|
||||
"${name}-fmt" = craneLib.cargoFmt { inherit src; };
|
||||
"${name}-toml-fmt" = craneLib.taploFmt {
|
||||
src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"];
|
||||
src = pkgs.lib.sources.sourceFilesBySuffices src [ ".toml" ];
|
||||
};
|
||||
# Audit dependencies
|
||||
"${name}-audit" = craneLib.cargoAudit {
|
||||
@@ -121,20 +141,26 @@
|
||||
"${name}-deny" = craneLib.cargoDeny {
|
||||
inherit src;
|
||||
};
|
||||
"${name}-nextest" = craneLib.cargoNextest (commonArgs
|
||||
"${name}-nextest" = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
partitions = 1;
|
||||
partitionType = "count";
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
|
||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; });
|
||||
};
|
||||
|
||||
packages = let
|
||||
pkg = craneLib.buildPackage (commonArgs
|
||||
// {inherit cargoArtifacts;}
|
||||
packages =
|
||||
let
|
||||
pkg = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
}
|
||||
// {
|
||||
nativeBuildInputs = with pkgs; [
|
||||
installShellFiles
|
||||
@@ -145,34 +171,40 @@
|
||||
--fish <($out/bin/${name} completions fish) \
|
||||
--zsh <($out/bin/${name} completions zsh)
|
||||
'';
|
||||
});
|
||||
in {
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
"${name}" = pkg;
|
||||
default = pkg;
|
||||
};
|
||||
|
||||
devShells = {
|
||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
|
||||
default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } (
|
||||
commonArgs
|
||||
// {
|
||||
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
||||
packages = with pkgs;
|
||||
packages =
|
||||
with pkgs;
|
||||
[
|
||||
stableToolchainWithRustAnalyzer
|
||||
cargo-nextest
|
||||
cargo-deny
|
||||
cmake
|
||||
mnn
|
||||
cargo-make
|
||||
]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
apple-sdk_13
|
||||
]);
|
||||
});
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
)
|
||||
// {
|
||||
githubActions = nix-github-actions.lib.mkGithubMatrix {
|
||||
checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks;
|
||||
checks = nixpkgs.lib.getAttrs [ "x86_64-linux" ] self.checks;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -48,6 +48,8 @@ pub struct Detect {
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
#[clap(short, long, default_value = "cpu")]
|
||||
pub forward_type: mnn::ForwardType,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
|
||||
@@ -246,7 +246,56 @@ impl FaceDetectionModelOutput {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<FaceDetection> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
||||
FaceDetectionBuilder::new
|
||||
}
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
@@ -267,7 +316,7 @@ impl FaceDetection {
|
||||
.attach_printable("Failed to set cache file")?;
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::CPU)
|
||||
.with_type(mnn::ForwardType::Metal)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face detection model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
|
||||
@@ -2,11 +2,57 @@ use crate::errors::*;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use std::path::Path;
|
||||
mod mnn_impl;
|
||||
mod ort_impl;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingGenerator {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
pub struct EmbeddingGeneratorBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl EmbeddingGeneratorBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(EmbeddingGenerator { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
@@ -18,6 +64,11 @@ impl EmbeddingGenerator {
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn builder<T: AsRef<[u8]>>()
|
||||
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
||||
EmbeddingGeneratorBuilder::new
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||
tracing::info!("Loading face embedding model from bytes");
|
||||
let mut model = mnn::Interpreter::from_bytes(model)
|
||||
@@ -57,16 +108,24 @@ impl EmbeddingGenerator {
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
unsafe {
|
||||
let needs_resize = unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
if *input.shape() != shape {
|
||||
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
||||
// input.resize(shape);
|
||||
intptr.resize_tensor(input.view_mut(), shape);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
};
|
||||
if needs_resize {
|
||||
tracing::trace!("Resized input tensor to shape: {:?}", shape);
|
||||
let now = std::time::Instant::now();
|
||||
intptr.resize_session(session);
|
||||
tracing::trace!("Session resized in {:?}", now.elapsed());
|
||||
}
|
||||
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
1
src/faceembed/facenet/mnn_impl.rs
Normal file
1
src/faceembed/facenet/mnn_impl.rs
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
65
src/faceembed/facenet/ort_impl.rs
Normal file
65
src/faceembed/facenet/ort_impl.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use crate::errors::{Result, *};
|
||||
use ndarray::*;
|
||||
use ort::*;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingGenerator {
|
||||
handle: ort::session::Session,
|
||||
}
|
||||
|
||||
// impl EmbeddingGeneratorBuilder {
|
||||
// pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
// Ok(Self {
|
||||
// schedule_config: None,
|
||||
// backend_config: None,
|
||||
// model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
// .map_err(|e| e.into_inner())
|
||||
// .change_context(Error)
|
||||
// .attach_printable("Failed to load model from bytes")?,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
// self.schedule_config
|
||||
// .get_or_insert_default()
|
||||
// .set_type(forward_type);
|
||||
// self
|
||||
// }
|
||||
//
|
||||
// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
// self.schedule_config = Some(config);
|
||||
// self
|
||||
// }
|
||||
//
|
||||
// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
// self.backend_config = Some(config);
|
||||
// self
|
||||
// }
|
||||
//
|
||||
// pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||
// let model = self.model;
|
||||
// let sc = self.schedule_config.unwrap_or_default();
|
||||
// let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
// .change_context(Error)
|
||||
// .attach_printable("Failed to create session handle")?;
|
||||
// Ok(EmbeddingGenerator { handle })
|
||||
// }
|
||||
// }
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {}
|
||||
}
|
||||
30
src/main.rs
30
src/main.rs
@@ -4,10 +4,12 @@ use bounding_box::roi::MultiRoi;
|
||||
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
use nalgebra::zero;
|
||||
use ndarray::*;
|
||||
use ndarray_image::*;
|
||||
use ndarray_resize::NdFir;
|
||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
@@ -19,10 +21,16 @@ pub fn main() -> Result<()> {
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
use detector::facedet;
|
||||
let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||
let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL)
|
||||
let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(detect.forward_type)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
@@ -45,8 +53,6 @@ pub fn main() -> Result<()> {
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
use ndarray::{Array2, Array3, Array4, Axis};
|
||||
use ndarray_resize::NdFir;
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
@@ -68,21 +74,19 @@ pub fn main() -> Result<()> {
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let chunk_size = CHUNK_SIZE;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(8)
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
if chunk.len() < 8 {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((512, 512, 3));
|
||||
let padded: Vec<ndarray::ArrayView3<'_, u8>> = chunk
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(8)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), padded.as_slice())
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
|
||||
Reference in New Issue
Block a user