feat: implement the facenet implementation for ort
This commit is contained in:
@@ -14,5 +14,5 @@ use ndarray::{Array2, ArrayView4};
|
|||||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||||
pub trait FaceEmbedder {
|
pub trait FaceEmbedder {
|
||||||
/// Generate embeddings for a batch of face images
|
/// Generate embeddings for a batch of face images
|
||||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -142,14 +142,6 @@ impl EmbeddingGenerator {
|
|||||||
.change_context(Error)?;
|
.change_context(Error)?;
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn embedding(&self, roi: ArrayView3<u8>) -> Result<Array1<u8>> {
|
|
||||||
// todo!()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn embeddings(&self, roi: ArrayView4<u8>) -> Result<Array2<u8>> {
|
|
||||||
// todo!()
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FaceNetEmbedder for EmbeddingGenerator {
|
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||||
@@ -160,7 +152,7 @@ impl FaceNetEmbedder for EmbeddingGenerator {
|
|||||||
|
|
||||||
// Main trait implementation for backward compatibility
|
// Main trait implementation for backward compatibility
|
||||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||||
self.run_models(faces)
|
EmbeddingGenerator::run_models(self, faces)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ impl EmbeddingGeneratorBuilder {
|
|||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
.attach_printable("Failed to create ORT session from model bytes")?;
|
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||||
|
|
||||||
tracing::info!("Successfully created ORT RetinaFace session");
|
tracing::info!("Successfully created ORT FaceNet session");
|
||||||
|
|
||||||
Ok(EmbeddingGenerator { session })
|
Ok(EmbeddingGenerator { session })
|
||||||
}
|
}
|
||||||
@@ -137,14 +137,63 @@ impl EmbeddingGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||||
tracing::info!("Loading face embedding model from bytes");
|
tracing::info!("Loading ORT face embedding model from bytes");
|
||||||
Self::builder(model)?.build()
|
Self::builder(model)?.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
// TODO: Implement ORT inference
|
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||||
tracing::error!("ORT FaceNet inference not yet implemented");
|
let input_tensor = faces
|
||||||
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
|
.mapv(|x| x as f32 / 255.0)
|
||||||
|
.as_standard_layout()
|
||||||
|
.into_owned();
|
||||||
|
|
||||||
|
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||||
|
|
||||||
|
// Create ORT input tensor
|
||||||
|
let input_value = Tensor::from_array(input_tensor)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create input tensor")?;
|
||||||
|
|
||||||
|
// Run inference
|
||||||
|
tracing::debug!("Running ORT FaceNet inference");
|
||||||
|
let outputs = self
|
||||||
|
.session
|
||||||
|
.run(ort::inputs![Self::INPUT_NAME => input_value])
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to run inference")?;
|
||||||
|
|
||||||
|
// Extract output tensor
|
||||||
|
let output = outputs
|
||||||
|
.get(Self::OUTPUT_NAME)
|
||||||
|
.ok_or(Error)
|
||||||
|
.attach_printable("Missing output from FaceNet model")?
|
||||||
|
.try_extract_tensor::<f32>()
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to extract output tensor")?;
|
||||||
|
|
||||||
|
let (output_shape, output_data) = output;
|
||||||
|
|
||||||
|
tracing::trace!("Output shape: {:?}", output_shape);
|
||||||
|
|
||||||
|
// Convert to ndarray format
|
||||||
|
let output_dims = output_shape.as_ref();
|
||||||
|
|
||||||
|
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
|
||||||
|
let batch_size = output_dims[0] as usize;
|
||||||
|
let embedding_dim = output_dims[1] as usize;
|
||||||
|
|
||||||
|
let output_array =
|
||||||
|
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create output ndarray")?;
|
||||||
|
|
||||||
|
tracing::trace!(
|
||||||
|
"Generated embeddings with shape: {:?}",
|
||||||
|
output_array.shape()
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(output_array)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +205,9 @@ impl FaceNetEmbedder for EmbeddingGenerator {
|
|||||||
|
|
||||||
// Main trait implementation for backward compatibility
|
// Main trait implementation for backward compatibility
|
||||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||||
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
|
// Need to create a mutable reference for the session
|
||||||
|
// This is a workaround for the trait signature mismatch
|
||||||
self.run_models(faces)
|
self.run_models(faces)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ pub fn main() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, facenet: E) -> Result<()>
|
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
|
||||||
where
|
where
|
||||||
D: facedet::FaceDetector,
|
D: facedet::FaceDetector,
|
||||||
E: faceembed::FaceEmbedder,
|
E: faceembed::FaceEmbedder,
|
||||||
|
|||||||
Reference in New Issue
Block a user