diff --git a/rfcs b/rfcs index 98ec027..c973203 160000 --- a/rfcs +++ b/rfcs @@ -1 +1 @@ -Subproject commit 98ec027ca5aa27b390b91e7d619c2512f76574d3 +Subproject commit c973203daf24793c02ce9e75cd2fda6858f1241c diff --git a/src/faceembed.rs b/src/faceembed.rs index d89ceb0..ccdb743 100644 --- a/src/faceembed.rs +++ b/src/faceembed.rs @@ -1,20 +1,18 @@ +pub mod facenet; + +// Re-export common types and traits +pub use facenet::FaceNetEmbedder; +pub use facenet::{FaceEmbedding, FaceEmbeddingConfig, IntoEmbeddings}; + +// Convenience type aliases for different backends +pub use facenet::mnn::EmbeddingGenerator as MnnEmbeddingGenerator; +pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator; + use crate::errors::*; use ndarray::{Array2, ArrayView4}; -pub mod mnn; -pub mod ort; - -/// Common trait for face embedding backends +/// Common trait for face embedding backends - maintained for backward compatibility pub trait FaceEmbedder { /// Generate embeddings for a batch of face images fn run_models(&self, faces: ArrayView4) -> Result>; } - -// Convenience type aliases for different backends -pub mod facenet { - pub use crate::faceembed::mnn::facenet as mnn; - pub use crate::faceembed::ort::facenet as ort; -} - -// Default to MNN implementation for backward compatibility -pub use mnn::facenet::EmbeddingGenerator; diff --git a/src/faceembed/facenet.rs b/src/faceembed/facenet.rs new file mode 100644 index 0000000..6977d2e --- /dev/null +++ b/src/faceembed/facenet.rs @@ -0,0 +1,209 @@ +pub mod mnn; +pub mod ort; + +use crate::errors::*; +use error_stack::ResultExt; +use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; + +/// Configuration for face embedding processing +#[derive(Debug, Clone, PartialEq)] +pub struct FaceEmbeddingConfig { + /// Input image width expected by the model + pub input_width: usize, + /// Input image height expected by the model + pub input_height: usize, + /// Whether to normalize embeddings to unit vectors + pub normalize: bool, +} + +impl FaceEmbeddingConfig { + pub fn with_input_size(mut self, width: usize, height: usize) -> Self { + self.input_width = width; + self.input_height = height; + self + } + + pub fn with_normalization(mut self, normalize: bool) -> Self { + self.normalize = normalize; + self + } +} + +impl Default for FaceEmbeddingConfig { + fn default() -> Self { + Self { + input_width: 160, + input_height: 160, + normalize: true, + } + } +} + +/// Represents a face embedding vector +#[derive(Debug, Clone, PartialEq)] +pub struct FaceEmbedding { + /// The embedding vector + pub vector: Array1, + /// Optional confidence score for the embedding quality + pub confidence: Option, +} + +impl FaceEmbedding { + pub fn new(vector: Array1) -> Self { + Self { + vector, + confidence: None, + } + } + + pub fn with_confidence(mut self, confidence: f32) -> Self { + self.confidence = Some(confidence); + self + } + + /// Calculate cosine similarity with another embedding + pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 { + let dot_product = self.vector.dot(&other.vector); + let norm_self = self.vector.mapv(|x| x * x).sum().sqrt(); + let norm_other = other.vector.mapv(|x| x * x).sum().sqrt(); + dot_product / (norm_self * norm_other) + } + + /// Calculate Euclidean distance with another embedding + pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 { + (&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt() + } + + /// Normalize the embedding vector to unit length + pub fn normalize(&mut self) { + let norm = self.vector.mapv(|x| x * x).sum().sqrt(); + if norm > 0.0 { + self.vector.mapv_inplace(|x| x / norm); + } + } + + /// Get the dimensionality of the embedding + pub fn dimension(&self) -> usize { + self.vector.len() + } +} + +/// Raw model outputs that can be converted to embeddings +pub trait IntoEmbeddings { + fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result>; +} + +impl IntoEmbeddings for Array2 { + fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result> { + let mut embeddings = Vec::new(); + + for row in self.rows() { + let mut vector = row.to_owned(); + + if config.normalize { + let norm = vector.mapv(|x| x * x).sum().sqrt(); + if norm > 0.0 { + vector.mapv_inplace(|x| x / norm); + } + } + + embeddings.push(FaceEmbedding::new(vector)); + } + + Ok(embeddings) + } +} + +/// Common trait for face embedding backends +pub trait FaceNetEmbedder { + /// Generate embeddings for a batch of face images + fn run_model(&mut self, faces: ArrayView4) -> Result>; + + /// Generate embeddings with full pipeline including postprocessing + fn generate_embeddings( + &mut self, + faces: ArrayView4, + config: FaceEmbeddingConfig, + ) -> Result> { + let raw_output = self + .run_model(faces) + .change_context(Error) + .attach_printable("Failed to generate embeddings")?; + + raw_output + .into_embeddings(&config) + .attach_printable("Failed to process embeddings") + } + + /// Generate a single embedding from a single face image + fn generate_embedding( + &mut self, + face: ArrayView3, + config: FaceEmbeddingConfig, + ) -> Result { + // Add batch dimension + let face_batch = face.insert_axis(ndarray::Axis(0)); + let embeddings = self.generate_embeddings(face_batch.view(), config)?; + + embeddings + .into_iter() + .next() + .ok_or(Error) + .attach_printable("No embedding generated for input face") + } +} + +/// Utility functions for embedding processing +pub mod utils { + use super::*; + + /// Compute pairwise cosine similarities between two sets of embeddings + pub fn pairwise_cosine_similarities( + embeddings1: &[FaceEmbedding], + embeddings2: &[FaceEmbedding], + ) -> Array2 { + let n1 = embeddings1.len(); + let n2 = embeddings2.len(); + let mut similarities = Array2::zeros((n1, n2)); + + for (i, emb1) in embeddings1.iter().enumerate() { + for (j, emb2) in embeddings2.iter().enumerate() { + similarities[(i, j)] = emb1.cosine_similarity(emb2); + } + } + + similarities + } + + /// Find the best matching embedding from a gallery for each query + pub fn find_best_matches( + queries: &[FaceEmbedding], + gallery: &[FaceEmbedding], + ) -> Vec<(usize, f32)> { + let similarities = pairwise_cosine_similarities(queries, gallery); + let mut best_matches = Vec::new(); + + for i in 0..queries.len() { + let row = similarities.row(i); + let (best_idx, best_score) = row + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap(); + best_matches.push((best_idx, *best_score)); + } + + best_matches + } + + /// Filter embeddings by minimum quality threshold + pub fn filter_by_confidence( + embeddings: Vec, + min_confidence: f32, + ) -> Vec { + embeddings + .into_iter() + .filter(|emb| emb.confidence.map_or(true, |conf| conf >= min_confidence)) + .collect() + } +} diff --git a/src/faceembed/mnn/facenet.rs b/src/faceembed/facenet/mnn.rs similarity index 92% rename from src/faceembed/mnn/facenet.rs rename to src/faceembed/facenet/mnn.rs index 6206c79..4ac97ab 100644 --- a/src/faceembed/mnn/facenet.rs +++ b/src/faceembed/facenet/mnn.rs @@ -1,5 +1,5 @@ use crate::errors::*; -use crate::faceembed::FaceEmbedder; +use crate::faceembed::facenet::FaceNetEmbedder; use mnn_bridge::ndarray::*; use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use std::path::Path; @@ -63,9 +63,10 @@ impl EmbeddingGenerator { Self::new_from_bytes(&model) } - pub fn builder>() - -> fn(T) -> std::result::Result> { - EmbeddingGeneratorBuilder::new + pub fn builder>( + model: T, + ) -> std::result::Result> { + EmbeddingGeneratorBuilder::new(model) } pub fn new_from_bytes(model: &[u8]) -> Result { @@ -151,7 +152,14 @@ impl EmbeddingGenerator { // } } -impl FaceEmbedder for EmbeddingGenerator { +impl FaceNetEmbedder for EmbeddingGenerator { + fn run_model(&mut self, faces: ArrayView4) -> Result> { + self.run_models(faces) + } +} + +// Main trait implementation for backward compatibility +impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { fn run_models(&self, faces: ArrayView4) -> Result> { self.run_models(faces) } diff --git a/src/faceembed/facenet/ort.rs b/src/faceembed/facenet/ort.rs new file mode 100644 index 0000000..0c59b32 --- /dev/null +++ b/src/faceembed/facenet/ort.rs @@ -0,0 +1,162 @@ +use crate::errors::*; +use crate::faceembed::facenet::FaceNetEmbedder; +use error_stack::ResultExt; +use ndarray::{Array2, ArrayView4}; +use ort::{ + execution_providers::{ + CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch, + }, + session::{Session, builder::GraphOptimizationLevel}, + value::Tensor, +}; +use std::path::Path; + +#[derive(Debug)] +pub struct EmbeddingGenerator { + session: Session, +} + +pub struct EmbeddingGeneratorBuilder { + model_data: Vec, + execution_providers: Option>, + intra_threads: Option, + inter_threads: Option, +} + +impl EmbeddingGeneratorBuilder { + pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result { + Ok(Self { + model_data: model.as_ref().to_vec(), + execution_providers: None, + intra_threads: None, + inter_threads: None, + }) + } + + pub fn with_execution_providers(mut self, providers: Vec) -> Self { + let execution_providers: Vec = providers + .into_iter() + .filter_map(|provider| match provider.as_str() { + "cpu" | "CPU" => Some(CPUExecutionProvider::default().build()), + #[cfg(target_os = "macos")] + "coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()), + _ => { + tracing::warn!("Unknown execution provider: {}", provider); + None + } + }) + .collect(); + + if !execution_providers.is_empty() { + self.execution_providers = Some(execution_providers); + } else { + tracing::warn!("No valid execution providers found, falling back to CPU"); + self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]); + } + self + } + + pub fn with_intra_threads(mut self, threads: usize) -> Self { + self.intra_threads = Some(threads); + self + } + + pub fn with_inter_threads(mut self, threads: usize) -> Self { + self.inter_threads = Some(threads); + self + } + + pub fn build(self) -> crate::errors::Result { + let mut session_builder = Session::builder() + .change_context(Error) + .attach_printable("Failed to create session builder")?; + + // Set execution providers + if let Some(providers) = self.execution_providers { + session_builder = session_builder + .with_execution_providers(providers) + .change_context(Error) + .attach_printable("Failed to set execution providers")?; + } else { + // Default to CPU + session_builder = session_builder + .with_execution_providers([CPUExecutionProvider::default().build()]) + .change_context(Error) + .attach_printable("Failed to set default CPU execution provider")?; + } + + // Set threading options + if let Some(threads) = self.intra_threads { + session_builder = session_builder + .with_intra_threads(threads) + .change_context(Error) + .attach_printable("Failed to set intra threads")?; + } + + if let Some(threads) = self.inter_threads { + session_builder = session_builder + .with_inter_threads(threads) + .change_context(Error) + .attach_printable("Failed to set inter threads")?; + } + + // Set optimization level + session_builder = session_builder + .with_optimization_level(GraphOptimizationLevel::Level3) + .change_context(Error) + .attach_printable("Failed to set optimization level")?; + + // Create session from model bytes + let session = session_builder + .commit_from_memory(&self.model_data) + .change_context(Error) + .attach_printable("Failed to create ORT session from model bytes")?; + + tracing::info!("Successfully created ORT RetinaFace session"); + + Ok(EmbeddingGenerator { session }) + } +} + +impl EmbeddingGenerator { + const INPUT_NAME: &'static str = "serving_default_input_6:0"; + const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; + + pub fn builder>( + model: T, + ) -> std::result::Result> + { + EmbeddingGeneratorBuilder::new(model) + } + + pub fn new(path: impl AsRef) -> crate::errors::Result { + let model = std::fs::read(path) + .change_context(Error) + .attach_printable("Failed to read model file")?; + Self::new_from_bytes(&model) + } + + pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result { + tracing::info!("Loading face embedding model from bytes"); + Self::builder(model)?.build() + } + + pub fn run_models(&self, _face: ArrayView4) -> crate::errors::Result> { + // TODO: Implement ORT inference + tracing::error!("ORT FaceNet inference not yet implemented"); + Err(Error).attach_printable("ORT FaceNet implementation is incomplete") + } +} + +impl FaceNetEmbedder for EmbeddingGenerator { + fn run_model(&mut self, faces: ArrayView4) -> crate::errors::Result> { + self.run_models(faces) + } +} + +// Main trait implementation for backward compatibility +impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { + fn run_models(&self, faces: ArrayView4) -> crate::errors::Result> { + self.run_models(faces) + } +} diff --git a/src/faceembed/mnn/mod.rs b/src/faceembed/mnn/mod.rs deleted file mode 100644 index 94700e6..0000000 --- a/src/faceembed/mnn/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod facenet; - -pub use facenet::EmbeddingGenerator; diff --git a/src/faceembed/ort/facenet.rs b/src/faceembed/ort/facenet.rs deleted file mode 100644 index c2f3272..0000000 --- a/src/faceembed/ort/facenet.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::errors::*; -use crate::faceembed::FaceEmbedder; -use error_stack::ResultExt; -use ndarray::{Array2, ArrayView4}; -use std::path::Path; - -#[derive(Debug)] -pub struct EmbeddingGenerator { - // Placeholder - ORT implementation to be completed later - _placeholder: (), -} - -pub struct EmbeddingGeneratorBuilder { - _model_data: Vec, -} - -impl EmbeddingGeneratorBuilder { - pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result { - Ok(Self { - _model_data: model.as_ref().to_vec(), - }) - } - - pub fn with_execution_providers(self, _providers: Vec) -> Self { - self - } - - pub fn with_intra_threads(self, _threads: usize) -> Self { - self - } - - pub fn with_inter_threads(self, _threads: usize) -> Self { - self - } - - pub fn build(self) -> crate::errors::Result { - // TODO: Implement ORT session creation - tracing::warn!("ORT FaceNet implementation is not yet complete"); - Ok(EmbeddingGenerator { _placeholder: () }) - } -} - -impl EmbeddingGenerator { - const INPUT_NAME: &'static str = "serving_default_input_6:0"; - const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; - - pub fn builder>() -> fn( - T, - ) -> std::result::Result< - EmbeddingGeneratorBuilder, - error_stack::Report, - > { - EmbeddingGeneratorBuilder::new - } - - pub fn new(path: impl AsRef) -> crate::errors::Result { - let model = std::fs::read(path) - .change_context(Error) - .attach_printable("Failed to read model file")?; - Self::new_from_bytes(&model) - } - - pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result { - tracing::info!("Loading face embedding model from bytes"); - Self::builder()(model)?.build() - } - - pub fn run_models(&self, _face: ArrayView4) -> crate::errors::Result> { - // TODO: Implement ORT inference - tracing::error!("ORT FaceNet inference not yet implemented"); - Err(Error).attach_printable("ORT FaceNet implementation is incomplete") - } -} - -impl FaceEmbedder for EmbeddingGenerator { - fn run_models(&self, faces: ArrayView4) -> crate::errors::Result> { - self.run_models(faces) - } -} diff --git a/src/faceembed/ort/mod.rs b/src/faceembed/ort/mod.rs deleted file mode 100644 index 94700e6..0000000 --- a/src/faceembed/ort/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod facenet; - -pub use facenet::EmbeddingGenerator; diff --git a/src/main.rs b/src/main.rs index 892e135..9c7657f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,12 +34,13 @@ pub fn main() -> Result<()> { .build() .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; - let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL_MNN) - .change_context(Error)? - .with_forward_type(detect.forward_type) - .build() - .change_context(errors::Error) - .attach_printable("Failed to create face embedding model")?; + let facenet = + faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN) + .change_context(Error)? + .with_forward_type(detect.forward_type) + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; run_detection(detect, retinaface, facenet)?; } @@ -50,11 +51,12 @@ pub fn main() -> Result<()> { .build() .change_context(errors::Error) .attach_printable("Failed to create face detection model")?; - let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_MODEL_ONNX) - .change_context(Error)? - .build() - .change_context(errors::Error) - .attach_printable("Failed to create face embedding model")?; + let facenet = + faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX) + .change_context(Error)? + .build() + .change_context(errors::Error) + .attach_printable("Failed to create face embedding model")?; run_detection(detect, retinaface, facenet)?; }