feat: Move facenet to same structure as facedet
This commit is contained in:
2
rfcs
2
rfcs
Submodule rfcs updated: 98ec027ca5...c973203daf
@@ -1,20 +1,18 @@
|
|||||||
|
pub mod facenet;
|
||||||
|
|
||||||
|
// Re-export common types and traits
|
||||||
|
pub use facenet::FaceNetEmbedder;
|
||||||
|
pub use facenet::{FaceEmbedding, FaceEmbeddingConfig, IntoEmbeddings};
|
||||||
|
|
||||||
|
// Convenience type aliases for different backends
|
||||||
|
pub use facenet::mnn::EmbeddingGenerator as MnnEmbeddingGenerator;
|
||||||
|
pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
|
||||||
|
|
||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use ndarray::{Array2, ArrayView4};
|
use ndarray::{Array2, ArrayView4};
|
||||||
|
|
||||||
pub mod mnn;
|
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||||
pub mod ort;
|
|
||||||
|
|
||||||
/// Common trait for face embedding backends
|
|
||||||
pub trait FaceEmbedder {
|
pub trait FaceEmbedder {
|
||||||
/// Generate embeddings for a batch of face images
|
/// Generate embeddings for a batch of face images
|
||||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convenience type aliases for different backends
|
|
||||||
pub mod facenet {
|
|
||||||
pub use crate::faceembed::mnn::facenet as mnn;
|
|
||||||
pub use crate::faceembed::ort::facenet as ort;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to MNN implementation for backward compatibility
|
|
||||||
pub use mnn::facenet::EmbeddingGenerator;
|
|
||||||
|
|||||||
209
src/faceembed/facenet.rs
Normal file
209
src/faceembed/facenet.rs
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
pub mod mnn;
|
||||||
|
pub mod ort;
|
||||||
|
|
||||||
|
use crate::errors::*;
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||||
|
|
||||||
|
/// Configuration for face embedding processing
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct FaceEmbeddingConfig {
|
||||||
|
/// Input image width expected by the model
|
||||||
|
pub input_width: usize,
|
||||||
|
/// Input image height expected by the model
|
||||||
|
pub input_height: usize,
|
||||||
|
/// Whether to normalize embeddings to unit vectors
|
||||||
|
pub normalize: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceEmbeddingConfig {
|
||||||
|
pub fn with_input_size(mut self, width: usize, height: usize) -> Self {
|
||||||
|
self.input_width = width;
|
||||||
|
self.input_height = height;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_normalization(mut self, normalize: bool) -> Self {
|
||||||
|
self.normalize = normalize;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FaceEmbeddingConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
input_width: 160,
|
||||||
|
input_height: 160,
|
||||||
|
normalize: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a face embedding vector
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct FaceEmbedding {
|
||||||
|
/// The embedding vector
|
||||||
|
pub vector: Array1<f32>,
|
||||||
|
/// Optional confidence score for the embedding quality
|
||||||
|
pub confidence: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceEmbedding {
|
||||||
|
pub fn new(vector: Array1<f32>) -> Self {
|
||||||
|
Self {
|
||||||
|
vector,
|
||||||
|
confidence: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_confidence(mut self, confidence: f32) -> Self {
|
||||||
|
self.confidence = Some(confidence);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate cosine similarity with another embedding
|
||||||
|
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
|
||||||
|
let dot_product = self.vector.dot(&other.vector);
|
||||||
|
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
|
||||||
|
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
|
||||||
|
dot_product / (norm_self * norm_other)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate Euclidean distance with another embedding
|
||||||
|
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
|
||||||
|
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize the embedding vector to unit length
|
||||||
|
pub fn normalize(&mut self) {
|
||||||
|
let norm = self.vector.mapv(|x| x * x).sum().sqrt();
|
||||||
|
if norm > 0.0 {
|
||||||
|
self.vector.mapv_inplace(|x| x / norm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the dimensionality of the embedding
|
||||||
|
pub fn dimension(&self) -> usize {
|
||||||
|
self.vector.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw model outputs that can be converted to embeddings
|
||||||
|
pub trait IntoEmbeddings {
|
||||||
|
fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result<Vec<FaceEmbedding>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoEmbeddings for Array2<f32> {
|
||||||
|
fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result<Vec<FaceEmbedding>> {
|
||||||
|
let mut embeddings = Vec::new();
|
||||||
|
|
||||||
|
for row in self.rows() {
|
||||||
|
let mut vector = row.to_owned();
|
||||||
|
|
||||||
|
if config.normalize {
|
||||||
|
let norm = vector.mapv(|x| x * x).sum().sqrt();
|
||||||
|
if norm > 0.0 {
|
||||||
|
vector.mapv_inplace(|x| x / norm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings.push(FaceEmbedding::new(vector));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Common trait for face embedding backends
|
||||||
|
pub trait FaceNetEmbedder {
|
||||||
|
/// Generate embeddings for a batch of face images
|
||||||
|
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||||
|
|
||||||
|
/// Generate embeddings with full pipeline including postprocessing
|
||||||
|
fn generate_embeddings(
|
||||||
|
&mut self,
|
||||||
|
faces: ArrayView4<u8>,
|
||||||
|
config: FaceEmbeddingConfig,
|
||||||
|
) -> Result<Vec<FaceEmbedding>> {
|
||||||
|
let raw_output = self
|
||||||
|
.run_model(faces)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to generate embeddings")?;
|
||||||
|
|
||||||
|
raw_output
|
||||||
|
.into_embeddings(&config)
|
||||||
|
.attach_printable("Failed to process embeddings")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a single embedding from a single face image
|
||||||
|
fn generate_embedding(
|
||||||
|
&mut self,
|
||||||
|
face: ArrayView3<u8>,
|
||||||
|
config: FaceEmbeddingConfig,
|
||||||
|
) -> Result<FaceEmbedding> {
|
||||||
|
// Add batch dimension
|
||||||
|
let face_batch = face.insert_axis(ndarray::Axis(0));
|
||||||
|
let embeddings = self.generate_embeddings(face_batch.view(), config)?;
|
||||||
|
|
||||||
|
embeddings
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or(Error)
|
||||||
|
.attach_printable("No embedding generated for input face")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utility functions for embedding processing
|
||||||
|
pub mod utils {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Compute pairwise cosine similarities between two sets of embeddings
|
||||||
|
pub fn pairwise_cosine_similarities(
|
||||||
|
embeddings1: &[FaceEmbedding],
|
||||||
|
embeddings2: &[FaceEmbedding],
|
||||||
|
) -> Array2<f32> {
|
||||||
|
let n1 = embeddings1.len();
|
||||||
|
let n2 = embeddings2.len();
|
||||||
|
let mut similarities = Array2::zeros((n1, n2));
|
||||||
|
|
||||||
|
for (i, emb1) in embeddings1.iter().enumerate() {
|
||||||
|
for (j, emb2) in embeddings2.iter().enumerate() {
|
||||||
|
similarities[(i, j)] = emb1.cosine_similarity(emb2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
similarities
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the best matching embedding from a gallery for each query
|
||||||
|
pub fn find_best_matches(
|
||||||
|
queries: &[FaceEmbedding],
|
||||||
|
gallery: &[FaceEmbedding],
|
||||||
|
) -> Vec<(usize, f32)> {
|
||||||
|
let similarities = pairwise_cosine_similarities(queries, gallery);
|
||||||
|
let mut best_matches = Vec::new();
|
||||||
|
|
||||||
|
for i in 0..queries.len() {
|
||||||
|
let row = similarities.row(i);
|
||||||
|
let (best_idx, best_score) = row
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
best_matches.push((best_idx, *best_score));
|
||||||
|
}
|
||||||
|
|
||||||
|
best_matches
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter embeddings by minimum quality threshold
|
||||||
|
pub fn filter_by_confidence(
|
||||||
|
embeddings: Vec<FaceEmbedding>,
|
||||||
|
min_confidence: f32,
|
||||||
|
) -> Vec<FaceEmbedding> {
|
||||||
|
embeddings
|
||||||
|
.into_iter()
|
||||||
|
.filter(|emb| emb.confidence.map_or(true, |conf| conf >= min_confidence))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::errors::*;
|
use crate::errors::*;
|
||||||
use crate::faceembed::FaceEmbedder;
|
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||||
use mnn_bridge::ndarray::*;
|
use mnn_bridge::ndarray::*;
|
||||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
@@ -63,9 +63,10 @@ impl EmbeddingGenerator {
|
|||||||
Self::new_from_bytes(&model)
|
Self::new_from_bytes(&model)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn builder<T: AsRef<[u8]>>()
|
pub fn builder<T: AsRef<[u8]>>(
|
||||||
-> fn(T) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
model: T,
|
||||||
EmbeddingGeneratorBuilder::new
|
) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
||||||
|
EmbeddingGeneratorBuilder::new(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||||
@@ -151,7 +152,14 @@ impl EmbeddingGenerator {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FaceEmbedder for EmbeddingGenerator {
|
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||||
|
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||||
|
self.run_models(faces)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main trait implementation for backward compatibility
|
||||||
|
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||||
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
fn run_models(&self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||||
self.run_models(faces)
|
self.run_models(faces)
|
||||||
}
|
}
|
||||||
162
src/faceembed/facenet/ort.rs
Normal file
162
src/faceembed/facenet/ort.rs
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
use crate::errors::*;
|
||||||
|
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use ndarray::{Array2, ArrayView4};
|
||||||
|
use ort::{
|
||||||
|
execution_providers::{
|
||||||
|
CPUExecutionProvider, CoreMLExecutionProvider, ExecutionProviderDispatch,
|
||||||
|
},
|
||||||
|
session::{Session, builder::GraphOptimizationLevel},
|
||||||
|
value::Tensor,
|
||||||
|
};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EmbeddingGenerator {
|
||||||
|
session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmbeddingGeneratorBuilder {
|
||||||
|
model_data: Vec<u8>,
|
||||||
|
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
|
||||||
|
intra_threads: Option<usize>,
|
||||||
|
inter_threads: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingGeneratorBuilder {
|
||||||
|
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
model_data: model.as_ref().to_vec(),
|
||||||
|
execution_providers: None,
|
||||||
|
intra_threads: None,
|
||||||
|
inter_threads: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
|
||||||
|
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|provider| match provider.as_str() {
|
||||||
|
"cpu" | "CPU" => Some(CPUExecutionProvider::default().build()),
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
"coreml" | "CoreML" => Some(CoreMLExecutionProvider::default().build()),
|
||||||
|
_ => {
|
||||||
|
tracing::warn!("Unknown execution provider: {}", provider);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if !execution_providers.is_empty() {
|
||||||
|
self.execution_providers = Some(execution_providers);
|
||||||
|
} else {
|
||||||
|
tracing::warn!("No valid execution providers found, falling back to CPU");
|
||||||
|
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_intra_threads(mut self, threads: usize) -> Self {
|
||||||
|
self.intra_threads = Some(threads);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_inter_threads(mut self, threads: usize) -> Self {
|
||||||
|
self.inter_threads = Some(threads);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> crate::errors::Result<EmbeddingGenerator> {
|
||||||
|
let mut session_builder = Session::builder()
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create session builder")?;
|
||||||
|
|
||||||
|
// Set execution providers
|
||||||
|
if let Some(providers) = self.execution_providers {
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_execution_providers(providers)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set execution providers")?;
|
||||||
|
} else {
|
||||||
|
// Default to CPU
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_execution_providers([CPUExecutionProvider::default().build()])
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set default CPU execution provider")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set threading options
|
||||||
|
if let Some(threads) = self.intra_threads {
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_intra_threads(threads)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set intra threads")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(threads) = self.inter_threads {
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_inter_threads(threads)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set inter threads")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set optimization level
|
||||||
|
session_builder = session_builder
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to set optimization level")?;
|
||||||
|
|
||||||
|
// Create session from model bytes
|
||||||
|
let session = session_builder
|
||||||
|
.commit_from_memory(&self.model_data)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||||
|
|
||||||
|
tracing::info!("Successfully created ORT RetinaFace session");
|
||||||
|
|
||||||
|
Ok(EmbeddingGenerator { session })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingGenerator {
|
||||||
|
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||||
|
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||||
|
|
||||||
|
pub fn builder<T: AsRef<[u8]>>(
|
||||||
|
model: T,
|
||||||
|
) -> std::result::Result<EmbeddingGeneratorBuilder, error_stack::Report<crate::errors::Error>>
|
||||||
|
{
|
||||||
|
EmbeddingGeneratorBuilder::new(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
||||||
|
let model = std::fs::read(path)
|
||||||
|
.change_context(Error)
|
||||||
|
.attach_printable("Failed to read model file")?;
|
||||||
|
Self::new_from_bytes(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||||
|
tracing::info!("Loading face embedding model from bytes");
|
||||||
|
Self::builder(model)?.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
|
// TODO: Implement ORT inference
|
||||||
|
tracing::error!("ORT FaceNet inference not yet implemented");
|
||||||
|
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||||
|
fn run_model(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
|
self.run_models(faces)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main trait implementation for backward compatibility
|
||||||
|
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||||
|
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||||
|
self.run_models(faces)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
pub mod facenet;
|
|
||||||
|
|
||||||
pub use facenet::EmbeddingGenerator;
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
use crate::errors::*;
|
|
||||||
use crate::faceembed::FaceEmbedder;
|
|
||||||
use error_stack::ResultExt;
|
|
||||||
use ndarray::{Array2, ArrayView4};
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct EmbeddingGenerator {
|
|
||||||
// Placeholder - ORT implementation to be completed later
|
|
||||||
_placeholder: (),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct EmbeddingGeneratorBuilder {
|
|
||||||
_model_data: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingGeneratorBuilder {
|
|
||||||
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
|
||||||
Ok(Self {
|
|
||||||
_model_data: model.as_ref().to_vec(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_execution_providers(self, _providers: Vec<String>) -> Self {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_intra_threads(self, _threads: usize) -> Self {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_inter_threads(self, _threads: usize) -> Self {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build(self) -> crate::errors::Result<EmbeddingGenerator> {
|
|
||||||
// TODO: Implement ORT session creation
|
|
||||||
tracing::warn!("ORT FaceNet implementation is not yet complete");
|
|
||||||
Ok(EmbeddingGenerator { _placeholder: () })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingGenerator {
|
|
||||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
|
||||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
|
||||||
|
|
||||||
pub fn builder<T: AsRef<[u8]>>() -> fn(
|
|
||||||
T,
|
|
||||||
) -> std::result::Result<
|
|
||||||
EmbeddingGeneratorBuilder,
|
|
||||||
error_stack::Report<crate::errors::Error>,
|
|
||||||
> {
|
|
||||||
EmbeddingGeneratorBuilder::new
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
|
||||||
let model = std::fs::read(path)
|
|
||||||
.change_context(Error)
|
|
||||||
.attach_printable("Failed to read model file")?;
|
|
||||||
Self::new_from_bytes(&model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
|
||||||
tracing::info!("Loading face embedding model from bytes");
|
|
||||||
Self::builder()(model)?.build()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run_models(&self, _face: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
|
||||||
// TODO: Implement ORT inference
|
|
||||||
tracing::error!("ORT FaceNet inference not yet implemented");
|
|
||||||
Err(Error).attach_printable("ORT FaceNet implementation is incomplete")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FaceEmbedder for EmbeddingGenerator {
|
|
||||||
fn run_models(&self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
|
||||||
self.run_models(faces)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
pub mod facenet;
|
|
||||||
|
|
||||||
pub use facenet::EmbeddingGenerator;
|
|
||||||
@@ -34,7 +34,8 @@ pub fn main() -> Result<()> {
|
|||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
let facenet = faceembed::mnn::EmbeddingGenerator::builder()(FACENET_MODEL_MNN)
|
let facenet =
|
||||||
|
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
.with_forward_type(detect.forward_type)
|
.with_forward_type(detect.forward_type)
|
||||||
.build()
|
.build()
|
||||||
@@ -50,7 +51,8 @@ pub fn main() -> Result<()> {
|
|||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to create face detection model")?;
|
.attach_printable("Failed to create face detection model")?;
|
||||||
let facenet = faceembed::ort::EmbeddingGenerator::builder()(FACENET_MODEL_ONNX)
|
let facenet =
|
||||||
|
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||||
.change_context(Error)?
|
.change_context(Error)?
|
||||||
.build()
|
.build()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
|
|||||||
Reference in New Issue
Block a user