feat: implement the facenet implementation for ort

This commit is contained in:
uttarayan21
2025-08-18 13:20:55 +05:30
parent 5a1f4b9ef6
commit e7c9c38ed7
4 changed files with 62 additions and 19 deletions

View File

@@ -14,5 +14,5 @@ use ndarray::{Array2, ArrayView4};
/// Common trait for face embedding backends - maintained for backward compatibility
pub trait FaceEmbedder {
/// Generate embeddings for a batch of face images
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
}

View File

@@ -142,14 +142,6 @@ impl EmbeddingGenerator {
.change_context(Error)?;
Ok(output)
}
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
// todo!()
// }
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
// todo!()
// }
}
impl FaceNetEmbedder for EmbeddingGenerator {
@@ -160,7 +152,7 @@ impl FaceNetEmbedder for EmbeddingGenerator {
// Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
self.run_models(faces)
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
EmbeddingGenerator::run_models(self, faces)
}
}

View File

@@ -112,7 +112,7 @@ impl EmbeddingGeneratorBuilder {
.change_context(Error)
.attach_printable("Failed to create ORT session from model bytes")?;
tracing::info!("Successfully created ORT RetinaFace session");
tracing::info!("Successfully created ORT FaceNet session");
Ok(EmbeddingGenerator { session })
}
@@ -137,14 +137,63 @@ impl EmbeddingGenerator {
}
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
tracing::info!("Loading face embedding model from bytes");
tracing::info!("Loading ORT face embedding model from bytes");
Self::builder(model)?.build()
}
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// TODO: Implement ORT inference
tracing::error!("ORT FaceNet inference not yet implemented");
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Convert input from u8 to f32 and normalize to [0, 1] range
let input_tensor = faces
.mapv(|x| x as f32 / 255.0)
.as_standard_layout()
.into_owned();
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
// Create ORT input tensor
let input_value = Tensor::from_array(input_tensor)
.change_context(Error)
.attach_printable("Failed to create input tensor")?;
// Run inference
tracing::debug!("Running ORT FaceNet inference");
let outputs = self
.session
.run(ort::inputs![Self::INPUT_NAME => input_value])
.change_context(Error)
.attach_printable("Failed to run inference")?;
// Extract output tensor
let output = outputs
.get(Self::OUTPUT_NAME)
.ok_or(Error)
.attach_printable("Missing output from FaceNet model")?
.try_extract_tensor::<f32>()
.change_context(Error)
.attach_printable("Failed to extract output tensor")?;
let (output_shape, output_data) = output;
tracing::trace!("Output shape: {:?}", output_shape);
// Convert to ndarray format
let output_dims = output_shape.as_ref();
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
let batch_size = output_dims[0] as usize;
let embedding_dim = output_dims[1] as usize;
let output_array =
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
.change_context(Error)
.attach_printable("Failed to create output ndarray")?;
tracing::trace!(
"Generated embeddings with shape: {:?}",
output_array.shape()
);
Ok(output_array)
}
}
@@ -156,7 +205,9 @@ impl FaceNetEmbedder for EmbeddingGenerator {
// Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Need to create a mutable reference for the session
// This is a workaround for the trait signature mismatch
self.run_models(faces)
}
}

View File

@@ -72,7 +72,7 @@ pub fn main() -> Result<()> {
Ok(())
}
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()>
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,