feat: Added stuff
This commit is contained in:
@@ -56,6 +56,8 @@ pub struct Detect {
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
|
||||
19
src/main.rs
19
src/main.rs
@@ -11,10 +11,9 @@ const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
const CHUNK_SIZE: usize = 2;
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
.with_env_filter("error")
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_target(false)
|
||||
@@ -35,8 +34,6 @@ pub fn main() -> Result<()> {
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
// .then_some(cli::Executor::Mnn)
|
||||
// .unwrap_or(cli::Executor::Ort);
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
@@ -92,7 +89,9 @@ where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image::open(&detect.image)
|
||||
.change_context(Error)
|
||||
.attach_printable(detect.image.to_string_lossy().to_string())?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
@@ -122,7 +121,7 @@ where
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(512, 512, &ResizeOptions::default())
|
||||
.fast_resize(160, 160, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
@@ -133,15 +132,15 @@ where
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let chunk_size = CHUNK_SIZE;
|
||||
let chunk_size = detect.batch_size;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
if chunk.len() < 8 {
|
||||
if chunk.len() < chunk_size {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((512, 512, 3));
|
||||
let zeros = Array3::zeros((160, 160, 3));
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
@@ -158,7 +157,7 @@ where
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>();
|
||||
.collect::<Result<Vec<Array2<f32>>>>()?;
|
||||
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
|
||||
@@ -107,7 +107,15 @@ impl ExecutionProvider {
|
||||
{
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
{
|
||||
Some(CoreMLExecutionProvider::default().build())
|
||||
use tap::Tap;
|
||||
|
||||
Some(
|
||||
CoreMLExecutionProvider::default()
|
||||
.with_model_format(
|
||||
ort::execution_providers::coreml::CoreMLModelFormat::MLProgram,
|
||||
)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
#[cfg(not(feature = "ort-coreml"))]
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user