refactor: replace bbox::BBox with bounding_box::Aabb2 across codebase
This commit is contained in:
@@ -1,11 +1,17 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::errors;
|
||||
use crate::facedet::{FaceDetectionConfig, FaceDetector, retinaface};
|
||||
use crate::faceembed::facenet;
|
||||
use crate::faceembed::{FaceNetEmbedder, facenet};
|
||||
use crate::gui::app::{ComparisonResult, DetectionResult, ExecutorType};
|
||||
use bounding_box::Aabb2;
|
||||
use bounding_box::roi::MultiRoi as _;
|
||||
use error_stack::ResultExt;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
use ndarray::{Array1, Array2, Array3, Array4};
|
||||
use ndarray_image::ImageToNdarray;
|
||||
use ndarray_math::CosineSimilarity;
|
||||
use ndarray_resize::NdFir;
|
||||
|
||||
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../../models/facenet.mnn");
|
||||
@@ -176,12 +182,12 @@ impl FaceDetectionBridge {
|
||||
executor_type: ExecutorType,
|
||||
) -> Result<(usize, usize, f32), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Load both images
|
||||
let img1 = image::open(&image1_path)?.to_rgb8();
|
||||
let img2 = image::open(&image2_path)?.to_rgb8();
|
||||
// let img1 = image::open(&image1_path)?.to_rgb8();
|
||||
// let img2 = image::open(&image2_path)?.to_rgb8();
|
||||
|
||||
// Convert to ndarray format
|
||||
let image1_array = img1.as_ndarray()?;
|
||||
let image2_array = img2.as_ndarray()?;
|
||||
// let image1_array = img1.as_ndarray()?;
|
||||
// let image2_array = img2.as_ndarray()?;
|
||||
|
||||
// Create detection configuration
|
||||
let config1 = FaceDetectionConfig::default()
|
||||
@@ -212,112 +218,171 @@ impl FaceDetectionBridge {
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(forward_type)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
// Detect faces in both images
|
||||
let faces1 = detector
|
||||
.detect_faces(image1_array.view(), &config1)
|
||||
.map_err(|e| format!("Detection failed for image 1: {}", e))?;
|
||||
let faces2 = detector
|
||||
.detect_faces(image2_array.view(), &config2)
|
||||
.map_err(|e| format!("Detection failed for image 2: {}", e))?;
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
// Extract face crops and generate embeddings
|
||||
let mut best_similarity = 0.0f32;
|
||||
|
||||
(faces1, faces2, best_similarity)
|
||||
}
|
||||
ExecutorType::OnnxCpu => {
|
||||
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create ONNX detector: {}", e))?
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build ONNX detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create ONNX embedder: {}", e))?
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build ONNX embedder: {}", e))?;
|
||||
|
||||
// Detect faces in both images
|
||||
let faces1 = detector
|
||||
.detect_faces(image1_array.view(), &config1)
|
||||
.map_err(|e| format!("Detection failed for image 1: {}", e))?;
|
||||
let faces2 = detector
|
||||
.detect_faces(image2_array.view(), &config2)
|
||||
.map_err(|e| format!("Detection failed for image 2: {}", e))?;
|
||||
|
||||
// Extract face crops and generate embeddings
|
||||
let mut best_similarity = 0.0f32;
|
||||
|
||||
if faces1.bbox.is_empty() || faces2.bbox.is_empty() {
|
||||
return Ok((faces1.bbox.len(), faces2.bbox.len(), 0.0));
|
||||
}
|
||||
if faces1.bbox.len() != faces2.bbox.len() {
|
||||
return Err(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Face count mismatch between images",
|
||||
)));
|
||||
}
|
||||
|
||||
(faces1, faces2, best_similarity)
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
(img_1, img_2, best_similarity)
|
||||
}
|
||||
ExecutorType::OnnxCpu => unimplemented!(),
|
||||
};
|
||||
|
||||
Ok((faces1.bbox.len(), faces2.bbox.len(), best_similarity))
|
||||
}
|
||||
}
|
||||
|
||||
// for bbox1 in &faces1.bbox {
|
||||
// let crop1 = Self::crop_face_from_image(&img1, bbox1)?;
|
||||
// let crop1_array = ndarray::Array::from_shape_vec(
|
||||
// (1, crop1.height() as usize, crop1.width() as usize, 3),
|
||||
// crop1
|
||||
// .pixels()
|
||||
// .flat_map(|p| [p.0[0], p.0[1], p.0[2]])
|
||||
// .collect(),
|
||||
// )?;
|
||||
|
||||
// let embedding1 = embedder
|
||||
// .run_models(crop1_array.view())
|
||||
// .map_err(|e| format!("Embedding generation failed: {}", e))?;
|
||||
|
||||
// for bbox2 in &faces2.bbox {
|
||||
// let crop2 = Self::crop_face_from_image(&img2, bbox2)?;
|
||||
// let crop2_array = ndarray::Array::from_shape_vec(
|
||||
// (1, crop2.height() as usize, crop2.width() as usize, 3),
|
||||
// crop2
|
||||
// .pixels()
|
||||
// .flat_map(|p| [p.0[0], p.0[1], p.0[2]])
|
||||
// .collect(),
|
||||
// )?;
|
||||
|
||||
// let embedding2 = embedder
|
||||
// .run_models(crop2_array.view())
|
||||
// .map_err(|e| format!("Embedding generation failed: {}", e))?;
|
||||
|
||||
// let similarity = Self::cosine_similarity(
|
||||
// embedding1.row(0).as_slice().unwrap(),
|
||||
// embedding2.row(0).as_slice().unwrap(),
|
||||
// );
|
||||
// best_similarity = best_similarity.max(similarity);
|
||||
// }
|
||||
// }
|
||||
|
||||
use crate::errors::Error;
|
||||
pub fn compare_faces(
|
||||
image1: &[Aabb2<usize>],
|
||||
image2: &[Aabb2<usize>],
|
||||
faces_1: &[Array1<f32>],
|
||||
faces_2: &[Array1<f32>],
|
||||
) -> Result<f32, error_stack::Report<crate::errors::Error>> {
|
||||
use error_stack::Report;
|
||||
|
||||
if image1.is_empty() || image2.is_empty() {
|
||||
if faces_1.is_empty() || faces_2.is_empty() {
|
||||
Err(Report::new(crate::errors::Error))
|
||||
.change_context(Report::new(crate::errors::Error))
|
||||
.attach_printable("One or both images have no detected faces")
|
||||
.attach_printable("One or both images have no detected faces")?;
|
||||
}
|
||||
Ok(0.0)
|
||||
if faces_1.len() != faces_2.len() {
|
||||
Err(Report::new(crate::errors::Error))
|
||||
.attach_printable("Face count mismatch between images")?;
|
||||
}
|
||||
Ok(faces_1
|
||||
.iter()
|
||||
.zip(faces_2)
|
||||
.flat_map(|(face_1, face_2)| face_1.cosine_similarity(face_2))
|
||||
.inspect(|v| tracing::info!("Cosine similarity: {}", v))
|
||||
.map(|v| ordered_float::OrderedFloat(v))
|
||||
.max()
|
||||
.map(|v| v.0)
|
||||
.ok_or(Report::new(Error))?)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DetectionOutput {
|
||||
bbox: Vec<Aabb2<usize>>,
|
||||
rois: Vec<ndarray::Array3<u8>>,
|
||||
embeddings: Vec<Array1<f32>>,
|
||||
}
|
||||
|
||||
fn run_detection<D, E>(
|
||||
image: impl AsRef<std::path::Path>,
|
||||
retinaface: &mut D,
|
||||
facenet: &mut E,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
chunk_size: usize,
|
||||
) -> crate::errors::Result<DetectionOutput>
|
||||
where
|
||||
D: crate::facedet::FaceDetector,
|
||||
E: crate::faceembed::FaceEmbedder,
|
||||
{
|
||||
use errors::*;
|
||||
// Initialize database if requested
|
||||
let image = image.as_ref();
|
||||
let image = image::open(image)
|
||||
.change_context(Error)
|
||||
.attach_printable(image.to_string_lossy().to_string())?;
|
||||
let image = image.into_rgb8();
|
||||
// let (image_width, image_height) = image.dimensions();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.view(),
|
||||
&FaceDetectionConfig::default()
|
||||
.with_threshold(threshold)
|
||||
.with_nms_threshold(nms_threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
dbg!(&output);
|
||||
|
||||
for bbox in &output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let embeddings: Vec<Array1<f32>> = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
let og_size = chunk.len();
|
||||
if chunk.len() < chunk_size {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(ndarray::Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok((output, og_size))
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(ndarray::Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok((output, og_size))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<(Array2<f32>, usize)>>>()?
|
||||
.into_iter()
|
||||
.map(|(chunk, size): (Array2<f32>, usize)| {
|
||||
use itertools::Itertools;
|
||||
chunk
|
||||
.rows()
|
||||
.into_iter()
|
||||
.take(size)
|
||||
.map(|row| row.to_owned())
|
||||
.collect_vec()
|
||||
.into_iter()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<Array1<f32>>>();
|
||||
|
||||
Ok(DetectionOutput {
|
||||
bbox: output.bbox,
|
||||
rois: face_rois,
|
||||
embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user