feat: implement the facenet implementation for ort
This commit is contained in:
@@ -14,5 +14,5 @@ use ndarray::{Array2, ArrayView4};
|
||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||
pub trait FaceEmbedder {
|
||||
/// Generate embeddings for a batch of face images
|
||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||
}
|
||||
|
||||
@@ -142,14 +142,6 @@ impl EmbeddingGenerator {
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
||||
// todo!()
|
||||
// }
|
||||
}
|
||||
|
||||
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
@@ -160,7 +152,7 @@ impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
|
||||
// Main trait implementation for backward compatibility
|
||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
self.run_models(faces)
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
EmbeddingGenerator::run_models(self, faces)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ impl EmbeddingGeneratorBuilder {
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||
|
||||
tracing::info!("Successfully created ORT RetinaFace session");
|
||||
tracing::info!("Successfully created ORT FaceNet session");
|
||||
|
||||
Ok(EmbeddingGenerator { session })
|
||||
}
|
||||
@@ -137,14 +137,63 @@ impl EmbeddingGenerator {
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||
tracing::info!("Loading face embedding model from bytes");
|
||||
tracing::info!("Loading ORT face embedding model from bytes");
|
||||
Self::builder(model)?.build()
|
||||
}
|
||||
|
||||
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// TODO: Implement ORT inference
|
||||
tracing::error!("ORT FaceNet inference not yet implemented");
|
||||
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
|
||||
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||
let input_tensor = faces
|
||||
.mapv(|x| x as f32 / 255.0)
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
// Create ORT input tensor
|
||||
let input_value = Tensor::from_array(input_tensor)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create input tensor")?;
|
||||
|
||||
// Run inference
|
||||
tracing::debug!("Running ORT FaceNet inference");
|
||||
let outputs = self
|
||||
.session
|
||||
.run(ort::inputs![Self::INPUT_NAME => input_value])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run inference")?;
|
||||
|
||||
// Extract output tensor
|
||||
let output = outputs
|
||||
.get(Self::OUTPUT_NAME)
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing output from FaceNet model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract output tensor")?;
|
||||
|
||||
let (output_shape, output_data) = output;
|
||||
|
||||
tracing::trace!("Output shape: {:?}", output_shape);
|
||||
|
||||
// Convert to ndarray format
|
||||
let output_dims = output_shape.as_ref();
|
||||
|
||||
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
|
||||
let batch_size = output_dims[0] as usize;
|
||||
let embedding_dim = output_dims[1] as usize;
|
||||
|
||||
let output_array =
|
||||
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output ndarray")?;
|
||||
|
||||
tracing::trace!(
|
||||
"Generated embeddings with shape: {:?}",
|
||||
output_array.shape()
|
||||
);
|
||||
|
||||
Ok(output_array)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +205,9 @@ impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
|
||||
// Main trait implementation for backward compatibility
|
||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Need to create a mutable reference for the session
|
||||
// This is a workaround for the trait signature mismatch
|
||||
self.run_models(faces)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ pub fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()>
|
||||
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
|
||||
where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
|
||||
Reference in New Issue
Block a user