feat(compare): add face comparison functionality with cosine similarity
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m18s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-21 17:34:07 +05:30
parent f8122892e0
commit bfa389b497
15 changed files with 1188 additions and 107 deletions

View File

@@ -1,6 +1,5 @@
use std::path::PathBuf;
use mnn::ForwardType;
#[derive(Debug, clap::Parser)]
pub struct Cli {
#[clap(subcommand)]
@@ -11,14 +10,16 @@ pub struct Cli {
pub enum SubCommand {
#[clap(name = "detect")]
Detect(Detect),
#[clap(name = "list")]
List(List),
#[clap(name = "detect-multi")]
DetectMulti(DetectMulti),
#[clap(name = "query")]
Query(Query),
#[clap(name = "similar")]
Similar(Similar),
#[clap(name = "stats")]
Stats(Stats),
#[clap(name = "compare")]
Compare(Compare),
#[clap(name = "completions")]
Completions { shell: clap_complete::Shell },
}
@@ -74,7 +75,47 @@ pub struct Detect {
}
#[derive(Debug, clap::Args)]
pub struct List {}
pub struct DetectMulti {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output_dir: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(
long,
help = "Image extensions to process (e.g., jpg,png,jpeg)",
default_value = "jpg,jpeg,png,bmp,tiff,webp"
)]
pub extensions: String,
#[clap(help = "Directory containing images to process")]
pub input_dir: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Query {
@@ -108,6 +149,41 @@ pub struct Stats {
pub database: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Compare {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(help = "First image to compare")]
pub image1: PathBuf,
#[clap(help = "Second image to compare")]
pub image2: PathBuf,
}
impl Cli {
pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command();

View File

@@ -65,7 +65,14 @@ impl FaceDatabase {
/// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?;
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
unsafe {
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
conn.load_extension(
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
None::<&str>,
)
.change_context(Error)?;
}
let db = Self { conn };
db.create_tables()?;
Ok(db)
@@ -190,10 +197,9 @@ impl FaceDatabase {
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
stmt.execute(params![file_path, width, height])
.change_context(Error)?;
Ok(self.conn.last_insert_rowid())
Ok(stmt
.insert(params![file_path, width, height])
.change_context(Error)?)
}
/// Store face detection results
@@ -231,17 +237,16 @@ impl FaceDatabase {
)
.change_context(Error)?;
stmt.execute(params![
image_id,
bbox.x1() as f32,
bbox.y1() as f32,
bbox.x2() as f32,
bbox.y2() as f32,
confidence
])
.change_context(Error)?;
Ok(self.conn.last_insert_rowid())
Ok(stmt
.insert(params![
image_id,
bbox.x1() as f32,
bbox.y1() as f32,
bbox.x2() as f32,
bbox.y2() as f32,
confidence
])
.change_context(Error)?)
}
/// Store face landmarks
@@ -258,22 +263,21 @@ impl FaceDatabase {
)
.change_context(Error)?;
stmt.execute(params![
face_id,
landmarks.left_eye.x,
landmarks.left_eye.y,
landmarks.right_eye.x,
landmarks.right_eye.y,
landmarks.nose.x,
landmarks.nose.y,
landmarks.left_mouth.x,
landmarks.left_mouth.y,
landmarks.right_mouth.x,
landmarks.right_mouth.y,
])
.change_context(Error)?;
Ok(self.conn.last_insert_rowid())
Ok(stmt
.insert(params![
face_id,
landmarks.left_eye.x,
landmarks.left_eye.y,
landmarks.right_eye.x,
landmarks.right_eye.y,
landmarks.nose.x,
landmarks.nose.y,
landmarks.left_mouth.x,
landmarks.left_mouth.y,
landmarks.right_mouth.x,
landmarks.right_mouth.y,
])
.change_context(Error)?)
}
/// Store face embeddings
@@ -310,12 +314,12 @@ impl FaceDatabase {
embedding: ndarray::ArrayView1<f32>,
model_name: &str,
) -> Result<i64> {
let embedding_bytes =
let safe_arrays =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
.change_context(Error)?
.serialize()
.change_context(Error)?;
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
let mut stmt = self
.conn
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
@@ -462,6 +466,35 @@ impl FaceDatabase {
Ok(embeddings)
}
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT images.id, images.file_path, images.width, images.height, images.created_at
FROM images
JOIN faces ON faces.image_id = images.id
WHERE faces.id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get database statistics
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
let images: usize = self
@@ -528,6 +561,39 @@ impl FaceDatabase {
Ok(result)
}
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)
.unwrap()
.serialize()
.change_context(Error)
.unwrap();
let mut stmt = self
.conn
.prepare(
r#"
SELECT face_id,
cosine_similarity(?1, embedding)
FROM embeddings
"#,
)
.change_context(Error)
.unwrap();
let result_iter = stmt
.query_map(params![embedding_bytes], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)
.unwrap();
for result in result_iter {
println!("{:?}", result);
}
}
}
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {

View File

@@ -310,7 +310,7 @@ pub trait FaceDetector {
fn detect_faces(
&mut self,
image: ndarray::ArrayView3<u8>,
config: FaceDetectionConfig,
config: &FaceDetectionConfig,
) -> Result<FaceDetectionOutput> {
let (height, width, _channels) = image.dim();
let output = self

View File

@@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
use crate::errors::*;
use ndarray::{Array2, ArrayView4};
pub mod preprocessing {
use ndarray::*;
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
let mean = image.mean().unwrap_or(0.0);
let std = image.std(0.0);
if std > 0.0 {
image.mapv_inplace(|x| (x - mean) / std);
} else {
image.mapv_inplace(|x| (x - 127.5) / 128.0)
}
});
owned
}
}
/// Common trait for face embedding backends - maintained for backward compatibility
pub trait FaceEmbedder {
/// Generate embeddings for a batch of face images

View File

@@ -4,6 +4,7 @@ pub mod ort;
use crate::errors::*;
use error_stack::ResultExt;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use ndarray_math::{CosineSimilarity, EuclideanDistance};
/// Configuration for face embedding processing
#[derive(Debug, Clone, PartialEq)]
@@ -32,9 +33,9 @@ impl FaceEmbeddingConfig {
impl Default for FaceEmbeddingConfig {
fn default() -> Self {
Self {
input_width: 160,
input_height: 160,
normalize: true,
input_width: 320,
input_height: 320,
normalize: false,
}
}
}
@@ -63,15 +64,14 @@ impl FaceEmbedding {
/// Calculate cosine similarity with another embedding
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
let dot_product = self.vector.dot(&other.vector);
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
dot_product / (norm_self * norm_other)
self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
}
/// Calculate Euclidean distance with another embedding
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt()
self.vector
.euclidean_distance(other.vector.view())
.unwrap_or(f32::INFINITY)
}
/// Normalize the embedding vector to unit length

View File

@@ -64,10 +64,7 @@ impl EmbeddingGenerator {
}
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = face
// .permuted_axes((0, 3, 1, 2))
.as_standard_layout()
.mapv(|x| x as f32);
let tensor = crate::faceembed::preprocessing::preprocess(face);
let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32);
let output = self

View File

@@ -135,10 +135,12 @@ impl EmbeddingGenerator {
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Convert input from u8 to f32 and normalize to [0, 1] range
let input_tensor = faces
.mapv(|x| x as f32 / 255.0)
.as_standard_layout()
.into_owned();
let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
// face_array = np.asarray(face_resized, 'float32')
// mean, std = face_array.mean(), face_array.std()
// face_normalized = (face_array - mean) / std
// let input_tensor = faces.mean()
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());

View File

@@ -75,8 +75,61 @@ pub fn main() -> Result<()> {
}
}
}
cli::SubCommand::List(list) => {
println!("List: {:?}", list);
cli::SubCommand::DetectMulti(detect_multi) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect_multi
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect_multi.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(
detect_multi.ort_execution_provider.clone(),
))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_multi_detection(detect_multi, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_multi_detection(detect_multi, retinaface, facenet)?;
}
}
}
cli::SubCommand::Query(query) => {
run_query(query)?;
@@ -87,6 +140,59 @@ pub fn main() -> Result<()> {
cli::SubCommand::Stats(stats) => {
run_stats(stats)?;
}
cli::SubCommand::Compare(compare) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = compare
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if compare.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(compare.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_compare(compare, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_compare(compare, retinaface, facenet)?;
}
}
}
cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell);
}
@@ -122,7 +228,7 @@ where
let output = retinaface
.detect_faces(
array.view(),
FaceDetectionConfig::default()
&FaceDetectionConfig::default()
.with_threshold(detect.threshold)
.with_nms_threshold(detect.nms_threshold),
)
@@ -163,7 +269,7 @@ where
// })
.map(|roi| {
roi.as_standard_layout()
.fast_resize(160, 160, &ResizeOptions::default())
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
// .inspect(|f| {
@@ -182,11 +288,14 @@ where
if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((160, 160, 3));
let zero_array = core::iter::repeat(zeros.view())
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect::<Vec<_>>();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
@@ -328,6 +437,446 @@ fn run_query(query: cli::Query) -> Result<()> {
Ok(())
}
fn run_compare<D, E>(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
// Helper function to detect faces and compute embeddings for an image
fn process_image<D, E>(
image_path: &std::path::Path,
retinaface: &mut D,
facenet: &mut E,
config: &FaceDetectionConfig,
batch_size: usize,
) -> Result<(Vec<Array1<f32>>, usize)>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
let image = image::open(image_path)
.change_context(Error)
.attach_printable(image_path.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(array.view(), config)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
tracing::info!(
"Detected {} faces in {}",
output.bbox.len(),
image_path.display()
);
if output.bbox.is_empty() {
return Ok((Vec::new(), 0));
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
if chunk.len() < chunk_size {
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
}
})
.collect::<Result<Vec<Array2<f32>>>>()?;
// Flatten embeddings into individual face embeddings
let mut face_embeddings = Vec::new();
for embedding_batch in embeddings {
for i in 0..output.bbox.len().min(embedding_batch.nrows()) {
face_embeddings.push(embedding_batch.row(i).to_owned());
}
}
Ok((face_embeddings, output.bbox.len()))
}
// Helper function to compute cosine similarity between two embeddings
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
dot_product / (norm_a * norm_b)
}
let config = FaceDetectionConfig::default()
.with_threshold(compare.threshold)
.with_nms_threshold(compare.nms_threshold);
// Process both images
let (embeddings1, face_count1) = process_image(
&compare.image1,
&mut retinaface,
&mut facenet,
&config,
compare.batch_size,
)?;
let (embeddings2, face_count2) = process_image(
&compare.image2,
&mut retinaface,
&mut facenet,
&config,
compare.batch_size,
)?;
println!(
"Image 1 ({}): {} faces detected",
compare.image1.display(),
face_count1
);
println!(
"Image 2 ({}): {} faces detected",
compare.image2.display(),
face_count2
);
if embeddings1.is_empty() && embeddings2.is_empty() {
println!("No faces detected in either image");
return Ok(());
}
if embeddings1.is_empty() {
println!("No faces detected in image 1");
return Ok(());
}
if embeddings2.is_empty() {
println!("No faces detected in image 2");
return Ok(());
}
// Compare all faces between the two images
println!("\nFace comparison results:");
println!("========================");
let mut max_similarity = f32::NEG_INFINITY;
let mut best_match = (0, 0);
for (i, emb1) in embeddings1.iter().enumerate() {
for (j, emb2) in embeddings2.iter().enumerate() {
let similarity = cosine_similarity(emb1, emb2);
println!(
"Face {} (image 1) vs Face {} (image 2): {:.4}",
i + 1,
j + 1,
similarity
);
if similarity > max_similarity {
max_similarity = similarity;
best_match = (i + 1, j + 1);
}
}
}
println!(
"\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}",
best_match.0, best_match.1, max_similarity
);
// Interpretation of similarity score
if max_similarity > 0.8 {
println!("Interpretation: Very likely the same person");
} else if max_similarity > 0.6 {
println!("Interpretation: Possibly the same person");
} else if max_similarity > 0.4 {
println!("Interpretation: Unlikely to be the same person");
} else {
println!("Interpretation: Very unlikely to be the same person");
}
Ok(())
}
fn run_multi_detection<D, E>(
detect_multi: cli::DetectMulti,
mut retinaface: D,
mut facenet: E,
) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
use std::fs;
// Initialize database - always save to database for multi-detection
let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?;
// Parse supported extensions
let extensions: std::collections::HashSet<String> = detect_multi
.extensions
.split(',')
.map(|ext| ext.trim().to_lowercase())
.collect();
// Create output directory if specified
if let Some(ref output_dir) = detect_multi.output_dir {
fs::create_dir_all(output_dir)
.change_context(Error)
.attach_printable("Failed to create output directory")?;
}
// Read directory and filter image files
let entries = fs::read_dir(&detect_multi.input_dir)
.change_context(Error)
.attach_printable("Failed to read input directory")?;
let mut image_paths = Vec::new();
for entry in entries {
let entry = entry.change_context(Error)?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
if extensions.contains(&ext.to_lowercase()) {
image_paths.push(path);
}
}
}
}
if image_paths.is_empty() {
tracing::warn!(
"No image files found in directory: {:?}",
detect_multi.input_dir
);
return Ok(());
}
tracing::info!("Found {} image files to process", image_paths.len());
let mut total_faces = 0;
let mut processed_images = 0;
// Process each image
for (idx, image_path) in image_paths.iter().enumerate() {
tracing::info!(
"Processing image {}/{}: {:?}",
idx + 1,
image_paths.len(),
image_path
);
// Load and process image
let image = match image::open(image_path) {
Ok(img) => img.into_rgb8(),
Err(e) => {
tracing::error!("Failed to load image {:?}: {}", image_path, e);
continue;
}
};
let (image_width, image_height) = image.dimensions();
let mut array = match image.into_ndarray().change_context(errors::Error) {
Ok(arr) => arr,
Err(e) => {
tracing::error!("Failed to convert image to ndarray: {:?}", e);
continue;
}
};
let config = FaceDetectionConfig::default()
.with_threshold(detect_multi.threshold)
.with_nms_threshold(detect_multi.nms_threshold);
// Detect faces
let output = match retinaface.detect_faces(array.view(), &config) {
Ok(output) => output,
Err(e) => {
tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e);
continue;
}
};
let num_faces = output.bbox.len();
total_faces += num_faces;
if num_faces == 0 {
tracing::info!("No faces detected in {:?}", image_path);
} else {
tracing::info!("Detected {} faces in {:?}", num_faces, image_path);
}
// Store image and detections in database
let image_path_str = image_path.to_string_lossy();
let img_id = match db.store_image(&image_path_str, image_width, image_height) {
Ok(id) => id,
Err(e) => {
tracing::error!("Failed to store image in database: {:?}", e);
continue;
}
};
let face_ids = match db.store_face_detections(img_id, &output) {
Ok(ids) => ids,
Err(e) => {
tracing::error!("Failed to store face detections in database: {:?}", e);
continue;
}
};
// Draw bounding boxes if output directory is specified
if detect_multi.output_dir.is_some() {
for bbox in &output.bbox {
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
}
// Process face embeddings if faces were detected
if !face_ids.is_empty() {
let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) {
Ok(rois) => rois,
Err(e) => {
tracing::error!("Failed to extract face ROIs: {:?}", e);
continue;
}
};
let face_rois: Result<Vec<_>> = face_rois
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
.collect();
let face_rois = match face_rois {
Ok(rois) => rois,
Err(e) => {
tracing::error!("Failed to resize face ROIs: {:?}", e);
continue;
}
};
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = detect_multi.batch_size;
let embeddings: Result<Vec<Array2<f32>>> = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
if chunk.len() < chunk_size {
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
}
})
.collect();
let embeddings = match embeddings {
Ok(emb) => emb,
Err(e) => {
tracing::error!("Failed to generate embeddings: {:?}", e);
continue;
}
};
// Store embeddings in database
if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) {
tracing::error!("Failed to store embeddings in database: {:?}", e);
continue;
}
}
// Save output image if directory specified
if let Some(ref output_dir) = detect_multi.output_dir {
let output_filename = format!(
"detected_{}",
image_path.file_name().unwrap().to_string_lossy()
);
let output_path = output_dir.join(output_filename);
let v = array.view();
let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) {
Ok(img) => img,
Err(e) => {
tracing::error!("Failed to convert ndarray to image: {:?}", e);
continue;
}
};
if let Err(e) = output_image.save(&output_path) {
tracing::error!("Failed to save output image to {:?}: {}", output_path, e);
continue;
}
tracing::info!("Saved output image to {:?}", output_path);
}
processed_images += 1;
}
// Print final statistics
tracing::info!(
"Processing complete: {}/{} images processed successfully, {} total faces detected",
processed_images,
image_paths.len(),
total_faces
);
let (num_images, num_faces, num_landmarks, num_embeddings) =
db.get_stats().change_context(Error)?;
tracing::info!(
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
num_images,
num_faces,
num_landmarks,
num_embeddings
);
Ok(())
}
fn run_similar(similar: cli::Similar) -> Result<()> {
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
@@ -341,14 +890,19 @@ fn run_similar(similar: cli::Similar) -> Result<()> {
let similar_faces = db
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
.change_context(Error)?;
// Get image information for the similar faces
println!(
"Found {} similar faces (threshold: {:.3}):",
similar_faces.len(),
similar.threshold
);
for (face_id, similarity) in similar_faces {
println!(" Face {}: similarity {:.3}", face_id, similarity);
for (face_id, similarity) in &similar_faces {
if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? {
println!(
" Face {}: similarity {:.3}, image: {}",
face_id, similarity, image_info.file_path
);
}
}
Ok(())