diff --git a/src/faceembed.rs b/src/faceembed.rs index ccdb743..bdc5294 100644 --- a/src/faceembed.rs +++ b/src/faceembed.rs @@ -14,5 +14,5 @@ use ndarray::{Array2, ArrayView4}; /// Common trait for face embedding backends - maintained for backward compatibility pub trait FaceEmbedder { /// Generate embeddings for a batch of face images - fn run_models(&self, faces: ArrayView4) -> Result>; + fn run_models(&mut self, faces: ArrayView4) -> Result>; } diff --git a/src/faceembed/facenet/mnn.rs b/src/faceembed/facenet/mnn.rs index 4ac97ab..4f305a5 100644 --- a/src/faceembed/facenet/mnn.rs +++ b/src/faceembed/facenet/mnn.rs @@ -142,14 +142,6 @@ impl EmbeddingGenerator { .change_context(Error)?; Ok(output) } - - // pub fn embedding(&self, roi: ArrayView3) -> Result> { - // todo!() - // } - - // pub fn embeddings(&self, roi: ArrayView4) -> Result> { - // todo!() - // } } impl FaceNetEmbedder for EmbeddingGenerator { @@ -160,7 +152,7 @@ impl FaceNetEmbedder for EmbeddingGenerator { // Main trait implementation for backward compatibility impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { - fn run_models(&self, faces: ArrayView4) -> Result> { - self.run_models(faces) + fn run_models(&mut self, faces: ArrayView4) -> Result> { + EmbeddingGenerator::run_models(self, faces) } } diff --git a/src/faceembed/facenet/ort.rs b/src/faceembed/facenet/ort.rs index 0c59b32..3db56a2 100644 --- a/src/faceembed/facenet/ort.rs +++ b/src/faceembed/facenet/ort.rs @@ -112,7 +112,7 @@ impl EmbeddingGeneratorBuilder { .change_context(Error) .attach_printable("Failed to create ORT session from model bytes")?; - tracing::info!("Successfully created ORT RetinaFace session"); + tracing::info!("Successfully created ORT FaceNet session"); Ok(EmbeddingGenerator { session }) } @@ -137,14 +137,63 @@ impl EmbeddingGenerator { } pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result { - tracing::info!("Loading face embedding model from bytes"); + tracing::info!("Loading ORT face embedding model from bytes"); Self::builder(model)?.build() } - pub fn run_models(&self, _face: ArrayView4) -> crate::errors::Result> { - // TODO: Implement ORT inference - tracing::error!("ORT FaceNet inference not yet implemented"); - Err(Error).attach_printable("ORT FaceNet implementation is incomplete") + pub fn run_models(&mut self, faces: ArrayView4) -> crate::errors::Result> { + // Convert input from u8 to f32 and normalize to [0, 1] range + let input_tensor = faces + .mapv(|x| x as f32 / 255.0) + .as_standard_layout() + .into_owned(); + + tracing::trace!("Input tensor shape: {:?}", input_tensor.shape()); + + // Create ORT input tensor + let input_value = Tensor::from_array(input_tensor) + .change_context(Error) + .attach_printable("Failed to create input tensor")?; + + // Run inference + tracing::debug!("Running ORT FaceNet inference"); + let outputs = self + .session + .run(ort::inputs![Self::INPUT_NAME => input_value]) + .change_context(Error) + .attach_printable("Failed to run inference")?; + + // Extract output tensor + let output = outputs + .get(Self::OUTPUT_NAME) + .ok_or(Error) + .attach_printable("Missing output from FaceNet model")? + .try_extract_tensor::() + .change_context(Error) + .attach_printable("Failed to extract output tensor")?; + + let (output_shape, output_data) = output; + + tracing::trace!("Output shape: {:?}", output_shape); + + // Convert to ndarray format + let output_dims = output_shape.as_ref(); + + // FaceNet typically outputs embeddings as [batch_size, embedding_dim] + let batch_size = output_dims[0] as usize; + let embedding_dim = output_dims[1] as usize; + + let output_array = + ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec()) + .change_context(Error) + .attach_printable("Failed to create output ndarray")?; + + tracing::trace!( + "Generated embeddings with shape: {:?}", + output_array.shape() + ); + + Ok(output_array) } } @@ -156,7 +205,9 @@ impl FaceNetEmbedder for EmbeddingGenerator { // Main trait implementation for backward compatibility impl crate::faceembed::FaceEmbedder for EmbeddingGenerator { - fn run_models(&self, faces: ArrayView4) -> crate::errors::Result> { + fn run_models(&mut self, faces: ArrayView4) -> crate::errors::Result> { + // Need to create a mutable reference for the session + // This is a workaround for the trait signature mismatch self.run_models(faces) } } diff --git a/src/main.rs b/src/main.rs index 9c7657f..2546d5f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,7 +72,7 @@ pub fn main() -> Result<()> { Ok(()) } -fn run_detection(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()> +fn run_detection(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()> where D: facedet::FaceDetector, E: faceembed::FaceEmbedder,