feat(compare): add face comparison functionality with cosine similarity
This commit is contained in:
84
src/cli.rs
84
src/cli.rs
@@ -1,6 +1,5 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use mnn::ForwardType;
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
@@ -11,14 +10,16 @@ pub struct Cli {
|
||||
pub enum SubCommand {
|
||||
#[clap(name = "detect")]
|
||||
Detect(Detect),
|
||||
#[clap(name = "list")]
|
||||
List(List),
|
||||
#[clap(name = "detect-multi")]
|
||||
DetectMulti(DetectMulti),
|
||||
#[clap(name = "query")]
|
||||
Query(Query),
|
||||
#[clap(name = "similar")]
|
||||
Similar(Similar),
|
||||
#[clap(name = "stats")]
|
||||
Stats(Stats),
|
||||
#[clap(name = "compare")]
|
||||
Compare(Compare),
|
||||
#[clap(name = "completions")]
|
||||
Completions { shell: clap_complete::Shell },
|
||||
}
|
||||
@@ -74,7 +75,47 @@ pub struct Detect {
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct List {}
|
||||
pub struct DetectMulti {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output_dir: Option<PathBuf>,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(
|
||||
long,
|
||||
help = "Image extensions to process (e.g., jpg,png,jpeg)",
|
||||
default_value = "jpg,jpeg,png,bmp,tiff,webp"
|
||||
)]
|
||||
pub extensions: String,
|
||||
#[clap(help = "Directory containing images to process")]
|
||||
pub input_dir: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Query {
|
||||
@@ -108,6 +149,41 @@ pub struct Stats {
|
||||
pub database: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Compare {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(help = "First image to compare")]
|
||||
pub image1: PathBuf,
|
||||
#[clap(help = "Second image to compare")]
|
||||
pub image2: PathBuf,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
|
||||
136
src/database.rs
136
src/database.rs
@@ -65,7 +65,14 @@ impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
||||
unsafe {
|
||||
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
conn.load_extension(
|
||||
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
None::<&str>,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
}
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
@@ -190,10 +197,9 @@ impl FaceDatabase {
|
||||
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![file_path, width, height])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
Ok(stmt
|
||||
.insert(params![file_path, width, height])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face detection results
|
||||
@@ -231,17 +237,16 @@ impl FaceDatabase {
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![
|
||||
image_id,
|
||||
bbox.x1() as f32,
|
||||
bbox.y1() as f32,
|
||||
bbox.x2() as f32,
|
||||
bbox.y2() as f32,
|
||||
confidence
|
||||
])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
Ok(stmt
|
||||
.insert(params![
|
||||
image_id,
|
||||
bbox.x1() as f32,
|
||||
bbox.y1() as f32,
|
||||
bbox.x2() as f32,
|
||||
bbox.y2() as f32,
|
||||
confidence
|
||||
])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face landmarks
|
||||
@@ -258,22 +263,21 @@ impl FaceDatabase {
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![
|
||||
face_id,
|
||||
landmarks.left_eye.x,
|
||||
landmarks.left_eye.y,
|
||||
landmarks.right_eye.x,
|
||||
landmarks.right_eye.y,
|
||||
landmarks.nose.x,
|
||||
landmarks.nose.y,
|
||||
landmarks.left_mouth.x,
|
||||
landmarks.left_mouth.y,
|
||||
landmarks.right_mouth.x,
|
||||
landmarks.right_mouth.y,
|
||||
])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
Ok(stmt
|
||||
.insert(params![
|
||||
face_id,
|
||||
landmarks.left_eye.x,
|
||||
landmarks.left_eye.y,
|
||||
landmarks.right_eye.x,
|
||||
landmarks.right_eye.y,
|
||||
landmarks.nose.x,
|
||||
landmarks.nose.y,
|
||||
landmarks.left_mouth.x,
|
||||
landmarks.left_mouth.y,
|
||||
landmarks.right_mouth.x,
|
||||
landmarks.right_mouth.y,
|
||||
])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face embeddings
|
||||
@@ -310,12 +314,12 @@ impl FaceDatabase {
|
||||
embedding: ndarray::ArrayView1<f32>,
|
||||
model_name: &str,
|
||||
) -> Result<i64> {
|
||||
let embedding_bytes =
|
||||
let safe_arrays =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
||||
.change_context(Error)?
|
||||
.serialize()
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
||||
@@ -462,6 +466,35 @@ impl FaceDatabase {
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT images.id, images.file_path, images.width, images.height, images.created_at
|
||||
FROM images
|
||||
JOIN faces ON faces.image_id = images.id
|
||||
WHERE faces.id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![face_id], |row| {
|
||||
Ok(ImageRecord {
|
||||
id: row.get(0)?,
|
||||
file_path: row.get(1)?,
|
||||
width: row.get(2)?,
|
||||
height: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get database statistics
|
||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let images: usize = self
|
||||
@@ -528,6 +561,39 @@ impl FaceDatabase {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
|
||||
let embedding_bytes =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||
.change_context(Error)
|
||||
.unwrap()
|
||||
.serialize()
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT face_id,
|
||||
cosine_similarity(?1, embedding)
|
||||
FROM embeddings
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
let result_iter = stmt
|
||||
.query_map(params![embedding_bytes], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||
})
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
for result in result_iter {
|
||||
println!("{:?}", result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
|
||||
@@ -310,7 +310,7 @@ pub trait FaceDetector {
|
||||
fn detect_faces(
|
||||
&mut self,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
config: FaceDetectionConfig,
|
||||
config: &FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, _channels) = image.dim();
|
||||
let output = self
|
||||
|
||||
@@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
|
||||
use crate::errors::*;
|
||||
use ndarray::{Array2, ArrayView4};
|
||||
|
||||
pub mod preprocessing {
|
||||
use ndarray::*;
|
||||
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
|
||||
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
|
||||
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
|
||||
let mean = image.mean().unwrap_or(0.0);
|
||||
let std = image.std(0.0);
|
||||
if std > 0.0 {
|
||||
image.mapv_inplace(|x| (x - mean) / std);
|
||||
} else {
|
||||
image.mapv_inplace(|x| (x - 127.5) / 128.0)
|
||||
}
|
||||
});
|
||||
owned
|
||||
}
|
||||
}
|
||||
|
||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||
pub trait FaceEmbedder {
|
||||
/// Generate embeddings for a batch of face images
|
||||
|
||||
@@ -4,6 +4,7 @@ pub mod ort;
|
||||
use crate::errors::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use ndarray_math::{CosineSimilarity, EuclideanDistance};
|
||||
|
||||
/// Configuration for face embedding processing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -32,9 +33,9 @@ impl FaceEmbeddingConfig {
|
||||
impl Default for FaceEmbeddingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input_width: 160,
|
||||
input_height: 160,
|
||||
normalize: true,
|
||||
input_width: 320,
|
||||
input_height: 320,
|
||||
normalize: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,15 +64,14 @@ impl FaceEmbedding {
|
||||
|
||||
/// Calculate cosine similarity with another embedding
|
||||
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
|
||||
let dot_product = self.vector.dot(&other.vector);
|
||||
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
|
||||
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
|
||||
dot_product / (norm_self * norm_other)
|
||||
self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Calculate Euclidean distance with another embedding
|
||||
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
|
||||
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt()
|
||||
self.vector
|
||||
.euclidean_distance(other.vector.view())
|
||||
.unwrap_or(f32::INFINITY)
|
||||
}
|
||||
|
||||
/// Normalize the embedding vector to unit length
|
||||
|
||||
@@ -64,10 +64,7 @@ impl EmbeddingGenerator {
|
||||
}
|
||||
|
||||
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
let tensor = face
|
||||
// .permuted_axes((0, 3, 1, 2))
|
||||
.as_standard_layout()
|
||||
.mapv(|x| x as f32);
|
||||
let tensor = crate::faceembed::preprocessing::preprocess(face);
|
||||
let shape: [usize; 4] = tensor.dim().into();
|
||||
let shape = shape.map(|f| f as i32);
|
||||
let output = self
|
||||
|
||||
@@ -135,10 +135,12 @@ impl EmbeddingGenerator {
|
||||
|
||||
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||
let input_tensor = faces
|
||||
.mapv(|x| x as f32 / 255.0)
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
|
||||
|
||||
// face_array = np.asarray(face_resized, 'float32')
|
||||
// mean, std = face_array.mean(), face_array.std()
|
||||
// face_normalized = (face_array - mean) / std
|
||||
// let input_tensor = faces.mean()
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
|
||||
576
src/main.rs
576
src/main.rs
@@ -75,8 +75,61 @@ pub fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::List(list) => {
|
||||
println!("List: {:?}", list);
|
||||
cli::SubCommand::DetectMulti(detect_multi) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
|
||||
let executor = detect_multi
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if detect_multi.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(
|
||||
detect_multi.ort_execution_provider.clone(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_multi_detection(detect_multi, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_multi_detection(detect_multi, retinaface, facenet)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::Query(query) => {
|
||||
run_query(query)?;
|
||||
@@ -87,6 +140,59 @@ pub fn main() -> Result<()> {
|
||||
cli::SubCommand::Stats(stats) => {
|
||||
run_stats(stats)?;
|
||||
}
|
||||
cli::SubCommand::Compare(compare) => {
|
||||
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
|
||||
let executor = compare
|
||||
.mnn_forward_type
|
||||
.map(|f| cli::Executor::Mnn(f))
|
||||
.or_else(|| {
|
||||
if compare.ort_execution_provider.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cli::Executor::Ort(compare.ort_execution_provider.clone()))
|
||||
}
|
||||
})
|
||||
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
|
||||
|
||||
match executor {
|
||||
cli::Executor::Mnn(forward) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.change_context(Error)?
|
||||
.with_forward_type(forward)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_compare(compare, retinaface, facenet)?;
|
||||
}
|
||||
cli::Executor::Ort(ep) => {
|
||||
let retinaface =
|
||||
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(&ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let facenet =
|
||||
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.change_context(Error)?
|
||||
.with_execution_providers(ep)
|
||||
.build()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face embedding model")?;
|
||||
|
||||
run_compare(compare, retinaface, facenet)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
@@ -122,7 +228,7 @@ where
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.view(),
|
||||
FaceDetectionConfig::default()
|
||||
&FaceDetectionConfig::default()
|
||||
.with_threshold(detect.threshold)
|
||||
.with_nms_threshold(detect.nms_threshold),
|
||||
)
|
||||
@@ -163,7 +269,7 @@ where
|
||||
// })
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(160, 160, &ResizeOptions::default())
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
// .inspect(|f| {
|
||||
@@ -182,11 +288,14 @@ where
|
||||
|
||||
if chunk.len() < chunk_size {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((160, 160, 3));
|
||||
let zero_array = core::iter::repeat(zeros.view())
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect::<Vec<_>>();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
@@ -328,6 +437,446 @@ fn run_query(query: cli::Query) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_compare<D, E>(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
// Helper function to detect faces and compute embeddings for an image
|
||||
fn process_image<D, E>(
|
||||
image_path: &std::path::Path,
|
||||
retinaface: &mut D,
|
||||
facenet: &mut E,
|
||||
config: &FaceDetectionConfig,
|
||||
batch_size: usize,
|
||||
) -> Result<(Vec<Array1<f32>>, usize)>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
let image = image::open(image_path)
|
||||
.change_context(Error)
|
||||
.attach_printable(image_path.to_string_lossy().to_string())?;
|
||||
let image = image.into_rgb8();
|
||||
let array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
|
||||
let output = retinaface
|
||||
.detect_faces(array.view(), config)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
|
||||
tracing::info!(
|
||||
"Detected {} faces in {}",
|
||||
output.bbox.len(),
|
||||
image_path.display()
|
||||
);
|
||||
|
||||
if output.bbox.is_empty() {
|
||||
return Ok((Vec::new(), 0));
|
||||
}
|
||||
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&output.bbox)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
let chunk_size = batch_size;
|
||||
let embeddings = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
if chunk.len() < chunk_size {
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>()?;
|
||||
|
||||
// Flatten embeddings into individual face embeddings
|
||||
let mut face_embeddings = Vec::new();
|
||||
for embedding_batch in embeddings {
|
||||
for i in 0..output.bbox.len().min(embedding_batch.nrows()) {
|
||||
face_embeddings.push(embedding_batch.row(i).to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((face_embeddings, output.bbox.len()))
|
||||
}
|
||||
|
||||
// Helper function to compute cosine similarity between two embeddings
|
||||
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
|
||||
let dot_product = a.dot(b);
|
||||
let norm_a = a.dot(a).sqrt();
|
||||
let norm_b = b.dot(b).sqrt();
|
||||
dot_product / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
let config = FaceDetectionConfig::default()
|
||||
.with_threshold(compare.threshold)
|
||||
.with_nms_threshold(compare.nms_threshold);
|
||||
|
||||
// Process both images
|
||||
let (embeddings1, face_count1) = process_image(
|
||||
&compare.image1,
|
||||
&mut retinaface,
|
||||
&mut facenet,
|
||||
&config,
|
||||
compare.batch_size,
|
||||
)?;
|
||||
|
||||
let (embeddings2, face_count2) = process_image(
|
||||
&compare.image2,
|
||||
&mut retinaface,
|
||||
&mut facenet,
|
||||
&config,
|
||||
compare.batch_size,
|
||||
)?;
|
||||
|
||||
println!(
|
||||
"Image 1 ({}): {} faces detected",
|
||||
compare.image1.display(),
|
||||
face_count1
|
||||
);
|
||||
println!(
|
||||
"Image 2 ({}): {} faces detected",
|
||||
compare.image2.display(),
|
||||
face_count2
|
||||
);
|
||||
|
||||
if embeddings1.is_empty() && embeddings2.is_empty() {
|
||||
println!("No faces detected in either image");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if embeddings1.is_empty() {
|
||||
println!("No faces detected in image 1");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if embeddings2.is_empty() {
|
||||
println!("No faces detected in image 2");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Compare all faces between the two images
|
||||
println!("\nFace comparison results:");
|
||||
println!("========================");
|
||||
|
||||
let mut max_similarity = f32::NEG_INFINITY;
|
||||
let mut best_match = (0, 0);
|
||||
|
||||
for (i, emb1) in embeddings1.iter().enumerate() {
|
||||
for (j, emb2) in embeddings2.iter().enumerate() {
|
||||
let similarity = cosine_similarity(emb1, emb2);
|
||||
println!(
|
||||
"Face {} (image 1) vs Face {} (image 2): {:.4}",
|
||||
i + 1,
|
||||
j + 1,
|
||||
similarity
|
||||
);
|
||||
|
||||
if similarity > max_similarity {
|
||||
max_similarity = similarity;
|
||||
best_match = (i + 1, j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}",
|
||||
best_match.0, best_match.1, max_similarity
|
||||
);
|
||||
|
||||
// Interpretation of similarity score
|
||||
if max_similarity > 0.8 {
|
||||
println!("Interpretation: Very likely the same person");
|
||||
} else if max_similarity > 0.6 {
|
||||
println!("Interpretation: Possibly the same person");
|
||||
} else if max_similarity > 0.4 {
|
||||
println!("Interpretation: Unlikely to be the same person");
|
||||
} else {
|
||||
println!("Interpretation: Very unlikely to be the same person");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_multi_detection<D, E>(
|
||||
detect_multi: cli::DetectMulti,
|
||||
mut retinaface: D,
|
||||
mut facenet: E,
|
||||
) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
use std::fs;
|
||||
|
||||
// Initialize database - always save to database for multi-detection
|
||||
let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?;
|
||||
|
||||
// Parse supported extensions
|
||||
let extensions: std::collections::HashSet<String> = detect_multi
|
||||
.extensions
|
||||
.split(',')
|
||||
.map(|ext| ext.trim().to_lowercase())
|
||||
.collect();
|
||||
|
||||
// Create output directory if specified
|
||||
if let Some(ref output_dir) = detect_multi.output_dir {
|
||||
fs::create_dir_all(output_dir)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output directory")?;
|
||||
}
|
||||
|
||||
// Read directory and filter image files
|
||||
let entries = fs::read_dir(&detect_multi.input_dir)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read input directory")?;
|
||||
|
||||
let mut image_paths = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry.change_context(Error)?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.is_file() {
|
||||
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
|
||||
if extensions.contains(&ext.to_lowercase()) {
|
||||
image_paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if image_paths.is_empty() {
|
||||
tracing::warn!(
|
||||
"No image files found in directory: {:?}",
|
||||
detect_multi.input_dir
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!("Found {} image files to process", image_paths.len());
|
||||
|
||||
let mut total_faces = 0;
|
||||
let mut processed_images = 0;
|
||||
|
||||
// Process each image
|
||||
for (idx, image_path) in image_paths.iter().enumerate() {
|
||||
tracing::info!(
|
||||
"Processing image {}/{}: {:?}",
|
||||
idx + 1,
|
||||
image_paths.len(),
|
||||
image_path
|
||||
);
|
||||
|
||||
// Load and process image
|
||||
let image = match image::open(image_path) {
|
||||
Ok(img) => img.into_rgb8(),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to load image {:?}: {}", image_path, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (image_width, image_height) = image.dimensions();
|
||||
let mut array = match image.into_ndarray().change_context(errors::Error) {
|
||||
Ok(arr) => arr,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to convert image to ndarray: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let config = FaceDetectionConfig::default()
|
||||
.with_threshold(detect_multi.threshold)
|
||||
.with_nms_threshold(detect_multi.nms_threshold);
|
||||
// Detect faces
|
||||
let output = match retinaface.detect_faces(array.view(), &config) {
|
||||
Ok(output) => output,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let num_faces = output.bbox.len();
|
||||
total_faces += num_faces;
|
||||
|
||||
if num_faces == 0 {
|
||||
tracing::info!("No faces detected in {:?}", image_path);
|
||||
} else {
|
||||
tracing::info!("Detected {} faces in {:?}", num_faces, image_path);
|
||||
}
|
||||
|
||||
// Store image and detections in database
|
||||
let image_path_str = image_path.to_string_lossy();
|
||||
let img_id = match db.store_image(&image_path_str, image_width, image_height) {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to store image in database: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let face_ids = match db.store_face_detections(img_id, &output) {
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to store face detections in database: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Draw bounding boxes if output directory is specified
|
||||
if detect_multi.output_dir.is_some() {
|
||||
for bbox in &output.bbox {
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Process face embeddings if faces were detected
|
||||
if !face_ids.is_empty() {
|
||||
let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) {
|
||||
Ok(rois) => rois,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to extract face ROIs: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let face_rois: Result<Vec<_>> = face_rois
|
||||
.into_iter()
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(320, 320, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let face_rois = match face_rois {
|
||||
Ok(rois) => rois,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to resize face ROIs: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let chunk_size = detect_multi.batch_size;
|
||||
let embeddings: Result<Vec<Array2<f32>>> = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
if chunk.len() < chunk_size {
|
||||
let zeros = Array3::zeros((320, 320, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
facenet.run_models(face_rois.view()).change_context(Error)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let embeddings = match embeddings {
|
||||
Ok(emb) => emb,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to generate embeddings: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Store embeddings in database
|
||||
if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) {
|
||||
tracing::error!("Failed to store embeddings in database: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Save output image if directory specified
|
||||
if let Some(ref output_dir) = detect_multi.output_dir {
|
||||
let output_filename = format!(
|
||||
"detected_{}",
|
||||
image_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let output_path = output_dir.join(output_filename);
|
||||
|
||||
let v = array.view();
|
||||
let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) {
|
||||
Ok(img) => img,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to convert ndarray to image: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = output_image.save(&output_path) {
|
||||
tracing::error!("Failed to save output image to {:?}: {}", output_path, e);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!("Saved output image to {:?}", output_path);
|
||||
}
|
||||
|
||||
processed_images += 1;
|
||||
}
|
||||
|
||||
// Print final statistics
|
||||
tracing::info!(
|
||||
"Processing complete: {}/{} images processed successfully, {} total faces detected",
|
||||
processed_images,
|
||||
image_paths.len(),
|
||||
total_faces
|
||||
);
|
||||
|
||||
let (num_images, num_faces, num_landmarks, num_embeddings) =
|
||||
db.get_stats().change_context(Error)?;
|
||||
tracing::info!(
|
||||
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
|
||||
num_images,
|
||||
num_faces,
|
||||
num_landmarks,
|
||||
num_embeddings
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
|
||||
|
||||
@@ -341,14 +890,19 @@ fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||
let similar_faces = db
|
||||
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Get image information for the similar faces
|
||||
println!(
|
||||
"Found {} similar faces (threshold: {:.3}):",
|
||||
similar_faces.len(),
|
||||
similar.threshold
|
||||
);
|
||||
for (face_id, similarity) in similar_faces {
|
||||
println!(" Face {}: similarity {:.3}", face_id, similarity);
|
||||
for (face_id, similarity) in &similar_faces {
|
||||
if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? {
|
||||
println!(
|
||||
" Face {}: similarity {:.3}, image: {}",
|
||||
face_id, similarity, image_info.file_path
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user