feat: Added stuff
This commit is contained in:
919
Cargo.lock
generated
919
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
|
|||||||
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||||
"tracing",
|
"tracing",
|
||||||
], branch = "restructure-tensor-type" }
|
], branch = "restructure-tensor-type" }
|
||||||
|
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "detector"
|
name = "detector"
|
||||||
@@ -35,7 +36,7 @@ clap_complete = "4.5"
|
|||||||
error-stack = "0.5"
|
error-stack = "0.5"
|
||||||
fast_image_resize = "5.2.0"
|
fast_image_resize = "5.2.0"
|
||||||
image = "0.25.6"
|
image = "0.25.6"
|
||||||
nalgebra = "0.33.2"
|
nalgebra = { workspace = true }
|
||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
ndarray-image = { workspace = true }
|
ndarray-image = { workspace = true }
|
||||||
ndarray-resize = { workspace = true }
|
ndarray-resize = { workspace = true }
|
||||||
@@ -52,6 +53,7 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
|
|||||||
color = "0.3.1"
|
color = "0.3.1"
|
||||||
itertools = "0.14.0"
|
itertools = "0.14.0"
|
||||||
ordered-float = "5.0.0"
|
ordered-float = "5.0.0"
|
||||||
|
ort = "2.0.0-rc.10"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
|
|||||||
27
Makefile.toml
Normal file
27
Makefile.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
[tasks.convert_facenet]
|
||||||
|
command = "MNNConvert"
|
||||||
|
args = [
|
||||||
|
"-f",
|
||||||
|
"ONNX",
|
||||||
|
"--modelFile",
|
||||||
|
"models/facenet.onnx",
|
||||||
|
"--MNNModel",
|
||||||
|
"models/facenet.mnn",
|
||||||
|
"--fp16",
|
||||||
|
"--bizCode",
|
||||||
|
"MNN",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tasks.convert_retinaface]
|
||||||
|
command = "MNNConvert"
|
||||||
|
args = [
|
||||||
|
"-f",
|
||||||
|
"ONNX",
|
||||||
|
"--modelFile",
|
||||||
|
"models/retinaface.onnx",
|
||||||
|
"--MNNModel",
|
||||||
|
"models/retinaface.mnn",
|
||||||
|
"--fp16",
|
||||||
|
"--bizCode",
|
||||||
|
"MNN",
|
||||||
|
]
|
||||||
@@ -6,7 +6,7 @@ edition = "2024"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
color = "0.3.1"
|
color = "0.3.1"
|
||||||
itertools = "0.14.0"
|
itertools = "0.14.0"
|
||||||
nalgebra = "0.33.2"
|
nalgebra = { workspace = true }
|
||||||
ndarray = { version = "0.16.1", optional = true }
|
ndarray = { version = "0.16.1", optional = true }
|
||||||
num = "0.4.3"
|
num = "0.4.3"
|
||||||
ordered-float = "5.0.0"
|
ordered-float = "5.0.0"
|
||||||
|
|||||||
BIN
facenet.mnn
Normal file
BIN
facenet.mnn
Normal file
Binary file not shown.
214
flake.nix
214
flake.nix
@@ -27,20 +27,22 @@
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
outputs = {
|
outputs =
|
||||||
self,
|
{
|
||||||
crane,
|
self,
|
||||||
flake-utils,
|
crane,
|
||||||
nixpkgs,
|
flake-utils,
|
||||||
rust-overlay,
|
nixpkgs,
|
||||||
advisory-db,
|
rust-overlay,
|
||||||
nix-github-actions,
|
advisory-db,
|
||||||
mnn-overlay,
|
nix-github-actions,
|
||||||
mnn-src,
|
mnn-overlay,
|
||||||
...
|
mnn-src,
|
||||||
}:
|
...
|
||||||
|
}:
|
||||||
flake-utils.lib.eachDefaultSystem (
|
flake-utils.lib.eachDefaultSystem (
|
||||||
system: let
|
system:
|
||||||
|
let
|
||||||
pkgs = import nixpkgs {
|
pkgs = import nixpkgs {
|
||||||
inherit system;
|
inherit system;
|
||||||
overlays = [
|
overlays = [
|
||||||
@@ -61,118 +63,148 @@
|
|||||||
|
|
||||||
stableToolchain = pkgs.rust-bin.stable.latest.default;
|
stableToolchain = pkgs.rust-bin.stable.latest.default;
|
||||||
stableToolchainWithLLvmTools = stableToolchain.override {
|
stableToolchainWithLLvmTools = stableToolchain.override {
|
||||||
extensions = ["rust-src" "llvm-tools"];
|
extensions = [
|
||||||
|
"rust-src"
|
||||||
|
"llvm-tools"
|
||||||
|
];
|
||||||
};
|
};
|
||||||
stableToolchainWithRustAnalyzer = stableToolchain.override {
|
stableToolchainWithRustAnalyzer = stableToolchain.override {
|
||||||
extensions = ["rust-src" "rust-analyzer"];
|
extensions = [
|
||||||
|
"rust-src"
|
||||||
|
"rust-analyzer"
|
||||||
|
];
|
||||||
};
|
};
|
||||||
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
||||||
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
||||||
|
|
||||||
src = let
|
src =
|
||||||
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
let
|
||||||
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
|
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
||||||
in
|
sourceFilters =
|
||||||
|
path: type:
|
||||||
|
(craneLib.filterCargoSources path type)
|
||||||
|
|| filterBySuffix path [
|
||||||
|
".c"
|
||||||
|
".h"
|
||||||
|
".hpp"
|
||||||
|
".cpp"
|
||||||
|
".cc"
|
||||||
|
];
|
||||||
|
in
|
||||||
lib.cleanSourceWith {
|
lib.cleanSourceWith {
|
||||||
filter = sourceFilters;
|
filter = sourceFilters;
|
||||||
src = ./.;
|
src = ./.;
|
||||||
};
|
};
|
||||||
commonArgs =
|
commonArgs = {
|
||||||
{
|
inherit src;
|
||||||
inherit src;
|
pname = name;
|
||||||
pname = name;
|
stdenv = pkgs.clangStdenv;
|
||||||
stdenv = pkgs.clangStdenv;
|
doCheck = false;
|
||||||
doCheck = false;
|
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
# nativeBuildInputs = with pkgs; [
|
||||||
# nativeBuildInputs = with pkgs; [
|
# cmake
|
||||||
# cmake
|
# llvmPackages.libclang.lib
|
||||||
# llvmPackages.libclang.lib
|
# ];
|
||||||
# ];
|
buildInputs =
|
||||||
buildInputs = with pkgs;
|
with pkgs;
|
||||||
[]
|
[ ]
|
||||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||||
libiconv
|
libiconv
|
||||||
apple-sdk_13
|
apple-sdk_13
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
// (lib.optionalAttrs pkgs.stdenv.isLinux {
|
// (lib.optionalAttrs pkgs.stdenv.isLinux {
|
||||||
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
|
# BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include";
|
||||||
});
|
});
|
||||||
cargoArtifacts = craneLib.buildPackage commonArgs;
|
cargoArtifacts = craneLib.buildPackage commonArgs;
|
||||||
in {
|
in
|
||||||
checks =
|
{
|
||||||
{
|
checks = {
|
||||||
"${name}-clippy" = craneLib.cargoClippy (commonArgs
|
"${name}-clippy" = craneLib.cargoClippy (
|
||||||
// {
|
commonArgs
|
||||||
inherit cargoArtifacts;
|
// {
|
||||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
inherit cargoArtifacts;
|
||||||
});
|
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
|
}
|
||||||
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
|
);
|
||||||
"${name}-toml-fmt" = craneLib.taploFmt {
|
"${name}-docs" = craneLib.cargoDoc (commonArgs // { inherit cargoArtifacts; });
|
||||||
src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"];
|
"${name}-fmt" = craneLib.cargoFmt { inherit src; };
|
||||||
};
|
"${name}-toml-fmt" = craneLib.taploFmt {
|
||||||
# Audit dependencies
|
src = pkgs.lib.sources.sourceFilesBySuffices src [ ".toml" ];
|
||||||
"${name}-audit" = craneLib.cargoAudit {
|
};
|
||||||
inherit src advisory-db;
|
# Audit dependencies
|
||||||
};
|
"${name}-audit" = craneLib.cargoAudit {
|
||||||
|
inherit src advisory-db;
|
||||||
# Audit licenses
|
|
||||||
"${name}-deny" = craneLib.cargoDeny {
|
|
||||||
inherit src;
|
|
||||||
};
|
|
||||||
"${name}-nextest" = craneLib.cargoNextest (commonArgs
|
|
||||||
// {
|
|
||||||
inherit cargoArtifacts;
|
|
||||||
partitions = 1;
|
|
||||||
partitionType = "count";
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
|
||||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
packages = let
|
# Audit licenses
|
||||||
pkg = craneLib.buildPackage (commonArgs
|
"${name}-deny" = craneLib.cargoDeny {
|
||||||
// {inherit cargoArtifacts;}
|
inherit src;
|
||||||
|
};
|
||||||
|
"${name}-nextest" = craneLib.cargoNextest (
|
||||||
|
commonArgs
|
||||||
// {
|
// {
|
||||||
nativeBuildInputs = with pkgs; [
|
inherit cargoArtifacts;
|
||||||
installShellFiles
|
partitions = 1;
|
||||||
];
|
partitionType = "count";
|
||||||
postInstall = ''
|
}
|
||||||
installShellCompletion --cmd ${name} \
|
);
|
||||||
--bash <($out/bin/${name} completions bash) \
|
}
|
||||||
--fish <($out/bin/${name} completions fish) \
|
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
||||||
--zsh <($out/bin/${name} completions zsh)
|
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // { inherit cargoArtifacts; });
|
||||||
'';
|
|
||||||
});
|
|
||||||
in {
|
|
||||||
"${name}" = pkg;
|
|
||||||
default = pkg;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
packages =
|
||||||
|
let
|
||||||
|
pkg = craneLib.buildPackage (
|
||||||
|
commonArgs
|
||||||
|
// {
|
||||||
|
inherit cargoArtifacts;
|
||||||
|
}
|
||||||
|
// {
|
||||||
|
nativeBuildInputs = with pkgs; [
|
||||||
|
installShellFiles
|
||||||
|
];
|
||||||
|
postInstall = ''
|
||||||
|
installShellCompletion --cmd ${name} \
|
||||||
|
--bash <($out/bin/${name} completions bash) \
|
||||||
|
--fish <($out/bin/${name} completions fish) \
|
||||||
|
--zsh <($out/bin/${name} completions zsh)
|
||||||
|
'';
|
||||||
|
}
|
||||||
|
);
|
||||||
|
in
|
||||||
|
{
|
||||||
|
"${name}" = pkg;
|
||||||
|
default = pkg;
|
||||||
|
};
|
||||||
|
|
||||||
devShells = {
|
devShells = {
|
||||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
|
default = pkgs.mkShell.override { stdenv = pkgs.clangStdenv; } (
|
||||||
|
commonArgs
|
||||||
// {
|
// {
|
||||||
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
||||||
packages = with pkgs;
|
packages =
|
||||||
|
with pkgs;
|
||||||
[
|
[
|
||||||
stableToolchainWithRustAnalyzer
|
stableToolchainWithRustAnalyzer
|
||||||
cargo-nextest
|
cargo-nextest
|
||||||
cargo-deny
|
cargo-deny
|
||||||
cmake
|
cmake
|
||||||
mnn
|
mnn
|
||||||
|
cargo-make
|
||||||
]
|
]
|
||||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||||
apple-sdk_13
|
apple-sdk_13
|
||||||
]);
|
]);
|
||||||
});
|
}
|
||||||
|
);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
// {
|
// {
|
||||||
githubActions = nix-github-actions.lib.mkGithubMatrix {
|
githubActions = nix-github-actions.lib.mkGithubMatrix {
|
||||||
checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks;
|
checks = nixpkgs.lib.getAttrs [ "x86_64-linux" ] self.checks;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ pub struct Detect {
|
|||||||
pub model_type: Models,
|
pub model_type: Models,
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
pub output: Option<PathBuf>,
|
pub output: Option<PathBuf>,
|
||||||
|
#[clap(short, long, default_value = "cpu")]
|
||||||
|
pub forward_type: mnn::ForwardType,
|
||||||
#[clap(short, long, default_value_t = 0.8)]
|
#[clap(short, long, default_value_t = 0.8)]
|
||||||
pub threshold: f32,
|
pub threshold: f32,
|
||||||
#[clap(short, long, default_value_t = 0.3)]
|
#[clap(short, long, default_value_t = 0.3)]
|
||||||
|
|||||||
@@ -246,7 +246,56 @@ impl FaceDetectionModelOutput {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct FaceDetectionBuilder {
|
||||||
|
schedule_config: Option<mnn::ScheduleConfig>,
|
||||||
|
backend_config: Option<mnn::BackendConfig>,
|
||||||
|
model: mnn::Interpreter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDetectionBuilder {
|
||||||
|
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
schedule_config: None,
|
||||||
|
backend_config: None,
|
||||||
|
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||||
|
.map_err(|e| e.into_inner())
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to load model from bytes")?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||||
|
self.schedule_config
|
||||||
|
.get_or_insert_default()
|
||||||
|
.set_type(forward_type);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||||
|
self.schedule_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||||
|
self.backend_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> Result<FaceDetection> {
|
||||||
|
let model = self.model;
|
||||||
|
let sc = self.schedule_config.unwrap_or_default();
|
||||||
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create session handle")?;
|
||||||
|
Ok(FaceDetection { handle })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl FaceDetection {
|
impl FaceDetection {
|
||||||
|
pub fn builder<T: AsRef<[u8]>>()
|
||||||
|
-> fn(T) -> std::result::Result<FaceDetectionBuilder, Report<Error>> {
|
||||||
|
FaceDetectionBuilder::new
|
||||||
|
}
|
||||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||||
let model = std::fs::read(path)
|
let model = std::fs::read(path)
|
||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
@@ -267,7 +316,7 @@ impl FaceDetection {
|
|||||||
.attach_printable("Failed to set cache file")?;
|
.attach_printable("Failed to set cache file")?;
|
||||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||||
let sc = mnn::ScheduleConfig::new()
|
let sc = mnn::ScheduleConfig::new()
|
||||||
.with_type(mnn::ForwardType::CPU)
|
.with_type(mnn::ForwardType::Metal)
|
||||||
.with_backend_config(bc);
|
.with_backend_config(bc);
|
||||||
tracing::info!("Creating session handle for face detection model");
|
tracing::info!("Creating session handle for face detection model");
|
||||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
|
|||||||
@@ -2,11 +2,57 @@ use crate::errors::*;
|
|||||||
use mnn_bridge::ndarray::*;
|
use mnn_bridge::ndarray::*;
|
||||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
mod mnn_impl;
|
||||||
|
mod ort_impl;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct EmbeddingGenerator {
|
pub struct EmbeddingGenerator {
|
||||||
handle: mnn_sync::SessionHandle,
|
handle: mnn_sync::SessionHandle,
|
||||||
}
|
}
|
||||||
|
pub struct EmbeddingGeneratorBuilder {
|
||||||
|
schedule_config: Option<mnn::ScheduleConfig>,
|
||||||
|
backend_config: Option<mnn::BackendConfig>,
|
||||||
|
model: mnn::Interpreter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingGeneratorBuilder {
|
||||||
|
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
schedule_config: None,
|
||||||
|
backend_config: None,
|
||||||
|
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||||
|
.map_err(|e| e.into_inner())
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to load model from bytes")?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||||
|
self.schedule_config
|
||||||
|
.get_or_insert_default()
|
||||||
|
.set_type(forward_type);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||||
|
self.schedule_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||||
|
self.backend_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||||
|
let model = self.model;
|
||||||
|
let sc = self.schedule_config.unwrap_or_default();
|
||||||
|
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create session handle")?;
|
||||||
|
Ok(EmbeddingGenerator { handle })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl EmbeddingGenerator {
|
impl EmbeddingGenerator {
|
||||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||||
@@ -18,6 +64,11 @@ impl EmbeddingGenerator {
|
|||||||
Self::new_from_bytes(&model)
|
Self::new_from_bytes(&model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn builder<T: AsRef<[u8]>>()
|
||||||
|
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
||||||
|
EmbeddingGeneratorBuilder::new
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||||
tracing::info!("Loading face embedding model from bytes");
|
tracing::info!("Loading face embedding model from bytes");
|
||||||
let mut model = mnn::Interpreter::from_bytes(model)
|
let mut model = mnn::Interpreter::from_bytes(model)
|
||||||
@@ -57,16 +108,24 @@ impl EmbeddingGenerator {
|
|||||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||||
let (intptr, session) = sr.both_mut();
|
let (intptr, session) = sr.both_mut();
|
||||||
tracing::trace!("Copying input tensor to host");
|
tracing::trace!("Copying input tensor to host");
|
||||||
unsafe {
|
let needs_resize = unsafe {
|
||||||
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
||||||
tracing::trace!("Input shape: {:?}", input.shape());
|
tracing::trace!("Input shape: {:?}", input.shape());
|
||||||
if *input.shape() != shape {
|
if *input.shape() != shape {
|
||||||
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
||||||
// input.resize(shape);
|
// input.resize(shape);
|
||||||
intptr.resize_tensor(input.view_mut(), shape);
|
intptr.resize_tensor(input.view_mut(), shape);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
if needs_resize {
|
||||||
|
tracing::trace!("Resized input tensor to shape: {:?}", shape);
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
intptr.resize_session(session);
|
||||||
|
tracing::trace!("Session resized in {:?}", now.elapsed());
|
||||||
}
|
}
|
||||||
intptr.resize_session(session);
|
|
||||||
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
||||||
tracing::trace!("Input shape: {:?}", input.shape());
|
tracing::trace!("Input shape: {:?}", input.shape());
|
||||||
input.copy_from_host_tensor(tensor.view())?;
|
input.copy_from_host_tensor(tensor.view())?;
|
||||||
|
|||||||
1
src/faceembed/facenet/mnn_impl.rs
Normal file
1
src/faceembed/facenet/mnn_impl.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
65
src/faceembed/facenet/ort_impl.rs
Normal file
65
src/faceembed/facenet/ort_impl.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
use crate::errors::{Result, *};
|
||||||
|
use ndarray::*;
|
||||||
|
use ort::*;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EmbeddingGenerator {
|
||||||
|
handle: ort::session::Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
// impl EmbeddingGeneratorBuilder {
|
||||||
|
// pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||||
|
// Ok(Self {
|
||||||
|
// schedule_config: None,
|
||||||
|
// backend_config: None,
|
||||||
|
// model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||||
|
// .map_err(|e| e.into_inner())
|
||||||
|
// .change_context(Error)
|
||||||
|
// .attach_printable("Failed to load model from bytes")?,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||||
|
// self.schedule_config
|
||||||
|
// .get_or_insert_default()
|
||||||
|
// .set_type(forward_type);
|
||||||
|
// self
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||||
|
// self.schedule_config = Some(config);
|
||||||
|
// self
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||||
|
// self.backend_config = Some(config);
|
||||||
|
// self
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||||
|
// let model = self.model;
|
||||||
|
// let sc = self.schedule_config.unwrap_or_default();
|
||||||
|
// let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||||
|
// .change_context(Error)
|
||||||
|
// .attach_printable("Failed to create session handle")?;
|
||||||
|
// Ok(EmbeddingGenerator { handle })
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
impl EmbeddingGenerator {
|
||||||
|
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||||
|
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||||
|
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||||
|
let model = std::fs::read(path)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to read model file")?;
|
||||||
|
Self::new_from_bytes(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {}
|
||||||
|
}
|
||||||
30
src/main.rs
30
src/main.rs
@@ -4,10 +4,12 @@ use bounding_box::roi::MultiRoi;
|
|||||||
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
use detector::{facedet::retinaface::FaceDetectionConfig, faceembed};
|
||||||
use errors::*;
|
use errors::*;
|
||||||
use fast_image_resize::ResizeOptions;
|
use fast_image_resize::ResizeOptions;
|
||||||
use nalgebra::zero;
|
use ndarray::*;
|
||||||
use ndarray_image::*;
|
use ndarray_image::*;
|
||||||
|
use ndarray_resize::NdFir;
|
||||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||||
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
const FACENET_MODEL: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||||
|
const CHUNK_SIZE: usize = 8;
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("trace")
|
.with_env_filter("trace")
|
||||||
@@ -19,10 +21,16 @@ pub fn main() -> Result<()> {
|
|||||||
match args.cmd {
|
match args.cmd {
|
||||||
cli::SubCommand::Detect(detect) => {
|
cli::SubCommand::Detect(detect) => {
|
||||||
use detector::facedet;
|
use detector::facedet;
|
||||||
let retinaface = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
let retinaface = facedet::retinaface::FaceDetection::builder()(RETINAFACE_MODEL)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_forward_type(detect.forward_type)
|
||||||
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
let facenet = faceembed::facenet::EmbeddingGenerator::new_from_bytes(FACENET_MODEL)
|
let facenet = faceembed::facenet::EmbeddingGenerator::builder()(FACENET_MODEL)
|
||||||
|
.change_context(Error)?
|
||||||
|
.with_forward_type(detect.forward_type)
|
||||||
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face embedding model")?;
|
.attach_printable("Failed to create face embedding model")?;
|
||||||
let image = image::open(detect.image).change_context(Error)?;
|
let image = image::open(detect.image).change_context(Error)?;
|
||||||
@@ -45,8 +53,6 @@ pub fn main() -> Result<()> {
|
|||||||
use bounding_box::draw::*;
|
use bounding_box::draw::*;
|
||||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||||
}
|
}
|
||||||
use ndarray::{Array2, Array3, Array4, Axis};
|
|
||||||
use ndarray_resize::NdFir;
|
|
||||||
let face_rois = array
|
let face_rois = array
|
||||||
.view()
|
.view()
|
||||||
.multi_roi(&output.bbox)
|
.multi_roi(&output.bbox)
|
||||||
@@ -68,21 +74,19 @@ pub fn main() -> Result<()> {
|
|||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let chunk_size = CHUNK_SIZE;
|
||||||
let embeddings = face_roi_views
|
let embeddings = face_roi_views
|
||||||
.chunks(8)
|
.chunks(chunk_size)
|
||||||
.map(|chunk| {
|
.map(|chunk| {
|
||||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||||
|
|
||||||
if chunk.len() < 8 {
|
if chunk.len() < 8 {
|
||||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||||
let zeros = Array3::zeros((512, 512, 3));
|
let zeros = Array3::zeros((512, 512, 3));
|
||||||
let padded: Vec<ndarray::ArrayView3<'_, u8>> = chunk
|
let zero_array = core::iter::repeat(zeros.view())
|
||||||
.iter()
|
.take(chunk_size)
|
||||||
.cloned()
|
.collect::<Vec<_>>();
|
||||||
.chain(core::iter::repeat(zeros.view()))
|
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||||
.take(8)
|
|
||||||
.collect();
|
|
||||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), padded.as_slice())
|
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to stack rois together")?;
|
.attach_printable("Failed to stack rois together")?;
|
||||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||||
|
|||||||
Reference in New Issue
Block a user