feat: Initial commit
This commit is contained in:
64
src/cli.rs
Normal file
64
src/cli.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use std::path::PathBuf;
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
pub cmd: SubCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
pub enum SubCommand {
|
||||
#[clap(name = "detect")]
|
||||
Detect(Detect),
|
||||
#[clap(name = "list")]
|
||||
List(List),
|
||||
#[clap(name = "completions")]
|
||||
Completions { shell: clap_complete::Shell },
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum Models {
|
||||
RetinaFace,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum Executor {
|
||||
Mnn,
|
||||
Onnx,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum OnnxEp {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum MnnEp {
|
||||
Cpu,
|
||||
Metal,
|
||||
OpenCL,
|
||||
CoreML,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Detect {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct List {}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
clap_complete::generate(
|
||||
shell,
|
||||
&mut command,
|
||||
env!("CARGO_BIN_NAME"),
|
||||
&mut std::io::stdout(),
|
||||
);
|
||||
}
|
||||
}
|
||||
6
src/errors.rs
Normal file
6
src/errors.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub use error_stack::{Report, ResultExt};
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("An error occurred")]
|
||||
pub struct Error;
|
||||
|
||||
pub type Result<T, E = error_stack::Report<Error>> = core::result::Result<T, E>;
|
||||
73
src/facedet.rs
Normal file
73
src/facedet.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use crate::errors::*;
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::NdarrayToMnn;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||
tracing::info!("Loading face detection model from bytes");
|
||||
let mut model = mnn::Interpreter::from_bytes(model)
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.with_type(mnn::ForwardType::CPU)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face detection model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
|
||||
pub fn detect_faces(&self, image: ndarray::Array3<u8>) -> Result<ndarray::Array2<u8>> {
|
||||
use mnn_bridge::ndarray::MnnToNdarray;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = image
|
||||
.as_mnn_tensor()
|
||||
.ok_or_else(|| Error)
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::error::ErrorKind::TensorError)?;
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
// let input = intptr.input::<u8>(session, "input")?;
|
||||
// dbg!(input.shape());
|
||||
// let mut t = input.create_host_tensor_from_device(false);
|
||||
// tensor.copy_to_host_tensor(&mut t)?;
|
||||
//
|
||||
// intptr.run_session(&session)?;
|
||||
// let output = intptr.output::<u8>(&session, "output").unwrap();
|
||||
// let output_tensor = output.create_host_tensor_from_device(true);
|
||||
// let output_array = output_tensor
|
||||
// .try_as_ndarray()
|
||||
// .change_context(mnn::error::ErrorKind::TensorError)?
|
||||
// .to_owned();
|
||||
// Ok(output_array)
|
||||
Ok(ndarray::Array2::<u8>::zeros((1, 1))) // Placeholder for actual output
|
||||
})
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error);
|
||||
output
|
||||
}
|
||||
}
|
||||
5
src/image.rs
Normal file
5
src/image.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
// pub struct Image {
|
||||
// pub width: u32,
|
||||
// pub height: u32,
|
||||
// pub data: Vec<u8>,
|
||||
// }
|
||||
4
src/lib.rs
Normal file
4
src/lib.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod errors;
|
||||
pub mod facedet;
|
||||
pub mod image;
|
||||
use errors::*;
|
||||
37
src/main.rs
Normal file
37
src/main.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
mod cli;
|
||||
mod errors;
|
||||
use errors::*;
|
||||
use ndarray_image::*;
|
||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_target(false)
|
||||
.init();
|
||||
let args = <cli::Cli as clap::Parser>::parse();
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
use detector::facedet;
|
||||
let model = facedet::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image.into_rgb8();
|
||||
let array = image.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
model.detect_faces(array)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
}
|
||||
cli::SubCommand::List(list) => {
|
||||
println!("List: {:?}", list);
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user